Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions omlx/engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,22 @@ def get_mlx_executor() -> concurrent.futures.ThreadPoolExecutor:
return _global_mlx_executor


_wired_limit_set = False


def _ensure_wired_limit() -> None:
"""Set Metal wired memory limit once at first engine creation.

BatchGenerator normally calls mx.set_wired_limit() per-instance, which
races when multiple engines init concurrently (process-global setting).
We call it once here instead.
"""
global _wired_limit_set
if not _wired_limit_set and mx.metal.is_available():
mx.set_wired_limit(mx.device_info()["max_recommended_working_set_size"])
_wired_limit_set = True


@dataclass
class EngineConfig:
"""Configuration for the engine."""
Expand Down Expand Up @@ -133,12 +149,23 @@ def __init__(
)
self._owns_model = True

# Create scheduler
# Per-engine executor with dedicated mx.Stream (#1248).
# Each EngineCore gets its own thread + GPU stream so different
# models can run scheduler.step() concurrently.
_ensure_wired_limit()
self._mlx_stream = mx.new_thread_local_stream(mx.default_device())
self._mlx_executor = concurrent.futures.ThreadPoolExecutor(
max_workers=1,
thread_name_prefix=f"mlx-engine-{self._engine_id[:8]}",
)

# Create scheduler with per-engine stream
scheduler_config = self.config.scheduler_config or SchedulerConfig()
self.scheduler = Scheduler(
model=model,
tokenizer=tokenizer,
config=scheduler_config,
stream=self._mlx_stream,
)

# Output collectors for low-latency streaming (vLLM pattern)
Expand All @@ -152,11 +179,6 @@ def __init__(
self._start_time: Optional[float] = None
self._steps_executed = 0

# Global single-thread executor shared across ALL engines.
# mlx-lm uses a module-level Metal stream, so concurrent MLX calls
# from different engine threads cause segfaults. See issue #85.
self._mlx_executor = get_mlx_executor()

logger.debug(f"Engine {self._engine_id} initialized")

async def start(self) -> None:
Expand Down Expand Up @@ -698,9 +720,9 @@ def close(self) -> None:

self._closed = True

# Both shutdown() and deep_reset() touch generation_stream (directly
# Both shutdown() and deep_reset() touch the engine stream (directly
# or via _drain_pending_async_removes / _do_abort_request). The
# stream is bound to the MLX executor thread, so dispatch both
# stream is bound to the engine's executor thread, so dispatch both
# through the executor; fall back to a direct call if the executor
# is already shut down.
for fn in (self.scheduler.shutdown, self.scheduler.deep_reset):
Expand All @@ -712,6 +734,10 @@ def close(self) -> None:
except RuntimeError:
pass

if self._mlx_executor is not None:
self._mlx_executor.shutdown(wait=True)
self._mlx_executor = None

# Clear output collectors
for collector in self._output_collectors.values():
collector.clear()
Expand Down
61 changes: 20 additions & 41 deletions omlx/patches/mlx_lm_mtp/batch_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,24 +308,6 @@ class _MtpState:
# Helpers
# ---------------------------------------------------------------------------

def _get_generation_stream():
"""Return the ``mlx_lm.generate`` module-level generation stream.

The standard ``GenerationBatch._step`` runs all forward passes inside
``mx.stream(generation_stream)``; the MTP cycle does the same so the
paged cache writes land on the same stream and ordering is preserved.
The stream lives on the *outer* ``BatchGenerator``, not on
``GenerationBatch``, so we read it from the module.

Note: ``mlx_lm.__init__`` re-exports a ``generate`` *function*, so
``import mlx_lm.generate as mlg`` resolves to the function, not the
module. We use ``sys.modules`` to grab the actual module.
"""
import sys

return sys.modules["mlx_lm.generate"].generation_stream


def _resolve_sampler(gen_batch: Any):
"""Match ``GenerationBatch._step``'s per-sequence sampler resolution (batch=1)."""
if gen_batch.samplers and gen_batch.samplers[0] is not None:
Expand Down Expand Up @@ -507,9 +489,9 @@ def _reconcile_mtp_to_standard(gen_batch: Any, state: _MtpState) -> bool:
procs = _proc_list(gen_batch)
_set_singleton_mrope_delta(gen_batch)
tok_arr = _ensure_uint32(mx.array(list(tokens)))
with mx.stream(_get_generation_stream()):
logits, _, _ = _call_backbone(gen_batch.model, tok_arr[None, :], new_cache)
last_logits = logits[:, -1, :] # (1, vocab) — dist after tokens[-1]
# Inherits the per-engine stream from the enclosing BatchGenerator context.
logits, _, _ = _call_backbone(gen_batch.model, tok_arr[None, :], new_cache)
last_logits = logits[:, -1, :] # (1, vocab) — dist after tokens[-1]

if state.queue:
next_id, next_lp_1d, _src = state.queue[0]
Expand Down Expand Up @@ -760,10 +742,10 @@ def _post_init_mtp(gen_batch: Any) -> None:

# 1-token backbone forward at main_tok with hidden state. No draft yet,
# so no rollback is possible — discard gdn_states.
with mx.stream(_get_generation_stream()):
logits, hidden, _ = _call_backbone(
gen_batch.model, main_tok[:, None], gen_batch.prompt_cache
)
# Inherits the per-engine stream from the enclosing BatchGenerator context.
logits, hidden, _ = _call_backbone(
gen_batch.model, main_tok[:, None], gen_batch.prompt_cache
)

next_main_logits = logits[:, -1, :] # (1, vocab) — distribution after main_tok
next_main_logits = _apply_processors(procs, prev_buf, next_main_logits)
Expand All @@ -775,8 +757,7 @@ def _post_init_mtp(gen_batch: Any) -> None:
mtp_cache = gen_batch.model.make_mtp_cache()
hidden_at_main = hidden[:, -1:, :] # (1, 1, H)
next_ids = next_main_tok.reshape(1, 1)
with mx.stream(_get_generation_stream()):
mtp_logits = gen_batch.model.mtp_forward(hidden_at_main, next_ids, mtp_cache)
mtp_logits = gen_batch.model.mtp_forward(hidden_at_main, next_ids, mtp_cache)
mtp_logits_2d = mtp_logits[:, -1, :]
if procs is not None:
prev_with_main_and_next = mx.concatenate(
Expand Down Expand Up @@ -923,15 +904,14 @@ def _run_verify_cycle(gen_batch: Any, state: _MtpState) -> None:
# per cycle (negligible vs the forward compute) and keeps the
# backbone_ms / sample_ms split accurate.
t0 = time.perf_counter()
with mx.stream(_get_generation_stream()):
logits, hidden, gdn_states = _call_backbone(
gen_batch.model,
inputs[None, :],
gen_batch.prompt_cache,
n_confirmed=1,
)
verify_logits = logits[:, 0, :]
bonus_logits = logits[:, 1, :]
logits, hidden, gdn_states = _call_backbone(
gen_batch.model,
inputs[None, :],
gen_batch.prompt_cache,
n_confirmed=1,
)
verify_logits = logits[:, 0, :]
bonus_logits = logits[:, 1, :]
mx.eval(logits)
state.stats.backbone_ms += (time.perf_counter() - t0) * 1000

Expand Down Expand Up @@ -1086,11 +1066,10 @@ def _step_mtp(

t0 = time.perf_counter()
next_ids = next_main_tok.reshape(1, 1)
with mx.stream(_get_generation_stream()):
mtp_logits = gen_batch.model.mtp_forward(
hidden_at_position, next_ids, state.mtp_cache
)
mtp_logits_2d = mtp_logits[:, -1, :]
mtp_logits = gen_batch.model.mtp_forward(
hidden_at_position, next_ids, state.mtp_cache
)
mtp_logits_2d = mtp_logits[:, -1, :]
if procs is not None and prev_buf is not None:
prev_with_next = mx.concatenate(
[prev_buf, _ensure_uint32(next_main_tok)]
Expand Down
Loading