diff --git a/tests/test_text_model_from_vlm.py b/tests/test_text_model_from_vlm.py index 654c8c5e..4674a5bd 100644 --- a/tests/test_text_model_from_vlm.py +++ b/tests/test_text_model_from_vlm.py @@ -204,3 +204,79 @@ def test_weight_sharing(): break else: pytest.fail("No layer with self_attn found") + + +def test_build_text_model_realizes_private_lazy_arrays(tmp_path, monkeypatch): + """Lazy private arrays (e.g. RoPE._freqs) must be realized at build time. + + MLX lazy graphs are tagged to the stream of the thread that recorded + them; nn.Module.parameters() excludes underscore-prefixed attributes, so + a private lazy array built on the load thread survives into generation + and fails with "There is no Stream(gpu, N) in current thread" when a + worker on another thread evaluates it. Regression: Gemma 4's scaled-RoPE + _freqs broke every MLLM text-route generation once #595 enabled the + route. + """ + import threading + + import mlx.core as mx + import mlx.nn as nn + + model_path = tmp_path / "gemma4" + model_path.mkdir() + (model_path / "config.json").write_text( + json.dumps({"text_config": {"model_type": "gemma4_text"}}) + ) + + class FakeRope(nn.Module): + def __init__(self): + super().__init__() + # Lazy graph, like rope_utils' scaled-RoPE _freqs computation. + self._freqs = mx.exp(mx.arange(0, 8, dtype=mx.float32) * -0.5) + + class GemmaModel(nn.Module): + def __init__(self, args): + super().__init__() + self.args_value = args + self.rope = FakeRope() + + def load_weights(self, weights, strict=False): + pass + + class GemmaModelArgs: + @classmethod + def from_dict(cls, config): + return "gemma4-args" + + gemma_module = types.ModuleType("mlx_lm.models.gemma4_text") + gemma_module.Model = GemmaModel + gemma_module.ModelArgs = GemmaModelArgs + + class FakeLanguageModel: + def parameters(self): + return {} + + class FakeVlmModel: + language_model = FakeLanguageModel() + + monkeypatch.setitem(sys.modules, "mlx_lm.models.gemma4_text", gemma_module) + # No tree_flatten patch here (unlike the dispatch test above): the fix's + # module walk relies on the real helper, and tree_flatten({}) is [] anyway. + + text_model = build_text_model(FakeVlmModel(), model_path) + assert isinstance(text_model, GemmaModel) + + # The private array must be evaluable from a different thread, which + # only holds if build_text_model realized it on the build thread. + errors = [] + + def cross_thread_eval(): + try: + mx.eval(text_model.rope._freqs) + except RuntimeError as e: # pragma: no cover - the regression itself + errors.append(e) + + t = threading.Thread(target=cross_thread_eval) + t.start() + t.join() + assert not errors, f"private lazy array not realized at build: {errors[0]}" diff --git a/vllm_mlx/text_model_from_vlm.py b/vllm_mlx/text_model_from_vlm.py index c99821b1..1083bdda 100644 --- a/vllm_mlx/text_model_from_vlm.py +++ b/vllm_mlx/text_model_from_vlm.py @@ -151,6 +151,23 @@ def _class_predicate(path, module): # to the slow Python recurrence instead of the Metal kernel. text_model.train(False) + # Realize every array the model holds before it leaves the build + # thread — including underscore-private module attributes such as + # RoPE._freqs, which parameters() excludes. MLX lazy graphs are tagged + # to the stream of the thread that recorded them; a lazy array + # surviving into generation dies with "There is no Stream(gpu, N) in + # current thread" the moment a worker on another thread evaluates it + # (Gemma 4: the scaled-RoPE _freqs of the first full_attention layer). + if hasattr(text_model, "modules"): + mx.eval( + [ + v + for module in text_model.modules() + for v in module.values() + if isinstance(v, mx.array) + ] + ) + return text_model except ImportError as e: