Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion tests/test_simple_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import asyncio
import threading
from concurrent.futures import ThreadPoolExecutor
from types import SimpleNamespace
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -1270,6 +1271,8 @@ async def test_start_keeps_text_routing_for_mllm_without_mtp(self):
"""MLLM text-only routing must stay available when MTP is disabled."""
from vllm_mlx.engine.simple import SimpleEngine

event_loop_thread = threading.get_ident()
captured = {}
text_model = MagicMock()
text_model.mtp = None
tokenizer = MagicMock()
Expand All @@ -1279,21 +1282,28 @@ async def test_start_keeps_text_routing_for_mllm_without_mtp(self):
mock_mllm.model = MagicMock()
mock_mllm.get_tokenizer.return_value = tokenizer

def build_text_model(*_args, **_kwargs):
captured["build_thread"] = threading.get_ident()
return text_model

with (
patch(
"vllm_mlx.models.mllm.MLXMultimodalLM",
return_value=mock_mllm,
),
patch(
"vllm_mlx.text_model_from_vlm.build_text_model",
return_value=text_model,
side_effect=build_text_model,
),
):
engine = SimpleEngine("qwen3.6-27b", force_mllm=True, mtp=False)
await engine.start()

assert engine._text_model is text_model
assert engine._text_tokenizer is tokenizer
assert captured["build_thread"] != event_loop_thread
assert engine._text_model_owner_thread == captured["build_thread"]
await engine.stop()

@pytest.mark.anyio
async def test_mllm_nonstream_text_only_routes_without_mtp(self):
Expand Down Expand Up @@ -1688,6 +1698,55 @@ def fake_make_logits_processors(**kwargs):
penalty_processor,
]

@pytest.mark.anyio
async def test_stream_generate_text_normal_path_uses_text_model_owner_worker(self):
"""VLM-derived TextModel generation must run on its build-owner worker."""
from vllm_mlx.engine.simple import SimpleEngine

event_loop_thread = threading.get_ident()
generation_threads = []

def fake_stream_generate(_model, _tokenizer, **_kwargs):
generation_threads.append(threading.get_ident())
yield SimpleNamespace(text="Hello", finish_reason="stop")

tokenizer = MagicMock()
tokenizer.apply_chat_template.return_value = "<|im_start|>user\nhello"
tokenizer.bos_token = None
tokenizer.eos_token_id = 42
tokenizer.encode.return_value = [1, 2, 3]

engine = SimpleEngine("test-model", force_mllm=True, mtp=False)
engine._loaded = True
engine._text_model = MagicMock()
engine._text_model.mtp = None
engine._text_tokenizer = tokenizer
engine._text_model_executor = ThreadPoolExecutor(max_workers=1)

def bind_owner_thread():
engine._text_model_owner_thread = threading.get_ident()

engine._text_model_executor.submit(bind_owner_thread).result(timeout=1.0)
owner_thread = engine._text_model_owner_thread

try:
with patch("mlx_lm.stream_generate", side_effect=fake_stream_generate):
outputs = [
chunk
async for chunk in engine._stream_generate_text(
messages=[{"role": "user", "content": "hello"}],
max_tokens=16,
temperature=0.7,
top_p=0.9,
)
]
finally:
await engine.stop()

assert outputs[-1].text == "Hello"
assert generation_threads == [owner_thread]
assert generation_threads[0] != event_loop_thread

@pytest.mark.anyio
async def test_stream_generate_text_disables_mtp_when_logits_processors_active(
self,
Expand Down
71 changes: 64 additions & 7 deletions vllm_mlx/engine/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import uuid
from collections import OrderedDict, deque
from collections.abc import AsyncIterator
from concurrent.futures import ThreadPoolExecutor
from typing import Any

# Re-entrancy guard for SimpleEngine._track_request_stream so that
Expand Down Expand Up @@ -203,6 +204,8 @@ def __init__(
# Per-request routing state (MLLM+MTP mode)
self._text_model = None
self._text_tokenizer = None
self._text_model_owner_thread: int | None = None
self._text_model_executor: ThreadPoolExecutor | None = None

# SpecPrefill draft model (loaded at start if enabled)
self._draft_model = None
Expand Down Expand Up @@ -348,15 +351,31 @@ async def start(self) -> None:
# on the slower mlx_vlm multimodal path.
if self._is_mllm and self._should_route_text_through_text_model():
try:
from ..text_model_from_vlm import build_text_model

self._text_model = build_text_model(
self._model.model, self._model_name
def build_text_route():
from ..text_model_from_vlm import build_text_model

text_model = build_text_model(
self._model.model, self._model_name
)
if text_model is None:
return None, None, None
return (
text_model,
self._model.get_tokenizer(),
threading.get_ident(),
)

(
self._text_model,
self._text_tokenizer,
self._text_model_owner_thread,
) = await self._run_blocking_serialized(
build_text_route,
executor=self._ensure_text_model_executor(),
)

if self._text_model is not None:
self._text_tokenizer = self._model.get_tokenizer()

# Apply Qwen3.5 eos_token fix (matches MLXLanguageModel.load)
if "qwen3" in self._model_name.lower():
self._text_tokenizer.eos_token = "<|im_end|>"
Expand Down Expand Up @@ -416,11 +435,13 @@ async def start(self) -> None:
else:
self._text_model = None
self._text_tokenizer = None
self._text_model_owner_thread = None

except Exception as e:
logger.error("MLLM text routing setup failed: %s", e)
self._text_model = None
self._text_tokenizer = None
self._text_model_owner_thread = None

# Load SpecPrefill draft model (small model for importance scoring)
if self._specprefill_enabled and self._specprefill_draft_model_path:
Expand Down Expand Up @@ -471,6 +492,11 @@ async def stop(self) -> None:
self._model = None
self._text_model = None
self._text_tokenizer = None
self._text_model_owner_thread = None
text_model_executor = self._text_model_executor
self._text_model_executor = None
if text_model_executor is not None:
text_model_executor.shutdown(wait=False, cancel_futures=True)
self._draft_model = None
self._loaded = False
self._system_kv_cache.clear()
Expand All @@ -485,7 +511,24 @@ def _should_route_text_through_text_model(
"""Return whether text-only MLLM requests may use mlx_lm TextModel."""
return not (mllm_draft_requested and self._mllm_draft_model_path is not None)

async def _run_blocking_serialized(self, func, /, *args, on_cancel=None, **kwargs):
def _ensure_text_model_executor(self) -> ThreadPoolExecutor:
"""Return the stable owner executor for VLM-derived TextModel calls."""
if self._text_model_executor is None:
self._text_model_executor = ThreadPoolExecutor(
max_workers=1,
thread_name_prefix="simple-text-model",
)
return self._text_model_executor

async def _run_blocking_serialized(
self,
func,
/,
*args,
on_cancel=None,
executor: ThreadPoolExecutor | None = None,
**kwargs,
):
"""Run a blocking MLX operation under the generation lock.

Cancellation must not release the async lock before the worker thread
Expand All @@ -498,7 +541,11 @@ def run_bound():
_bind_worker_generation_streams()
return func(*args, **kwargs)

task = asyncio.create_task(asyncio.to_thread(run_bound))
if executor is None:
task = asyncio.create_task(asyncio.to_thread(run_bound))
else:
loop = asyncio.get_running_loop()
task = asyncio.ensure_future(loop.run_in_executor(executor, run_bound))
try:
return await asyncio.shield(task)
except asyncio.CancelledError:
Expand Down Expand Up @@ -1927,6 +1974,7 @@ def make_cache_with_snapshot(
make_cache_with_snapshot,
self._text_model,
hit_candidate[0],
executor=self._text_model_executor,
)
)
# Bump LRU position now that we know we'll use it.
Expand Down Expand Up @@ -2070,6 +2118,14 @@ def _resume_after_processor_retirement(
def _run_all():
nonlocal backbone_cache, prompt_to_send

if (
self._text_model_owner_thread is not None
and threading.get_ident() != self._text_model_owner_thread
):
raise RuntimeError(
"VLM TextModel generation must run on its owner thread"
)

model = self._text_model
can_retire_processors = _processors_can_retire(all_processors)
use_mtp = (
Expand Down Expand Up @@ -2404,6 +2460,7 @@ async def _produce_responses() -> None:
await self._run_blocking_serialized(
_run_all,
on_cancel=abort_event.set,
executor=self._text_model_executor,
)
except asyncio.CancelledError:
raise
Expand Down
Loading