diff --git a/tests/test_simple_engine.py b/tests/test_simple_engine.py index 607428e6c..99975e8a1 100644 --- a/tests/test_simple_engine.py +++ b/tests/test_simple_engine.py @@ -3,6 +3,7 @@ import asyncio import threading +from concurrent.futures import ThreadPoolExecutor from types import SimpleNamespace from unittest.mock import MagicMock, patch @@ -1270,6 +1271,8 @@ async def test_start_keeps_text_routing_for_mllm_without_mtp(self): """MLLM text-only routing must stay available when MTP is disabled.""" from vllm_mlx.engine.simple import SimpleEngine + event_loop_thread = threading.get_ident() + captured = {} text_model = MagicMock() text_model.mtp = None tokenizer = MagicMock() @@ -1279,6 +1282,10 @@ async def test_start_keeps_text_routing_for_mllm_without_mtp(self): mock_mllm.model = MagicMock() mock_mllm.get_tokenizer.return_value = tokenizer + def build_text_model(*_args, **_kwargs): + captured["build_thread"] = threading.get_ident() + return text_model + with ( patch( "vllm_mlx.models.mllm.MLXMultimodalLM", @@ -1286,7 +1293,7 @@ async def test_start_keeps_text_routing_for_mllm_without_mtp(self): ), patch( "vllm_mlx.text_model_from_vlm.build_text_model", - return_value=text_model, + side_effect=build_text_model, ), ): engine = SimpleEngine("qwen3.6-27b", force_mllm=True, mtp=False) @@ -1294,6 +1301,9 @@ async def test_start_keeps_text_routing_for_mllm_without_mtp(self): assert engine._text_model is text_model assert engine._text_tokenizer is tokenizer + assert captured["build_thread"] != event_loop_thread + assert engine._text_model_owner_thread == captured["build_thread"] + await engine.stop() @pytest.mark.anyio async def test_mllm_nonstream_text_only_routes_without_mtp(self): @@ -1688,6 +1698,55 @@ def fake_make_logits_processors(**kwargs): penalty_processor, ] + @pytest.mark.anyio + async def test_stream_generate_text_normal_path_uses_text_model_owner_worker(self): + """VLM-derived TextModel generation must run on its build-owner worker.""" + from vllm_mlx.engine.simple import SimpleEngine + + event_loop_thread = threading.get_ident() + generation_threads = [] + + def fake_stream_generate(_model, _tokenizer, **_kwargs): + generation_threads.append(threading.get_ident()) + yield SimpleNamespace(text="Hello", finish_reason="stop") + + tokenizer = MagicMock() + tokenizer.apply_chat_template.return_value = "<|im_start|>user\nhello" + tokenizer.bos_token = None + tokenizer.eos_token_id = 42 + tokenizer.encode.return_value = [1, 2, 3] + + engine = SimpleEngine("test-model", force_mllm=True, mtp=False) + engine._loaded = True + engine._text_model = MagicMock() + engine._text_model.mtp = None + engine._text_tokenizer = tokenizer + engine._text_model_executor = ThreadPoolExecutor(max_workers=1) + + def bind_owner_thread(): + engine._text_model_owner_thread = threading.get_ident() + + engine._text_model_executor.submit(bind_owner_thread).result(timeout=1.0) + owner_thread = engine._text_model_owner_thread + + try: + with patch("mlx_lm.stream_generate", side_effect=fake_stream_generate): + outputs = [ + chunk + async for chunk in engine._stream_generate_text( + messages=[{"role": "user", "content": "hello"}], + max_tokens=16, + temperature=0.7, + top_p=0.9, + ) + ] + finally: + await engine.stop() + + assert outputs[-1].text == "Hello" + assert generation_threads == [owner_thread] + assert generation_threads[0] != event_loop_thread + @pytest.mark.anyio async def test_stream_generate_text_disables_mtp_when_logits_processors_active( self, diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 41c460038..83fc7d969 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -16,6 +16,7 @@ import uuid from collections import OrderedDict, deque from collections.abc import AsyncIterator +from concurrent.futures import ThreadPoolExecutor from typing import Any # Re-entrancy guard for SimpleEngine._track_request_stream so that @@ -203,6 +204,8 @@ def __init__( # Per-request routing state (MLLM+MTP mode) self._text_model = None self._text_tokenizer = None + self._text_model_owner_thread: int | None = None + self._text_model_executor: ThreadPoolExecutor | None = None # SpecPrefill draft model (loaded at start if enabled) self._draft_model = None @@ -348,15 +351,31 @@ async def start(self) -> None: # on the slower mlx_vlm multimodal path. if self._is_mllm and self._should_route_text_through_text_model(): try: - from ..text_model_from_vlm import build_text_model - self._text_model = build_text_model( - self._model.model, self._model_name + def build_text_route(): + from ..text_model_from_vlm import build_text_model + + text_model = build_text_model( + self._model.model, self._model_name + ) + if text_model is None: + return None, None, None + return ( + text_model, + self._model.get_tokenizer(), + threading.get_ident(), + ) + + ( + self._text_model, + self._text_tokenizer, + self._text_model_owner_thread, + ) = await self._run_blocking_serialized( + build_text_route, + executor=self._ensure_text_model_executor(), ) if self._text_model is not None: - self._text_tokenizer = self._model.get_tokenizer() - # Apply Qwen3.5 eos_token fix (matches MLXLanguageModel.load) if "qwen3" in self._model_name.lower(): self._text_tokenizer.eos_token = "<|im_end|>" @@ -416,11 +435,13 @@ async def start(self) -> None: else: self._text_model = None self._text_tokenizer = None + self._text_model_owner_thread = None except Exception as e: logger.error("MLLM text routing setup failed: %s", e) self._text_model = None self._text_tokenizer = None + self._text_model_owner_thread = None # Load SpecPrefill draft model (small model for importance scoring) if self._specprefill_enabled and self._specprefill_draft_model_path: @@ -471,6 +492,11 @@ async def stop(self) -> None: self._model = None self._text_model = None self._text_tokenizer = None + self._text_model_owner_thread = None + text_model_executor = self._text_model_executor + self._text_model_executor = None + if text_model_executor is not None: + text_model_executor.shutdown(wait=False, cancel_futures=True) self._draft_model = None self._loaded = False self._system_kv_cache.clear() @@ -485,7 +511,24 @@ def _should_route_text_through_text_model( """Return whether text-only MLLM requests may use mlx_lm TextModel.""" return not (mllm_draft_requested and self._mllm_draft_model_path is not None) - async def _run_blocking_serialized(self, func, /, *args, on_cancel=None, **kwargs): + def _ensure_text_model_executor(self) -> ThreadPoolExecutor: + """Return the stable owner executor for VLM-derived TextModel calls.""" + if self._text_model_executor is None: + self._text_model_executor = ThreadPoolExecutor( + max_workers=1, + thread_name_prefix="simple-text-model", + ) + return self._text_model_executor + + async def _run_blocking_serialized( + self, + func, + /, + *args, + on_cancel=None, + executor: ThreadPoolExecutor | None = None, + **kwargs, + ): """Run a blocking MLX operation under the generation lock. Cancellation must not release the async lock before the worker thread @@ -498,7 +541,11 @@ def run_bound(): _bind_worker_generation_streams() return func(*args, **kwargs) - task = asyncio.create_task(asyncio.to_thread(run_bound)) + if executor is None: + task = asyncio.create_task(asyncio.to_thread(run_bound)) + else: + loop = asyncio.get_running_loop() + task = asyncio.ensure_future(loop.run_in_executor(executor, run_bound)) try: return await asyncio.shield(task) except asyncio.CancelledError: @@ -1927,6 +1974,7 @@ def make_cache_with_snapshot( make_cache_with_snapshot, self._text_model, hit_candidate[0], + executor=self._text_model_executor, ) ) # Bump LRU position now that we know we'll use it. @@ -2070,6 +2118,14 @@ def _resume_after_processor_retirement( def _run_all(): nonlocal backbone_cache, prompt_to_send + if ( + self._text_model_owner_thread is not None + and threading.get_ident() != self._text_model_owner_thread + ): + raise RuntimeError( + "VLM TextModel generation must run on its owner thread" + ) + model = self._text_model can_retire_processors = _processors_can_retire(all_processors) use_mtp = ( @@ -2404,6 +2460,7 @@ async def _produce_responses() -> None: await self._run_blocking_serialized( _run_all, on_cancel=abort_event.set, + executor=self._text_model_executor, ) except asyncio.CancelledError: raise