From 21a4cb9dd2ccd70806b331afda680563351af50f Mon Sep 17 00:00:00 2001 From: xxxkkw <646567098@qq.com> Date: Wed, 27 May 2026 19:43:01 +0800 Subject: [PATCH 1/6] Add hybrid prefix cache restore coverage Fix prompt-cache serialization and server byte-limit wiring while adding conservative restore capability coverage for hybrid KV, rotating, recurrent, and chunked cache states. --- mlx_lm/models/cache.py | 200 +++++++++++++++++++-- mlx_lm/server.py | 8 +- tests/test_prefix_cache_correctness.py | 233 +++++++++++++++++++++++++ tests/test_prompt_cache.py | 54 ++++++ tests/test_server.py | 35 ++++ 5 files changed, 515 insertions(+), 15 deletions(-) create mode 100644 tests/test_prefix_cache_correctness.py diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index b84c9d650..3f1acb1a0 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -12,6 +12,18 @@ from .base import create_causal_mask +def _parse_bool_state(value): + if isinstance(value, bool): + return value + if isinstance(value, str): + lowered = value.lower() + if lowered == "true": + return True + if lowered == "false": + return False + raise ValueError(f"Invalid boolean cache metadata value: {value!r}") + + def make_prompt_cache( model: nn.Module, max_kv_size: Optional[int] = None, @@ -111,6 +123,15 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: return [c.trim(num_tokens) for c in cache][0] +def can_restore_prompt_cache(cache: List[Any], position: int) -> bool: + return all(hasattr(c, "can_restore_to") and c.can_restore_to(position) for c in cache) + + +def restore_prompt_cache(cache: List[Any], position: int) -> bool: + results = [c.restore_to(position) for c in cache] + return all(r.restored for r in results) + + def create_attention_mask( N: int, offset: int, return_array: bool, window_size: Optional[int] ): @@ -124,6 +145,22 @@ def create_attention_mask( return "causal" +@dataclass(frozen=True) +class CacheRetainedRange: + logical_start: int + logical_end: int + physical_start: int + physical_end: int + + +@dataclass(frozen=True) +class CacheRestoreResult: + restored: bool + restored_tokens: int + replay_start: int + reason: str = "" + + class _BaseCache: @property def state(self): @@ -146,6 +183,37 @@ def meta_state(self, v): def is_trimmable(self): return False + def logical_length(self): + return self.size() + + def retained_range(self): + length = self.logical_length() + return CacheRetainedRange( + logical_start=0, + logical_end=length, + physical_start=0, + physical_end=length, + ) + + def can_restore_to(self, position: int): + return self.is_trimmable() and 0 <= position <= self.logical_length() + + def restore_to(self, position: int): + if not self.can_restore_to(position): + return CacheRestoreResult( + restored=False, + restored_tokens=0, + replay_start=0, + reason=f"{type(self).__name__} cannot restore to {position}", + ) + trimmed = self.trim(self.logical_length() - position) + return CacheRestoreResult( + restored=True, + restored_tokens=position, + replay_start=position, + reason=f"trimmed {trimmed} tokens", + ) + def size(self): """ Return the size (i.e. sequence length) of the cache. @@ -357,6 +425,9 @@ def update_and_fetch(self, keys, values): def size(self): return self.offset + def logical_length(self): + return self.offset + @property def state(self): if self.offset == self.keys.shape[2]: @@ -517,6 +588,35 @@ def update_and_fetch(self, keys, values): def size(self): return min(self.offset, self.max_size) + def logical_length(self): + return self.offset + + def retained_range(self): + if self.keys is None: + return CacheRetainedRange( + logical_start=self.offset, + logical_end=self.offset, + physical_start=0, + physical_end=0, + ) + physical_len = self.keys.shape[2] + logical_start = max(0, self.offset - physical_len) + return CacheRetainedRange( + logical_start=logical_start, + logical_end=self.offset, + physical_start=0, + physical_end=physical_len, + ) + + def can_restore_to(self, position: int): + if self.keep != 0: + return False + if not 0 <= position <= self.offset: + return False + retained = self.retained_range() + required_start = max(0, position - self.max_size) + return retained.logical_start <= required_start and position <= retained.logical_end + @property def state(self): if self.offset < self.keys.shape[2]: @@ -548,6 +648,31 @@ def trim(self, n): self._idx -= n return n + def restore_to(self, position: int): + if not self.can_restore_to(position): + return CacheRestoreResult( + restored=False, + restored_tokens=0, + replay_start=0, + reason=f"{type(self).__name__} retained range cannot restore to {position}", + ) + if self.keys is None: + self.offset = position + self._idx = 0 + return CacheRestoreResult(True, position, position, "empty cache restored") + + retained = self.retained_range() + required_start = max(0, position - self.max_size) + start = required_start - retained.logical_start + end = position - retained.logical_start + keys = self._temporal_order(self.keys) + values = self._temporal_order(self.values) + self.keys = mx.contiguous(keys[..., start:end, :]) + self.values = mx.contiguous(values[..., start:end, :]) + self.offset = position + self._idx = self.keys.shape[2] + return CacheRestoreResult(True, position, position, "restored retained rotating window") + def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: raise NotImplementedError("RotatingKVCache Quantization NYI") @@ -698,6 +823,20 @@ def make_mask(self, N: int): else: return None + def logical_length(self): + return 0 + + def can_restore_to(self, position: int): + return False + + def restore_to(self, position: int): + return CacheRestoreResult( + restored=False, + restored_tokens=0, + replay_start=0, + reason="ArraysCache requires an exact checkpoint or model replay", + ) + @classmethod def merge(cls, caches): n_state = len(caches[0].cache) @@ -739,11 +878,13 @@ def __init__(self, chunk_size): self.start_position = 0 def maybe_trim_front(self): - # Maintain the cache below the chunk size - if self.keys is not None and self.keys.shape[2] >= self.chunk_size: - self.start_position += self.keys.shape[2] - self.chunk_size - self.keys = self.keys[..., -self.chunk_size :, :] - self.values = self.values[..., -self.chunk_size :, :] + if self.keys is not None: + length = self.offset - self.start_position + if length >= self.chunk_size: + trim_size = length - self.chunk_size + self.start_position += trim_size + self.keys = self.keys[..., trim_size:length, :] + self.values = self.values[..., trim_size:length, :] def update_and_fetch(self, keys, values): prev = self.offset - self.start_position @@ -772,18 +913,20 @@ def update_and_fetch(self, keys, values): @property def state(self): - if self.offset == self.keys.shape[2]: + end = self.offset - self.start_position + if end == self.keys.shape[2]: return self.keys, self.values else: return ( - self.keys[..., : self.offset, :], - self.values[..., : self.offset, :], + self.keys[..., :end, :], + self.values[..., :end, :], ) @state.setter def state(self, v): self.keys, self.values = v - self.offset = self.keys.shape[2] + start_position = getattr(self, "start_position", 0) + self.offset = start_position + self.keys.shape[2] def is_trimmable(self): return True @@ -795,11 +938,15 @@ def trim(self, n): @property def meta_state(self): - return tuple(map(str, (self.chunk_size, self.start_position))) + return tuple(map(str, (self.chunk_size, self.start_position, self.offset))) @meta_state.setter def meta_state(self, v): - self.chunk_size, self.start_position = map(int, v) + if len(v) == 2: + self.chunk_size, self.start_position = map(int, v) + self.offset = self.start_position + self.keys.shape[2] + else: + self.chunk_size, self.start_position, self.offset = map(int, v) def empty(self): return self.keys is None @@ -826,6 +973,27 @@ def trim(self, n): m = c.trim(n) return m + def logical_length(self): + if not self.caches: + return 0 + return max(c.logical_length() for c in self.caches) + + def can_restore_to(self, position: int): + return all( + hasattr(c, "can_restore_to") and c.can_restore_to(position) + for c in self.caches + ) + + def restore_to(self, position: int): + results = [c.restore_to(position) for c in self.caches] + restored = all(r.restored for r in results) + return CacheRestoreResult( + restored=restored, + restored_tokens=position if restored else 0, + replay_start=position if restored else 0, + reason="restored CacheList" if restored else "one or more child caches cannot restore", + ) + @property def state(self): return [c.state for c in self.caches] @@ -1312,7 +1480,7 @@ def meta_state(self, v): int, v[:3], ) - self.rotated = bool(v[3]) + self.rotated = _parse_bool_state(v[3]) def is_trimmable(self): return self._offset < self.max_size @@ -1680,9 +1848,13 @@ def fetch_nearest_cache(self, model: Any, tokens: List[int]): short_length = len(result.shorter) if result.shorter is not None else 0 if result.longer is not None and result.common_prefix > short_length: cache_entry = self._trie.get(result.model, result.longer) - if can_trim_prompt_cache(cache_entry.prompt_cache): + prefix = min(len(tokens) - 1, result.common_prefix) + if can_restore_prompt_cache(cache_entry.prompt_cache, prefix): + cache = copy.deepcopy(cache_entry.prompt_cache) + if restore_prompt_cache(cache, prefix): + return cache, tokens[prefix:] + elif can_trim_prompt_cache(cache_entry.prompt_cache): cache = copy.deepcopy(cache_entry.prompt_cache) - prefix = min(len(tokens) - 1, result.common_prefix) num_to_trim = len(result.longer) - prefix trim_prompt_cache(cache, num_to_trim) return cache, tokens[prefix:] diff --git a/mlx_lm/server.py b/mlx_lm/server.py index ce8d95817..2cca5e963 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -1740,7 +1740,13 @@ def run( handler_class=APIHandler, ): group = mx.distributed.init() - prompt_cache = LRUPromptCache(model_provider.cli_args.prompt_cache_size) + prompt_cache_bytes = model_provider.cli_args.prompt_cache_bytes + if prompt_cache_bytes is None: + prompt_cache_bytes = 1 << 63 + prompt_cache = LRUPromptCache( + model_provider.cli_args.prompt_cache_size, + max_bytes=prompt_cache_bytes, + ) response_generator = ResponseGenerator(model_provider, prompt_cache) if group.rank() == 0: _run_http_server(host, port, response_generator) diff --git a/tests/test_prefix_cache_correctness.py b/tests/test_prefix_cache_correctness.py new file mode 100644 index 000000000..51e7f2407 --- /dev/null +++ b/tests/test_prefix_cache_correctness.py @@ -0,0 +1,233 @@ +# Copyright © 2024 Apple Inc. + +import unittest + +import mlx.core as mx + +from mlx_lm.models.cache import ( + ArraysCache, + CacheList, + KVCache, + LRUPromptCache, + RotatingKVCache, +) + + +def make_kv_cache(length): + cache = KVCache() + if length > 0: + x = mx.arange(length, dtype=mx.float32).reshape(1, 1, length, 1) + cache.update_and_fetch(x, x) + return cache + + +def make_rotating_cache(length, max_size=4): + cache = RotatingKVCache(max_size=max_size) + for i in range(length): + x = mx.array([i], dtype=mx.float32).reshape(1, 1, 1, 1) + cache.update_and_fetch(x, x) + return cache + + +def make_prefill_rotating_cache(length, max_size=4): + cache = RotatingKVCache(max_size=max_size) + if length > 0: + x = mx.arange(length, dtype=mx.float32).reshape(1, 1, length, 1) + cache.update_and_fetch(x, x) + return cache + + +def make_arrays_cache(value): + cache = ArraysCache(size=1) + cache[0] = mx.array([[value]], dtype=mx.float32) + return cache + + +class TestHybridPrefixCacheCorrectness(unittest.TestCase): + def test_exact_hit_reuses_non_trimmable_arrays_cache(self): + lru = LRUPromptCache(max_size=4) + model = ("toy",) + prompt = [1, 2, 3] + lru.insert_cache(model, prompt, [make_arrays_cache(11), make_kv_cache(3)]) + + cache, rest = lru.fetch_nearest_cache(model, prompt) + + self.assertEqual(rest, []) + self.assertIsNotNone(cache) + self.assertTrue(mx.array_equal(cache[0][0], mx.array([[11]], dtype=mx.float32))) + + def test_shorter_checkpoint_reuses_non_trimmable_arrays_cache(self): + lru = LRUPromptCache(max_size=4) + model = ("toy",) + lru.insert_cache(model, [1, 2], [make_arrays_cache(22), make_kv_cache(2)]) + lru.insert_cache(model, [1, 2, 3, 4], [make_arrays_cache(44), make_kv_cache(4)]) + + cache, rest = lru.fetch_nearest_cache(model, [1, 2, 9]) + + self.assertEqual(rest, [9]) + self.assertIsNotNone(cache) + self.assertTrue(mx.array_equal(cache[0][0], mx.array([[22]], dtype=mx.float32))) + + def test_longer_non_trimmable_cache_is_not_used_without_checkpoint(self): + lru = LRUPromptCache(max_size=4) + model = ("toy",) + lru.insert_cache(model, [1, 2, 3, 4], [make_arrays_cache(44), make_kv_cache(4)]) + + cache, rest = lru.fetch_nearest_cache(model, [1, 2, 9]) + + self.assertIsNone(cache) + self.assertEqual(rest, [1, 2, 9]) + + def test_nested_cache_list_preserves_non_trimmable_boundary(self): + lru = LRUPromptCache(max_size=4) + model = ("toy",) + nested = CacheList(make_arrays_cache(5), make_kv_cache(4)) + lru.insert_cache(model, [1, 2, 3, 4], [nested]) + + cache, rest = lru.fetch_nearest_cache(model, [1, 2, 9]) + + self.assertIsNone(cache) + self.assertEqual(rest, [1, 2, 9]) + + def test_saturated_rotating_cache_documents_current_longer_prefix_miss(self): + lru = LRUPromptCache(max_size=4) + model = ("toy",) + lru.insert_cache(model, [1, 2, 3, 4, 5, 6], [make_rotating_cache(6, max_size=4)]) + + cache, rest = lru.fetch_nearest_cache(model, [1, 2, 3, 9]) + + self.assertIsNone(cache) + self.assertEqual(rest, [1, 2, 3, 9]) + + def test_kv_cache_can_restore_to_shorter_prefix(self): + cache = make_kv_cache(5) + + self.assertTrue(cache.can_restore_to(3)) + restored = cache.restore_to(3) + + self.assertTrue(restored.restored) + self.assertEqual(cache.offset, 3) + + def test_saturated_rotating_cache_reports_retained_logical_range(self): + cache = make_rotating_cache(8, max_size=4) + + retained = cache.retained_range() + + self.assertEqual(retained.logical_start, 4) + self.assertEqual(retained.logical_end, 8) + + def test_saturated_rotating_cache_rejects_restore_requiring_evicted_context(self): + cache = make_rotating_cache(8, max_size=4) + + self.assertFalse(cache.can_restore_to(7)) + + def test_saturated_rotating_cache_accepts_current_boundary_noop_restore(self): + cache = make_rotating_cache(8, max_size=4) + + self.assertTrue(cache.can_restore_to(8)) + + def test_unsaturated_rotating_cache_restore_slices_temporal_state(self): + cache = make_rotating_cache(3, max_size=4) + + result = cache.restore_to(2) + + self.assertTrue(result.restored) + self.assertEqual(cache.offset, 2) + self.assertEqual(cache.keys.shape[2], 2) + + def test_saturated_rotating_cache_restore_to_current_boundary_is_noop(self): + cache = make_rotating_cache(8, max_size=4) + + result = cache.restore_to(8) + + self.assertTrue(result.restored) + self.assertEqual(cache.offset, 8) + self.assertEqual(cache.keys.shape[2], 4) + + def test_saturated_rotating_cache_restore_requiring_evicted_context_does_not_mutate(self): + cache = make_rotating_cache(8, max_size=4) + before_offset = cache.offset + before_shape = cache.keys.shape + + result = cache.restore_to(7) + + self.assertFalse(result.restored) + self.assertEqual(cache.offset, before_offset) + self.assertEqual(cache.keys.shape, before_shape) + + def test_lru_uses_restorable_longer_prefill_rotating_cache(self): + lru = LRUPromptCache(max_size=4) + model = ("toy",) + lru.insert_cache( + model, + [1, 2, 3, 4, 5, 6, 7, 8], + [make_prefill_rotating_cache(8, max_size=4)], + ) + + cache, rest = lru.fetch_nearest_cache(model, [1, 2, 3, 4, 5, 6, 7, 99]) + + self.assertIsNotNone(cache) + self.assertEqual(rest, [99]) + self.assertEqual(cache[0].offset, 7) + + def test_lru_prefers_shorter_safe_rotating_checkpoint_over_unrestorable_longer_hit(self): + lru = LRUPromptCache(max_size=4) + model = ("toy",) + lru.insert_cache(model, [1, 2, 3, 4, 5, 6, 7], [make_rotating_cache(7, max_size=4)]) + lru.insert_cache(model, [1, 2, 3, 4, 5, 6, 7, 8], [make_rotating_cache(8, max_size=4)]) + + cache, rest = lru.fetch_nearest_cache(model, [1, 2, 3, 4, 5, 6, 7, 99]) + + self.assertIsNotNone(cache) + self.assertEqual(rest, [99]) + self.assertEqual(cache[0].offset, 7) + + def test_cache_list_restores_when_all_children_restore(self): + cache = CacheList(make_kv_cache(8), make_prefill_rotating_cache(8, max_size=4)) + + self.assertTrue(cache.can_restore_to(7)) + result = cache.restore_to(7) + + self.assertTrue(result.restored) + self.assertEqual(cache[0].offset, 7) + self.assertEqual(cache[1].offset, 7) + + def test_cache_list_rejects_when_any_child_cannot_restore(self): + cache = CacheList(make_arrays_cache(1), make_kv_cache(5)) + + self.assertFalse(cache.can_restore_to(3)) + + def test_recurrent_cache_prefers_shorter_checkpoint_over_unrestorable_longer_hit(self): + lru = LRUPromptCache(max_size=8) + model = ("toy",) + lru.insert_cache(model, [1, 2], [make_arrays_cache(20), make_kv_cache(2)]) + lru.insert_cache(model, [1, 2, 3, 4, 5], [make_arrays_cache(50), make_kv_cache(5)]) + + cache, rest = lru.fetch_nearest_cache(model, [1, 2, 9, 10]) + + self.assertIsNotNone(cache) + self.assertEqual(rest, [9, 10]) + self.assertTrue(mx.array_equal(cache[0][0], mx.array([[20]], dtype=mx.float32))) + + def test_arrays_cache_cannot_restore_to_shorter_prefix_by_default(self): + cache = make_arrays_cache(7) + + self.assertFalse(cache.can_restore_to(1)) + result = cache.restore_to(1) + self.assertFalse(result.restored) + self.assertIn("exact checkpoint", result.reason) + self.assertTrue(mx.array_equal(cache[0], mx.array([[7]], dtype=mx.float32))) + + def test_arrays_cache_can_be_reused_for_exact_entry_only_through_lru(self): + lru = LRUPromptCache(max_size=4) + model = ("toy",) + lru.insert_cache(model, [1, 2, 3], [make_arrays_cache(7)]) + + cache, rest = lru.fetch_nearest_cache(model, [1, 2, 3]) + + self.assertEqual(rest, []) + self.assertIsNotNone(cache) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_prompt_cache.py b/tests/test_prompt_cache.py index bd1bc75ba..ef08a957d 100644 --- a/tests/test_prompt_cache.py +++ b/tests/test_prompt_cache.py @@ -28,6 +28,60 @@ HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" +class TestBatchRotatingKVCacheSerialization(unittest.TestCase): + def test_save_load_preserves_rotated_false(self): + with tempfile.TemporaryDirectory() as test_dir: + cache_file = os.path.join(test_dir, "prompt_cache.safetensors") + cache = BatchRotatingKVCache(max_size=4, left_padding=[0]) + x = mx.random.uniform(shape=(1, 2, 2, 4)) + cache.update_and_fetch(x, x) + cache.rotated = False + + save_prompt_cache(cache_file, [cache]) + loaded = load_prompt_cache(cache_file)[0] + + self.assertFalse(loaded.rotated) + self.assertEqual(loaded.max_size, 4) + self.assertEqual(loaded._offset, cache._offset) + self.assertEqual(loaded._idx, cache._idx) + + def test_save_load_preserves_rotated_true(self): + with tempfile.TemporaryDirectory() as test_dir: + cache_file = os.path.join(test_dir, "prompt_cache.safetensors") + cache = BatchRotatingKVCache(max_size=4, left_padding=[0]) + for _ in range(5): + x = mx.random.uniform(shape=(1, 2, 1, 4)) + cache.update_and_fetch(x, x) + self.assertTrue(cache.rotated) + + save_prompt_cache(cache_file, [cache]) + loaded = load_prompt_cache(cache_file)[0] + + self.assertTrue(loaded.rotated) + self.assertEqual(loaded.max_size, 4) + self.assertEqual(loaded._offset, cache._offset) + + +class TestChunkedKVCacheSerialization(unittest.TestCase): + def test_save_load_preserves_offsets_after_front_trim(self): + with tempfile.TemporaryDirectory() as test_dir: + cache_file = os.path.join(test_dir, "prompt_cache.safetensors") + cache = ChunkedKVCache(chunk_size=4) + x = mx.random.uniform(shape=(1, 2, 6, 4)) + cache.update_and_fetch(x, x) + cache.maybe_trim_front() + + self.assertEqual(cache.start_position, 2) + self.assertEqual(cache.offset, 6) + + save_prompt_cache(cache_file, [cache]) + loaded = load_prompt_cache(cache_file)[0] + + self.assertEqual(loaded.chunk_size, cache.chunk_size) + self.assertEqual(loaded.start_position, cache.start_position) + self.assertEqual(loaded.offset, cache.offset) + + class TestPromptCache(unittest.TestCase): @classmethod diff --git a/tests/test_server.py b/tests/test_server.py index 9a8a2ad14..08bf5b0e6 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -517,6 +517,41 @@ def keepalive_callback(processed_tokens, total_tokens): self.fail(f"Callback should handle BrokenPipeError: {e}") +class TestServerPromptCacheWiring(unittest.TestCase): + def test_run_wires_prompt_cache_bytes(self): + from unittest.mock import patch + + from mlx_lm.server import run + + class Args: + prompt_cache_size = 7 + prompt_cache_bytes = 12345 + + class Provider: + cli_args = Args() + + class Group: + def rank(self): + return 0 + + created = {} + + def fake_lru(max_size=10, max_bytes=1 << 63): + created["max_size"] = max_size + created["max_bytes"] = max_bytes + return LRUPromptCache(max_size=max_size, max_bytes=max_bytes) + + with patch("mlx_lm.server.LRUPromptCache", side_effect=fake_lru), patch( + "mlx_lm.server.ResponseGenerator" + ), patch("mlx_lm.server._run_http_server"), patch( + "mlx_lm.server.mx.distributed.init", return_value=Group() + ): + run("127.0.0.1", 0, Provider()) + + self.assertEqual(created["max_size"], 7) + self.assertEqual(created["max_bytes"], 12345) + + class TestLRUPromptCache(unittest.TestCase): def test_caching(self): cache = LRUPromptCache(max_size=10) From 9b0b4e581216c3b79e9d3f710049d5a6d578acf9 Mon Sep 17 00:00:00 2001 From: xxxkkw <646567098@qq.com> Date: Wed, 27 May 2026 20:27:17 +0800 Subject: [PATCH 2/6] Add prefix cache benchmark --- benchmarks/prefix_cache_benchmark.py | 304 +++++++++++++++++++++++++++ 1 file changed, 304 insertions(+) create mode 100644 benchmarks/prefix_cache_benchmark.py diff --git a/benchmarks/prefix_cache_benchmark.py b/benchmarks/prefix_cache_benchmark.py new file mode 100644 index 000000000..1e91d85dc --- /dev/null +++ b/benchmarks/prefix_cache_benchmark.py @@ -0,0 +1,304 @@ +import argparse +import json +import platform +import subprocess +import time +from pathlib import Path + +import mlx.core as mx + +from mlx_lm.generate import stream_generate +from mlx_lm.models.cache import LRUPromptCache, make_prompt_cache +from mlx_lm.utils import load + + +def encode(tokenizer, text): + return list(tokenizer.encode(text)) + + +def cache_nbytes(prompt_cache): + return int(sum(getattr(c, "nbytes", 0) for c in prompt_cache)) + + +def cache_logical_length(prompt_cache): + lengths = [c.logical_length() for c in prompt_cache if hasattr(c, "logical_length")] + return max(lengths) if lengths else None + + +def prefill_snapshot(model, tokens, prefill_step_size): + if not tokens: + raise ValueError("Cannot prefill an empty token sequence") + prompt_cache = make_prompt_cache(model) + token_array = mx.array(tokens) + start = time.perf_counter() + for offset in range(0, len(tokens), prefill_step_size): + chunk = token_array[offset : offset + prefill_step_size] + model(chunk[None], cache=prompt_cache) + mx.eval([c.state for c in prompt_cache]) + mx.clear_cache() + prefill_s = time.perf_counter() - start + return prompt_cache, prefill_s + + +def run_generation(model, tokenizer, tokens, prompt_cache, max_tokens, prefill_step_size): + if not tokens: + raise ValueError("Cannot generate from an empty token tail") + mx.clear_cache() + mx.reset_peak_memory() + start = time.perf_counter() + first_s = None + last = None + for response in stream_generate( + model, + tokenizer, + tokens, + max_tokens=max_tokens, + prompt_cache=prompt_cache, + prefill_step_size=prefill_step_size, + ): + if first_s is None: + first_s = time.perf_counter() - start + last = response + total_s = time.perf_counter() - start + mx.eval([c.state for c in prompt_cache]) + return { + "ttft_s": None if first_s is None else first_s, + "total_s": total_s, + "generated_tokens": 0 if last is None else int(last.generation_tokens), + "prompt_tps": None if last is None else float(last.prompt_tps), + "generation_tps": None if last is None else float(last.generation_tps), + "peak_memory_gb": float(mx.get_peak_memory()) / 1e9, + } + + +def run_cold(model, tokenizer, tokens, max_tokens, prefill_step_size): + prompt_cache = make_prompt_cache(model) + result = run_generation( + model, + tokenizer, + tokens, + prompt_cache, + max_tokens=max_tokens, + prefill_step_size=prefill_step_size, + ) + return { + "name": "cold_full_prefill", + "cache_hit_kind": "cold", + "full_prompt_tokens": len(tokens), + "cached_tokens": 0, + "replayed_tokens": len(tokens), + "cache_bytes": cache_nbytes(prompt_cache), + "cache_logical_length": cache_logical_length(prompt_cache), + **result, + } + + +def insert_snapshot(lru, model_key, key_tokens, prompt_cache, prefill_s): + lru.insert_cache(model_key, list(key_tokens), prompt_cache) + return { + "prefill_snapshot_tokens": len(key_tokens), + "inserted_key_tokens": len(key_tokens), + "cache_logical_length": cache_logical_length(prompt_cache), + "cache_bytes": cache_nbytes(prompt_cache), + "prefill_s": prefill_s, + } + + +def run_hot(name, cache_hit_kind, lru, model_key, model, tokenizer, tokens, max_tokens, prefill_step_size): + prompt_cache, rest = lru.fetch_nearest_cache(model_key, list(tokens)) + if prompt_cache is None: + prompt_cache = make_prompt_cache(model) + rest = list(tokens) + result = run_generation( + model, + tokenizer, + rest, + prompt_cache, + max_tokens=max_tokens, + prefill_step_size=prefill_step_size, + ) + return { + "name": name, + "cache_hit_kind": cache_hit_kind, + "full_prompt_tokens": len(tokens), + "cached_tokens": len(tokens) - len(rest), + "replayed_tokens": len(rest), + "cache_bytes": cache_nbytes(prompt_cache), + "cache_logical_length": cache_logical_length(prompt_cache), + **result, + } + + +def build_prompts(tokenizer, prompt, target_prefix_tokens): + if prompt: + shared_prefix = prompt + else: + sentence = ( + "Apple Silicon inference note: MLX uses unified memory and Metal kernels. " + "Prefix cache reuse matters for long multi-turn assistant workloads because shared system prompts, tools, and retrieved context should not be prefetched again. " + "This benchmark repeats technical context so prefill dominates a smoke-sized prompt. " + ) + blocks = [] + for index in range(1, 10000): + blocks.append(f"Context block {index}: {sentence}") + shared_prefix = "\n".join(blocks) + if len(encode(tokenizer, shared_prefix)) >= target_prefix_tokens: + break + suffix_a = "\nTask A: Produce a concise checklist for validating prefix-cache behavior on Apple Silicon. Include tests, profiling, and benchmark evidence.\nAnswer:" + suffix_b = "\nTask B: Produce a concise risk assessment for changing prefix-cache restore semantics. Include correctness hazards and benchmark requirements.\nAnswer:" + return shared_prefix, shared_prefix + suffix_a, shared_prefix + suffix_b + + +def rounded(obj): + if isinstance(obj, float): + return round(obj, 4) + if isinstance(obj, dict): + return {key: rounded(value) for key, value in obj.items()} + if isinstance(obj, list): + return [rounded(value) for value in obj] + return obj + + +def emit(event, **payload): + print(json.dumps(rounded({"event": event, **payload}), ensure_ascii=False), flush=True) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", required=True) + parser.add_argument("--prompt") + parser.add_argument("--target-prefix-tokens", type=int, default=2048) + parser.add_argument("--max-tokens", type=int, default=64) + parser.add_argument("--prefill-step-size", type=int, default=512) + args = parser.parse_args() + + repo = Path(__file__).resolve().parents[1] + try: + commit = subprocess.check_output( + ["git", "-C", str(repo), "rev-parse", "--short", "HEAD"], text=True + ).strip() + branch = subprocess.check_output( + ["git", "-C", str(repo), "branch", "--show-current"], text=True + ).strip() + except Exception: + commit = None + branch = None + + emit( + "env", + branch=branch, + commit=commit, + model=args.model, + platform=platform.platform(), + machine=platform.machine(), + ) + + start = time.perf_counter() + model, tokenizer = load(args.model) + emit("loaded", load_s=time.perf_counter() - start) + + shared_prefix, prompt_a, prompt_b = build_prompts( + tokenizer, args.prompt, args.target_prefix_tokens + ) + prefix_tokens = encode(tokenizer, shared_prefix) + tokens_a = encode(tokenizer, prompt_a) + tokens_b = encode(tokenizer, prompt_b) + emit( + "prompt", + shared_prefix_tokens=len(prefix_tokens), + prompt_a_tokens=len(tokens_a), + prompt_b_tokens=len(tokens_b), + max_tokens=args.max_tokens, + prefill_step_size=args.prefill_step_size, + ) + + warm_cache = make_prompt_cache(model) + for _ in stream_generate( + model, + tokenizer, + encode(tokenizer, "Warmup prompt for kernel compilation. Answer:"), + max_tokens=8, + prompt_cache=warm_cache, + prefill_step_size=args.prefill_step_size, + ): + pass + mx.eval([c.state for c in warm_cache]) + mx.clear_cache() + + cold = run_cold( + model, + tokenizer, + tokens_b, + max_tokens=args.max_tokens, + prefill_step_size=args.prefill_step_size, + ) + emit("result", **cold) + + exact_lru = LRUPromptCache(max_size=4) + exact_key = tokens_b[:-1] + exact_cache, exact_prefill_s = prefill_snapshot( + model, exact_key, args.prefill_step_size + ) + emit( + "snapshot", + name="exact_prefix_seed", + **insert_snapshot(exact_lru, args.model, exact_key, exact_cache, exact_prefill_s), + ) + exact_hot = run_hot( + "exact_prefix_hot_reuse", + "exact_prefix", + exact_lru, + args.model, + model, + tokenizer, + tokens_b, + max_tokens=args.max_tokens, + prefill_step_size=args.prefill_step_size, + ) + emit("result", **exact_hot) + + nearest_lru = LRUPromptCache(max_size=4) + nearest_cache, nearest_prefill_s = prefill_snapshot( + model, tokens_a, args.prefill_step_size + ) + emit( + "snapshot", + name="nearest_prefix_seed", + **insert_snapshot(nearest_lru, args.model, tokens_a, nearest_cache, nearest_prefill_s), + ) + nearest_hot = run_hot( + "nearest_prefix_hot_reuse", + "nearest_prefix", + nearest_lru, + args.model, + model, + tokenizer, + tokens_b, + max_tokens=args.max_tokens, + prefill_step_size=args.prefill_step_size, + ) + emit("result", **nearest_hot) + + emit( + "comparison", + exact_cached_tokens=exact_hot["cached_tokens"], + exact_replayed_tokens=exact_hot["replayed_tokens"], + exact_ttft_speedup_vs_cold=( + cold["ttft_s"] / exact_hot["ttft_s"] if exact_hot["ttft_s"] else None + ), + exact_total_speedup_vs_cold=( + cold["total_s"] / exact_hot["total_s"] if exact_hot["total_s"] else None + ), + nearest_cached_tokens=nearest_hot["cached_tokens"], + nearest_replayed_tokens=nearest_hot["replayed_tokens"], + nearest_ttft_speedup_vs_cold=( + cold["ttft_s"] / nearest_hot["ttft_s"] if nearest_hot["ttft_s"] else None + ), + nearest_total_speedup_vs_cold=( + cold["total_s"] / nearest_hot["total_s"] if nearest_hot["total_s"] else None + ), + ) + + +if __name__ == "__main__": + main() From b3617df83041729af5f059a9188fa313f22e9a3a Mon Sep 17 00:00:00 2001 From: xxxkkw <646567098@qq.com> Date: Wed, 27 May 2026 20:31:19 +0800 Subject: [PATCH 3/6] Add prefix cache session helper --- mlx_lm/generate.py | 28 ++++++++++++++++++++++++++ tests/test_prefix_cache_correctness.py | 15 ++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 3573b2640..a9dd9e3ca 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -56,6 +56,34 @@ DEFAULT_QUANTIZED_KV_START = 5000 +@dataclass +class PrefixCacheLookup: + prompt_cache: Optional[Any] + cached_tokens: int + tokens_to_process: List[int] + + +class PrefixCacheSession: + def __init__(self, max_size: int = 10, max_bytes: int = 1 << 63): + self._cache = cache.LRUPromptCache(max_size=max_size, max_bytes=max_bytes) + + def lookup(self, model_key, tokens): + prompt_cache, rest = self._cache.fetch_nearest_cache(model_key, tokens) + return PrefixCacheLookup( + prompt_cache=prompt_cache, + cached_tokens=len(tokens) - len(rest), + tokens_to_process=rest, + ) + + def insert(self, model_key, tokens, prompt_cache, cache_type="assistant"): + self._cache.insert_cache( + model_key, + tokens, + prompt_cache, + cache_type=cache_type, + ) + + def str2bool(string): return string.lower() not in ["false", "f"] diff --git a/tests/test_prefix_cache_correctness.py b/tests/test_prefix_cache_correctness.py index 51e7f2407..c4d714eee 100644 --- a/tests/test_prefix_cache_correctness.py +++ b/tests/test_prefix_cache_correctness.py @@ -229,5 +229,20 @@ def test_arrays_cache_can_be_reused_for_exact_entry_only_through_lru(self): self.assertIsNotNone(cache) +class TestPrefixCacheSession(unittest.TestCase): + def test_session_returns_rest_and_cached_count(self): + from mlx_lm.generate import PrefixCacheSession + + session = PrefixCacheSession(max_size=4) + model = ("toy",) + session.insert(model, [1, 2, 3], [make_kv_cache(3)]) + + hit = session.lookup(model, [1, 2, 3, 4]) + + self.assertEqual(hit.cached_tokens, 3) + self.assertEqual(hit.tokens_to_process, [4]) + self.assertIsNotNone(hit.prompt_cache) + + if __name__ == "__main__": unittest.main() From 01b7c89ebd924a7306afd084ca2050e360e37dff Mon Sep 17 00:00:00 2001 From: xxxkkw <646567098@qq.com> Date: Wed, 27 May 2026 20:51:52 +0800 Subject: [PATCH 4/6] Document single-path prompt checkpoint boundary --- mlx_lm/server.py | 40 +++++++++++++++++++++++++++++++++++++++- tests/test_server.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 2cca5e963..ac9eb9747 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -1,6 +1,7 @@ # Copyright © 2023-2024 Apple Inc. import argparse +import copy import json import logging import pickle @@ -469,6 +470,36 @@ def _log_cache_stats(self): f"- {cache_type}: {n_sequences} sequences, {n_bytes / 1e9:.2f} GB" ) + @staticmethod + def _should_insert_segment_cache(segment_type: str) -> bool: + return segment_type in {"system", "user", "assistant"} + + @staticmethod + def _remaining_uncached_segments(segments, segment_types, cached_tokens): + remaining_segments = [] + remaining_types = [] + cached = cached_tokens + for segment, segment_type in zip(segments, segment_types): + segment = segment[:] + if cached >= len(segment): + cached -= len(segment) + continue + if cached > 0: + segment = segment[cached:] + cached = 0 + remaining_segments.append(segment) + remaining_types.append(segment_type) + return remaining_segments, remaining_types + + def _insert_prompt_checkpoint(self, model_key, tokens, prompt_cache, cache_type): + if self._should_insert_segment_cache(cache_type): + self.prompt_cache.insert_cache( + model_key, + tokens[:], + copy.deepcopy(prompt_cache), + cache_type=cache_type, + ) + def _next_request(self, timeout=None): request = None if not self._is_distributed or self._rank == 0: @@ -933,7 +964,9 @@ def progress(tokens_processed, tokens_total): draft_model = self.model_provider.draft_model # Prepare the prompt and state machine - prompt, _, _, initial_state = self._tokenize(tokenizer, request, args) + prompt, segments, segment_types, initial_state = self._tokenize( + tokenizer, request, args + ) sm, sequences = self._make_state_machine( self.model_provider.model_key, tokenizer, @@ -966,6 +999,11 @@ def progress(tokens_processed, tokens_total): self.model_provider.model_key, prompt ) ctx.prompt_cache_count = len(prompt) - len(rest) + remaining_segments, remaining_types = self._remaining_uncached_segments( + segments, segment_types, ctx.prompt_cache_count + ) + # Single-request segment checkpoints need prefill boundary callbacks. + del remaining_segments, remaining_types cache_key = prompt[:] if cache is None: cache = make_prompt_cache(self.model_provider.model) diff --git a/tests/test_server.py b/tests/test_server.py index 08bf5b0e6..392f4a651 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -517,6 +517,48 @@ def keepalive_callback(processed_tokens, total_tokens): self.fail(f"Callback should handle BrokenPipeError: {e}") +class TestServerPromptCheckpointBoundary(unittest.TestCase): + def test_remaining_uncached_segments_drops_cached_prefix(self): + segments = [[1, 2], [3, 4, 5], [6]] + segment_types = ["system", "user", "assistant"] + + remaining, remaining_types = ResponseGenerator._remaining_uncached_segments( + segments, segment_types, 4 + ) + + self.assertEqual(remaining, [[5], [6]]) + self.assertEqual(remaining_types, ["user", "assistant"]) + self.assertEqual(segments, [[1, 2], [3, 4, 5], [6]]) + + def test_should_insert_segment_cache_only_accepts_prompt_segments(self): + self.assertTrue(ResponseGenerator._should_insert_segment_cache("system")) + self.assertTrue(ResponseGenerator._should_insert_segment_cache("user")) + self.assertTrue(ResponseGenerator._should_insert_segment_cache("assistant")) + self.assertFalse(ResponseGenerator._should_insert_segment_cache("tool")) + + def test_insert_prompt_checkpoint_respects_segment_type(self): + class Store: + def __init__(self): + self.calls = [] + + def insert_cache(self, *args, **kwargs): + self.calls.append((args, kwargs)) + + generator = ResponseGenerator.__new__(ResponseGenerator) + generator.prompt_cache = Store() + cache = [MockCache("checkpoint")] + + generator._insert_prompt_checkpoint(("model",), [1, 2], cache, "system") + generator._insert_prompt_checkpoint(("model",), [3], cache, "tool") + + self.assertEqual(len(generator.prompt_cache.calls), 1) + args, kwargs = generator.prompt_cache.calls[0] + self.assertEqual(args[0], ("model",)) + self.assertEqual(args[1], [1, 2]) + self.assertEqual(kwargs["cache_type"], "system") + self.assertIsNot(args[2], cache) + + class TestServerPromptCacheWiring(unittest.TestCase): def test_run_wires_prompt_cache_bytes(self): from unittest.mock import patch From f4d3e411fa846adb6e7b6ec379c677199aad4373 Mon Sep 17 00:00:00 2001 From: xxxkkw <646567098@qq.com> Date: Wed, 27 May 2026 21:03:52 +0800 Subject: [PATCH 5/6] Log cache types that disable batching --- mlx_lm/server.py | 18 ++++++++++++++---- tests/test_server.py | 16 ++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index ac9eb9747..6113a97e3 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -324,6 +324,19 @@ def __init__(self, cli_args: argparse.Namespace): if cli_args.chat_template: self._tokenizer_config["chat_template"] = cli_args.chat_template + @staticmethod + def _supports_cache_batching(model): + cache_types = [ + type(c).__name__ for c in make_prompt_cache(model) if not hasattr(c, "merge") + ] + if cache_types: + logging.info( + "Disabling batching for model because cache type %s does not implement merge()", + ", ".join(cache_types), + ) + return False + return True + def _load(self, model_path, adapter_path=None, draft_model_path=None): if self.is_distributed and ( adapter_path is not None or draft_model_path is not None @@ -369,10 +382,7 @@ def _load(self, model_path, adapter_path=None, draft_model_path=None): ) # Compute batchability - is_batchable = draft_model is None - is_batchable = is_batchable and all( - hasattr(c, "merge") for c in make_prompt_cache(model) - ) + is_batchable = draft_model is None and self._supports_cache_batching(model) # Update the member variables self.model_key = (model_path, adapter_path, draft_model_path) diff --git a/tests/test_server.py b/tests/test_server.py index 392f4a651..6da2abad5 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -14,6 +14,7 @@ from mlx_lm.server import ( APIHandler, LRUPromptCache, + ModelProvider, Response, ResponseGenerator, _process_control_tokens, @@ -517,6 +518,21 @@ def keepalive_callback(processed_tokens, total_tokens): self.fail(f"Callback should handle BrokenPipeError: {e}") +class TestModelProviderBatchability(unittest.TestCase): + def test_non_merge_cache_logs_batching_disable_reason(self): + from unittest.mock import patch + + class NotMergeable: + pass + + with patch("mlx_lm.server.make_prompt_cache", return_value=[NotMergeable()]): + with self.assertLogs(level="INFO") as logs: + is_batchable = ModelProvider._supports_cache_batching(object()) + + self.assertFalse(is_batchable) + self.assertIn("NotMergeable", "\n".join(logs.output)) + + class TestServerPromptCheckpointBoundary(unittest.TestCase): def test_remaining_uncached_segments_drops_cached_prefix(self): segments = [[1, 2], [3, 4, 5], [6]] From 813917e4690d5b1b66602040f3394130415e539b Mon Sep 17 00:00:00 2001 From: xxxkkw <646567098@qq.com> Date: Thu, 28 May 2026 23:42:35 +0800 Subject: [PATCH 6/6] Report prefix cache fetch timings Surface lookup, deepcopy, and restore timings from prompt-cache fetches so benchmarks can separate cache-management overhead from model prefill and decode time. --- benchmarks/prefix_cache_benchmark.py | 17 +++++- mlx_lm/models/cache.py | 80 +++++++++++++++++++++---- tests/test_prefix_cache_correctness.py | 82 ++++++++++++++++++++++++++ 3 files changed, 166 insertions(+), 13 deletions(-) diff --git a/benchmarks/prefix_cache_benchmark.py b/benchmarks/prefix_cache_benchmark.py index 1e91d85dc..e1fb6a7f3 100644 --- a/benchmarks/prefix_cache_benchmark.py +++ b/benchmarks/prefix_cache_benchmark.py @@ -25,6 +25,20 @@ def cache_logical_length(prompt_cache): return max(lengths) if lengths else None +def cache_fetch_payload(stats): + fetch_ms = stats.lookup_ms + stats.deepcopy_ms + stats.restore_ms + return { + "cache_fetch_ms": fetch_ms, + "cache_lookup_ms": stats.lookup_ms, + "cache_deepcopy_ms": stats.deepcopy_ms, + "cache_restore_ms": stats.restore_ms, + "cache_fetch_nbytes": stats.cache_nbytes, + "cache_fetch_matched_tokens": stats.matched_tokens, + "cache_fetch_hit_kind": stats.hit_kind, + "cache_fetch_fallback_reason": stats.fallback_reason, + } + + def prefill_snapshot(model, tokens, prefill_step_size): if not tokens: raise ValueError("Cannot prefill an empty token sequence") @@ -105,7 +119,7 @@ def insert_snapshot(lru, model_key, key_tokens, prompt_cache, prefill_s): def run_hot(name, cache_hit_kind, lru, model_key, model, tokenizer, tokens, max_tokens, prefill_step_size): - prompt_cache, rest = lru.fetch_nearest_cache(model_key, list(tokens)) + prompt_cache, rest, stats = lru.fetch_nearest_cache_with_stats(model_key, list(tokens)) if prompt_cache is None: prompt_cache = make_prompt_cache(model) rest = list(tokens) @@ -125,6 +139,7 @@ def run_hot(name, cache_hit_kind, lru, model_key, model, tokenizer, tokens, max_ "replayed_tokens": len(rest), "cache_bytes": cache_nbytes(prompt_cache), "cache_logical_length": cache_logical_length(prompt_cache), + **cache_fetch_payload(stats), **result, } diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index 3f1acb1a0..964863e1b 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -1,6 +1,7 @@ # Copyright © 2023-2024 Apple Inc. import copy +import time from collections import deque from dataclasses import dataclass from typing import Any, Dict, List, Optional @@ -1697,6 +1698,17 @@ class PromptTrieResult: common_prefix: int # Length of common prefix with any path +@dataclass +class PromptCacheFetchStats: + lookup_ms: float = 0.0 + deepcopy_ms: float = 0.0 + restore_ms: float = 0.0 + cache_nbytes: int = 0 + matched_tokens: int = 0 + fallback_reason: str = "miss" + hit_kind: str = "miss" + + class PromptTrie: def __init__(self): self._trie = {} @@ -1840,30 +1852,74 @@ def nbytes(self): return self._n_bytes def fetch_nearest_cache(self, model: Any, tokens: List[int]): + cache, rest, _ = self.fetch_nearest_cache_with_stats(model, tokens) + return cache, rest + + def fetch_nearest_cache_with_stats(self, model: Any, tokens: List[int]): + stats = PromptCacheFetchStats() + + start = time.perf_counter() result = self._trie.search(model, tokens) + stats.lookup_ms = (time.perf_counter() - start) * 1000 + + def copy_cache(cache_entry): + start = time.perf_counter() + cache = copy.deepcopy(cache_entry.prompt_cache) + stats.deepcopy_ms += (time.perf_counter() - start) * 1000 + stats.cache_nbytes = cache_entry.nbytes + return cache + if result.exact is not None: cache_entry = self._trie.get(result.model, result.exact) - return copy.deepcopy(cache_entry.prompt_cache), [] + stats.hit_kind = "exact" + stats.fallback_reason = "exact" + stats.matched_tokens = len(tokens) + return copy_cache(cache_entry), [], stats short_length = len(result.shorter) if result.shorter is not None else 0 + fallback_reason = "miss" if result.longer is not None and result.common_prefix > short_length: cache_entry = self._trie.get(result.model, result.longer) prefix = min(len(tokens) - 1, result.common_prefix) - if can_restore_prompt_cache(cache_entry.prompt_cache, prefix): - cache = copy.deepcopy(cache_entry.prompt_cache) - if restore_prompt_cache(cache, prefix): - return cache, tokens[prefix:] - elif can_trim_prompt_cache(cache_entry.prompt_cache): - cache = copy.deepcopy(cache_entry.prompt_cache) - num_to_trim = len(result.longer) - prefix - trim_prompt_cache(cache, num_to_trim) - return cache, tokens[prefix:] + start = time.perf_counter() + can_restore = can_restore_prompt_cache(cache_entry.prompt_cache, prefix) + stats.restore_ms += (time.perf_counter() - start) * 1000 + if can_restore: + cache = copy_cache(cache_entry) + start = time.perf_counter() + restored = restore_prompt_cache(cache, prefix) + stats.restore_ms += (time.perf_counter() - start) * 1000 + if restored: + stats.hit_kind = "longer_restore" + stats.fallback_reason = "longer_restore" + stats.matched_tokens = prefix + return cache, tokens[prefix:], stats + fallback_reason = "longer_restore_failed" + else: + start = time.perf_counter() + can_trim = can_trim_prompt_cache(cache_entry.prompt_cache) + stats.restore_ms += (time.perf_counter() - start) * 1000 + if can_trim: + cache = copy_cache(cache_entry) + num_to_trim = len(result.longer) - prefix + start = time.perf_counter() + trim_prompt_cache(cache, num_to_trim) + stats.restore_ms += (time.perf_counter() - start) * 1000 + stats.hit_kind = "longer_trim" + stats.fallback_reason = "longer_trim" + stats.matched_tokens = prefix + return cache, tokens[prefix:], stats + fallback_reason = "longer_not_restorable" if short_length > 0: cache_entry = self._trie.get(result.model, result.shorter) - return copy.deepcopy(cache_entry.prompt_cache), tokens[short_length:] + stats.hit_kind = "shorter" if fallback_reason == "miss" else "fallback_shorter" + stats.fallback_reason = fallback_reason if fallback_reason != "miss" else "shorter" + stats.matched_tokens = short_length + return copy_cache(cache_entry), tokens[short_length:], stats - return None, tokens + stats.fallback_reason = fallback_reason + return None, tokens, stats def insert_cache( self, diff --git a/tests/test_prefix_cache_correctness.py b/tests/test_prefix_cache_correctness.py index c4d714eee..b5ea45997 100644 --- a/tests/test_prefix_cache_correctness.py +++ b/tests/test_prefix_cache_correctness.py @@ -56,6 +56,25 @@ def test_exact_hit_reuses_non_trimmable_arrays_cache(self): self.assertIsNotNone(cache) self.assertTrue(mx.array_equal(cache[0][0], mx.array([[11]], dtype=mx.float32))) + def test_exact_hit_reports_fetch_stats(self): + lru = LRUPromptCache(max_size=4) + model = ("toy",) + prompt = [1, 2, 3] + prompt_cache = [make_kv_cache(3)] + lru.insert_cache(model, prompt, prompt_cache) + + cache, rest, stats = lru.fetch_nearest_cache_with_stats(model, prompt) + + self.assertEqual(rest, []) + self.assertIsNotNone(cache) + self.assertEqual(stats.hit_kind, "exact") + self.assertEqual(stats.fallback_reason, "exact") + self.assertEqual(stats.matched_tokens, 3) + self.assertEqual(stats.cache_nbytes, prompt_cache[0].nbytes) + self.assertGreaterEqual(stats.lookup_ms, 0) + self.assertGreaterEqual(stats.deepcopy_ms, 0) + self.assertEqual(stats.restore_ms, 0) + def test_shorter_checkpoint_reuses_non_trimmable_arrays_cache(self): lru = LRUPromptCache(max_size=4) model = ("toy",) @@ -68,6 +87,21 @@ def test_shorter_checkpoint_reuses_non_trimmable_arrays_cache(self): self.assertIsNotNone(cache) self.assertTrue(mx.array_equal(cache[0][0], mx.array([[22]], dtype=mx.float32))) + def test_shorter_checkpoint_reports_fetch_stats(self): + lru = LRUPromptCache(max_size=4) + model = ("toy",) + lru.insert_cache(model, [1, 2], [make_arrays_cache(22), make_kv_cache(2)]) + lru.insert_cache(model, [1, 2, 3, 4], [make_arrays_cache(44), make_kv_cache(4)]) + + cache, rest, stats = lru.fetch_nearest_cache_with_stats(model, [1, 2, 3, 9]) + + self.assertEqual(rest, [3, 9]) + self.assertIsNotNone(cache) + self.assertEqual(stats.hit_kind, "fallback_shorter") + self.assertEqual(stats.fallback_reason, "longer_not_restorable") + self.assertEqual(stats.matched_tokens, 2) + self.assertGreater(stats.cache_nbytes, 0) + def test_longer_non_trimmable_cache_is_not_used_without_checkpoint(self): lru = LRUPromptCache(max_size=4) model = ("toy",) @@ -170,6 +204,27 @@ def test_lru_uses_restorable_longer_prefill_rotating_cache(self): self.assertEqual(rest, [99]) self.assertEqual(cache[0].offset, 7) + def test_restorable_longer_hit_reports_fetch_stats(self): + lru = LRUPromptCache(max_size=4) + model = ("toy",) + lru.insert_cache( + model, + [1, 2, 3, 4, 5, 6, 7, 8], + [make_prefill_rotating_cache(8, max_size=4)], + ) + + cache, rest, stats = lru.fetch_nearest_cache_with_stats( + model, [1, 2, 3, 4, 5, 6, 7, 99] + ) + + self.assertIsNotNone(cache) + self.assertEqual(rest, [99]) + self.assertEqual(cache[0].offset, 7) + self.assertEqual(stats.hit_kind, "longer_restore") + self.assertEqual(stats.fallback_reason, "longer_restore") + self.assertEqual(stats.matched_tokens, 7) + self.assertGreaterEqual(stats.restore_ms, 0) + def test_lru_prefers_shorter_safe_rotating_checkpoint_over_unrestorable_longer_hit(self): lru = LRUPromptCache(max_size=4) model = ("toy",) @@ -228,6 +283,33 @@ def test_arrays_cache_can_be_reused_for_exact_entry_only_through_lru(self): self.assertEqual(rest, []) self.assertIsNotNone(cache) + def test_miss_reports_fetch_stats(self): + lru = LRUPromptCache(max_size=4) + model = ("toy",) + lru.insert_cache(model, [1, 2, 3], [make_arrays_cache(7)]) + + cache, rest, stats = lru.fetch_nearest_cache_with_stats(model, [9, 10]) + + self.assertIsNone(cache) + self.assertEqual(rest, [9, 10]) + self.assertEqual(stats.hit_kind, "miss") + self.assertEqual(stats.fallback_reason, "miss") + self.assertEqual(stats.matched_tokens, 0) + self.assertEqual(stats.cache_nbytes, 0) + + def test_fetch_stats_does_not_mutate_stored_entry(self): + lru = LRUPromptCache(max_size=4) + model = ("toy",) + lru.insert_cache(model, [1, 2, 3], [make_kv_cache(3)]) + + cache, _, _ = lru.fetch_nearest_cache_with_stats(model, [1, 2, 3]) + cache[0].trim(1) + cache_again, rest_again, stats = lru.fetch_nearest_cache_with_stats(model, [1, 2, 3]) + + self.assertEqual(rest_again, []) + self.assertEqual(cache_again[0].offset, 3) + self.assertEqual(stats.hit_kind, "exact") + class TestPrefixCacheSession(unittest.TestCase): def test_session_returns_rest_and_cached_count(self):