Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions tests/test_text_model_from_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,79 @@ def test_weight_sharing():
break
else:
pytest.fail("No layer with self_attn found")


def test_build_text_model_realizes_private_lazy_arrays(tmp_path, monkeypatch):
"""Lazy private arrays (e.g. RoPE._freqs) must be realized at build time.

MLX lazy graphs are tagged to the stream of the thread that recorded
them; nn.Module.parameters() excludes underscore-prefixed attributes, so
a private lazy array built on the load thread survives into generation
and fails with "There is no Stream(gpu, N) in current thread" when a
worker on another thread evaluates it. Regression: Gemma 4's scaled-RoPE
_freqs broke every MLLM text-route generation once #595 enabled the
route.
"""
import threading

import mlx.core as mx
import mlx.nn as nn

model_path = tmp_path / "gemma4"
model_path.mkdir()
(model_path / "config.json").write_text(
json.dumps({"text_config": {"model_type": "gemma4_text"}})
)

class FakeRope(nn.Module):
def __init__(self):
super().__init__()
# Lazy graph, like rope_utils' scaled-RoPE _freqs computation.
self._freqs = mx.exp(mx.arange(0, 8, dtype=mx.float32) * -0.5)

class GemmaModel(nn.Module):
def __init__(self, args):
super().__init__()
self.args_value = args
self.rope = FakeRope()

def load_weights(self, weights, strict=False):
pass

class GemmaModelArgs:
@classmethod
def from_dict(cls, config):
return "gemma4-args"

gemma_module = types.ModuleType("mlx_lm.models.gemma4_text")
gemma_module.Model = GemmaModel
gemma_module.ModelArgs = GemmaModelArgs

class FakeLanguageModel:
def parameters(self):
return {}

class FakeVlmModel:
language_model = FakeLanguageModel()

monkeypatch.setitem(sys.modules, "mlx_lm.models.gemma4_text", gemma_module)
# No tree_flatten patch here (unlike the dispatch test above): the fix's
# module walk relies on the real helper, and tree_flatten({}) is [] anyway.

text_model = build_text_model(FakeVlmModel(), model_path)
assert isinstance(text_model, GemmaModel)

# The private array must be evaluable from a different thread, which
# only holds if build_text_model realized it on the build thread.
errors = []

def cross_thread_eval():
try:
mx.eval(text_model.rope._freqs)
except RuntimeError as e: # pragma: no cover - the regression itself
errors.append(e)

t = threading.Thread(target=cross_thread_eval)
t.start()
t.join()
assert not errors, f"private lazy array not realized at build: {errors[0]}"
17 changes: 17 additions & 0 deletions vllm_mlx/text_model_from_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,23 @@ def _class_predicate(path, module):
# to the slow Python recurrence instead of the Metal kernel.
text_model.train(False)

# Realize every array the model holds before it leaves the build
# thread — including underscore-private module attributes such as
# RoPE._freqs, which parameters() excludes. MLX lazy graphs are tagged
# to the stream of the thread that recorded them; a lazy array
# surviving into generation dies with "There is no Stream(gpu, N) in
# current thread" the moment a worker on another thread evaluates it
# (Gemma 4: the scaled-RoPE _freqs of the first full_attention layer).
if hasattr(text_model, "modules"):
mx.eval(
[
v
for module in text_model.modules()
for v in module.values()
if isinstance(v, mx.array)
]
)

return text_model

except ImportError as e:
Expand Down
Loading