Skip to content

Gemma 4 text route broken on main: 'There is no Stream(gpu, N) in current thread' (lazy RoPE._freqs × thread-bound streams) #613

@ursk

Description

@ursk

Symptom

Every text-route generation on a Gemma 4 MLLM (26B-A4B MoE and 31B dense both reproduce) fails with HTTP 500 on current main:

File ".../mlx_lm/generate.py", line 442, in generate_step
    mx.eval([c.state for c in prompt_cache])
RuntimeError: There is no Stream(gpu, 2) in current thread.

git bisect between caa8838 (good) and a48c86c (bad) lands on 90f759a “fix(text-model): dispatch gemma4 text models (#595)” — but #595 is exposure, not cause: it enables the mlx_lm TextModel fast route for Gemma 4, and that route’s threading was already unsound. The fault is invisible for models whose TextModel never went down this route before.

Root cause (three layers)

1. MLX streams are thread-bound (mlx 0.31.2). A lazy graph whose ops were recorded under a stream created on thread X can only be evaluated on thread X. This holds for plain mx.new_stream streams, not just ThreadLocalStream:

import threading, mlx.core as mx
s = mx.new_stream(mx.default_device())
with mx.stream(s):
    a = mx.ones((4, 4)) * 2
def w():
    try: mx.eval(a)
    except RuntimeError as e: print(e)   # There is no Stream(gpu, 0) in current thread.
threading.Thread(target=w).start()

2. A lazy array escapes the build thread. build_text_model constructs the TextModel on the load thread. The scaled-RoPE classes in rope_utils (Llama3RoPE, YarnRoPE, SuScaledRoPE, ProportionalRoPE) compute self._freqs lazily in __init__, and nn.Module.parameters() excludes underscore-prefixed attributes — so _freqs is never realized by weight loading or any mx.eval(model.parameters()). It stays a lazy graph tagged to the load thread’s stream.

3. Generation runs on arbitrary pool threads. SimpleEngine._run_blocking_serialized executes via asyncio.to_thread (shared executor, any idle thread) and rebinds fresh streams per call (bind_generation_streams). First forward through a full_attention layer evaluates _freqs from a different thread → crash. Gemma 4’s layers 0–4 are sliding_attention (plain nn.RoPE, no _freqs); layer 5 is the first full_attention layer with scaled RoPE — confirmed by stepwise per-layer bisection: forward fails entering layer 5, and mx.eval(model.layers[5].self_attn.rope._freqs) from a worker thread reproduces the error directly. After force-realizing all module-held arrays (including private ones) on the build thread, the same cross-thread forward passes.

Fixes

  • Targeted unblock: realize every module-held array (including private attributes) before the model leaves the build thread — PR incoming, one mx.eval walk + regression test.
  • Structural (recommended follow-up): the to_thread-onto-any-pool-thread + per-call stream rebinding pattern is fragile under thread-bound streams; any lazy state crossing serialized calls (cache snapshots, prefix caches, lazily-derived weights) is a latent crash. Pinning all MLX work to a single dedicated worker thread (ThreadPoolExecutor(max_workers=1)) makes stream affinity moot. We run this in production downstream and it has been stable; happy to upstream it as a follow-up PR if there’s interest.

Environment

mlx 0.31.2, mlx_lm 0.31.3, M3 Ultra 96 GB, macOS 25.3.0.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions