From c8479501e8b7c09af950cac5b4feb6ffc18b940a Mon Sep 17 00:00:00 2001 From: Thump604 Date: Sun, 24 May 2026 08:57:39 -0500 Subject: [PATCH] Add SimpleEngine prefix trie cache --- tests/test_cli.py | 3 + tests/test_lifecycle_cli.py | 9 + tests/test_model_registry.py | 3 + tests/test_simple_engine_prefix_trie_cache.py | 212 ++++++++++++++++++ vllm_mlx/cli.py | 37 +++ vllm_mlx/engine/simple.py | 151 ++++++++++++- vllm_mlx/lifecycle.py | 3 + vllm_mlx/model_registry.py | 33 +++ vllm_mlx/server.py | 39 ++++ 9 files changed, 488 insertions(+), 2 deletions(-) create mode 100644 tests/test_simple_engine_prefix_trie_cache.py diff --git a/tests/test_cli.py b/tests/test_cli.py index c24d738c1..4b120fe4f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -45,6 +45,9 @@ def _serve_args(**overrides): "prefill_batch_size": 8, "prefill_step_size": 512, "prefix_cache_size": 100, + "prefix_trie_cache": False, + "prefix_trie_cache_size": 32, + "prefix_trie_cache_memory_mb": None, "rate_limit": 0, "reasoning_parser": None, "served_model_name": None, diff --git a/tests/test_lifecycle_cli.py b/tests/test_lifecycle_cli.py index 18805ac63..e34dae6ff 100644 --- a/tests/test_lifecycle_cli.py +++ b/tests/test_lifecycle_cli.py @@ -323,6 +323,9 @@ def fake_load_model(*args, **kwargs): specprefill_threshold=8192, specprefill_keep_pct=0.3, specprefill_draft_model=None, + prefix_trie_cache=False, + prefix_trie_cache_size=32, + prefix_trie_cache_memory_mb=None, mcp_config=None, api_key=None, rate_limit=0, @@ -417,6 +420,9 @@ def __init__(self, **kwargs): specprefill_threshold=8192, specprefill_keep_pct=0.3, specprefill_draft_model=None, + prefix_trie_cache=False, + prefix_trie_cache_size=32, + prefix_trie_cache_memory_mb=None, mcp_config=None, api_key=None, rate_limit=0, @@ -534,6 +540,9 @@ def test_serve_command_describes_lazy_startup_without_claiming_model_is_loaded( specprefill_threshold=8192, specprefill_keep_pct=0.3, specprefill_draft_model=None, + prefix_trie_cache=False, + prefix_trie_cache_size=32, + prefix_trie_cache_memory_mb=None, mcp_config=None, api_key=None, rate_limit=0, diff --git a/tests/test_model_registry.py b/tests/test_model_registry.py index c2204bd33..4f0f7a982 100644 --- a/tests/test_model_registry.py +++ b/tests/test_model_registry.py @@ -75,6 +75,9 @@ def _defaults() -> RegistryServeDefaults: specprefill_threshold=8192, specprefill_keep_pct=0.3, specprefill_draft_model=None, + prefix_trie_cache=False, + prefix_trie_cache_size=32, + prefix_trie_cache_memory_mb=None, stream_interval=1, gpu_memory_utilization=0.9, scheduler_config=None, diff --git a/tests/test_simple_engine_prefix_trie_cache.py b/tests/test_simple_engine_prefix_trie_cache.py new file mode 100644 index 000000000..0f9a40ac7 --- /dev/null +++ b/tests/test_simple_engine_prefix_trie_cache.py @@ -0,0 +1,212 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for SimpleEngine's optional mlx-lm prompt trie cache.""" + +import hashlib +from types import SimpleNamespace +from unittest.mock import patch + +import mlx.core as mx +import pytest + +from vllm_mlx.engine.simple import SimpleEngine + +pytestmark = pytest.mark.anyio + + +class FakeTokenizer: + bos_token = None + eos_token_ids = [] + + def apply_chat_template(self, messages, **_kwargs): + rendered = "" + for message in messages: + rendered += f"<|{message['role']}|>{message.get('content', '')}\n" + return rendered + "<|assistant|>" + + def encode(self, text, add_special_tokens=True): + tokens = [ord(ch) for ch in text] + return ([1] if add_special_tokens else []) + tokens + + +class FakeCache: + def __init__(self): + self.state = ( + mx.array([[1]], dtype=mx.float32), + mx.array([[2]], dtype=mx.float32), + ) + self.nbytes = 8 + + def is_trimmable(self): + return False + + +class FakeModel: + def __call__(self, *_args, **_kwargs): + return mx.zeros((1, 1, 4), dtype=mx.float32) + + +class NoFetchTrie: + nbytes = 0 + + def __len__(self): + return 0 + + def fetch_nearest_cache(self, *_args, **_kwargs): + pytest.fail("prefix trie lookup should not run on exact snapshot hit") + + def insert_cache(self, *_args, **_kwargs): + return None + + +def _engine(**kwargs): + engine = SimpleEngine("test-model", **kwargs) + engine._loaded = True + engine._supports_system_kv_cache = True + engine._model = SimpleNamespace(model=FakeModel(), tokenizer=FakeTokenizer()) + return engine + + +def _responses(tokens): + def fake_stream_generate(*_args, **kwargs): + seen_prompts = fake_stream_generate.seen_prompts + seen_prompts.append(kwargs["prompt"].tolist()) + for token in tokens: + yield SimpleNamespace(text=chr(token), token=token, finish_reason="stop") + + fake_stream_generate.seen_prompts = [] + return fake_stream_generate + + +async def _collect(engine, messages): + return [ + chunk + async for chunk in engine.stream_chat( + messages, + max_tokens=4, + temperature=0.0, + top_p=1.0, + ) + ] + + +async def test_prefix_trie_cache_reuses_growing_conversation_prefix(): + engine = _engine(prefix_trie_cache=True, prefix_trie_cache_size=8) + fake_stream_generate = _responses([ord("X")]) + + with ( + patch("mlx_lm.models.cache.make_prompt_cache", return_value=[FakeCache()]), + patch("mlx_lm.stream_generate", side_effect=fake_stream_generate), + ): + await _collect( + engine, + [ + {"role": "system", "content": "Rules"}, + {"role": "user", "content": "first"}, + ], + ) + await _collect( + engine, + [ + {"role": "system", "content": "Rules"}, + {"role": "user", "content": "first"}, + {"role": "assistant", "content": "X"}, + {"role": "user", "content": "second"}, + ], + ) + + stats = engine.get_stats()["prefix_trie_cache"] + assert stats["hits"] == 1 + assert stats["tokens_saved"] > 0 + assert stats["inserts"] == 2 + assert len(fake_stream_generate.seen_prompts[1]) < len( + FakeTokenizer().encode( + FakeTokenizer().apply_chat_template( + [ + {"role": "system", "content": "Rules"}, + {"role": "user", "content": "first"}, + {"role": "assistant", "content": "X"}, + {"role": "user", "content": "second"}, + ] + ) + ) + ) + + +async def test_existing_exact_snapshot_hit_wins_before_prefix_trie_lookup(): + tokenizer = FakeTokenizer() + engine = _engine(prefix_trie_cache=True) + messages = [ + {"role": "system", "content": "Rules"}, + {"role": "user", "content": "first"}, + ] + rendered_a = tokenizer.apply_chat_template( + [ + {"role": "system", "content": "Rules"}, + {"role": "user", "content": "Alpha"}, + ] + ) + rendered_b = tokenizer.apply_chat_template( + [ + {"role": "system", "content": "Rules"}, + {"role": "user", "content": "Bravo"}, + ] + ) + boundary = next(i for i, (a, b) in enumerate(zip(rendered_a, rendered_b)) if a != b) + prefix = rendered_a[:boundary] + engine._system_kv_hash = hashlib.sha256(prefix.encode()).hexdigest()[:16] + engine._system_kv_token_count = len( + tokenizer.encode(prefix, add_special_tokens=True) + ) + engine._system_kv_snapshot = [FakeCache().state] + engine._prefix_trie_cache = NoFetchTrie() + + with ( + patch("mlx_lm.models.cache.make_prompt_cache", return_value=[FakeCache()]), + patch("mlx_lm.stream_generate", side_effect=_responses([ord("Y")])), + ): + await _collect(engine, messages) + + assert engine.get_stats()["prefix_trie_cache"]["lookups"] == 0 + + +async def test_prefix_trie_cache_is_disabled_by_default(): + engine = _engine() + + with ( + patch("mlx_lm.models.cache.make_prompt_cache", return_value=[FakeCache()]), + patch("mlx_lm.stream_generate", side_effect=_responses([ord("Z")])), + ): + await _collect( + engine, + [ + {"role": "system", "content": "Rules"}, + {"role": "user", "content": "first"}, + ], + ) + + assert "prefix_trie_cache" not in engine.get_stats() + + +async def test_prefix_trie_cache_honors_entry_bound(): + engine = _engine(prefix_trie_cache=True, prefix_trie_cache_size=1) + + with ( + patch("mlx_lm.models.cache.make_prompt_cache", return_value=[FakeCache()]), + patch("mlx_lm.stream_generate", side_effect=_responses([ord("A"), ord("B")])), + ): + await _collect( + engine, + [ + {"role": "system", "content": "Rules"}, + {"role": "user", "content": "one"}, + ], + ) + await _collect( + engine, + [ + {"role": "system", "content": "Other rules"}, + {"role": "user", "content": "two"}, + ], + ) + + assert engine.get_stats()["prefix_trie_cache"]["entries"] == 1 diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index a0e7dee4e..1a2626d3d 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -332,6 +332,16 @@ def serve_command(args): f"threshold={args.specprefill_threshold}, " f"keep={args.specprefill_keep_pct*100:.0f}%)" ) + if args.prefix_trie_cache: + memory = ( + f", memory={args.prefix_trie_cache_memory_mb}MB" + if args.prefix_trie_cache_memory_mb is not None + else "" + ) + print( + "Prefix trie cache: enabled " + f"(max_entries={args.prefix_trie_cache_size}{memory})" + ) if mllm_draft_model: print( "MLLM draft model: enabled " @@ -349,6 +359,9 @@ def serve_command(args): specprefill_threshold=args.specprefill_threshold, specprefill_keep_pct=args.specprefill_keep_pct, specprefill_draft_model=args.specprefill_draft_model, + prefix_trie_cache=args.prefix_trie_cache, + prefix_trie_cache_size=args.prefix_trie_cache_size, + prefix_trie_cache_memory_mb=args.prefix_trie_cache_memory_mb, stream_interval=args.stream_interval if args.continuous_batching else 1, gpu_memory_utilization=args.gpu_memory_utilization, scheduler_config=scheduler_config, @@ -375,6 +388,9 @@ def serve_command(args): specprefill_threshold=args.specprefill_threshold, specprefill_keep_pct=args.specprefill_keep_pct, specprefill_draft_model=args.specprefill_draft_model, + prefix_trie_cache=args.prefix_trie_cache, + prefix_trie_cache_size=args.prefix_trie_cache_size, + prefix_trie_cache_memory_mb=args.prefix_trie_cache_memory_mb, mllm_draft_model=mllm_draft_model, mllm_draft_kind=mllm_draft_kind, mllm_draft_block_size=mllm_draft_block_size, @@ -1247,6 +1263,27 @@ def create_parser() -> argparse.ArgumentParser: help="Path to small draft model for SpecPrefill importance scoring. " "Must share the same tokenizer as the target model.", ) + serve_parser.add_argument( + "--prefix-trie-cache", + action="store_true", + default=False, + help=( + "Enable mlx-lm LRUPromptCache for pure-LLM SimpleEngine chat. " + "Default off; exact system-prefix snapshots still take precedence." + ), + ) + serve_parser.add_argument( + "--prefix-trie-cache-size", + type=make_positive_int_arg_parser("--prefix-trie-cache-size"), + default=32, + help="Maximum prompt-cache trie entries for --prefix-trie-cache.", + ) + serve_parser.add_argument( + "--prefix-trie-cache-memory-mb", + type=make_positive_int_arg_parser("--prefix-trie-cache-memory-mb"), + default=None, + help="Optional prompt-cache trie memory cap in MB.", + ) # MLLM speculative draft/assistant model serve_parser.add_argument( "--mllm-draft-model", diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 0fb28f55e..b791549c2 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -130,6 +130,9 @@ def __init__( mllm_draft_model: str | None = None, mllm_draft_kind: str | None = None, mllm_draft_block_size: int | None = None, + prefix_trie_cache: bool = False, + prefix_trie_cache_size: int = 32, + prefix_trie_cache_memory_mb: int | None = None, ): """ Initialize the simple engine. @@ -150,6 +153,9 @@ def __init__( mllm_draft_model: Optional MLLM speculative draft/assistant model path mllm_draft_kind: Optional mlx-vlm draft kind, for example "mtp" mllm_draft_block_size: Optional speculative block size for mlx-vlm + prefix_trie_cache: Enable mlx-lm LRUPromptCache on pure-LLM stream_chat + prefix_trie_cache_size: Maximum prompt-cache trie entries + prefix_trie_cache_memory_mb: Optional prompt-cache trie memory cap in MB """ self._model_name = model_name self._created_at = time.time() @@ -168,6 +174,19 @@ def __init__( self._mllm_draft_model_path = mllm_draft_model self._mllm_draft_kind = mllm_draft_kind self._mllm_draft_block_size = mllm_draft_block_size + self._prefix_trie_cache_enabled = prefix_trie_cache + self._prefix_trie_cache_size = max(1, prefix_trie_cache_size) + self._prefix_trie_cache_memory_mb = prefix_trie_cache_memory_mb + self._prefix_trie_cache = None + self._prefix_trie_cache_lock = threading.Lock() + self._prefix_trie_cache_stats = { + "lookups": 0, + "hits": 0, + "misses": 0, + "inserts": 0, + "skips": 0, + "tokens_saved": 0, + } # KV cache size limit self._max_kv_size = max_kv_size @@ -196,6 +215,87 @@ def __init__( # restore would silently desynchronize. Probed once in ``start()``. self._supports_system_kv_cache: bool = False + def _ensure_prefix_trie_cache(self) -> Any | None: + """Return the optional mlx-lm prompt trie cache, creating it lazily.""" + if not self._prefix_trie_cache_enabled: + return None + if self._is_mllm: + self._prefix_trie_cache_stats["skips"] += 1 + return None + with self._prefix_trie_cache_lock: + if self._prefix_trie_cache is None: + from mlx_lm.models.cache import LRUPromptCache + + max_bytes = ( + self._prefix_trie_cache_memory_mb * 1024 * 1024 + if self._prefix_trie_cache_memory_mb is not None + else 1 << 63 + ) + self._prefix_trie_cache = LRUPromptCache( + max_size=self._prefix_trie_cache_size, + max_bytes=max_bytes, + ) + return self._prefix_trie_cache + + def _fetch_prefix_trie_cache( + self, model: Any, tokens: list[int] + ) -> tuple[Any | None, list[int] | None, int]: + """Fetch a nearest prompt-cache trie entry for a full prompt token list.""" + prefix_trie = self._ensure_prefix_trie_cache() + if prefix_trie is None: + return None, None, 0 + + self._prefix_trie_cache_stats["lookups"] += 1 + try: + with self._prefix_trie_cache_lock: + trie_cache, trie_rest = prefix_trie.fetch_nearest_cache(model, tokens) + if trie_cache is None or trie_rest is None or len(trie_rest) >= len(tokens): + self._prefix_trie_cache_stats["misses"] += 1 + return None, None, 0 + + if len(trie_rest) == 0: + from mlx_lm.models.cache import can_trim_prompt_cache, trim_prompt_cache + + if not can_trim_prompt_cache(trie_cache): + raise ValueError("exact prefix-trie cache hit is not trimmable") + trim_prompt_cache(trie_cache, 1) + trie_rest = [tokens[-1]] + + tokens_saved = len(tokens) - len(trie_rest) + self._prefix_trie_cache_stats["hits"] += 1 + self._prefix_trie_cache_stats["tokens_saved"] += tokens_saved + return trie_cache, list(trie_rest), tokens_saved + except Exception as e: + self._prefix_trie_cache_stats["skips"] += 1 + logger.debug("Prefix trie cache lookup skipped after failure (%s)", e) + return None, None, 0 + + def _insert_prefix_trie_cache( + self, model: Any, cache_key: list[int], prompt_cache: Any + ) -> None: + """Insert a completed prompt cache into the optional prompt trie cache.""" + prefix_trie = self._ensure_prefix_trie_cache() + if prefix_trie is None or not cache_key: + return + try: + with self._prefix_trie_cache_lock: + prefix_trie.insert_cache(model, cache_key, prompt_cache) + self._prefix_trie_cache_stats["inserts"] += 1 + except Exception as e: + self._prefix_trie_cache_stats["skips"] += 1 + logger.debug("Prefix trie cache insert skipped (%s)", e) + + def _prefix_trie_cache_snapshot(self) -> tuple[int, int]: + """Return current prompt-trie entry and byte counts.""" + with self._prefix_trie_cache_lock: + entries = ( + len(self._prefix_trie_cache) + if self._prefix_trie_cache is not None + else 0 + ) + nbytes = self._prefix_trie_cache.nbytes if self._prefix_trie_cache else 0 + return entries, nbytes + @property def model_name(self) -> str: """Get the model name.""" @@ -397,6 +497,8 @@ async def stop(self) -> None: self._system_kv_hash = None self._system_kv_token_count = 0 self._supports_system_kv_cache = False + with self._prefix_trie_cache_lock: + self._prefix_trie_cache = None logger.info("SimpleEngine stopped") def _should_route_text_through_text_model( @@ -955,8 +1057,12 @@ def run_stream(): system_tokens = None system_token_count = 0 full_token_count = 0 + full_tokens_list: list[int] | None = None system_hash = None kv_cache_eligible = False + prefix_trie_hit = False + prefix_trie_rest_tokens: list[int] | None = None + prefix_trie_tokens_saved = 0 # Snapshot reference captured at gate time so a concurrent MISS that # reassigns ``self._system_kv_snapshot`` between the gate and the # restore (which runs later inside ``_run_blocking_serialized``) @@ -1146,6 +1252,22 @@ def _with_user(user_content: str) -> list[dict[str, Any]]: len(suffix_tokens), system_hash, ) + trie_cache, trie_rest, trie_tokens_saved = ( + self._fetch_prefix_trie_cache( + self._model.model, full_tokens_list + ) + ) + if trie_cache is not None and trie_rest is not None: + prefix_trie_hit = True + hit_snapshot = trie_cache + prefix_trie_rest_tokens = trie_rest + prefix_trie_tokens_saved = trie_tokens_saved + logger.info( + "Prefix trie cache HIT (stream_chat): " + "reusing %d tokens, prefilling %d new", + prefix_trie_tokens_saved, + len(trie_rest), + ) if kv_cache_eligible: # Cache-aware path: drive mlx-lm directly with a pre-populated cache. @@ -1175,7 +1297,11 @@ def _run_with_cache() -> None: model = self._model.model sampler = make_sampler(temp=temperature, top_p=top_p) - if cache_hit: + cache_key = list(full_tokens_list or []) + + if prefix_trie_hit: + bc = hit_snapshot + elif cache_hit: bc = make_prompt_cache(model) # Restore from the closure-local reference captured at the # gate, never from ``self._system_kv_snapshot`` directly: @@ -1219,7 +1345,10 @@ def _run_with_cache() -> None: cache_mb, ) - prompt_arr = mx.array(suffix_tokens) + prompt_tokens_for_decode = ( + prefix_trie_rest_tokens if prefix_trie_hit else suffix_tokens + ) + prompt_arr = mx.array(prompt_tokens_for_decode) for resp in mlx_stream_generate( model, tokenizer, @@ -1230,8 +1359,17 @@ def _run_with_cache() -> None: ): if abort_event.is_set(): break + token = getattr(resp, "token", None) + if token is not None: + try: + cache_key.append(int(token)) + except TypeError: + cache_key.append(int(token.item())) _emit_response(resp) + if not abort_event.is_set() and bc is not None: + self._insert_prefix_trie_cache(model, cache_key, bc) + async def _produce_responses() -> None: try: await self._run_blocking_serialized( @@ -2236,6 +2374,15 @@ def get_stats(self) -> dict[str, Any]: "memory_mb": round(cache_bytes / 1e6, 1), } + if self._prefix_trie_cache_enabled: + trie_entries, trie_bytes = self._prefix_trie_cache_snapshot() + stats["prefix_trie_cache"] = { + "enabled": True, + **self._prefix_trie_cache_stats, + "entries": trie_entries, + "memory_mb": round(trie_bytes / 1e6, 1), + } + # Include Metal memory stats try: import mlx.core as mx diff --git a/vllm_mlx/lifecycle.py b/vllm_mlx/lifecycle.py index d5cb86141..9e1e178a9 100644 --- a/vllm_mlx/lifecycle.py +++ b/vllm_mlx/lifecycle.py @@ -41,6 +41,9 @@ class ModelSpec: specprefill_threshold: int = 8192 specprefill_keep_pct: float = 0.3 specprefill_draft_model: str | None = None + prefix_trie_cache: bool = False + prefix_trie_cache_size: int = 32 + prefix_trie_cache_memory_mb: int | None = None @dataclass diff --git a/vllm_mlx/model_registry.py b/vllm_mlx/model_registry.py index 2ab9e88b1..c9156fd5b 100644 --- a/vllm_mlx/model_registry.py +++ b/vllm_mlx/model_registry.py @@ -116,6 +116,9 @@ class RegistryServeDefaults: specprefill_threshold: int specprefill_keep_pct: float specprefill_draft_model: str | None + prefix_trie_cache: bool + prefix_trie_cache_size: int + prefix_trie_cache_memory_mb: int | None stream_interval: int gpu_memory_utilization: float scheduler_config: SchedulerConfig | None @@ -155,6 +158,9 @@ class RegisteredModel: specprefill_threshold: int | None = None specprefill_keep_pct: float | None = None specprefill_draft_model: str | None = None + prefix_trie_cache: bool | None = None + prefix_trie_cache_size: int | None = None + prefix_trie_cache_memory_mb: int | None = None stream_interval: int | None = None gpu_memory_utilization: float | None = None estimated_memory_bytes: int | None = None @@ -174,6 +180,9 @@ class ResolvedModelConfig: specprefill_threshold: int specprefill_keep_pct: float specprefill_draft_model: str | None + prefix_trie_cache: bool + prefix_trie_cache_size: int + prefix_trie_cache_memory_mb: int | None stream_interval: int gpu_memory_utilization: float scheduler_config: SchedulerConfig | None @@ -349,6 +358,9 @@ def load_registry_config( specprefill_threshold=item.get("specprefill_threshold"), specprefill_keep_pct=item.get("specprefill_keep_pct"), specprefill_draft_model=item.get("specprefill_draft_model"), + prefix_trie_cache=item.get("prefix_trie_cache"), + prefix_trie_cache_size=item.get("prefix_trie_cache_size"), + prefix_trie_cache_memory_mb=item.get("prefix_trie_cache_memory_mb"), stream_interval=item.get("stream_interval"), gpu_memory_utilization=item.get("gpu_memory_utilization"), estimated_memory_bytes=estimated_bytes, @@ -798,6 +810,9 @@ async def _instantiate_model( specprefill_threshold=config.specprefill_threshold, specprefill_keep_pct=config.specprefill_keep_pct, specprefill_draft_model=config.specprefill_draft_model, + prefix_trie_cache=config.prefix_trie_cache, + prefix_trie_cache_size=config.prefix_trie_cache_size, + prefix_trie_cache_memory_mb=config.prefix_trie_cache_memory_mb, ) await engine.start() @@ -896,6 +911,21 @@ def _resolve_model_config( if entry.specprefill_draft_model is not None else self._defaults.specprefill_draft_model ) + prefix_trie_cache = ( + entry.prefix_trie_cache + if entry.prefix_trie_cache is not None + else self._defaults.prefix_trie_cache + ) + prefix_trie_cache_size = ( + entry.prefix_trie_cache_size + if entry.prefix_trie_cache_size is not None + else self._defaults.prefix_trie_cache_size + ) + prefix_trie_cache_memory_mb = ( + entry.prefix_trie_cache_memory_mb + if entry.prefix_trie_cache_memory_mb is not None + else self._defaults.prefix_trie_cache_memory_mb + ) stream_interval = ( entry.stream_interval if entry.stream_interval is not None @@ -919,6 +949,9 @@ def _resolve_model_config( specprefill_threshold=specprefill_threshold, specprefill_keep_pct=specprefill_keep_pct, specprefill_draft_model=specprefill_draft_model, + prefix_trie_cache=prefix_trie_cache, + prefix_trie_cache_size=prefix_trie_cache_size, + prefix_trie_cache_memory_mb=prefix_trie_cache_memory_mb, stream_interval=stream_interval, gpu_memory_utilization=gpu_memory_utilization, scheduler_config=scheduler_config, diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 1de223f05..0240d4b5e 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -1162,6 +1162,9 @@ def _build_engine(spec: ModelSpec) -> BaseEngine: specprefill_keep_pct=spec.specprefill_keep_pct, specprefill_draft_model=spec.specprefill_draft_model, max_kv_size=max_kv_size, + prefix_trie_cache=spec.prefix_trie_cache, + prefix_trie_cache_size=spec.prefix_trie_cache_size, + prefix_trie_cache_memory_mb=spec.prefix_trie_cache_memory_mb, ) @@ -2915,6 +2918,9 @@ def load_model( specprefill_threshold: int = 8192, specprefill_keep_pct: float = 0.3, specprefill_draft_model: str = None, + prefix_trie_cache: bool = False, + prefix_trie_cache_size: int = 32, + prefix_trie_cache_memory_mb: int | None = None, mllm_draft_model: str | None = None, mllm_draft_kind: str | None = None, mllm_draft_block_size: int | None = None, @@ -2940,6 +2946,9 @@ def load_model( specprefill_threshold: Minimum suffix tokens to trigger SpecPrefill (default: 8192) specprefill_keep_pct: Fraction of tokens to keep (default: 0.3) specprefill_draft_model: Path to small draft model for SpecPrefill scoring + prefix_trie_cache: Enable mlx-lm LRUPromptCache for pure-LLM SimpleEngine chat + prefix_trie_cache_size: Maximum prompt-cache trie entries + prefix_trie_cache_memory_mb: Optional prompt-cache trie memory cap in MB mllm_draft_model: Optional MLLM speculative draft/assistant model path. mllm_draft_kind: Optional mlx-vlm draft kind, for example "mtp". mllm_draft_block_size: Optional speculative block size passed to mlx-vlm. @@ -3036,6 +3045,9 @@ def load_model( specprefill_threshold=specprefill_threshold, specprefill_keep_pct=specprefill_keep_pct, specprefill_draft_model=specprefill_draft_model, + prefix_trie_cache=prefix_trie_cache, + prefix_trie_cache_size=prefix_trie_cache_size, + prefix_trie_cache_memory_mb=prefix_trie_cache_memory_mb, ) _residency_manager = ResidencyManager( _engine_factory, @@ -3091,6 +3103,9 @@ def load_model( mllm_draft_model=mllm_draft_model, mllm_draft_kind=mllm_draft_kind, mllm_draft_block_size=mllm_draft_block_size, + prefix_trie_cache=prefix_trie_cache, + prefix_trie_cache_size=prefix_trie_cache_size, + prefix_trie_cache_memory_mb=prefix_trie_cache_memory_mb, ) # Start SimpleEngine synchronously (no background loop) # Use new_event_loop() for Python 3.10+ compatibility (get_event_loop() is deprecated) @@ -6364,6 +6379,9 @@ def main(): mllm_draft_model=args.mllm_draft_model, mllm_draft_kind=args.mllm_draft_kind, mllm_draft_block_size=args.mllm_draft_block_size, + prefix_trie_cache=args.prefix_trie_cache, + prefix_trie_cache_size=args.prefix_trie_cache_size, + prefix_trie_cache_memory_mb=args.prefix_trie_cache_memory_mb, auto_unload_idle_seconds=args.auto_unload_idle_seconds, lazy_load_model=args.lazy_load_model, ) @@ -6448,6 +6466,27 @@ def create_parser() -> argparse.ArgumentParser: default=None, help="Draft block size passed to mlx-vlm for --mllm-draft-model.", ) + parser.add_argument( + "--prefix-trie-cache", + action="store_true", + default=False, + help=( + "Enable mlx-lm LRUPromptCache for pure-LLM SimpleEngine chat. " + "Default off; exact system-prefix snapshots still take precedence." + ), + ) + parser.add_argument( + "--prefix-trie-cache-size", + type=make_positive_int_arg_parser("--prefix-trie-cache-size"), + default=32, + help="Maximum prompt-cache trie entries for --prefix-trie-cache.", + ) + parser.add_argument( + "--prefix-trie-cache-memory-mb", + type=make_positive_int_arg_parser("--prefix-trie-cache-memory-mb"), + default=None, + help="Optional prompt-cache trie memory cap in MB.", + ) parser.add_argument( "--mcp-config", type=str,