fix: run BatchedEngine MLLM on dedicated MLXWorkerThread to prevent cross-thread stream errors#479
fix: run BatchedEngine MLLM on dedicated MLXWorkerThread to prevent cross-thread stream errors#479xykong wants to merge 4 commits into
Conversation
MLX >= 0.31.2 makes Metal command encoders thread-local. When BatchedEngine loads the model on the event-loop thread but runs inference on a worker thread, mx.eval() raises: RuntimeError: There is no Stream(gpu, N) in current thread Fix by: 1. Adding MLXWorkerThread to mlx_streams.py — a persistent single-thread executor that guarantees model load + inference share the same thread-local stream context. 2. Moving _prepare_mllm_model() to run on MLXWorkerThread via async submit. 3. Passing the worker to MLLMScheduler so _process_loop submits step() to the same thread. 4. Removing MLLMBatchGenerator._stream (class-level mx.new_stream) in favor of the worker thread default stream, eliminating the last source of cross-thread stream references. The legacy event-loop path (no worker thread) is preserved as fallback with bind_generation_streams() for backward compatibility. Tested on M4 Max 128GB with gemma-4-26b-a4b-it-4bit: - Single request: 82-93 tok/s generation - 4x concurrent: 179 tok/s aggregate throughput
janhilgard
left a comment
There was a problem hiding this comment.
The overall approach is solid — model load + inference on the same persistent MLXWorkerThread is the right fix for MLX 0.31.2 thread-local CommandEncoders. The legacy fallback path is a nice touch for backward compatibility. No streaming regression here (unlike #478 for SimpleEngine).
A few issues to address:
1. if True: replaces with mx.stream(...) — should dedent instead
if True: # use thread-default stream
logits = self.language_model(last_token, cache=request_cache)This appears 4 times. The if True: block is dead code that exists only to preserve indentation. Just remove the with mx.stream(...) wrapper and dedent the body. A # noqa or a comment is fine, but a no-op if True: is a code smell that will confuse future readers.
2. Preprocessing unnecessarily serialized on the MLX worker thread
The old code ran text-only preprocessing in the default thread-pool executor (loop.run_in_executor(None, ...)):
await loop.run_in_executor(None, bg._preprocess_request, req)The new worker-thread path submits preprocessing to the same single MLX worker:
await self._mlx_worker.submit(loop, bg._preprocess_request, req)The comment says "Preprocessing creates MLX arrays (input_ids) so it must run on the same worker thread" — but _preprocess_request does Jinja2 template rendering + tokenizer.encode(), which produces Python lists, not mx.array. The mx.array() conversion happens inside _process_prompts() called from step().
Running preprocessing on the single MLX worker thread serializes it with step(), which means a 30-second preprocessing of a 40K+ token conversation blocks ALL inference for 30 seconds. The old approach of offloading to the thread pool kept inference unblocked during preprocessing. Consider keeping run_in_executor for preprocessing.
3. Lost slow-preprocessing logging and dynamic n_yields
The old code logged slow preprocessing:
elapsed = time.perf_counter() - tic
if elapsed > 1.0:
logger.info(f"Preprocessing {req.request_id[:12]}: {n_tok} tokens in {elapsed:.2f}s")And had dynamic yield counts:
n_yields = 10 if elapsed > 1.0 else 5The worker-thread path drops both — no timing, always 5 yields. Please preserve the slow-preprocessing logging (it's very useful for debugging production issues) and the dynamic yield count.
4. mx.new_thread_local_stream(mx.gpu) should live in MLXWorkerThread._run(), not in _load_on_worker
Currently the thread-local stream is initialized inside _load_on_worker():
def _load_on_worker():
mx.new_thread_local_stream(mx.gpu)
inst = MLXMultimodalLM(...)But the MLXWorkerThread docstring promises "mx.new_stream() is called exactly once (during thread init)." If someone submits a task before model loading, there's no stream. Move mx.new_thread_local_stream(mx.gpu) into MLXWorkerThread._run() at the top, before the task loop.
5. Overlap with PR #478 on mlx_streams.py
Both #478 and #479 add MLXWorkerThread to mlx_streams.py. If both land, there will be a merge conflict. Consider coordinating — either land #478 first and rebase #479 on top, or extract MLXWorkerThread into a shared PR that both depend on.
6. Comments stripped from legacy path
Several explanatory comments were removed from the legacy path (early preprocessing phase rationale, health-check yield explanation). These comments document why the code works the way it does and are valuable for maintainability. Please keep them.
Summary
The MLXWorkerThread + BatchedEngine integration is architecturally sound. Main blockers: (1) preprocessing should stay on the thread pool to avoid blocking inference, (2) if True: should be proper dedent, (3) preserve the slow-step diagnostics logging. The rest are minor cleanups.
…aterialization - Add _tokenize_text_only(): CPU-only tokenization (Jinja2 template + tokenizer) that produces a plain Python list. Safe to run on any thread via the default ThreadPoolExecutor, unblocking the MLX worker for generation during long prompt tokenization (10-30s for 40K+ tokens). - Add _materialize_tokens(): fast mx.array() conversion on the MLX worker thread. Microseconds per request, ensures arrays are on the correct stream. - Update _process_loop in scheduler to use the two-phase approach: Phase 1: CPU tokenization on thread pool (parallel, non-blocking) Phase 2: mx.array() on worker thread (serial, fast) - Remove dead "if True: # use thread-default stream" blocks and dedent their bodies (leftover from mx.stream removal). Benchmark: aggregate throughput peaks at ~197 tok/s (16 concurrent), sweet spot 4 concurrent at 177 tok/s with 44.7 tok/s per-request.
Two fixes for multimodal (vision) support: 1. prepare_for_start(): Make MLLM path a no-op. The async _prepare_mllm_model() requires MLXWorkerThread and event loop, so model loading must be deferred to _start_mllm() which is properly awaited. Calling it synchronously caused a RuntimeWarning (coroutine never awaited) and broken model state. 2. _process_prompts(): Skip prefix cache lookup when the request has pixel_values. The VLM forward pass must run with pixel_values to encode vision features into the KV cache. Previously, a prefix cache hit would route multimodal requests through the language-model-only path, which cannot process image placeholder tokens — producing garbage output that ignored the image content entirely. The existing image_token_index guard did not catch this because some models (e.g. Gemma 4) do not set config.image_token_index.
Quantized Gemma 4 models (e.g. 4-bit) sometimes emit tool call
arguments with single quotes instead of the expected <|"|"> delimiter
tokens. For example: {location:'Tokyo'} instead of
{location:<|"|">Tokyo<|"|">}.
Add a pre-processing step (Step 0.5) in _gemma4_args_to_json that
converts single-quoted strings to the canonical <|"|">-delimited
format before the existing parsing pipeline runs.
|
Thanks for the great work on this PR which builds on top of #478; the MLLM path fix is solid and the two-phase preprocessing split is a nice touch. I wanted to flag that the LLM-only Full traceback: Server startup confirms LLM-only path: TBH I don't have enough expertise to confirm this, but maybe the fix can follow the same pattern as this PR: the Environment: Apple M3 Ultra 512GB, macOS, MLX 0.31.2, mlx-lm 0.31.3, DeepSeek-R1-Distill-Llama-8B (text-only). Reproduced consistently across the original |
|
Hi @xykong, thanks for chasing the cross-thread stream bug. The The merge conflict is in MLX streams are thread-local. If a model is loaded on one thread and
generation runs on another, module-level generation streams created during
import can point at a stream that does not exist in the worker thread.
This intentionally creates a fresh stream for the current worker call and
replaces module-level generation_stream handles under a process-local lock.
It is an admission/ownership fix, not a batching optimization; callers
should invoke it at worker-entry boundaries rather than inside token loops.
.. deprecated::
Prefer ``MLXWorkerThread`` which keeps model load and inference on the
same persistent thread, avoiding the need to rebind streams entirely.
"""Then the lint failure on Finally, the failing test in class FakeMLLMScheduler:
def __init__(self, model, processor, config, mlx_worker=None):
captured["scheduler_config"] = config
captured["mlx_worker"] = mlx_workerCapturing Let me know which path you want and I'll re-run CI when it lands. |
) * fix(ssd-cache): snapshot KV layers on producer thread + handle bf16 → numpy Two distinct bugs that together prevented --ssd-cache-dir from being usable on any of the top models (qwen3-coder-30b, gpt-oss-20b, gpt-oss-120b) under --continuous-batching with --enable-prefix-cache. Bug 1: SSD writer thread crashed with std::runtime_error: There is no Stream(gpu, N) in current thread. When a prefix-cache eviction triggered ssd_tier.enqueue_spill(), the mx.array references for layer.keys / layer.values were placed on a queue and picked up by a background writer thread. That writer thread then called np.array(mx_array) inside serialize_layer, which forces an MLX materialization. MLX 0.31+ ties each request's compute to a thread-local Stream(gpu, N) (where N is the request uid); the writer thread has no such stream registered, so the conversion aborts the whole process. Same root cause as issue #496 / PR #479, but at a new site: the SSD spill writer is its own background daemon thread, not the BatchedEngine MLLM thread #479 fixes. Fix: add LayerSerializer.snapshot_layer(layer) -> dict, called on the producer (request handler) thread inside enqueue_spill BEFORE the queue handoff. The snapshot materializes each mx.array to numpy while the request's stream is still valid, and the writer thread now only handles numpy + disk bytes. serialize_layer's signature changes from (layer, idx, path) -> metadata to (snapshot, idx, path) -> metadata so the writer thread can't accidentally touch MLX. Bug 2: np.array(mx.array<bfloat16>) raises RuntimeError "Item size 2 for PEP 3118 buffer format string B does not match the dtype B item size 1" — numpy has no native bfloat16. qwen3-coder-30b and many other modern MLX models use bf16 KV cache, so this was masked behind Bug 1 on every model that actually evicts. Fix: add ssd_cache._mx_to_numpy_safe(arr) which try/excepts the bf16 case and falls back to arr.astype(mx.float32) + mx.eval + np.array. Original dtype is recorded in the snapshot (and propagated through the manifest) so _reconstruct_ssd_layers in scheduler.py can cast back to bf16 after deserialization. fp32 storage is 2x bf16 on disk; transient 2x RAM during snapshot — acceptable for the spill path. Plus: fix 4 unit tests in test_ssd_cache.py that were failing on 0.4.0rc1 base (Bug 1 was masking these too). The MemoryCacheConfig default min_prefix_tokens=128 rejected the 10-token sequences these tests used. Tests now pass min_prefix_tokens=1 explicitly. Plus: new regression test test_snapshot_runs_on_caller_thread that uses a ThreadBoundArray fake to assert enqueue_spill must materialize layer tensors on the calling thread, not the writer thread. Runtime verified on M5 Pro Max 128GB with Qwen3-Coder-30B (MLX-4bit, bf16 KV): start vllm-mlx with --continuous-batching --enable-prefix-cache --cache-memory-mb 50 --ssd-cache-dir /tmp/test --ssd-cache-max-gb 5. Send 9 distinct 200+ token prompts. Before: server aborts on request 4 with Stream(gpu, 3). After: server completes all 9 requests, evictions=8 in /v1/cache/stats, 400 safetensors files ~462MB written to /tmp/test/data, no crash. * chore: trim verbose comments on ssd-cache spill fix Drop the multi-paragraph docstrings on LayerSerializer / snapshot_layer / enqueue_spill / _writer_loop / _write_entry and the long test docstring explaining the threading invariant. The threading model (snapshot on producer thread, persist on writer thread) is stated once in the LayerSerializer class docstring and otherwise lives in the PR description. Kept: brief one-line reasons on _mx_to_numpy_safe (bf16 buffer mismatch), on _mx_dtype_from_name in scheduler.py (why None is acceptable), and on min_prefix_tokens=1 in the 4 fixed unit tests. No behavior change. 106 tests still pass; lint + black clean.
Summary
Fixes
RuntimeError: There is no Stream(gpu, N) in current threadwhen running BatchedEngine (--continuous-batching) with MLX >= 0.31.2 on Apple Silicon.Problem
MLX 0.31.2 made Metal CommandEncoders thread-local. When
BatchedEngine._prepare_mllm_model()loads the model on the event-loop thread butMLLMScheduler._process_loop()runsstep()on a different thread, the runtime raises:This happens because MLX arrays created during model loading are tagged with the loading thread's stream, which doesn't exist in the inference thread's TLS.
Fix
mlx_streams.py: AddMLXWorkerThread— a persistent single-thread executor that guarantees model load + inference share the same thread-local stream context.engine/batched.py: Make_prepare_mllm_model()async and run model loading onMLXWorkerThread. Pass the worker toMLLMScheduler.mllm_scheduler.py: Acceptmlx_workerparameter. When provided, submitstep()and preprocessing to the worker thread. Falls back to the legacy event-loop path (withbind_generation_streams()) when no worker is provided.mllm_batch_generator.py: RemoveMLLMBatchGenerator._stream(class-levelmx.new_stream()) in favor of the worker thread's default stream, eliminating the last source of cross-thread stream references.Testing
Tested on M4 Max 128GB with
mlx-community/gemma-4-26b-a4b-it-4bit:Previously this configuration crashed immediately on the first request.
Backward Compatibility
MLXWorkerThread) is preserved as fallback whenmlx_worker=None