Skip to content

fix: run BatchedEngine MLLM on dedicated MLXWorkerThread to prevent cross-thread stream errors#479

Closed
xykong wants to merge 4 commits into
waybarrios:mainfrom
xykong:fix/batched-engine-stream-thread
Closed

fix: run BatchedEngine MLLM on dedicated MLXWorkerThread to prevent cross-thread stream errors#479
xykong wants to merge 4 commits into
waybarrios:mainfrom
xykong:fix/batched-engine-stream-thread

Conversation

@xykong

@xykong xykong commented May 1, 2026

Copy link
Copy Markdown

Summary

Fixes RuntimeError: There is no Stream(gpu, N) in current thread when running BatchedEngine (--continuous-batching) with MLX >= 0.31.2 on Apple Silicon.

Problem

MLX 0.31.2 made Metal CommandEncoders thread-local. When BatchedEngine._prepare_mllm_model() loads the model on the event-loop thread but MLLMScheduler._process_loop() runs step() on a different thread, the runtime raises:

RuntimeError: There is no Stream(gpu, N) in current thread

This happens because MLX arrays created during model loading are tagged with the loading thread's stream, which doesn't exist in the inference thread's TLS.

Fix

  1. mlx_streams.py: Add MLXWorkerThread — a persistent single-thread executor that guarantees model load + inference share the same thread-local stream context.

  2. engine/batched.py: Make _prepare_mllm_model() async and run model loading on MLXWorkerThread. Pass the worker to MLLMScheduler.

  3. mllm_scheduler.py: Accept mlx_worker parameter. When provided, submit step() and preprocessing to the worker thread. Falls back to the legacy event-loop path (with bind_generation_streams()) when no worker is provided.

  4. mllm_batch_generator.py: Remove MLLMBatchGenerator._stream (class-level mx.new_stream()) in favor of the worker thread's default stream, eliminating the last source of cross-thread stream references.

Testing

Tested on M4 Max 128GB with mlx-community/gemma-4-26b-a4b-it-4bit:

Metric Result
Single request (short) 93 tok/s
Single request (long) 82 tok/s
4x concurrent aggregate 179 tok/s
TTFT (short) 35ms

Previously this configuration crashed immediately on the first request.

Backward Compatibility

  • The legacy event-loop path (no MLXWorkerThread) is preserved as fallback when mlx_worker=None
  • SimpleEngine path is unaffected
  • LLM-only BatchedEngine path is unaffected

MLX >= 0.31.2 makes Metal command encoders thread-local. When
BatchedEngine loads the model on the event-loop thread but runs
inference on a worker thread, mx.eval() raises:

  RuntimeError: There is no Stream(gpu, N) in current thread

Fix by:
1. Adding MLXWorkerThread to mlx_streams.py — a persistent single-thread
   executor that guarantees model load + inference share the same
   thread-local stream context.
2. Moving _prepare_mllm_model() to run on MLXWorkerThread via async submit.
3. Passing the worker to MLLMScheduler so _process_loop submits step()
   to the same thread.
4. Removing MLLMBatchGenerator._stream (class-level mx.new_stream) in
   favor of the worker thread default stream, eliminating the last
   source of cross-thread stream references.

The legacy event-loop path (no worker thread) is preserved as fallback
with bind_generation_streams() for backward compatibility.

Tested on M4 Max 128GB with gemma-4-26b-a4b-it-4bit:
- Single request: 82-93 tok/s generation
- 4x concurrent: 179 tok/s aggregate throughput

@janhilgard janhilgard left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overall approach is solid — model load + inference on the same persistent MLXWorkerThread is the right fix for MLX 0.31.2 thread-local CommandEncoders. The legacy fallback path is a nice touch for backward compatibility. No streaming regression here (unlike #478 for SimpleEngine).

A few issues to address:


1. if True: replaces with mx.stream(...) — should dedent instead

if True:  # use thread-default stream
    logits = self.language_model(last_token, cache=request_cache)

This appears 4 times. The if True: block is dead code that exists only to preserve indentation. Just remove the with mx.stream(...) wrapper and dedent the body. A # noqa or a comment is fine, but a no-op if True: is a code smell that will confuse future readers.

2. Preprocessing unnecessarily serialized on the MLX worker thread

The old code ran text-only preprocessing in the default thread-pool executor (loop.run_in_executor(None, ...)):

await loop.run_in_executor(None, bg._preprocess_request, req)

The new worker-thread path submits preprocessing to the same single MLX worker:

await self._mlx_worker.submit(loop, bg._preprocess_request, req)

The comment says "Preprocessing creates MLX arrays (input_ids) so it must run on the same worker thread" — but _preprocess_request does Jinja2 template rendering + tokenizer.encode(), which produces Python lists, not mx.array. The mx.array() conversion happens inside _process_prompts() called from step().

Running preprocessing on the single MLX worker thread serializes it with step(), which means a 30-second preprocessing of a 40K+ token conversation blocks ALL inference for 30 seconds. The old approach of offloading to the thread pool kept inference unblocked during preprocessing. Consider keeping run_in_executor for preprocessing.

3. Lost slow-preprocessing logging and dynamic n_yields

The old code logged slow preprocessing:

elapsed = time.perf_counter() - tic
if elapsed > 1.0:
    logger.info(f"Preprocessing {req.request_id[:12]}: {n_tok} tokens in {elapsed:.2f}s")

And had dynamic yield counts:

n_yields = 10 if elapsed > 1.0 else 5

The worker-thread path drops both — no timing, always 5 yields. Please preserve the slow-preprocessing logging (it's very useful for debugging production issues) and the dynamic yield count.

4. mx.new_thread_local_stream(mx.gpu) should live in MLXWorkerThread._run(), not in _load_on_worker

Currently the thread-local stream is initialized inside _load_on_worker():

def _load_on_worker():
    mx.new_thread_local_stream(mx.gpu)
    inst = MLXMultimodalLM(...)

But the MLXWorkerThread docstring promises "mx.new_stream() is called exactly once (during thread init)." If someone submits a task before model loading, there's no stream. Move mx.new_thread_local_stream(mx.gpu) into MLXWorkerThread._run() at the top, before the task loop.

5. Overlap with PR #478 on mlx_streams.py

Both #478 and #479 add MLXWorkerThread to mlx_streams.py. If both land, there will be a merge conflict. Consider coordinating — either land #478 first and rebase #479 on top, or extract MLXWorkerThread into a shared PR that both depend on.

6. Comments stripped from legacy path

Several explanatory comments were removed from the legacy path (early preprocessing phase rationale, health-check yield explanation). These comments document why the code works the way it does and are valuable for maintainability. Please keep them.


Summary

The MLXWorkerThread + BatchedEngine integration is architecturally sound. Main blockers: (1) preprocessing should stay on the thread pool to avoid blocking inference, (2) if True: should be proper dedent, (3) preserve the slow-step diagnostics logging. The rest are minor cleanups.

xykong added 3 commits May 1, 2026 23:00
…aterialization

- Add _tokenize_text_only(): CPU-only tokenization (Jinja2 template +
  tokenizer) that produces a plain Python list. Safe to run on any thread
  via the default ThreadPoolExecutor, unblocking the MLX worker for
  generation during long prompt tokenization (10-30s for 40K+ tokens).

- Add _materialize_tokens(): fast mx.array() conversion on the MLX worker
  thread. Microseconds per request, ensures arrays are on the correct
  stream.

- Update _process_loop in scheduler to use the two-phase approach:
  Phase 1: CPU tokenization on thread pool (parallel, non-blocking)
  Phase 2: mx.array() on worker thread (serial, fast)

- Remove dead "if True:  # use thread-default stream" blocks and dedent
  their bodies (leftover from mx.stream removal).

Benchmark: aggregate throughput peaks at ~197 tok/s (16 concurrent),
sweet spot 4 concurrent at 177 tok/s with 44.7 tok/s per-request.
Two fixes for multimodal (vision) support:

1. prepare_for_start(): Make MLLM path a no-op. The async
   _prepare_mllm_model() requires MLXWorkerThread and event loop,
   so model loading must be deferred to _start_mllm() which is
   properly awaited. Calling it synchronously caused a
   RuntimeWarning (coroutine never awaited) and broken model state.

2. _process_prompts(): Skip prefix cache lookup when the request
   has pixel_values. The VLM forward pass must run with pixel_values
   to encode vision features into the KV cache. Previously, a prefix
   cache hit would route multimodal requests through the language-model-only
   path, which cannot process image placeholder tokens — producing
   garbage output that ignored the image content entirely.

   The existing image_token_index guard did not catch this because
   some models (e.g. Gemma 4) do not set config.image_token_index.
Quantized Gemma 4 models (e.g. 4-bit) sometimes emit tool call
arguments with single quotes instead of the expected <|"|"> delimiter
tokens. For example: {location:'Tokyo'} instead of
{location:<|"|">Tokyo<|"|">}.

Add a pre-processing step (Step 0.5) in _gemma4_args_to_json that
converts single-quoted strings to the canonical <|"|">-delimited
format before the existing parsing pipeline runs.
@CarloArpini

Copy link
Copy Markdown

Thanks for the great work on this PR which builds on top of #478; the MLLM path fix is solid and the two-phase preprocessing split is a nice touch.

I wanted to flag that the LLM-only BatchedEngine path (--continuous-batching with a text-only model, i.e. mllm=False) is still crashing with the same RuntimeError: There is no Stream(gpu, N) in current thread on MLX 0.31.2. The crash occurs in scheduler.step() when it calls self.batch_generator.next() on a thread that doesn't own the GPU stream created during BatchGenerator initialization.

Full traceback:

File "vllm_mlx/scheduler.py", line 2455, in step
    result = self.batch_generator.next()
  File "mlx_lm/generate.py", line 1855, in next
    return self._next()
  File "mlx_lm/generate.py", line 1841, in _next
    self._prompt_batch.prompt(prompts)
  File "mlx_lm/generate.py", line 1161, in prompt
    mx.eval([c.state for c in self.prompt_cache])
RuntimeError: There is no Stream(gpu, 1) in current thread.

Server startup confirms LLM-only path: BatchedEngine loaded: <model> (mllm=False).

TBH I don't have enough expertise to confirm this, but maybe the fix can follow the same pattern as this PR: the SchedulerCore (or wherever step() is dispatched) needs to submit batch_generator.next() to the same persistent MLXWorkerThread on which the BatchGenerator was created, mirroring what you've done here for MLLMScheduler._process_loop().

Environment: Apple M3 Ultra 512GB, macOS, MLX 0.31.2, mlx-lm 0.31.3, DeepSeek-R1-Distill-Llama-8B (text-only). Reproduced consistently across the original waybarrios/main, adriangalilea fork (PR #478), and this branch.

@waybarrios

Copy link
Copy Markdown
Owner

Hi @xykong, thanks for chasing the cross-thread stream bug. The MLXWorkerThread approach looks solid and the test results on Gemma look great. CI is red on three small issues; happy to push a follow-up commit to your branch if you'd prefer, otherwise here's what to fix.

The merge conflict is in vllm_mlx/mlx_streams.py and it's documentation only. main extended the bind_generation_streams docstring with a paragraph about admission/ownership semantics. Your PR added a .. deprecated:: block. Both texts are complementary, so the cleanest resolution after git fetch origin main && git merge origin/main is to keep both:

    MLX streams are thread-local. If a model is loaded on one thread and
    generation runs on another, module-level generation streams created during
    import can point at a stream that does not exist in the worker thread.

    This intentionally creates a fresh stream for the current worker call and
    replaces module-level generation_stream handles under a process-local lock.
    It is an admission/ownership fix, not a batching optimization; callers
    should invoke it at worker-entry boundaries rather than inside token loops.

    .. deprecated::
        Prefer ``MLXWorkerThread`` which keeps model load and inference on the
        same persistent thread, avoiding the need to rebind streams entirely.
    """

Then the lint failure on vllm_mlx/mllm_batch_generator.py:787. The from mlx_vlm.utils import prepare_inputs is unused after your preprocessing refactor in a157942 (the function now calls tokenizer(...) directly). Just delete that import; ruff check --fix will do it for you.

Finally, the failing test in tests/test_mllm_continuous_batching.py:1259. The mock FakeMLLMScheduler.__init__ doesn't accept the new mlx_worker kwarg that BatchedEngine._start_mllm now passes. One line:

        class FakeMLLMScheduler:
            def __init__(self, model, processor, config, mlx_worker=None):
                captured["scheduler_config"] = config
                captured["mlx_worker"] = mlx_worker

Capturing mlx_worker matches the existing capture style for config_kwargs and leaves the door open for a future assertion that the worker is wired correctly. This same change fixes the 3.11, 3.13, and tests aggregator jobs.

Let me know which path you want and I'll re-run CI when it lands.

waybarrios pushed a commit that referenced this pull request Jun 9, 2026
)

* fix(ssd-cache): snapshot KV layers on producer thread + handle bf16 → numpy

Two distinct bugs that together prevented --ssd-cache-dir from being usable
on any of the top models (qwen3-coder-30b, gpt-oss-20b, gpt-oss-120b) under
--continuous-batching with --enable-prefix-cache.

Bug 1: SSD writer thread crashed with std::runtime_error: There is no
Stream(gpu, N) in current thread.

When a prefix-cache eviction triggered ssd_tier.enqueue_spill(), the
mx.array references for layer.keys / layer.values were placed on a
queue and picked up by a background writer thread. That writer thread
then called np.array(mx_array) inside serialize_layer, which forces an
MLX materialization. MLX 0.31+ ties each request's compute to a
thread-local Stream(gpu, N) (where N is the request uid); the writer
thread has no such stream registered, so the conversion aborts the
whole process.

Same root cause as issue #496 / PR #479, but at a new site: the SSD
spill writer is its own background daemon thread, not the BatchedEngine
MLLM thread #479 fixes.

Fix: add LayerSerializer.snapshot_layer(layer) -> dict, called on the
producer (request handler) thread inside enqueue_spill BEFORE the queue
handoff. The snapshot materializes each mx.array to numpy while the
request's stream is still valid, and the writer thread now only handles
numpy + disk bytes. serialize_layer's signature changes from
(layer, idx, path) -> metadata to (snapshot, idx, path) -> metadata so
the writer thread can't accidentally touch MLX.

Bug 2: np.array(mx.array<bfloat16>) raises RuntimeError "Item size 2
for PEP 3118 buffer format string B does not match the dtype B item
size 1" — numpy has no native bfloat16. qwen3-coder-30b and many other
modern MLX models use bf16 KV cache, so this was masked behind Bug 1
on every model that actually evicts.

Fix: add ssd_cache._mx_to_numpy_safe(arr) which try/excepts the bf16
case and falls back to arr.astype(mx.float32) + mx.eval + np.array.
Original dtype is recorded in the snapshot (and propagated through the
manifest) so _reconstruct_ssd_layers in scheduler.py can cast back to
bf16 after deserialization. fp32 storage is 2x bf16 on disk; transient
2x RAM during snapshot — acceptable for the spill path.

Plus: fix 4 unit tests in test_ssd_cache.py that were failing on
0.4.0rc1 base (Bug 1 was masking these too). The MemoryCacheConfig
default min_prefix_tokens=128 rejected the 10-token sequences these
tests used. Tests now pass min_prefix_tokens=1 explicitly.

Plus: new regression test test_snapshot_runs_on_caller_thread that
uses a ThreadBoundArray fake to assert enqueue_spill must materialize
layer tensors on the calling thread, not the writer thread.

Runtime verified on M5 Pro Max 128GB with Qwen3-Coder-30B
(MLX-4bit, bf16 KV): start vllm-mlx with --continuous-batching
--enable-prefix-cache --cache-memory-mb 50 --ssd-cache-dir /tmp/test
--ssd-cache-max-gb 5. Send 9 distinct 200+ token prompts. Before:
server aborts on request 4 with Stream(gpu, 3). After: server completes
all 9 requests, evictions=8 in /v1/cache/stats, 400 safetensors files
~462MB written to /tmp/test/data, no crash.

* chore: trim verbose comments on ssd-cache spill fix

Drop the multi-paragraph docstrings on LayerSerializer / snapshot_layer /
enqueue_spill / _writer_loop / _write_entry and the long test docstring
explaining the threading invariant. The threading model (snapshot on
producer thread, persist on writer thread) is stated once in the
LayerSerializer class docstring and otherwise lives in the PR description.

Kept: brief one-line reasons on _mx_to_numpy_safe (bf16 buffer mismatch),
on _mx_dtype_from_name in scheduler.py (why None is acceptable), and on
min_prefix_tokens=1 in the 4 fixed unit tests.

No behavior change. 106 tests still pass; lint + black clean.
@waybarrios

Copy link
Copy Markdown
Owner

Closing as superseded, same story as #478. Main's design keeps MLLM step() on the event-loop thread (see mllm_scheduler) and prepare_for_start runs inline since 3396981, so the dedicated worker thread approach moves in the opposite direction from what landed.

@waybarrios waybarrios closed this Jun 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants