From 7b2e849f560af5071f199e29e3cbc87139c4f96b Mon Sep 17 00:00:00 2001 From: cfbraun Date: Wed, 27 May 2026 09:48:56 +0200 Subject: [PATCH 01/10] fix(boundary-store): serialize cleanup_all() and cleanup_request() with the writer thread (#1423) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit test_cleanup_all_drains_queue failed ~20% of the time, in isolation and in suites. The race: T0 save() puts item in queue T1 writer queue.get() pulls item (already off the queue) T2 cleanup_all() drains queue — finds it empty T3 cleanup_all() rmtree(snapshot_dir) + mkdir(snapshot_dir) T4 writer mkdir(req_dir) + temp write + rename(final) Result: ``req-X/N.safetensors`` survives the supposed cleanup, the caller (scheduler / shutdown) believes the snapshot dir is empty, and the leftover file is the disk shadow of a request that should have been forgotten. Add ``_writer_busy`` (threading.Lock). The writer holds it for the entire duration of each ``_process_write_item`` call. ``cleanup_all()`` drains the queue first (no new items can start) and then acquires ``_writer_busy``, so any item the writer had already pulled finishes before rmtree runs. Same class of race exists for ``cleanup_request()`` — writer pulls a request's item before cleanup_request sets the cancelled-counter, then mkdir's the req_dir back into existence after the rmtree, and the late-rename rescue at the tail of ``_process_write_item`` catches it in most timings but a small window remains. Apply the same ``_writer_busy`` barrier (bounded, see below) so the two cleanup paths are symmetric. Per-path timeouts so cleanup_request can yield faster than cleanup_all: - ``_CLEANUP_ALL_TIMEOUT_S = 5.0`` — cleanup_all runs at reset / startup where blocking longer is tolerable in exchange for a stronger orphan-avoidance guarantee. - ``_CLEANUP_REQUEST_TIMEOUT_S = 2.0`` — cleanup_request is called from the scheduler's abort hot path (~3 sites) where bounding latency matters more than chasing one in-flight write. The bounded acquires have a logged-warning fallback. Without that bound a slow disk inside ``_write_safetensors_no_mx`` could deadlock the scheduler's hot path — cleanup_all() is called from request abort / reset (scheduler.py:5400 / 5596 / 5737), not just shutdown. The worst-case fallback behaviour is an orphaned file under the recreated snapshot dir until next startup's constructor cleanup, which is the exact same state the pre-fix code produced on every cleanup_all. Counter pop on cleanup_request now only runs when the ``_writer_busy`` acquire succeeded. On the timeout fallback the writer is still mid-item and may not yet have consulted ``_is_cancelled``; popping the counter there would defeat the late- rename rescue that the docstring advertises as the timeout- fallback safety net. The previous code popped unconditionally. Counter bump itself is now additive (``get + count``) and skipped entirely when ``count == 0``. Two distinct bugs that closes: - Skip-on-zero: cleanup_request("X") for an rid with NO pending items previously wrote ``cancelled[X] = 0`` then popped on the acquired path. On the timeout fallback the pop never ran and the ``X: 0`` entry lingered for the process lifetime — every later ``save()`` under that rid (or any reuse of the string) was silently discarded by the writer's ``_is_cancelled`` gates, which check key membership not value > 0. Counter must only exist when there is at least one in-flight item to drain it. Regression: test_cleanup_request_no_pending_does_not_pin_counter _on_timeout (new). - Additive: a re-entrant cleanup_request("X") for an rid that already has an in-flight cancellation must NOT overwrite the previous count. The writer's ``cleared_by_cleanup`` branch + ``_writer_busy`` lock together close the file-write race today (so no orphan file slips out), but the per-item dec_cancelled bookkeeping has to balance — overwriting drops remaining decs on the floor and on the next ``save()`` under the same rid the writer would see a non-zero counter from the earlier batch and silently discard the new item. shutdown() now accepts ``cleanup=True`` so callers that want both operations can express the cleanup-before-shutdown ordering in one call instead of sequencing them manually. The cleanup_all() warning at the top of the function catches misordered callers that still call cleanup_all() AFTER shutdown — the writer no longer reacquires ``_writer_busy`` past the sentinel, so a post-shutdown cleanup degrades to an in-memory-only clear. Tests: - ``test_cleanup_all_drains_queue`` — no longer sleeps; relies on the lock to guarantee ordering, runs deterministically. - ``test_cleanup_all_blocks_until_writer_finishes_pinned_item`` — monkey-patches ``_write_safetensors_no_mx`` to block on an Event, pins the writer mid-item, asserts cleanup_all does NOT return until the writer releases. Without the lock this test fails deterministically rather than the original ~20% flake rate. - ``test_cleanup_request_blocks_until_writer_finishes_pinned_item`` — symmetric for cleanup_request. - ``test_cleanup_request_keeps_counter_on_timeout`` — regression for the bug where the timeout-fallback pop dropped the late-rename rescue's safety net. - ``test_cleanup_request_timeout_drains_counter_on_writer_early_return`` — the writer's cleared_by_cleanup early-return must still dec_cancelled or the counter pins the rid forever. - ``test_cleanup_request_no_pending_does_not_pin_counter_on_timeout`` (new) — regression for the skip-on-zero bug above. Pins the writer with an unrelated save, calls cleanup_request("never-saved-rid") past the 0.1s timeout, asserts the rid does NOT appear in ``_cancelled_requests`` AND that a subsequent save under the same rid is not silently discarded. - ``test_shutdown_cleanup_true_runs_cleanup_before_setting_flag`` — pins the cleanup-before-shutdown ordering of the new ``shutdown(cleanup=True)`` path. 24/24 boundary-store tests pass. (cherry picked from commit 4f3a9b994b997af280020d60b27afa0482f29830) --- omlx/cache/boundary_snapshot_store.py | 486 +++++++++++++++++----- tests/test_boundary_snapshot_store.py | 557 +++++++++++++++++++++++++- 2 files changed, 937 insertions(+), 106 deletions(-) diff --git a/omlx/cache/boundary_snapshot_store.py b/omlx/cache/boundary_snapshot_store.py index 902172f10..3cb58f203 100644 --- a/omlx/cache/boundary_snapshot_store.py +++ b/omlx/cache/boundary_snapshot_store.py @@ -56,6 +56,17 @@ class BoundarySnapshotSSDStore: Snapshots are stored under ``base_dir/_boundary_snapshots/``. """ + # Timeouts applied when acquiring _writer_busy from each cleanup + # path. cleanup_request is called from the scheduler's abort hot + # path (~3 sites) and must yield faster than cleanup_all, which + # also runs at startup / reset where blocking longer is tolerable + # in exchange for a stronger orphan-avoidance guarantee. The + # worst-case impact on the timeout fallback is identical in both + # paths — an orphan file in the recreated dir until the next + # constructor cleanup — so the only knob is per-call latency. + _CLEANUP_ALL_TIMEOUT_S = 5.0 + _CLEANUP_REQUEST_TIMEOUT_S = 2.0 + def __init__(self, base_dir: Path) -> None: self._snapshot_dir = base_dir / "_boundary_snapshots" # Clean up orphaned files from previous crashes. @@ -75,14 +86,24 @@ def __init__(self, base_dir: Path) -> None: self._pending_writes: dict[tuple[str, int], dict] = {} self._pending_lock = threading.Lock() - # Cancelled requests with remaining queue item counts. Writer + # Cancelled requests with remaining queue item counts. Writer # thread decrements on each skip; entry is deleted when count - # reaches zero, preventing unbounded growth. + # reaches zero, preventing unbounded growth. All access is + # guarded by ``_cancelled_lock`` — the dict was previously + # mutated unlocked from cleanup_request, cleanup_all, and the + # writer thread, creating lost-cancellation and counter- + # underflow races. self._cancelled_requests: dict[str, int] = {} + self._cancelled_lock = threading.Lock() # Background writer thread. self._write_queue: queue.Queue = queue.Queue(maxsize=_MAX_PENDING_WRITES) self._shutdown = threading.Event() + # Held by the writer for the duration of each item's processing. + # cleanup_all() acquires it after draining the queue so the writer + # can't be mid-item (creating files inside the just-cleaned dir) + # when rmtree runs. + self._writer_busy = threading.Lock() self._writer_thread = threading.Thread( target=self._writer_loop, name="boundary-snapshot-writer", @@ -154,13 +175,30 @@ def save( try: self._write_queue.put_nowait((pw_key, tensors_raw, metadata, file_path)) except queue.Full: + # Roll back the pending + registry entries: with no + # queue item the writer can never decrement + # _cancelled_requests for this entry, so if a later + # cleanup_request counts it the rid stays pinned in + # _cancelled_requests forever and every subsequent + # save under that rid is silently discarded by the + # _is_cancelled gates. The previous "stays in memory + # only" promise was already broken because cleanup + # discards the in-memory copy anyway. logger.warning( - "Boundary snapshot write queue full, snapshot %s/%d " - "stays in memory only", + "Boundary snapshot write queue full, dropping " + "snapshot %s/%d", request_id, token_count, ) - # Still returns True — data is in pending_writes for read-back. + with self._pending_lock: + self._pending_writes.pop(pw_key, None) + with self._registry_lock: + req_files = self._file_registry.get(request_id) + if req_files is not None: + req_files.pop(token_count, None) + if not req_files: + self._file_registry.pop(request_id, None) + return False return True @@ -241,57 +279,256 @@ def cleanup_request(self, request_id: str) -> None: the worker's :meth:`load` calls and silently strip block storage. :class:`omlx.scheduler.Scheduler` defers this call until the ``store_future`` for ``request_id`` is done. + + Acquires ``_writer_busy`` after marking the request cancelled so + the writer thread can finish any item it is mid-processing first. + Without this barrier the writer can pull an item, ``mkdir`` the + request directory, write its temp file, then ``os.rename`` it + into the final path *after* we have rmtree'd — leaving an + orphaned file behind. The ``_cancelled_requests`` counter (held + under ``_cancelled_lock``) catches the late-rename case if + ``_writer_busy.acquire`` times out. + + Bounded with a timeout so a stuck I/O on the writer thread + cannot deadlock request abort paths (called from scheduler's + hot path at ~3 sites). + + The cancelled-counter is bumped additively and only when at + least one pending item exists for the rid — see the inline + comment at the bump site for the two distinct bugs that + rules out (stale ``rid: 0`` after a timeout for an empty + cleanup, and overwrites racing with re-entrant cleanup_request + calls for the same rid). """ - # Count remaining queue items and mark as cancelled. The writer - # thread decrements the count on each skip and removes the entry - # when it reaches zero. + if self._shutdown.is_set(): + # After shutdown the writer no longer reacquires + # _writer_busy per-item, so cleanup_request cannot + # synchronise with it. Best-effort: just drop in-memory + # state. Files (if any leaked through shutdown) are removed + # by the next constructor cleanup_all. + with self._pending_lock: + for k in [k for k in self._pending_writes if k[0] == request_id]: + del self._pending_writes[k] + with self._registry_lock: + self._file_registry.pop(request_id, None) + logger.warning( + "cleanup_request(%s) called after shutdown — running " + "in-memory-only", request_id, + ) + return + + # Atomically: count pending items for this rid, drop them, mark + # the rid cancelled. Holding both locks during the snapshot is + # required to keep the counter consistent with what the writer + # will see — a save() call from another thread cannot interleave + # an enqueue between our count and our cancellation mark. + # + # The bump is additive (``get + count``) and skipped entirely + # when ``count == 0``. Both rules close real bugs: + # * Skip-on-zero: cleanup_request("X") for an rid with no + # pending items previously wrote ``cancelled[X] = 0`` then + # popped it on the acquired path. On the timeout fallback + # the pop never runs and the ``X: 0`` entry lingers for + # the process lifetime — every subsequent save() under + # that rid (or any later reuse of the same string) is + # discarded by the writer's ``_is_cancelled`` gates, + # which check key membership not value > 0. The counter + # must only exist when there is at least one in-flight + # item to drain it. + # * Additive: a re-entrant cleanup_request("X") for an rid + # that already has an in-flight cancellation must NOT + # overwrite the previous count. The writer's + # ``cleared_by_cleanup`` branch + ``_writer_busy`` lock + # together close the file-write race today, but the + # per-item dec_cancelled bookkeeping still has to balance. + # Overwriting drops the remaining decs on the floor; on + # the next ``save()`` under the same rid the writer would + # see a non-zero counter from the earlier batch and + # silently discard the new item. with self._pending_lock: - count = sum(1 for k in self._pending_writes if k[0] == request_id) keys_to_remove = [k for k in self._pending_writes if k[0] == request_id] + count = len(keys_to_remove) for key in keys_to_remove: del self._pending_writes[key] - self._cancelled_requests[request_id] = count + if count > 0: + with self._cancelled_lock: + self._cancelled_requests[request_id] = ( + self._cancelled_requests.get(request_id, 0) + count + ) # Remove from registry. with self._registry_lock: self._file_registry.pop(request_id, None) - # Remove files. - req_dir = self._snapshot_dir / request_id - if req_dir.exists(): - try: - shutil.rmtree(req_dir) - except Exception as e: - logger.debug("Failed to clean up snapshots for %s: %s", request_id, e) + # Wait briefly for the writer to finish any item it had already + # pulled. If it's genuinely stuck (slow disk, dead thread) fall + # back to the cancelled-counter rescue rather than blocking the + # caller. + acquired = self._writer_busy.acquire( + timeout=self._CLEANUP_REQUEST_TIMEOUT_S + ) + try: + # Remove files. + req_dir = self._snapshot_dir / request_id + if req_dir.exists(): + try: + shutil.rmtree(req_dir) + except Exception as e: + logger.debug( + "Failed to clean up snapshots for %s: %s", request_id, e + ) + finally: + if acquired: + self._writer_busy.release() + # Counter entry has done its job — we own the lock so all + # _is_cancelled-gated work has either run or skipped. Drop + # the counter so a future racing save() can't leave it + # elevated forever. CRITICAL: only pop on the acquired + # path. On timeout the writer is still mid-item and may + # not yet have consulted ``_is_cancelled``; popping here + # would defeat the late-rename rescue that the docstring + # advertises as the timeout-fallback safety net. + with self._cancelled_lock: + self._cancelled_requests.pop(request_id, None) + else: + logger.warning( + "cleanup_request(%s): writer thread did not yield " + "within %.1fs; relying on cancelled-counter rescue " + "for late-rename safety", + request_id, + self._CLEANUP_REQUEST_TIMEOUT_S, + ) def cleanup_all(self) -> None: - """Delete all snapshot files (for reset/startup).""" + """Delete all snapshot files (for reset/startup). + + Synchronizes with the background writer: we drain the queue to + prevent it from starting a new item, then acquire ``_writer_busy`` + to wait until any item it had already pulled finishes. Without + this barrier the writer can create ``req-X/temp.safetensors`` + and ``os.rename`` it to its final path *after* we've already + rmtree'd and recreated the snapshot directory, leaving an + orphaned file behind. + + Threading: concurrent ``save()`` is safe because the writer + consults ``_pending_writes`` and ``_is_cancelled`` while + holding ``_writer_busy``, and ``cleanup_all`` clears both + under the same lock before rmtree. The earlier "must run on + the save() thread" constraint is therefore no longer required. + + Invariant enforcement: ``cleanup_all`` must run BEFORE + ``shutdown()`` to actually synchronise with the writer. Once + ``_shutdown`` is set the writer drops the per-item + ``_writer_busy`` acquire, so a post-shutdown ``cleanup_all`` + cannot block on the writer and degrades to an in-memory + clear. Callers that need both should pass ``shutdown( + cleanup=True)`` instead of sequencing the calls themselves. + """ + if self._shutdown.is_set(): + # See cleanup_request: best-effort in-memory clear only. + with self._pending_lock: + self._pending_writes.clear() + with self._registry_lock: + self._file_registry.clear() + with self._cancelled_lock: + self._cancelled_requests.clear() + logger.warning( + "cleanup_all called after shutdown — running in-memory-only; " + "callers wanting on-disk cleanup should use " + "shutdown(cleanup=True) instead" + ) + return + # Drain write queue so the writer thread doesn't process stale - # items after the directory is deleted. + # items after the directory is deleted. Put_nowait the sentinel + # back so shutdown still sees it; on Full just drop and let + # shutdown re-issue. while True: try: item = self._write_queue.get_nowait() if item is None: # Sentinel — put it back for shutdown. - self._write_queue.put(item) + try: + self._write_queue.put_nowait(item) + except queue.Full: + # Drop the sentinel; shutdown will re-enqueue. + # If cleanup_all is the LAST call before process + # exit without an explicit shutdown(), the writer + # thread will only be reaped on daemon teardown. + logger.debug( + "cleanup_all: dropped writer-sentinel on Full" + ) break except queue.Empty: break - with self._pending_lock: - self._pending_writes.clear() - with self._registry_lock: - self._file_registry.clear() - self._cancelled_requests.clear() - - if self._snapshot_dir.exists(): - try: - shutil.rmtree(self._snapshot_dir) - except Exception as e: - logger.debug("Failed to clean up all boundary snapshots: %s", e) - self._snapshot_dir.mkdir(parents=True, exist_ok=True) + # Wait for the writer to finish any item it had already pulled. + # When we own _writer_busy the writer is between items, and we + # just drained the queue so no new item can start. Bounded so a + # stuck writer (slow disk, dead thread) cannot deadlock callers + # — scheduler calls cleanup_all() from its abort / reset hot + # path. After the timeout we proceed anyway: the worst case is + # an orphaned file in the recreated directory, which next + # startup's cleanup_all() will clear. + acquired = self._writer_busy.acquire( + timeout=self._CLEANUP_ALL_TIMEOUT_S + ) + try: + if not acquired: + logger.warning( + "cleanup_all: writer thread did not yield within " + "%.1fs; proceeding with rmtree — late-rename may " + "orphan a file under the recreated snapshot dir " + "until next startup.", + self._CLEANUP_ALL_TIMEOUT_S, + ) + with self._pending_lock: + self._pending_writes.clear() + with self._registry_lock: + self._file_registry.clear() + with self._cancelled_lock: + # Only safe to clear when we own _writer_busy — otherwise + # a writer mid-_dec_cancelled would race. On timeout we + # leave the counter intact so the rescue path stays + # effective for in-flight items. + if acquired: + self._cancelled_requests.clear() + + if self._snapshot_dir.exists(): + try: + shutil.rmtree(self._snapshot_dir) + except Exception as e: + logger.debug( + "Failed to clean up all boundary snapshots: %s", e + ) + self._snapshot_dir.mkdir(parents=True, exist_ok=True) + finally: + if acquired: + self._writer_busy.release() + + def shutdown(self, *, cleanup: bool = False) -> None: + """Stop background writer thread. - def shutdown(self) -> None: - """Stop background writer thread.""" + Parameters + ---------- + cleanup : bool + When True, run ``cleanup_all()`` first, then signal shutdown. + This enforces the cleanup-before-shutdown ordering invariant + in one call. Callers that pass ``cleanup=False`` (the + default) and *also* want cleanup MUST call ``cleanup_all()`` + themselves before ``shutdown()``. + + Invariant: if a caller wants to combine ``cleanup_all()`` with + shutdown, the cleanup MUST run BEFORE ``_shutdown.set()`` / + sentinel-enqueue. Once the sentinel is in the queue and the + writer has consumed it, the writer no longer reacquires + ``_writer_busy`` after each item, so a subsequent + ``cleanup_all`` would wait its full 5s timeout if the writer is + already mid-final-item, then proceed unsynchronised. The + ``cleanup=True`` path handles this ordering; the warning at + the top of ``cleanup_all`` catches misordered callers. + """ + if cleanup: + self.cleanup_all() self._shutdown.set() try: self._write_queue.put_nowait(None) # Sentinel @@ -303,13 +540,21 @@ def shutdown(self) -> None: # Internal # ------------------------------------------------------------------ + def _is_cancelled(self, request_id: str) -> bool: + """Thread-safe check for cancellation.""" + with self._cancelled_lock: + return request_id in self._cancelled_requests + def _dec_cancelled(self, request_id: str) -> None: - """Decrement cancelled counter; remove entry when exhausted.""" - remaining = self._cancelled_requests.get(request_id, 0) - 1 - if remaining <= 0: - self._cancelled_requests.pop(request_id, None) - else: - self._cancelled_requests[request_id] = remaining + """Decrement cancelled counter under lock; remove entry when + exhausted. Atomic read-modify-write closes the underflow race + between two writer-thread iterations / cleanup_all clears.""" + with self._cancelled_lock: + remaining = self._cancelled_requests.get(request_id, 0) - 1 + if remaining <= 0: + self._cancelled_requests.pop(request_id, None) + else: + self._cancelled_requests[request_id] = remaining def _file_path(self, request_id: str, token_count: int) -> Path: return self._snapshot_dir / request_id / f"{token_count}.safetensors" @@ -325,74 +570,117 @@ def _writer_loop(self) -> None: if item is None: # Sentinel break - pw_key, tensors_raw, metadata, file_path = item + # Hold _writer_busy for the entire item's lifetime so + # cleanup_all() can serialize with us — otherwise it can + # rmtree the snapshot directory while we're mid-write and + # we'd recreate ``req-X/`` underneath it, leaving an + # orphaned file after the cleanup returns. + with self._writer_busy: + self._process_write_item(item) + + def _process_write_item(self, item) -> None: + """Process one (pw_key, tensors_raw, metadata, file_path) queue item. + + Extracted from ``_writer_loop`` so the busy-lock can wrap it + cleanly. Called only on the writer thread. + """ + pw_key, tensors_raw, metadata, file_path = item + + # If cleanup_all or cleanup_request cleared this key from + # _pending_writes while the item was in the writer's local hand + # (i.e. between ``get()`` and entering ``with _writer_busy``), + # treat the write as cancelled. This closes the late-rename + # window where cleanup runs entirely between the writer's pull + # and its busy-lock acquisition. + with self._pending_lock: + cleared_by_cleanup = pw_key not in self._pending_writes + if cleared_by_cleanup: + # If a timed-out cleanup_request bumped ``_cancelled_requests`` + # before clearing pending_writes, this item is one of the N + # the counter is waiting on. Without this decrement the + # counter would never reach zero, leaving the rid pinned in + # ``_cancelled_requests`` for the process lifetime and + # causing every subsequent write under that rid (or any + # later reuse of the same string) to be silently discarded. + if self._is_cancelled(pw_key[0]): + self._dec_cancelled(pw_key[0]) + return + + # Skip writes for cancelled/cleaned-up requests. + if self._is_cancelled(pw_key[0]): + with self._pending_lock: + self._pending_writes.pop(pw_key, None) + try: + req_dir = file_path.parent + if req_dir.exists(): + shutil.rmtree(req_dir) + except Exception: + pass + self._dec_cancelled(pw_key[0]) + return + + temp_path = None + try: + file_path.parent.mkdir(parents=True, exist_ok=True) + temp_path = file_path.with_name(file_path.stem + "_tmp.safetensors") + _write_safetensors_no_mx(str(temp_path), tensors_raw, metadata) - # Skip writes for cancelled/cleaned-up requests. - if pw_key[0] in self._cancelled_requests: + # Request may have been cleaned up while serializing. + if self._is_cancelled(pw_key[0]): + try: + if temp_path.exists(): + temp_path.unlink() + except Exception: + pass with self._pending_lock: self._pending_writes.pop(pw_key, None) + self._dec_cancelled(pw_key[0]) + return + + os.rename(str(temp_path), str(file_path)) + + # Cleanup may race with a queued write; remove any late file. + if self._is_cancelled(pw_key[0]): + try: + if file_path.exists(): + file_path.unlink() + except Exception: + pass + req_dir = file_path.parent try: - req_dir = file_path.parent if req_dir.exists(): shutil.rmtree(req_dir) except Exception: pass self._dec_cancelled(pw_key[0]) - continue - - temp_path = None - try: - file_path.parent.mkdir(parents=True, exist_ok=True) - temp_path = file_path.with_name(file_path.stem + "_tmp.safetensors") - _write_safetensors_no_mx(str(temp_path), tensors_raw, metadata) - - # Request may have been cleaned up while serializing. - if pw_key[0] in self._cancelled_requests: - try: - if temp_path.exists(): - temp_path.unlink() - except Exception: - pass - with self._pending_lock: - self._pending_writes.pop(pw_key, None) - self._dec_cancelled(pw_key[0]) - continue - - os.rename(str(temp_path), str(file_path)) - - # Cleanup may race with a queued write; remove any late file. - if pw_key[0] in self._cancelled_requests: - try: - if file_path.exists(): - file_path.unlink() - except Exception: - pass - req_dir = file_path.parent - try: - if req_dir.exists(): - shutil.rmtree(req_dir) - except Exception: - pass - self._dec_cancelled(pw_key[0]) - except Exception as e: - logger.debug("Background snapshot write failed: %s", e) - for p in (temp_path, file_path): - try: - if p is not None and p.exists(): - p.unlink() - except Exception: - pass - finally: - # Remove extracted cache objects from pending writes to free - # memory, but keep tensors_raw for read-back until file is on - # disk. - with self._pending_lock: - pending = self._pending_writes.get(pw_key) - if pending is not None: - pending.pop("extracted", None) - # If file was written successfully, remove entirely. - if file_path.exists(): - self._pending_writes.pop(pw_key, None) + except Exception as e: + logger.debug("Background snapshot write failed: %s", e) + for p in (temp_path, file_path): + try: + if p is not None and p.exists(): + p.unlink() + except Exception: + pass + # Same bookkeeping invariant as the early-return path: if + # cleanup_request bumped the counter and the failure was a + # side-effect of that cleanup (e.g. its rmtree pulled the + # parent dir out from under our temp write), we still owe + # one decrement. The _is_cancelled rescue blocks above all + # return before this except clause runs, so we cannot + # double-decrement. + if self._is_cancelled(pw_key[0]): + self._dec_cancelled(pw_key[0]) + finally: + # Remove extracted cache objects from pending writes to free + # memory, but keep tensors_raw for read-back until file is on + # disk. + with self._pending_lock: + pending = self._pending_writes.get(pw_key) + if pending is not None: + pending.pop("extracted", None) + # If file was written successfully, remove entirely. + if file_path.exists(): + self._pending_writes.pop(pw_key, None) def _serialize_extracted( self, diff --git a/tests/test_boundary_snapshot_store.py b/tests/test_boundary_snapshot_store.py index 5a233ab3a..6af67f978 100644 --- a/tests/test_boundary_snapshot_store.py +++ b/tests/test_boundary_snapshot_store.py @@ -221,24 +221,567 @@ def test_cleanup_request_skips_queued_writes(self): assert not req_dir.exists() def test_cleanup_all_drains_queue(self): - """cleanup_all should drain write queue before deleting directory.""" - import time - + """cleanup_all() should leave the snapshot directory empty no + matter where the writer thread was in its processing cycle. + + Previously this test slept 1.0 s as a guess at the writer's + finish time and was flaky ~20% of the time: the writer could + ``os.rename`` a temp file into its final path *after* cleanup_all + had rmtree'd the directory, leaving an orphaned file. + cleanup_all now holds the writer-busy lock until any in-flight + item is done, so no sleep is required. + """ self.store.save("req-1", 1024, [MagicMock()], _mock_extract_cache_states) self.store.save("req-2", 2048, [MagicMock()], _mock_extract_cache_states) - # Cleanup all before writer thread processes items. + # cleanup_all() must synchronize with the writer. self.store.cleanup_all() - # Wait for writer to finish any in-flight work. - time.sleep(1.0) - # Snapshot directory should be clean (recreated but empty). snapshot_dir = self.base_dir / "_boundary_snapshots" assert snapshot_dir.exists() children = list(snapshot_dir.iterdir()) assert len(children) == 0 + def test_cleanup_all_blocks_until_writer_finishes_pinned_item(self): + """Deterministic regression for the writer-vs-cleanup race. + + Pins the writer mid-item with a slow ``_write_safetensors_no_mx`` + replacement, fires ``cleanup_all()`` from the test thread, and + asserts that: + 1. cleanup_all does not return before the writer finishes its + pinned item (would-be-orphaned rename), AND + 2. the snapshot directory ends up empty. + + Without the ``_writer_busy`` lock this would fail deterministically + rather than flakily — the writer's ``os.rename`` lands after the + rmtree and an orphan survives. + """ + import threading + import time + from unittest.mock import patch + + writer_in_item = threading.Event() + release_writer = threading.Event() + original_write = None + + def slow_write(*args, **kwargs): + writer_in_item.set() + # Hold the writer here so cleanup_all is forced to wait on + # _writer_busy. 1 s is plenty for the test thread to call + # cleanup_all and start blocking. + release_writer.wait(timeout=5.0) + return original_write(*args, **kwargs) + + from omlx.cache import boundary_snapshot_store as mod + + original_write = mod._write_safetensors_no_mx + + with patch.object(mod, "_write_safetensors_no_mx", side_effect=slow_write): + self.store.save("req-pinned", 1024, [MagicMock()], _mock_extract_cache_states) + + # Wait until the writer has picked up the item and is inside + # the slow_write hook. + assert writer_in_item.wait(timeout=5.0), "writer never started" + + # Kick off cleanup_all from a background thread so we can + # observe that it does not complete while the writer is pinned. + cleanup_done = threading.Event() + + def _do_cleanup(): + self.store.cleanup_all() + cleanup_done.set() + + t = threading.Thread(target=_do_cleanup, name="cleanup-all-test") + t.start() + + # cleanup_all must NOT return while the writer holds _writer_busy. + assert not cleanup_done.wait(timeout=0.5), ( + "cleanup_all returned while writer was mid-item — " + "_writer_busy lock is not being honored" + ) + + # Release the writer; cleanup_all should then complete. + release_writer.set() + assert cleanup_done.wait(timeout=10.0), "cleanup_all hung" + t.join(timeout=5.0) + + # Give the writer one more tick to fully exit _process_write_item + # before asserting on the directory. + time.sleep(0.1) + snapshot_dir = self.base_dir / "_boundary_snapshots" + assert snapshot_dir.exists() + assert list(snapshot_dir.iterdir()) == [] + + def test_cleanup_request_blocks_until_writer_finishes_pinned_item(self): + """Symmetric regression to cleanup_all: cleanup_request must also + wait on the writer's in-flight item before rmtree, otherwise the + writer's late ``os.rename`` lands under the just-cleaned dir. + """ + import threading + import time + from unittest.mock import patch + + writer_in_item = threading.Event() + release_writer = threading.Event() + original_write = None + + def slow_write(*args, **kwargs): + writer_in_item.set() + release_writer.wait(timeout=5.0) + return original_write(*args, **kwargs) + + from omlx.cache import boundary_snapshot_store as mod + + original_write = mod._write_safetensors_no_mx + + with patch.object(mod, "_write_safetensors_no_mx", side_effect=slow_write): + self.store.save("req-cleanup", 2048, [MagicMock()], _mock_extract_cache_states) + + assert writer_in_item.wait(timeout=5.0), "writer never started" + + cleanup_done = threading.Event() + + def _do_cleanup(): + self.store.cleanup_request("req-cleanup") + cleanup_done.set() + + t = threading.Thread(target=_do_cleanup, name="cleanup-req-test") + t.start() + + assert not cleanup_done.wait(timeout=0.5), ( + "cleanup_request returned while writer was mid-item — " + "_writer_busy lock is not being honored" + ) + + release_writer.set() + assert cleanup_done.wait(timeout=10.0), "cleanup_request hung" + t.join(timeout=5.0) + + # After cleanup_request the per-request directory must be gone. + time.sleep(0.1) + req_dir = self.base_dir / "_boundary_snapshots" / "req-cleanup" + assert not req_dir.exists() + + def test_cleanup_request_keeps_counter_on_timeout(self): + """When ``cleanup_request`` cannot acquire ``_writer_busy`` within + ``_CLEANUP_REQUEST_TIMEOUT_S``, it must NOT pop + ``_cancelled_requests[request_id]``: the counter is the rescue + path the docstring promises for the late-rename window. The + previous code popped unconditionally and silently defeated the + rescue. Regression for the real bug found in review. + """ + import threading + import time + from unittest.mock import patch + + # Pin the writer mid-item so the cleanup_request acquire times out. + writer_in_item = threading.Event() + release_writer = threading.Event() + original_write = None + + def slow_write(*args, **kwargs): + writer_in_item.set() + release_writer.wait(timeout=10.0) + return original_write(*args, **kwargs) + + from omlx.cache import boundary_snapshot_store as mod + + original_write = mod._write_safetensors_no_mx + + # Tighten the timeout for the test so the test runs fast. + with patch.object( + type(self.store), "_CLEANUP_REQUEST_TIMEOUT_S", 0.1 + ), patch.object(mod, "_write_safetensors_no_mx", side_effect=slow_write): + self.store.save( + "req-timeout-rescue", + 2048, + [MagicMock()], + _mock_extract_cache_states, + ) + assert writer_in_item.wait(timeout=5.0), "writer never started" + + # cleanup_request returns once the 0.1s timeout fires — writer + # is still pinned. The counter MUST remain so _is_cancelled + # can still catch the late rename. + self.store.cleanup_request("req-timeout-rescue") + + with self.store._cancelled_lock: + assert ( + "req-timeout-rescue" in self.store._cancelled_requests + ), ( + "counter dropped on timeout — late-rename rescue " + "would be defeated" + ) + + # Let the writer finish; rescue then drops the counter via + # _is_cancelled → _dec_cancelled. + release_writer.set() + time.sleep(0.5) + + def test_cleanup_request_timeout_drains_counter_on_writer_early_return( + self, + ): + """Regression: when ``cleanup_request`` times out while + ``_cancelled_requests[rid]`` is non-zero, items that the writer + later dequeues but whose pending entry was already cleared by + cleanup must still decrement the counter on the early-return + path. Without that decrement the rid stays in + ``_cancelled_requests`` for the process lifetime and every + future write under that rid is silently discarded by the + ``_is_cancelled`` gates. + """ + import threading + import time + from unittest.mock import patch + + writer_in_item = threading.Event() + release_writer = threading.Event() + original_write = None + + def slow_write(*args, **kwargs): + writer_in_item.set() + release_writer.wait(timeout=10.0) + return original_write(*args, **kwargs) + + from omlx.cache import boundary_snapshot_store as mod + + original_write = mod._write_safetensors_no_mx + + with patch.object( + type(self.store), "_CLEANUP_REQUEST_TIMEOUT_S", 0.1 + ), patch.object( + mod, "_write_safetensors_no_mx", side_effect=slow_write + ): + # Two items for the same rid: A pins the writer; B sits in + # the queue behind A. + self.store.save( + "req-drain", 2048, [MagicMock()], + _mock_extract_cache_states, + ) + self.store.save( + "req-drain", 4096, [MagicMock()], + _mock_extract_cache_states, + ) + assert writer_in_item.wait(timeout=5.0), ( + "writer never started item A" + ) + + # cleanup_request snapshots both pending items, sets + # counter=2, then times out (writer still pinned on A). + self.store.cleanup_request("req-drain") + with self.store._cancelled_lock: + assert ( + self.store._cancelled_requests.get("req-drain") == 2 + ), ( + "cleanup_request did not record both pending items " + "before timing out" + ) + + # Releasing A lets the writer finish: post-rename + # _is_cancelled fires → 2→1. The queue then advances to B + # whose pending entry was already cleared by cleanup; + # writer's early-return path MUST decrement 1→0 and pop. + release_writer.set() + + deadline = time.monotonic() + 5.0 + while time.monotonic() < deadline: + with self.store._cancelled_lock: + if "req-drain" not in self.store._cancelled_requests: + break + time.sleep(0.02) + else: + with self.store._cancelled_lock: + state = dict(self.store._cancelled_requests) + raise AssertionError( + "_cancelled_requests still pins 'req-drain' after " + f"both items processed: {state}" + ) + + def test_cleanup_request_no_pending_does_not_pin_counter_on_timeout(self): + """Regression: ``cleanup_request("X")`` for an rid with NO + pending items must NOT bump ``_cancelled_requests[X] = 0``. + + Previously the unconditional bump would write ``X: 0``, then on + the acquired path pop it. On the timeout fallback the pop never + ran and the ``X: 0`` entry lingered for the process lifetime — + every subsequent ``save()`` under that rid (or any later reuse + of the same string) was silently discarded by the writer's + ``_is_cancelled`` gates, which check key membership not + value > 0. + """ + import threading + import time + from unittest.mock import patch + + # Pin the writer with an unrelated save so cleanup_request's + # _writer_busy.acquire times out without any item for our rid. + writer_in_item = threading.Event() + release_writer = threading.Event() + original_write = None + + def slow_write(*args, **kwargs): + writer_in_item.set() + release_writer.wait(timeout=10.0) + return original_write(*args, **kwargs) + + from omlx.cache import boundary_snapshot_store as mod + + original_write = mod._write_safetensors_no_mx + + with patch.object( + type(self.store), "_CLEANUP_REQUEST_TIMEOUT_S", 0.1 + ), patch.object( + mod, "_write_safetensors_no_mx", side_effect=slow_write + ): + self.store.save( + "req-blocker", 2048, [MagicMock()], + _mock_extract_cache_states, + ) + assert writer_in_item.wait(timeout=5.0), ( + "writer never started blocker item" + ) + + # cleanup_request for an rid that was NEVER saved. count==0. + # _writer_busy is held by the blocker → acquire times out. + self.store.cleanup_request("never-saved-rid") + + with self.store._cancelled_lock: + assert ( + "never-saved-rid" not in self.store._cancelled_requests + ), ( + "cleanup_request bumped _cancelled_requests for an " + "rid with no pending items — the stale 0-counter " + "would silently kill every future save under that rid" + ) + + # Verify the bug's downstream consequence directly: + # a save() under the same rid must succeed, not be discarded + # by the writer's _is_cancelled gates. + release_writer.set() + time.sleep(0.2) # let blocker drain + ok = self.store.save( + "never-saved-rid", 4096, [MagicMock()], + _mock_extract_cache_states, + ) + assert ok, "save() failed" + # Wait for the writer to finish. + deadline = time.monotonic() + 5.0 + while time.monotonic() < deadline: + if self.store.has("never-saved-rid", 4096): + break + time.sleep(0.02) + file_path = ( + self.base_dir / "_boundary_snapshots" / "never-saved-rid" + / "4096.safetensors" + ) + # Either the file is on disk OR still buffered in pending — + # but it must not have been silently discarded. + with self.store._pending_lock: + still_pending = ( + "never-saved-rid", 4096 + ) in self.store._pending_writes + assert file_path.exists() or still_pending, ( + "save() under rid was silently discarded — stale " + "_cancelled_requests entry defeated the new write" + ) + + def test_save_queue_full_rolls_back_pending_and_registry(self): + """When the writer queue is saturated, ``save()`` must roll back + its pending_writes / file_registry entries and return False. + Otherwise a later ``cleanup_request`` for the same rid would + count this orphan entry into ``_cancelled_requests`` while no + queue item ever exists to decrement it — the rid would stay + pinned in the cancelled set and every subsequent save under + that rid would be silently discarded by the ``_is_cancelled`` + gates. + """ + from unittest.mock import patch + import queue as _queue + + def _full(*args, **kwargs): + raise _queue.Full + + with patch.object( + self.store._write_queue, "put_nowait", side_effect=_full + ): + ok = self.store.save( + "req-qfull", 2048, [MagicMock()], + _mock_extract_cache_states, + ) + + assert ok is False, ( + "save() must return False when the queue is full so the " + "caller knows the write was dropped" + ) + with self.store._pending_lock: + assert ("req-qfull", 2048) not in self.store._pending_writes + with self.store._registry_lock: + assert "req-qfull" not in self.store._file_registry + + # cleanup_request on the same rid must NOT pin the counter. + self.store.cleanup_request("req-qfull") + with self.store._cancelled_lock: + assert "req-qfull" not in self.store._cancelled_requests + + def test_shutdown_cleanup_true_runs_cleanup_before_setting_flag(self): + """``shutdown(cleanup=True)`` must run ``cleanup_all()`` BEFORE + flipping ``_shutdown`` so the writer still reacquires + ``_writer_busy`` per item during the cleanup. Otherwise the + cleanup degrades to an in-memory-only clear (see the + post-shutdown branch in ``cleanup_all``). + """ + # Save a block first so cleanup_all has something to drain. + self.store.save("req-shutdown", 1024, [MagicMock()], _mock_extract_cache_states) + + # Use a small custom store so we can shut down without affecting + # other tests in this class. + from pathlib import Path + import tempfile + + with tempfile.TemporaryDirectory() as td: + store2 = BoundarySnapshotSSDStore(Path(td)) + try: + store2.save( + "req-x", 256, [MagicMock()], _mock_extract_cache_states + ) + snapshot_dir = store2._snapshot_dir + assert snapshot_dir.exists() + store2.shutdown(cleanup=True) + # After cleanup_all+shutdown the per-request dir is empty + # of leftover request subdirs (the cleanup itself rmtrees + # then mkdirs the parent). + assert snapshot_dir.exists() + leftover = list(snapshot_dir.iterdir()) + assert leftover == [], ( + f"shutdown(cleanup=True) left files behind: {leftover}" + ) + finally: + if store2._writer_thread.is_alive(): + store2.shutdown() + + def test_cancelled_requests_dict_is_thread_safe(self): + """Concurrent cleanup_request + writer should not race on + _cancelled_requests. Without locking, the counter underflows or + cancellation can be silently lost. + """ + import threading + + # Fire many concurrent cleanup_request calls against requests + # that don't have any pending items — exercises the lock acquire + # / set / clear paths without needing real file I/O. + errors: list[Exception] = [] + + def cancel_loop(rid_prefix: str): + try: + for i in range(200): + self.store.cleanup_request(f"{rid_prefix}-{i}") + except Exception as e: + errors.append(e) + + threads = [ + threading.Thread(target=cancel_loop, args=(f"t{tid}",)) + for tid in range(4) + ] + for t in threads: + t.start() + for t in threads: + t.join(timeout=10.0) + assert not errors, errors + # The dict must not be in a corrupt state — clear() and len() + # both succeed. + with self.store._cancelled_lock: + assert len(self.store._cancelled_requests) >= 0 + + + def test_concurrent_save_cleanup_request_cleanup_all_no_orphans(self): + """Stress: concurrent save() + cleanup_request() + cleanup_all(). + + Regression target: the late-rename window where the writer pulled + an item from the queue but had not yet entered the busy-lock + critical section while cleanup ran would leave an orphaned file + under the recreated snapshot directory. The _process_write_item + pending-writes membership check closes that window. + + Test asserts: after all activity quiesces, every file on disk + also has a corresponding entry in _file_registry — i.e. no + orphans. + """ + import threading + import time as _time + + stop = threading.Event() + errors: list[Exception] = [] + + def saver(rid_prefix: str): + try: + tc = 0 + while not stop.is_set(): + tc += 1 + self.store.save( + f"{rid_prefix}-{tc % 7}", + tc * 1024, + [MagicMock()], + _mock_extract_cache_states, + ) + except Exception as e: + errors.append(e) + + def cleaner(rid_prefix: str): + try: + tc = 0 + while not stop.is_set(): + tc += 1 + self.store.cleanup_request(f"{rid_prefix}-{tc % 7}") + _time.sleep(0.001) + except Exception as e: + errors.append(e) + + def all_cleaner(): + try: + while not stop.is_set(): + _time.sleep(0.05) + self.store.cleanup_all() + except Exception as e: + errors.append(e) + + threads = [ + threading.Thread(target=saver, args=("a",)), + threading.Thread(target=saver, args=("b",)), + threading.Thread(target=cleaner, args=("a",)), + threading.Thread(target=cleaner, args=("b",)), + threading.Thread(target=all_cleaner), + ] + for t in threads: + t.start() + _time.sleep(1.5) + stop.set() + for t in threads: + t.join(timeout=10.0) + assert not errors, errors + + # Let writer drain. + _time.sleep(0.5) + + # Orphan check: every .safetensors on disk must have a matching + # registry entry. The reverse direction is fine to drift (the + # registry may have entries the writer hasn't materialised yet). + snap_root = self.base_dir / "_boundary_snapshots" + on_disk = list(snap_root.rglob("*.safetensors")) + registered_paths: set[Path] = set() + with self.store._registry_lock: + for tc_to_path in self.store._file_registry.values(): + registered_paths.update(tc_to_path.values()) + + orphans = [p for p in on_disk if p not in registered_paths] + # Allow a small tolerance for in-flight temp files only — those + # have "_tmp" in the stem and are not real orphans. + real_orphans = [p for p in orphans if "_tmp" not in p.stem] + assert not real_orphans, ( + f"Found {len(real_orphans)} orphaned files: " + f"{real_orphans[:5]}" + ) + # --------------------------------------------------------------------------- # _BoundarySnapshotProvider tests From 0a65ddc9b235d38414db8d5d2069d331a1494599 Mon Sep 17 00:00:00 2001 From: jundot Date: Wed, 27 May 2026 17:02:58 +0900 Subject: [PATCH 02/10] cleanup(boundary-store): drop unreachable shutdown(cleanup=) path Scheduler.shutdown() never calls _boundary_snapshot_store.shutdown(), so _shutdown.is_set() is always False in production. The post-shutdown branches in cleanup_request / cleanup_all and the cleanup= kwarg on shutdown() were dead code, kept alive only by a self-referential regression test. Drops them along with the test. (cherry picked from commit bc1c427bf62f557b2645f05584eb3e135ca29236) --- omlx/cache/boundary_snapshot_store.py | 66 +-------------------------- tests/test_boundary_snapshot_store.py | 36 --------------- 2 files changed, 2 insertions(+), 100 deletions(-) diff --git a/omlx/cache/boundary_snapshot_store.py b/omlx/cache/boundary_snapshot_store.py index 3cb58f203..33052bc6b 100644 --- a/omlx/cache/boundary_snapshot_store.py +++ b/omlx/cache/boundary_snapshot_store.py @@ -300,23 +300,6 @@ def cleanup_request(self, request_id: str) -> None: cleanup, and overwrites racing with re-entrant cleanup_request calls for the same rid). """ - if self._shutdown.is_set(): - # After shutdown the writer no longer reacquires - # _writer_busy per-item, so cleanup_request cannot - # synchronise with it. Best-effort: just drop in-memory - # state. Files (if any leaked through shutdown) are removed - # by the next constructor cleanup_all. - with self._pending_lock: - for k in [k for k in self._pending_writes if k[0] == request_id]: - del self._pending_writes[k] - with self._registry_lock: - self._file_registry.pop(request_id, None) - logger.warning( - "cleanup_request(%s) called after shutdown — running " - "in-memory-only", request_id, - ) - return - # Atomically: count pending items for this rid, drop them, mark # the rid cancelled. Holding both locks during the snapshot is # required to keep the counter consistent with what the writer @@ -415,30 +398,7 @@ def cleanup_all(self) -> None: holding ``_writer_busy``, and ``cleanup_all`` clears both under the same lock before rmtree. The earlier "must run on the save() thread" constraint is therefore no longer required. - - Invariant enforcement: ``cleanup_all`` must run BEFORE - ``shutdown()`` to actually synchronise with the writer. Once - ``_shutdown`` is set the writer drops the per-item - ``_writer_busy`` acquire, so a post-shutdown ``cleanup_all`` - cannot block on the writer and degrades to an in-memory - clear. Callers that need both should pass ``shutdown( - cleanup=True)`` instead of sequencing the calls themselves. """ - if self._shutdown.is_set(): - # See cleanup_request: best-effort in-memory clear only. - with self._pending_lock: - self._pending_writes.clear() - with self._registry_lock: - self._file_registry.clear() - with self._cancelled_lock: - self._cancelled_requests.clear() - logger.warning( - "cleanup_all called after shutdown — running in-memory-only; " - "callers wanting on-disk cleanup should use " - "shutdown(cleanup=True) instead" - ) - return - # Drain write queue so the writer thread doesn't process stale # items after the directory is deleted. Put_nowait the sentinel # back so shutdown still sees it; on Full just drop and let @@ -505,30 +465,8 @@ def cleanup_all(self) -> None: if acquired: self._writer_busy.release() - def shutdown(self, *, cleanup: bool = False) -> None: - """Stop background writer thread. - - Parameters - ---------- - cleanup : bool - When True, run ``cleanup_all()`` first, then signal shutdown. - This enforces the cleanup-before-shutdown ordering invariant - in one call. Callers that pass ``cleanup=False`` (the - default) and *also* want cleanup MUST call ``cleanup_all()`` - themselves before ``shutdown()``. - - Invariant: if a caller wants to combine ``cleanup_all()`` with - shutdown, the cleanup MUST run BEFORE ``_shutdown.set()`` / - sentinel-enqueue. Once the sentinel is in the queue and the - writer has consumed it, the writer no longer reacquires - ``_writer_busy`` after each item, so a subsequent - ``cleanup_all`` would wait its full 5s timeout if the writer is - already mid-final-item, then proceed unsynchronised. The - ``cleanup=True`` path handles this ordering; the warning at - the top of ``cleanup_all`` catches misordered callers. - """ - if cleanup: - self.cleanup_all() + def shutdown(self) -> None: + """Stop background writer thread.""" self._shutdown.set() try: self._write_queue.put_nowait(None) # Sentinel diff --git a/tests/test_boundary_snapshot_store.py b/tests/test_boundary_snapshot_store.py index 6af67f978..cc3f16161 100644 --- a/tests/test_boundary_snapshot_store.py +++ b/tests/test_boundary_snapshot_store.py @@ -624,42 +624,6 @@ def _full(*args, **kwargs): with self.store._cancelled_lock: assert "req-qfull" not in self.store._cancelled_requests - def test_shutdown_cleanup_true_runs_cleanup_before_setting_flag(self): - """``shutdown(cleanup=True)`` must run ``cleanup_all()`` BEFORE - flipping ``_shutdown`` so the writer still reacquires - ``_writer_busy`` per item during the cleanup. Otherwise the - cleanup degrades to an in-memory-only clear (see the - post-shutdown branch in ``cleanup_all``). - """ - # Save a block first so cleanup_all has something to drain. - self.store.save("req-shutdown", 1024, [MagicMock()], _mock_extract_cache_states) - - # Use a small custom store so we can shut down without affecting - # other tests in this class. - from pathlib import Path - import tempfile - - with tempfile.TemporaryDirectory() as td: - store2 = BoundarySnapshotSSDStore(Path(td)) - try: - store2.save( - "req-x", 256, [MagicMock()], _mock_extract_cache_states - ) - snapshot_dir = store2._snapshot_dir - assert snapshot_dir.exists() - store2.shutdown(cleanup=True) - # After cleanup_all+shutdown the per-request dir is empty - # of leftover request subdirs (the cleanup itself rmtrees - # then mkdirs the parent). - assert snapshot_dir.exists() - leftover = list(snapshot_dir.iterdir()) - assert leftover == [], ( - f"shutdown(cleanup=True) left files behind: {leftover}" - ) - finally: - if store2._writer_thread.is_alive(): - store2.shutdown() - def test_cancelled_requests_dict_is_thread_safe(self): """Concurrent cleanup_request + writer should not race on _cancelled_requests. Without locking, the counter underflows or From 89f3b99cac93b5335be2c7486fbd53c09a9c8c65 Mon Sep 17 00:00:00 2001 From: cfbraun Date: Wed, 27 May 2026 09:31:27 +0200 Subject: [PATCH 03/10] cleanup(cache): remove dead TieredCacheManager (#1422) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Verified unused: zero instantiations across the entire git history (`git log -S 'TieredCacheManager('` returns only the initial-commit declaration), zero imports outside the module's own __init__ re-export, zero test references. The class was a planned coordinator between PagedCacheManager / BlockAwarePrefixCache / PagedSSDCacheManager / MemoryMonitor — Scheduler.__init__ does that wiring directly, making the abstraction redundant. Removes 353 lines plus the re-export from omlx.cache.__all__. (cherry picked from commit 2916ab40059b4ac24853f851fb3fef36a6fc9833) --- omlx/cache/__init__.py | 2 - omlx/cache/tiered_manager.py | 353 ----------------------------------- 2 files changed, 355 deletions(-) delete mode 100644 omlx/cache/tiered_manager.py diff --git a/omlx/cache/__init__.py b/omlx/cache/__init__.py index da625bcab..95c2fbee6 100644 --- a/omlx/cache/__init__.py +++ b/omlx/cache/__init__.py @@ -52,7 +52,6 @@ ) # Managers -from .tiered_manager import TieredCacheManager from .recovery import CacheRecoveryManager # Factory @@ -109,7 +108,6 @@ "VisionFeatureSSDCache", "VisionFeatureSSDEntry", # Managers - "TieredCacheManager", "CacheRecoveryManager", # Factory "CacheConfig", diff --git a/omlx/cache/tiered_manager.py b/omlx/cache/tiered_manager.py deleted file mode 100644 index 0a74cc2b5..000000000 --- a/omlx/cache/tiered_manager.py +++ /dev/null @@ -1,353 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -""" -Tiered Cache Manager for oMLX. - -This module manages hot/cold tiered KV caching, enabling automatic paged SSD offloading -when GPU memory is under pressure. - -In paged SSD-only mode: -- All KV cache data is stored on paged SSD via PagedSSDCacheManager -- PagedCacheManager only stores block metadata (no GPU memory for cache data) -- BatchGenerator handles GPU memory for active inference -""" - -import logging -from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Optional - -from omlx.utils.formatting import format_bytes - -if TYPE_CHECKING: - from ..paged_cache import PagedCacheManager - from ..prefix_cache import BlockAwarePrefixCache - from ..paged_ssd_cache import PagedSSDCacheManager - from ..memory_monitor import MemoryMonitor - -logger = logging.getLogger(__name__) - - -class TieredCacheManager: - """ - Manages hot/cold tiered KV caching. - - This class coordinates between PagedCacheManager (hot cache in GPU memory) - and PagedSSDCacheManager (cold cache on disk) to provide efficient memory usage. - - In paged SSD-only mode, all KV cache data is stored on paged SSD. The PagedCacheManager - only manages block metadata, and BatchGenerator handles GPU memory for - active inference. - """ - - def __init__( - self, - paged_cache_manager: Optional["PagedCacheManager"] = None, - block_aware_cache: Optional["BlockAwarePrefixCache"] = None, - paged_ssd_cache_manager: Optional["PagedSSDCacheManager"] = None, - memory_monitor: Optional["MemoryMonitor"] = None, - block_size: int = 256, - ): - """ - Initialize the tiered cache manager. - - Args: - paged_cache_manager: Manager for paged KV cache. - block_aware_cache: Block-aware prefix cache. - paged_ssd_cache_manager: Manager for paged SSD storage. - memory_monitor: Monitor for memory pressure. - block_size: Tokens per block. - """ - self.paged_cache_manager = paged_cache_manager - self.block_aware_cache = block_aware_cache - self.paged_ssd_cache_manager = paged_ssd_cache_manager - self.memory_monitor = memory_monitor - self.block_size = block_size - - @classmethod - def from_config( - cls, - paged_cache_manager: Optional["PagedCacheManager"], - block_aware_cache: Optional["BlockAwarePrefixCache"], - paged_ssd_cache_dir: Optional[str], - paged_ssd_cache_max_size: int, - block_size: int, - model: Any = None, - hot_cache_max_bytes: int = 0, - ) -> Optional["TieredCacheManager"]: - """ - Create a TieredCacheManager from configuration. - - Args: - paged_cache_manager: Manager for paged KV cache. - block_aware_cache: Block-aware prefix cache. - paged_ssd_cache_dir: Path for paged SSD cache storage (None = disabled). - paged_ssd_cache_max_size: Maximum paged SSD cache size in bytes. - block_size: Tokens per block. - model: The model (for extracting KV cache dimensions). - hot_cache_max_bytes: Maximum in-memory hot cache size (0 = disabled). - - Returns: - TieredCacheManager instance or None if tiered caching is disabled. - """ - # Import here to avoid circular imports - try: - from ..paged_ssd_cache import PagedSSDCacheManager - from ..memory_monitor import MemoryMonitor - except ImportError: - if paged_ssd_cache_dir: - logger.warning( - "Paged paged SSD cache requested but paged_ssd_cache/memory_monitor modules " - "not available. Install required dependencies." - ) - return None - - if not paged_cache_manager: - if paged_ssd_cache_dir: - logger.warning( - "Paged paged SSD cache requires paged cache. Ignoring paged_ssd_cache_dir." - ) - return None - - if not paged_ssd_cache_dir: - logger.debug("Paged paged SSD cache not configured (no --paged-ssd-cache-dir specified)") - return None - - try: - # Initialize paged SSD cache manager - paged_ssd_cache_manager = PagedSSDCacheManager( - cache_dir=Path(paged_ssd_cache_dir), - max_size_bytes=paged_ssd_cache_max_size, - hot_cache_max_bytes=hot_cache_max_bytes, - ) - - # Connect paged SSD cache manager to PagedCacheManager - paged_cache_manager.set_paged_ssd_cache_manager(paged_ssd_cache_manager) - - # Connect paged SSD cache manager to BlockAwarePrefixCache for paged SSD-only mode - if block_aware_cache is not None: - block_aware_cache.set_paged_ssd_cache_manager(paged_ssd_cache_manager) - - manager = cls( - paged_cache_manager=paged_cache_manager, - block_aware_cache=block_aware_cache, - paged_ssd_cache_manager=paged_ssd_cache_manager, - memory_monitor=None, # Memory monitor not used in paged SSD-only mode - block_size=block_size, - ) - - logger.info( - f"Paged paged SSD cache enabled: " - f"cache_dir={paged_ssd_cache_dir}, " - f"max_size={format_bytes(paged_ssd_cache_max_size)}, " - f"block_size={block_size} tokens" - ) - - return manager - - except Exception as e: - logger.error(f"Failed to initialize paged SSD cache: {e}") - return None - - def check_memory_pressure(self) -> bool: - """ - Check memory and evict blocks if needed. - - In paged SSD-only mode, memory pressure is not monitored since - KV cache data is stored on paged SSD, not GPU memory. - - Returns: - True if eviction was performed. - """ - # In paged SSD-only mode, memory_monitor is not used - # All KV cache data is on paged SSD, so no GPU memory pressure from PagedCache - return False - - def evict_blocks_permanently(self, bytes_to_free: int) -> int: - """ - Evict LRU blocks permanently (metadata cleanup). - - In paged SSD-only mode, blocks don't store data in GPU memory. - This method just removes block metadata to free up slots. - - Args: - bytes_to_free: Target bytes to free (used for estimation). - - Returns: - Number of bytes freed (estimated). - """ - if self.paged_cache_manager is None or self.memory_monitor is None: - return 0 - - # Estimate how many blocks to evict - num_blocks_to_evict = self.memory_monitor.estimate_blocks_to_free( - bytes_to_free, self.block_size - ) - - # Get evictable blocks in LRU order - evictable = self.paged_cache_manager.get_evictable_blocks(num_blocks_to_evict) - - if not evictable: - logger.debug("No evictable blocks found for permanent eviction") - return 0 - - freed = 0 - evicted_count = 0 - - for block in evictable: - # In paged SSD-only mode, just clear metadata (data is on paged SSD) - if self.paged_cache_manager.evict_block_permanently(block.block_id): - freed += self.memory_monitor.estimate_block_memory(self.block_size) - evicted_count += 1 - - if freed >= bytes_to_free: - break - - if evicted_count > 0: - logger.info( - f"Evicted {evicted_count} blocks permanently " - f"(~{format_bytes(freed)} estimated)" - ) - - return freed - - def evict_blocks_to_cold(self, bytes_to_free: int) -> int: - """ - Evict LRU blocks (with paged SSD cache configured). - - In paged SSD-only mode, data is already on paged SSD, so this just evicts - block metadata from the index. The data remains on paged SSD and can - be re-discovered if the same token sequence is requested. - - Args: - bytes_to_free: Target bytes to free (used for estimation). - - Returns: - Number of bytes freed (estimated). - """ - if self.paged_cache_manager is None or self.paged_ssd_cache_manager is None: - return 0 - - if self.memory_monitor is None: - return 0 - - # Estimate how many blocks to evict - num_blocks_to_evict = self.memory_monitor.estimate_blocks_to_free( - bytes_to_free, self.block_size - ) - - # Get evictable blocks in LRU order - evictable = self.paged_cache_manager.get_evictable_blocks(num_blocks_to_evict) - - if not evictable: - logger.debug("No evictable blocks found") - return 0 - - evicted_count = 0 - - for block in evictable: - # In paged SSD-only mode, data is already on paged SSD - # Just evict the block metadata - if self.paged_cache_manager.evict_block_permanently(block.block_id): - evicted_count += 1 - - # Estimate bytes freed based on block count - estimated_freed = evicted_count * self.memory_monitor.estimate_block_memory( - self.block_size - ) - - if evicted_count > 0: - logger.info( - f"Evicted {evicted_count} blocks from index " - f"(data preserved on paged SSD, ~{format_bytes(estimated_freed)} metadata freed)" - ) - - return estimated_freed - - def restore_block_from_cold(self, block_id: int, block_hash: bytes) -> bool: - """ - Restore a block from cold storage (deprecated in paged SSD-only mode). - - In paged SSD-only mode, blocks don't store cache_data. Data is loaded - directly from SSD when needed via reconstruct_cache(). - - Kept for API compatibility. - - Args: - block_id: Block ID to restore. - block_hash: Block's content hash. - - Returns: - True if block exists in cold storage. - """ - if self.paged_ssd_cache_manager is None or self.paged_cache_manager is None: - return False - - # In paged SSD-only mode, just verify block exists on paged SSD - if not self.paged_ssd_cache_manager.has_block(block_hash): - logger.warning(f"Block {block_id} not found in cold storage") - return False - - # Touch the block to update LRU - blocks = self.paged_cache_manager.blocks - if block_id < len(blocks): - block = blocks[block_id] - if block: - block.touch() - - logger.debug( - f"Block {block_id} verified on paged SSD (hash={block_hash.hex()[:16]}...)" - ) - return True - - def restore_cold_blocks_for_request(self, request_id: str) -> int: - """ - Verify all blocks needed for a request exist on paged SSD. - - In paged SSD-only mode, blocks don't store cache_data. This method - just verifies that blocks exist on paged SSD. - - Args: - request_id: Request ID. - - Returns: - Number of blocks verified on paged SSD. - """ - if self.paged_cache_manager is None or self.paged_ssd_cache_manager is None: - return 0 - - if self.block_aware_cache is None: - return 0 - - # Get block table for request - block_table = self.paged_cache_manager.request_tables.get(request_id) - if block_table is None: - return 0 - - verified = 0 - for block_id in block_table.block_ids: - blocks = self.paged_cache_manager.blocks - if block_id < len(blocks): - block = blocks[block_id] - if block and block.block_hash is not None: - if self.restore_block_from_cold(block_id, block.block_hash): - verified += 1 - - return verified - - def get_stats(self) -> Optional[Dict[str, Any]]: - """ - Get tiered cache statistics. - - Returns: - Dictionary with cache statistics. - """ - stats = {} - - if self.paged_ssd_cache_manager is not None: - stats["ssd_cache"] = self.paged_ssd_cache_manager.get_stats() - - if self.paged_cache_manager is not None: - # In paged SSD-only mode, all cache data is on paged SSD - stats["indexed_blocks"] = self.paged_cache_manager.cold_block_count - stats["block_size"] = self.block_size - - return stats if stats else None From fc26ab38c60924631d7014d73894e9ebb755f7ea Mon Sep 17 00:00:00 2001 From: ivaniguarans Date: Wed, 27 May 2026 01:05:39 -0400 Subject: [PATCH 04/10] fix(engine): per-engine threads to eliminate cross-engine stream contamination (#1304) Replace the shared _global_mlx_executor with per-EngineCore ThreadPoolExecutor + mx.Stream, and fix the MTP patch reading the module-level generation_stream instead of the per-engine stream. (cherry picked from commit 56860b325a5a4b48c1b38e8d88e1079713db3ffb) --- omlx/engine_core.py | 42 +++- omlx/patches/mlx_lm_mtp/batch_generator.py | 62 ++---- omlx/scheduler.py | 173 ++++++++-------- tests/test_engine_core.py | 27 ++- tests/test_per_engine_threads.py | 217 +++++++++++++++++++++ tests/test_scheduler.py | 68 +++++++ 6 files changed, 447 insertions(+), 142 deletions(-) create mode 100644 tests/test_per_engine_threads.py diff --git a/omlx/engine_core.py b/omlx/engine_core.py index 7577ee94a..20af9c271 100644 --- a/omlx/engine_core.py +++ b/omlx/engine_core.py @@ -78,6 +78,22 @@ def get_mlx_executor() -> concurrent.futures.ThreadPoolExecutor: return _global_mlx_executor +_wired_limit_set = False + + +def _ensure_wired_limit() -> None: + """Set Metal wired memory limit once at first engine creation. + + BatchGenerator normally calls mx.set_wired_limit() per-instance, which + races when multiple engines init concurrently (process-global setting). + We call it once here instead. + """ + global _wired_limit_set + if not _wired_limit_set and mx.metal.is_available(): + mx.set_wired_limit(mx.device_info()["max_recommended_working_set_size"]) + _wired_limit_set = True + + @dataclass class EngineConfig: """Configuration for the engine.""" @@ -133,12 +149,23 @@ def __init__( ) self._owns_model = True - # Create scheduler + # Per-engine executor with dedicated mx.Stream (#1248). + # Each EngineCore gets its own thread + GPU stream so different + # models can run scheduler.step() concurrently. + _ensure_wired_limit() + self._mlx_stream = mx.new_thread_local_stream(mx.default_device()) + self._mlx_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=1, + thread_name_prefix=f"mlx-engine-{self._engine_id[:8]}", + ) + + # Create scheduler with per-engine stream scheduler_config = self.config.scheduler_config or SchedulerConfig() self.scheduler = Scheduler( model=model, tokenizer=tokenizer, config=scheduler_config, + stream=self._mlx_stream, ) # Output collectors for low-latency streaming (vLLM pattern) @@ -152,11 +179,6 @@ def __init__( self._start_time: Optional[float] = None self._steps_executed = 0 - # Global single-thread executor shared across ALL engines. - # mlx-lm uses a module-level Metal stream, so concurrent MLX calls - # from different engine threads cause segfaults. See issue #85. - self._mlx_executor = get_mlx_executor() - logger.debug(f"Engine {self._engine_id} initialized") async def start(self) -> None: @@ -681,9 +703,9 @@ def close(self) -> None: self._closed = True - # Both shutdown() and deep_reset() touch generation_stream (directly + # Both shutdown() and deep_reset() touch the engine stream (directly # or via _drain_pending_async_removes / _do_abort_request). The - # stream is bound to the MLX executor thread, so dispatch both + # stream is bound to the engine's executor thread, so dispatch both # through the executor; fall back to a direct call if the executor # is already shut down. for fn in (self.scheduler.shutdown, self.scheduler.deep_reset): @@ -695,6 +717,10 @@ def close(self) -> None: except RuntimeError: pass + if self._mlx_executor is not None: + self._mlx_executor.shutdown(wait=True) + self._mlx_executor = None + # Clear output collectors for collector in self._output_collectors.values(): collector.clear() diff --git a/omlx/patches/mlx_lm_mtp/batch_generator.py b/omlx/patches/mlx_lm_mtp/batch_generator.py index 27d01192f..408139a4e 100644 --- a/omlx/patches/mlx_lm_mtp/batch_generator.py +++ b/omlx/patches/mlx_lm_mtp/batch_generator.py @@ -299,24 +299,6 @@ class _MtpState: # Helpers # --------------------------------------------------------------------------- -def _get_generation_stream(): - """Return the ``mlx_lm.generate`` module-level generation stream. - - The standard ``GenerationBatch._step`` runs all forward passes inside - ``mx.stream(generation_stream)``; the MTP cycle does the same so the - paged cache writes land on the same stream and ordering is preserved. - The stream lives on the *outer* ``BatchGenerator``, not on - ``GenerationBatch``, so we read it from the module. - - Note: ``mlx_lm.__init__`` re-exports a ``generate`` *function*, so - ``import mlx_lm.generate as mlg`` resolves to the function, not the - module. We use ``sys.modules`` to grab the actual module. - """ - import sys - - return sys.modules["mlx_lm.generate"].generation_stream - - def _resolve_sampler(gen_batch: Any): """Match ``GenerationBatch._step``'s per-sequence sampler resolution (batch=1).""" if gen_batch.samplers and gen_batch.samplers[0] is not None: @@ -498,9 +480,9 @@ def _reconcile_mtp_to_standard(gen_batch: Any, state: _MtpState) -> bool: procs = _proc_list(gen_batch) _set_singleton_mrope_delta(gen_batch) tok_arr = _ensure_uint32(mx.array(list(tokens))) - with mx.stream(_get_generation_stream()): - logits, _, _ = _call_backbone(gen_batch.model, tok_arr[None, :], new_cache) - last_logits = logits[:, -1, :] # (1, vocab) — dist after tokens[-1] + # Inherits the per-engine stream from the enclosing BatchGenerator context. + logits, _, _ = _call_backbone(gen_batch.model, tok_arr[None, :], new_cache) + last_logits = logits[:, -1, :] # (1, vocab) — dist after tokens[-1] if state.queue: next_id, next_lp_1d, _src = state.queue[0] @@ -751,10 +733,10 @@ def _post_init_mtp(gen_batch: Any) -> None: # 1-token backbone forward at main_tok with hidden state. No draft yet, # so no rollback is possible — discard gdn_states. - with mx.stream(_get_generation_stream()): - logits, hidden, _ = _call_backbone( - gen_batch.model, main_tok[:, None], gen_batch.prompt_cache - ) + # Inherits the per-engine stream from the enclosing BatchGenerator context. + logits, hidden, _ = _call_backbone( + gen_batch.model, main_tok[:, None], gen_batch.prompt_cache + ) next_main_logits = logits[:, -1, :] # (1, vocab) — distribution after main_tok next_main_logits = _apply_processors(procs, prev_buf, next_main_logits) @@ -766,8 +748,7 @@ def _post_init_mtp(gen_batch: Any) -> None: mtp_cache = gen_batch.model.make_mtp_cache() hidden_at_main = hidden[:, -1:, :] # (1, 1, H) next_ids = next_main_tok.reshape(1, 1) - with mx.stream(_get_generation_stream()): - mtp_logits = gen_batch.model.mtp_forward(hidden_at_main, next_ids, mtp_cache) + mtp_logits = gen_batch.model.mtp_forward(hidden_at_main, next_ids, mtp_cache) mtp_logits_2d = mtp_logits[:, -1, :] if procs is not None: prev_with_main_and_next = mx.concatenate( @@ -911,15 +892,15 @@ def _run_verify_cycle(gen_batch: Any, state: _MtpState) -> None: # Tradeoff: backbone_ms / sample_ms split is no longer wall-clock # accurate (everything lands in sample_ms), but cumulative timing is. t0 = time.perf_counter() - with mx.stream(_get_generation_stream()): - logits, hidden, gdn_states = _call_backbone( - gen_batch.model, - inputs[None, :], - gen_batch.prompt_cache, - n_confirmed=1, - ) - verify_logits = logits[:, 0, :] - bonus_logits = logits[:, 1, :] + logits, hidden, gdn_states = _call_backbone( + gen_batch.model, + inputs[None, :], + gen_batch.prompt_cache, + n_confirmed=1, + ) + verify_logits = logits[:, 0, :] + bonus_logits = logits[:, 1, :] + mx.eval(logits) state.stats.backbone_ms += (time.perf_counter() - t0) * 1000 t0 = time.perf_counter() @@ -1067,11 +1048,10 @@ def _step_mtp( t0 = time.perf_counter() next_ids = next_main_tok.reshape(1, 1) - with mx.stream(_get_generation_stream()): - mtp_logits = gen_batch.model.mtp_forward( - hidden_at_position, next_ids, state.mtp_cache - ) - mtp_logits_2d = mtp_logits[:, -1, :] + mtp_logits = gen_batch.model.mtp_forward( + hidden_at_position, next_ids, state.mtp_cache + ) + mtp_logits_2d = mtp_logits[:, -1, :] if procs is not None and prev_buf is not None: prev_with_next = mx.concatenate( [prev_buf, _ensure_uint32(next_main_tok)] diff --git a/omlx/scheduler.py b/omlx/scheduler.py index 9fdd77679..202f494e3 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -48,6 +48,10 @@ from .utils.proc_memory import get_phys_footprint from .utils.sampling import make_sampler as omlx_make_sampler +# Module-level alias so Scheduler.__init__ can fall back to mlx-lm's default +# stream when no per-engine stream is provided. +_default_generation_stream = generation_stream + @dataclass class _VLMMTPDecodeState: @@ -98,7 +102,7 @@ class _VLMMTPResponse: _mx_buffer_access_lock = threading.RLock() -def _sync_and_clear_cache(): +def _sync_and_clear_cache(stream=None): """Synchronize in-flight GPU work before clearing the Metal buffer cache. Without synchronization, mx.clear_cache() can release Metal buffers that @@ -114,28 +118,30 @@ def _sync_and_clear_cache(): See: https://github.com/jundot/omlx/issues/300, #888, #1106 """ with _mx_buffer_access_lock: - # Generation_stream may not have in-flight work on the current thread + # The engine stream may not have in-flight work on the current thread # (e.g. external prefill submits to the default stream). On some MLX # builds mx.synchronize raises "There is no Stream(gpu, 0) in current # thread" in that case; swallow it since there is nothing to drain. + target = stream if stream is not None else _default_generation_stream try: - mx.synchronize(generation_stream) + mx.synchronize(target) except RuntimeError: pass mx.synchronize() # default stream mx.clear_cache() -def _safe_sync_generation_stream(): - """mx.synchronize(generation_stream) that tolerates cross-thread calls. +def _safe_sync_stream(stream=None): + """mx.synchronize(stream) that tolerates cross-thread calls. - Generation_stream is owned by the _mlx_executor thread. Teardown paths - that run on the main thread (via EngineCore.close) hit "no Stream in + The per-engine stream is owned by the engine's executor thread. Teardown + paths that run on the main thread (via EngineCore.close) hit "no Stream in current thread" RuntimeError. Swallow that specific case so cleanup can proceed; re-raise anything else so real GPU errors stay visible. """ + target = stream if stream is not None else _default_generation_stream try: - mx.synchronize(generation_stream) + mx.synchronize(target) except RuntimeError as e: if "no Stream" not in str(e): raise @@ -717,6 +723,7 @@ def __init__( model: Any, tokenizer: Any, config: SchedulerConfig | None = None, + stream: Any | None = None, ): """ Initialize the scheduler. @@ -725,6 +732,8 @@ def __init__( model: The MLX model tokenizer: The tokenizer config: Scheduler configuration + stream: Optional mx.Stream for this engine. Falls back to the + module-level _default_generation_stream when not provided. """ self.model = model # Deep-copy the tokenizer so the scheduler owns an independent Rust @@ -735,6 +744,7 @@ def __init__( # Rust RefCell. See: https://github.com/huggingface/tokenizers/issues/537 self.tokenizer = copy.deepcopy(tokenizer) self.config = copy.copy(config) if config else SchedulerConfig() + self._stream = stream if stream is not None else _default_generation_stream # Load additional EOS tokens from generation_config.json. # Some models (e.g. GLM-4.6V) define multiple EOS tokens there @@ -1117,13 +1127,18 @@ def _async_store_cache_worker( without blocking the inference thread. async_eval completes Metal command enqueueing before returning, so all commands are submitted by the time executor.submit() runs. - - This worker calls mx.synchronize() (global barrier — waits - all streams) to ensure materialization is complete before - extracting tensor bytes. Stream-scoped sync is not possible - here because generation_stream is thread-local to the - inference thread. + - This worker calls mx.synchronize(self._stream) via the + _safe_sync_stream helper to wait on the same stream where + mx.async_eval dispatched the arrays. A bare mx.synchronize() + with no args only blocks on the default stream (gpu:0) and + would leave the dispatched per-engine stream's work + unsynchronized, racing the buffer-protocol access below + (#1437). Stream objects are not thread-local in MLX (Metal + device is a global singleton), so mx.synchronize(stream) is + safe cross-thread; it just calls waitUntilCompleted on the + command buffer. - bfloat16 view+eval inside _extract_tensor_bytes runs on this - worker's default mx stream, isolated from generation_stream; + worker's default mx stream, isolated from self._stream; the underlying buffer is read-only at this point. - batch_generator.remove(uid) is deferred until this worker completes (handled by _drain_pending_async_removes). @@ -1140,7 +1155,7 @@ def _async_store_cache_worker( # buffer pool mid-read (#1106). with _mx_buffer_access_lock: with self._phase_timer("store_cache_worker_sync"): - mx.synchronize() + _safe_sync_stream(self._stream) block_table = self.block_aware_cache.store_cache( request_id, token_sequence_to_store, @@ -1187,7 +1202,7 @@ def _drain_pending_async_removes(self) -> None: ) # Run batch_generator.remove on the inference thread. try: - _safe_sync_generation_stream() + _safe_sync_stream(self._stream) self._remove_uid_from_active_batch(uid) if hasattr(self.model, "unregister_rope_delta"): self.model.unregister_rope_delta(uid) @@ -1602,6 +1617,7 @@ def _create_batch_generator( prefill_batch_size=1, completion_batch_size=self.config.completion_batch_size, prefill_step_size=self.config.prefill_step_size, + stream=self._stream, ) return bg @@ -1962,7 +1978,7 @@ def _do_external_prefill( raise _PrefillAbortedError(abort_uids, processed_tokens) # Reclaim Metal intermediates between prefill chunks. - _sync_and_clear_cache() + _sync_and_clear_cache(self._stream) # Emit final boundary snapshot if prompt lands exactly on boundary. if boundary_enabled: @@ -1976,7 +1992,7 @@ def _do_external_prefill( request, prompt_cache, total_tokens ) - _sync_and_clear_cache() + _sync_and_clear_cache(self._stream) # Restore _rope_deltas after cached VLM prefill (for decode capture) if vlm_embeds is not None and _saved_rope_deltas is not None: @@ -2291,7 +2307,7 @@ def _step_prefill_chunk(self, state: _PrefillState) -> bool: f"{self._memory_hard_limit_bytes / 1024**3:.1f}GB)" ) - _sync_and_clear_cache() + _sync_and_clear_cache(self._stream) return state.tokens_remaining.shape[1] == 0 def _emit_final_boundary_if_needed(self, state: _PrefillState) -> None: @@ -2429,7 +2445,7 @@ def _advance_chunked_prefills( # Prefill complete — emit final boundary snapshot and insert. self._prefill_states.pop(rid, None) self._emit_final_boundary_if_needed(state) - _sync_and_clear_cache() + _sync_and_clear_cache(self._stream) # Ensure a BatchGenerator exists (may not if all requests were # previously in chunked prefill with no running decode). @@ -2937,12 +2953,12 @@ def _extract_boundary_snapshot(self, uid: int) -> list[Any] | None: return None try: - # Synchronize pending generation_stream operations before + # Synchronize pending engine stream operations before # accessing batch cache tensors. with self._phase_timer("boundary_capture_sync"): - _safe_sync_generation_stream() + _safe_sync_stream(self._stream) with self._phase_timer("boundary_capture_extract"): - with mx.stream(generation_stream): + with mx.stream(self._stream): result = self.batch_generator.extract_cache([uid]) if uid not in result: return None @@ -3812,7 +3828,7 @@ def _route_to_vlm_mtp( last_arr = mx.array(last_tokens)[None] # (1, len_last) try: - with mx.stream(generation_stream): + with mx.stream(self._stream): out = lm( last_arr, cache=prefilled_cache, @@ -3952,7 +3968,7 @@ def _step_vlm_mtp(self) -> list[_VLMMTPResponse]: responses: list[_VLMMTPResponse] = [] for uid, state in list(self._vlm_mtp_active.items()): try: - with mx.stream(generation_stream): + with mx.stream(self._stream): token_val = next(state.generator) except StopIteration: # Round loop exited naturally — terminate with prompt cache @@ -4188,11 +4204,11 @@ def _score_progress(processed: int, total: int, phase: str) -> None: logger.debug(f"SpecPrefill: draft cache store failed: {e}") # Free draft cache from memory. Use _sync_and_clear_cache() so - # the generation_stream is drained before Metal buffers are + # the engine stream is drained before Metal buffers are # returned to the pool — a bare mx.clear_cache() here can race # with in-flight async evals and trigger a kernel panic (#557). del used_cache - _sync_and_clear_cache() + _sync_and_clear_cache(self._stream) # Mark scoring complete (auto-removes tracker entry). tracker.update(request.request_id, n_to_score, n_to_score, model_id) @@ -4339,7 +4355,7 @@ def _do_abort_request(self, request_id: str) -> bool: # that replaces references to arrays still used by in-flight # Metal command buffers. Without this barrier the Metal driver # can hit 'completeMemory() prepare count underflow'. - _safe_sync_generation_stream() + _safe_sync_stream(self._stream) self._remove_uid_from_active_batch(uid) if hasattr(self.model, "unregister_rope_delta"): self.model.unregister_rope_delta(uid) @@ -4460,7 +4476,7 @@ def fail_all_requests(self) -> list[str]: # state — mx.synchronize() or mx.clear_cache() can throw a C++ # exception that causes SIGABRT if uncaught (#435). try: - _sync_and_clear_cache() + _sync_and_clear_cache(self._stream) except Exception as e: logger.warning(f"Metal cache clear failed during error recovery: {e}") return failed_ids @@ -4796,13 +4812,13 @@ def _check_specprefill_abort(processed: int) -> None: ) sys_arr = sys_arr[step:] # Use _sync_and_clear_cache() instead of bare - # mx.clear_cache() to flush the generation_stream + # mx.clear_cache() to flush the engine stream # before releasing Metal buffers. A bare call here # can race with in-flight command buffers submitted # by the preceding mx.eval(), triggering the same # 'completeMemory() prepare count underflow' kernel # panic that #435 fixed elsewhere (#557). - _sync_and_clear_cache() + _sync_and_clear_cache(self._stream) if sys_arr.size > 0: _check_specprefill_abort(sys_processed) final_sys = int(sys_arr.size) @@ -4980,7 +4996,7 @@ def _sparse_progress(processed: int, total: int) -> None: if done: self._emit_final_boundary_if_needed(state) - _sync_and_clear_cache() + _sync_and_clear_cache(self._stream) get_prefill_tracker().remove(request.request_id) self._insert_prefilled_request(request, state, scheduled) else: @@ -5388,12 +5404,12 @@ def _process_batch_responses( def _cleanup_finished(self, finished_ids: set[str]) -> None: """Clean up finished requests and store caches for reuse.""" - # Synchronize pending generation_stream operations before cache storage. + # Synchronize pending engine stream operations before cache storage. # store_cache -> mx.save_safetensors triggers implicit mx.eval() which # can conflict with async Metal operations on the generation stream. if finished_ids: with self._phase_timer("cleanup_finished_sync"): - _safe_sync_generation_stream() + _safe_sync_stream(self._stream) # SpecPrefill: restore original RoPE if active request finished for rid in finished_ids: @@ -5439,50 +5455,49 @@ def _cleanup_finished(self, finished_ids: set[str]) -> None: ) intermediate_snapshots = None - # Boundary merge + async_eval on the inference - # thread. async_eval dispatches KV array - # materialization without blocking, so the - # inference thread can start the next request - # immediately. The worker calls - # mx.synchronize() to wait for completion - # before extracting bytes. - with ( - self._phase_timer("store_cache_main_prep"), - mx.stream(generation_stream), - ): - boundary_override = self._get_boundary_store_override( - request_id, - cacheable_sequence, - ) - if boundary_override is not None: - ( - token_sequence_to_store, - boundary_cache, - boundary_model_config, - intermediate_snapshots, - ) = boundary_override - cache_to_store = ( - self._merge_boundary_with_full_cache( - boundary_cache, request._extracted_cache - ) - ) - if boundary_model_config is not None: - model_cache_config = boundary_model_config - logger.info( - f"Using boundary cache snapshot for {request_id}: " - f"storing {len(token_sequence_to_store)}/" - f"{len(full_token_sequence)} tokens " - f"(skipping trailing partial block, " - f"{len(intermediate_snapshots) if intermediate_snapshots else 0} " - f"intermediate snapshots)" + # Inference-thread store_cache prep, timed as + # three sub-phases (boundary / collect / dispatch) + # mirroring boundary_capture_* granularity. + # async_eval dispatches KV array materialization + # without blocking; the worker calls + # mx.synchronize(self._stream) to wait before + # extracting bytes. + with mx.stream(self._stream): + with self._phase_timer("store_cache_main_boundary"): + boundary_override = self._get_boundary_store_override( + request_id, + cacheable_sequence, ) - pre_eval_arrays = ( - self._collect_arrays_from_extracted_cache( - cache_to_store + if boundary_override is not None: + ( + token_sequence_to_store, + boundary_cache, + boundary_model_config, + intermediate_snapshots, + ) = boundary_override + cache_to_store = ( + self._merge_boundary_with_full_cache( + boundary_cache, request._extracted_cache + ) + ) + if boundary_model_config is not None: + model_cache_config = boundary_model_config + logger.info( + f"Using boundary cache snapshot for {request_id}: " + f"storing {len(token_sequence_to_store)}/" + f"{len(full_token_sequence)} tokens " + f"(skipping trailing partial block, " + f"{len(intermediate_snapshots) if intermediate_snapshots else 0} " + f"intermediate snapshots)" + ) + with self._phase_timer("store_cache_main_collect"): + pre_eval_arrays = ( + self._collect_arrays_from_extracted_cache( + cache_to_store + ) ) - ) - if pre_eval_arrays: - mx.async_eval(*pre_eval_arrays) + if pre_eval_arrays: + mx.async_eval(*pre_eval_arrays) if self._store_cache_executor is not None: # Gate acquire blocks if too many KV caches @@ -5587,7 +5602,7 @@ def _cleanup_finished(self, finished_ids: set[str]) -> None: # used by in-flight Metal command buffers from the previous # batch_generator.next() call. Without this barrier the Metal # driver can hit 'completeMemory() prepare count underflow'. - _safe_sync_generation_stream() + _safe_sync_stream(self._stream) self._remove_uid_from_active_batch(uid) if hasattr(self.model, "unregister_rope_delta"): self.model.unregister_rope_delta(uid) @@ -5870,7 +5885,7 @@ def step(self) -> SchedulerOutput: + len(responses) ) if self._tokens_since_clear_cache >= 1024: - _sync_and_clear_cache() + _sync_and_clear_cache(self._stream) self._tokens_since_clear_cache = 0 except _PrefillAbortedError: @@ -5941,7 +5956,7 @@ def step(self) -> SchedulerOutput: should_clear = True self._deferred_clear_at = None if should_clear: - _sync_and_clear_cache() + _sync_and_clear_cache(self._stream) if ( self.config.gc_cleanup_interval > 0 and self._step_counter % self.config.gc_cleanup_interval == 0 diff --git a/tests/test_engine_core.py b/tests/test_engine_core.py index a189cdce4..a74ca2c89 100644 --- a/tests/test_engine_core.py +++ b/tests/test_engine_core.py @@ -972,10 +972,8 @@ def test_get_mlx_executor_returns_singleton(self): executor2 = get_mlx_executor() assert executor1 is executor2 - def test_engines_share_mlx_executor(self, mock_model, mock_tokenizer): - """Multiple EngineCore instances must share a single MLX executor (#85).""" - from omlx.engine_core import get_mlx_executor - + def test_engines_have_per_engine_executors(self, mock_model, mock_tokenizer): + """Each EngineCore must have its own executor (#1248).""" with patch("omlx.engine_core.get_registry") as mock_registry: mock_registry.return_value.acquire.return_value = True @@ -983,8 +981,7 @@ def test_engines_share_mlx_executor(self, mock_model, mock_tokenizer): engine2 = EngineCore(model=mock_model, tokenizer=mock_tokenizer) try: - assert engine1._mlx_executor is engine2._mlx_executor - assert engine1._mlx_executor is get_mlx_executor() + assert engine1._mlx_executor is not engine2._mlx_executor finally: engine1.close() engine2.close() @@ -1040,13 +1037,14 @@ def simulated_step(task_id: str, duration: float = 0.05): ) @pytest.mark.asyncio - async def test_two_engine_loops_serialize_on_shared_executor( + async def test_two_engine_loops_run_concurrently_on_separate_executors( self, mock_model, mock_tokenizer ): - """Two engines running their loops must serialize step() calls (#85). + """Two engines with per-engine executors can run step() concurrently (#1248). - Creates two EngineCore instances with mock schedulers, starts both - engine loops, and verifies their scheduler.step() calls never overlap. + Each EngineCore has its own ThreadPoolExecutor and mx.Stream, so their + scheduler.step() calls can overlap. This test verifies that two engines + actually achieve concurrent execution. """ import threading import time @@ -1101,8 +1099,9 @@ def tracked_step(): assert total_steps >= 4, ( f"Expected at least 4 steps from two engines, got {total_steps}" ) - assert max_concurrent == 1, ( - f"Expected max 1 concurrent step(), got {max_concurrent}. " - f"Two engines ran MLX operations in parallel — would cause " - f"Metal command buffer races in production." + # With per-engine executors (#1248), two engines CAN run concurrently. + # max_concurrent >= 2 means both engines overlapped at least once. + assert max_concurrent >= 2, ( + f"Expected concurrent execution (max_concurrent >= 2), got {max_concurrent}. " + f"Per-engine executors should allow parallel step() calls." ) diff --git a/tests/test_per_engine_threads.py b/tests/test_per_engine_threads.py new file mode 100644 index 000000000..284e426bd --- /dev/null +++ b/tests/test_per_engine_threads.py @@ -0,0 +1,217 @@ +"""Tests for per-engine thread isolation (issue #1248).""" + +from unittest.mock import MagicMock, patch + +import mlx.core as mx +import pytest + +from omlx.engine_core import EngineCore +from omlx.scheduler import Scheduler, SchedulerConfig + + +class TestSchedulerStreamParam: + """Scheduler must accept an explicit stream and use it instead of the + module-level generation_stream.""" + + def test_scheduler_stores_explicit_stream(self): + mock_model = MagicMock() + mock_model.model_type = "test" + mock_tokenizer = MagicMock() + mock_tokenizer.eos_token_id = 0 + + stream = mx.new_thread_local_stream(mx.default_device()) + scheduler = Scheduler( + model=mock_model, + tokenizer=mock_tokenizer, + stream=stream, + ) + assert scheduler._stream is stream + + def test_scheduler_defaults_to_generation_stream(self): + from omlx.scheduler import _default_generation_stream + + mock_model = MagicMock() + mock_model.model_type = "test" + mock_tokenizer = MagicMock() + mock_tokenizer.eos_token_id = 0 + + scheduler = Scheduler( + model=mock_model, + tokenizer=mock_tokenizer, + ) + assert scheduler._stream is _default_generation_stream + + +class TestSchedulerStreamIsolation: + """Scheduler must use self._stream in all GPU stream operations, + never the module-level generation_stream.""" + + def test_no_module_level_generation_stream_in_hot_path(self): + """After migration, scheduler.py should not reference the module-level + generation_stream anywhere in the Scheduler class body except the + __init__ default fallback and comments/docstrings.""" + import inspect + import re + + import omlx.scheduler as sched_mod + source = inspect.getsource(sched_mod.Scheduler) + + # Find bare generation_stream references that aren't: + # - _default_generation_stream (the import alias) + # - Part of a larger word + bare_refs = re.findall( + r'(? SIGABRT in get_command_encoder(gpu:2). + """ + + def test_safe_sync_passes_generation_stream(self): + """_safe_sync_stream() with no args must invoke mx.synchronize with + the module-level _default_generation_stream object, not call the + no-args variant. + + Regression: PR #1146 wired the worker to bare mx.synchronize() + under the (incorrect) assumption that it was a global barrier + and that stream-scoped sync was unsafe cross-thread. Both + assumptions are wrong: synchronize() defaults to a single + stream, and Stream objects are not thread-local. The worker + path now routes through this helper so the regression has a + single chokepoint to assert against. + """ + from omlx import scheduler as sched_mod + + calls = [] + + def fake_sync(*args, **kwargs): + calls.append(args) + + with patch.object(sched_mod.mx, "synchronize", side_effect=fake_sync): + sched_mod._safe_sync_stream() + + assert len(calls) == 1 + assert calls[0] and calls[0][0] is sched_mod.generation_stream, ( + f"Worker sync must target generation_stream, got: {calls}" + ) + + def test_safe_sync_swallows_no_stream_runtime_error(self): + """A 'no Stream' RuntimeError from cross-thread sync must be + swallowed so the worker can still proceed to extract bytes. + + On some MLX builds mx.synchronize(stream) raises 'There is no + Stream(gpu, X) in current thread' from a thread that has not + submitted work to that stream. In the store-cache worker that + condition means there is no in-flight gpu:2 work to drain, so + it is safe to continue. + """ + from omlx import scheduler as sched_mod + + def fake_sync(*args, **kwargs): + raise RuntimeError("There is no Stream(gpu, 2) in current thread.") + + with patch.object(sched_mod.mx, "synchronize", side_effect=fake_sync): + sched_mod._safe_sync_stream() + + def test_safe_sync_propagates_other_runtime_errors(self): + """Real GPU errors must not be silently swallowed.""" + from omlx import scheduler as sched_mod + + def fake_sync(*args, **kwargs): + raise RuntimeError("Metal command buffer execution failed") + + with patch.object(sched_mod.mx, "synchronize", side_effect=fake_sync): + with pytest.raises(RuntimeError, match="command buffer execution failed"): + sched_mod._safe_sync_stream() + + class TestSchedulerFormatBytes: """Tests for Scheduler._format_bytes().""" From b7cb489d9a28be0e28d5de101904a61c2c59b7ba Mon Sep 17 00:00:00 2001 From: jundot Date: Wed, 27 May 2026 14:06:54 +0900 Subject: [PATCH 05/10] refactor(engine): remove redundant _ensure_wired_limit guard BatchGenerator.__init__ already calls mx.set_wired_limit() on each instance, and concurrent calls with the same value are race-free (verified empirically). The guard never prevented the race it claimed to fix. Follow-up to #1304. (cherry picked from commit a62f95392b73a17e1e7268170d0129b6a70a1774) --- omlx/engine_core.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/omlx/engine_core.py b/omlx/engine_core.py index 20af9c271..7baa51b01 100644 --- a/omlx/engine_core.py +++ b/omlx/engine_core.py @@ -78,22 +78,6 @@ def get_mlx_executor() -> concurrent.futures.ThreadPoolExecutor: return _global_mlx_executor -_wired_limit_set = False - - -def _ensure_wired_limit() -> None: - """Set Metal wired memory limit once at first engine creation. - - BatchGenerator normally calls mx.set_wired_limit() per-instance, which - races when multiple engines init concurrently (process-global setting). - We call it once here instead. - """ - global _wired_limit_set - if not _wired_limit_set and mx.metal.is_available(): - mx.set_wired_limit(mx.device_info()["max_recommended_working_set_size"]) - _wired_limit_set = True - - @dataclass class EngineConfig: """Configuration for the engine.""" @@ -152,7 +136,6 @@ def __init__( # Per-engine executor with dedicated mx.Stream (#1248). # Each EngineCore gets its own thread + GPU stream so different # models can run scheduler.step() concurrently. - _ensure_wired_limit() self._mlx_stream = mx.new_thread_local_stream(mx.default_device()) self._mlx_executor = concurrent.futures.ThreadPoolExecutor( max_workers=1, From c50d64e32e66943d9044e745d64d2a4e15783d53 Mon Sep 17 00:00:00 2001 From: cfbraun Date: Wed, 27 May 2026 09:19:56 +0200 Subject: [PATCH 06/10] test(mtp): drop monkeypatch of removed ``_get_generation_stream`` (#1445) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit #1304 (``fix(engine): per-engine threads to eliminate cross-engine stream contamination``) refactored the patched ``BatchGenerator`` to inherit its execution stream from the enclosing engine context, and removed the module-level ``_get_generation_stream`` helper as part of that. ``TestBatchGeneratorDispatch._make_reconcile_batch`` still tried to monkeypatch that name and failed at collection of every test that depends on the fixture with ``AttributeError: module ... has no attribute '_get_generation_stream'`` — taking down 4 reconcile-path tests on every CI run. The override is no longer needed: the surrounding fixture replaces ``_rebuild_singleton_cache`` and ``_call_backbone`` with fakes that do all of their work via ``np.array`` / ``mx.array`` directly, so neither MLX dispatch nor stream selection is reached. Tests (tests/test_mlx_lm_mtp_patch.py::TestBatchGeneratorDispatch): - test_reconcile_uses_queue_front_as_next_token - test_reconcile_empty_queue_samples_from_logits - test_reconcile_returns_false_on_empty_tokens - test_reconcile_fallback_on_rebuild_failure 11/11 TestBatchGeneratorDispatch tests pass. (cherry picked from commit e6d8a3f68683719671a6d38361a93f618dc2a131) --- tests/test_mlx_lm_mtp_patch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_mlx_lm_mtp_patch.py b/tests/test_mlx_lm_mtp_patch.py index bc594e94f..13eb8a806 100644 --- a/tests/test_mlx_lm_mtp_patch.py +++ b/tests/test_mlx_lm_mtp_patch.py @@ -528,7 +528,10 @@ def fake_backbone(model, inputs, cache, n_confirmed=0): monkeypatch.setattr(batch_generator, "_rebuild_singleton_cache", fake_rebuild) monkeypatch.setattr(batch_generator, "_call_backbone", fake_backbone) - monkeypatch.setattr(batch_generator, "_get_generation_stream", lambda: mx.cpu) + # ``_get_generation_stream`` was removed in #1304 when the patch + # moved stream selection to the enclosing BatchGenerator context. + # The fake_backbone / fake_rebuild monkeypatches above bypass the + # actual MLX dispatch, so no stream override is needed. def greedy(lp_2d): return mx.argmax(lp_2d, axis=-1).astype(mx.uint32) From f554f1961c433578831dc0e550dbc4c9daa08c9b Mon Sep 17 00:00:00 2001 From: jundot Date: Wed, 27 May 2026 13:08:04 +0900 Subject: [PATCH 07/10] fix(scheduler): wait on generation_stream in store-cache worker sync (#1437) mx.synchronize() with no args only waits on the default stream (gpu:0), not the generation_stream (gpu:2) where mx.async_eval dispatched the arrays. The lazy memoryview() in _extract_tensor_bytes triggered a cross-thread eval that aborted at get_command_encoder(gpu:2). Route the worker sync through _safe_sync_generation_stream so the wait targets the correct stream. Closes #1437. (cherry picked from commit 2e698ff318c5f2f413bdc5723d0350873283bc2e) --- omlx/cache/paged_ssd_cache.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/omlx/cache/paged_ssd_cache.py b/omlx/cache/paged_ssd_cache.py index be52bc8e5..30cbed734 100644 --- a/omlx/cache/paged_ssd_cache.py +++ b/omlx/cache/paged_ssd_cache.py @@ -1280,14 +1280,19 @@ def _store_nstate_elements(prefix: str, elements): metadata.update(cache_list_meta) # Caller (scheduler._cleanup_finished, async store-cache path) - # already mx.eval's all real KV arrays on the inference thread - # before submitting to the omlx-store-cache executor. The tiny + # dispatches real KV arrays via mx.async_eval on the inference + # thread's generation_stream before submitting to the + # omlx-store-cache executor. The worker then waits on that same + # stream via mx.synchronize(generation_stream) (see + # _async_store_cache_worker) before reaching this code path, + # so the arrays are fully materialized by the time + # _extract_tensor_bytes hits the buffer protocol. The tiny # mx.zeros((1,)) placeholders allocated above are lazy nodes # whose buffer materialization happens implicitly via the buffer - # protocol. Skipping the explicit mx.eval here keeps save_block + # protocol. Skipping any explicit mx.eval here keeps save_block # off the Metal command-submission path when invoked from a # non-inference thread, which is the source of the cross-thread - # race tracked in #978/#1040. + # race tracked in #978/#1040/#1106/#1437. tensors_raw = {} for name, arr in arrays.items(): tensors_raw[name] = _extract_tensor_bytes(arr) From ebf2c21ab5f8e80504ee4704b8b79fad6b62b5ea Mon Sep 17 00:00:00 2001 From: jundot Date: Wed, 27 May 2026 16:12:03 +0900 Subject: [PATCH 08/10] fix(engine): materialize VLM model lazy state on loader thread mlx-vlm's load() only materializes model.language_model.parameters(), leaving frozen buffers (RoPE freqs) and sibling sub-trees (vision_tower, audio_tower) as lazy arrays bound to the loader thread's default stream. Pre-#1304 this was invisible because loader and forward shared one global thread; the per-engine executor split exposed it as "no Stream(gpu, X) in current thread" when mx.eval touches a sibling buffer during prefill. Fix: materialize the full model tree on the loader thread right after load. Verified against gemma-4-E2B-it and gemma-4-31b-it-4bit. Also fixes test_safe_sync_passes_generation_stream to match the _default_generation_stream alias introduced by #1304. Reported by @zviratko in #1304. (cherry picked from commit 9d5bed83ede00e4d0b696241cc00bef2957bb9d4) --- omlx/engine/batched.py | 8 +++++++- omlx/engine/vlm.py | 7 +++++++ omlx/utils/model_loading.py | 20 ++++++++++++++++++++ tests/test_scheduler.py | 4 ++-- 4 files changed, 36 insertions(+), 3 deletions(-) diff --git a/omlx/engine/batched.py b/omlx/engine/batched.py index 42b229594..1379af1ba 100644 --- a/omlx/engine/batched.py +++ b/omlx/engine/batched.py @@ -248,10 +248,16 @@ def _load_model_sync(): ) # Apply post-load transforms (e.g., IndexCache for DSA models) - from ..utils.model_loading import apply_post_load_transforms + from ..utils.model_loading import apply_post_load_transforms, materialize_lazy_state self._model = apply_post_load_transforms(self._model, self._model_settings) + # Materialize lazy buffers on the loader thread so per-engine + # inference threads can read them (#1304). + await loop.run_in_executor( + get_mlx_executor(), materialize_lazy_state, self._model + ) + # TurboQuant KV cache: patch attention and set kv_bits on scheduler if self._model_settings is not None: tq_enabled = getattr(self._model_settings, "turboquant_kv_enabled", False) diff --git a/omlx/engine/vlm.py b/omlx/engine/vlm.py index 8dd53cd53..95eee547f 100644 --- a/omlx/engine/vlm.py +++ b/omlx/engine/vlm.py @@ -704,6 +704,13 @@ def _load_vlm_sync(): get_mlx_executor(), _load_vlm_sync ) + # Materialize lazy buffers (RoPE freqs, vision/audio towers) on the + # loader thread so per-engine inference threads can read them (#1304). + from ..utils.model_loading import materialize_lazy_state + await loop.run_in_executor( + get_mlx_executor(), materialize_lazy_state, self._vlm_model + ) + _fix_processor_none_pixels(self._processor) # Initialize vision feature cache diff --git a/omlx/utils/model_loading.py b/omlx/utils/model_loading.py index be3fe6d34..2ad4c3be7 100644 --- a/omlx/utils/model_loading.py +++ b/omlx/utils/model_loading.py @@ -8,6 +8,9 @@ from pathlib import Path from typing import Any +import mlx.core as mx +from mlx.utils import tree_flatten + logger = logging.getLogger(__name__) _VLM_TEXT_PREFIX = "language_model." @@ -324,6 +327,23 @@ def load_text_model( return load(model_name, tokenizer_config=tokenizer_config) +def materialize_lazy_state(model: Any) -> None: + """Force-evaluate every mx.array in the model tree on the loader thread. + + mlx-vlm's load() runs `mx.eval(model.language_model.parameters())`, which + leaves frozen buffers (RoPE freqs and similar) plus sibling sub-trees + (vision_tower, audio_tower) as lazy arrays bound to the loader thread's + default stream. When a different thread (e.g. an EngineCore per-engine + executor introduced in #1304) later runs forward, mx.eval hits "no + Stream(gpu, X) in current thread" because those lazy ops target a stream + that only exists on the loader thread. Materializing the whole tree here + makes every leaf array safe to read from any thread afterwards. + """ + arrays = [v for _, v in tree_flatten(model) if isinstance(v, mx.array)] + if arrays: + mx.eval(arrays) + + def apply_post_load_transforms(model: Any, model_settings: Any = None) -> Any: """Apply optional post-load model transforms based on settings. diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index f80a5615d..2dea7a9cc 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -858,8 +858,8 @@ def fake_sync(*args, **kwargs): sched_mod._safe_sync_stream() assert len(calls) == 1 - assert calls[0] and calls[0][0] is sched_mod.generation_stream, ( - f"Worker sync must target generation_stream, got: {calls}" + assert calls[0] and calls[0][0] is sched_mod._default_generation_stream, ( + f"Worker sync must target _default_generation_stream, got: {calls}" ) def test_safe_sync_swallows_no_stream_runtime_error(self): From 414b84341183de917647a5837c8b3526bd25cfa5 Mon Sep 17 00:00:00 2001 From: jundot Date: Wed, 27 May 2026 15:24:13 +0900 Subject: [PATCH 09/10] fix(load): skip VLM MTPModule attach when checkpoint lacks mtp.* weights Some Qwen3.6 MoE VLM exports declare mtp_num_hidden_layers > 0 in config.json but ship no mtp.* weights in the safetensors (unsloth Qwen3.6 UD MLX builds across 3bit/4bit). PR #1404 unconditionally attaches MTPModule whenever the config declares MTP heads so persisted weights have a binding site; for these weight-stripped checkpoints there are no weights to bind, mlx-vlm strict load_weights fails with "Missing N parameters: language_model.mtp.*", the engine silently falls back to LLM and vision is dropped. The user just sees the model answer as if no image was attached, hence the hallucination report. Add _checkpoint_has_mtp_weights to scan model.safetensors.index.json (or the first safetensors shard header) for any mtp.* / language_model.mtp.* / model.mtp.* key. Flip set_mtp_attach_enabled(False) before mlx_vlm.utils.load() runs when the scan misses, so the runtime patch's __init__ wrap skips attachment. PR #1404's intended behavior is preserved when the checkpoint does ship mtp.* weights. Smoke test: unsloth_Qwen3.6-35B-A3B-UD-MLX-3bit + red square PNG. Before: "the user hasn't provided an image" (vision dropped). After: "Red square" identified correctly (prompt_tokens 25 -> 99 with vision tokens). Fixes #1426. (cherry picked from commit ff7522b6cb5968b38247623aa5eac3568682524a) --- omlx/patches/mlx_vlm_mtp/__init__.py | 35 ++++ .../mlx_vlm_mtp/qwen35_moe_vlm_runtime.py | 13 +- .../patches/mlx_vlm_mtp/qwen35_vlm_runtime.py | 17 +- omlx/utils/model_loading.py | 92 ++++++++++- tests/test_model_loading.py | 156 +++++++++++++++++- 5 files changed, 287 insertions(+), 26 deletions(-) diff --git a/omlx/patches/mlx_vlm_mtp/__init__.py b/omlx/patches/mlx_vlm_mtp/__init__.py index 700ada708..243389928 100644 --- a/omlx/patches/mlx_vlm_mtp/__init__.py +++ b/omlx/patches/mlx_vlm_mtp/__init__.py @@ -22,6 +22,41 @@ logger = logging.getLogger(__name__) + +# Process-wide gate read by ``LanguageModel.__init__`` (in both +# ``qwen35_moe_vlm_runtime`` and ``qwen35_vlm_runtime``) to decide whether +# to attach ``MTPModule`` when the config declares ``mtp_num_hidden_layers +# > 0``. Default True preserves PR #1404 behavior: VLM checkpoints that +# actually ship persisted ``mtp.*`` weights need the head bound even with +# ``mtp_enabled=False`` so strict ``load_weights`` succeeds. +# +# Caller (``utils/model_loading.py::maybe_apply_pre_load_patches``) flips +# this to False right before ``mlx_vlm.utils.load()`` runs when the +# checkpoint declares MTP heads in config but ships no ``mtp.*`` weights +# (e.g. unsloth Qwen3.6 UD MLX builds, issue #1426). Without the gate, +# strict load fails with "Missing N parameters" and the engine silently +# falls back to LLM, dropping vision. +# +# Single-thread MLX executor serializes loads, so this is race-free. +_MTP_ATTACH_ENABLED = True + + +def set_mtp_attach_enabled(enabled: bool) -> None: + """Toggle whether subsequent ``mlx_vlm.utils.load()`` calls attach the + VLM MTPModule. + + Independent of ``mlx_lm_mtp.set_mtp_active``: attach controls module + presence so strict load can bind persisted ``mtp.*`` weights, active + controls whether BatchGenerator actually invokes the MTP head during + decode. + """ + global _MTP_ATTACH_ENABLED + _MTP_ATTACH_ENABLED = bool(enabled) + + +def is_mtp_attach_enabled() -> bool: + return _MTP_ATTACH_ENABLED + def apply_mlx_vlm_mtp_patch() -> bool: """Apply the mlx-vlm MTP sanitize monkey-patches. diff --git a/omlx/patches/mlx_vlm_mtp/qwen35_moe_vlm_runtime.py b/omlx/patches/mlx_vlm_mtp/qwen35_moe_vlm_runtime.py index 42ea7e9f6..0b5d945cd 100644 --- a/omlx/patches/mlx_vlm_mtp/qwen35_moe_vlm_runtime.py +++ b/omlx/patches/mlx_vlm_mtp/qwen35_moe_vlm_runtime.py @@ -204,14 +204,21 @@ def _patch_vlm_language_model(q35moe_lang: Any) -> None: original_call = cls.__call__ def __init__(self, args, config=None): + from . import is_mtp_attach_enabled + original_init(self, args, config) - # Always attach MTPModule when the config declares MTP heads, so - # mlx-vlm's load_weights (which skips Model.sanitize for is_mlx_format + # Attach MTPModule when the config declares MTP heads, so mlx-vlm's + # load_weights (which skips Model.sanitize for is_mlx_format # checkpoints) can place the persisted mtp.* tensors. MTP speculative # decode invocation is gated downstream by # ``mlx_lm_mtp.batch_generator._is_mtp_eligible`` via ``is_mtp_active``. + # + # Gated by ``is_mtp_attach_enabled()`` so checkpoints that declare + # mtp_num_hidden_layers > 0 but ship no mtp.* weights (unsloth + # Qwen3.6 UD MLX builds, issue #1426) don't trip strict load_weights + # with "Missing N parameters" and silently fall back to LLM. n_mtp = int(getattr(args, "mtp_num_hidden_layers", 0) or 0) - if n_mtp > 0: + if n_mtp > 0 and is_mtp_attach_enabled(): self.mtp = q35moe_lang.MTPModule(args) def __call__(self, inputs, inputs_embeds=None, mask=None, cache=None, **kwargs): diff --git a/omlx/patches/mlx_vlm_mtp/qwen35_vlm_runtime.py b/omlx/patches/mlx_vlm_mtp/qwen35_vlm_runtime.py index 202ac3b04..97789db3f 100644 --- a/omlx/patches/mlx_vlm_mtp/qwen35_vlm_runtime.py +++ b/omlx/patches/mlx_vlm_mtp/qwen35_vlm_runtime.py @@ -191,19 +191,22 @@ def _patch_vlm_language_model(q35_lang: Any) -> None: original_call = cls.__call__ def __init__(self, args, config=None): + from . import is_mtp_attach_enabled + original_init(self, args, config) - # Always attach MTPModule when the config declares MTP heads, so - # mlx-vlm's load_weights (which skips Model.sanitize for is_mlx_format + # Attach MTPModule when the config declares MTP heads so mlx-vlm's + # load_weights (which skips Model.sanitize for is_mlx_format # checkpoints) can place the persisted mtp.* tensors. Whether MTP # speculative decode is actually invoked at inference time is gated # downstream by ``mlx_lm_mtp.batch_generator._is_mtp_eligible``, # which checks the process-wide ``is_mtp_active`` flag. - # Without this unconditional attach, mtp_enabled=False would fail - # VLM load with "Received N parameters not in model" and the engine - # pool would permanently downgrade the entry to BatchedEngine — - # losing vision support. + # + # Gated by ``is_mtp_attach_enabled()`` so checkpoints that declare + # mtp_num_hidden_layers > 0 but ship no mtp.* weights (unsloth + # Qwen3.6 UD MLX builds, issue #1426) don't fail strict load_weights + # with "Missing N parameters" and silently downgrade to LLM. n_mtp = int(getattr(args, "mtp_num_hidden_layers", 0) or 0) - if n_mtp > 0: + if n_mtp > 0 and is_mtp_attach_enabled(): self.mtp = q35_lang.MTPModule(args) def __call__(self, inputs, inputs_embeds=None, mask=None, cache=None, **kwargs): diff --git a/omlx/utils/model_loading.py b/omlx/utils/model_loading.py index 2ad4c3be7..67e5a9854 100644 --- a/omlx/utils/model_loading.py +++ b/omlx/utils/model_loading.py @@ -189,16 +189,29 @@ def maybe_apply_pre_load_patches( from ..patches.mlx_vlm_mtp import ( apply_mlx_vlm_mtp_patch, apply_mlx_vlm_mtp_runtime_patch, + set_mtp_attach_enabled, ) except Exception: pass else: - # Sanitize-preservation patch MUST run too: the stock - # mlx-vlm Model.sanitize strips every ``mtp.*`` key, so - # without this the MTPModule loads at random init (0% - # accept). Previously only wired into the oQ path; needed - # on the inference load path as well for VLM checkpoints - # that ship MTP heads (e.g. PARO + injected guru87 head). + # Decide attach-vs-skip BEFORE applying the runtime patch + # because the patch wraps ``LanguageModel.__init__`` which + # reads the flag at instantiation. Some Qwen3.6 MoE VLM + # exports (unsloth UD MLX builds, issue #1426) declare + # ``mtp_num_hidden_layers > 0`` in config.json but ship no + # ``mtp.*`` weights; attaching MTPModule there causes + # strict load_weights to fail with "Missing N parameters" + # and silently downgrade the engine to LLM, dropping + # vision. Scan the index for actual mtp.* keys and skip + # attachment when they're absent. + has_mtp_weights = _checkpoint_has_mtp_weights(model_name) + set_mtp_attach_enabled(has_mtp_weights) + + # Sanitize-preservation patch runs unconditionally: the + # stock mlx-vlm Model.sanitize strips every ``mtp.*`` key, + # so without this an MTP head with persisted weights would + # load at random init (0% accept). When mtp.* weights are + # absent the patch is a no-op on the affected paths. if apply_mlx_vlm_mtp_patch(): if mtp_enabled: logger.info( @@ -213,7 +226,15 @@ def maybe_apply_pre_load_patches( model_name, ) if apply_mlx_vlm_mtp_runtime_patch(): - if mtp_enabled: + if not has_mtp_weights: + logger.info( + "mlx-vlm runtime MTP patch applied for %s " + "(config declares mtp heads but checkpoint " + "ships no mtp.* weights; MTPModule attachment " + "skipped to keep strict load_weights happy)", + model_name, + ) + elif mtp_enabled: logger.info( "mlx-vlm runtime MTP patch applied for %s", model_name, @@ -297,6 +318,63 @@ def _has_mtp_heads(config: dict) -> bool: return False +_MTP_WEIGHT_PREFIXES = ( + "mtp.", + "language_model.mtp.", + "model.mtp.", + "model.language_model.mtp.", +) + + +def _checkpoint_has_mtp_weights(model_path: str | Path) -> bool: + """True iff the checkpoint at *model_path* ships any ``mtp.*`` weight tensor. + + Some Qwen3.6 MoE VLM exports declare ``mtp_num_hidden_layers > 0`` in + ``config.json`` but strip the MTP weights during conversion (e.g. + ``unsloth/Qwen3.6-35B-A3B-UD-MLX-*bit``). Attaching ``MTPModule`` for + such a checkpoint causes mlx-vlm's strict ``load_weights`` to fail with + "Missing N parameters: language_model.mtp.*", the engine falls back to + LLM, and vision is silently dropped (issue #1426). + + Reads ``model.safetensors.index.json`` when present (no shard I/O). + Falls back to the first safetensors shard's metadata header. Returns + False when neither resolves — callers treat that as "no MTP weights" + (the conservative choice: skip MTPModule attachment). + """ + p = Path(model_path) + if not p.is_dir(): + return False + + index_path = p / "model.safetensors.index.json" + if index_path.exists(): + try: + data = json.loads(index_path.read_text()) + weight_map = data.get("weight_map") or {} + return any( + k.startswith(_MTP_WEIGHT_PREFIXES) for k in weight_map + ) + except Exception as e: + logger.debug( + "Failed to read %s for mtp weight scan: %s", index_path, e + ) + + shards = sorted(p.glob("*.safetensors")) + if not shards: + return False + try: + import safetensors + + with safetensors.safe_open(str(shards[0]), framework="numpy") as f: + for k in f.keys(): + if k.startswith(_MTP_WEIGHT_PREFIXES): + return True + except Exception as e: + logger.debug( + "Failed to read %s header for mtp weight scan: %s", shards[0], e + ) + return False + + def _is_mtp_compatible(config: dict, model_type: str | None) -> bool: """Decide whether the native MTP patch can be applied to this model. diff --git a/tests/test_model_loading.py b/tests/test_model_loading.py index 2f2d61284..972110ca3 100644 --- a/tests/test_model_loading.py +++ b/tests/test_model_loading.py @@ -19,6 +19,19 @@ def _write_config(tmp_path, body: str) -> str: return str(tmp_path) +def _write_mtp_index(tmp_path, has_mtp: bool) -> None: + """Drop a stub ``model.safetensors.index.json`` next to config.json so + ``_checkpoint_has_mtp_weights`` resolves deterministically in tests.""" + keys = {"language_model.model.embed_tokens.weight": "model.safetensors"} + if has_mtp: + keys["language_model.mtp.fc.weight"] = "model.safetensors" + (tmp_path / "model.safetensors.index.json").write_text( + '{"metadata": {}, "weight_map": ' + + str(keys).replace("'", '"') + + "}" + ) + + class TestNoDispatch: """Cases where the dispatcher should return None and let the caller fall back to the standard mlx-lm/mlx-vlm load path.""" @@ -164,10 +177,14 @@ class TestVlmMtpPreLoadDispatch: def _stub_patches(self, monkeypatch): """Replace the patch modules with mocks that record call order. - Returns the recorded-order list plus the sanitize/runtime mocks.""" + Returns the recorded-order list plus the sanitize/runtime/attach + mocks.""" calls: list[str] = [] sanitize_mock = MagicMock(side_effect=lambda: calls.append("sanitize") or True) runtime_mock = MagicMock(side_effect=lambda: calls.append("runtime") or True) + attach_mock = MagicMock( + side_effect=lambda enabled: calls.append(f"attach={enabled}") + ) # Side-step the real mlx-lm load_config monkey-patch. monkeypatch.setattr(model_loading, "_patch_mlx_lm_load_config", lambda: None) monkeypatch.setitem( @@ -184,29 +201,34 @@ def _stub_patches(self, monkeypatch): MagicMock( apply_mlx_vlm_mtp_patch=sanitize_mock, apply_mlx_vlm_mtp_runtime_patch=runtime_mock, + set_mtp_attach_enabled=attach_mock, ), ) - return calls, sanitize_mock, runtime_mock + return calls, sanitize_mock, runtime_mock, attach_mock def test_sanitize_patch_runs_before_runtime_for_vlm_mtp( self, tmp_path, monkeypatch ): - calls, sanitize_mock, runtime_mock = self._stub_patches(monkeypatch) + calls, sanitize_mock, runtime_mock, attach_mock = self._stub_patches( + monkeypatch + ) # qwen3_5 (dense VLM) declaring an MTP head under text_config. path = _write_config( tmp_path, '{"model_type": "qwen3_5", "vision_config": {}, ' '"text_config": {"mtp_num_hidden_layers": 1}}', ) + _write_mtp_index(tmp_path, has_mtp=True) settings = types.SimpleNamespace(mtp_enabled=True) maybe_apply_pre_load_patches(path, model_settings=settings, for_vlm=True) sanitize_mock.assert_called_once() runtime_mock.assert_called_once() + attach_mock.assert_called_once_with(True) # Ordering matters: the dense runtime patch assumes sanitize was # already installed by apply_mlx_vlm_mtp_patch. - assert calls == ["sanitize", "runtime"] + assert calls == ["attach=True", "sanitize", "runtime"] def test_vlm_patches_applied_when_mtp_disabled_for_vlm(self, tmp_path, monkeypatch): # Issue #1404: persisted ``mtp.*`` weights must still get a binding @@ -214,26 +236,63 @@ def test_vlm_patches_applied_when_mtp_disabled_for_vlm(self, tmp_path, monkeypat # even with mtp_enabled=False. Otherwise mlx-vlm's strict load_weights # fails with "parameters not in model" and the engine falls back to # LLM, silently dropping vision. - calls, sanitize_mock, runtime_mock = self._stub_patches(monkeypatch) + calls, sanitize_mock, runtime_mock, attach_mock = self._stub_patches( + monkeypatch + ) path = _write_config( tmp_path, '{"model_type": "qwen3_5", "vision_config": {}, ' '"text_config": {"mtp_num_hidden_layers": 1}}', ) + _write_mtp_index(tmp_path, has_mtp=True) + settings = types.SimpleNamespace(mtp_enabled=False) + + maybe_apply_pre_load_patches(path, model_settings=settings, for_vlm=True) + + sanitize_mock.assert_called_once() + runtime_mock.assert_called_once() + attach_mock.assert_called_once_with(True) + assert calls == ["attach=True", "sanitize", "runtime"] + + def test_vlm_attach_disabled_when_config_declares_mtp_but_weights_missing( + self, tmp_path, monkeypatch + ): + # Issue #1426: unsloth Qwen3.6 UD MLX builds declare + # mtp_num_hidden_layers=1 in config.json but ship no mtp.* weights. + # Attaching MTPModule there causes mlx-vlm strict load_weights to + # fail with "Missing N parameters: language_model.mtp.*", the + # engine falls back to LLM, and vision is silently dropped. The + # dispatcher must flip set_mtp_attach_enabled(False) so the runtime + # patch's __init__ wrap skips attachment for this load. + calls, sanitize_mock, runtime_mock, attach_mock = self._stub_patches( + monkeypatch + ) + path = _write_config( + tmp_path, + '{"model_type": "qwen3_5_moe", "vision_config": {}, ' + '"text_config": {"mtp_num_hidden_layers": 1}}', + ) + _write_mtp_index(tmp_path, has_mtp=False) settings = types.SimpleNamespace(mtp_enabled=False) maybe_apply_pre_load_patches(path, model_settings=settings, for_vlm=True) sanitize_mock.assert_called_once() + # Runtime patch itself still applies (process-wide class wrap is + # idempotent and harmless when there are no mtp.* weights to bind); + # the gate is what prevents MTPModule attachment. runtime_mock.assert_called_once() - assert calls == ["sanitize", "runtime"] + attach_mock.assert_called_once_with(False) + assert calls == ["attach=False", "sanitize", "runtime"] def test_vlm_patches_skipped_when_not_for_vlm(self, tmp_path, monkeypatch): # BatchedEngine / DFlashEngine / LLM loader paths must NOT touch # mlx-vlm classes even when the model declares MTP heads. for_vlm # defaults to False so they pass through without invoking mlx-vlm # patches. - calls, sanitize_mock, runtime_mock = self._stub_patches(monkeypatch) + calls, sanitize_mock, runtime_mock, attach_mock = self._stub_patches( + monkeypatch + ) path = _write_config( tmp_path, '{"model_type": "qwen3_5", "vision_config": {}, ' @@ -251,7 +310,9 @@ def test_qwen36_moe_vlm_sanitize_when_no_mtp_heads(self, tmp_path, monkeypatch): # mlx-lm Qwen3.6 MoE VLMs without MTP heads still need the mlx-vlm # sanitize replacement so pre-converted switch_mlp weights load. # Runtime MTP patch must NOT run — there is no mtp.* tree to bind. - calls, sanitize_mock, runtime_mock = self._stub_patches(monkeypatch) + calls, sanitize_mock, runtime_mock, attach_mock = self._stub_patches( + monkeypatch + ) path = _write_config( tmp_path, '{"model_type": "qwen3_5_moe", "vision_config": {}, ' @@ -268,7 +329,9 @@ def test_qwen36_moe_vlm_sanitize_when_no_mtp_heads(self, tmp_path, monkeypatch): def test_qwen36_moe_vlm_sanitize_skipped_without_for_vlm( self, tmp_path, monkeypatch ): - calls, sanitize_mock, runtime_mock = self._stub_patches(monkeypatch) + calls, sanitize_mock, runtime_mock, attach_mock = self._stub_patches( + monkeypatch + ) path = _write_config( tmp_path, '{"model_type": "qwen3_5_moe", "vision_config": {}, ' @@ -281,3 +344,78 @@ def test_qwen36_moe_vlm_sanitize_skipped_without_for_vlm( sanitize_mock.assert_not_called() runtime_mock.assert_not_called() assert calls == [] + + +class TestCheckpointHasMtpWeights: + """``_checkpoint_has_mtp_weights`` decides whether the mlx-vlm runtime + patch attaches ``MTPModule`` at load time. The scan must: + + - return True when ``model.safetensors.index.json`` declares any key + under the ``(language_model.|model.)?mtp.`` prefix family; + - return False when no MTP-prefixed key is found; + - return False on missing / unreadable inputs (callers treat that as + "no MTP weights" — the conservative choice). + """ + + def _write_index(self, tmp_path, weight_map: dict) -> None: + import json as _json + + (tmp_path / "model.safetensors.index.json").write_text( + _json.dumps({"metadata": {}, "weight_map": weight_map}) + ) + + def test_returns_true_when_index_has_language_model_mtp(self, tmp_path): + self._write_index( + tmp_path, + { + "language_model.model.embed_tokens.weight": "model.safetensors", + "language_model.mtp.fc.weight": "model.safetensors", + }, + ) + assert model_loading._checkpoint_has_mtp_weights(str(tmp_path)) is True + + def test_returns_true_when_index_has_bare_mtp(self, tmp_path): + self._write_index( + tmp_path, + {"mtp.layers.0.self_attn.q_proj.weight": "model.safetensors"}, + ) + assert model_loading._checkpoint_has_mtp_weights(str(tmp_path)) is True + + def test_returns_true_when_index_has_model_language_model_mtp(self, tmp_path): + # mlx-vlm HF-source layout before sanitize-time remap (oQ writes this). + self._write_index( + tmp_path, + {"model.language_model.mtp.norm.weight": "model.safetensors"}, + ) + assert model_loading._checkpoint_has_mtp_weights(str(tmp_path)) is True + + def test_returns_false_when_index_lacks_mtp(self, tmp_path): + # Unsloth Qwen3.6 UD MLX layout: vision_tower + language_model.model.* + # but no language_model.mtp.* keys despite mtp_num_hidden_layers > 0 + # in config.json (issue #1426). + self._write_index( + tmp_path, + { + "language_model.model.embed_tokens.weight": "model.safetensors", + "vision_tower.blocks.0.attn.proj.weight": "model.safetensors", + }, + ) + assert model_loading._checkpoint_has_mtp_weights(str(tmp_path)) is False + + def test_returns_false_for_empty_dir(self, tmp_path): + # No index, no shards — caller treats as "no MTP weights". + assert model_loading._checkpoint_has_mtp_weights(str(tmp_path)) is False + + def test_returns_false_for_nonexistent_path(self, tmp_path): + assert ( + model_loading._checkpoint_has_mtp_weights( + str(tmp_path / "does-not-exist") + ) + is False + ) + + def test_returns_false_on_malformed_index(self, tmp_path): + (tmp_path / "model.safetensors.index.json").write_text("{not valid") + # Falls through to safetensors-header scan; no shards exist, so + # the helper conservatively returns False. + assert model_loading._checkpoint_has_mtp_weights(str(tmp_path)) is False From 59d9e7e9ad3585e6738f06b4615415c6d7d2b592 Mon Sep 17 00:00:00 2001 From: "Aka.Fido" Date: Wed, 27 May 2026 17:00:36 +0800 Subject: [PATCH 10/10] refactor(profiles): three-scope template contract + drop is_builtin emission (#1399) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(profiles): three-scope template fields + persistence Adds the universal/per-model field split and built-in preset templates shipped on first run. UNIVERSAL_PROFILE_FIELDS gates which keys a global template may carry; per-model profiles continue to accept the full sampling surface. New default_global_profile.json seeds the preset catalog when the user has no templates of their own. Server-side foundation for the design's three-scope profile model (preset / global / model) the Swift app now binds to. * refactor(profiles): retire server builtins, align on preset bundle Server-side slice. Drops the 4 hard-coded preset templates the ModelSettingsManager shipped on first run and lets clients source their catalog from `omlx/admin/static/omlx_preset.json` (the bundle that's already refreshable from omlx.ai). The 4 builtin entries were never on omlx.ai's side, so both HTML and the upcoming Swift app converge on the bundle's 10-entry catalog with no drift. - Delete `omlx/template/default_global_profile.json` and the `_get_builtin_templates()` machinery in `ModelSettingsManager`. - Drop the rename/save/delete rejection branches that special-cased builtin names — `/api/profile-templates` now serves user templates only. `is_builtin` stays on the response schema (always false) so back-compat with older HTML/Swift clients is intact; a follow-up removes it once both clients no longer key on the field. - Sync `model_profiles.py` to drop `load_default_global_templates()` — it's dead code once the builtin source is gone. - Tests: drop the qwen3.6 builtin assertions; keep CRUD + persistence coverage. Carved out of the Swift app PR (#1371) for focused review per maintainer ask — the Swift-side preset bundle client lives there. * chore(profiles): stop emitting is_builtin on template responses `ModelSettingsManager` previously stamped `is_builtin: false` on every template payload — belt-and-suspenders left over from when server-side builtins existed. With those retired (prior commit), the flag is dead weight: HTML doesn't read it, and the Swift app gates on `templateScope` (defaults to `.global` when the field is absent) so older clients keep working unchanged. Schema-compat note: the field stays decodable on the client side as `Optional` — we're just dropping the emission. Older clients that still expect the key continue to work because their decoders treat it as optional. Carved out of the Swift app PR (#1371). The Swift-side stale-reference patches and packaging/build.py cleanup that were also in the original commit stay in #1371 since they touch Swift sources. (cherry picked from commit 0c881f574f002b43b9aa016bcf82e5de594ab06d) --- omlx/model_settings.py | 13 +++++++++- tests/test_model_settings_profiles.py | 35 ++++++++++++++++++++++++++- 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/omlx/model_settings.py b/omlx/model_settings.py index 05fa3eabd..efcb35cb9 100644 --- a/omlx/model_settings.py +++ b/omlx/model_settings.py @@ -660,6 +660,11 @@ def apply_profile(self, model_id: str, name: str) -> Optional[ModelSettings]: # ==================== Templates ==================== def _load_templates(self) -> None: + # Built-in defaults ship inside the package (omlx/default_global_templates.json) + # and are merged in at read time — they are NEVER copied to disk and never + # appear in `self._templates`. The user file under holds + # ONLY user-created templates; a missing/empty file is the legitimate + # initial state. if not self.templates_file.exists(): self._templates = {} return @@ -694,12 +699,18 @@ def _save_templates(self) -> None: raise def list_templates(self) -> list[dict]: + # Shipped JSON seeds were retired in favor of the client-side preset + # bundle (`omlx/admin/static/omlx_preset.json`); every entry on this + # surface is user-created. Callers that distinguish presets from + # user templates do so via the preset bundle, not an `is_builtin` + # flag on this response. with self._lock: return [dict(t) for t in self._templates.values()] def get_template(self, name: str) -> Optional[dict]: with self._lock: - return dict(self._templates.get(name, {})) or None + u = self._templates.get(name) + return dict(u) if u is not None else None def save_template( self, diff --git a/tests/test_model_settings_profiles.py b/tests/test_model_settings_profiles.py index 84db646f8..c2d4122d0 100644 --- a/tests/test_model_settings_profiles.py +++ b/tests/test_model_settings_profiles.py @@ -2,6 +2,8 @@ """Tests for profile/template CRUD on ModelSettingsManager.""" +import json + import pytest from omlx.model_profiles import InvalidProfileNameError @@ -145,7 +147,10 @@ def test_save_filters_excluded_fields(self, mgr): class TestTemplatesCRUD: - def test_list_templates_empty(self, mgr): + def test_list_templates_empty_by_default(self, mgr): + # Shipped builtins were retired in favor of the client-side preset + # bundle (`omlx/admin/static/omlx_preset.json`); the server's + # /api/profile-templates surface now exposes user templates only. assert mgr.list_templates() == [] def test_save_template_universal_only(self, mgr): @@ -202,3 +207,31 @@ def test_templates_persist_across_instances(self, tmp_path): m1.save_template("coding", "Coding", None, {"temperature": 0.0}) m2 = ModelSettingsManager(tmp_path) assert m2.get_template("coding") is not None + + +class TestTemplatesPersistence: + """The on-disk template file holds only user-created entries. Built-in + seed templates were retired in favor of the client-side preset bundle + (`omlx/admin/static/omlx_preset.json`); /api/profile-templates is now a + pure user-store surface.""" + + def test_no_file_created_when_empty(self, tmp_path): + ModelSettingsManager(tmp_path) + # With no user templates and no shipped builtins, the manager must + # not create the templates file proactively. + assert not (tmp_path / "global_templates.json").exists() + + def test_user_template_persists_only_itself(self, tmp_path): + m1 = ModelSettingsManager(tmp_path) + m1.save_template("custom", "Custom", None, {"temperature": 0.1}) + + on_disk = json.loads((tmp_path / "global_templates.json").read_text()) + assert set(on_disk["templates"].keys()) == {"custom"} + + m2 = ModelSettingsManager(tmp_path) + names = {t["name"] for t in m2.list_templates()} + assert names == {"custom"} + # No `is_builtin` is emitted now that builtins are retired; preset + # vs user classification lives on the client (preset bundle), not + # on this response. + assert "is_builtin" not in m2.get_template("custom")