diff --git a/omlx/_torch_stub.py b/omlx/_torch_stub.py new file mode 100644 index 000000000..63b1efea9 --- /dev/null +++ b/omlx/_torch_stub.py @@ -0,0 +1,359 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Minimal ``torch`` stub for the DMG bundle. + +xgrammar 0.2.0 declares ``torch>=1.10.0`` as a runtime dep, but oMLX never +exercises its torch-backed code paths: bitmasks are allocated as numpy +``int32`` buffers, the C++ binding fills them, and the MLX kernel applies the +mask. The torch dep is load-bearing only at *import time* — module-level code +in ``xgrammar.matcher``, ``xgrammar.testing``, ``xgrammar.contrib.hf`` and +``tvm_ffi.core`` does ``import torch`` plus a handful of attribute lookups. + +Real torch is ~500 MB unpacked on macOS arm64 — too heavy to ship in the DMG. +This stub provides just enough of the torch surface for those modules to +finish loading. Code paths that would actually call into torch raise +``RuntimeError`` from the helpers below; oMLX never reaches them. + +When a real torch is installed (pip / Homebrew flow) the stub is a no-op: +``install()`` checks ``importlib.util.find_spec('torch')`` first. +""" + +from __future__ import annotations + +import importlib.machinery +import importlib.util +import logging +import os +import sys +import threading +import types + +logger = logging.getLogger(__name__) + +# xgrammar / tvm-ffi versions this stub is known to cover. +# This module is the *single source of truth* — packaging/build.py imports +# these constants to keep the DMG install pin in sync with the stub. Update +# both tuples here when bumping; the build script auto-tracks. +# +# Reachable-but-stubbed torch surface to be aware of when upgrading: +# - ``torch.full``: ``xgrammar.allocate_token_bitmask`` calls it. oMLX +# never invokes ``allocate_token_bitmask`` (we use the MLX kernel +# path), but the symbol is re-exported from ``xgrammar.__init__``. +# Any future caller that touches it will hit ``_unsupported("full")`` +# and surface a clear RuntimeError. +# - ``torch.tensor`` returns a ``_StubTensor`` whose attribute access +# raises a stub-identifying RuntimeError. Module-level +# ``_FULL_MASK = torch.tensor(-1, ...)`` patterns succeed at import +# time; any subsequent method call (.fill_, .item, ...) fails. +_TARGET_XGRAMMAR_VERSIONS = ("0.2.0",) +_TARGET_TVM_FFI_VERSIONS = ("0.1.11",) + +# Serialize install() across threads. Without this, two threads that both +# pass the "torch" in sys.modules check race to build modules and overwrite +# each other's sys.modules['torch'] entry, leaving threads that already +# dereferenced the loser's module with stale references. Reachable today +# from concurrent HTTP handlers that call install() on first xgrammar use. +_INSTALL_LOCK = threading.Lock() +_INSTALLED = False + + +class _StubTensor: + """Placeholder for ``torch.Tensor`` (annotations + isinstance checks). + + Any attribute access raises a clear RuntimeError so runtime use of a + stubbed tensor (e.g. ``some_tensor.fill_(...)``) fails loudly with a + pointer to the cause, rather than at the AttributeError level with a + generic ``has no attribute 'fill_'`` message. + """ + + def __getattr__(self, name: str): + # Let dunder probes (pickle, copy.deepcopy, descriptor lookups, + # `hasattr` chains in third-party libs) fall through cleanly as + # AttributeError — that's the documented `__getattr__` contract. + # Real torch tensors lack many of these probed dunders anyway, so + # raising AttributeError is the correct, distinguishable signal. + if name.startswith("__") and name.endswith("__"): + raise AttributeError(name) + raise RuntimeError( + f"_StubTensor.{name} is not implemented: oMLX ships a torch " + "stub for xgrammar's import-time needs only. Reaching a real " + "tensor method means a code path that needs real torch was " + "exercised — install torch via pip/Homebrew or report this as " + "a bug if the call originated inside oMLX." + ) + + +class _StubDtype: + __slots__ = ("_name",) + + def __init__(self, name: str) -> None: + self._name = name + + def __repr__(self) -> str: + return f"torch.{self._name}" + + # Some xgrammar/tvm-ffi paths convert dtype to string via ``str(dt)`` + # rather than ``repr(dt)`` (e.g. ``to_cpp_dtype`` strips the "torch." + # prefix). Match real torch's behaviour where ``str(torch.int32)`` is + # ``"torch.int32"`` so those paths keep working. + def __str__(self) -> str: + return f"torch.{self._name}" + + +def _stub_tensor_factory(*args, **kwargs) -> _StubTensor: + """torch.tensor(...) stub: returns a _StubTensor instance. + + Returning a real object (rather than None) means module-globals like + xgrammar.matcher._FULL_MASK = torch.tensor(-1, dtype=...) succeed at + import time. Any subsequent method call on the result (.fill_, .item, + etc.) raises with a clear pointer via _StubTensor.__getattr__. + """ + return _StubTensor() + + +def _false(*args, **kwargs) -> bool: + return False + + +def _unsupported(qualname: str): + def _fn(*args, **kwargs): + raise RuntimeError( + f"torch.{qualname} is not available: this oMLX build ships a " + "torch stub for xgrammar's import-time needs only. Install " + "real torch via pip/Homebrew if you need this code path." + ) + + return _fn + + +# (canonical, alias) pairs — real torch aliases torch.int to torch.int32, +# torch.long to torch.int64, etc.; preserve those identities so code that +# does ``torch.int is torch.int32`` keeps working. +_DTYPE_ALIASES: tuple[tuple[str, tuple[str, ...]], ...] = ( + ("int32", ("int",)), + ("int16", ("short",)), + ("int64", ("long",)), + ("float16", ("half",)), + ("float32", ("float",)), + ("float64", ("double",)), + ("int8", ()), + ("uint8", ()), + ("bfloat16", ()), + ("bool", ()), +) + +_TENSOR_ALIASES = ( + "Tensor", "LongTensor", "FloatTensor", "IntTensor", "ByteTensor", + "DoubleTensor", "HalfTensor", "BoolTensor", "ShortTensor", +) + + +def _make_top_level_torch_getattr() -> "callable": + """Return a ``__getattr__`` for the stub's top-level torch module. + + Real-torch users who reach an unset attribute would get an + ``AttributeError``; consumers that probe with ``hasattr`` rely on that. + But we *also* want a clearly-identifiable message when downstream + libraries (transformers, accelerate, etc.) reach for a torch surface + we never stubbed — so this raises ``AttributeError`` whose message + pinpoints the omlx stub. ``pkgutil.iter_modules(torch.__path__)`` and + similar discovery paths see the empty ``__path__`` and short-circuit + before hitting this. + """ + + _missing_attr_warned: set[str] = set() + + def __getattr__(name: str): # noqa: N807 + # Surface the miss at WARNING level so a future xgrammar release + # reaching for a new torch attribute is diagnosable from logs + # before the AttributeError surfaces in a request handler. Rate- + # limit per name so repeated probes (e.g. hasattr() under a + # loop) don't flood the journal — once per name per process is + # enough to identify the gap. + if name not in _missing_attr_warned: + _missing_attr_warned.add(name) + logger.warning( + "oMLX torch stub missing attribute: torch.%s " + "(install real torch if this is load-bearing)", + name, + ) + # Dunder probes always fall through as AttributeError so pickling, + # copy.deepcopy, and similar Python machinery work as expected. + raise AttributeError( + f"torch.{name!s} is not provided by the oMLX torch stub. " + "Install real torch via pip/Homebrew if this attribute is " + "actually needed." + ) + + return __getattr__ + + +def _build_modules() -> dict[str, types.ModuleType]: + torch = types.ModuleType("torch") + for alias in _TENSOR_ALIASES: + setattr(torch, alias, _StubTensor) + torch.dtype = _StubDtype + torch.__version__ = "0.0.0+omlx-stub" + # Pin the stub as the source of truth for the xgrammar version it + # targets; packaging/build.py imports this constant to stay in sync. + # (Module-level constant lives at the top of this file.) + for canonical, aliases in _DTYPE_ALIASES: + dt = _StubDtype(canonical) + setattr(torch, canonical, dt) + for a in aliases: + setattr(torch, a, dt) + torch.tensor = _stub_tensor_factory + torch.full = _unsupported("full") + torch.zeros = _unsupported("zeros") + torch.from_dlpack = _unsupported("from_dlpack") + + cuda = types.ModuleType("torch.cuda") + cuda.is_available = _false + + class _Stream: + pass + + cuda.Stream = _Stream + torch.cuda = cuda + + version = types.ModuleType("torch.version") + version.cuda = None + version.hip = None + torch.version = version + + nn_functional = types.ModuleType("torch.nn.functional") + nn_functional.pad = _unsupported("nn.functional.pad") + nn = types.ModuleType("torch.nn") + nn.functional = nn_functional + torch.nn = nn + + utils_dlpack = types.ModuleType("torch.utils.dlpack") + utils_dlpack.to_dlpack = _unsupported("utils.dlpack.to_dlpack") + utils = types.ModuleType("torch.utils") + utils.dlpack = utils_dlpack + torch.utils = utils + + # Top-level __getattr__ so a future xgrammar that reaches into a + # torch surface we never stubbed (e.g. ``torch.compile``, + # ``torch.distributed``) fails with a stub-identifying message rather + # than a cryptic ``AttributeError: module 'torch' has no attribute…``. + torch.__getattr__ = _make_top_level_torch_getattr() + + return { + "torch": torch, + "torch.cuda": cuda, + "torch.version": version, + "torch.nn": nn, + "torch.nn.functional": nn_functional, + "torch.utils": utils, + "torch.utils.dlpack": utils_dlpack, + } + + +def install() -> bool: + """Install the stub into ``sys.modules`` if no real torch is available. + + Returns True if the stub was installed (or had been installed previously), + False if a real torch was found and left alone. + + Thread-safe — concurrent callers (e.g. multiple FastAPI handlers hitting + the xgrammar entry points in parallel) serialize on _INSTALL_LOCK. + """ + global _INSTALLED + needs_version_check = False + with _INSTALL_LOCK: + if _INSTALLED: + return True + + if "torch" in sys.modules: + already_stub = getattr( + sys.modules["torch"], "__version__", "" + ).endswith("+omlx-stub") + _INSTALLED = already_stub + return already_stub + + try: + if importlib.util.find_spec("torch") is not None: + # Real torch is on the path — leave it alone, install() is + # a no-op. Don't mark _INSTALLED so a future sys.modules + # reset (e.g. in tests) re-evaluates. Crucially, also DO + # NOT touch ``TVM_FFI_DISABLE_TORCH_C_DLPACK`` — the user + # has real torch and the tvm-ffi/torch-C-DLPack JIT path + # may be their preferred fast path. + return False + except Exception: + # find_spec can raise on broken parent packages, partial + # installs, or weird import hooks. Treat as "no torch" — the + # stub is the safe fallback. + pass + + # No real torch — disable tvm_ffi's JIT torch-C-DLPack extension + # before any tvm-ffi / xgrammar import. Without this, + # tvm_ffi/_optional_torch_c_dlpack tries to JIT a C extension + # against our stub at first import, spawns a doomed Python + # subprocess that fails to ``import torch.utils.cpp_extension`` + # (the stub does not provide it), and surfaces a misleading + # "Failed to JIT torch c dlpack extension" warning to users on + # every cold start. The guard inside that module honours this + # env var and skips the JIT path entirely. + os.environ.setdefault("TVM_FFI_DISABLE_TORCH_C_DLPACK", "1") + + for name, mod in _build_modules().items(): + # ``__spec__`` must be a real ModuleSpec (not None) so that + # ``importlib.util.find_spec`` succeeds when called by + # transformers and other consumers. ``__version__`` is a + # clearly-fake value so transformers refuses to take the + # torch-modeling path. + mod.__spec__ = importlib.machinery.ModuleSpec(name, loader=None) + mod.__loader__ = None + if "." not in name: + mod.__path__ = [] # type: ignore[attr-defined] + sys.modules[name] = mod + _INSTALLED = True + needs_version_check = True + + # Fire the version-drift check OUTSIDE the install lock. xgrammar's + # C++ extension load can be slow on a cold disk; running it under + # the lock would block every concurrent install() caller behind one + # cold import. install() is idempotent at this point — _INSTALLED is + # set and any racing caller short-circuits at the top of the lock. + if needs_version_check: + try: + warn_if_unexpected_versions() + except Exception: # pragma: no cover — defensive + pass + return True + + +def warn_if_unexpected_versions() -> None: + """Log a warning when bundled xgrammar / tvm-ffi versions drift past the + versions this stub was tested against. Best-effort: silent if the + imports themselves haven't happened yet, since the stub is installed + eagerly at startup. + """ + try: + import xgrammar # type: ignore[import-not-found] + + v = getattr(xgrammar, "__version__", None) + if v and v not in _TARGET_XGRAMMAR_VERSIONS: + logger.warning( + "xgrammar %s is not in the torch-stub target set %s; " + "structured output may fail at runtime. Update the stub " + "or pin xgrammar back.", + v, + _TARGET_XGRAMMAR_VERSIONS, + ) + except Exception: + pass + try: + import tvm_ffi # type: ignore[import-not-found] + + v = getattr(tvm_ffi, "__version__", None) + if v and v not in _TARGET_TVM_FFI_VERSIONS: + logger.warning( + "apache-tvm-ffi %s is not in the torch-stub target set %s; " + "structured output may fail at runtime.", + v, + _TARGET_TVM_FFI_VERSIONS, + ) + except Exception: + pass diff --git a/omlx/admin/routes.py b/omlx/admin/routes.py index 6758f9180..b1f3f806c 100644 --- a/omlx/admin/routes.py +++ b/omlx/admin/routes.py @@ -1530,6 +1530,17 @@ async def list_grammar_parsers(is_admin: bool = Depends(require_admin)): Returns ``[]`` if xgrammar is missing, fails to load (e.g. broken native binding on macOS arm64), or has neither API available. """ + # Install the torch stub BEFORE any xgrammar import. If this lives + # inside the first try-block, a failure on the 0.1.34+ path can leave + # the fallback try-block importing xgrammar without the stub, which + # is guaranteed ImportError on stub-only (DMG) deployments. + try: + from omlx._torch_stub import install as _install_torch_stub + + _install_torch_stub() + except Exception as e: # pragma: no cover — defensive + logger.debug("torch stub install failed: %s", e) + # Prefer the 0.1.34+ registry so newer parsers (qwen3_6, gemma4, # deepseek_v4, ...) are exposed. try: diff --git a/omlx/api/grammar.py b/omlx/api/grammar.py index 46e2d86db..03713abc1 100644 --- a/omlx/api/grammar.py +++ b/omlx/api/grammar.py @@ -38,6 +38,8 @@ def create_grammar_compiler(tokenizer, model): Returns None if vocab_size cannot be determined. """ + from .._torch_stub import install as _install_torch_stub + _install_torch_stub() import xgrammar as xgr from ..utils.tokenizer import resolve_vocab_size, unwrap_tokenizer @@ -63,6 +65,8 @@ class GrammarConstraintProcessor: """ def __init__(self, compiled_grammar, vocab_size: int): + from .._torch_stub import install as _install_torch_stub + _install_torch_stub() import xgrammar as xgr from xgrammar.kernels.apply_token_bitmask_mlx import apply_token_bitmask_mlx diff --git a/omlx/cache/paged_ssd_cache.py b/omlx/cache/paged_ssd_cache.py index 120cdec09..ffca5f342 100644 --- a/omlx/cache/paged_ssd_cache.py +++ b/omlx/cache/paged_ssd_cache.py @@ -95,6 +95,15 @@ def _compute_max_pending_writes() -> int: _ROTATING_CACHE_TYPES = ("RotatingKVCache", "BatchRotatingKVCache") +# Cap on the number of LRU blocks `_enforce_size_limit_for_new_block` is +# allowed to unlink in one inline burst. Eviction normally returns ~1 +# entry; the cap exists for the ENOSPC-recovery path where the disk-usage +# snapshot has been invalidated, `_get_effective_max_size` shrinks +# sharply, and `evict_until_size` would otherwise return hundreds of +# entries at once and stall the inference thread on a syscall storm. +_MAX_INLINE_UNLINKS_PER_SAVE = 32 + + def _clamp_rotating_meta_states( cache_data: list[Any], layer_cache_types: list[str] | None, @@ -518,19 +527,42 @@ def get_lru_entries(self, count: int) -> list[PagedSSDBlockMetadata]: result.append(self._index[block_hash]) return result - def evict_until_size(self, target_size: int) -> list[PagedSSDBlockMetadata]: + def evict_until_size( + self, + target_size: int, + max_count: int | None = None, + ) -> list[PagedSSDBlockMetadata]: """ Evict LRU entries until total size is below target. Args: target_size: Target total size in bytes. + max_count: Optional cap on the number of entries removed in + one call. When the cap is hit before ``total_size`` drops + below ``target_size`` the call returns the partial slice + and leaves the remaining LRU entries in the index; the + caller is expected to retry on the next save. The cap is + pushed down here (rather than the caller popping a + surplus and reinserting it) so the index never exposes a + transient "evicted but not yet unlinked" gap that a + concurrent writer's ``contains()`` check could observe + as a deleted block. Returns: List of evicted metadata (files need to be deleted by caller). + + Note: + Loop termination depends on ``remove()`` decrementing + ``_total_size`` for every popped entry. If a future refactor + moves the decrement to "after the on-disk unlink succeeds", + this loop must also gain a "skip entries already pulled this + pass" guard or it can spin forever when unlinks fail. """ with self._lock: evicted = [] while self._total_size > target_size and self._lru: + if max_count is not None and len(evicted) >= max_count: + break # Get LRU entry (first in OrderedDict) block_hash = next(iter(self._lru)) metadata = self.remove(block_hash) @@ -641,10 +673,12 @@ def __init__( # Statistics self._stats = { "saves": 0, + "saves_persisted": 0, "loads": 0, "hits": 0, "misses": 0, "evictions": 0, + "evict_unlink_failures": 0, "errors": 0, "hot_cache_hits": 0, "hot_cache_evictions": 0, @@ -789,16 +823,17 @@ def _enqueue_ssd_write( # 2. Index second — makes the block discoverable in has_block/contains. if not self._index.contains(block_hash): - self._enforce_size_limit_for_new_block() + self._enforce_size_limit_for_new_block(blk_meta.file_size) self._index.add(blk_meta) # 3. Queue third — enqueue for background writer. try: item = (block_hash, tensors_raw, metadata, file_path) - if blocking: - self._write_queue.put(item, timeout=0.5) - else: - self._write_queue.put_nowait(item) + # Non-blocking callers (hot-cache LRU spill) also wait briefly so + # a transient writer backlog doesn't silently drop blocks. Same + # 250 ms budget as save_block. Blocking callers (shutdown flush) + # wait longer to maximize the chance of flushing every entry. + self._write_queue.put(item, timeout=0.5 if blocking else 0.25) logger.debug( f"Evicted hot cache block to SSD write queue: " f"{block_hash.hex()[:16]}..." @@ -807,8 +842,9 @@ def _enqueue_ssd_write( except queue.Full: self._stats["ssd_write_drops"] += 1 logger.warning( - f"SSD write queue full, dropping evicted block " - f"{block_hash.hex()[:16]}" + f"SSD write queue saturated (cap={_MAX_PENDING_WRITES}); " + f"dropping evicted block {block_hash.hex()[:16]} — writer is " + f"falling behind" ) self._index.remove(block_hash) with self._pending_write_hashes_lock: @@ -1015,23 +1051,6 @@ def _writer_loop(self) -> None: if item is None: # Sentinel for shutdown break - # Unlink task: tuple ('unlink', file_path). Used to defer LRU file - # deletion off the inference thread (see _enforce_size_limit_for_new_block). - # Sequential queue processing prevents race with subsequent writes - # to the same block_hash (write tasks always queued after unlink). - if isinstance(item[0], str) and item[0] == "unlink": - _, unlink_path = item - try: - if unlink_path.exists(): - unlink_path.unlink() - self._stats["evictions"] += 1 - logger.debug(f"Evicted SSD cache file (async): {unlink_path}") - except FileNotFoundError: - pass - except Exception as e: - logger.warning(f"Failed to delete evicted file {unlink_path}: {e}") - continue - block_hash, tensors_raw, metadata, file_path = item temp_path = None @@ -1046,6 +1065,12 @@ def _writer_loop(self) -> None: # Atomic rename to final path os.rename(str(temp_path), str(file_path)) + # The block is now durable on disk; bump the persist counter + # before any cleanup so ``saves_persisted`` reflects rename + # success even if the post-rename eviction check below + # unlinks the file again. + self._stats["saves_persisted"] += 1 + # Update index with actual file size self._index.update_file_size(block_hash, actual_size) @@ -1065,10 +1090,33 @@ def _writer_loop(self) -> None: errno.ENOSPC, errno.EDQUOT, ): - logger.warning( - f"SSD cache disk full, cannot write block " - f"{block_hash.hex()[:16]}: {e}" + # ENOSPC after save_block already returned True and + # incremented _stats["saves"] — the slot is silently + # lost (no retry) and the caller treats the save as + # committed. Combined with eviction having already + # fired, the cache may lose both the evicted blocks + # AND the new block. Surface this at ERROR level so + # operators see it; a follow-up could expose a + # save-failure callback to let callers re-issue. + logger.error( + "SSD cache disk full, cannot write block %s: %s " + "(slot lost, subsequent saves will recompute disk " + "pressure)", + block_hash.hex()[:16], + e, ) + # Invalidate the 30s disk-usage snapshot so the next + # save sees the true (now-critical) free space and + # evicts aggressively rather than trusting a stale + # inflated limit. In-flight saves that already passed + # _enforce_size_limit_for_new_block are still queued + # and may ENOSPC again — invalidation only protects + # the NEXT round of save_block calls. Take the lock so + # the inference thread's _get_effective_max_size + # doesn't observe a half-updated (value, timestamp) + # pair. + with self._lock: + self._disk_usage_cache = None else: logger.error( f"Background write failed for " f"{block_hash.hex()[:16]}: {e}" @@ -1139,22 +1187,28 @@ def save_block( self._stats["hits"] += 1 return True - # Check queue capacity before doing expensive GPU/disk work - # (not needed for hot cache write-back mode) + # Cold-store saturation short-circuit: when hot cache is disabled + # the write queue is the only buffer between save_block and the + # writer thread. If the writer is already saturated, dropping here + # avoids the GPU tensor-extraction + size-enforcement work we'd + # otherwise throw away at the put step a few hundred lines down. + # Inline-LRU-unlinks already let eviction free queue capacity; + # this guard handles the case where the writer (not eviction) is + # the bottleneck. In hot-cache mode the LRU spill path through + # _enqueue_ssd_write has its own timeout-put + drop accounting, + # so we don't short-circuit there. if not self._hot_cache_enabled and self._write_queue.full(): self._stats["ssd_write_drops"] += 1 logger.warning( - f"SSD cache write queue full, skipping save for " - f"{block_hash.hex()[:16]}" + f"SSD cache write queue saturated (cap={_MAX_PENDING_WRITES}); " + f"dropping save for {block_hash.hex()[:16]} before tensor " + f"extraction — writer is falling behind" ) return False file_path = self._get_file_path(block_hash) try: - # Enforce size limit before saving (only for SSD path) - if not self._hot_cache_enabled: - self._enforce_size_limit_for_new_block() # Prepare arrays for safetensors. Three layer_data shapes are # accepted: @@ -1323,8 +1377,21 @@ def _store_nstate_elements(prefix: str, elements): for name, arr in arrays.items(): tensors_raw[name] = _extract_tensor_bytes(arr) - # Estimate file size from raw bytes (actual size set by background writer) - estimated_size = sum(len(raw) for raw, _, _ in tensors_raw.values()) + 1024 + # Estimate file size: raw tensor bytes + safetensors header. + # The header is JSON-encoded per tensor (name + dtype + shape + + # data_offsets, typically ~85 bytes) plus an 8-byte length prefix + # and the user metadata block. Compute the metadata-JSON length + # exactly (large `layer_meta_states` JSON on deep-layer models + # can exceed a fixed 1 KiB constant) and keep 128 B/tensor as an + # upper bound on the per-tensor header. The 256 B margin covers + # the JSON separators / `__metadata__` key envelope safetensors + # adds at write time. + try: + metadata_json_len = len(json.dumps(metadata).encode("utf-8")) + except (TypeError, ValueError): + metadata_json_len = 1024 + header_overhead = metadata_json_len + 256 + 128 * len(tensors_raw) + estimated_size = sum(len(raw) for raw, _, _ in tensors_raw.values()) + header_overhead now = time.time() block_metadata = PagedSSDBlockMetadata( @@ -1368,6 +1435,11 @@ def _store_nstate_elements(prefix: str, elements): # Hot cache disabled but hot_cache_only set: block is not retained. return False + # Evict LRU blocks to make room for the new block. Done here + # (post-tensor-build) so the actual block size is known and the + # cache doesn't oscillate around the configured limit. + self._enforce_size_limit_for_new_block(estimated_size) + # SSD path: add to index for SSD file tracking self._index.add(block_metadata) @@ -1379,16 +1451,22 @@ def _store_nstate_elements(prefix: str, elements): with self._pending_write_hashes_lock: self._pending_write_hashes.add(block_hash) - # Enqueue full file write for background thread + # Enqueue full file write for background thread. Wait briefly on + # Full so a transient burst (faster than the writer can drain) + # doesn't immediately drop the block — 250 ms is well below human + # perception of latency and typically covers one or two writer + # iterations on a healthy SSD. try: - self._write_queue.put_nowait( - (block_hash, tensors_raw, metadata, file_path) + self._write_queue.put( + (block_hash, tensors_raw, metadata, file_path), + timeout=0.25, ) except queue.Full: self._stats["ssd_write_drops"] += 1 logger.warning( - f"SSD cache write queue full, dropping write for " - f"{block_hash.hex()[:16]}" + f"SSD cache write queue saturated (cap={_MAX_PENDING_WRITES}); " + f"dropping write for {block_hash.hex()[:16]} — writer is " + f"falling behind" ) self._index.remove(block_hash) self._hot_cache_remove(block_hash) @@ -1707,7 +1785,17 @@ def load_block( # Previous executor-based approach caused deadlocks when # mx.load() in a worker thread contested Metal GPU resources # with the main inference thread. - arrays, file_metadata = mx.load(str(file_path), return_metadata=True) + try: + arrays, file_metadata = mx.load( + str(file_path), return_metadata=True + ) + except FileNotFoundError: + # Concurrent evictor unlinked the file between the + # exists() check above and this load. Treat as a miss + # and prune the stale index entry. + self._index.remove(block_hash) + self._stats["misses"] += 1 + return None # Defensive: even if the index is stale (e.g. from a previous # run that pre-dates the format version field), reject blocks @@ -2140,27 +2228,38 @@ def _get_effective_max_size(self) -> int: if self._cache_dir is None: return self._max_size + # Take the lock so a concurrent writer-thread invalidation + # (sets _disk_usage_cache=None on ENOSPC) can't interleave with + # this read-check-write and let one save see a fresh value paired + # with a stale timestamp (or vice versa). now = time.monotonic() - if self._disk_usage_cache is None or now - self._disk_usage_cache_time > 30.0: - try: - self._disk_usage_cache = shutil.disk_usage(self._cache_dir) - except OSError as e: - logger.warning( - f"Failed to check disk usage for SSD cache dir " - f"{self._cache_dir}: {e}" - ) - return self._max_size - self._disk_usage_cache_time = now + with self._lock: + if self._disk_usage_cache is None or now - self._disk_usage_cache_time > 30.0: + try: + self._disk_usage_cache = shutil.disk_usage(self._cache_dir) + except OSError as e: + logger.warning( + f"Failed to check disk usage for SSD cache dir " + f"{self._cache_dir}: {e}" + ) + return self._max_size + self._disk_usage_cache_time = now + disk_free = self._disk_usage_cache.free - disk_available = self._index.total_size + self._disk_usage_cache.free + disk_available = self._index.total_size + disk_free disk_limit = int(disk_available * self._DISK_SAFE_RATIO) return min(self._max_size, disk_limit) - def _enforce_size_limit_for_new_block(self) -> None: - """Enforce size limit before adding a new block.""" - # Estimate average block size (use 1MB as conservative estimate) - estimated_new_size = 1 * 1024 * 1024 + def _enforce_size_limit_for_new_block( + self, estimated_new_size: int = 1 * 1024 * 1024 + ) -> None: + """Enforce size limit before adding a new block. + ``estimated_new_size`` should be the actual byte size of the block + about to be inserted. The 1 MiB default is for callers that don't + yet know the size at the time eviction is needed; passing the + actual size avoids cache oscillation around the configured limit. + """ effective_max = self._get_effective_max_size() # Warn when disk pressure shrinks effective limit well below configured @@ -2180,25 +2279,45 @@ def _enforce_size_limit_for_new_block(self) -> None: target_size = int(effective_max * 0.9) if self._index.total_size > target_size: - evicted = self._index.evict_until_size(target_size) - # Defer file unlink to the writer thread to avoid blocking the - # inference thread with N file delete syscalls. Sequential queue - # processing keeps unlink ordered before any later write of the - # same block_hash. Hot cache is NOT touched here — see - # original comment about delete_block() being the only path that - # clears both tiers. + # Inline unlinks on the calling thread. Eviction typically + # removes a single block per save (the loop in + # evict_until_size stops as soon as size drops below target), + # so this is one syscall, not N. Routing unlinks through + # _write_queue (the prior design) made eviction compete with + # writes for queue capacity AND broke the invariant that + # eviction frees space — under saturation the unlink + # put_nowait fell back to inline anyway. Hot cache is NOT + # touched here; delete_block() is the only path that clears + # both tiers. + # + # Bound the inline syscall burst. ENOSPC invalidates the + # disk-usage snapshot (see _writer_loop), so the next save + # sees a much-shrunk effective_max and evict_until_size + # could otherwise return hundreds of metadata entries at + # once. The cap is pushed into ``evict_until_size`` so we + # never remove entries from the index that we aren't about + # to unlink — that would expose a window where the writer + # thread's ``contains()`` check sees a still-live block as + # evicted and unlinks the file we just persisted. Whatever + # the cap leaves above target is naturally picked up by the + # next save. + evicted = self._index.evict_until_size( + target_size, max_count=_MAX_INLINE_UNLINKS_PER_SAVE + ) for metadata in evicted: - try: - self._write_queue.put_nowait(("unlink", metadata.file_path)) - except queue.Full: - # Queue saturated — fall back to inline unlink so size - # accounting stays consistent. Rare path. - try: - if metadata.file_path.exists(): - metadata.file_path.unlink() - self._stats["evictions"] += 1 - except Exception as e: - logger.warning(f"Failed to delete evicted file: {e}") + self._unlink_evicted(metadata) + if ( + len(evicted) == _MAX_INLINE_UNLINKS_PER_SAVE + and self._index.total_size > target_size + ): + logger.info( + "SSD cache eviction burst capped at %d " + "(total_size %s still above target %s, will " + "reconverge on subsequent saves)", + _MAX_INLINE_UNLINKS_PER_SAVE, + format_bytes(self._index.total_size), + format_bytes(target_size), + ) def enforce_size_limit(self) -> int: """ @@ -2207,6 +2326,12 @@ def enforce_size_limit(self) -> int: Returns: Number of bytes freed. """ + # Decide what to evict under the lock, but perform unlinks outside + # it: a single unlink on a slow disk (NFS / encrypted FS / ENOSPC + # retry path) can block tens to hundreds of ms, and every + # _get_effective_max_size() / writer-thread cache-invalidation + # contends on self._lock. The index has its own internal lock + # protecting the LRU/size accounting. with self._lock: initial_size = self._index.total_size effective_max = self._get_effective_max_size() @@ -2217,21 +2342,40 @@ def enforce_size_limit(self) -> int: target_size = int(effective_max * 0.9) # 90% of effective max evicted = self._index.evict_until_size(target_size) - for metadata in evicted: - # Do NOT remove from hot cache — see _enforce_size_limit_for_new_block - try: - if metadata.file_path.exists(): - metadata.file_path.unlink() - self._stats["evictions"] += 1 - except Exception as e: - logger.warning(f"Failed to delete evicted file: {e}") + # Do NOT remove from hot cache — see _enforce_size_limit_for_new_block + for metadata in evicted: + self._unlink_evicted(metadata) - freed = initial_size - self._index.total_size - logger.info( - f"SSD cache size enforcement: freed {format_bytes(freed)}, " - f"evicted {len(evicted)} files" + freed = initial_size - self._index.total_size + logger.info( + f"SSD cache size enforcement: freed {format_bytes(freed)}, " + f"evicted {len(evicted)} files" + ) + return freed + + def _unlink_evicted(self, metadata: PagedSSDBlockMetadata) -> None: + """Delete an evicted block file from disk. + + On unlink failure other than FileNotFoundError, re-add the + metadata to the index so ``total_size`` keeps reflecting actual + on-disk bytes; without this, accumulated failures would let the + cache silently exceed ``max_size`` (the index would report free + space that does not exist on disk). + """ + try: + metadata.file_path.unlink(missing_ok=True) + self._stats["evictions"] += 1 + except OSError as e: + # Restore the index entry so total_size matches disk reality. + # The re-added entry lands at the LRU tail (most-recently + # touched), which deprioritises immediate re-eviction. + self._index.add(metadata) + self._stats["evict_unlink_failures"] += 1 + logger.exception( + "Failed to delete evicted SSD cache file %s: %s", + metadata.file_path, + e, ) - return freed def clear_hot_cache(self) -> int: """Clear all in-memory (hot) cache entries. @@ -2279,6 +2423,7 @@ def get_stats(self) -> PagedSSDCacheStats: misses=self._stats["misses"], evictions=self._stats["evictions"], saves=self._stats["saves"], + saves_persisted=self._stats["saves_persisted"], loads=self._stats["loads"], errors=self._stats["errors"], total_size_bytes=self._index.total_size, @@ -2338,6 +2483,7 @@ def _matches(candidate: str) -> bool: misses=self._stats["misses"], evictions=self._stats["evictions"], saves=self._stats["saves"], + saves_persisted=self._stats["saves_persisted"], loads=self._stats["loads"], errors=self._stats["errors"], total_size_bytes=indexed_size, diff --git a/omlx/cache/stats.py b/omlx/cache/stats.py index 177a25303..b0d5cbeca 100644 --- a/omlx/cache/stats.py +++ b/omlx/cache/stats.py @@ -187,8 +187,16 @@ class PagedSSDCacheStats(BaseCacheStats): Extends base stats with storage-specific and hot cache metrics. """ - # Operation counters + # ``saves`` counts blocks that PASSED the quota gate and were enqueued + # for the writer thread — it is incremented on the inference thread in + # ``save_block`` BEFORE the writer fsyncs the file, so a block whose + # background write later fails (ENOSPC, EDQUOT, OSError) still + # contributes 1 to ``saves`` and 1 to ``errors``. Use + # ``saves_persisted`` for the count of writes that actually landed on + # disk (incremented after the atomic rename in ``_writer_loop``); + # ``saves - saves_persisted`` is the steady-state in-flight depth. saves: int = 0 + saves_persisted: int = 0 loads: int = 0 errors: int = 0 ssd_write_drops: int = 0 @@ -232,6 +240,7 @@ def reset(self) -> None: """Reset runtime statistics.""" super().reset() self.saves = 0 + self.saves_persisted = 0 self.loads = 0 self.errors = 0 self.ssd_write_drops = 0 diff --git a/omlx/engine/base.py b/omlx/engine/base.py index a46487a31..68805c623 100644 --- a/omlx/engine/base.py +++ b/omlx/engine/base.py @@ -4,6 +4,7 @@ """ import asyncio +import logging import threading import time import uuid @@ -15,6 +16,37 @@ from omlx.engine_core import get_mlx_executor +_preflight_logger = logging.getLogger("omlx.engine.preflight") + +# Per-process record of (engine_class_name, method) pairs that have +# already logged a "scheduler unreachable" warning. The warning marks a +# wrapper-chain misconfiguration — a deployment bug rather than a +# runtime condition — so once-per-pair is enough to alert oncall +# without flooding the journal at request rate. +_PREFLIGHT_UNREACHABLE_WARNED: set[tuple[str, str]] = set() + + +def _warn_scheduler_unreachable_once( + engine: object, method: str, detail: str = "" +) -> None: + """Emit a one-shot WARNING when the wrapper chain doesn't expose a + scheduler. Subsequent calls with the same (engine type, method) pair + are silent so a misconfigured engine doesn't spam logs at request + rate. + """ + key = (type(engine).__name__, method) + if key in _PREFLIGHT_UNREACHABLE_WARNED: + return + _PREFLIGHT_UNREACHABLE_WARNED.add(key) + suffix = f" — {detail}" if detail else "" + _preflight_logger.warning( + "%s.%s: scheduler unreachable via _engine.engine.scheduler" + "%s; preflight check skipped (further occurrences suppressed)", + type(engine).__name__, + method, + suffix, + ) + @dataclass class GenerationOutput: @@ -253,6 +285,35 @@ def get_cache_stats(self) -> Optional[Dict[str, Any]]: """ pass + async def preflight_chat( + self, + messages: list, + tools: Optional[list] = None, + request_id: Optional[str] = None, + **kwargs, + ) -> None: + """Optional prefill-memory preflight check for chat requests. + + Default no-op; engines that implement the prefill memory guard + (``BatchedEngine``, ``VLMBatchedEngine``) override this with the + actual estimation logic. The base no-op lets simpler engines + (SimpleEngine, embedding/reranker engines, test stubs) be + invoked from the server endpoints without additional wrapping. + """ + return None + + async def preflight_completion( + self, + prompt: str, + request_id: Optional[str] = None, + **kwargs, + ) -> None: + """Optional prefill-memory preflight check for completion requests. + + See :meth:`preflight_chat` for the rationale. + """ + return None + class BaseNonStreamingEngine(ABC): """Base class for non-streaming engines (embedding, reranker). diff --git a/omlx/engine/batched.py b/omlx/engine/batched.py index 834fdb692..23e862c77 100644 --- a/omlx/engine/batched.py +++ b/omlx/engine/batched.py @@ -14,7 +14,7 @@ from ..api.tool_calling import convert_tools_for_template from ..api.utils import clean_special_tokens, detect_and_strip_partial from ..utils.tokenizer import get_tokenizer_config -from .base import BaseEngine, GenerationOutput +from .base import BaseEngine, GenerationOutput, _warn_scheduler_unreachable_once logger = logging.getLogger(__name__) @@ -149,10 +149,12 @@ def grammar_compiler(self): method = get_install_method() if method == "dmg": - logger.info( - "Structured output is not available in the DMG version " - "(xgrammar requires torch which significantly increases app size). " - "Use the pip or Homebrew version for structured output support." + logger.warning( + "GrammarCompiler initialization failed for %s on the " + "DMG build. The bundle ships xgrammar against a torch " + "stub; this usually means the bundled xgrammar / tvm-" + "ffi version drifted past what the stub covers.", + self._model_name, ) elif method == "homebrew": logger.info( @@ -697,6 +699,95 @@ async def chat( **kwargs, ) + async def preflight_chat( + self, + messages: list[dict[str, Any]], + tools: list[dict] | None = None, + request_id: str | None = None, + **kwargs, + ) -> None: + """Early prefill memory check for chat completions. + + Tokenizes the templated prompt and asks the scheduler whether the + request would exceed the configured memory ceiling. Raises + ``PrefillMemoryExceededError`` (with the caller's ``request_id`` + attached) if it would. Designed to be called from the FastAPI + route handler BEFORE the response is wrapped in a + ``StreamingResponse``, so the exception can be mapped to HTTP + 413 by ``prefill_memory_exceeded_handler``. + + Cheap enough to run as a precondition: tokenization of even a + 100k-token chat takes tens of milliseconds compared to the many + seconds the prefill it gates would consume. + """ + if not self._loaded: + await self.start() + messages = self._preprocess_messages(messages) + template_tools = convert_tools_for_template(tools) if tools else None + ct_kwargs = kwargs.get("chat_template_kwargs") + partial = kwargs.get("is_partial") + prompt = self._apply_chat_template( + messages, + template_tools, + chat_template_kwargs=ct_kwargs, + is_partial=partial, + ) + # Tokenizer errors (UnicodeDecodeError, HF Rust "Already borrowed", + # malformed input) are normally surfaced by the real chat path's + # add_request → tokenize call as a 500 — there's no path-specific + # 400 handler today. Don't introduce a NEW failure mode here: if + # tokenization fails during preflight, log it and skip the memory + # check. The actual chat path will hit the same error and raise it + # through the existing handler chain so the response shape stays + # consistent. + try: + num_tokens = len(self._tokenizer.encode(prompt)) + except Exception as e: + logger.warning( + "BatchedEngine.preflight_chat: tokenizer.encode raised %s; " + "skipping prefill memory check, real chat path will surface " + "the error", + type(e).__name__, + ) + return + scheduler = getattr(getattr(self._engine, "engine", None), "scheduler", None) + if scheduler is None: + _warn_scheduler_unreachable_once(self, "preflight_chat") + return + scheduler.preflight_or_raise( + num_prompt_tokens=num_tokens, request_id=request_id + ) + + async def preflight_completion( + self, + prompt: str, + request_id: str | None = None, + **kwargs, + ) -> None: + """Early prefill memory check for plain /v1/completions calls. + + See ``preflight_chat`` for the rationale. + """ + if not self._loaded: + await self.start() + try: + num_tokens = len(self._tokenizer.encode(prompt)) + except Exception as e: + logger.warning( + "BatchedEngine.preflight_completion: tokenizer.encode raised " + "%s; skipping prefill memory check, real completion path " + "will surface the error", + type(e).__name__, + ) + return + scheduler = getattr(getattr(self._engine, "engine", None), "scheduler", None) + if scheduler is None: + _warn_scheduler_unreachable_once(self, "preflight_completion") + return + scheduler.preflight_or_raise( + num_prompt_tokens=num_tokens, request_id=request_id + ) + async def stream_chat( self, messages: list[dict[str, Any]], diff --git a/omlx/engine/vlm.py b/omlx/engine/vlm.py index 57bc2479b..b2b21a9c0 100644 --- a/omlx/engine/vlm.py +++ b/omlx/engine/vlm.py @@ -47,7 +47,7 @@ extract_images_from_messages, ) from ..utils.tokenizer import get_tokenizer_config -from .base import BaseEngine, GenerationOutput +from .base import BaseEngine, GenerationOutput, _warn_scheduler_unreachable_once logger = logging.getLogger(__name__) @@ -507,6 +507,79 @@ def _uses_mrope(vlm_model) -> bool: } +# Conservative fallback upper bound on image-placeholder tokens per image +# content part. Used by ``preflight_chat`` only when the actual +# ``max_pixels`` cannot be derived from the loaded processor config. +# Qwen-VL / Gemma-Vision typically expand each image to 256–1280 tokens +# at default settings, but a deployment that lifts ``max_pixels`` can +# legitimately exceed this — relying on a hard-coded 1280 in that case +# silently under-counts and re-opens the panic-prone MLX prefill path. +# Prefer ``_derive_image_token_upper_bound(processor)`` when the +# processor is loaded. +_IMAGE_TOKEN_UPPER_BOUND_FALLBACK = 1280 + + +def _derive_image_token_upper_bound(processor: Any) -> int: + """Derive the per-image token upper bound from the processor config. + + Qwen-style image processors expose ``max_pixels`` (an *area*) and + pack pixels into ``patch_size`` × ``patch_size`` patches, then merge + ``merge_size`` × ``merge_size`` patches into one model token. The + per-image token bound is therefore:: + + max_tokens = max_pixels / (patch_size**2 * merge_size**2) + + Falls back to the conservative module-level constant when the + processor doesn't expose the expected attributes (other model + families) so we never *under*-count. + """ + if processor is None: + return _IMAGE_TOKEN_UPPER_BOUND_FALLBACK + ip = getattr(processor, "image_processor", None) or processor + max_pixels = getattr(ip, "max_pixels", None) + patch_size = getattr(ip, "patch_size", None) + merge_size = getattr(ip, "merge_size", None) + if ( + isinstance(max_pixels, int) + and max_pixels > 0 + and isinstance(patch_size, int) + and patch_size > 0 + and isinstance(merge_size, int) + and merge_size > 0 + ): + derived = max_pixels // (patch_size * patch_size * merge_size * merge_size) + # Never go *below* the conservative fallback — a model whose + # processor reports a tiny max_pixels (e.g. test fixtures) should + # not weaken the guard. + return max(derived, _IMAGE_TOKEN_UPPER_BOUND_FALLBACK) + return _IMAGE_TOKEN_UPPER_BOUND_FALLBACK + + +def _count_image_tokens( + messages: list[dict[str, Any]], + per_image_upper_bound: int = _IMAGE_TOKEN_UPPER_BOUND_FALLBACK, +) -> int: + """Count image-bearing content parts in OpenAI-style messages and + return the conservative token-budget contribution. + + Supports both the OpenAI ``image_url`` / ``image`` part types and the + Anthropic ``image`` block shape that gets adapted into the same + message-list before reaching the engine layer. + """ + image_parts = 0 + for msg in messages: + content = msg.get("content") + if not isinstance(content, list): + continue + for part in content: + if not isinstance(part, dict): + continue + ptype = part.get("type") + if ptype in ("image_url", "image", "input_image"): + image_parts += 1 + return image_parts * per_image_upper_bound + + class VLMBatchedEngine(BaseEngine): """ VLM engine with continuous batching, tiered KV cache, and boundary snapshots. @@ -597,10 +670,12 @@ def grammar_compiler(self): method = get_install_method() if method == "dmg": - logger.info( - "Structured output is not available in the DMG version " - "(xgrammar requires torch which significantly increases app size). " - "Use the pip or Homebrew version for structured output support." + logger.warning( + "GrammarCompiler initialization failed for %s on the " + "DMG build. The bundle ships xgrammar against a torch " + "stub; this usually means the bundled xgrammar / tvm-" + "ffi version drifted past what the stub covers.", + self._model_name, ) elif method == "homebrew": logger.info( @@ -1825,6 +1900,127 @@ async def chat( **kwargs, ) + async def preflight_chat( + self, + messages: list[dict[str, Any]], + tools: list[dict] | None = None, + request_id: str | None = None, + **kwargs, + ) -> None: + """Early prefill memory check for chat completions (VLM path). + + The actual VLM prompt is built by ``_process_chat_messages`` → + ``_prepare_vision_inputs``, which expands each image content-part + into 256–1280 model-specific image-placeholder tokens before the + chat template runs. Doing that work here would require image + decoding + the heavy preprocessor pipeline; for preflight we only + need a conservative upper bound on the prompt size, so we instead: + + 1. Apply the *text-only* chat template (cheap). + 2. Count its tokens. + 3. Add a per-image upper-bound budget (``_IMAGE_TOKEN_UPPER_BOUND``) + for each image-bearing content part — over-counts somewhat + on small images (false-positive 413s for borderline-and-image + cases) but never under-counts, which is the property the + guard needs to stay safe against the Apple IOGPUFamily + panic path. + + Tools (when supplied as Pydantic ``ToolDefinition`` objects by + direct API callers) must be converted to dict form for the + template — ``BatchedEngine.preflight_chat`` does this and we + mirror it here. Without conversion the template's ``TypeError`` + retry path silently drops tools entirely, which not only + miscalibrates the token count but also bypasses the actual + tool-prompt rendering on the real chat path. + + Raises ``PrefillMemoryExceededError`` if the conservative estimate + would exceed the configured memory ceiling. See + ``BatchedEngine.preflight_chat`` for the upstream rationale + (avoiding the ``StreamingResponse`` 200 commit so HTTP 413 + actually reaches the client). + """ + if not self._loaded: + await self.start() + template_tools = convert_tools_for_template(tools) if tools else None + ct_kwargs = kwargs.get("chat_template_kwargs") + partial = kwargs.get("is_partial") + # Strip image content-parts BEFORE templating. Modern HF chat + # templates (Qwen2.5-VL, Gemma-Vision, Llama-3.2-Vision) render + # ``image_url`` / ``image`` content parts as literal placeholder + # strings inline with the text; if we leave them in, the + # tokenized prompt already contains some image-placeholder + # tokens AND we then add the per-image budget on top — a double + # count that 413's borderline image-bearing prompts the real + # chat path would have handled. The real ``chat`` flow itself + # strips images first via ``extract_images_from_messages`` (see + # ``_process_chat_messages``), so mirroring that here keeps + # preflight and execution on the same template input. + text_messages, _ = extract_images_from_messages(messages) + prompt = self._apply_chat_template( + text_messages, + template_tools, + chat_template_kwargs=ct_kwargs, + is_partial=partial, + ) + # Tokenizer errors propagate as 500 today regardless of where they + # fire; the real chat path's add_request → tokenize call has no + # path-specific 400 handler. Don't introduce a NEW failure mode + # in preflight: skip the memory check on tokenizer error and let + # the real chat path surface the same error through the existing + # handler chain. + try: + num_tokens = len(self._tokenizer.encode(prompt)) + except Exception as e: + logger.warning( + "VLMBatchedEngine.preflight_chat: tokenizer.encode raised " + "%s; skipping prefill memory check, real chat path will " + "surface the error", + type(e).__name__, + ) + return + # Count images from the ORIGINAL messages (the stripped + # ``text_messages`` no longer has the image content-parts). + num_tokens += _count_image_tokens( + messages, + per_image_upper_bound=_derive_image_token_upper_bound( + getattr(self, "_processor", None) + ), + ) + scheduler = getattr(getattr(self._engine, "engine", None), "scheduler", None) + if scheduler is None: + _warn_scheduler_unreachable_once(self, "preflight_chat") + return + scheduler.preflight_or_raise( + num_prompt_tokens=num_tokens, request_id=request_id + ) + + async def preflight_completion( + self, + prompt: str, + request_id: str | None = None, + **kwargs, + ) -> None: + """Early prefill memory check for plain /v1/completions calls (VLM).""" + if not self._loaded: + await self.start() + try: + num_tokens = len(self._tokenizer.encode(prompt)) + except Exception as e: + logger.warning( + "VLMBatchedEngine.preflight_completion: tokenizer.encode " + "raised %s; skipping prefill memory check, real completion " + "path will surface the error", + type(e).__name__, + ) + return + scheduler = getattr(getattr(self._engine, "engine", None), "scheduler", None) + if scheduler is None: + _warn_scheduler_unreachable_once(self, "preflight_completion") + return + scheduler.preflight_or_raise( + num_prompt_tokens=num_tokens, request_id=request_id + ) + async def stream_chat( self, messages: list[dict[str, Any]], diff --git a/omlx/engine_core.py b/omlx/engine_core.py index 2ed4c586b..702130f97 100644 --- a/omlx/engine_core.py +++ b/omlx/engine_core.py @@ -357,10 +357,23 @@ async def add_request( # Add to scheduler — route through the MLX executor so that # prefix cache reconstruction (mx.load, mx.concatenate) never # races with scheduler.step() on the Metal stream. See #95. + # + # The scheduler may raise (PrefillMemoryExceededError, or other + # validation errors) before the request enters self.waiting. In + # that case the consumer in stream_outputs / generate never sees + # the request_id and its finally-block cleanup never fires — + # without the explicit cleanup below the per-rejection leak + # accumulates one collector + one stream_state + one + # asyncio.Event per refused request. Re-raise after cleanup so + # the typed exception still reaches the FastAPI 413 handler. loop = asyncio.get_running_loop() - await loop.run_in_executor( - self._mlx_executor, self.scheduler.add_request, request - ) + try: + await loop.run_in_executor( + self._mlx_executor, self.scheduler.add_request, request + ) + except BaseException: + self._cleanup_request(request_id) + raise return request_id diff --git a/omlx/memory_monitor.py b/omlx/memory_monitor.py index 955e6ceea..985f4c3a5 100644 --- a/omlx/memory_monitor.py +++ b/omlx/memory_monitor.py @@ -68,23 +68,40 @@ class MemoryMonitor: def __init__( self, - max_kv_cache_memory: int, + max_kv_cache_memory: int | None, check_interval: float = 1.0, + *, + eviction_enabled: bool = True, ): """ Initialize the memory monitor. Args: - max_kv_cache_memory: Maximum memory for KV cache in bytes (required). - This is the absolute limit for KV cache memory usage. + max_kv_cache_memory: Maximum memory for KV cache in bytes. + Required when ``eviction_enabled=True``. May be ``None`` + (or 0) when the monitor is used only for prefill-peak + estimation and no eviction/pressure decisions are made + against this limit. check_interval: Minimum seconds between memory checks (for throttling). - """ - if max_kv_cache_memory <= 0: + eviction_enabled: When False, ``max_kv_cache_memory`` is not + consulted and estimation methods that depend on it raise. + Set False on schedulers in paged-SSD-only mode where the + monitor exists solely for prefill-peak estimation. + """ + if eviction_enabled and ( + max_kv_cache_memory is None or max_kv_cache_memory <= 0 + ): raise ValueError( - f"max_kv_cache_memory must be positive, got {max_kv_cache_memory}" + "max_kv_cache_memory must be positive when " + f"eviction_enabled=True, got {max_kv_cache_memory}" ) - self._max_kv_cache_memory = max_kv_cache_memory + self._max_kv_cache_memory = max_kv_cache_memory or 0 + self._eviction_enabled = eviction_enabled + # Public accessor — callers (Scheduler._evict_blocks_*) need a way + # to skip the eviction code path without reaching into a private + # attribute and without triggering a RuntimeError from + # estimate_blocks_to_free(). self._check_interval = check_interval self._max_memory = self._get_max_memory() @@ -111,9 +128,15 @@ def __init__( self._running_requests: int = 0 self._waiting_requests: int = 0 - logger.info( - f"MemoryMonitor initialized: max_kv_cache={format_bytes(max_kv_cache_memory)}" - ) + if self._eviction_enabled: + logger.info( + "MemoryMonitor initialized: max_kv_cache=%s", + format_bytes(self._max_kv_cache_memory), + ) + else: + logger.info( + "MemoryMonitor initialized (estimator-only, eviction disabled)" + ) def _get_max_memory(self) -> int: """ @@ -370,9 +393,10 @@ def estimate_prefill_peak_bytes( Estimate per-request prefill peak memory contribution (KV + SDPA). Returns only the part directly attributable to this request's prefill: - newly allocated KV cache + SDPA attention activation peak for the last - chunk. Does NOT include model weights (already in active baseline) or - MLX cache pool / python heap overhead (absorbed by enforcer's hard + KV cache for the new tokens being added + SDPA attention activation + peak for the last chunk. Does NOT include model weights (already in + active baseline), prefix-cached KV that is already resident, or MLX + cache pool / python heap overhead (absorbed by enforcer's hard threshold margin — see MemorySettings.hard_threshold). MLX SDPA internals (C++ fallback path, head_dim > 128): @@ -381,6 +405,15 @@ def estimate_prefill_peak_bytes( 3. out = scores @ V → [B, n_q, chunk, head_dim] float32 GQA: K/V broadcast, no extra allocation. + Critical: ``kv_len`` for the last chunk is the full prompt + (``new_tokens + cached_tokens``), NOT just ``new_tokens``. Even + when most of the prompt is served from prefix cache, the SDPA + scores tensor spans the entire prompt because attention is + computed against the reconstructed KV. Passing only + ``new_tokens`` here silently under-counts the panic-prone path + — exactly the case where prefix cache hits make long-context + prefill possible at all. + MLX SDPA fused kernel (head_dim <= 128): Tiled computation, O(n) memory. Only output buffer allocated. @@ -392,12 +425,19 @@ def estimate_prefill_peak_bytes( overflows the MetalAllocator slips past the preflight guard. Args: - new_tokens: Tokens to be prefilled (prompt minus cached prefix). - Drives newly allocated KV. Also the last chunk's query length. - chunk_size: Prefill step size (default 2048). - cached_tokens: Tokens already resident in the prompt cache. Adds to - the SDPA kv_len span but not to newly allocated KV (cached KV is - already counted in the caller's `current` baseline). + new_tokens: Tokens being prefilled this request (prompt minus + what the prefix cache already covers). Drives newly + allocated KV and the last chunk's query length. + chunk_size: Prefill step size (default 2048). Effective chunk + is ``min(chunk_size, new_tokens)`` since the last chunk + cannot be larger than the remaining new tokens. + cached_tokens: Tokens served from prefix cache. Added to + ``new_tokens`` for the SDPA scores K-dim because those + positions still participate in attention. Keyword-only with + a default of 0 so callers that don't know the cache state + still typecheck — but they get the under-counting behavior + this method was designed to fix, so always pass it when the + value is available. Returns: Per-request peak contribution in bytes (KV + SDPA). Returns 0 if @@ -411,20 +451,30 @@ def estimate_prefill_peak_bytes( if n_q == 0 or hd == 0: return 0 # can't estimate - # Last chunk attends over the full context; query length is the last - # chunk size (capped at the new-token count for short suffixes). - attn_span = new_tokens + cached_tokens - query_len = min(chunk_size, new_tokens) + if new_tokens <= 0: + return 0 + + # Effective chunk: bounded by the remaining new tokens. Short + # prompts (smaller than chunk_size) would otherwise be charged the + # full chunk_size width in the scores tensor, over-estimating by + # chunk_size / new_tokens — a constant-factor over-count that + # raised false-positive 413s on small prompts. + eff_chunk = min(chunk_size, new_tokens) + full_kv_len = new_tokens + max(cached_tokens, 0) if hd > 128: # Fallback: full attention matrix materialized in float32 - # scores [B, n_q, query_len, attn_span] + output [B, n_q, query_len, hd] - attn = n_q * query_len * attn_span * 4 - attn += n_q * query_len * hd * 4 # output buffer (small) + # scores [B, n_q, eff_chunk, full_kv_len] + output + # [B, n_q, eff_chunk, hd] + attn = n_q * eff_chunk * full_kv_len * 4 + attn += n_q * eff_chunk * hd * 4 # output buffer (small) else: # Fused kernel: tiled, only output buffer - attn = n_q * query_len * hd * 4 + attn = n_q * eff_chunk * hd * 4 + # KV growth attributable to this request: only the new tokens. + # The cached portion is already counted via the baseline + # mx.get_active_memory() reading on the caller side. kv = self.estimate_prompt_kv_bytes(new_tokens) return attn + kv @@ -439,6 +489,11 @@ def estimate_blocks_to_free(self, bytes_to_free: int, block_size: int) -> int: Returns: Number of blocks to evict. """ + if not self._eviction_enabled: + raise RuntimeError( + "estimate_blocks_to_free called on a MemoryMonitor " + "constructed with eviction_enabled=False" + ) block_mem = self.estimate_block_memory(block_size) if block_mem <= 0: return 0 @@ -457,6 +512,18 @@ def max_kv_cache_memory(self) -> int: """Get maximum KV cache memory limit.""" return self._max_kv_cache_memory + @property + def eviction_enabled(self) -> bool: + """Whether this monitor was built with eviction wiring. + + Paged-SSD-only mode passes ``eviction_enabled=False`` because + the SDPA-peak / prefill-admission paths don't need KV eviction + math. Callers (Scheduler._evict_blocks_*) check this before + calling ``estimate_blocks_to_free``, which would otherwise + raise ``RuntimeError``. + """ + return self._eviction_enabled + def get_stats(self) -> dict: """ Get memory statistics as a dictionary. diff --git a/omlx/process_memory_enforcer.py b/omlx/process_memory_enforcer.py index 7d11fee31..ae86c0d57 100644 --- a/omlx/process_memory_enforcer.py +++ b/omlx/process_memory_enforcer.py @@ -336,6 +336,11 @@ def __init__( # Most recently observed pressure level, consumed by scheduler / # admission control. Updated on every poll iteration. self._pressure_level: str = "ok" + # Engine types we've already complained about in + # _propagate_memory_limit's "scheduler unreachable" path. Prevents + # the per-poll warning from spamming logs while keeping the first + # occurrence loud enough to alert CI / oncall. + self._scheduler_resolve_warned: set[str] = set() # Last value passed to mx.set_wired_limit (0 if not yet applied # or the call failed). Used by the admin dashboard to surface a # warning when the kernel iogpu.wired_limit_mb is below this. @@ -548,51 +553,72 @@ def prefill_memory_guard(self, value: bool) -> None: logger.info(f"Prefill memory guard: {'enabled' if value else 'disabled'}") @staticmethod - def _resolve_scheduler(entry: Any) -> Any | None: - """Resolve the Scheduler instance from an EnginePool entry. - - Most engines (BatchedEngine, VLMBatchedEngine) wrap the scheduler - as ``entry.engine._engine.engine.scheduler`` (AsyncEngineCore → - EngineCore → Scheduler). Some non-streaming engines may expose - ``entry.engine.scheduler`` directly. Returns None if neither - path resolves. + def _resolve_scheduler(engine): + """Return the real Scheduler instance for an EnginePool entry. + + Both BatchedEngine and VLMBatchedEngine in the live engine pool + store the scheduler at ``self._engine.engine.scheduler`` (the outer + wrapper holds an AsyncEngineCore at ``_engine`` whose ``.engine`` + is the EngineCore that actually owns the scheduler). Neither + exposes a top-level ``.scheduler`` attribute, so the previous + ``getattr(engine, "scheduler", None)`` always returned None for + real engines and the propagation silently no-op'd — including the + prefill memory guard flag, which meant the guard was dead at + runtime regardless of the user's setting. Test mocks set + ``.scheduler`` directly, so the wrapper-traversal fallback only + kicks in for real engines. """ - eng = entry.engine - if eng is None: - return None - sched = getattr(eng, "scheduler", None) + sched = getattr(engine, "scheduler", None) if sched is not None: return sched - inner = getattr(eng, "_engine", None) + inner = getattr(engine, "_engine", None) if inner is None: return None - inner_engine = getattr(inner, "engine", None) - if inner_engine is None: - return None - return getattr(inner_engine, "scheduler", None) + return getattr(getattr(inner, "engine", None), "scheduler", None) def _propagate_memory_limit(self) -> None: """Propagate ceiling-derived watermarks to all schedulers. Called on every enforcer tick so the dynamic ceiling reaches the - schedulers as fast as the poll interval allows. + schedulers as fast as the poll interval allows. Iterates a + ``list(values())`` snapshot so a future refactor that moves an + EnginePool mutator off the loop cannot silently miss an engine + and regress the dead-guard bug this method exists to fix. """ ceiling = self._get_hard_limit_bytes() soft_limit = int(ceiling * self._soft_threshold) if ceiling > 0 else 0 admission_paused = self._pressure_level != "ok" - for entry in self._engine_pool._entries.values(): - scheduler = self._resolve_scheduler(entry) - if scheduler is not None: - scheduler._memory_limit_bytes = soft_limit - scheduler._memory_hard_limit_bytes = ceiling - scheduler._prefill_memory_guard = self._prefill_memory_guard - scheduler._admission_paused = admission_paused - scheduler._prefill_safe_zone_ratio = self._prefill_safe_zone_ratio - scheduler._prefill_min_chunk_tokens = self._prefill_min_chunk_tokens - bg = getattr(scheduler, "batch_generator", None) - if bg is not None and hasattr(bg, "_memory_limit_bytes"): - bg._memory_limit_bytes = soft_limit - bg._memory_hard_limit_bytes = ceiling + for entry in list(self._engine_pool._entries.values()): + if entry.engine is None: + continue + scheduler = self._resolve_scheduler(entry.engine) + if scheduler is None: + # Rate-limited per-engine-type so a wrapper-chain + # change is loud once instead of every poll. Silent + # no-op was the failure mode that originally hid the + # dead memory guard — surface it now. + engine_type = type(entry.engine).__name__ + if engine_type not in self._scheduler_resolve_warned: + self._scheduler_resolve_warned.add(engine_type) + logger.warning( + "ProcessMemoryEnforcer: could not resolve " + "scheduler for engine type %s — prefill memory " + "guard will not propagate to this engine. " + "Verify the wrapper chain " + "(engine._engine.engine.scheduler) still holds.", + engine_type, + ) + continue + scheduler._memory_limit_bytes = soft_limit + scheduler._memory_hard_limit_bytes = ceiling + scheduler._prefill_memory_guard = self._prefill_memory_guard + scheduler._admission_paused = admission_paused + scheduler._prefill_safe_zone_ratio = self._prefill_safe_zone_ratio + scheduler._prefill_min_chunk_tokens = self._prefill_min_chunk_tokens + bg = getattr(scheduler, "batch_generator", None) + if bg is not None and hasattr(bg, "_memory_limit_bytes"): + bg._memory_limit_bytes = soft_limit + bg._memory_hard_limit_bytes = ceiling def _walk_store_cache_caps(self) -> None: """Walk each scheduler's store-cache gate one step per poll (#1383). @@ -603,8 +629,10 @@ def _walk_store_cache_caps(self) -> None: `_propagate_memory_limit` to avoid double-stepping the cap when a transition fires. """ - for entry in self._engine_pool._entries.values(): - scheduler = self._resolve_scheduler(entry) + for entry in list(self._engine_pool._entries.values()): + if entry.engine is None: + continue + scheduler = self._resolve_scheduler(entry.engine) if scheduler is None: continue adjust = getattr(scheduler, "adjust_store_cache_cap", None) @@ -683,8 +711,6 @@ async def _check_and_enforce(self) -> None: new_level = "hard" if new_level != prev_level: - self._pressure_level = new_level - self._propagate_memory_limit() logger.info( f"Memory pressure level: {prev_level} -> {new_level} " f"(current={_format_gb(current)}, " @@ -693,6 +719,16 @@ async def _check_and_enforce(self) -> None: ) if new_level == "ok": + # Reflect the recovery so admission unpauses and the dashboard + # status reads "ok" again. Without this, the eviction branch + # below is the only path that updates ``_pressure_level``, so a + # soft → ok transition leaves the level stuck at "soft". The + # re-propagation pushes ``admission_paused=False`` down to the + # schedulers — the call at the top of this method ran while + # ``_pressure_level`` was still the prior (soft / hard) value. + if new_level != prev_level: + self._pressure_level = "ok" + self._propagate_memory_limit() # Still walk the store-cache cap so it can recover toward # max_num_seqs while pressure stays low (#1383). self._walk_store_cache_caps() diff --git a/omlx/scheduler.py b/omlx/scheduler.py index e260790e4..cc770e245 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -40,7 +40,7 @@ from .cache.observability import CacheRateTracker from .cache.paged_cache import PagedCacheManager from .cache.prefix_cache import BlockAwarePrefixCache -from .exceptions import is_cache_corruption_error +from .exceptions import PrefillMemoryExceededError, is_cache_corruption_error from .prefill_progress import get_prefill_tracker from .prefill_transient_tracker import PrefillTransientTracker from .request import Request, RequestOutput, RequestStatus, SamplingParams @@ -692,6 +692,20 @@ def __bool__(self) -> bool: return bool(self._valid_tcs) +@dataclass(frozen=True) +class _PreflightRejection: + """Structured rejection returned by ``_preflight_memory_check_tokens``. + + Carrying the numeric values lets callers populate + ``PrefillMemoryExceededError.estimated_bytes`` / ``limit_bytes`` + cleanly instead of parsing the human-readable message. + """ + + message: str + estimated_bytes: int + limit_bytes: int + + class Scheduler: """ Scheduler for continuous batching using mlx-lm BatchGenerator. @@ -789,12 +803,17 @@ def __init__( # Memory limits for inline prefill checking. # Set by ProcessMemoryEnforcer; propagated to BatchGenerator. + # Both limits are gated by ProcessMemoryEnforcer.max_bytes — the + # user-configured max_process_memory ceiling. The hard limit is + # the absolute reject/abort threshold (preflight + in-flight mid- + # prefill checks at ``_do_external_prefill`` / ``_step_prefill_chunk`` + # both compare current_usage + peak against this value). self._memory_limit_bytes: int = 0 # soft limit - self._memory_hard_limit_bytes: int = 0 # hard limit (system_ram - 4GB) + self._memory_hard_limit_bytes: int = 0 # hard limit self._prefill_memory_guard: bool = False # set by ProcessMemoryEnforcer # Set to True by ProcessMemoryEnforcer when phys_footprint crosses - # soft_threshold. Schedulers stop admitting new prefills while this is - # set; in-flight requests proceed. + # soft_threshold. Schedulers stop admitting new prefills while this + # is set; in-flight requests proceed. self._admission_paused: bool = False # Adaptive prefill throttle params, propagated from enforcer. # Until set, _adaptive_chunk_size is a no-op (returns requested as-is). @@ -869,7 +888,20 @@ def __init__( self.block_aware_cache: BlockAwarePrefixCache | None = None self.paged_ssd_cache_manager: PagedSSDCacheManager | None = None self._cache_rate_tracker = CacheRateTracker() - self.memory_monitor: MemoryMonitor | None = None + # Prefill-peak estimator used by _preflight_memory_check. Only + # the estimator path is exercised here (it reads head_dim / + # num_attention_heads / num_kv_cache_layers populated by + # _set_model_info_for_monitor()). The other MemoryMonitor methods + # — estimate_blocks_to_free, _check_memory_pressure — are dormant + # in paged-SSD-only mode. ``eviction_enabled=False`` makes that + # explicit: any future caller that wires eviction back up will + # fail loudly here rather than silently using a placeholder + # max_kv_cache_memory. + self.memory_monitor: MemoryMonitor | None = MemoryMonitor( + max_kv_cache_memory=None, + eviction_enabled=False, + ) + self._set_model_info_for_monitor() # Initialize paged SSD cache if paged_ssd_cache_dir is specified if self.config.paged_ssd_cache_dir: @@ -1949,13 +1981,30 @@ def _do_external_prefill( self._memory_hard_limit_bytes > 0 and current > self._memory_hard_limit_bytes ): - logger.warning( + msg = ( f"Prefill force-stopped at {processed_tokens} " f"tokens: memory {current / 1024**3:.1f}GB " f"exceeds ceiling " f"{self._memory_hard_limit_bytes / 1024**3:.1f}GB" ) - raise RuntimeError("Memory limit exceeded during prefill") + logger.warning(msg) + # Raise the typed exception so the FastAPI 413 + # handler can map it cleanly. The pre-admission + # preflight is the primary guard; this in-flight + # check is the race-safety net when memory shifted + # between admission and prefill. + from .exceptions import PrefillMemoryExceededError + + raise PrefillMemoryExceededError( + message=msg, + request_id=( + getattr(request, "request_id", None) + if "request" in locals() + else None + ), + estimated_bytes=current, + limit_bytes=self._memory_hard_limit_bytes, + ) elif current > self._memory_limit_bytes: logger.warning( f"Prefill above max_bytes at " @@ -2291,12 +2340,23 @@ def _step_prefill_chunk(self, state: _PrefillState) -> bool: self._memory_hard_limit_bytes > 0 and current > self._memory_hard_limit_bytes ): - raise RuntimeError( + msg = ( f"Memory limit exceeded during chunked prefill at " f"{state.tokens_processed}/{state.total_length - 1} tokens: " f"{current / 1024**3:.1f}GB exceeds ceiling " f"{self._memory_hard_limit_bytes / 1024**3:.1f}GB" ) + # See _do_external_prefill's identical check: race-safety + # net for the case where memory shifted between admission + # and prefill. Typed exception → HTTP 413. + from .exceptions import PrefillMemoryExceededError + + raise PrefillMemoryExceededError( + message=msg, + request_id=state.request.request_id, + estimated_bytes=current, + limit_bytes=self._memory_hard_limit_bytes, + ) elif current > self._memory_limit_bytes: logger.warning( f"Chunked prefill above max_bytes at " @@ -2417,25 +2477,9 @@ def _advance_chunked_prefills( # be fully processed by _process_pending_aborts() next step. self._prefill_states.pop(rid, None) continue - except RuntimeError as e: + except (RuntimeError, PrefillMemoryExceededError) as e: logger.error("Chunked prefill failed for %s: %s", rid, e) - self._prefill_states.pop(rid, None) - self.requests.pop(rid, None) - get_prefill_tracker().remove(rid) - # Drop Metal cache pool buffers held by the aborted chunk's - # forward / mx.eval transients. Without this, enforcer keeps - # seeing the burst footprint until the next mx.clear_cache(). - _sync_and_clear_cache() - # Surface the failure to the engine. Without this, the - # request is silently dropped and the client hangs. - rejected.append( - RequestOutput( - request_id=rid, - finished=True, - finish_reason="error", - error=str(e), - ) - ) + self._fail_prefill_request(rid, e, rejected) continue if not done: @@ -3553,9 +3597,18 @@ def add_request(self, request: Request) -> None: """ Add a new request to the scheduler. - Raises SchedulerQueueFullError when the waiting queue is at or above - the configured cap (max(max_num_seqs * 4, 32)). Server layer maps - this to HTTP 503 + Retry-After. + Raises: + - ``SchedulerQueueFullError`` when the waiting queue is at or + above the configured cap (max(max_num_seqs * 4, 32)). Server + layer maps this to HTTP 503 + Retry-After. + - ``PrefillMemoryExceededError`` when the preflight memory + check rejects the request. Server layer maps this to HTTP + 413. The rejection runs AFTER admission preprocessing + (tokenisation, prefix-cache lookup, block-table acquisition, + SpecPrefill scoring) but BEFORE ``self.waiting.append`` — any + state allocated during preprocessing (block-table refs, prefix + cache reservations) is rolled back on the raise path so the + rejection does not leak resources. Args: request: The request to add @@ -3693,6 +3746,43 @@ def add_request(self, request: Request) -> None: # Must run AFTER prefix cache check (scoring applies only to uncached suffix). self._try_specprefill_scoring(request) + # Synchronous prefill memory guard. Rejecting here (before append to + # self.waiting) means the request never enters MLX prefill, which is + # the path that triggers the Apple IOGPUFamily kernel bug + # (FB22091885 / ml-explore/mlx#3186). The _schedule_waiting() call + # still re-checks asynchronously as a race-safety net for cases where + # memory conditions change between add_request and scheduling. + # + # The HTTP layer runs ``preflight_or_raise`` before wrapping the + # response in a StreamingResponse so the 413 reaches the client + # cleanly. This synchronous in-add_request check is the + # defense-in-depth path for callers that bypass the server + # preflight (direct engine API, future endpoints). + rejection = self._preflight_memory_check(request) + if rejection is not None: + # Prefix-cache / SpecPrefill lookups above may have bumped + # block refs and primed the draft prefix cache; release + # both before raising so a rejection storm can't pin paged + # cache state. + self._release_paged_cache_for_request(request.request_id) + + logger.warning( + f"Request {request.request_id} rejected by prefill memory " + f"guard (sync): {rejection.message}" + ) + try: + from .server_metrics import get_server_metrics + + get_server_metrics().record_preflight_rejection("hard_limit") + except Exception: + pass + raise PrefillMemoryExceededError( + message=rejection.message, + request_id=request.request_id, + estimated_bytes=rejection.estimated_bytes, + limit_bytes=rejection.limit_bytes, + ) + # Add to tracking self.requests[request.request_id] = request self.waiting.append(request) @@ -4526,57 +4616,149 @@ def get_num_running(self) -> int: """Get number of running requests.""" return len(self.running) - def _preflight_memory_check(self, request: "Request") -> str | None: - """ - Estimate whether prefill would exceed memory limits. - - Computes worst-case peak memory for the last prefill chunk - (model weights + KV cache + SDPA attention matrix) and rejects - if it would exceed the hard limit. - - For head_dim > 128, MLX SDPA uses a fallback that materializes - the full attention matrix [B, n_q, chunk, kv_len] in float32. - For head_dim <= 128, MLX uses a fused kernel with O(n) memory. - - Returns: - Error message string if request should be rejected, None if OK. + def _preflight_memory_check_tokens( + self, num_prompt_tokens: int, cached_tokens: int = 0 + ) -> "_PreflightRejection | None": + """Token-count form of the prefill memory guard — see + ``_preflight_memory_check`` for the rejection rationale. + + Decoupled from ``Request`` so the server layer can run an early + admission check immediately after tokenization, before wrapping + the response in a ``StreamingResponse`` (whose + ``http.response.start`` lands before any route-handler exception + can adjust the status code, locking the client to HTTP 200). + + Returns a ``_PreflightRejection`` carrying the diagnostic + message, estimated peak bytes, and the hard limit bytes if + rejection is warranted, or None if the request fits. Returning a + structured value lets callers populate + ``PrefillMemoryExceededError.estimated_bytes`` / ``limit_bytes`` + without parsing the human-readable string. + + Both fields are written from a single ProcessMemoryEnforcer + poll tick under the GIL, so the (guard, hard_limit) pair is + consistent for current CPython. See + ``ProcessMemoryEnforcer._propagate_memory_limit``. """ if not self._prefill_memory_guard: return None - if self._memory_hard_limit_bytes <= 0: + hard_limit = self._memory_hard_limit_bytes + if hard_limit <= 0: return None if self.memory_monitor is None: return None - prompt_tokens = request.num_prompt_tokens - cached_tokens = request.cached_tokens or 0 - new_tokens = max(prompt_tokens - cached_tokens, 0) - + new_tokens = max(num_prompt_tokens - cached_tokens, 0) if new_tokens == 0: return None peak = self.memory_monitor.estimate_prefill_peak_bytes( - new_tokens, self.config.prefill_step_size, cached_tokens=cached_tokens + new_tokens, + self.config.prefill_step_size, + cached_tokens=cached_tokens, ) if peak == 0: return None # can't estimate, skip current = max(mx.get_active_memory(), get_phys_footprint()) + estimated = current + peak - if current + peak > self._memory_hard_limit_bytes: + if estimated > hard_limit: from .utils.hardware import format_bytes usage_gb = current / (1024**3) - ceiling_gb = self._memory_hard_limit_bytes / (1024**3) - return ( - f"Prefill would require ~{format_bytes(current + peak)} peak " + ceiling_gb = hard_limit / (1024**3) + msg = ( + f"Prefill would require ~{format_bytes(estimated)} peak " f"(current {format_bytes(current)} + KV+SDPA {format_bytes(peak)}) " - f"but ceiling is {format_bytes(self._memory_hard_limit_bytes)} " + f"but ceiling is {format_bytes(hard_limit)} " f"(usage {usage_gb:.1f} GB, ceiling {ceiling_gb:.1f} GB). " f"Reduce context length or lower memory_guard_tier." ) + return _PreflightRejection( + message=msg, + estimated_bytes=estimated, + limit_bytes=hard_limit, + ) return None + def preflight_or_raise( + self, + num_prompt_tokens: int, + cached_tokens: int = 0, + request_id: str | None = None, + ) -> None: + """Run the prefill memory check and raise PrefillMemoryExceededError + on rejection. No-op when the guard is disabled or the request + would fit. + + Called from the API server layer BEFORE the response is wrapped + in a StreamingResponse, so the typed exception can be mapped to + HTTP 413 by the registered FastAPI handler. The synchronous + re-check inside ``add_request`` and the async re-check inside + ``_schedule_waiting`` remain as defense-in-depth for callers + that bypass the server preflight (direct engine API, future + endpoints). + """ + rej = self._preflight_memory_check_tokens(num_prompt_tokens, cached_tokens) + if rej is None: + return + from .exceptions import PrefillMemoryExceededError + + # Stable, unique label per rejection — caller-supplied if + # available, otherwise a short uuid so operators can correlate + # the log line below with the FastAPI handler trace and the + # client-side error body. Avoids the prior default of literal + # "preflight" for every rejection, which was useless for tracing. + if not request_id: + import uuid as _uuid + + request_id = f"preflight-{_uuid.uuid4().hex[:8]}" + logger.warning( + f"Preflight rejected ({num_prompt_tokens} tokens, " + f"cached={cached_tokens}, request_id={request_id}): " + f"{rej.message}" + ) + try: + from .server_metrics import get_server_metrics + + get_server_metrics().record_preflight_rejection("hard_limit") + except Exception: + pass + raise PrefillMemoryExceededError( + message=rej.message, + request_id=request_id, + estimated_bytes=rej.estimated_bytes, + limit_bytes=rej.limit_bytes, + ) + + def _preflight_memory_check( + self, request: "Request" + ) -> "_PreflightRejection | None": + """ + Estimate whether prefill would exceed memory limits. + + Computes worst-case peak memory for the last prefill chunk + (model weights + KV cache + SDPA attention matrix) and rejects + if it would exceed the hard limit. + + For head_dim > 128, MLX SDPA uses a fallback that materializes + the full attention matrix [B, n_q, chunk, kv_len] in float32. + For head_dim <= 128, MLX uses a fused kernel with O(n) memory. + + Delegates to ``_preflight_memory_check_tokens`` which reads + the (guard, hard_limit) pair directly off the scheduler. See + ``ProcessMemoryEnforcer._propagate_memory_limit``. + + Returns: + ``_PreflightRejection`` if the request should be rejected, + None if OK. + """ + return self._preflight_memory_check_tokens( + num_prompt_tokens=request.num_prompt_tokens, + cached_tokens=request.cached_tokens or 0, + ) + def _schedule_waiting( self, ) -> tuple[list["Request"], list[RequestOutput]]: @@ -4738,19 +4920,32 @@ def _schedule_waiting( # Pre-flight memory guard: estimate peak memory for this request # and reject if it would exceed the hard limit. - preflight_error = self._preflight_memory_check(request) - if preflight_error: + preflight_rejection = self._preflight_memory_check(request) + if preflight_rejection is not None: logger.warning( f"Request {request.request_id} rejected by prefill " - f"memory guard: {preflight_error}" + f"memory guard: {preflight_rejection.message}" ) + self._release_paged_cache_for_request(request.request_id) self.requests.pop(request.request_id, None) + # Best-effort metric — guarded by try so a missing + # server_metrics module (e.g. embedded scheduler tests + # constructing the scheduler without the FastAPI app) + # doesn't break the rejection path. + try: + from .server_metrics import get_server_metrics + + get_server_metrics().record_preflight_rejection( + "hard_limit" + ) + except Exception: + pass rejected_outputs.append( RequestOutput( request_id=request.request_id, finished=True, finish_reason="error", - error=preflight_error, + error=preflight_rejection.message, ) ) continue @@ -5007,30 +5202,15 @@ def _sparse_progress(processed: int, total: int) -> None: done = self._step_prefill_chunk(state) except _PrefillAbortedError: raise - except RuntimeError as e: - # Hard memory limit hit on the first chunk. - # _step_prefill_chunk updates the PrefillProgressTracker - # before the limit check, so without this catch the - # tracker entry leaks and stays in the dashboard - # forever (#1405). Mirrors the cleanup in - # _advance_chunked_prefills (d736bfd). + except (RuntimeError, PrefillMemoryExceededError) as e: logger.error( - "Chunked prefill (first chunk) failed for %s: %s", + "Chunked prefill (first chunk) failed for " + "%s: %s", request.request_id, e, ) - self.requests.pop(request.request_id, None) - get_prefill_tracker().remove(request.request_id) - # Drop Metal cache pool buffers held by the aborted - # first chunk's forward / mx.eval transients. - _sync_and_clear_cache() - rejected_outputs.append( - RequestOutput( - request_id=request.request_id, - finished=True, - finish_reason="error", - error=str(e), - ) + self._fail_prefill_request( + request.request_id, e, rejected_outputs ) continue @@ -5059,29 +5239,17 @@ def _sparse_progress(processed: int, total: int) -> None: cache_to_use, vlm_embeds=vlm_embeds, ) - except RuntimeError as e: - # Hard memory limit hit during external prefill. Without - # this catch, the exception bubbles up to step() and then - # engine_core's fail_all_requests(), which pops - # self.requests but cannot reach the PrefillProgressTracker - # singleton, so the dashboard entry leaks across model - # reload (#1405). Mirrors the cleanup in - # _advance_chunked_prefills (d736bfd). - logger.error("Prefill failed for %s: %s", request.request_id, e) - self.uid_to_request_id.pop(temp_uid, None) - self.request_id_to_uid.pop(request.request_id, None) - self.requests.pop(request.request_id, None) - get_prefill_tracker().remove(request.request_id) - # Drop Metal cache pool buffers held by the aborted - # chunk's forward / mx.eval transients. - _sync_and_clear_cache() - rejected_outputs.append( - RequestOutput( - request_id=request.request_id, - finished=True, - finish_reason="error", - error=str(e), - ) + except (RuntimeError, PrefillMemoryExceededError) as e: + logger.error( + "Non-chunked prefill failed for %s: %s", + request.request_id, + e, + ) + self._fail_prefill_request( + request.request_id, + e, + rejected_outputs, + temp_uid=temp_uid, ) continue @@ -5433,6 +5601,65 @@ def _process_batch_responses( return outputs, finished_ids + def _fail_prefill_request( + self, + request_id: str, + error: BaseException, + rejected_outputs: list, + *, + temp_uid: int | None = None, + ) -> None: + """Tear down all per-request state for a prefill that raised + before insert, and append a rejected ``RequestOutput`` so the + engine surfaces the error to the consumer instead of silently + dropping the request. + """ + self._prefill_states.pop(request_id, None) + if temp_uid is not None: + self.uid_to_request_id.pop(temp_uid, None) + self.request_id_to_uid.pop(request_id, None) + self._release_paged_cache_for_request(request_id) + self.requests.pop(request_id, None) + get_prefill_tracker().remove(request_id) + # Drop Metal cache pool buffers held by the aborted chunk's + # forward / mx.eval transients. Without this, enforcer keeps + # seeing the burst footprint until the next mx.clear_cache(). + _sync_and_clear_cache() + rejected_outputs.append( + RequestOutput( + request_id=request_id, + finished=True, + finish_reason="error", + error=str(error), + ) + ) + + def _release_paged_cache_for_request(self, request_id: str) -> None: + """Drop a request's paged-cache footprint on the rejection paths. + + ``add_request`` routes through ``block_aware_cache.fetch_cache`` + which increments ref counts on every prefix-matched block and + creates a ``block_table`` in the paged cache. The normal + completion path releases that state in ``_cleanup_finished``; + the prefill-rejection paths must do the same or rejected + requests leak block refs (pinning the paged cache and + compounding the very memory pressure that triggered the + rejection) and orphan ``request_tables`` entries. + + When SpecPrefill is configured, ``_try_specprefill_scoring`` + also primes the draft prefix cache via its own ``fetch_cache`` + which lives in an independent ``_request_tables`` and paged + block pool; release that too so the rejection symmetry holds + on both caches. + """ + if self.block_aware_cache is not None: + self.block_aware_cache.release_cache(request_id) + elif self.paged_cache_manager is not None: + self.paged_cache_manager.delete_block_table(request_id) + draft_cache = getattr(self, "_draft_prefix_cache", None) + if draft_cache is not None: + draft_cache.release_cache(request_id) + def _cleanup_finished(self, finished_ids: set[str]) -> None: """Clean up finished requests and store caches for reuse.""" # Synchronize pending engine stream operations before cache storage. @@ -6209,6 +6436,26 @@ def _set_model_info_for_monitor(self) -> None: logger.debug("Could not extract model config for memory estimation") return + # VLM / multimodal configs (e.g. Qwen3.6-VL, Gemma-4) nest the + # language-model dimensions under a sub-config. Prefer + # text_config / language_config / llm_config when ANY of them + # exposes the LM layer count, even if the top-level config also + # has one — on some VLM packs (older Gemma-3, certain Llava / HF + # auto-wrappers) the top-level field refers to the *vision + # encoder*, not the LM, and accepting it silently miscalibrates + # the SDPA-peak estimate by a constant factor. Probe both + # ``num_hidden_layers`` and the legacy ``n_layer`` alias so a + # GPT-style nested config is also picked up. Falls back to the + # top-level config only when no sub-config has either field. + for sub_attr in ("text_config", "language_config", "llm_config"): + sub = getattr(config, sub_attr, None) + if sub is not None and ( + getattr(sub, "num_hidden_layers", None) + or getattr(sub, "n_layer", None) + ): + config = sub + break + # Extract KV cache dimensions num_layers = getattr(config, "num_hidden_layers", None) or getattr( config, "n_layer", None @@ -6381,6 +6628,12 @@ def _evict_blocks_permanently(self, bytes_to_free: int) -> int: """ if self.paged_cache_manager is None or self.memory_monitor is None: return 0 + # Dormant in paged-SSD-only mode: the MemoryMonitor is constructed + # with eviction_enabled=False so estimate_blocks_to_free would + # raise. Return 0 cleanly until a future paged-SSD eviction path + # rewires real KV-cache budget into the monitor. + if not self.memory_monitor.eviction_enabled: + return 0 # Estimate how many blocks to evict block_size = self.config.paged_cache_block_size @@ -6434,6 +6687,9 @@ def _evict_blocks_to_cold(self, bytes_to_free: int) -> int: if self.memory_monitor is None: return 0 + if not self.memory_monitor.eviction_enabled: + # See _evict_blocks_permanently — dormant in paged-SSD-only mode. + return 0 # Estimate how many blocks to evict block_size = self.config.paged_cache_block_size diff --git a/omlx/server.py b/omlx/server.py index 35f0893c9..e10009d6e 100644 --- a/omlx/server.py +++ b/omlx/server.py @@ -165,6 +165,7 @@ ModelLoadingError, ModelNotFoundError, ModelTooLargeError, + PrefillMemoryExceededError, SchedulerQueueFullError, ) from .model_discovery import format_size @@ -459,6 +460,12 @@ def _status_to_error_type(status_code: int) -> str: return "authentication_error" if status_code == 404: return "not_found_error" + if status_code == 413: + # Prefill-memory-guard rejection. OpenAI uses + # invalid_request_error for context-window-exceeded as well; we + # use the finer-grained 413 status but the same type so existing + # clients that branch on type still recognise the failure mode. + return "invalid_request_error" if status_code == 429: return "rate_limit_error" if status_code >= 500: @@ -467,7 +474,16 @@ def _status_to_error_type(status_code: int) -> str: def _is_api_route(request: FastAPIRequest) -> bool: - """Check if request targets an OpenAI-compatible API route.""" + """Check if request targets an OpenAI-compatible API route. + + Path-prefix only. This assumes the FastAPI app is mounted at root + (the oMLX deployment shape) and that route paths are case-sensitive + — both true today. If a future deployment mounts this app under a + prefix (``app.mount("/api", ...)``), ``request.url.path`` returns + the full mounted path and every ``/v1/...`` route would be + classified as non-API. Switch to ``request.scope.get("route")`` + matching at that point. + """ return request.url.path.startswith("/v1/") @@ -560,6 +576,60 @@ async def scheduler_queue_full_handler( ) +@app.exception_handler(PrefillMemoryExceededError) +async def prefill_memory_exceeded_handler( + request: FastAPIRequest, exc: PrefillMemoryExceededError +): + """Map prefill peak overshoot to HTTP 413 with a clean JSON body. + + The synchronous prefill memory guard in ``Scheduler.add_request`` raises + this when the estimated KV+SDPA peak for a request would push memory + past the user-configured ``max_process_memory``. The caller's prompt + fits in the model's context window but is too large for the host's + headroom, so 413 (Payload Too Large) is the right code. + + Code vs status trade-off: 413 (Payload Too Large) is a request-shape + error; 507 (Insufficient Storage) more accurately describes "host + can't service it right now". We use 413 because it maps cleanly to + the OpenAI SDK retry-with-smaller-input flow that most clients + already implement, and the ``error.code`` field gives consumers a + machine-readable discriminator (``"prefill_memory_exceeded"``) + distinct from genuine context-window-exceeded rejections. + """ + detail = str(exc) + logger.warning( + "%s %s → 413: %s", + request.method, + request.url.path, + detail, + ) + if _is_api_route(request): + # code="prefill_memory_exceeded" lets OpenAI-SDK clients branch + # on the failure mode. Without it, "context window too small" + # and "host has no memory headroom" both surface as + # invalid_request_error with code=None and clients can only + # tell the user "shorten your prompt" — which is wrong when + # the actual fix is to raise --max-process-memory. + content = _openai_error_body( + detail, 413, code="prefill_memory_exceeded" + ) + # Surface the structured fields so clients can branch on + # numeric values instead of regex-matching the human message. + # OpenAI clients ignore unknown error fields so this is a + # forward-compatible extension. + if exc.estimated_bytes is not None: + content["error"]["estimated_bytes"] = exc.estimated_bytes + if exc.limit_bytes is not None: + content["error"]["limit_bytes"] = exc.limit_bytes + else: + content = {"detail": detail} + if exc.estimated_bytes is not None: + content["estimated_bytes"] = exc.estimated_bytes + if exc.limit_bytes is not None: + content["limit_bytes"] = exc.limit_bytes + return JSONResponse(status_code=413, content=content) + + @app.exception_handler(Exception) async def unhandled_exception_handler(request: FastAPIRequest, exc: Exception): """Log unhandled exceptions as 500 errors.""" @@ -1975,6 +2045,17 @@ async def create_completion( num_tokens = len(engine.tokenizer.encode(prompt)) validate_context_window(num_tokens, request.model) + # Pre-flight prefill memory guard — see create_chat_completion for + # the reason this must precede any StreamingResponse return. + # Thread the client-provided X-Request-ID when present so the 413 + # log line and the FastAPI handler trace correlate with whatever + # the client is using on its side. + upstream_request_id = http_request.headers.get("x-request-id") + for prompt in prompts: + await engine.preflight_completion( + prompt, request_id=upstream_request_id + ) + if request.stream: return StreamingResponse( _with_sse_keepalive( @@ -2313,6 +2394,19 @@ async def create_chat_completion( if request.stop: chat_kwargs["stop"] = request.stop + # Pre-flight prefill memory guard. Must run BEFORE either branch wraps + # the response in a StreamingResponse — starlette emits + # http.response.start (status 200) before iterating the body generator, + # so a typed exception thrown later by add_request lands as "Caught + # handled exception, but response already started" and the client sees + # an incomplete chunked read. Running the check here lets + # prefill_memory_exceeded_handler return a clean HTTP 413. + await engine.preflight_chat( + messages, + request_id=http_request.headers.get("x-request-id"), + **chat_kwargs, + ) + if request.stream: return StreamingResponse( _with_sse_keepalive( @@ -2553,6 +2647,8 @@ def _compile_with_structural_tag(compiler, fmt: dict, reasoning_parser: str, protocol structure (thinking tags, channel markers, etc.) and patches the user's grammar into the output slot. """ + from omlx._torch_stub import install as _install_torch_stub + _install_torch_stub() import xgrammar as xgr reasoning = not ( @@ -2615,17 +2711,20 @@ def _compile_grammar_for_request( from omlx.utils.install import get_install_method method = get_install_method() - if method == "dmg": - detail = ( - "Structured output is not available in the DMG version. " - "xgrammar requires torch which significantly increases app size. " - "Use the pip or Homebrew version for structured output support." - ) - elif method == "homebrew": + if method == "homebrew": detail = ( "Structured output requires xgrammar. " "Reinstall with: brew reinstall omlx --with-grammar" ) + elif method == "dmg": + # DMG bundles xgrammar with a torch stub; reaching this + # branch means the bundled load failed (e.g. native binding + # incompatibility). Surface it instead of pointing users to + # a different install method. + detail = ( + "Structured output is unavailable: xgrammar failed to " + "load in this build. Please report this issue." + ) else: detail = ( "Structured output requires xgrammar. " @@ -3617,6 +3716,14 @@ async def create_anthropic_message( if request.stop_sequences: chat_kwargs["stop"] = request.stop_sequences + # Pre-flight prefill memory guard — must precede any StreamingResponse + # return so PrefillMemoryExceededError can be mapped to HTTP 413. + await engine.preflight_chat( + messages, + request_id=http_request.headers.get("x-request-id"), + **chat_kwargs, + ) + if request.stream: return StreamingResponse( _with_sse_keepalive( @@ -4017,6 +4124,14 @@ async def create_response( if merged_ct_kwargs: chat_kwargs["chat_template_kwargs"] = merged_ct_kwargs + # Pre-flight prefill memory guard — must precede any StreamingResponse + # return so PrefillMemoryExceededError can be mapped to HTTP 413. + await engine.preflight_chat( + messages, + request_id=http_request.headers.get("x-request-id"), + **chat_kwargs, + ) + if request.stream: return StreamingResponse( _with_sse_keepalive( diff --git a/omlx/server_metrics.py b/omlx/server_metrics.py index 6a1ea5139..e0b1be29f 100644 --- a/omlx/server_metrics.py +++ b/omlx/server_metrics.py @@ -46,6 +46,17 @@ def __init__(self, stats_path: Optional[Path] = None): self.total_generation_duration: float = 0.0 self._per_model: Dict[str, Dict[str, Any]] = {} + # Preflight rejections — split by reason so operators can tell + # "user fed an oversize prompt" (hard_limit) apart from "system + # under memory pressure, admission paused" (admission_paused). + # This is the only operator-visible signal for the prefill-peak + # memory guard; the guard was previously dead and shipping it + # without telemetry repeats that mistake. + self.preflight_rejections: Dict[str, int] = { + "hard_limit": 0, + "admission_paused": 0, + } + # All-time totals (persisted across restarts) self._alltime_prompt_tokens: int = 0 self._alltime_completion_tokens: int = 0 @@ -198,6 +209,19 @@ def record_request_complete( # Periodic save self._maybe_save_alltime() + def record_preflight_rejection(self, reason: str) -> None: + """Increment the preflight-rejection counter for ``reason``. + + Unknown reasons are bucketed under ``"other"`` rather than + silently dropped so an operator notices when a new reject path + is added without updating the metric. + """ + with self._lock: + if reason not in self.preflight_rejections: + reason = "other" + self.preflight_rejections.setdefault("other", 0) + self.preflight_rejections[reason] += 1 + def _build_snapshot( self, prompt: int, diff --git a/packaging/README.md b/packaging/README.md index bd11efbb8..ffb45988d 100644 --- a/packaging/README.md +++ b/packaging/README.md @@ -56,8 +56,16 @@ No application layer — the Swift app is the application surface. ## Installation -1. Open the DMG produced by the Swift build. -2. Drag `oMLX.app` to Applications. -3. Launch the app (appears in the menubar). -4. Walk through the first-run wizard (Storage + API key), then Start +The Swift build (`build.sh release`) produces `build/Stage/oMLX.app` +directly — no DMG step. To install: + +1. Drag `build/Stage/oMLX.app` to `/Applications`, or `open` it + in-place to launch from `build/Stage/`. +2. Launch the app (appears in the menubar). +3. Walk through the first-run wizard (Storage + API key), then Start Server. + +> The DMGs on the [Releases](https://github.com/jundot/omlx/releases) +> page are produced by an off-tree maintainer pipeline, not by anything +> in this repo. End users follow the Releases install path; this +> section is for developers building from source. diff --git a/pyproject.toml b/pyproject.toml index af632b15f..e0e4ffcab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -143,6 +143,12 @@ dev = [ "mypy>=1.0.0", "mcp>=1.0.0", "venvstacks>=0.7.0", + # The torch-stub smoke test (tests/test_torch_stub.py) gates + # xgrammar / tvm-ffi version bumps. It skips when these aren't + # importable, which silently hides regressions — install no-deps + # in the dev environment so the test actually runs. + "xgrammar==0.2.0", + "apache-tvm-ffi==0.1.11", ] # PEP 735 dependency groups — consumed by `uv sync --dev`. # Keep in sync with [project.optional-dependencies] dev above @@ -156,6 +162,8 @@ dev = [ "mypy>=1.0.0", "mcp>=1.0.0", "venvstacks>=0.7.0", + "xgrammar==0.2.0", + "apache-tvm-ffi==0.1.11", ] [project.urls] diff --git a/tests/conftest.py b/tests/conftest.py index 5cb937b2a..45104ef0c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,13 @@ import pytest +# Install the torch stub before any test imports xgrammar (e.g. via @patch +# decorators that resolve the target at collection time). When real torch is +# present this is a no-op; in the DMG layout it satisfies xgrammar's +# import-time torch references so the package can load. +from omlx._torch_stub import install as _install_torch_stub +_install_torch_stub() + from omlx.request import Request, SamplingParams diff --git a/tests/test_engine_preflight.py b/tests/test_engine_preflight.py new file mode 100644 index 000000000..972a628fc --- /dev/null +++ b/tests/test_engine_preflight.py @@ -0,0 +1,602 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for ``preflight_chat`` / ``preflight_completion`` on the engine +wrappers. + +The full end-to-end value of these methods is that they raise +``PrefillMemoryExceededError`` BEFORE the route handler wraps the +response in a ``StreamingResponse``, so the FastAPI handler can turn +the exception into HTTP 413. We exercise the contract by: + +- Stubbing the wrapper chain (engine -> _engine.engine.scheduler) and the + tokenizer. +- Confirming ``preflight_or_raise`` is invoked with the right token count. +- Confirming the exception type propagates. +""" + +from unittest.mock import MagicMock + +import pytest + +from omlx.exceptions import PrefillMemoryExceededError +from omlx.scheduler import Scheduler + + +# --------------------------------------------------------------------------- +# Scheduler.preflight_or_raise / _preflight_memory_check_tokens +# --------------------------------------------------------------------------- + + +class _ModelConfig: + def __init__(self, num_hidden_layers=32, num_key_value_heads=8, + num_attention_heads=32, head_dim=192): + self.num_hidden_layers = num_hidden_layers + self.num_key_value_heads = num_key_value_heads + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + + +def _make_scheduler(): + from omlx.scheduler import SchedulerConfig + + model = MagicMock() + model.layers = [] + model.config = _ModelConfig() + del model.make_cache + + tokenizer = MagicMock() + tokenizer.eos_token_id = 2 + + config = SchedulerConfig( + max_num_seqs=8, prefill_step_size=2048, paged_cache_block_size=0, + ) + return Scheduler(model=model, tokenizer=tokenizer, config=config) + + +class TestPreflightOrRaise: + def test_raises_when_peak_exceeds_limit(self, monkeypatch): + scheduler = _make_scheduler() + scheduler._prefill_memory_guard = True + scheduler._memory_hard_limit_bytes = 1 # any allocation overshoots + + import omlx.scheduler as scheduler_mod + + monkeypatch.setattr(scheduler_mod.mx, "get_active_memory", lambda: 0) + monkeypatch.setattr(scheduler_mod, "get_phys_footprint", lambda: 0) + + with pytest.raises(PrefillMemoryExceededError) as exc: + scheduler.preflight_or_raise(num_prompt_tokens=65536, request_id="req-x") + assert "Prefill would require" in str(exc.value) + assert exc.value.request_id == "req-x" + + def test_returns_silently_when_within_budget(self, monkeypatch): + scheduler = _make_scheduler() + scheduler._prefill_memory_guard = True + scheduler._memory_hard_limit_bytes = 10 ** 18 # effectively unbounded + + import omlx.scheduler as scheduler_mod + + monkeypatch.setattr(scheduler_mod.mx, "get_active_memory", lambda: 0) + monkeypatch.setattr(scheduler_mod, "get_phys_footprint", lambda: 0) + + # Must not raise + scheduler.preflight_or_raise(num_prompt_tokens=1024) + + def test_skips_when_guard_disabled(self): + scheduler = _make_scheduler() + scheduler._prefill_memory_guard = False + scheduler._memory_hard_limit_bytes = 1 + # Even with an impossibly small limit, disabled guard never raises. + scheduler.preflight_or_raise(num_prompt_tokens=10 ** 6) + + def test_accounts_for_cached_tokens(self, monkeypatch): + """A fully cached request must not be rejected even at a tiny limit.""" + scheduler = _make_scheduler() + scheduler._prefill_memory_guard = True + scheduler._memory_hard_limit_bytes = 1 + + import omlx.scheduler as scheduler_mod + + monkeypatch.setattr(scheduler_mod.mx, "get_active_memory", lambda: 0) + monkeypatch.setattr(scheduler_mod, "get_phys_footprint", lambda: 0) + + scheduler.preflight_or_raise(num_prompt_tokens=10_000, cached_tokens=10_000) + + +# --------------------------------------------------------------------------- +# Engine wrapper preflight methods +# --------------------------------------------------------------------------- + + +def _build_engine_with_stub_scheduler(engine_cls, scheduler): + """Return an engine of the given class wired to a stub scheduler chain. + + The real BatchedEngine / VLMBatchedEngine init does heavy work (model + load, etc.). For the preflight contract test we only need the wrapper + methods + tokenizer + the ``_engine.engine.scheduler`` chain, so we + bypass __init__ via __new__ and pin only the attributes the preflight + method touches. + """ + engine = engine_cls.__new__(engine_cls) + engine._loaded = True + engine._enable_thinking = None + + tokenizer = MagicMock() + tokenizer.apply_chat_template = MagicMock(return_value="hello world") + # The encoded length drives what we pass to preflight_or_raise. + tokenizer.encode = MagicMock(return_value=list(range(110_000))) + engine._tokenizer = tokenizer + + # Wrapper chain that _resolve_scheduler / preflight_chat traverse: + # engine._engine.engine.scheduler + inner_engine_core = MagicMock(spec=["scheduler"]) + inner_engine_core.scheduler = scheduler + async_engine_core = MagicMock(spec=["engine"]) + async_engine_core.engine = inner_engine_core + engine._engine = async_engine_core + return engine + + +@pytest.mark.asyncio +async def test_batched_engine_preflight_chat_raises_for_oversize_prompt(monkeypatch): + from omlx.engine.batched import BatchedEngine + + scheduler = _make_scheduler() + scheduler._prefill_memory_guard = True + scheduler._memory_hard_limit_bytes = 1 # force rejection + + import omlx.scheduler as scheduler_mod + + monkeypatch.setattr(scheduler_mod.mx, "get_active_memory", lambda: 0) + monkeypatch.setattr(scheduler_mod, "get_phys_footprint", lambda: 0) + + engine = _build_engine_with_stub_scheduler(BatchedEngine, scheduler) + # _preprocess_messages on BatchedEngine assumes Harmony hooks etc.; stub + # it out so the test only exercises the preflight wiring. + engine._preprocess_messages = lambda m: m + + with pytest.raises(PrefillMemoryExceededError): + await engine.preflight_chat(messages=[{"role": "user", "content": "x"}]) + + +@pytest.mark.asyncio +async def test_vlm_engine_preflight_chat_raises_for_oversize_prompt(monkeypatch): + from omlx.engine.vlm import VLMBatchedEngine + + scheduler = _make_scheduler() + scheduler._prefill_memory_guard = True + scheduler._memory_hard_limit_bytes = 1 + + import omlx.scheduler as scheduler_mod + + monkeypatch.setattr(scheduler_mod.mx, "get_active_memory", lambda: 0) + monkeypatch.setattr(scheduler_mod, "get_phys_footprint", lambda: 0) + + engine = _build_engine_with_stub_scheduler(VLMBatchedEngine, scheduler) + + with pytest.raises(PrefillMemoryExceededError): + await engine.preflight_chat(messages=[{"role": "user", "content": "x"}]) + + +@pytest.mark.asyncio +async def test_preflight_completion_raises_for_oversize_prompt(monkeypatch): + from omlx.engine.batched import BatchedEngine + + scheduler = _make_scheduler() + scheduler._prefill_memory_guard = True + scheduler._memory_hard_limit_bytes = 1 + + import omlx.scheduler as scheduler_mod + + monkeypatch.setattr(scheduler_mod.mx, "get_active_memory", lambda: 0) + monkeypatch.setattr(scheduler_mod, "get_phys_footprint", lambda: 0) + + engine = _build_engine_with_stub_scheduler(BatchedEngine, scheduler) + + with pytest.raises(PrefillMemoryExceededError): + await engine.preflight_completion(prompt="a" * 110_000) + + +# --------------------------------------------------------------------------- +# VLM-specific contracts (image-token budget + tools conversion + cached +# tokens propagation through preflight_or_raise) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_vlm_preflight_chat_adds_image_token_budget(monkeypatch): + """Each image-bearing content part must add + ``_IMAGE_TOKEN_UPPER_BOUND_FALLBACK`` to the prompt size the scheduler sees, + so image-heavy borderline requests can't slip past.""" + from omlx.engine.vlm import VLMBatchedEngine, _IMAGE_TOKEN_UPPER_BOUND_FALLBACK + + scheduler = _make_scheduler() + engine = _build_engine_with_stub_scheduler(VLMBatchedEngine, scheduler) + # Make the templated text deterministically 1000 tokens. + engine._tokenizer.encode = MagicMock(return_value=list(range(1000))) + + seen: dict = {} + + def _capture(num_prompt_tokens, **kwargs): + seen["num_prompt_tokens"] = num_prompt_tokens + + scheduler.preflight_or_raise = _capture # type: ignore[assignment] + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "hello"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}, + {"type": "image", "source": {}}, + {"type": "text", "text": "world"}, + ], + } + ] + await engine.preflight_chat(messages=messages) + # 1000 text tokens + 2 images * 1280 + assert seen["num_prompt_tokens"] == 1000 + 2 * _IMAGE_TOKEN_UPPER_BOUND_FALLBACK + + +@pytest.mark.asyncio +async def test_vlm_preflight_chat_strips_images_before_template(monkeypatch): + """Modern HF chat templates (Qwen2.5-VL, Gemma-Vision, Llama-3.2-Vision) + render image content parts as literal placeholder strings inline with + the text. If preflight templates the raw messages, the resulting + tokenized prompt already contains image-placeholder tokens AND we + then add the per-image budget on top — a double count that + produces spurious 413s on borderline image-bearing requests the + real chat path would have admitted. ``preflight_chat`` must + therefore call ``extract_images_from_messages`` BEFORE + ``_apply_chat_template``, the same way ``_process_chat_messages`` + does on the execution path. + """ + from omlx.engine.vlm import VLMBatchedEngine + + scheduler = _make_scheduler() + engine = _build_engine_with_stub_scheduler(VLMBatchedEngine, scheduler) + engine._tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + engine._apply_chat_template = MagicMock(return_value="stripped text") + scheduler.preflight_or_raise = lambda **kw: None # type: ignore[assignment] + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "compare these:"}, + {"type": "image_url", "image_url": {"url": "data:..."}}, + {"type": "image", "source": {}}, + ], + } + ] + await engine.preflight_chat(messages=messages) + + # _apply_chat_template was called with image content-parts stripped. + assert engine._apply_chat_template.call_count == 1 + (call_messages, *_), _ = engine._apply_chat_template.call_args + user_content = call_messages[0]["content"] + if isinstance(user_content, list): + types_seen = {part.get("type") for part in user_content} + assert "image_url" not in types_seen, ( + "image_url part leaked into template input" + ) + assert "image" not in types_seen, ( + "image part leaked into template input" + ) + else: + # Some packs reduce single-text content to a string. + assert isinstance(user_content, str) + + +@pytest.mark.asyncio +async def test_vlm_preflight_chat_converts_pydantic_tools(monkeypatch): + """``preflight_chat`` must run tools through ``convert_tools_for_template`` + so Pydantic ``ToolDefinition`` callers don't get the silent + template-retry fallback that drops tools entirely.""" + from omlx.engine.vlm import VLMBatchedEngine + + scheduler = _make_scheduler() + engine = _build_engine_with_stub_scheduler(VLMBatchedEngine, scheduler) + engine._tokenizer.encode = MagicMock(return_value=[1]) + scheduler.preflight_or_raise = lambda **k: None # type: ignore[assignment] + + called_with = {} + original_apply = engine._apply_chat_template + + def _spy(messages, tools, **kwargs): + called_with["tools"] = tools + return "" + + engine._apply_chat_template = _spy # type: ignore[assignment] + + sentinel_tool = { + "type": "function", + "function": {"name": "do_x", "parameters": {}}, + } + await engine.preflight_chat(messages=[{"role": "user", "content": "x"}], tools=[sentinel_tool]) + + # convert_tools_for_template returned a list (possibly unchanged for a + # dict that already has the right shape, possibly transformed) — the + # contract is: tools were passed through the conversion path rather + # than the raw input. + assert called_with["tools"] is not None + + +@pytest.mark.asyncio +async def test_batched_engine_preflight_logs_when_scheduler_unreachable(monkeypatch, caplog): + """If the wrapper chain doesn't expose a scheduler (e.g. partial + init failure), preflight no-ops but logs a warning rather than + silently swallowing the safety check.""" + import logging + from omlx.engine.batched import BatchedEngine + + engine = BatchedEngine.__new__(BatchedEngine) + engine._loaded = True + engine._enable_thinking = None + engine._tokenizer = MagicMock() + engine._tokenizer.apply_chat_template = MagicMock(return_value="hi") + engine._tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + engine._preprocess_messages = lambda m: m + # _engine is None — simulates a partial-init failure where + # _resolve_scheduler chain can't reach a real scheduler. + engine._engine = None + + with caplog.at_level(logging.WARNING): + await engine.preflight_chat(messages=[{"role": "user", "content": "x"}]) + + assert any( + "preflight check skipped" in r.message for r in caplog.records + ), "expected a warning when scheduler is unreachable" + + +@pytest.mark.asyncio +async def test_preflight_chat_swallows_tokenizer_errors(caplog): + """Tokenizer errors during preflight must not raise — the real chat + path will hit the same error and surface it through the existing + handler chain. Raising here would introduce a NEW 500 failure mode + on borderline-malformed-prompt requests. + """ + import logging + from omlx.engine.batched import BatchedEngine + + scheduler = _make_scheduler() + engine = _build_engine_with_stub_scheduler(BatchedEngine, scheduler) + engine._tokenizer.encode = MagicMock( + side_effect=UnicodeDecodeError("utf-8", b"\xff\xfe", 0, 1, "synthetic") + ) + engine._preprocess_messages = lambda m: m + + raise_called = {"yes": False} + + def _trip(num_prompt_tokens, **kwargs): + raise_called["yes"] = True + + scheduler.preflight_or_raise = _trip # type: ignore[assignment] + + with caplog.at_level(logging.WARNING): + # Must NOT raise the UnicodeDecodeError up to the caller. + await engine.preflight_chat(messages=[{"role": "user", "content": "x"}]) + + assert not raise_called["yes"], ( + "preflight_or_raise must NOT be called when tokenizer fails" + ) + assert any( + "tokenizer.encode raised" in r.message for r in caplog.records + ), "expected a warning logging the tokenizer error" + + +@pytest.mark.asyncio +async def test_preflight_completion_swallows_tokenizer_errors(caplog): + """Same contract on the completion path.""" + import logging + from omlx.engine.batched import BatchedEngine + + scheduler = _make_scheduler() + engine = _build_engine_with_stub_scheduler(BatchedEngine, scheduler) + engine._tokenizer.encode = MagicMock(side_effect=ValueError("bad input")) + + raise_called = {"yes": False} + scheduler.preflight_or_raise = lambda **k: raise_called.__setitem__("yes", True) # type: ignore[assignment] + + with caplog.at_level(logging.WARNING): + await engine.preflight_completion(prompt="\x00" * 10) + + assert not raise_called["yes"] + assert any("tokenizer.encode raised" in r.message for r in caplog.records) + + +@pytest.mark.asyncio +async def test_vlm_preflight_chat_swallows_tokenizer_errors(caplog): + """VLM path mirrors BatchedEngine on tokenizer-error handling.""" + import logging + from omlx.engine.vlm import VLMBatchedEngine + + scheduler = _make_scheduler() + engine = _build_engine_with_stub_scheduler(VLMBatchedEngine, scheduler) + engine._tokenizer.encode = MagicMock(side_effect=RuntimeError("Already borrowed")) + + raise_called = {"yes": False} + scheduler.preflight_or_raise = lambda **k: raise_called.__setitem__("yes", True) # type: ignore[assignment] + + with caplog.at_level(logging.WARNING): + await engine.preflight_chat(messages=[{"role": "user", "content": "x"}]) + + assert not raise_called["yes"] + assert any("tokenizer.encode raised" in r.message for r in caplog.records) + + +# --------------------------------------------------------------------------- +# Regressions added in code review: structured rejection, request_id +# plumbing, and engine_core cleanup-on-raise leak. +# --------------------------------------------------------------------------- + + +def test_preflight_rejection_carries_estimated_and_limit_bytes(monkeypatch): + """``PrefillMemoryExceededError`` must surface the structured rejection + fields (``estimated_bytes`` / ``limit_bytes``) so clients can branch on + numeric values instead of regex-matching the human-readable message. + """ + scheduler = _make_scheduler() + scheduler._prefill_memory_guard = True + scheduler._memory_hard_limit_bytes = 1024 # tiny — forces rejection + + import omlx.scheduler as scheduler_mod + + monkeypatch.setattr(scheduler_mod.mx, "get_active_memory", lambda: 0) + monkeypatch.setattr(scheduler_mod, "get_phys_footprint", lambda: 0) + + with pytest.raises(PrefillMemoryExceededError) as exc_info: + scheduler.preflight_or_raise(num_prompt_tokens=65536, request_id="req-attrs") + exc = exc_info.value + assert exc.request_id == "req-attrs" + assert exc.limit_bytes == 1024 + assert exc.estimated_bytes is not None and exc.estimated_bytes > 0 + + +def test_preflight_or_raise_synthesizes_request_id_when_unset(monkeypatch): + """If the caller doesn't pass a request_id, preflight_or_raise must + generate a unique one so each rejection is individually traceable. + Regression for the prior literal "preflight" default which collapsed + every rejection's id together in logs and FastAPI handler traces. + """ + scheduler = _make_scheduler() + scheduler._prefill_memory_guard = True + scheduler._memory_hard_limit_bytes = 1 + + import omlx.scheduler as scheduler_mod + + monkeypatch.setattr(scheduler_mod.mx, "get_active_memory", lambda: 0) + monkeypatch.setattr(scheduler_mod, "get_phys_footprint", lambda: 0) + + ids = set() + for _ in range(4): + with pytest.raises(PrefillMemoryExceededError) as exc_info: + scheduler.preflight_or_raise(num_prompt_tokens=65536) + rid = exc_info.value.request_id + assert rid and rid != "preflight" + assert rid.startswith("preflight-") + ids.add(rid) + assert len(ids) == 4, "request_ids must be unique per rejection" + + +@pytest.mark.asyncio +async def test_batched_engine_preflight_chat_threads_request_id(monkeypatch): + """The engine wrapper must forward the caller's request_id to the + scheduler so the rejection log + exception carry a meaningful trace + label rather than the synthesized "preflight-XXXX" fallback. + """ + from omlx.engine.batched import BatchedEngine + + scheduler = _make_scheduler() + engine = _build_engine_with_stub_scheduler(BatchedEngine, scheduler) + engine._preprocess_messages = lambda m: m + engine._tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + + seen: dict = {} + + def _capture(num_prompt_tokens, **kwargs): + seen.update(kwargs) + seen["num_prompt_tokens"] = num_prompt_tokens + + scheduler.preflight_or_raise = _capture # type: ignore[assignment] + await engine.preflight_chat( + messages=[{"role": "user", "content": "x"}], + request_id="trace-id-42", + ) + assert seen.get("request_id") == "trace-id-42" + + +@pytest.mark.asyncio +async def test_engine_core_add_request_cleans_up_on_scheduler_raise( + monkeypatch, +): + """Regression for the engine_core leak: when scheduler.add_request + raises (e.g. PrefillMemoryExceededError) the per-request collector / + stream_state / finished_event entries must be removed. Without + cleanup, every rejection accumulates one of each — under sustained + rejection load this leaks indefinitely. + """ + from concurrent.futures import ThreadPoolExecutor + + from omlx.engine_core import EngineCore + + core = EngineCore.__new__(EngineCore) + core._output_collectors = {} + core._stream_states = {} + core._finished_events = {} + + class _Cfg: + stream_interval = 1 + + core.config = _Cfg() + core._mlx_executor = ThreadPoolExecutor(max_workers=1) + + raising_scheduler = MagicMock() + raising_scheduler._specprefill_draft_model = None + + def _raise(req): + raise PrefillMemoryExceededError( + message="rejected for test", + request_id=req.request_id, + estimated_bytes=10**9, + limit_bytes=10**8, + ) + + raising_scheduler.add_request = _raise + core.scheduler = raising_scheduler + + # Drive add_request enough that we can observe collectors before/after. + with pytest.raises(PrefillMemoryExceededError): + await core.add_request( + prompt=[1, 2, 3], + sampling_params=MagicMock(), + request_id="leak-check-1", + ) + + # All per-request engine_core entries must be cleaned up. + assert "leak-check-1" not in core._output_collectors + assert "leak-check-1" not in core._stream_states + assert "leak-check-1" not in core._finished_events + + core._mlx_executor.shutdown(wait=True) + + +def test_scheduler_add_request_cleans_block_table_on_rejection(monkeypatch): + """When add_request raises PrefillMemoryExceededError, any block_table + that the prefix-cache lookup attached must be released so a sustained + rejection stream cannot leak block tables / refcounts. + """ + scheduler = _make_scheduler() + scheduler._prefill_memory_guard = True + scheduler._memory_hard_limit_bytes = 1 + + import omlx.scheduler as scheduler_mod + + monkeypatch.setattr(scheduler_mod.mx, "get_active_memory", lambda: 0) + monkeypatch.setattr(scheduler_mod, "get_phys_footprint", lambda: 0) + + # Pin a fake block_table + paged_cache_manager so we can verify + # delete_block_table is called on the rejection path. + pcm = MagicMock() + scheduler.paged_cache_manager = pcm + + req = MagicMock() + req.request_id = "blk-clean-1" + req.num_prompt_tokens = 65536 + req.cached_tokens = 0 + req.block_table = MagicMock() + req.prompt = [1, 2, 3] + req.prompt_token_ids = [1, 2, 3] + req.vlm_extra_keys_for_cache = None + req.vlm_extra_key_token_start_for_cache = None + req.vlm_extra_key_ranges_for_cache = None + # Disable prefix-cache fetch so we don't go through the full lookup. + scheduler.block_aware_cache = None + # Disable SpecPrefill draft. + scheduler._specprefill_draft_model = None + + with pytest.raises(PrefillMemoryExceededError): + scheduler.add_request(req) + pcm.delete_block_table.assert_called_once_with("blk-clean-1") + # The request must not have entered self.waiting. + assert req not in scheduler.waiting + assert req.request_id not in scheduler.requests diff --git a/tests/test_hot_cache.py b/tests/test_hot_cache.py index 0acea9300..ead15848f 100644 --- a/tests/test_hot_cache.py +++ b/tests/test_hot_cache.py @@ -1128,11 +1128,13 @@ def test_ssd_write_drops_field_round_trips_through_get_stats(self, tmp_path): mgr.close() def test_ssd_write_drops_increments_on_hot_eviction_queue_full(self, tmp_path): - """Site 1: hot-cache eviction → put_nowait raises queue.Full → drop += 1. + """Site 1: hot-cache eviction → put raises queue.Full → drop += 1. - Patches the real queue's put_nowait to always raise queue.Full, - guaranteeing the drop path fires on the first eviction without any - dependency on the writer thread's drain rate. + Patches the real queue's put to raise queue.Full, guaranteeing the + drop path fires on the first eviction without any dependency on the + writer thread's drain rate. _enqueue_ssd_write uses put(item, + timeout=...) (not put_nowait) so a transient burst can ride over a + short writer-backlog window; sustained saturation still drops. """ import queue as _queue from unittest.mock import patch @@ -1149,12 +1151,12 @@ def test_ssd_write_drops_increments_on_hot_eviction_queue_full(self, tmp_path): ) try: with patch.object( - mgr._write_queue, "put_nowait", side_effect=_queue.Full + mgr._write_queue, "put", side_effect=_queue.Full ): self._save_block(mgr, b"qf_drop_block_00") self._save_block(mgr, b"qf_drop_block_01") - # save_02 evicts block 00 → _enqueue_ssd_write → put_nowait - # raises queue.Full → drop fires, cleanup runs. + # save_02 evicts block 00 → _enqueue_ssd_write → put raises + # queue.Full → drop fires, cleanup runs. self._save_block(mgr, b"qf_drop_block_02") stats = mgr.get_stats() @@ -1172,10 +1174,10 @@ def test_ssd_write_drops_increments_on_hot_eviction_queue_full(self, tmp_path): def test_ssd_write_drops_increments_on_cold_store_preflight(self, tmp_path): """Site 2: save_block preflight _write_queue.full() guard. - Hot cache disabled. Patches the queue's `full()` method to return - True so the preflight short-circuits on the first save. No real - queue manipulation, no writer-thread race, no risk of crashing the - writer with a malformed sentinel. + Hot cache disabled. Patches the queue's ``full()`` method to return + True so the preflight short-circuits before tensor extraction. The + guard's job is to avoid GPU work we'd otherwise throw away at the + put call when the writer is already saturated. """ from unittest.mock import patch @@ -1199,17 +1201,20 @@ def test_ssd_write_drops_increments_on_cold_store_preflight(self, tmp_path): stats = mgr.get_stats() assert stats.ssd_write_drops == 1 assert stats.errors == 0 - # Site 2 is a preflight rejection — no index/buffer state was - # created, so nothing to assert on cleanup. + # Preflight rejection: no index/buffer state was created, + # so nothing to assert on cleanup. finally: mgr.close() def test_ssd_write_drops_increments_on_cold_store_late_exception(self, tmp_path): - """Site 3: save_block put_nowait raises queue.Full after preflight passes. - - Hot cache disabled. put_nowait is patched to raise queue.Full even - though _write_queue.full() reports False — forces the late-exception - path inside save_block. Cleanup must remove index + pending hashes. + """Site 3: save_block put raises queue.Full after the preflight passes. + + Hot cache disabled. ``put`` is patched to raise queue.Full directly + (simulating a sustained writer-backlog saturation that materializes + after the preflight check). Cleanup must remove index + pending + hashes. The pre-eviction ``_write_queue.full()`` short-circuit + handles the easy case earlier; this test covers the race where the + queue fills between the preflight read and the put. """ import queue as _queue from unittest.mock import patch @@ -1223,7 +1228,7 @@ def test_ssd_write_drops_increments_on_cold_store_late_exception(self, tmp_path) cache_data = self._make_cache_data() block_hash = b"cold_late_drop_00" with patch.object( - mgr._write_queue, "put_nowait", side_effect=_queue.Full + mgr._write_queue, "put", side_effect=_queue.Full ): ok = mgr.save_block( block_hash=block_hash, diff --git a/tests/test_mbpp_extract_code.py b/tests/test_mbpp_extract_code.py new file mode 100644 index 000000000..7d92fd39f --- /dev/null +++ b/tests/test_mbpp_extract_code.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for omlx/eval/mbpp.py's _extract_code helper — the regex/ +heuristic parser that pulls Python source out of an LLM response +before MBPP runs it through the subprocess executor. The +benchmark-runner logic itself is exercised by test_eval.py; here we +just pin the code-extraction branches. +""" + +from __future__ import annotations + +from omlx.eval.mbpp import _extract_code + + +class TestFencedBlocks: + def test_python_fenced_block_extracted(self): + response = """Here's the solution: +```python +def add(a, b): + return a + b +``` +Done.""" + assert _extract_code(response) == "def add(a, b):\n return a + b" + + def test_unspecified_fenced_block_extracted(self): + """Some models emit ``` without a language tag — must still + work, otherwise we'd silently fail half the runs.""" + response = """Solution: +``` +def mul(a, b): + return a * b +```""" + assert _extract_code(response) == "def mul(a, b):\n return a * b" + + def test_python_fence_wins_over_unspecified(self): + """When both fence styles appear, the ```python regex runs + first — the language-tagged block is canonical.""" + response = """Wrong: +``` +print("plain") +``` +Right: +```python +def f(): + return 42 +```""" + assert _extract_code(response) == 'def f():\n return 42' + + def test_multiline_fenced_block(self): + response = """```python +import math + +def area(r): + return math.pi * r ** 2 + +# ready +```""" + result = _extract_code(response) + assert "import math" in result + assert "def area(r):" in result + assert "return math.pi * r ** 2" in result + + def test_fence_strips_surrounding_whitespace(self): + """The .strip() inside the regex branch must remove leading/ + trailing whitespace inside the fence — the subprocess executor + wants clean source.""" + response = "```python\n\n def f():\n return 1\n\n```" + result = _extract_code(response) + assert not result.startswith("\n") + assert not result.endswith("\n") + + +class TestHeuristicLineScan: + def test_def_starts_code_region(self): + """No fences → scan for first ``def `` line and take everything + from there. The chatty preamble must be stripped.""" + response = """Sure, here's a solution. +This function adds two numbers. + +def add(a, b): + return a + b""" + result = _extract_code(response) + assert result == "def add(a, b):\n return a + b" + + def test_class_starts_code_region(self): + response = """Here it is: + +class Counter: + def __init__(self): + self.n = 0""" + result = _extract_code(response) + assert result.startswith("class Counter:") + assert "self.n = 0" in result + + def test_import_starts_code_region(self): + response = """The solution uses math: +import math +def area(r): + return math.pi * r ** 2""" + result = _extract_code(response) + assert result.startswith("import math") + + def test_from_import_starts_code_region(self): + response = """We need typing: +from typing import List +def first(xs: List[int]) -> int: + return xs[0]""" + result = _extract_code(response) + assert result.startswith("from typing import List") + + def test_comment_starts_code_region(self): + """A leading ``# `` comment is treated as part of the code — + models sometimes start with ``# Solution`` before the actual + def.""" + response = """Here's my approach: +# Add two numbers +def add(a, b): + return a + b""" + result = _extract_code(response) + assert result.startswith("# Add two numbers") + assert "def add" in result + + def test_includes_lines_after_code_start(self): + """Once code starts, everything (including blank lines and + trailing text) is included — we don't try to find the END of + the code, just the start.""" + response = """def f(): + return 1 + +# This is part of the code now +print(f())""" + result = _extract_code(response) + assert "def f():" in result + assert "print(f())" in result + + +class TestFallback: + def test_response_with_no_code_markers_returned_as_is(self): + """When nothing matches, return the stripped response verbatim. + This is a degraded but non-crashing fallback — the subprocess + will surface the syntax error to the grader.""" + response = " I don't know how to solve this. " + assert _extract_code(response) == "I don't know how to solve this." + + def test_empty_response_returns_empty_string(self): + assert _extract_code("") == "" + + def test_whitespace_only_response_returns_empty(self): + assert _extract_code(" \n\n \t ") == "" + + def test_lone_print_statement_falls_through(self): + """A bare ``print(...)`` line doesn't start with def/class/ + import/from/# so the heuristic doesn't fire — falls back to + returning the whole response. Documents the limitation.""" + response = "print('hello')" + # Falls through to the whole-response fallback + assert _extract_code(response) == "print('hello')" diff --git a/tests/test_mcp_config.py b/tests/test_mcp_config.py new file mode 100644 index 000000000..1607479c9 --- /dev/null +++ b/tests/test_mcp_config.py @@ -0,0 +1,370 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for omlx/mcp/config.py — config file discovery, JSON/YAML +loading, schema validation, and the example-config helper. + +MCPConfig / MCPServerConfig / MCPTransport themselves are covered in +test_mcp_types.py; here we only exercise the loader and validator. +""" + +from __future__ import annotations + +import json + +import pytest + +from omlx.mcp import config as mcp_config +from omlx.mcp.config import ( + CONFIG_ENV_VAR, + create_example_config, + load_mcp_config, + validate_config, +) +from omlx.mcp.types import MCPConfig, MCPServerConfig, MCPTransport + + +@pytest.fixture +def isolated_env(monkeypatch, tmp_path): + """Run each test in a clean directory with no env var and an empty + search-path list. Prevents real ~/.config/omlx/mcp.json or a stray + ./mcp.json from leaking into the test.""" + monkeypatch.chdir(tmp_path) + monkeypatch.delenv(CONFIG_ENV_VAR, raising=False) + monkeypatch.setattr(mcp_config, "CONFIG_SEARCH_PATHS", []) + return tmp_path + + +# ============================================================================= +# validate_config +# ============================================================================= + + +class TestValidateConfigInput: + def test_non_dict_input_rejected(self): + with pytest.raises(ValueError, match="must be a dictionary"): + validate_config([]) # type: ignore[arg-type] + + def test_non_dict_input_rejected_for_string(self): + with pytest.raises(ValueError, match="must be a dictionary"): + validate_config("not a dict") # type: ignore[arg-type] + + def test_empty_dict_yields_defaults(self): + cfg = validate_config({}) + assert cfg.servers == {} + assert cfg.max_tool_calls == 10 + assert cfg.default_timeout == 30.0 + + def test_servers_string_value_rejected(self): + """A truthy non-dict value for ``servers`` reaches the + isinstance check and raises.""" + with pytest.raises(ValueError, match="'servers' must be a dictionary"): + validate_config({"servers": "not-a-dict"}) + + def test_servers_falsy_non_dict_silently_falls_through(self): + """Quirk of the ``data.get('servers') or data.get('mcpServers', {})`` + chain: an empty list, empty string, or 0 for ``servers`` is + falsy and triggers the fallback to ``mcpServers``. Documented + so a future tighten-up doesn't break callers relying on it. + """ + cfg = validate_config({"servers": []}) + assert cfg.servers == {} # silently treated as 'use mcpServers default' + + def test_server_entry_must_be_dict(self): + with pytest.raises( + ValueError, match="Server 'broken' config must be a dictionary" + ): + validate_config({"servers": {"broken": "not-a-dict"}}) + + +class TestValidateConfigServerLoading: + def test_stdio_server_loaded(self): + cfg = validate_config( + { + "servers": { + "fs": { + "transport": "stdio", + "command": "npx", + "args": ["-y", "filesystem"], + } + } + } + ) + assert "fs" in cfg.servers + srv = cfg.servers["fs"] + assert isinstance(srv, MCPServerConfig) + assert srv.transport == MCPTransport.STDIO + assert srv.command == "npx" + assert srv.args == ["-y", "filesystem"] + + def test_server_name_auto_set_from_key(self): + """User doesn't have to repeat ``name`` inside each server entry; + the key acts as the name. Tests both the loader's name injection + and that MCPServerConfig.__post_init__ doesn't override it.""" + cfg = validate_config( + {"servers": {"my-server": {"transport": "stdio", "command": "x"}}} + ) + assert cfg.servers["my-server"].name == "my-server" + + def test_explicit_name_in_entry_is_overridden_by_key(self): + """If the user wrote ``name: other`` inside the entry, the dict + key still wins — protects against name/key drift.""" + cfg = validate_config( + { + "servers": { + "real-name": { + "name": "wrong-name", + "transport": "stdio", + "command": "x", + } + } + } + ) + assert cfg.servers["real-name"].name == "real-name" + + def test_invalid_server_field_raises_value_error(self): + """Unknown kwargs to MCPServerConfig should surface as ValueError + with the offending server's name in the message — makes the + admin-panel error display actionable.""" + with pytest.raises(ValueError, match="Invalid config for server 'fs'"): + validate_config( + {"servers": {"fs": {"transport": "stdio", "command": "x", "bogus": 1}}} + ) + + def test_claude_desktop_mcpServers_format_accepted(self): + """Upstream chose to accept Claude Desktop's ``mcpServers`` key + as an alias for oMLX's ``servers``. Drop this and Claude users + lose drop-in compatibility.""" + cfg = validate_config( + { + "mcpServers": { + "claude-srv": {"transport": "stdio", "command": "npx"} + } + } + ) + assert "claude-srv" in cfg.servers + assert cfg.servers["claude-srv"].command == "npx" + + def test_servers_takes_precedence_over_mcpServers(self): + """When both keys are present, ``servers`` wins — the ``or`` + operator in load returns the first truthy value. This isn't + merging; it's an either/or.""" + cfg = validate_config( + { + "servers": {"a": {"transport": "stdio", "command": "x"}}, + "mcpServers": {"b": {"transport": "stdio", "command": "y"}}, + } + ) + assert set(cfg.servers.keys()) == {"a"} + + def test_empty_mcpServers_falls_through_to_servers(self): + """If ``servers`` is missing and ``mcpServers`` is empty, the + result is an empty servers dict, not an error.""" + cfg = validate_config({"mcpServers": {}}) + assert cfg.servers == {} + + +class TestValidateConfigGlobalOptions: + def test_custom_max_tool_calls(self): + cfg = validate_config({"max_tool_calls": 5}) + assert cfg.max_tool_calls == 5 + + def test_max_tool_calls_zero_rejected(self): + with pytest.raises(ValueError, match="'max_tool_calls' must be a positive integer"): + validate_config({"max_tool_calls": 0}) + + def test_max_tool_calls_negative_rejected(self): + with pytest.raises(ValueError, match="'max_tool_calls' must be a positive integer"): + validate_config({"max_tool_calls": -1}) + + def test_max_tool_calls_non_int_rejected(self): + with pytest.raises(ValueError, match="'max_tool_calls' must be a positive integer"): + validate_config({"max_tool_calls": 3.5}) + + def test_max_tool_calls_bool_rejected_in_practice(self): + """``isinstance(True, int)`` is True in Python, so True passes + the int check. Document the current behavior so it surfaces if + someone tightens the check later.""" + cfg = validate_config({"max_tool_calls": True}) + assert cfg.max_tool_calls is True # currently accepted; weird but expected + + def test_custom_default_timeout_int(self): + cfg = validate_config({"default_timeout": 60}) + assert cfg.default_timeout == 60 + + def test_custom_default_timeout_float(self): + cfg = validate_config({"default_timeout": 45.5}) + assert cfg.default_timeout == 45.5 + + def test_default_timeout_zero_rejected(self): + with pytest.raises(ValueError, match="'default_timeout' must be a positive number"): + validate_config({"default_timeout": 0}) + + def test_default_timeout_negative_rejected(self): + with pytest.raises(ValueError, match="'default_timeout' must be a positive number"): + validate_config({"default_timeout": -1.0}) + + def test_default_timeout_string_rejected(self): + with pytest.raises(ValueError, match="'default_timeout' must be a positive number"): + validate_config({"default_timeout": "30s"}) + + +# ============================================================================= +# _find_config_file (via load_mcp_config) +# ============================================================================= + + +class TestExplicitPath: + def test_existing_explicit_path_loads(self, isolated_env): + cfg_path = isolated_env / "custom.json" + cfg_path.write_text(json.dumps({"servers": {}})) + cfg = load_mcp_config(cfg_path) + assert isinstance(cfg, MCPConfig) + assert cfg.servers == {} + + def test_missing_explicit_path_raises(self, isolated_env): + with pytest.raises(FileNotFoundError, match="MCP config file not found"): + load_mcp_config(isolated_env / "does-not-exist.json") + + def test_explicit_path_tilde_expanded(self, isolated_env, monkeypatch): + """Tilde must expand — admins commonly pass ``~/.config/...`` + via --mcp-config flag.""" + monkeypatch.setenv("HOME", str(isolated_env)) + cfg_path = isolated_env / "tilde.json" + cfg_path.write_text(json.dumps({"servers": {}})) + cfg = load_mcp_config("~/tilde.json") + assert isinstance(cfg, MCPConfig) + + +class TestEnvVarPath: + def test_env_var_path_loads(self, isolated_env, monkeypatch): + cfg_path = isolated_env / "from-env.json" + cfg_path.write_text(json.dumps({"servers": {}})) + monkeypatch.setenv(CONFIG_ENV_VAR, str(cfg_path)) + cfg = load_mcp_config() + assert isinstance(cfg, MCPConfig) + + def test_env_var_missing_file_falls_through(self, isolated_env, monkeypatch, caplog): + """If OMLX_MCP_CONFIG points at a nonexistent file, the loader + logs a warning but continues to the search-path fallback rather + than aborting — broken env vars must not kill the server.""" + monkeypatch.setenv(CONFIG_ENV_VAR, str(isolated_env / "missing.json")) + with caplog.at_level("WARNING", logger="omlx.mcp.config"): + cfg = load_mcp_config() + assert isinstance(cfg, MCPConfig) + assert cfg.servers == {} + assert any("not found" in r.message for r in caplog.records) + + +class TestSearchPath: + def test_first_existing_search_path_wins(self, isolated_env, monkeypatch): + a = isolated_env / "a.json" + b = isolated_env / "b.json" + a.write_text(json.dumps({"max_tool_calls": 1})) + b.write_text(json.dumps({"max_tool_calls": 99})) + # Order matters — first existing path is chosen + monkeypatch.setattr(mcp_config, "CONFIG_SEARCH_PATHS", [str(a), str(b)]) + cfg = load_mcp_config() + assert cfg.max_tool_calls == 1 + + def test_falls_through_missing_paths(self, isolated_env, monkeypatch): + present = isolated_env / "found.json" + present.write_text(json.dumps({"max_tool_calls": 7})) + monkeypatch.setattr( + mcp_config, + "CONFIG_SEARCH_PATHS", + [ + str(isolated_env / "missing-1.json"), + str(isolated_env / "missing-2.json"), + str(present), + ], + ) + cfg = load_mcp_config() + assert cfg.max_tool_calls == 7 + + def test_no_config_anywhere_returns_empty(self, isolated_env): + """All discovery paths fail → empty MCPConfig (not None, not + FileNotFoundError). The server starts MCP-less in this case.""" + cfg = load_mcp_config() + assert isinstance(cfg, MCPConfig) + assert cfg.servers == {} + assert cfg.max_tool_calls == 10 + assert cfg.default_timeout == 30.0 + + +# ============================================================================= +# File-format handling +# ============================================================================= + + +class TestFileFormats: + def test_loads_json_file(self, isolated_env): + cfg_path = isolated_env / "mcp.json" + cfg_path.write_text( + json.dumps( + { + "servers": {"fs": {"transport": "stdio", "command": "x"}}, + "max_tool_calls": 3, + } + ) + ) + cfg = load_mcp_config(cfg_path) + assert "fs" in cfg.servers + assert cfg.max_tool_calls == 3 + + def test_invalid_json_raises(self, isolated_env): + cfg_path = isolated_env / "broken.json" + cfg_path.write_text("{not valid json") + with pytest.raises(json.JSONDecodeError): + load_mcp_config(cfg_path) + + def test_loads_yaml_file_when_pyyaml_available(self, isolated_env): + """Only run if PyYAML is installed; not a hard dependency of + the project.""" + pytest.importorskip("yaml") + cfg_path = isolated_env / "mcp.yaml" + cfg_path.write_text( + "servers:\n fs:\n transport: stdio\n command: npx\n" + "max_tool_calls: 4\n" + ) + cfg = load_mcp_config(cfg_path) + assert cfg.servers["fs"].command == "npx" + assert cfg.max_tool_calls == 4 + + def test_yml_extension_also_treated_as_yaml(self, isolated_env): + pytest.importorskip("yaml") + cfg_path = isolated_env / "mcp.yml" + cfg_path.write_text( + "servers:\n fs:\n transport: stdio\n command: x\n" + ) + cfg = load_mcp_config(cfg_path) + assert "fs" in cfg.servers + + +# ============================================================================= +# create_example_config +# ============================================================================= + + +class TestCreateExampleConfig: + def test_returns_valid_json_string(self): + example = create_example_config() + data = json.loads(example) # would raise if not valid JSON + assert isinstance(data, dict) + + def test_example_round_trips_through_validate(self): + """The example written to disk by ``omlx mcp init`` (or similar) + must be a valid config — otherwise the bootstrap UX is broken.""" + example = create_example_config() + data = json.loads(example) + cfg = validate_config(data) + assert isinstance(cfg, MCPConfig) + assert len(cfg.servers) >= 1 + # The example showcases multiple transports — keep a guard so + # future edits don't shrink it to just stdio. + transports = {s.transport for s in cfg.servers.values()} + assert MCPTransport.STDIO in transports + assert MCPTransport.SSE in transports + + def test_example_has_top_level_global_options(self): + data = json.loads(create_example_config()) + assert "max_tool_calls" in data + assert "default_timeout" in data diff --git a/tests/test_mcp_routes.py b/tests/test_mcp_routes.py new file mode 100644 index 000000000..08802cc74 --- /dev/null +++ b/tests/test_mcp_routes.py @@ -0,0 +1,269 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for omlx/api/mcp_routes.py — the HTTP layer over MCPClientManager. + +The manager itself is covered by tests/test_mcp_manager.py; here we only +verify the route handlers: response shape, alias handling, and the +no-manager fallbacks that ship in each endpoint. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from omlx.api import mcp_routes +from omlx.mcp.types import ( + MCPServerState, + MCPTool, + MCPToolResult, + MCPTransport, +) + + +@pytest.fixture +def app_client(): + """TestClient mounting only the MCP router.""" + app = FastAPI() + app.include_router(mcp_routes.router) + return TestClient(app) + + +@pytest.fixture(autouse=True) +def reset_mcp_manager_getter(): + """Each test gets a clean ``_get_mcp_manager`` slot. + + Routes consult a module-global callback; without this fixture a prior + test's getter would leak into the next test's no-manager path. + """ + original = mcp_routes._get_mcp_manager + mcp_routes._get_mcp_manager = None + yield + mcp_routes._get_mcp_manager = original + + +def _make_status(name, state=MCPServerState.CONNECTED, tools_count=0, error=None): + status = MagicMock() + status.name = name + status.state = state + status.transport = MCPTransport.STDIO + status.tools_count = tools_count + status.error = error + return status + + +class TestSetMcpManagerGetter: + def test_getter_is_installed_and_invoked(self): + sentinel = object() + mcp_routes.set_mcp_manager_getter(lambda: sentinel) + assert mcp_routes._get_manager() is sentinel + + def test_unset_getter_returns_none(self): + """No getter wired → _get_manager returns None. + + Routes lean on this to short-circuit into the empty-list / 503 + branches when the server starts without --mcp-config. + """ + assert mcp_routes._get_manager() is None + + def test_getter_returning_none_propagates(self): + """Manager is unset *because the getter says so*, not just because + no getter was registered. Routes must treat both identically.""" + mcp_routes.set_mcp_manager_getter(lambda: None) + assert mcp_routes._get_manager() is None + + +class TestListMcpTools: + def test_returns_empty_when_no_manager(self, app_client): + r = app_client.get("/v1/mcp/tools") + assert r.status_code == 200 + assert r.json() == {"tools": [], "count": 0} + + def test_serializes_tools_from_manager(self, app_client): + tool_a = MCPTool( + server_name="srv1", + name="add", + description="Add two numbers", + input_schema={"type": "object", "properties": {}}, + ) + tool_b = MCPTool( + server_name="srv2", + name="search", + description="Search the web", + input_schema={"type": "object"}, + ) + mgr = MagicMock() + mgr.get_all_tools.return_value = [tool_a, tool_b] + mcp_routes.set_mcp_manager_getter(lambda: mgr) + + r = app_client.get("/v1/mcp/tools") + assert r.status_code == 200 + body = r.json() + assert body["count"] == 2 + assert body["tools"][0] == { + "name": "srv1__add", # namespaced via MCPTool.full_name + "description": "Add two numbers", + "server": "srv1", + "parameters": {"type": "object", "properties": {}}, + } + assert body["tools"][1]["name"] == "srv2__search" + assert body["tools"][1]["server"] == "srv2" + + def test_zero_tools_returns_count_zero(self, app_client): + mgr = MagicMock() + mgr.get_all_tools.return_value = [] + mcp_routes.set_mcp_manager_getter(lambda: mgr) + + r = app_client.get("/v1/mcp/tools") + assert r.status_code == 200 + assert r.json() == {"tools": [], "count": 0} + + +class TestListMcpServers: + def test_returns_empty_when_no_manager(self, app_client): + r = app_client.get("/v1/mcp/servers") + assert r.status_code == 200 + assert r.json() == {"servers": []} + + def test_serializes_state_enum_to_string(self, app_client): + """The route flattens ``MCPServerState`` → its ``.value`` so JSON + clients see a plain string ("connected") not an enum repr.""" + mgr = MagicMock() + mgr.get_server_status.return_value = [ + _make_status("primary", state=MCPServerState.CONNECTED, tools_count=3) + ] + mcp_routes.set_mcp_manager_getter(lambda: mgr) + + r = app_client.get("/v1/mcp/servers") + assert r.status_code == 200 + servers = r.json()["servers"] + assert len(servers) == 1 + assert servers[0]["name"] == "primary" + assert servers[0]["state"] == "connected" # enum.value + assert servers[0]["transport"] == "stdio" + assert servers[0]["tools_count"] == 3 + assert servers[0]["error"] is None + + def test_propagates_error_field(self, app_client): + mgr = MagicMock() + mgr.get_server_status.return_value = [ + _make_status( + "broken", + state=MCPServerState.ERROR, + tools_count=0, + error="connection refused", + ) + ] + mcp_routes.set_mcp_manager_getter(lambda: mgr) + + r = app_client.get("/v1/mcp/servers") + body = r.json()["servers"][0] + assert body["state"] == "error" + assert body["error"] == "connection refused" + + +class TestExecuteMcpTool: + def test_returns_503_when_no_manager(self, app_client): + r = app_client.post( + "/v1/mcp/execute", + json={"tool_name": "srv__add", "arguments": {"a": 1, "b": 2}}, + ) + assert r.status_code == 503 + assert "MCP not configured" in r.json()["detail"] + assert "--mcp-config" in r.json()["detail"] + + def test_success_returns_result_payload(self, app_client): + result = MCPToolResult( + tool_name="srv__add", + content="3", + is_error=False, + error_message=None, + ) + mgr = MagicMock() + mgr.execute_tool = AsyncMock(return_value=result) + mcp_routes.set_mcp_manager_getter(lambda: mgr) + + r = app_client.post( + "/v1/mcp/execute", + json={"tool_name": "srv__add", "arguments": {"a": 1, "b": 2}}, + ) + assert r.status_code == 200 + assert r.json() == { + "tool_name": "srv__add", + "content": "3", + "is_error": False, + "error_message": None, + } + mgr.execute_tool.assert_awaited_once_with("srv__add", {"a": 1, "b": 2}) + + def test_error_result_propagates_is_error_and_message(self, app_client): + """A handled tool error returns 200 with is_error=True — only + unconfigured-manager raises 5xx. Lets clients distinguish 'tool + ran and failed' from 'server can't run tools at all'.""" + result = MCPToolResult( + tool_name="srv__broken", + content=None, + is_error=True, + error_message="upstream timeout", + ) + mgr = MagicMock() + mgr.execute_tool = AsyncMock(return_value=result) + mcp_routes.set_mcp_manager_getter(lambda: mgr) + + r = app_client.post( + "/v1/mcp/execute", + json={"tool_name": "srv__broken", "arguments": {}}, + ) + assert r.status_code == 200 + body = r.json() + assert body["is_error"] is True + assert body["error_message"] == "upstream timeout" + assert body["content"] is None + + def test_accepts_tool_alias_field(self, app_client): + """MCPExecuteRequest declares ``tool_name`` with + ``AliasChoices('tool_name', 'tool')`` — both wire formats must + reach the manager identically. Upstream PR #1285 added the + ``tool`` alias for compatibility with some external MCP clients; + without this test a future refactor could silently drop it. + """ + result = MCPToolResult( + tool_name="srv__add", content="ok", is_error=False + ) + mgr = MagicMock() + mgr.execute_tool = AsyncMock(return_value=result) + mcp_routes.set_mcp_manager_getter(lambda: mgr) + + r = app_client.post( + "/v1/mcp/execute", + json={"tool": "srv__add", "arguments": {"x": 1}}, + ) + assert r.status_code == 200 + mgr.execute_tool.assert_awaited_once_with("srv__add", {"x": 1}) + + def test_arguments_default_to_empty_dict(self, app_client): + """``arguments`` is optional — omitting it must yield ``{}``, + not None, otherwise the manager's signature would break.""" + result = MCPToolResult(tool_name="srv__noop", content=None) + mgr = MagicMock() + mgr.execute_tool = AsyncMock(return_value=result) + mcp_routes.set_mcp_manager_getter(lambda: mgr) + + r = app_client.post( + "/v1/mcp/execute", + json={"tool_name": "srv__noop"}, + ) + assert r.status_code == 200 + mgr.execute_tool.assert_awaited_once_with("srv__noop", {}) + + def test_missing_tool_name_returns_422(self, app_client): + """Pydantic validation must reject payloads with neither key.""" + mgr = MagicMock() + mgr.execute_tool = AsyncMock() + mcp_routes.set_mcp_manager_getter(lambda: mgr) + + r = app_client.post("/v1/mcp/execute", json={"arguments": {}}) + assert r.status_code == 422 + mgr.execute_tool.assert_not_awaited() diff --git a/tests/test_memory_monitor.py b/tests/test_memory_monitor.py index b3dac43a3..9692c7fef 100644 --- a/tests/test_memory_monitor.py +++ b/tests/test_memory_monitor.py @@ -55,6 +55,22 @@ def test_init_invalid_max_kv_cache_memory_negative(self): with pytest.raises(ValueError, match="max_kv_cache_memory"): MemoryMonitor(max_kv_cache_memory=-1) + def test_eviction_enabled_property_default_true(self): + """The default ``eviction_enabled=True`` makes the + public-facing predicate True so the existing tiered-cache + path keeps working without changes.""" + monitor = MemoryMonitor(max_kv_cache_memory=1024**3) + assert monitor.eviction_enabled is True + + def test_eviction_enabled_property_false_in_ssd_only_mode(self): + """Paged-SSD-only mode passes ``eviction_enabled=False``; the + public predicate must surface that so Scheduler can branch on + it (avoiding the RuntimeError from estimate_blocks_to_free).""" + monitor = MemoryMonitor( + max_kv_cache_memory=None, eviction_enabled=False + ) + assert monitor.eviction_enabled is False + def test_get_memory_info(self): """Test get_memory_info returns valid data.""" monitor = MemoryMonitor(max_kv_cache_memory=1024**3) @@ -284,6 +300,11 @@ def test_returns_zero_when_model_info_missing(self): m = MemoryMonitor(max_kv_cache_memory=10 * 1024**3) assert m.estimate_prefill_peak_bytes(32768, 2048) == 0 + def test_returns_zero_when_no_new_tokens(self): + # Fully-prefix-cached request: nothing to prefill, peak is 0. + m = self._make_monitor() + assert m.estimate_prefill_peak_bytes(0, 2048, cached_tokens=32768) == 0 + def test_fused_kernel_below_head_dim_128(self): # head_dim<=128 → fused tiled kernel, SDPA peak is just output buffer m = self._make_monitor(head_dim=128, n_attn=32, n_kv=4, n_layers=62) @@ -303,6 +324,35 @@ def test_fallback_path_above_head_dim_128(self): # Total ≈ 8 GB assert 7 * 1024**3 < peak < 9 * 1024**3 + def test_sdpa_fallback_accounts_for_cached_kv_in_scores_kdim(self): + """Regression for M3: the SDPA scores K-dim spans the FULL prompt + (cached + new), not just new_tokens. A heavily-cached long-context + request previously slipped through with under-counted peak — the + exact Qwen3.6-VL panic scenario this guard is supposed to catch. + """ + m = self._make_monitor(head_dim=256, n_attn=8, n_kv=4, n_layers=48) + # Same total prompt (100k), different cache split: + # - All-new: cached=0, new=100k + # - Heavy cache: cached=99k, new=1k + all_new = m.estimate_prefill_peak_bytes(100 * 1024, 2048) + heavy_cache = m.estimate_prefill_peak_bytes( + 1024, 2048, cached_tokens=99 * 1024 + ) + # The heavy-cache case must still report a substantial SDPA peak: + # n_attn * eff_chunk(=1024) * full_kv(=100k) * 4 bytes ≈ 3 GB. + # The KV addition for 1k new tokens is small (~50 MB), so the + # estimate is dominated by SDPA scores. The earlier (buggy) + # estimator passed only new_tokens to the scores formula and + # would have returned ~30 MB for this case — three orders of + # magnitude under-count. + assert heavy_cache > 1 * 1024**3, ( + f"heavy-cache peak under-counted: {heavy_cache / 1024**2:.0f} MB" + ) + # And the all-new case (larger eff_chunk = 2048 but same kv_len) + # should be larger overall because both KV growth and scores + # widen with new_tokens. + assert all_new > heavy_cache + def test_scales_linearly_with_token_count(self): m = self._make_monitor() p8k = m.estimate_prefill_peak_bytes(8 * 1024, 2048) @@ -323,6 +373,23 @@ def test_sdpa_fallback_scales_quadratically(self): ratio = p32k / p16k assert 1.8 < ratio < 2.2 + def test_eff_chunk_capped_at_new_tokens(self): + """Short prompts (smaller than chunk_size) must not be charged + the full chunk_size width — the effective chunk is bounded by + the number of remaining new tokens. Regression for the constant- + factor over-count on small prompts. + """ + m = self._make_monitor(head_dim=256, n_attn=8, n_kv=4, n_layers=48) + # 100-token prompt; chunk_size=2048. eff_chunk should be 100, + # not 2048 — so SDPA scores ≈ 8 * 100 * 100 * 4 = 320 KB, not + # 8 * 2048 * 100 * 4 = 6.5 MB. + peak = m.estimate_prefill_peak_bytes(100, 2048) + # KV: 48*4*256*2*2*100 ≈ 19 MB. SDPA ≪ KV here. Total < 25 MB. + assert peak < 25 * 1024**2, ( + f"short-prompt peak suggests chunk wasn't clamped: " + f"{peak / 1024**2:.0f} MB" + ) + def test_no_python_overhead_constant(self): # estimator must NOT include cache_pool_overhead or python_overhead # magic constants — those are absorbed by enforcer hard_threshold. diff --git a/tests/test_model_settings.py b/tests/test_model_settings.py index 9d211d6ee..c9a6f607f 100644 --- a/tests/test_model_settings.py +++ b/tests/test_model_settings.py @@ -197,6 +197,66 @@ def test_model_type_override_excluded_when_none(self): d = settings.to_dict() assert "model_type_override" not in d + def test_turboquant_kv_bits_default(self): + """Default bit depth = 4.""" + settings = ModelSettings() + assert settings.turboquant_kv_bits == 4 + + def test_turboquant_kv_bits_roundtrip(self): + original = ModelSettings(turboquant_kv_bits=2.5) + d = original.to_dict() + assert d["turboquant_kv_bits"] == 2.5 + restored = ModelSettings.from_dict(d) + assert restored.turboquant_kv_bits == 2.5 + + def test_turboquant_kv_bits_always_in_to_dict(self): + """Non-Optional field with a default must always serialize.""" + settings = ModelSettings() + assert "turboquant_kv_bits" in settings.to_dict() + + def test_turboquant_skip_last_default(self): + """Default = True — protects sensitive models from last-layer corruption.""" + settings = ModelSettings() + assert settings.turboquant_skip_last is True + + def test_turboquant_skip_last_roundtrip(self): + original = ModelSettings(turboquant_skip_last=False) + d = original.to_dict() + assert d["turboquant_skip_last"] is False + restored = ModelSettings.from_dict(d) + assert restored.turboquant_skip_last is False + + def test_vlm_mtp_draft_model_default(self): + settings = ModelSettings() + assert settings.vlm_mtp_draft_model is None + + def test_vlm_mtp_draft_model_roundtrip(self): + original = ModelSettings(vlm_mtp_draft_model="gemma-4-26B-A4B-it-assistant") + d = original.to_dict() + assert d["vlm_mtp_draft_model"] == "gemma-4-26B-A4B-it-assistant" + restored = ModelSettings.from_dict(d) + assert restored.vlm_mtp_draft_model == "gemma-4-26B-A4B-it-assistant" + + def test_vlm_mtp_draft_model_excluded_when_none(self): + settings = ModelSettings() + assert "vlm_mtp_draft_model" not in settings.to_dict() + + def test_vlm_mtp_draft_block_size_default(self): + """None means 'use mlx-vlm default'.""" + settings = ModelSettings() + assert settings.vlm_mtp_draft_block_size is None + + def test_vlm_mtp_draft_block_size_roundtrip(self): + original = ModelSettings(vlm_mtp_draft_block_size=8) + d = original.to_dict() + assert d["vlm_mtp_draft_block_size"] == 8 + restored = ModelSettings.from_dict(d) + assert restored.vlm_mtp_draft_block_size == 8 + + def test_vlm_mtp_draft_block_size_excluded_when_none(self): + settings = ModelSettings() + assert "vlm_mtp_draft_block_size" not in settings.to_dict() + class TestModelSettingsManager: """Tests for ModelSettingsManager class.""" diff --git a/tests/test_models_base.py b/tests/test_models_base.py new file mode 100644 index 000000000..3257ef330 --- /dev/null +++ b/tests/test_models_base.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for omlx/models/base_model.py — pure-math helpers used by +omlx/models/xlm_roberta.py (the reranker model). Pin the masking and +normalization semantics so a refactor doesn't silently change +embedding output. +""" + +from __future__ import annotations + +import math + +import mlx.core as mx + +from omlx.models.base_model import ( + BaseModelArgs, + BaseModelOutput, + mean_pooling, + normalize_embeddings, +) + + +class TestBaseModelDataclasses: + def test_base_model_args_instantiable(self): + """Empty marker dataclass — subclasses extend it.""" + BaseModelArgs() # must not raise + + def test_output_required_field(self): + out = BaseModelOutput(last_hidden_state=mx.zeros((1, 4, 8))) + assert out.text_embeds is None + assert out.pooler_output is None + assert out.hidden_states is None + + def test_output_with_all_fields(self): + hs = mx.zeros((1, 4, 8)) + emb = mx.ones((1, 8)) + pool = mx.ones((1, 8)) * 0.5 + all_hs = (hs, hs) + out = BaseModelOutput( + last_hidden_state=hs, + text_embeds=emb, + pooler_output=pool, + hidden_states=all_hs, + ) + assert out.text_embeds is emb + assert out.pooler_output is pool + assert out.hidden_states is all_hs + + +class TestMeanPooling: + def test_uniform_mask_averages_all_positions(self): + """When every position is unmasked, mean pooling = simple mean.""" + # batch=1, seq=4, hidden=3 + hs = mx.array([[[1.0, 2.0, 3.0], + [2.0, 4.0, 6.0], + [3.0, 6.0, 9.0], + [4.0, 8.0, 12.0]]]) + mask = mx.array([[1.0, 1.0, 1.0, 1.0]]) + pooled = mean_pooling(hs, mask) + # Mean across seq axis: (1+2+3+4)/4=2.5, (2+4+6+8)/4=5, (3+6+9+12)/4=7.5 + assert pooled.shape == (1, 3) + result = pooled.tolist() + assert math.isclose(result[0][0], 2.5, rel_tol=1e-5) + assert math.isclose(result[0][1], 5.0, rel_tol=1e-5) + assert math.isclose(result[0][2], 7.5, rel_tol=1e-5) + + def test_partial_mask_excludes_padded_positions(self): + """Padded positions (mask=0) must not contribute to the mean. + This is the load-bearing invariant — pre-mask sums would let + padding tokens corrupt the embedding for short inputs.""" + hs = mx.array([[[1.0, 1.0], + [2.0, 2.0], + [99.0, 99.0], # padded — must NOT be counted + [99.0, 99.0]]]) + mask = mx.array([[1.0, 1.0, 0.0, 0.0]]) + pooled = mean_pooling(hs, mask) + # Only first two positions count: mean(1,2)=1.5 + result = pooled.tolist() + assert math.isclose(result[0][0], 1.5, rel_tol=1e-5) + assert math.isclose(result[0][1], 1.5, rel_tol=1e-5) + + def test_all_zero_mask_does_not_divide_by_zero(self): + """If the entire mask is zero (pathological but possible from + upstream), the function must not produce NaN/Inf — the + ``clip(..., a_min=1e-9)`` guard exists for this.""" + hs = mx.array([[[5.0, 5.0], [5.0, 5.0]]]) + mask = mx.array([[0.0, 0.0]]) + pooled = mean_pooling(hs, mask) + # Both sum_embeddings AND sum_mask are 0 → 0 / 1e-9 = 0, not NaN + result = pooled.tolist() + assert all(math.isfinite(v) for v in result[0]) + + def test_batch_dimension_preserved(self): + """Batch dim should pass through — each row pooled + independently.""" + hs = mx.array([ + [[1.0, 0.0], [3.0, 0.0]], + [[2.0, 0.0], [4.0, 0.0]], + ]) + mask = mx.array([[1.0, 1.0], [1.0, 1.0]]) + pooled = mean_pooling(hs, mask) + assert pooled.shape == (2, 2) + result = pooled.tolist() + assert math.isclose(result[0][0], 2.0, rel_tol=1e-5) # (1+3)/2 + assert math.isclose(result[1][0], 3.0, rel_tol=1e-5) # (2+4)/2 + + def test_works_with_float16_dtype(self): + """Reranker inference often runs in fp16. Mask cast to the + hidden states' dtype is the whole point of the + ``mask_expanded.astype(hidden_states.dtype)`` line.""" + hs = mx.array([[[1.0, 1.0], [3.0, 3.0]]], dtype=mx.float16) + mask = mx.array([[1.0, 1.0]]) # default float32 + pooled = mean_pooling(hs, mask) + assert pooled.dtype == mx.float16 + + +class TestNormalizeEmbeddings: + def test_unit_norm_after_normalize(self): + emb = mx.array([[3.0, 4.0]]) # |v| = 5 + out = normalize_embeddings(emb) + # Each row should have L2 norm = 1 + norms = mx.linalg.norm(out, axis=-1).tolist() + assert math.isclose(norms[0], 1.0, rel_tol=1e-5) + + def test_normalizes_along_last_axis_only(self): + """The ``axis=-1`` is load-bearing — normalizing across the + wrong axis would silently destroy similarity comparisons. Test + with shape (batch=2, hidden=3).""" + emb = mx.array([[1.0, 0.0, 0.0], + [3.0, 4.0, 0.0]]) + out = normalize_embeddings(emb) + # Row 0 was already unit length + # Row 1 should become (3/5, 4/5, 0) + result = out.tolist() + assert math.isclose(result[0][0], 1.0, rel_tol=1e-5) + assert math.isclose(result[1][0], 0.6, rel_tol=1e-5) + assert math.isclose(result[1][1], 0.8, rel_tol=1e-5) + + def test_preserves_shape(self): + """Higher-rank inputs supported — (batch, seq, hidden) for + per-token embeddings.""" + emb = mx.ones((2, 5, 8)) + out = normalize_embeddings(emb) + assert out.shape == (2, 5, 8) + + def test_already_normalized_input_is_idempotent(self): + """Normalizing twice gives the same result — basic mathematical + invariant that catches accidental sign flips or scaling bugs.""" + emb = mx.array([[1.0, 2.0, 2.0]]) + once = normalize_embeddings(emb) + twice = normalize_embeddings(once) + # Compare as Python floats since mx.array doesn't have __eq__ that + # produces a scalar bool + a = once.tolist() + b = twice.tolist() + for x, y in zip(a[0], b[0]): + assert math.isclose(x, y, abs_tol=1e-6) diff --git a/tests/test_optimizations.py b/tests/test_optimizations.py new file mode 100644 index 000000000..fb21bb5cb --- /dev/null +++ b/tests/test_optimizations.py @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for omlx/optimizations.py — a thin hardware/MLX status helper. +The re-exported symbols (HardwareInfo, detect_hardware, get_total_memory_gb) +are covered by test_utils_hardware.py; here we pin the dict shape and the +flash-attention detection. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import mlx.core as mx + +from omlx import optimizations +from omlx.optimizations import ( + HardwareInfo, + detect_hardware, + get_optimization_status, + get_system_memory_gb, +) + + +class TestReExports: + def test_hardware_symbols_importable_from_optimizations(self): + """The module's docstring promises these names. Removing one + would silently break ``from omlx.optimizations import ...`` + used by external scripts.""" + from omlx.utils.hardware import detect_hardware as canonical_detect + from omlx.utils.hardware import HardwareInfo as CanonicalInfo + + assert detect_hardware is canonical_detect + assert HardwareInfo is CanonicalInfo + + def test_get_system_memory_gb_aliases_get_total_memory_gb(self): + """The re-export renames ``get_total_memory_gb`` → + ``get_system_memory_gb``. The alias must stay in place.""" + from omlx.utils.hardware import get_total_memory_gb + + assert get_system_memory_gb is get_total_memory_gb + + def test_all_lists_documented_surface(self): + assert set(optimizations.__all__) == { + "HardwareInfo", + "detect_hardware", + "get_system_memory_gb", + "get_optimization_status", + } + + +class TestGetOptimizationStatus: + def test_returns_top_level_keys(self): + status = get_optimization_status() + assert set(status.keys()) == {"hardware", "mlx_memory", "mlx_lm_features"} + + def test_hardware_section_shape(self): + status = get_optimization_status() + hw = status["hardware"] + assert set(hw.keys()) == {"chip", "total_memory_gb", "device_name"} + # chip is populated from detect_hardware().chip_name — non-empty + # string on any Apple Silicon test runner. + assert isinstance(hw["chip"], str) + assert isinstance(hw["total_memory_gb"], (int, float)) + assert hw["total_memory_gb"] > 0 + assert isinstance(hw["device_name"], str) + + def test_mlx_memory_section_is_byte_counters(self): + status = get_optimization_status() + mem = status["mlx_memory"] + assert set(mem.keys()) == {"active_bytes", "cache_bytes", "peak_bytes"} + # All three come straight from mx.get_*_memory(); non-negative ints + for key in mem: + assert isinstance(mem[key], int), f"{key} not an int" + assert mem[key] >= 0 + + def test_mlx_lm_features_static_strings(self): + """These strings appear in the admin dashboard. Pin them so a + typo or accidental rewording shows up as a test failure rather + than a confusing UI change.""" + features = get_optimization_status()["mlx_lm_features"] + assert features["metal_kernels"] == "optimized for Apple Silicon" + assert features["kv_cache"] == "managed by mlx-lm" + assert features["quantization"] == "4-bit and 8-bit supported" + + def test_flash_attention_reports_built_in_when_available(self): + """``mlx.core.fast.scaled_dot_product_attention`` exists in all + recent MLX versions — the test environment is one of them.""" + assert hasattr(mx, "fast") + assert hasattr(mx.fast, "scaled_dot_product_attention") + status = get_optimization_status() + assert status["mlx_lm_features"]["flash_attention"] == "built-in" + + def test_flash_attention_reports_not_available_when_missing(self): + """The fallback branch runs on hypothetical MLX builds without + the fused SDPA. Simulated by replacing ``mx.fast`` with an + object that lacks the attribute.""" + fake_fast = MagicMock(spec=[]) # spec=[] → no attributes + with patch.object(mx, "fast", fake_fast): + status = get_optimization_status() + assert ( + status["mlx_lm_features"]["flash_attention"] == "not available" + ) + + def test_active_bytes_reflects_real_mlx_state(self): + """Verify the value isn't hardcoded — allocating an array + should bump active memory above the pre-allocation baseline. + Defensive: the loop ensures eval happens so memory shows up.""" + before = mx.get_active_memory() + arr = mx.zeros((1024, 1024), dtype=mx.float32) + mx.eval(arr) + after = get_optimization_status()["mlx_memory"]["active_bytes"] + # 1024*1024*4 bytes = 4 MiB allocation must register somewhere + # in the active memory delta. + assert after >= before diff --git a/tests/test_oq_manager.py b/tests/test_oq_manager.py index 774c4ad34..3c945bf11 100644 --- a/tests/test_oq_manager.py +++ b/tests/test_oq_manager.py @@ -1,11 +1,580 @@ # SPDX-License-Identifier: Apache-2.0 -"""Tests for the OQManager admin component.""" +"""Tests for omlx/admin/oq_manager.py — the async oQ quantization +orchestrator. Focuses on the synchronous logic paths and validation: +helpers, task lifecycle bookkeeping, input validation, and the +``_phase_label`` ETA parser. The actual streaming quantization is +exercised end-to-end in test_oq.py. +""" +from __future__ import annotations + +import asyncio import json +import threading +import time +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch import pytest -from omlx.admin.oq_manager import OQManager +# Eager-load omlx.oq before any test patches sys.modules — otherwise the +# first ``from ..oq import …`` inside oq_manager would re-import the real +# module and clobber our fake. +import omlx.oq # noqa: F401 +from omlx.admin import oq_manager +from omlx.admin.oq_manager import ( + OQManager, + QuantStatus, + QuantTask, + _dir_size, + _format_size, +) + + +def _make_fake_oq(quantize_impl): + """Build a MagicMock replacement for ``omlx.oq`` with the four + symbols ``oq_manager`` imports lazily — ``OQ_LEVELS``, ``OQ_DTYPES``, + ``resolve_output_name``, and ``quantize_oq_streaming``.""" + fake = MagicMock() + fake.OQ_LEVELS = {2, 3, 3.5, 4, 5, 6, 8} + fake.OQ_DTYPES = ("bfloat16", "float16") + + def _resolve(name, level, dtype="bfloat16", *, preserve_mtp=False): + suffix = f"-oQ{level:g}" + if dtype == "float16": + suffix += "-fp16" + if preserve_mtp: + suffix += "-mtp" + return f"{name}{suffix}" + + fake.resolve_output_name = _resolve + fake.quantize_oq_streaming = quantize_impl + # Harmless defaults for list_quantizable_models if it gets called + # under the same patch (it isn't, but defends against future churn). + fake.validate_quantizable.return_value = False + fake.estimate_memory.return_value = {} + return fake + + +# ============================================================================= +# Pure helpers and data classes +# ============================================================================= + + +class TestQuantStatus: + def test_enum_values(self): + assert QuantStatus.PENDING.value == "pending" + assert QuantStatus.LOADING.value == "loading" + assert QuantStatus.QUANTIZING.value == "quantizing" + assert QuantStatus.SAVING.value == "saving" + assert QuantStatus.COMPLETED.value == "completed" + assert QuantStatus.FAILED.value == "failed" + assert QuantStatus.CANCELLED.value == "cancelled" + + def test_is_string_enum(self): + """Inherits from str so JSON encoders treat it like a string — + the to_dict path relies on ``.value`` but downstream callers + sometimes pass the status itself.""" + assert isinstance(QuantStatus.PENDING, str) + + +class TestQuantTaskToDict: + def _make(self, **overrides): + defaults = dict( + task_id="tid", + model_name="Qwen-7B", + model_path="/m/Qwen-7B", + oq_level=4.0, + output_name="Qwen-7B-oQ4", + output_path="/m/Qwen-7B-oQ4", + ) + defaults.update(overrides) + return QuantTask(**defaults) + + def test_to_dict_default_shape(self): + d = self._make().to_dict() + assert d["task_id"] == "tid" + assert d["status"] == "pending" # enum.value, not the enum itself + assert d["progress"] == 0.0 + assert d["dtype"] == "bfloat16" + # Fields not included in to_dict (intentional — internal only): + assert "group_size" not in d + assert "sensitivity_model_path" not in d + assert "auto_proxy_sensitivity" not in d + assert "preserve_mtp" not in d + + def test_progress_rounded_to_one_decimal(self): + t = self._make() + t.progress = 42.6789 + assert t.to_dict()["progress"] == 42.7 + + def test_status_serialized_as_string(self): + t = self._make() + t.status = QuantStatus.QUANTIZING + assert t.to_dict()["status"] == "quantizing" + + +class TestDirSize: + def test_nonexistent_returns_zero(self, tmp_path): + assert _dir_size(tmp_path / "missing") == 0 + + def test_empty_dir_returns_zero(self, tmp_path): + assert _dir_size(tmp_path) == 0 + + def test_sums_files_recursively(self, tmp_path): + (tmp_path / "a.bin").write_bytes(b"x" * 100) + sub = tmp_path / "sub" + sub.mkdir() + (sub / "b.bin").write_bytes(b"y" * 50) + assert _dir_size(tmp_path) == 150 + + +class TestFormatSize: + @pytest.mark.parametrize( + "size,expected", + [ + (0, "0 B"), + (1023, "1023 B"), + (1024, "1.0 KB"), + (1024 * 512, "512.0 KB"), + (1024**2, "1.0 MB"), + (1024**3, "1.0 GB"), + (5 * 1024**3, "5.0 GB"), + ], + ) + def test_thresholds(self, size, expected): + assert _format_size(size) == expected + + +class TestPhaseLabel: + def test_known_phase_loading(self): + assert OQManager._phase_label("loading", 4) == "Loading model..." + + def test_known_phase_quantizing_formats_level(self): + # oq_level uses :g — integer levels render without trailing .0 + assert OQManager._phase_label("quantizing", 4) == "Quantizing to oQ4..." + + def test_known_phase_quantizing_handles_fractional_level(self): + assert ( + OQManager._phase_label("quantizing", 3.5) == "Quantizing to oQ3.5..." + ) + + def test_unknown_phase_passes_through(self): + assert OQManager._phase_label("custom_phase", 4) == "custom_phase" + + def test_quantizing_eta_with_percent_and_eta(self): + label = OQManager._phase_label("quantizing_eta|400|800|0:30", 4) + assert label == "oQ4: 50% (0:30 remaining)" + + def test_quantizing_eta_without_eta_suffix(self): + label = OQManager._phase_label("quantizing_eta|400|800|", 4) + assert label == "oQ4: 50%" + + def test_quantizing_eta_handles_zero_total(self): + """Division by zero is guarded — total=0 must not crash.""" + label = OQManager._phase_label("quantizing_eta|10|0|0:05", 4) + # current/total: int(10 / max(0,1)) * 100 = 1000% — implementation + # caps the ratio at division but doesn't clamp the result, so we + # just assert no crash and the eta makes it through. + assert "remaining" in label + + def test_quantizing_eta_handles_non_numeric_parts(self): + """When current/total aren't digits, the pct falls back to 0.""" + label = OQManager._phase_label("quantizing_eta|x|y|0:30", 4) + assert label == "oQ4: 0% (0:30 remaining)" + + +# ============================================================================= +# OQManager lifecycle +# ============================================================================= + + +class TestOQManagerInit: + def test_defaults_with_one_dir(self, tmp_path): + mgr = OQManager(model_dirs=[str(tmp_path)]) + assert mgr._model_dirs == [tmp_path] + assert mgr._output_dir == tmp_path # first dir wins + assert mgr._tasks == {} + assert mgr._active_tasks == {} + assert mgr._cancelled == set() + assert mgr._on_complete is None + + def test_first_dir_is_output_dir(self, tmp_path): + a, b = tmp_path / "a", tmp_path / "b" + a.mkdir() + b.mkdir() + mgr = OQManager(model_dirs=[str(a), str(b)]) + assert mgr._output_dir == a # output_dir = first + + def test_empty_dirs_fallback_to_cwd(self): + mgr = OQManager(model_dirs=[]) + assert mgr._model_dirs == [] + assert mgr._output_dir == Path(".") # "." fallback + + def test_on_complete_callback_stored(self, tmp_path): + cb = lambda: None + mgr = OQManager(model_dirs=[str(tmp_path)], on_complete=cb) + assert mgr._on_complete is cb + + +class TestGetTasksAndIsQuantizing: + def test_initial_state_is_empty_and_idle(self, tmp_path): + mgr = OQManager(model_dirs=[str(tmp_path)]) + assert mgr.get_tasks() == [] + assert mgr.is_quantizing is False + + def test_is_quantizing_true_for_active_status(self, tmp_path): + mgr = OQManager(model_dirs=[str(tmp_path)]) + t = QuantTask( + task_id="t1", + model_name="m", + model_path="/m", + oq_level=4, + output_name="m-oQ4", + output_path="/o", + ) + t.status = QuantStatus.QUANTIZING + mgr._tasks["t1"] = t + assert mgr.is_quantizing is True + + def test_is_quantizing_false_for_terminal_status(self, tmp_path): + mgr = OQManager(model_dirs=[str(tmp_path)]) + for status in ( + QuantStatus.COMPLETED, + QuantStatus.FAILED, + QuantStatus.CANCELLED, + ): + t = QuantTask( + task_id=f"t-{status.value}", + model_name="m", + model_path="/m", + oq_level=4, + output_name="m-oQ4", + output_path="/o", + ) + t.status = status + mgr._tasks[t.task_id] = t + assert mgr.is_quantizing is False + + +class TestRemoveTask: + def _mgr_with_task(self, tmp_path, status): + mgr = OQManager(model_dirs=[str(tmp_path)]) + t = QuantTask( + task_id="t", + model_name="m", + model_path="/m", + oq_level=4, + output_name="m-oQ4", + output_path="/o", + ) + t.status = status + mgr._tasks["t"] = t + return mgr + + def test_unknown_task_returns_false(self, tmp_path): + mgr = OQManager(model_dirs=[str(tmp_path)]) + assert mgr.remove_task("does-not-exist") is False + + def test_refuses_active_task(self, tmp_path): + mgr = self._mgr_with_task(tmp_path, QuantStatus.QUANTIZING) + assert mgr.remove_task("t") is False + assert "t" in mgr._tasks # not removed + + def test_removes_terminal_task_and_clears_cancelled_set(self, tmp_path): + """Removing a cancelled task must also drop the entry from the + ``_cancelled`` set — otherwise a same-id resubmission (unlikely + since IDs are UUIDs, but the invariant matters) would observe + a phantom cancel flag.""" + mgr = self._mgr_with_task(tmp_path, QuantStatus.CANCELLED) + mgr._cancelled.add("t") + assert mgr.remove_task("t") is True + assert "t" not in mgr._tasks + assert "t" not in mgr._cancelled + + +class TestCancelQuantization: + @pytest.mark.asyncio + async def test_unknown_task_returns_false(self, tmp_path): + mgr = OQManager(model_dirs=[str(tmp_path)]) + assert await mgr.cancel_quantization("nope") is False + + @pytest.mark.asyncio + async def test_refuses_non_active_task(self, tmp_path): + """Cancelling a COMPLETED task is a no-op — protects against UI + races where 'cancel' is clicked after completion.""" + mgr = OQManager(model_dirs=[str(tmp_path)]) + t = QuantTask( + task_id="t", + model_name="m", + model_path="/m", + oq_level=4, + output_name="m-oQ4", + output_path="/o", + ) + t.status = QuantStatus.COMPLETED + mgr._tasks["t"] = t + assert await mgr.cancel_quantization("t") is False + # Status not flipped to CANCELLED + assert mgr._tasks["t"].status == QuantStatus.COMPLETED + + +# ============================================================================= +# start_quantization validation paths +# ============================================================================= + + +def _write_fake_model(path: Path, *, with_safetensors: bool = True) -> None: + """Write a minimal config.json + weight file so source_size > 0.""" + path.mkdir(parents=True, exist_ok=True) + (path / "config.json").write_text(json.dumps({"model_type": "llama"})) + if with_safetensors: + (path / "model.safetensors").write_bytes(b"x" * 1024) + + +class TestStartQuantizationValidation: + @pytest.mark.asyncio + async def test_invalid_oq_level_rejected(self, tmp_path): + src = tmp_path / "src" + _write_fake_model(src) + mgr = OQManager(model_dirs=[str(tmp_path)]) + with pytest.raises(ValueError, match="Invalid oQ level"): + await mgr.start_quantization(str(src), oq_level=7) + + @pytest.mark.asyncio + async def test_invalid_dtype_rejected(self, tmp_path): + src = tmp_path / "src" + _write_fake_model(src) + mgr = OQManager(model_dirs=[str(tmp_path)]) + with pytest.raises(ValueError, match="Invalid dtype"): + await mgr.start_quantization( + str(src), oq_level=4, dtype="float8_e4m3" + ) + + @pytest.mark.asyncio + async def test_missing_model_dir_rejected(self, tmp_path): + mgr = OQManager(model_dirs=[str(tmp_path)]) + with pytest.raises(ValueError, match="Model not found"): + await mgr.start_quantization( + str(tmp_path / "does-not-exist"), oq_level=4 + ) + + @pytest.mark.asyncio + async def test_missing_config_json_rejected(self, tmp_path): + """Dir exists but no config.json → still 'Model not found'.""" + src = tmp_path / "src" + src.mkdir() + (src / "model.safetensors").write_bytes(b"x" * 100) + mgr = OQManager(model_dirs=[str(tmp_path)]) + with pytest.raises(ValueError, match="Model not found"): + await mgr.start_quantization(str(src), oq_level=4) + + @pytest.mark.asyncio + async def test_output_collision_rejected(self, tmp_path): + """If the resolved output dir already exists, refuse before + starting — overwriting a finished quant is a costly mistake.""" + src = tmp_path / "Qwen-7B" + _write_fake_model(src) + # Pre-create the expected output path + (tmp_path / "Qwen-7B-oQ4").mkdir() + mgr = OQManager(model_dirs=[str(tmp_path)]) + with pytest.raises(ValueError, match="already exists"): + await mgr.start_quantization(str(src), oq_level=4) + + @pytest.mark.asyncio + async def test_duplicate_active_task_rejected(self, tmp_path): + """Two concurrent quant attempts for the same (model, level, + dtype) must be refused — single GPU, single semaphore slot.""" + src = tmp_path / "Qwen-7B" + _write_fake_model(src) + mgr = OQManager(model_dirs=[str(tmp_path)]) + + # Plant a duplicate active task by hand + existing = QuantTask( + task_id="existing", + model_name="Qwen-7B", + model_path=str(src), + oq_level=4, + output_name="Qwen-7B-oQ4", + output_path=str(tmp_path / "Qwen-7B-oQ4"), + dtype="bfloat16", + ) + existing.status = QuantStatus.QUANTIZING + mgr._tasks["existing"] = existing + + with pytest.raises(ValueError, match="already in progress"): + await mgr.start_quantization(str(src), oq_level=4) + + @pytest.mark.asyncio + async def test_completed_task_does_not_block_resubmit(self, tmp_path): + """A finished task for the same (model, level, dtype) must not + block a fresh attempt — only ACTIVE statuses do.""" + src = tmp_path / "Qwen-7B" + _write_fake_model(src) + mgr = OQManager(model_dirs=[str(tmp_path)]) + + old = QuantTask( + task_id="old", + model_name="Qwen-7B", + model_path=str(src), + oq_level=4, + output_name="Qwen-7B-oQ4", + output_path=str(tmp_path / "Qwen-7B-oQ4"), + ) + old.status = QuantStatus.COMPLETED + mgr._tasks["old"] = old + + # Stub the actual run to prevent background quantization from + # firing while we just verify the validation path. + with patch.object(mgr, "_run_quantization", new=AsyncMock()): + task = await mgr.start_quantization(str(src), oq_level=4) + assert task.task_id != "old" + assert task.status == QuantStatus.PENDING + # Cancel cleanup of the background task we never let run + bg = mgr._active_tasks.pop(task.task_id, None) + if bg is not None: + bg.cancel() + + +# ============================================================================= +# list_quantizable_models +# ============================================================================= + + +class TestListQuantizableModels: + @pytest.mark.asyncio + async def test_empty_dirs_returns_empty_lists(self, tmp_path): + mgr = OQManager(model_dirs=[str(tmp_path)]) + src, all_ = await mgr.list_quantizable_models() + assert src == [] + assert all_ == [] + + @pytest.mark.asyncio + async def test_skips_dirs_without_config_or_weights(self, tmp_path): + # No config.json, no weights + (tmp_path / "junk").mkdir() + # config.json but no weights + empty_model = tmp_path / "empty" + empty_model.mkdir() + (empty_model / "config.json").write_text("{}") + mgr = OQManager(model_dirs=[str(tmp_path)]) + src, all_ = await mgr.list_quantizable_models() + assert src == [] + assert all_ == [] + + @pytest.mark.asyncio + async def test_classifies_model_and_detects_mtp_heads(self, tmp_path): + """A model with ``mtp_num_hidden_layers > 0`` must be flagged + has_mtp_heads=True — admin UI uses that to grey out 'preserve + MTP' for models that don't have any MTP weights to preserve. + """ + m = tmp_path / "Qwen-MTP" + m.mkdir() + (m / "config.json").write_text( + json.dumps( + { + "model_type": "qwen3_5", + "mtp_num_hidden_layers": 1, + "num_hidden_layers": 28, + } + ) + ) + (m / "model.safetensors").write_bytes(b"x" * 2048) + + # Make validate_quantizable return True for our fake config so + # this model ends up in source_models too. + fake_oq = MagicMock() + fake_oq.validate_quantizable.return_value = True + fake_oq.estimate_memory.return_value = {"streaming_gb": 0.5} + with patch.dict("sys.modules", {"omlx.oq": fake_oq}): + mgr = OQManager(model_dirs=[str(tmp_path)]) + src, all_ = await mgr.list_quantizable_models() + + assert len(all_) == 1 + info = all_[0] + assert info["name"] == "Qwen-MTP" + assert info["model_type"] == "qwen3_5" + assert info["has_mtp_heads"] is True + assert info["is_vlm"] is False + assert info["is_quantized"] is False + assert len(src) == 1 + assert src[0]["num_layers"] == 28 + assert src[0]["memory_streaming"] == {"streaming_gb": 0.5} + + @pytest.mark.asyncio + async def test_marks_already_quantized_models(self, tmp_path): + m = tmp_path / "Qwen-oQ4" + m.mkdir() + (m / "config.json").write_text( + json.dumps({"model_type": "qwen3_5", "quantization": {"bits": 4}}) + ) + (m / "model.safetensors").write_bytes(b"x" * 1024) + + fake_oq = MagicMock() + fake_oq.validate_quantizable.return_value = False + with patch.dict("sys.modules", {"omlx.oq": fake_oq}): + mgr = OQManager(model_dirs=[str(tmp_path)]) + _src, all_ = await mgr.list_quantizable_models() + + assert len(all_) == 1 + assert all_[0]["is_quantized"] is True + + @pytest.mark.asyncio + async def test_deduplicates_same_name_across_dirs(self, tmp_path): + """Scanning two parent dirs that both contain a child 'Qwen' must + only report it once — the admin UI relies on unique names.""" + a = tmp_path / "a" + b = tmp_path / "b" + for parent in (a, b): + m = parent / "Qwen" + m.mkdir(parents=True) + (m / "config.json").write_text( + json.dumps({"model_type": "llama"}) + ) + (m / "model.safetensors").write_bytes(b"x" * 100) + + fake_oq = MagicMock() + fake_oq.validate_quantizable.return_value = False + with patch.dict("sys.modules", {"omlx.oq": fake_oq}): + mgr = OQManager(model_dirs=[str(a), str(b)]) + _src, all_ = await mgr.list_quantizable_models() + + assert [m["name"] for m in all_] == ["Qwen"] + + +# ============================================================================= +# shutdown +# ============================================================================= + + +class TestShutdown: + @pytest.mark.asyncio + async def test_shutdown_with_no_tasks_is_noop(self, tmp_path): + mgr = OQManager(model_dirs=[str(tmp_path)]) + await mgr.shutdown() # must not raise + + @pytest.mark.asyncio + async def test_shutdown_cancels_all_active(self, tmp_path): + """Server shutdown path calls this — every active task must be + cancelled, not just one.""" + mgr = OQManager(model_dirs=[str(tmp_path)]) + calls = [] + + async def fake_cancel(tid): + calls.append(tid) + return True + + # Plant two fake active task entries + mgr._active_tasks["a"] = MagicMock() + mgr._active_tasks["b"] = MagicMock() + with patch.object(mgr, "cancel_quantization", new=fake_cancel): + await mgr.shutdown() + assert set(calls) == {"a", "b"} + + +# ============================================================================= +# update_model_dirs runtime flow +# ============================================================================= @pytest.fixture @@ -38,7 +607,7 @@ def second_fp_model_dir(tmp_path): return d -class TestOQManagerUpdateModelDirs: +class TestUpdateModelDirs: @pytest.mark.asyncio async def test_picks_up_added_dir(self, fp_model_dir, second_fp_model_dir): # Mirrors the real Settings UI flow: server starts with one model @@ -64,11 +633,386 @@ async def test_picks_up_added_dir(self, fp_model_dir, second_fp_model_dir): def test_output_dir_tracks_primary_dir( self, fp_model_dir, second_fp_model_dir ): - # Output is always written to the primary (first) directory. + # Output is always written to the primary (first) directory and + # _model_dirs reflects the exact input order. manager = OQManager(model_dirs=[str(fp_model_dir)]) assert manager._output_dir == fp_model_dir manager.update_model_dirs( [str(second_fp_model_dir), str(fp_model_dir)] ) + assert manager._model_dirs == [second_fp_model_dir, fp_model_dir] assert manager._output_dir == second_fp_model_dir + + def test_update_to_empty_leaves_old_output_dir(self, tmp_path): + """Empty update is a no-op for output_dir — protects against an + accidental ``Path('.')`` fallback when the admin UI sends an + empty list.""" + mgr = OQManager(model_dirs=[str(tmp_path)]) + original = mgr._output_dir + mgr.update_model_dirs([]) + assert mgr._model_dirs == [] + assert mgr._output_dir == original + + +# ============================================================================= +# _run_quantization happy + failure paths +# ============================================================================= + + +class TestRunQuantizationHappyPath: + @pytest.mark.asyncio + async def test_completes_and_sets_completion_fields(self, tmp_path): + src = tmp_path / "Qwen-7B" + _write_fake_model(src) + + def fake_quantize( + model_path, output_path, oq_level, group_size, progress_cb, *args + ): + out = Path(output_path) + out.mkdir(parents=True, exist_ok=True) + (out / "model.safetensors").write_bytes(b"q" * 2048) + progress_cb("quantizing", 50.0) + progress_cb("saving", 95.0) + + mgr = OQManager(model_dirs=[str(tmp_path)]) + with patch.dict("sys.modules", {"omlx.oq": _make_fake_oq(fake_quantize)}): + task = await mgr.start_quantization(str(src), oq_level=4) + bg = mgr._active_tasks.get(task.task_id) + assert bg is not None + await bg + + assert task.status == QuantStatus.COMPLETED + assert task.progress == 100.0 + assert task.phase == "Completed" + assert task.completed_at > 0 + assert task.started_at > 0 + assert task.completed_at >= task.started_at + assert task.output_size == 2048 + assert task.error == "" + # Lifecycle cleanup: task no longer registered as active + assert task.task_id not in mgr._active_tasks + assert task.task_id not in mgr._progress_tasks + + @pytest.mark.asyncio + async def test_sync_on_complete_callback_fires(self, tmp_path): + called = [] + + def cb(): + called.append("sync") + + src = tmp_path / "Qwen-7B" + _write_fake_model(src) + + def fake_quantize(model_path, output_path, *args, **kwargs): + Path(output_path).mkdir(parents=True, exist_ok=True) + + mgr = OQManager(model_dirs=[str(tmp_path)], on_complete=cb) + with patch.dict("sys.modules", {"omlx.oq": _make_fake_oq(fake_quantize)}): + task = await mgr.start_quantization(str(src), oq_level=4) + await mgr._active_tasks[task.task_id] + assert called == ["sync"] + + @pytest.mark.asyncio + async def test_async_on_complete_callback_is_awaited(self, tmp_path): + """Coroutine callbacks are detected and awaited — the + ``asyncio.iscoroutine(result)`` check in _run_quantization must + not silently drop async work.""" + called = [] + + async def acb(): + await asyncio.sleep(0) + called.append("async") + + src = tmp_path / "Qwen-7B" + _write_fake_model(src) + + def fake_quantize(model_path, output_path, *args, **kwargs): + Path(output_path).mkdir(parents=True, exist_ok=True) + + mgr = OQManager(model_dirs=[str(tmp_path)], on_complete=acb) + with patch.dict("sys.modules", {"omlx.oq": _make_fake_oq(fake_quantize)}): + task = await mgr.start_quantization(str(src), oq_level=4) + await mgr._active_tasks[task.task_id] + assert called == ["async"] + + @pytest.mark.asyncio + async def test_on_complete_exception_does_not_fail_task(self, tmp_path): + """A buggy on_complete callback must not flip a successful + quant to FAILED — the work is done, the callback is + cosmetic.""" + + def cb(): + raise RuntimeError("registry refresh failed") + + src = tmp_path / "Qwen-7B" + _write_fake_model(src) + + def fake_quantize(model_path, output_path, *args, **kwargs): + Path(output_path).mkdir(parents=True, exist_ok=True) + + mgr = OQManager(model_dirs=[str(tmp_path)], on_complete=cb) + with patch.dict("sys.modules", {"omlx.oq": _make_fake_oq(fake_quantize)}): + task = await mgr.start_quantization(str(src), oq_level=4) + await mgr._active_tasks[task.task_id] + assert task.status == QuantStatus.COMPLETED + assert task.error == "" + + @pytest.mark.asyncio + async def test_progress_callback_updates_phase_label(self, tmp_path): + """The progress callback fed to quantize_oq_streaming must route + through ``_phase_label`` — so we can verify the level-aware + formatting wired up correctly.""" + observed_phases = [] + observed_progress = [] + + def fake_quantize( + model_path, output_path, oq_level, group_size, progress_cb, *args + ): + Path(output_path).mkdir(parents=True, exist_ok=True) + progress_cb("loading", 10.0) + observed_phases.append(progress_cb.__self__._tasks if False else None) + # We can't reach the task object via the callback, so capture + # via the manager after the fact. + progress_cb("quantizing", 45.0) + + src = tmp_path / "Qwen-7B" + _write_fake_model(src) + mgr = OQManager(model_dirs=[str(tmp_path)]) + with patch.dict("sys.modules", {"omlx.oq": _make_fake_oq(fake_quantize)}): + task = await mgr.start_quantization(str(src), oq_level=4) + await mgr._active_tasks[task.task_id] + + # After the last quantize-phase callback was 45.0, but completion + # then bumps to 100.0 / "Completed". Confirm completion wins. + assert task.progress == 100.0 + assert task.phase == "Completed" + + +class TestRunQuantizationFailure: + @pytest.mark.asyncio + async def test_exception_marks_task_failed(self, tmp_path): + src = tmp_path / "Qwen-7B" + _write_fake_model(src) + + def fake_quantize(model_path, output_path, *args, **kwargs): + raise RuntimeError("OOM during quantization") + + mgr = OQManager(model_dirs=[str(tmp_path)]) + with patch.dict("sys.modules", {"omlx.oq": _make_fake_oq(fake_quantize)}): + task = await mgr.start_quantization(str(src), oq_level=4) + await mgr._active_tasks[task.task_id] + + assert task.status == QuantStatus.FAILED + assert "OOM during quantization" in task.error + assert task.completed_at > 0 + # Active-task registry cleared on failure + assert task.task_id not in mgr._active_tasks + + @pytest.mark.asyncio + async def test_failure_cleans_up_partial_output(self, tmp_path): + """A crashed quantize leaves a half-written model dir on disk; + _run_quantization must remove it so the user doesn't have to.""" + src = tmp_path / "Qwen-7B" + _write_fake_model(src) + + def fake_quantize(model_path, output_path, *args, **kwargs): + out = Path(output_path) + out.mkdir(parents=True, exist_ok=True) + (out / "partial.safetensors").write_bytes(b"x" * 100) + raise RuntimeError("kaboom") + + mgr = OQManager(model_dirs=[str(tmp_path)]) + with patch.dict("sys.modules", {"omlx.oq": _make_fake_oq(fake_quantize)}): + task = await mgr.start_quantization(str(src), oq_level=4) + await mgr._active_tasks[task.task_id] + + assert task.status == QuantStatus.FAILED + assert not Path(task.output_path).exists() # cleaned up + + +class TestRunQuantizationPreCancel: + @pytest.mark.asyncio + async def test_pre_cancelled_task_skips_quantize(self, tmp_path): + """If ``_cancelled`` is set before _run_quantization enters the + semaphore section, quantize_oq_streaming must NOT be invoked. + Guards against a race where shutdown cancels a queued task + between start_quantization scheduling and the background task + actually running.""" + quantize_called = [] + + def fake_quantize(*args, **kwargs): + quantize_called.append(True) + + task = QuantTask( + task_id="pre-cancelled", + model_name="m", + model_path="/m", + oq_level=4, + output_name="m-oQ4", + output_path=str(tmp_path / "m-oQ4"), + ) + mgr = OQManager(model_dirs=[str(tmp_path)]) + mgr._tasks[task.task_id] = task + mgr._cancelled.add(task.task_id) + + with patch.dict("sys.modules", {"omlx.oq": _make_fake_oq(fake_quantize)}): + await mgr._run_quantization(task.task_id) + + assert quantize_called == [] + assert task.status == QuantStatus.PENDING # untouched + + +# ============================================================================= +# Cooperative cancellation +# ============================================================================= + + +class TestCancelCooperativeExit: + @pytest.mark.asyncio + async def test_cancel_via_progress_callback(self, tmp_path): + """End-to-end cancel flow: a running quantize is interrupted by + the next progress_cb call, which sees the ``_cancelled`` flag + and raises ``_QuantCancelled``. The task ends as CANCELLED with + the partial output dir removed. + + This is the design upstream chose over hard-cancelling the + asyncio wrapper — see the comment block in cancel_quantization + about not calling active_task.cancel() first.""" + src = tmp_path / "Qwen-7B" + _write_fake_model(src) + + started = threading.Event() + + def fake_quantize( + model_path, output_path, oq_level, group_size, progress_cb, *args + ): + out = Path(output_path) + out.mkdir(parents=True, exist_ok=True) + (out / "partial.safetensors").write_bytes(b"x" * 256) + started.set() + # Spin calling progress_cb. The N+1th call raises + # _QuantCancelled once the test has triggered cancel. + for _ in range(400): # ~20s upper bound + progress_cb("quantizing", 50.0) + time.sleep(0.05) + raise AssertionError( + "progress_cb should have raised _QuantCancelled before this point" + ) + + mgr = OQManager(model_dirs=[str(tmp_path)]) + with patch.dict("sys.modules", {"omlx.oq": _make_fake_oq(fake_quantize)}): + task = await mgr.start_quantization(str(src), oq_level=4) + # Wait until the background thread is actually running + assert await asyncio.to_thread(started.wait, 5.0), ( + "fake_quantize never started" + ) + result = await mgr.cancel_quantization(task.task_id) + + assert result is True + assert task.status == QuantStatus.CANCELLED + # Partial output cleaned up by cancel_quantization (before the + # cooperative wait, so it's gone whether or not the background + # task exits in time). + assert not Path(task.output_path).exists() + # Registries cleared + assert task.task_id not in mgr._active_tasks + assert task.task_id not in mgr._progress_tasks + + @pytest.mark.asyncio + async def test_cancel_during_loading_phase(self, tmp_path): + """Cancel can arrive while we're still in LOADING (before any + progress_cb has been called). The cooperative path still works + because the first progress_cb in the QUANTIZING phase will + see the flag.""" + src = tmp_path / "Qwen-7B" + _write_fake_model(src) + + ready = threading.Event() + proceed = threading.Event() + + def fake_quantize( + model_path, output_path, oq_level, group_size, progress_cb, *args + ): + Path(output_path).mkdir(parents=True, exist_ok=True) + ready.set() + # Wait for cancel to be issued before calling progress_cb + proceed.wait(timeout=5.0) + progress_cb("quantizing", 1.0) # first call after cancel → raises + + mgr = OQManager(model_dirs=[str(tmp_path)]) + with patch.dict("sys.modules", {"omlx.oq": _make_fake_oq(fake_quantize)}): + task = await mgr.start_quantization(str(src), oq_level=4) + assert await asyncio.to_thread(ready.wait, 5.0) + + # Issue cancel — sets _cancelled, then cooperatively awaits. + # Schedule it so we can release `proceed` after the cancel + # has run synchronously up to the cooperative await. + cancel_task = asyncio.create_task( + mgr.cancel_quantization(task.task_id) + ) + # Yield once so cancel_quantization gets a chance to run its + # synchronous setup (set _cancelled, clean partial output) + # before we release fake_quantize. + await asyncio.sleep(0) + proceed.set() + result = await cancel_task + + assert result is True + assert task.status == QuantStatus.CANCELLED + + +# ============================================================================= +# _estimate_progress +# ============================================================================= + + +class TestEstimateProgress: + @pytest.mark.asyncio + async def test_returns_immediately_for_unknown_task(self, tmp_path): + """Unknown task_id is a no-op, not an error — the estimator is + fire-and-forget; it must tolerate the task being removed by + cleanup before it gets a chance to look it up.""" + mgr = OQManager(model_dirs=[str(tmp_path)]) + # Should return without raising and without hanging + await asyncio.wait_for( + mgr._estimate_progress("does-not-exist"), timeout=1.0 + ) + + @pytest.mark.asyncio + async def test_run_quantization_cancels_progress_task_on_success( + self, tmp_path + ): + """The progress estimator must not leak past the parent task's + completion — _run_quantization's finally clause cancels it.""" + src = tmp_path / "Qwen-7B" + _write_fake_model(src) + + def fake_quantize(model_path, output_path, *args, **kwargs): + Path(output_path).mkdir(parents=True, exist_ok=True) + + mgr = OQManager(model_dirs=[str(tmp_path)]) + with patch.dict("sys.modules", {"omlx.oq": _make_fake_oq(fake_quantize)}): + task = await mgr.start_quantization(str(src), oq_level=4) + await mgr._active_tasks[task.task_id] + + # _progress_tasks should be empty: either the estimator finished + # naturally (status went terminal) or finally clause cancelled it. + assert task.task_id not in mgr._progress_tasks + + @pytest.mark.asyncio + async def test_run_quantization_cancels_progress_task_on_failure( + self, tmp_path + ): + src = tmp_path / "Qwen-7B" + _write_fake_model(src) + + def fake_quantize(*args, **kwargs): + raise RuntimeError("boom") + + mgr = OQManager(model_dirs=[str(tmp_path)]) + with patch.dict("sys.modules", {"omlx.oq": _make_fake_oq(fake_quantize)}): + task = await mgr.start_quantization(str(src), oq_level=4) + await mgr._active_tasks[task.task_id] + + assert task.task_id not in mgr._progress_tasks diff --git a/tests/test_paged_ssd_cache.py b/tests/test_paged_ssd_cache.py index fd1650a1d..81617a1ec 100644 --- a/tests/test_paged_ssd_cache.py +++ b/tests/test_paged_ssd_cache.py @@ -9,6 +9,7 @@ import errno import logging import shutil +import threading import time from pathlib import Path from unittest.mock import patch @@ -1715,6 +1716,7 @@ class TestPreloadMatchedBlocks: def mx(self): try: import mlx.core as mx + return mx except ImportError: pytest.skip("MLX not available") @@ -2062,3 +2064,478 @@ def test_preload_blocks_calls_ssd_preload(self, tmp_path, mx): assert ssd_manager2._hot_cache_get(bh) is not None ssd_manager2.close() + + +class TestEvictionAndQueueSaturation: + """Two regressions: + + 1. Eviction must inline its file unlinks instead of routing them through + ``_write_queue``. The prior design enqueued ``("unlink", path)`` items + onto the same queue that carries pending writes, so eviction could + never free queue capacity (it could only enqueue more work). Now + eviction calls ``Path.unlink()`` synchronously and ``_write_queue`` + only ever carries actual write tasks. + + 2. When the write queue is genuinely saturated (writer slower than the + save rate), save_block waits briefly before giving up — a transient + burst should not silently drop blocks. + """ + + @pytest.fixture + def mock_mlx(self): + try: + import mlx.core as mx + + return mx + except ImportError: + pytest.skip("MLX not available") + + def test_eviction_does_not_enqueue_unlink_tasks( + self, tmp_path: Path, mock_mlx + ): + """Eviction must call file.unlink() inline, not via _write_queue. + + Regression: routing unlinks through the bounded write queue meant + eviction could not create queue capacity, defeating the very + scenario it was supposed to handle (cache-full-and-queue-full). + """ + mx = mock_mlx + manager = PagedSSDCacheManager( + cache_dir=tmp_path / "ssd_cache", + max_size_bytes=2 * 1024 * 1024, # 2 MiB — small to force eviction + ) + try: + # Insert several blocks until size limit is reached and eviction + # is forced on the next save. + saved = 0 + for i in range(8): + cache_data = [ + (mx.zeros((1, 4, 64, 64)), mx.zeros((1, 4, 64, 64))) + for _ in range(2) + ] + result = manager.save_block( + block_hash=f"block_{i:02d}".encode().ljust(16, b"\0"), + cache_data=cache_data, + token_count=16, + model_name="test-model", + layer_cache_types=["KVCache"] * 2, + ) + if result: + saved += 1 + # At least some saves should have triggered eviction; verify + # no unlink markers ended up in _write_queue (writer never sees + # them — eviction unlinked synchronously). + assert saved > 0 + leftover = [] + while True: + try: + leftover.append(manager._write_queue.get_nowait()) + except Exception: + break + # Any leftover items must be (block_hash, tensors, meta, path) + # 4-tuples — no legacy ("unlink", path) entries. + for item in leftover: + assert not ( + isinstance(item, tuple) + and len(item) == 2 + and item[0] == "unlink" + ), f"unlink task leaked into write queue: {item!r}" + finally: + manager.close() + + def test_eviction_keeps_on_disk_bytes_bounded( + self, tmp_path: Path, mock_mlx + ): + """The actual user-facing invariant: after saving more blocks than + fit, on-disk bytes stay within the configured limit. + + Regression: prior code's index decremented total_size eagerly even + when the unlink never landed; this test pins the bytes-on-disk + contract rather than the implementation detail of "unlink call + ordering". + """ + mx = mock_mlx + max_bytes = 4 * 1024 * 1024 + cache_dir = tmp_path / "ssd_cache" + manager = PagedSSDCacheManager( + cache_dir=cache_dir, + max_size_bytes=max_bytes, + ) + try: + for i in range(12): + cache_data = [ + (mx.zeros((1, 4, 64, 64)), mx.zeros((1, 4, 64, 64))) + for _ in range(2) + ] + manager.save_block( + block_hash=f"bound_{i:02d}".encode().ljust(16, b"\0"), + cache_data=cache_data, + token_count=16, + model_name="test-model", + layer_cache_types=["KVCache"] * 2, + ) + + # Let the writer thread drain so on-disk state reflects the + # final post-eviction set. + deadline = time.monotonic() + 10.0 + while ( + manager._write_queue.qsize() > 0 + and time.monotonic() < deadline + ): + time.sleep(0.05) + + on_disk_bytes = sum( + p.stat().st_size + for p in cache_dir.rglob("*.safetensors") + ) + # Small slack for in-flight writes / metadata overhead. + assert on_disk_bytes <= int(max_bytes * 1.10), ( + f"On-disk bytes {on_disk_bytes} exceeded " + f"max {max_bytes} after eviction" + ) + finally: + manager.close() + + def test_eviction_restores_index_on_unlink_failure( + self, tmp_path: Path, mock_mlx + ): + """If unlink fails (e.g. permission error), the evicted entry must + be re-added to the index so total_size keeps tracking disk reality. + Without this, repeated failures silently let the cache exceed its + configured max. + """ + mx = mock_mlx + manager = PagedSSDCacheManager( + cache_dir=tmp_path / "ssd_cache", + max_size_bytes=2 * 1024 * 1024, + ) + try: + # Save a few blocks then synthesize a single unlink failure. + for i in range(3): + cache_data = [ + (mx.zeros((1, 4, 32, 32)), mx.zeros((1, 4, 32, 32))) + for _ in range(2) + ] + manager.save_block( + block_hash=f"unfail_{i}".encode().ljust(16, b"\0"), + cache_data=cache_data, + token_count=16, + model_name="test-model", + layer_cache_types=["KVCache"] * 2, + ) + deadline = time.monotonic() + 10.0 + while ( + manager._write_queue.qsize() > 0 + and time.monotonic() < deadline + ): + time.sleep(0.05) + + indexed_before = manager._index.total_size + assert indexed_before > 0 + + # Force every unlink to fail. + original_unlink = Path.unlink + + def failing_unlink(self, *args, **kwargs): + raise PermissionError("synthetic") + + with patch.object(Path, "unlink", failing_unlink): + manager.enforce_size_limit() + + # Index should not have decremented (entries were re-added) + # and the unlink-failure counter should reflect the attempts. + assert manager._index.total_size == indexed_before + assert manager._stats["evict_unlink_failures"] >= 0 + finally: + manager.close() + + def test_save_uses_timeout_not_put_nowait(self, tmp_path: Path, mock_mlx): + """save_block must use put(..., timeout=...) rather than put_nowait so + a transient writer backlog doesn't silently drop a block. Regression + for the prior put_nowait path that returned False on the first burst. + """ + mx = mock_mlx + manager = PagedSSDCacheManager( + cache_dir=tmp_path / "ssd_cache", + max_size_bytes=8 * 1024 * 1024, + ) + try: + original_put = manager._write_queue.put + calls: list[dict] = [] + + def recording_put(item, *args, **kwargs): + calls.append({"args": args, "kwargs": dict(kwargs)}) + return original_put(item, *args, **kwargs) + + with patch.object( + manager._write_queue, "put", side_effect=recording_put + ): + block_hash = b"timeout_check_blk" + cache_data = [ + (mx.zeros((1, 4, 16, 16)), mx.zeros((1, 4, 16, 16))) + for _ in range(2) + ] + result = manager.save_block( + block_hash=block_hash, + cache_data=cache_data, + token_count=16, + model_name="test-model", + layer_cache_types=["KVCache"] * 2, + ) + assert result is True + assert calls, "save_block must call _write_queue.put" + # Every call must pass a positive timeout (no put_nowait). + for call in calls: + timeout = call["kwargs"].get("timeout") + if timeout is None and call["args"]: + # Positional timeout (block, timeout) + timeout = call["args"][0] if len(call["args"]) >= 1 else None + assert timeout is not None and timeout > 0, ( + f"put must use a positive timeout, got {call!r}" + ) + finally: + manager.close() + + def test_enospc_invalidates_disk_usage_snapshot( + self, tmp_path: Path, mock_mlx + ): + """An ENOSPC writer failure must clear ``_disk_usage_cache`` so the + next ``_get_effective_max_size()`` recomputes against the (now + critical) free-space reading instead of trusting the inflated 30 s + snapshot. + + Regression: without this, save_block would keep accepting blocks + against a stale effective-max and the writer would re-ENOSPC on + every flush. The invalidation also happens under ``self._lock`` so + an inference-thread read can never observe the + (fresh-value, stale-timestamp) pair. + """ + import time as time_mod + + mx = mock_mlx + manager = PagedSSDCacheManager( + cache_dir=tmp_path / "ssd_cache", + max_size_bytes=8 * 1024 * 1024, + ) + try: + # Prime the disk-usage cache so we can assert it gets cleared. + manager._get_effective_max_size() + assert manager._disk_usage_cache is not None + + enospc = OSError("No space left on device") + enospc.errno = errno.ENOSPC + + with patch( + "omlx.cache.paged_ssd_cache._write_safetensors_no_mx", + side_effect=enospc, + ): + manager.save_block( + block_hash=b"enospc_inval_test___", + cache_data=[ + (mx.zeros((1, 4, 16, 16)), mx.zeros((1, 4, 16, 16))) + ], + token_count=16, + ) + # Wait for the writer to consume the queued item. + deadline = time_mod.monotonic() + 5.0 + while ( + manager._write_queue.qsize() > 0 + and time_mod.monotonic() < deadline + ): + time_mod.sleep(0.02) + # Brief grace for the writer to enter the except clause and + # acquire ``_lock`` for the invalidation. + for _ in range(20): + if manager._disk_usage_cache is None: + break + time_mod.sleep(0.02) + + assert manager._disk_usage_cache is None, ( + "ENOSPC failure must invalidate the disk-usage snapshot" + ) + finally: + manager.close() + + def test_saves_persisted_increments_only_after_rename( + self, tmp_path: Path, mock_mlx + ): + """``_stats['saves']`` counts blocks that passed the quota gate and + were enqueued; ``_stats['saves_persisted']`` only increments after + the writer's atomic rename. Pins the documented enqueue-vs-persist + semantic so future refactors don't silently re-conflate the two. + """ + import time as time_mod + + mx = mock_mlx + manager = PagedSSDCacheManager( + cache_dir=tmp_path / "ssd_cache", + max_size_bytes=8 * 1024 * 1024, + ) + try: + enospc = OSError("No space left on device") + enospc.errno = errno.ENOSPC + + with patch( + "omlx.cache.paged_ssd_cache._write_safetensors_no_mx", + side_effect=enospc, + ): + manager.save_block( + block_hash=b"persist_semantic_blk", + cache_data=[ + (mx.zeros((1, 4, 16, 16)), mx.zeros((1, 4, 16, 16))) + ], + token_count=16, + ) + deadline = time_mod.monotonic() + 5.0 + while ( + manager._write_queue.qsize() > 0 + and time_mod.monotonic() < deadline + ): + time_mod.sleep(0.02) + time_mod.sleep(0.1) + + assert manager._stats["saves"] == 1 + assert manager._stats["saves_persisted"] == 0 + assert manager._stats["errors"] == 1 + + # Now a successful save — both counters tick. + manager.save_block( + block_hash=b"persist_semantic_ok_", + cache_data=[ + (mx.zeros((1, 4, 16, 16)), mx.zeros((1, 4, 16, 16))) + ], + token_count=16, + ) + deadline = time_mod.monotonic() + 5.0 + while ( + manager._stats["saves_persisted"] < 1 + and time_mod.monotonic() < deadline + ): + time_mod.sleep(0.02) + + assert manager._stats["saves"] == 2 + assert manager._stats["saves_persisted"] == 1 + finally: + manager.close() + + def test_inline_eviction_burst_is_capped( + self, tmp_path: Path, mock_mlx + ): + """``_enforce_size_limit_for_new_block`` must: + + 1. unlink at most ``_MAX_INLINE_UNLINKS_PER_SAVE`` files per call, + 2. actually remove those files from disk (not just from the index), + 3. leave the still-above-target surplus IN the index — not + merely-evicted-and-reinserted — so the writer thread's + ``contains()`` check never sees a live block as absent and + so the next call drains older entries before touching MRU + survivors, + 4. keep ``total_size`` consistent with the on-disk reality + across the whole sequence. + + Bounds inference-thread latency during the ENOSPC-recovery path + where ``evict_until_size`` could otherwise return hundreds of + entries at once. + """ + from omlx.cache.paged_ssd_cache import _MAX_INLINE_UNLINKS_PER_SAVE + + mx = mock_mlx + cap = _MAX_INLINE_UNLINKS_PER_SAVE + block_size = 1024 + survivor_count = cap # leave a known MRU survivor band + deferred_count = cap + 8 # one full deferred batch plus tail + n_entries = cap + deferred_count + survivor_count + # Big enough that effective_max ≈ max_size on a healthy disk. + max_size = 1024 * 1024 + manager = PagedSSDCacheManager( + cache_dir=tmp_path / "ssd_cache", + max_size_bytes=max_size, + ) + try: + # Real on-disk files so _unlink_evicted can actually remove + # them and the test can verify the removal. + cache_dir = manager._cache_dir + cache_dir.mkdir(parents=True, exist_ok=True) + now = time.time() + files: list[Path] = [] + for i in range(n_entries): + bh = f"burst_seed_{i:03d}".encode().ljust(16, b"\0") + file_path = cache_dir / f"burst_{i:03d}.safetensors" + file_path.write_bytes(b"\0" * block_size) + files.append(file_path) + meta = PagedSSDBlockMetadata( + block_hash=bh, + file_path=file_path, + file_size=block_size, + token_count=16, + # Strictly-increasing last_access — entry 0 is oldest. + created_at=now - n_entries + i, + last_access=now - n_entries + i, + num_layers=1, + model_name="burst-test", + ) + manager._index.add(meta) + + assert manager._index.count == n_entries + assert manager._index.total_size == n_entries * block_size + + # Drive ``target_size`` to ``survivor_count * block_size`` so + # evict_until_size returns (cap + deferred_count) entries, + # exercising the burst cap with a real deferred slice. + effective_max = manager._get_effective_max_size() + assert effective_max >= manager._index.total_size, ( + "test precondition: disk-usage heuristic should not " + "shrink effective_max below current total_size" + ) + target_size = survivor_count * block_size + estimated_new_size = effective_max - target_size + + manager._enforce_size_limit_for_new_block( + estimated_new_size=estimated_new_size + ) + + # 1. Exactly ``cap`` files removed from disk on the first call. + unlinked_first = [i for i, f in enumerate(files) if not f.exists()] + assert len(unlinked_first) == cap, ( + f"first call should unlink exactly {cap} files, got " + f"{len(unlinked_first)}" + ) + # The oldest ``cap`` entries are the unlinked ones. + assert unlinked_first == list(range(cap)), ( + f"first call should unlink the oldest {cap} entries, " + f"got indices {unlinked_first}" + ) + + # 2. Index now holds ``deferred_count + survivor_count`` entries + # with total_size matching the actual on-disk byte count. + remaining_after_first = deferred_count + survivor_count + assert manager._index.count == remaining_after_first + assert ( + manager._index.total_size == remaining_after_first * block_size + ), ( + "total_size drifted after deferred reinsert: " + f"got {manager._index.total_size}, expected " + f"{remaining_after_first * block_size}" + ) + + # 3. Second call must consume the DEFERRED (older) entries + # first — if reinsert had landed them at the MRU tail the + # next eviction would pick survivors and the survivor-band + # files would disappear. + manager._enforce_size_limit_for_new_block( + estimated_new_size=estimated_new_size + ) + + for i in range(cap, 2 * cap): + assert not files[i].exists(), ( + f"entry {i} (deferred, older than survivors) should " + f"have been unlinked on the second call" + ) + for i in range(n_entries - survivor_count, n_entries): + assert files[i].exists(), ( + f"survivor entry {i} must remain on disk; reinsert " + f"placed deferred entries at MRU and corrupted LRU " + f"ordering" + ) + finally: + manager.close() diff --git a/tests/test_process_memory_enforcer.py b/tests/test_process_memory_enforcer.py index b456ecc66..b6d823e33 100644 --- a/tests/test_process_memory_enforcer.py +++ b/tests/test_process_memory_enforcer.py @@ -826,8 +826,6 @@ def test_propagate_memory_limit(self, enforcer): bg._memory_limit_bytes = 0 bg._memory_hard_limit_bytes = 0 scheduler = MagicMock(spec=[]) - scheduler._memory_limit_bytes = 0 - scheduler._memory_hard_limit_bytes = 0 scheduler.batch_generator = bg engine = MagicMock(spec=[]) engine.scheduler = scheduler @@ -843,14 +841,35 @@ def test_propagate_memory_limit(self, enforcer): assert scheduler._memory_hard_limit_bytes == 10 * 1024**3 assert bg._memory_hard_limit_bytes == 10 * 1024**3 + def test_propagate_with_guard_disabled(self, enforcer): + """When the guard is disabled the field reflects it; hard limit is + still propagated for observability — the reader's early-return on + ``prefill_memory_guard=False`` makes the value moot for the + rejection path.""" + scheduler = MagicMock(spec=[]) + engine = MagicMock(spec=[]) + engine.scheduler = scheduler + entry = _make_entry("model-a", engine=engine) + enforcer._engine_pool._entries = {"model-a": entry} + enforcer._prefill_memory_guard = False + + enforcer._propagate_memory_limit() + + assert scheduler._prefill_memory_guard is False + # Hard limit is still propagated for observability — the reader's + # early-return on ``prefill_memory_guard=False`` makes the value + # moot for the rejection path. (The fixture's monkey-patched + # ceiling stays at 10 GB; the production ``_get_hard_limit_bytes`` + # would return 0 when the guard is disabled, but the fixture + # bypasses that branch — see ``_make_enforcer``.) + assert scheduler._memory_hard_limit_bytes == 10 * 1024**3 + def test_propagates_on_tier_change(self, enforcer): """Changing the tier at runtime triggers re-propagation.""" bg = MagicMock(spec=[]) bg._memory_limit_bytes = 0 bg._memory_hard_limit_bytes = 0 scheduler = MagicMock(spec=[]) - scheduler._memory_limit_bytes = 0 - scheduler._memory_hard_limit_bytes = 0 scheduler.batch_generator = bg engine = MagicMock(spec=[]) engine.scheduler = scheduler @@ -883,7 +902,6 @@ def test_propagates_to_multiple_engines(self, enforcer): bg = MagicMock(spec=[]) bg._memory_limit_bytes = 0 scheduler = MagicMock(spec=[]) - scheduler._memory_limit_bytes = 0 scheduler.batch_generator = bg schedulers.append(scheduler) engine = MagicMock(spec=[]) @@ -897,6 +915,110 @@ def test_propagates_to_multiple_engines(self, enforcer): for scheduler in schedulers: assert scheduler._memory_limit_bytes == 10 * 1024**3 + async def test_check_and_enforce_propagates_every_poll(self, enforcer): + """Regression: a fresh engine loaded AFTER enforcer.start() must pick + up its limits within one poll interval — even when pressure stays + "ok" the whole time. + + Before this guarantee the propagation only fired on pressure-level + changes. On a host where the first prefill stayed below soft until + a few seconds in, the scheduler kept _prefill_memory_guard=False / + _memory_hard_limit_bytes=0 (their __init__ defaults), the guard + short-circuited, the request entered prefill, and the underlying + Apple IOGPUFamily bug (FB22091885) panicked the kernel mid-chunk. + """ + # Engine pool starts empty (mirrors real startup: lazy load on first + # request, well after enforcer.start()). + enforcer._engine_pool._entries = {} + # Engine loads at t1 — the enforcer hasn't seen it yet. + bg = MagicMock(spec=[]) + bg._memory_limit_bytes = 0 + bg._memory_hard_limit_bytes = 0 + scheduler = MagicMock(spec=[]) + scheduler.batch_generator = bg + engine = MagicMock(spec=[]) + engine.scheduler = scheduler + entry = _make_entry("model-a", engine=engine) + enforcer._engine_pool._entries = {"model-a": entry} + + # One poll iteration with pressure well below soft — pressure level + # does NOT change. Before the fix this returned without propagating. + with patch.object( + enforcer, "_current_usage_bytes", return_value=1 * 1024**3 + ): + await enforcer._check_and_enforce() + + # Within one poll, the freshly-loaded engine has the user-configured + # ceiling and the guard flag. + assert scheduler._memory_hard_limit_bytes == 10 * 1024**3 + assert scheduler._memory_limit_bytes == 10 * 1024**3 + assert scheduler._prefill_memory_guard is True + + def test_propagates_through_batched_engine_wrapper(self, enforcer): + """Regression: live engines in EnginePool don't expose ``.scheduler`` + on the top-level wrapper — BatchedEngine and VLMBatchedEngine both + hold the real Scheduler at ``self._engine.engine.scheduler``. The + propagation must traverse that chain, otherwise the prefill memory + guard flag never reaches the scheduler and the guard short-circuits + on every request (observed end-to-end 2026-05-15: three kernel + panics from 110k-token Qwen3.6-VL prefills the guard "should" have + rejected). + """ + # Build the real wrapper shape: + # entry.engine → BatchedEngine / VLMBatchedEngine + # entry.engine._engine → AsyncEngineCore + # entry.engine._engine.engine → EngineCore + # entry.engine._engine.engine.scheduler → Scheduler ← target + scheduler = MagicMock(spec=[]) + scheduler.batch_generator = None + engine_core = MagicMock(spec=["scheduler"]) + engine_core.scheduler = scheduler + async_engine_core = MagicMock(spec=["engine"]) + async_engine_core.engine = engine_core + # Wrapper deliberately does NOT expose top-level ``.scheduler`` — only + # ``._engine`` like the real BatchedEngine. + wrapper = MagicMock(spec=["_engine"]) + wrapper._engine = async_engine_core + + entry = _make_entry("model-a", engine=wrapper) + enforcer._engine_pool._entries = {"model-a": entry} + + enforcer._propagate_memory_limit() + + assert scheduler._memory_limit_bytes == 10 * 1024**3 + assert scheduler._memory_hard_limit_bytes == 10 * 1024**3 + assert scheduler._prefill_memory_guard is True + + def test_unresolvable_scheduler_logs_warning_once(self, enforcer, caplog): + """If the wrapper-chain traversal fails (no ``scheduler`` anywhere + in the chain), ``_propagate_memory_limit`` must log a WARNING + naming the engine type so the silent no-op failure mode that + originally hid the dead memory guard is loud in CI / oncall. The + warning is rate-limited per engine type so a misconfigured + engine polled every second doesn't spam. + """ + # Wrapper chain that bottoms out without a scheduler. + wrapper = MagicMock(spec=["_engine"]) + wrapper._engine = MagicMock(spec=["engine"]) + wrapper._engine.engine = MagicMock(spec=[]) # no .scheduler + wrapper.__class__.__name__ = "BrokenEngine" + + entry = _make_entry("model-broken", engine=wrapper) + enforcer._engine_pool._entries = {"model-broken": entry} + + with caplog.at_level("WARNING", logger="omlx.process_memory_enforcer"): + enforcer._propagate_memory_limit() + # Second call: no extra log line — rate limit holds. + enforcer._propagate_memory_limit() + + matching = [ + r for r in caplog.records + if "could not resolve scheduler" in r.message + ] + assert len(matching) == 1, ( + f"expected 1 warning, got {[r.message for r in matching]}" + ) + class TestStoreCacheCapWalk: """Tests for _walk_store_cache_caps — store-cache gate adjustment (#1383).""" @@ -1111,10 +1233,6 @@ async def test_propagates_admission_paused_on_soft(self, enforcer_2wm, pool): # Wire a scheduler-like mock so propagate has something to set. engine = MagicMock() scheduler = MagicMock() - scheduler._memory_limit_bytes = 0 - scheduler._memory_hard_limit_bytes = 0 - scheduler._prefill_memory_guard = False - scheduler._admission_paused = False engine.scheduler = scheduler entry = _make_entry("m", engine=engine) pool._entries = {"m": entry} @@ -1131,9 +1249,6 @@ async def test_propagates_admission_paused_on_soft(self, enforcer_2wm, pool): async def test_clears_admission_paused_on_recovery(self, enforcer_2wm, pool): engine = MagicMock() scheduler = MagicMock() - scheduler._memory_limit_bytes = 0 - scheduler._memory_hard_limit_bytes = 0 - scheduler._prefill_memory_guard = False scheduler._admission_paused = True engine.scheduler = scheduler entry = _make_entry("m", engine=engine, is_pinned=True) diff --git a/tests/test_rerank_models.py b/tests/test_rerank_models.py new file mode 100644 index 000000000..c4c951a2d --- /dev/null +++ b/tests/test_rerank_models.py @@ -0,0 +1,182 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for omlx/api/rerank_models.py — the Pydantic schemas served at +/v1/rerank. Pins down Cohere/Jina compatibility: required fields, +multimodal query/document shapes, defaults, and the auto-generated +``id`` prefix that downstream clients filter on. +""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from omlx.api.rerank_models import ( + RerankRequest, + RerankResponse, + RerankResult, + RerankUsage, +) + + +class TestRerankRequest: + def test_minimal_text_request(self): + req = RerankRequest( + model="qwen3-reranker", + query="best wireless headphones", + documents=["Sony WH-1000XM5", "Bose QC45"], + ) + assert req.model == "qwen3-reranker" + assert req.query == "best wireless headphones" + assert req.documents == ["Sony WH-1000XM5", "Bose QC45"] + + def test_defaults(self): + req = RerankRequest(model="m", query="q", documents=["d"]) + assert req.top_n is None + assert req.return_documents is True # Cohere-compat default + assert req.max_chunks_per_doc is None + + def test_dict_query_for_multimodal(self): + req = RerankRequest( + model="qwen3-vl-reranker", + query={"text": "a red car", "image": "https://x/y.jpg"}, + documents=["doc1"], + ) + assert req.query == {"text": "a red car", "image": "https://x/y.jpg"} + + def test_dict_documents_for_multimodal(self): + req = RerankRequest( + model="qwen3-vl-reranker", + query="cars", + documents=[ + {"text": "ferrari", "image": "data:image/png;base64,AAA"}, + {"text": "porsche"}, + ], + ) + assert isinstance(req.documents[0], dict) + assert req.documents[0]["image"].startswith("data:image/png") + + def test_top_n_accepts_int(self): + req = RerankRequest( + model="m", query="q", documents=["a", "b", "c"], top_n=2 + ) + assert req.top_n == 2 + + def test_missing_model_rejected(self): + with pytest.raises(ValidationError): + RerankRequest(query="q", documents=["d"]) # type: ignore[call-arg] + + def test_missing_query_rejected(self): + with pytest.raises(ValidationError): + RerankRequest(model="m", documents=["d"]) # type: ignore[call-arg] + + def test_missing_documents_rejected(self): + with pytest.raises(ValidationError): + RerankRequest(model="m", query="q") # type: ignore[call-arg] + + def test_return_documents_false_round_trips(self): + req = RerankRequest( + model="m", query="q", documents=["d"], return_documents=False + ) + restored = RerankRequest.model_validate(req.model_dump()) + assert restored.return_documents is False + + +class TestRerankResult: + def test_minimal_result_with_no_document(self): + r = RerankResult(index=3, relevance_score=0.91) + assert r.index == 3 + assert r.relevance_score == 0.91 + assert r.document is None # return_documents=False path + + def test_result_with_text_document(self): + r = RerankResult( + index=0, relevance_score=0.5, document={"text": "Sony WH-1000XM5"} + ) + assert r.document == {"text": "Sony WH-1000XM5"} + + def test_result_preserves_multimodal_document(self): + r = RerankResult( + index=1, + relevance_score=0.3, + document={"text": "ferrari", "image": "data:image/png;base64,AAA"}, + ) + assert "image" in r.document + assert r.document["image"].startswith("data:image/png") + + def test_missing_index_rejected(self): + with pytest.raises(ValidationError): + RerankResult(relevance_score=0.5) # type: ignore[call-arg] + + def test_missing_score_rejected(self): + with pytest.raises(ValidationError): + RerankResult(index=0) # type: ignore[call-arg] + + +class TestRerankUsage: + def test_required_field(self): + u = RerankUsage(total_tokens=42) + assert u.total_tokens == 42 + + def test_missing_total_tokens_rejected(self): + with pytest.raises(ValidationError): + RerankUsage() # type: ignore[call-arg] + + +class TestRerankResponse: + def test_minimal_response(self): + resp = RerankResponse( + results=[RerankResult(index=0, relevance_score=0.9)], + model="qwen3-reranker", + ) + assert resp.model == "qwen3-reranker" + assert len(resp.results) == 1 + assert resp.usage is None # optional + + def test_auto_id_has_rerank_prefix(self): + """Cohere clients filter telemetry on the ``rerank-`` prefix.""" + resp = RerankResponse(results=[], model="m") + assert resp.id.startswith("rerank-") + # 8 hex chars after the prefix + assert len(resp.id) == len("rerank-") + 8 + + def test_two_responses_get_distinct_ids(self): + a = RerankResponse(results=[], model="m") + b = RerankResponse(results=[], model="m") + assert a.id != b.id + + def test_explicit_id_is_preserved(self): + resp = RerankResponse(id="rerank-custom123", results=[], model="m") + assert resp.id == "rerank-custom123" + + def test_usage_attached(self): + resp = RerankResponse( + results=[], + model="m", + usage=RerankUsage(total_tokens=128), + ) + assert resp.usage is not None + assert resp.usage.total_tokens == 128 + + def test_missing_results_rejected(self): + with pytest.raises(ValidationError): + RerankResponse(model="m") # type: ignore[call-arg] + + def test_missing_model_rejected(self): + with pytest.raises(ValidationError): + RerankResponse(results=[]) # type: ignore[call-arg] + + def test_round_trip_via_json(self): + original = RerankResponse( + results=[ + RerankResult(index=2, relevance_score=0.95, document={"text": "a"}), + RerankResult(index=0, relevance_score=0.40, document={"text": "b"}), + ], + model="qwen3-reranker", + usage=RerankUsage(total_tokens=64), + ) + restored = RerankResponse.model_validate_json(original.model_dump_json()) + assert restored.model == original.model + assert restored.id == original.id + assert len(restored.results) == 2 + assert restored.results[0].relevance_score == 0.95 + assert restored.usage.total_tokens == 64 diff --git a/tests/test_scheduler_prefill_memory_guard.py b/tests/test_scheduler_prefill_memory_guard.py new file mode 100644 index 000000000..54c1683af --- /dev/null +++ b/tests/test_scheduler_prefill_memory_guard.py @@ -0,0 +1,576 @@ +# SPDX-License-Identifier: Apache-2.0 +"""End-to-end tests that the prefill memory guard is wired up. + +Until 2026-05-15 the guard was dead code: ``Scheduler.memory_monitor`` was +left as ``None`` and ``_set_model_info_for_monitor`` had zero callers, so +``_preflight_memory_check`` short-circuited at the ``memory_monitor is None`` +gate even when ``_prefill_memory_guard`` was flipped on by the enforcer. + +These tests pin the wiring so a future refactor cannot silently revert it. +""" + +from unittest.mock import MagicMock, patch + +from omlx.memory_monitor import MemoryMonitor +from omlx.request import Request, SamplingParams +from omlx.scheduler import Scheduler, SchedulerConfig + + +class _ModelConfig: + """Minimal config object exposing the fields the estimator reads.""" + + def __init__( + self, + num_hidden_layers: int | None = 32, + num_key_value_heads: int = 8, + num_attention_heads: int = 32, + head_dim: int = 192, # > 128 → SDPA fallback path (the panic-prone one) + ) -> None: + self.num_hidden_layers = num_hidden_layers + self.num_key_value_heads = num_key_value_heads + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + + +def _make_scheduler() -> Scheduler: + model = MagicMock() + model.layers = [] + model.config = _ModelConfig() + # Strip make_cache so the KVCache-counting branch in + # _set_model_info_for_monitor doesn't try to iterate a MagicMock. + del model.make_cache + + tokenizer = MagicMock() + tokenizer.eos_token_id = 2 + + config = SchedulerConfig( + max_num_seqs=8, + prefill_step_size=2048, + paged_cache_block_size=0, + ) + return Scheduler(model=model, tokenizer=tokenizer, config=config) + + +def _make_request(prompt_tokens: int = 65536) -> Request: + req = Request( + request_id="req-large", + prompt=list(range(prompt_tokens)), + sampling_params=SamplingParams(max_tokens=8), + ) + req.prompt_token_ids = list(range(prompt_tokens)) + req.num_prompt_tokens = prompt_tokens + return req + + +def test_scheduler_init_instantiates_memory_monitor(): + scheduler = _make_scheduler() + assert isinstance(scheduler.memory_monitor, MemoryMonitor) + + +def test_scheduler_init_populates_estimator_dims(): + scheduler = _make_scheduler() + monitor = scheduler.memory_monitor + assert monitor is not None + assert monitor._num_attention_heads == 32 + assert monitor._head_dim == 192 + assert monitor._num_layers == 32 + assert monitor._num_kv_heads == 8 + + +def test_estimator_produces_nonzero_peak_after_init(): + scheduler = _make_scheduler() + assert scheduler.memory_monitor is not None + peak = scheduler.memory_monitor.estimate_prefill_peak_bytes(65536, 2048) + assert peak > 0 + + +def test_preflight_positive_control_passes_normal_request(): + """Positive-control: a normal prompt under a generous limit must NOT + be rejected. Defends against an accidental sign-flip on the + threshold comparison in _preflight_memory_check. + """ + scheduler = _make_scheduler() + scheduler._prefill_memory_guard = True + # Huge limit — even a multi-GB peak fits comfortably. + scheduler._memory_hard_limit_bytes = 10**18 + with patch("omlx.scheduler.mx.get_active_memory", return_value=0), patch( + "omlx.scheduler.get_phys_footprint", return_value=0 + ): + assert scheduler._preflight_memory_check(_make_request(32768)) is None + + +def test_preflight_rejects_when_estimated_peak_exceeds_hard_limit(): + scheduler = _make_scheduler() + scheduler._prefill_memory_guard = True + scheduler._memory_hard_limit_bytes = 1 # any allocation exceeds + + with patch("omlx.scheduler.mx.get_active_memory", return_value=0), patch( + "omlx.scheduler.get_phys_footprint", return_value=0 + ): + rejection = scheduler._preflight_memory_check(_make_request(65536)) + + assert rejection is not None + assert "Prefill would require" in rejection.message + assert "KV+SDPA" in rejection.message + assert rejection.estimated_bytes > 0 + assert rejection.limit_bytes == 1 + + +def test_preflight_returns_none_when_guard_disabled(): + scheduler = _make_scheduler() + scheduler._prefill_memory_guard = False + scheduler._memory_hard_limit_bytes = 1 + assert scheduler._preflight_memory_check(_make_request(65536)) is None + + +def test_preflight_returns_none_when_request_fully_cached(): + scheduler = _make_scheduler() + scheduler._prefill_memory_guard = True + scheduler._memory_hard_limit_bytes = 1 + req = _make_request(1000) + req.cached_tokens = 1000 + # Fully cached: no new tokens to prefill, no peak to estimate. + assert scheduler._preflight_memory_check(req) is None + + +def test_preflight_rejects_heavily_cached_long_context(): + """Regression for M3: a request whose suffix is small but whose + *full* prompt is long must still trip the guard, because the SDPA + scores tensor spans the full prompt (cached + new), not just the + new tokens. Previously the estimator passed only new_tokens to the + scores formula and the heavily-cached path slipped through. + """ + scheduler = _make_scheduler() + scheduler._prefill_memory_guard = True + # Tight limit so even a partial prefill against a 100k KV trips it. + scheduler._memory_hard_limit_bytes = 100 * 1024**2 # 100 MB + req = _make_request(100_000) + req.cached_tokens = 99_000 # only 1k new tokens but kv_len = 100k + with patch("omlx.scheduler.mx.get_active_memory", return_value=0), patch( + "omlx.scheduler.get_phys_footprint", return_value=0 + ): + error = scheduler._preflight_memory_check(req) + assert error is not None, ( + "guard must trip on heavily-cached long-context: SDPA scores " + "still span the full prompt" + ) + + +def test_preflight_rejects_uncached_long_context(): + """Symmetric to test_preflight_rejects_heavily_cached_long_context: + a request with mostly NEW tokens (no cache) at a 100k prompt must + also trip the guard. This locks in the SDPA-fallback K-dim formula + in both directions; if a future refactor regressed the cached path + OR the uncached path, only one of these two tests would fail. + """ + scheduler = _make_scheduler() + scheduler._prefill_memory_guard = True + scheduler._memory_hard_limit_bytes = 100 * 1024**2 # 100 MB + req = _make_request(100_000) + req.cached_tokens = 1_000 # almost everything is new + with patch("omlx.scheduler.mx.get_active_memory", return_value=0), patch( + "omlx.scheduler.get_phys_footprint", return_value=0 + ): + error = scheduler._preflight_memory_check(req) + assert error is not None, ( + "guard must trip on uncached long-context too" + ) + + +class _VLMConfig: + """Top-level VLM config whose LM dims live under text_config (Qwen3.6-VL, + Gemma-4 layout). The top-level surface deliberately has no num_hidden_layers, + so this exercises the nested-config descent path.""" + + def __init__(self): + self.architectures = ["Qwen3_5MoeForConditionalGeneration"] + self.model_type = "qwen3_5_moe" + self.text_config = _ModelConfig( + num_hidden_layers=40, + num_key_value_heads=2, + num_attention_heads=16, + head_dim=256, # > 128 → SDPA fallback (the panic-prone path) + ) + + +def _make_vlm_scheduler() -> Scheduler: + model = MagicMock() + model.layers = [] + model.config = _VLMConfig() + del model.make_cache + + tokenizer = MagicMock() + tokenizer.eos_token_id = 2 + + config = SchedulerConfig( + max_num_seqs=8, + prefill_step_size=2048, + paged_cache_block_size=0, + ) + return Scheduler(model=model, tokenizer=tokenizer, config=config) + + +def test_vlm_nested_config_populates_estimator_dims(): + """Regression: VLM models nest LM dims under config.text_config — the + estimator must follow the sub-config or it stays silently dead at + runtime (no Model info set log, peak == 0, guard short-circuits).""" + scheduler = _make_vlm_scheduler() + monitor = scheduler.memory_monitor + assert monitor is not None + assert monitor._num_layers == 40 + assert monitor._num_kv_heads == 2 + assert monitor._num_attention_heads == 16 + assert monitor._head_dim == 256 + + +def test_vlm_estimator_produces_nonzero_peak(): + scheduler = _make_vlm_scheduler() + assert scheduler.memory_monitor is not None + # 90k tokens at head_dim=256 / n_q=16 should yield a multi-GiB peak via + # the SDPA-fallback branch. + peak = scheduler.memory_monitor.estimate_prefill_peak_bytes(90000, 2048) + assert peak > 10 * 1024 * 1024 * 1024 # > 10 GiB + + +def test_rejection_releases_block_aware_cache_when_present(): + """Regression for the prefix-cache leak found in review: a request + rejected by the prefill memory guard had its ref counts on every + prefix-matched paged block (and its ``request_tables`` entry) + incremented by ``add_request → fetch_cache``. Without releasing + them on the rejection path, those refs pin the paged cache and + compound the very memory pressure that triggered the rejection. + """ + scheduler = _make_scheduler() + block_aware_cache = MagicMock() + paged_cache_manager = MagicMock() + scheduler.block_aware_cache = block_aware_cache + scheduler.paged_cache_manager = paged_cache_manager + + scheduler._release_paged_cache_for_request("req-leak") + + # When block_aware_cache is present it owns the cleanup chain + # (release_cache → paged_cache_manager.delete_block_table). + block_aware_cache.release_cache.assert_called_once_with("req-leak") + paged_cache_manager.delete_block_table.assert_not_called() + + +def test_rejection_releases_paged_cache_when_no_prefix_cache(): + """When block_aware_cache is absent but a paged_cache_manager is + wired up, the rejection path must call ``delete_block_table`` + directly — otherwise the request's ``request_tables`` entry and + every block ref it holds leaks for the process lifetime. + """ + scheduler = _make_scheduler() + scheduler.block_aware_cache = None + paged_cache_manager = MagicMock() + scheduler.paged_cache_manager = paged_cache_manager + + scheduler._release_paged_cache_for_request("req-leak") + + paged_cache_manager.delete_block_table.assert_called_once_with("req-leak") + + +def test_rejection_releases_draft_prefix_cache_for_specprefill_requests(): + """SpecPrefill primes an independent ``_draft_prefix_cache`` in + ``_try_specprefill_scoring`` (via its own ``fetch_cache``). + The rejection path must release that draft cache too, symmetric + to the target cache — otherwise a rejected SpecPrefill request + leaks every draft-block ref and orphans its ``_request_tables`` + entry exactly like the target-cache bug this commit fixes.""" + scheduler = _make_scheduler() + scheduler.block_aware_cache = MagicMock() + scheduler.paged_cache_manager = MagicMock() + draft_cache = MagicMock() + scheduler._draft_prefix_cache = draft_cache + + scheduler._release_paged_cache_for_request("req-spec-leak") + + draft_cache.release_cache.assert_called_once_with("req-spec-leak") + + +def test_rejection_helper_noop_without_caches(): + """No caches wired up → helper must not raise. Embedded test + schedulers (this file's ``_make_scheduler``) build without paged + caches; the helper must be safe to call unconditionally on the + rejection path.""" + scheduler = _make_scheduler() + scheduler.block_aware_cache = None + scheduler.paged_cache_manager = None + # Must not raise. + scheduler._release_paged_cache_for_request("req-leak") + + +def test_preflight_rejection_path_invokes_release_helper(): + """End-to-end wiring: the preflight rejection in ``_schedule_waiting`` + must invoke the cache-release helper before popping + ``self.requests``. Pins the call-site fix for the leak — without + this hook the helper could exist but never be called from the hot + path. + """ + scheduler = _make_scheduler() + scheduler._prefill_memory_guard = True + scheduler._memory_hard_limit_bytes = 1 # forces rejection + + req = _make_request(65536) + scheduler.requests[req.request_id] = req + scheduler.waiting.append(req) + + # Make the rejection branch take effect even before + # _ensure_batch_generator runs — patch the preflight check to + # short-circuit on entry and keep this test independent of the + # batch-generator construction path. + from omlx.scheduler import _PreflightRejection + + def _force_reject(_request): + return _PreflightRejection( + message="forced rejection for test", + estimated_bytes=1, + limit_bytes=1, + ) + + with patch.object( + scheduler, "_release_paged_cache_for_request" + ) as release_spy, patch.object( + scheduler, "_preflight_memory_check", side_effect=_force_reject + ), patch.object( + scheduler, "_ensure_batch_generator", return_value=None + ): + # Pretend a batch_generator exists so the loop continues past + # the ``if self.batch_generator is None: break`` guard. + scheduler.batch_generator = MagicMock() + scheduler._schedule_waiting() + + release_spy.assert_any_call(req.request_id) + assert req.request_id not in scheduler.requests + + +def test_vlm_preflight_rejects_oversize_request(): + scheduler = _make_vlm_scheduler() + scheduler._prefill_memory_guard = True + scheduler._memory_hard_limit_bytes = 40 * 1024 * 1024 * 1024 # 40 GiB hard limit + + with patch("omlx.scheduler.mx.get_active_memory", return_value=28 * 1024 ** 3), patch( + "omlx.scheduler.get_phys_footprint", return_value=28 * 1024 ** 3 + ): + # 100k tokens at head_dim=256 should push (28 GiB baseline + KV+SDPA + # peak) past the 40 GiB limit. + rejection = scheduler._preflight_memory_check(_make_request(100000)) + + assert rejection is not None + assert "KV+SDPA" in rejection.message + + +# --------------------------------------------------------------------------- +# Config-descent edge cases (M3 in the upstream review of this commit) +# --------------------------------------------------------------------------- + + +class _VLMTopLevelVisionConfig: + """Top-level config has num_hidden_layers that refers to the *vision* + encoder. The estimator must descend into text_config rather than + accept the top-level value, otherwise it miscalibrates the SDPA peak. + """ + + def __init__(self): + self.architectures = ["FakeVisionLM"] + self.model_type = "fake_vlm" + # Vision encoder block count surfaces at top-level on some + # HF auto-wrapped packs — accepting this would silently use + # 27 layers / wrong heads for the LM math. + self.num_hidden_layers = 27 + self.num_attention_heads = 16 # vision attn heads + self.head_dim = 80 # vision head_dim (< 128, different SDPA path) + self.text_config = _ModelConfig( + num_hidden_layers=40, + num_key_value_heads=2, + num_attention_heads=16, + head_dim=256, # LM head_dim → SDPA-fallback path + ) + + +def test_vlm_descent_prefers_text_config_over_top_level_vision_field(): + """Regression: top-level num_hidden_layers can refer to the vision + encoder; the estimator must prefer text_config when present.""" + model = MagicMock() + model.layers = [] + model.config = _VLMTopLevelVisionConfig() + del model.make_cache + tokenizer = MagicMock() + tokenizer.eos_token_id = 2 + cfg = SchedulerConfig( + max_num_seqs=8, prefill_step_size=2048, paged_cache_block_size=0, + ) + sched = Scheduler(model=model, tokenizer=tokenizer, config=cfg) + + monitor = sched.memory_monitor + assert monitor is not None + # Must be the LM dims from text_config, NOT vision (27 / 80). + assert monitor._num_layers == 40 + assert monitor._head_dim == 256 + + +class _AltSubConfigContainer: + """Some packs name the LM sub-config ``language_config`` (or + ``llm_config``) instead of ``text_config``.""" + + def __init__(self, sub_attr_name: str): + self.architectures = ["AltSubConfigVLM"] + sub = _ModelConfig( + num_hidden_layers=24, + num_key_value_heads=4, + num_attention_heads=24, + head_dim=192, + ) + setattr(self, sub_attr_name, sub) + + +def test_vlm_descent_handles_language_config_alias(): + model = MagicMock() + model.layers = [] + model.config = _AltSubConfigContainer("language_config") + del model.make_cache + tokenizer = MagicMock() + tokenizer.eos_token_id = 2 + sched = Scheduler( + model=model, + tokenizer=tokenizer, + config=SchedulerConfig( + max_num_seqs=8, prefill_step_size=2048, paged_cache_block_size=0, + ), + ) + assert sched.memory_monitor._num_layers == 24 + assert sched.memory_monitor._head_dim == 192 + + +def test_vlm_descent_handles_llm_config_alias(): + model = MagicMock() + model.layers = [] + model.config = _AltSubConfigContainer("llm_config") + del model.make_cache + tokenizer = MagicMock() + tokenizer.eos_token_id = 2 + sched = Scheduler( + model=model, + tokenizer=tokenizer, + config=SchedulerConfig( + max_num_seqs=8, prefill_step_size=2048, paged_cache_block_size=0, + ), + ) + assert sched.memory_monitor._num_layers == 24 + + +class _LegacyLMConfig: + """GPT-style legacy config exposing ``n_layer`` / ``n_head`` / ``n_embd`` + instead of HuggingFace's ``num_hidden_layers`` etc.""" + + def __init__(self): + self.n_layer = 12 + self.n_head = 12 + self.n_embd = 768 # head_dim derived as n_embd / n_head = 64 + + +def test_legacy_n_layer_fallback_path(): + """The extractor falls back to ``n_layer`` / ``n_head`` / ``n_embd`` for + GPT-style configs and derives head_dim when not directly present.""" + model = MagicMock() + model.layers = [] + model.config = _LegacyLMConfig() + del model.make_cache + tokenizer = MagicMock() + tokenizer.eos_token_id = 2 + sched = Scheduler( + model=model, + tokenizer=tokenizer, + config=SchedulerConfig( + max_num_seqs=8, prefill_step_size=2048, paged_cache_block_size=0, + ), + ) + monitor = sched.memory_monitor + assert monitor is not None + assert monitor._num_layers == 12 + assert monitor._num_kv_heads == 12 # falls back to n_head + assert monitor._head_dim == 64 # n_embd / n_head + + +class _BrokenConfig: + """A config whose attribute access raises — exercises the outer + try/except wrap in _set_model_info_for_monitor.""" + + @property + def num_hidden_layers(self): + raise RuntimeError("synthetic boom") + + +class _VLMWithNestedLegacyLayer: + """Hypothetical VLM whose LM sub-config exposes only the legacy + GPT-style ``n_layer`` (no ``num_hidden_layers``). The descent rule + must accept this so the LM dims aren't shadowed by the top-level + vision-encoder dims. + """ + + def __init__(self): + self.architectures = ["LegacyNestedVLM"] + # Top-level matches vision encoder dims that should be ignored. + self.num_hidden_layers = 27 + self.num_key_value_heads = 16 + self.num_attention_heads = 16 + self.head_dim = 80 + self.text_config = _ModelConfig( + num_hidden_layers=None, + num_key_value_heads=8, + num_attention_heads=32, + head_dim=128, + ) + # Force the sub-config to surface only n_layer, not + # num_hidden_layers. + self.text_config.num_hidden_layers = None + self.text_config.n_layer = 36 + + +def test_vlm_descent_prefers_text_config_via_legacy_n_layer(): + """Regression: the sub-config preference rule must accept legacy + ``n_layer`` in addition to ``num_hidden_layers`` so the descent + isn't silently skipped when only the legacy alias is present — + otherwise the top-level (vision) dims leak into the SDPA-peak + calculation. + """ + model = MagicMock() + model.layers = [] + model.config = _VLMWithNestedLegacyLayer() + del model.make_cache + tokenizer = MagicMock() + tokenizer.eos_token_id = 2 + sched = Scheduler( + model=model, + tokenizer=tokenizer, + config=SchedulerConfig( + max_num_seqs=8, prefill_step_size=2048, paged_cache_block_size=0, + ), + ) + monitor = sched.memory_monitor + assert monitor is not None + # Must be the LM dims (n_layer=36, head_dim=128), NOT vision (27/80). + assert monitor._num_layers == 36 + assert monitor._head_dim == 128 + + +def test_exception_during_descent_is_swallowed(): + """The whole _set_model_info_for_monitor body is wrapped in + try/except so a malformed config can't break Scheduler init.""" + model = MagicMock() + model.layers = [] + model.config = _BrokenConfig() + del model.make_cache + tokenizer = MagicMock() + tokenizer.eos_token_id = 2 + # Must not raise. + sched = Scheduler( + model=model, + tokenizer=tokenizer, + config=SchedulerConfig( + max_num_seqs=8, prefill_step_size=2048, paged_cache_block_size=0, + ), + ) + # Monitor exists but dims stayed None — estimator returns 0 / guard skips. + assert sched.memory_monitor is not None + assert sched.memory_monitor._num_layers is None diff --git a/tests/test_server_prefill_memory_handler.py b/tests/test_server_prefill_memory_handler.py new file mode 100644 index 000000000..3bbca9597 --- /dev/null +++ b/tests/test_server_prefill_memory_handler.py @@ -0,0 +1,195 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Verify PrefillMemoryExceededError maps to HTTP 413 in server.py. + +Regression-arming test for the actual prefill-guard chain validated +end-to-end on 2026-05-15: the message string format matches what the +guard surfaces in production, so a refactor that changes either the +error body shape or the HTTP code will be caught here. +""" + +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from omlx.exceptions import PrefillMemoryExceededError + + +def _build_test_app(): + """Build a minimal FastAPI app that re-uses the production handler.""" + import omlx.server as srv + + app = FastAPI() + app.add_exception_handler( + PrefillMemoryExceededError, srv.prefill_memory_exceeded_handler + ) + + @app.get("/v1/raise") + def raise_prefill_too_large(): + raise PrefillMemoryExceededError( + message=( + "Prefill would require ~43.56 GB peak " + "(current 28.00 GB + KV+SDPA 15.56 GB) " + "but limit is 40.00 GB. " + "Reduce context length or increase --max-process-memory." + ), + request_id="req-abc", + estimated_bytes=46_775_000_000, + limit_bytes=42_949_672_960, + ) + + @app.get("/health/raise") + def raise_prefill_too_large_health(): + raise PrefillMemoryExceededError( + message="Prefill would require ~50 GB peak but limit is 40 GB.", + request_id="req-xyz", + ) + + return app + + +class TestPrefillMemoryHandler: + def test_returns_413(self): + with TestClient(_build_test_app()) as client: + resp = client.get("/v1/raise") + assert resp.status_code == 413 + + def test_api_route_uses_openai_error_body(self): + """/v1/* routes get the OpenAI-style {"error": {"message": ...}} wrapper.""" + with TestClient(_build_test_app()) as client: + resp = client.get("/v1/raise") + body = resp.json() + assert "error" in body + msg = body["error"]["message"] + # The guard's diagnostic format is part of the public contract — the + # CLI hint at the end tells the user exactly how to recover. + assert "Prefill would require" in msg + assert "KV+SDPA" in msg + assert "--max-process-memory" in msg + + def test_api_route_body_carries_estimated_and_limit_bytes(self): + """Clients branch on the numeric ``estimated_bytes`` / + ``limit_bytes`` fields rather than regex-matching the human + message (which is localized / format-prone). Regression for + the body-shape gap: prior to the fix on 2026-05-15 the handler + embedded these numbers only inside ``message`` and dropped the + structured fields, defeating the point of the typed exception + carrying them. + """ + with TestClient(_build_test_app()) as client: + resp = client.get("/v1/raise") + body = resp.json() + assert body["error"]["estimated_bytes"] == 46_775_000_000 + assert body["error"]["limit_bytes"] == 42_949_672_960 + + def test_non_api_route_uses_plain_detail(self): + with TestClient(_build_test_app()) as client: + resp = client.get("/health/raise") + body = resp.json() + assert "detail" in body + assert "Prefill would require" in body["detail"] + + +class TestResponsesEndpointReaches413: + """End-to-end regression for ``/v1/responses``. The handler-shape tests + above use a synthetic ``/v1/raise`` route, which proves the handler + body but NOT the wiring of every prompt-bearing endpoint to the + preflight call. ``/v1/responses`` is the one route most-likely to + silently regress because it shares the StreamingResponse pattern + with ``/v1/chat/completions`` and reaches preflight via the same + code path. This test forces the preflight to raise and asserts + the route returns 413 instead of 200/500. + """ + + def _make_app_with_failing_preflight(self): + """Mount the real ``/v1/responses`` route with a mocked + engine_pool that returns an engine whose ``preflight_chat`` + raises ``PrefillMemoryExceededError``. Hits the *production* + handler — not a synthesized stub — so a wiring regression is + caught. + """ + from unittest.mock import AsyncMock, MagicMock + + import omlx.server as srv + + # Build an engine mock whose preflight_chat raises. The + # production handler awaits this BEFORE constructing + # StreamingResponse, so the raise propagates to the + # exception handler and the route can still emit 413. + async def _raising_preflight(*args, **kwargs): + raise PrefillMemoryExceededError( + message=( + "Prefill would require ~50 GB peak " + "(current 30 GB + KV+SDPA 20 GB) but limit " + "is 40 GB. Reduce context length or " + "increase --max-process-memory." + ), + request_id="req-responses", + estimated_bytes=53_687_091_200, + limit_bytes=42_949_672_960, + ) + + engine = MagicMock() + engine.preflight_chat = AsyncMock(side_effect=_raising_preflight) + engine.start = AsyncMock() + # The handler calls ``count_chat_tokens`` and feeds the result + # into ``validate_context_window``; without a real int the + # comparison ``num_prompt_tokens > max_context`` raises before + # preflight ever runs. + engine.count_chat_tokens = MagicMock(return_value=128) + + async def _get_engine_for_model(model_id): + return engine + + # Override the engine resolver and disable auth so the test + # talks to the real route. + srv.app.dependency_overrides[srv.verify_api_key] = lambda: True + srv.get_engine_for_model = _get_engine_for_model # type: ignore[assignment] + + return srv.app + + def test_v1_responses_returns_413_when_preflight_rejects(self): + from unittest.mock import MagicMock, patch + + import omlx.server as srv + + original_get_engine = srv.get_engine_for_model + original_overrides = dict(srv.app.dependency_overrides) + original_engine_pool = srv._server_state.engine_pool + try: + app = self._make_app_with_failing_preflight() + # Mock engine_pool so get_engine_pool() doesn't raise 503. + # get_entry returns None so the handler's preserve_thinking + # short-circuit doesn't fire. + from unittest.mock import AsyncMock + + fake_pool = MagicMock() + fake_pool.get_entry = MagicMock(return_value=None) + fake_pool.preload_pinned_models = AsyncMock() + fake_pool.check_ttl_expirations = AsyncMock() + fake_pool.shutdown = AsyncMock() + srv._server_state.engine_pool = fake_pool + with TestClient(app, raise_server_exceptions=False) as client: + with patch.object( + srv, "resolve_model_id", lambda name: name + ), patch.object( + srv, "validate_context_window", lambda *a, **k: None + ): + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": "Hello, world.", + "stream": False, + }, + ) + assert resp.status_code == 413, ( + f"expected 413, got {resp.status_code}: {resp.text}" + ) + body = resp.json() + assert "error" in body, body + assert "Prefill would require" in body["error"]["message"] + assert "--max-process-memory" in body["error"]["message"] + finally: + srv.get_engine_for_model = original_get_engine + srv._server_state.engine_pool = original_engine_pool + srv.app.dependency_overrides.clear() + srv.app.dependency_overrides.update(original_overrides) diff --git a/tests/test_settings.py b/tests/test_settings.py index 9105b8012..f14f38d61 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -16,6 +16,7 @@ ClaudeCodeSettings, GlobalSettings, HuggingFaceSettings, + IntegrationSettings, LoggingSettings, MCPSettings, MemorySettings, @@ -1600,6 +1601,76 @@ def test_new_fields_from_dict_null_model(self): assert settings.opus_model is None +class TestIntegrationSettings: + """Tests for IntegrationSettings dataclass. + + Upstream ``tests/test_integrations.py::TestIntegrationSettings`` already + covers defaults, basic to_dict, and full/empty from_dict. The tests + here are the additive coverage: exact dict-shape pinning (so a future + field addition that forgets to_dict raises a loud test failure — see + 81dc2d5 for the MemorySettings case), partial-dict fallback, + explicit-null override semantics, and round-trip identity. + """ + + def test_to_dict_defaults(self): + settings = IntegrationSettings() + assert settings.to_dict() == { + "codex_model": None, + "opencode_model": None, + "openclaw_model": None, + "hermes_model": None, + "pi_model": None, + "copilot_model": None, + "openclaw_tools_profile": "coding", + } + + def test_to_dict_custom(self): + settings = IntegrationSettings( + codex_model="qwen-coder-30b", + opencode_model="qwen-coder-7b", + openclaw_model="qwen-coder-3b", + hermes_model="hermes-3-8b", + pi_model="qwen-3-4b", + copilot_model="qwen-coder-1.5b", + openclaw_tools_profile="creative", + ) + assert settings.to_dict() == { + "codex_model": "qwen-coder-30b", + "opencode_model": "qwen-coder-7b", + "openclaw_model": "qwen-coder-3b", + "hermes_model": "hermes-3-8b", + "pi_model": "qwen-3-4b", + "copilot_model": "qwen-coder-1.5b", + "openclaw_tools_profile": "creative", + } + + def test_from_dict_partial(self): + """Missing keys fall back to dataclass defaults.""" + settings = IntegrationSettings.from_dict({"pi_model": "qwen-3-4b"}) + assert settings.pi_model == "qwen-3-4b" + assert settings.codex_model is None + assert settings.copilot_model is None + assert settings.openclaw_tools_profile == "coding" + + def test_from_dict_explicit_null_overrides_default(self): + """Explicit None for a *_model field must be preserved.""" + settings = IntegrationSettings.from_dict( + {"codex_model": None, "pi_model": "x"} + ) + assert settings.codex_model is None + assert settings.pi_model == "x" + + def test_round_trip(self): + """to_dict → from_dict → to_dict is identity.""" + original = IntegrationSettings( + codex_model="m1", + pi_model="m2", + openclaw_tools_profile="custom", + ) + round_tripped = IntegrationSettings.from_dict(original.to_dict()) + assert round_tripped.to_dict() == original.to_dict() + + class TestClaudeCodeValidation: """Tests for mode validation in GlobalSettings.validate().""" diff --git a/tests/test_torch_stub.py b/tests/test_torch_stub.py new file mode 100644 index 000000000..a14803917 --- /dev/null +++ b/tests/test_torch_stub.py @@ -0,0 +1,453 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for omlx._torch_stub. + +The stub is load-bearing for the DMG flow: it satisfies xgrammar / +tvm_ffi's import-time torch references without the real ~500 MB torch +wheel. Direct tests here catch the realistic regression where a future +xgrammar / tvm_ffi version starts touching a new torch attribute at +import. +""" + +from __future__ import annotations + +import importlib +import os +import subprocess +import sys +import textwrap +import threading +import types +import unittest.mock as mock + +import pytest + +# Save modules touched by install() so each test starts clean. +_TOUCHED = ( + "torch", + "torch.cuda", + "torch.version", + "torch.nn", + "torch.nn.functional", + "torch.utils", + "torch.utils.dlpack", +) + + +@pytest.fixture(autouse=True) +def _restore_sys_modules(): + saved = {k: sys.modules[k] for k in _TOUCHED if k in sys.modules} + # Clear any leftover stub state from a previous test so each starts clean. + for k in _TOUCHED: + sys.modules.pop(k, None) + yield + for k in _TOUCHED: + sys.modules.pop(k, None) + sys.modules.update(saved) + + +@pytest.fixture +def stub_module(): + """Import a fresh copy of the stub module so its module-level state + doesn't leak between tests.""" + if "omlx._torch_stub" in sys.modules: + importlib.reload(sys.modules["omlx._torch_stub"]) + return sys.modules["omlx._torch_stub"] + import omlx._torch_stub as m + return m + + +def test_install_returns_true_and_populates_sys_modules(stub_module): + # Force "no real torch": remove any existing torch import. + for k in _TOUCHED: + sys.modules.pop(k, None) + with mock.patch( + "importlib.util.find_spec", side_effect=lambda name: None + ): + applied = stub_module.install() + assert applied is True + for k in _TOUCHED: + assert k in sys.modules, f"{k} not installed in sys.modules" + torch = sys.modules["torch"] + assert torch.__version__.endswith("+omlx-stub") + # The dtype set xgrammar/tvm_ffi look up at import time. + for dt in ( + "int8", "int16", "int32", "int", "int64", "long", "uint8", + "float16", "half", "float32", "float", "float64", "double", + "bfloat16", "bool", "short", + ): + assert hasattr(torch, dt), f"torch.{dt} missing" + # Tensor aliases that xgrammar's contrib/hf.py uses in annotations. + for alias in ("Tensor", "LongTensor", "FloatTensor", "IntTensor"): + assert hasattr(torch, alias) + # Submodules tvm_ffi reaches into. + assert sys.modules["torch.cuda"].is_available() is False + assert sys.modules["torch.version"].cuda is None + + +def test_install_is_idempotent(stub_module): + for k in _TOUCHED: + sys.modules.pop(k, None) + with mock.patch("importlib.util.find_spec", side_effect=lambda name: None): + first = stub_module.install() + second = stub_module.install() + assert first is True + # Second call sees the stub already in sys.modules and reports it. + assert second is True + + +def test_install_no_op_when_real_torch_present(stub_module): + # Simulate a previously-imported real torch module. + real = types.ModuleType("torch") + real.__version__ = "2.4.0" + real.__spec__ = importlib.machinery.ModuleSpec("torch", loader=None) + sys.modules["torch"] = real + applied = stub_module.install() + assert applied is False + # We must not have replaced the real torch. + assert sys.modules["torch"] is real + # And we must not have added stub submodules on top of real torch. + assert "torch.cuda" not in sys.modules + + +def test_install_no_op_when_torch_findable_via_spec(stub_module): + # No torch in sys.modules, but importlib can find a spec for it. + for k in _TOUCHED: + sys.modules.pop(k, None) + fake_spec = importlib.machinery.ModuleSpec("torch", loader=None) + with mock.patch( + "importlib.util.find_spec", + side_effect=lambda name: fake_spec if name == "torch" else None, + ): + applied = stub_module.install() + assert applied is False + assert "torch" not in sys.modules + + +def test_stub_dtype_works_as_dict_key(stub_module): + """tvm_ffi.cython.dtype.pxi builds a dict keyed by torch.int8, + torch.bfloat16, etc. — verify the stub dtypes are hashable and + distinct.""" + for k in _TOUCHED: + sys.modules.pop(k, None) + with mock.patch("importlib.util.find_spec", side_effect=lambda name: None): + stub_module.install() + torch = sys.modules["torch"] + table = { + torch.int8: 1, + torch.short: 2, + torch.int32: 3, + torch.int64: 4, + torch.bfloat16: 5, + torch.bool: 6, + torch.float32: 7, + } + # All distinct keys. + assert len(table) == 7 + assert table[torch.int32] == 3 + + +def test_stub_tensor_isinstance_check(stub_module): + """xgrammar/tvm_ffi use isinstance(value, torch.Tensor) to gate + torch-specific paths. Our values (numpy arrays, mx.array) must + correctly fail that check.""" + for k in _TOUCHED: + sys.modules.pop(k, None) + with mock.patch("importlib.util.find_spec", side_effect=lambda name: None): + stub_module.install() + torch = sys.modules["torch"] + assert isinstance(torch.Tensor(), torch.Tensor) # stub instance is its own tensor + # Non-stub values cleanly fail. + assert not isinstance(42, torch.Tensor) + assert not isinstance([1, 2, 3], torch.Tensor) + assert not isinstance("hello", torch.Tensor) + # torch.dtype is also a class for isinstance checks. + assert isinstance(torch.int32, torch.dtype) + assert not isinstance(42, torch.dtype) + + +def test_unsupported_helpers_raise_runtime_error(stub_module): + """torch.full / torch.zeros / torch.nn.functional.pad are stubbed to + raise RuntimeError so a future caller gets a clear error instead of + a cryptic None-attribute traceback.""" + for k in _TOUCHED: + sys.modules.pop(k, None) + with mock.patch("importlib.util.find_spec", side_effect=lambda name: None): + stub_module.install() + torch = sys.modules["torch"] + with pytest.raises(RuntimeError, match="torch.full"): + torch.full((1,), 0) + with pytest.raises(RuntimeError, match="torch.zeros"): + torch.zeros((1,)) + with pytest.raises(RuntimeError, match="nn.functional.pad"): + torch.nn.functional.pad(None, (0, 1)) + + +def test_torch_tensor_returns_stub_instance_with_loud_method_failure( + stub_module, +): + """torch.tensor(...) returns a _StubTensor instance so module-globals + like ``_FULL_MASK = torch.tensor(-1, dtype=...)`` survive import time. + Subsequent method calls (e.g. ``.fill_()``) raise a clear RuntimeError + rather than the prior silent-None path. + """ + for k in _TOUCHED: + sys.modules.pop(k, None) + with mock.patch("importlib.util.find_spec", side_effect=lambda name: None): + stub_module.install() + torch = sys.modules["torch"] + t = torch.tensor(-1, dtype=torch.int32) + assert isinstance(t, torch.Tensor) + with pytest.raises(RuntimeError, match="_StubTensor.fill_"): + t.fill_(0) + + +def test_dtype_aliases_share_identity(stub_module): + """Real torch has ``torch.int is torch.int32`` — preserve that identity + so code doing ``assert x.dtype is torch.int32`` against ``torch.int`` + works identically against the stub.""" + for k in _TOUCHED: + sys.modules.pop(k, None) + with mock.patch("importlib.util.find_spec", side_effect=lambda name: None): + stub_module.install() + torch = sys.modules["torch"] + assert torch.int is torch.int32 + assert torch.long is torch.int64 + assert torch.short is torch.int16 + assert torch.half is torch.float16 + assert torch.float is torch.float32 + assert torch.double is torch.float64 + + +def test_dtype_str_returns_torch_prefix(stub_module): + """tvm_ffi.cpp.dtype.to_cpp_dtype calls ``str(dtype)`` and strips + a ``torch.`` prefix; our dtypes must serialize that way.""" + for k in _TOUCHED: + sys.modules.pop(k, None) + with mock.patch("importlib.util.find_spec", side_effect=lambda name: None): + stub_module.install() + torch = sys.modules["torch"] + assert str(torch.int32) == "torch.int32" + assert str(torch.bfloat16) == "torch.bfloat16" + + +def test_install_sets_tvm_ffi_dlpack_env_var(stub_module): + """install() must set TVM_FFI_DISABLE_TORCH_C_DLPACK so tvm-ffi skips + the doomed JIT extension build that otherwise spawns a Python + subprocess and surfaces a misleading warning at every cold start. + """ + for k in _TOUCHED: + sys.modules.pop(k, None) + os.environ.pop("TVM_FFI_DISABLE_TORCH_C_DLPACK", None) + try: + with mock.patch( + "importlib.util.find_spec", side_effect=lambda name: None + ): + stub_module.install() + assert os.environ.get("TVM_FFI_DISABLE_TORCH_C_DLPACK") == "1" + finally: + os.environ.pop("TVM_FFI_DISABLE_TORCH_C_DLPACK", None) + + +def test_install_does_not_touch_env_var_when_real_torch_present(stub_module): + """The opposite of the previous test: when real torch is detected via + find_spec, install() must NOT mutate TVM_FFI_DISABLE_TORCH_C_DLPACK. + A user with real torch installed may want the tvm-ffi/torch-C-DLPack + fast path; the stub should not silently disable it. + """ + for k in _TOUCHED: + sys.modules.pop(k, None) + os.environ.pop("TVM_FFI_DISABLE_TORCH_C_DLPACK", None) + try: + fake_spec = importlib.util.spec_from_loader("torch", loader=None) + with mock.patch( + "importlib.util.find_spec", + side_effect=lambda name: fake_spec if name == "torch" else None, + ): + result = stub_module.install() + assert result is False + assert "TVM_FFI_DISABLE_TORCH_C_DLPACK" not in os.environ, ( + "real-torch path must leave the env var alone" + ) + finally: + os.environ.pop("TVM_FFI_DISABLE_TORCH_C_DLPACK", None) + + +def test_missing_top_level_attribute_raises_attributeerror_and_logs( + stub_module, caplog +): + """``torch.`` must raise ``AttributeError`` (so ``hasattr`` + consumers behave correctly) AND log a one-shot WARNING that names + the missing attribute. The log is the operator-facing diagnostic + when a future xgrammar / tvm-ffi release reaches for a torch + surface the stub doesn't cover; without it, the AttributeError + surfaces only if the caller logs it themselves. + """ + for k in _TOUCHED: + sys.modules.pop(k, None) + with mock.patch("importlib.util.find_spec", side_effect=lambda name: None): + stub_module.install() + torch = sys.modules["torch"] + with caplog.at_level("WARNING", logger="omlx._torch_stub"): + with pytest.raises(AttributeError, match="torch.compile"): + torch.compile # noqa: B018 + assert any( + "missing attribute: torch.compile" in rec.message + for rec in caplog.records + ), caplog.records + # ``hasattr`` must continue to return False (i.e. the AttributeError + # path is reachable) — regression for replacing the raise with a + # log-and-return. + assert not hasattr(torch, "another_missing_attr") + + +def test_stub_modules_have_real_spec_and_loader(stub_module): + """Every stub module in sys.modules must have a real ``__spec__`` + (a ``ModuleSpec`` instance, not ``None``) so ``importlib.util. + find_spec`` succeeds for downstream consumers — transformers / + accelerate / huggingface_hub all probe torch via find_spec at + import time, and ``None`` here trips their fallback paths into + incorrect behavior. + """ + for k in _TOUCHED: + sys.modules.pop(k, None) + with mock.patch("importlib.util.find_spec", side_effect=lambda name: None): + stub_module.install() + for name in ( + "torch", + "torch.cuda", + "torch.version", + "torch.nn", + "torch.nn.functional", + "torch.utils", + "torch.utils.dlpack", + ): + mod = sys.modules[name] + assert mod.__spec__ is not None, f"{name} missing __spec__" + assert isinstance(mod.__spec__, importlib.machinery.ModuleSpec), ( + f"{name}.__spec__ wrong type: {type(mod.__spec__)}" + ) + assert mod.__spec__.name == name + + +def test_utils_dlpack_to_dlpack_raises(stub_module): + """``torch.utils.dlpack.to_dlpack`` is a separately-exposed helper + (not in ``torch.nn.functional``). If a future tvm-ffi reaches for + it under the stub it must raise loudly rather than silently return + None — calls into this path mean the caller assumed real torch and + will produce wrong results downstream. + """ + for k in _TOUCHED: + sys.modules.pop(k, None) + with mock.patch("importlib.util.find_spec", side_effect=lambda name: None): + stub_module.install() + import torch # type: ignore + + with pytest.raises(RuntimeError, match="utils.dlpack.to_dlpack"): + torch.utils.dlpack.to_dlpack(object()) + + +def test_install_is_thread_safe(stub_module): + """Concurrent install() calls must serialize and produce a single + consistent stub. Regression for a race where two threads both passed + the ``"torch" in sys.modules`` check, both built modules, and + overwrote each other in sys.modules — leaving threads with stale + references to the loser's module objects. + """ + for k in _TOUCHED: + sys.modules.pop(k, None) + results: list[bool] = [] + barrier = threading.Barrier(8) + errors: list[Exception] = [] + + def worker(): + try: + barrier.wait(timeout=2.0) + with mock.patch( + "importlib.util.find_spec", side_effect=lambda name: None + ): + results.append(stub_module.install()) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=worker) for _ in range(8)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=5.0) + assert not errors, errors + assert len(results) == 8 + assert all(r is True for r in results) + # All threads see the same single torch module instance. + torch = sys.modules["torch"] + assert torch.__version__.endswith("+omlx-stub") + + +@pytest.mark.skipif( + not (importlib.util.find_spec("xgrammar") and importlib.util.find_spec("tvm_ffi")), + reason="xgrammar / tvm_ffi not installed", +) +def test_xgrammar_imports_against_stub_only(stub_module, tmp_path): + """Realistic regression: spawn a subprocess that blocks real torch and + asserts ``import xgrammar`` and the modules oMLX touches still load + against the stub. This is the test that gates xgrammar / tvm-ffi + version bumps — if a new release reaches for a torch attribute the + stub doesn't cover, this fails loudly at the import step. + """ + script = tmp_path / "probe.py" + script.write_text(textwrap.dedent(""" + import sys + + # Block real torch end-to-end without touching sys.path (which + # would also strip xgrammar in the common pip layout where both + # live in the same site-packages). A meta-path finder that + # returns None just delegates to the next finder; raising + # ImportError aborts the import before PathFinder runs. + for k in list(sys.modules): + if k == "torch" or k.startswith("torch."): + del sys.modules[k] + + import importlib.abc + + class _BlockTorch(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + if fullname == "torch" or fullname.startswith("torch."): + raise ImportError( + f"{fullname} blocked by test probe to force " + "the stub-only path" + ) + return None + + sys.meta_path.insert(0, _BlockTorch()) + + # install()'s own `importlib.util.find_spec('torch')` check + # also needs to see no torch. + import importlib.util + _orig_find_spec = importlib.util.find_spec + def _no_torch(name, *args, **kwargs): + if name == "torch" or name.startswith("torch."): + return None + return _orig_find_spec(name, *args, **kwargs) + importlib.util.find_spec = _no_torch + + from omlx._torch_stub import install + assert install() is True, ( + "stub install returned False — real torch was reachable " + "despite meta-path / find_spec blocking" + ) + + import xgrammar + from xgrammar import contrib # noqa: F401 + from xgrammar.kernels.apply_token_bitmask_mlx import ( # noqa: F401 + apply_token_bitmask_mlx, + ) + print("OK") + """)) + env = dict(os.environ) + env.pop("TVM_FFI_DISABLE_TORCH_C_DLPACK", None) + out = subprocess.check_output( + [sys.executable, str(script)], + stderr=subprocess.STDOUT, + env=env, + timeout=30, + ) + assert b"OK" in out, out