Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def _serve_args(**overrides):
"prefill_batch_size": 8,
"prefill_step_size": 512,
"prefix_cache_size": 100,
"prefix_trie_cache": False,
"prefix_trie_cache_size": 32,
"prefix_trie_cache_memory_mb": None,
"rate_limit": 0,
"reasoning_parser": None,
"served_model_name": None,
Expand Down
9 changes: 9 additions & 0 deletions tests/test_lifecycle_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,9 @@ def fake_load_model(*args, **kwargs):
specprefill_threshold=8192,
specprefill_keep_pct=0.3,
specprefill_draft_model=None,
prefix_trie_cache=False,
prefix_trie_cache_size=32,
prefix_trie_cache_memory_mb=None,
mcp_config=None,
api_key=None,
rate_limit=0,
Expand Down Expand Up @@ -417,6 +420,9 @@ def __init__(self, **kwargs):
specprefill_threshold=8192,
specprefill_keep_pct=0.3,
specprefill_draft_model=None,
prefix_trie_cache=False,
prefix_trie_cache_size=32,
prefix_trie_cache_memory_mb=None,
mcp_config=None,
api_key=None,
rate_limit=0,
Expand Down Expand Up @@ -534,6 +540,9 @@ def test_serve_command_describes_lazy_startup_without_claiming_model_is_loaded(
specprefill_threshold=8192,
specprefill_keep_pct=0.3,
specprefill_draft_model=None,
prefix_trie_cache=False,
prefix_trie_cache_size=32,
prefix_trie_cache_memory_mb=None,
mcp_config=None,
api_key=None,
rate_limit=0,
Expand Down
3 changes: 3 additions & 0 deletions tests/test_model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def _defaults() -> RegistryServeDefaults:
specprefill_threshold=8192,
specprefill_keep_pct=0.3,
specprefill_draft_model=None,
prefix_trie_cache=False,
prefix_trie_cache_size=32,
prefix_trie_cache_memory_mb=None,
stream_interval=1,
gpu_memory_utilization=0.9,
scheduler_config=None,
Expand Down
212 changes: 212 additions & 0 deletions tests/test_simple_engine_prefix_trie_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for SimpleEngine's optional mlx-lm prompt trie cache."""

import hashlib
from types import SimpleNamespace
from unittest.mock import patch

import mlx.core as mx
import pytest

from vllm_mlx.engine.simple import SimpleEngine

pytestmark = pytest.mark.anyio


class FakeTokenizer:
bos_token = None
eos_token_ids = []

def apply_chat_template(self, messages, **_kwargs):
rendered = ""
for message in messages:
rendered += f"<|{message['role']}|>{message.get('content', '')}\n"
return rendered + "<|assistant|>"

def encode(self, text, add_special_tokens=True):
tokens = [ord(ch) for ch in text]
return ([1] if add_special_tokens else []) + tokens


class FakeCache:
def __init__(self):
self.state = (
mx.array([[1]], dtype=mx.float32),
mx.array([[2]], dtype=mx.float32),
)
self.nbytes = 8

def is_trimmable(self):
return False


class FakeModel:
def __call__(self, *_args, **_kwargs):
return mx.zeros((1, 1, 4), dtype=mx.float32)


class NoFetchTrie:
nbytes = 0

def __len__(self):
return 0

def fetch_nearest_cache(self, *_args, **_kwargs):
pytest.fail("prefix trie lookup should not run on exact snapshot hit")

def insert_cache(self, *_args, **_kwargs):
return None


def _engine(**kwargs):
engine = SimpleEngine("test-model", **kwargs)
engine._loaded = True
engine._supports_system_kv_cache = True
engine._model = SimpleNamespace(model=FakeModel(), tokenizer=FakeTokenizer())
return engine


def _responses(tokens):
def fake_stream_generate(*_args, **kwargs):
seen_prompts = fake_stream_generate.seen_prompts
seen_prompts.append(kwargs["prompt"].tolist())
for token in tokens:
yield SimpleNamespace(text=chr(token), token=token, finish_reason="stop")

fake_stream_generate.seen_prompts = []
return fake_stream_generate


async def _collect(engine, messages):
return [
chunk
async for chunk in engine.stream_chat(
messages,
max_tokens=4,
temperature=0.0,
top_p=1.0,
)
]


async def test_prefix_trie_cache_reuses_growing_conversation_prefix():
engine = _engine(prefix_trie_cache=True, prefix_trie_cache_size=8)
fake_stream_generate = _responses([ord("X")])

with (
patch("mlx_lm.models.cache.make_prompt_cache", return_value=[FakeCache()]),
patch("mlx_lm.stream_generate", side_effect=fake_stream_generate),
):
await _collect(
engine,
[
{"role": "system", "content": "Rules"},
{"role": "user", "content": "first"},
],
)
await _collect(
engine,
[
{"role": "system", "content": "Rules"},
{"role": "user", "content": "first"},
{"role": "assistant", "content": "X"},
{"role": "user", "content": "second"},
],
)

stats = engine.get_stats()["prefix_trie_cache"]
assert stats["hits"] == 1
assert stats["tokens_saved"] > 0
assert stats["inserts"] == 2
assert len(fake_stream_generate.seen_prompts[1]) < len(
FakeTokenizer().encode(
FakeTokenizer().apply_chat_template(
[
{"role": "system", "content": "Rules"},
{"role": "user", "content": "first"},
{"role": "assistant", "content": "X"},
{"role": "user", "content": "second"},
]
)
)
)


async def test_existing_exact_snapshot_hit_wins_before_prefix_trie_lookup():
tokenizer = FakeTokenizer()
engine = _engine(prefix_trie_cache=True)
messages = [
{"role": "system", "content": "Rules"},
{"role": "user", "content": "first"},
]
rendered_a = tokenizer.apply_chat_template(
[
{"role": "system", "content": "Rules"},
{"role": "user", "content": "Alpha"},
]
)
rendered_b = tokenizer.apply_chat_template(
[
{"role": "system", "content": "Rules"},
{"role": "user", "content": "Bravo"},
]
)
boundary = next(i for i, (a, b) in enumerate(zip(rendered_a, rendered_b)) if a != b)
prefix = rendered_a[:boundary]
engine._system_kv_hash = hashlib.sha256(prefix.encode()).hexdigest()[:16]
engine._system_kv_token_count = len(
tokenizer.encode(prefix, add_special_tokens=True)
)
engine._system_kv_snapshot = [FakeCache().state]
engine._prefix_trie_cache = NoFetchTrie()

with (
patch("mlx_lm.models.cache.make_prompt_cache", return_value=[FakeCache()]),
patch("mlx_lm.stream_generate", side_effect=_responses([ord("Y")])),
):
await _collect(engine, messages)

assert engine.get_stats()["prefix_trie_cache"]["lookups"] == 0


async def test_prefix_trie_cache_is_disabled_by_default():
engine = _engine()

with (
patch("mlx_lm.models.cache.make_prompt_cache", return_value=[FakeCache()]),
patch("mlx_lm.stream_generate", side_effect=_responses([ord("Z")])),
):
await _collect(
engine,
[
{"role": "system", "content": "Rules"},
{"role": "user", "content": "first"},
],
)

assert "prefix_trie_cache" not in engine.get_stats()


async def test_prefix_trie_cache_honors_entry_bound():
engine = _engine(prefix_trie_cache=True, prefix_trie_cache_size=1)

with (
patch("mlx_lm.models.cache.make_prompt_cache", return_value=[FakeCache()]),
patch("mlx_lm.stream_generate", side_effect=_responses([ord("A"), ord("B")])),
):
await _collect(
engine,
[
{"role": "system", "content": "Rules"},
{"role": "user", "content": "one"},
],
)
await _collect(
engine,
[
{"role": "system", "content": "Other rules"},
{"role": "user", "content": "two"},
],
)

assert engine.get_stats()["prefix_trie_cache"]["entries"] == 1
37 changes: 37 additions & 0 deletions vllm_mlx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,16 @@ def serve_command(args):
f"threshold={args.specprefill_threshold}, "
f"keep={args.specprefill_keep_pct*100:.0f}%)"
)
if args.prefix_trie_cache:
memory = (
f", memory={args.prefix_trie_cache_memory_mb}MB"
if args.prefix_trie_cache_memory_mb is not None
else ""
)
print(
"Prefix trie cache: enabled "
f"(max_entries={args.prefix_trie_cache_size}{memory})"
)
if mllm_draft_model:
print(
"MLLM draft model: enabled "
Expand All @@ -349,6 +359,9 @@ def serve_command(args):
specprefill_threshold=args.specprefill_threshold,
specprefill_keep_pct=args.specprefill_keep_pct,
specprefill_draft_model=args.specprefill_draft_model,
prefix_trie_cache=args.prefix_trie_cache,
prefix_trie_cache_size=args.prefix_trie_cache_size,
prefix_trie_cache_memory_mb=args.prefix_trie_cache_memory_mb,
stream_interval=args.stream_interval if args.continuous_batching else 1,
gpu_memory_utilization=args.gpu_memory_utilization,
scheduler_config=scheduler_config,
Expand All @@ -375,6 +388,9 @@ def serve_command(args):
specprefill_threshold=args.specprefill_threshold,
specprefill_keep_pct=args.specprefill_keep_pct,
specprefill_draft_model=args.specprefill_draft_model,
prefix_trie_cache=args.prefix_trie_cache,
prefix_trie_cache_size=args.prefix_trie_cache_size,
prefix_trie_cache_memory_mb=args.prefix_trie_cache_memory_mb,
mllm_draft_model=mllm_draft_model,
mllm_draft_kind=mllm_draft_kind,
mllm_draft_block_size=mllm_draft_block_size,
Expand Down Expand Up @@ -1247,6 +1263,27 @@ def create_parser() -> argparse.ArgumentParser:
help="Path to small draft model for SpecPrefill importance scoring. "
"Must share the same tokenizer as the target model.",
)
serve_parser.add_argument(
"--prefix-trie-cache",
action="store_true",
default=False,
help=(
"Enable mlx-lm LRUPromptCache for pure-LLM SimpleEngine chat. "
"Default off; exact system-prefix snapshots still take precedence."
),
)
serve_parser.add_argument(
"--prefix-trie-cache-size",
type=make_positive_int_arg_parser("--prefix-trie-cache-size"),
default=32,
help="Maximum prompt-cache trie entries for --prefix-trie-cache.",
)
serve_parser.add_argument(
"--prefix-trie-cache-memory-mb",
type=make_positive_int_arg_parser("--prefix-trie-cache-memory-mb"),
default=None,
help="Optional prompt-cache trie memory cap in MB.",
)
# MLLM speculative draft/assistant model
serve_parser.add_argument(
"--mllm-draft-model",
Expand Down
Loading
Loading