diff --git a/omlx/process_memory_enforcer.py b/omlx/process_memory_enforcer.py index 7d11fee31..41f3370d1 100644 --- a/omlx/process_memory_enforcer.py +++ b/omlx/process_memory_enforcer.py @@ -33,6 +33,7 @@ import logging import subprocess import sys +from collections import deque from typing import TYPE_CHECKING, Any import mlx.core as mx @@ -291,6 +292,7 @@ def __init__( hard_threshold: float = 0.95, prefill_safe_zone_ratio: float = 0.80, prefill_min_chunk_tokens: int = 32, + prefill_transient_margin_gb: float = 0.0, ): """ Initialize the process memory enforcer. @@ -317,6 +319,11 @@ def __init__( prefill_safe_zone_ratio: Fraction of hard cap below which prefill runs at full chunk size; above triggers adaptive shrink. prefill_min_chunk_tokens: Floor for adaptive shrink. + prefill_transient_margin_gb: Conservative margin added to the + modelled per-chunk prefill peak by the scheduler's + forward-front gate, covering the MoE expert-dequant activation + spike that estimate_prefill_peak_bytes does not model. + Propagated to each scheduler. 0 = no extra margin. """ self._engine_pool = engine_pool self._memory_guard_tier = self._normalize_tier(memory_guard_tier) @@ -331,6 +338,9 @@ def __init__( self._hard_threshold = hard_threshold self._prefill_safe_zone_ratio = prefill_safe_zone_ratio self._prefill_min_chunk_tokens = prefill_min_chunk_tokens + self._prefill_transient_margin_bytes = max( + 0, int(prefill_transient_margin_gb * 1024**3) + ) self._task: asyncio.Task | None = None self._running = False # Most recently observed pressure level, consumed by scheduler / @@ -340,6 +350,13 @@ def __init__( # or the call failed). Used by the admin dashboard to surface a # warning when the kernel iogpu.wired_limit_mb is below this. self._metal_wired_limit_request: int = 0 + # Rolling window of recent usage readings + their high-water mark. + # Prefill memory dips into a trough between chunks, so the instant + # reading can read low mid-prefill; preflight admission consults this + # peak instead so it does not wave through a request that will wall + # the next chunk. Updated on every poll iteration. + self._usage_window: deque[int] = deque(maxlen=5) + self._recent_peak_bytes: int = 0 @staticmethod def _normalize_tier(tier: str) -> str: @@ -503,6 +520,10 @@ def get_final_ceiling(self) -> int: """Public accessor used by engine_pool pre-load admission.""" return self._get_hard_limit_bytes() + def recent_peak_bytes(self) -> int: + """Recent high-water memory usage over the last few poll ticks.""" + return self._recent_peak_bytes + def _soft_bytes(self) -> int: """Soft watermark: ceiling * soft_threshold.""" ceiling = self._get_hard_limit_bytes() @@ -589,6 +610,10 @@ def _propagate_memory_limit(self) -> None: scheduler._admission_paused = admission_paused scheduler._prefill_safe_zone_ratio = self._prefill_safe_zone_ratio scheduler._prefill_min_chunk_tokens = self._prefill_min_chunk_tokens + scheduler._prefill_transient_margin_bytes = ( + self._prefill_transient_margin_bytes + ) + scheduler._memory_recent_peak_bytes = self._recent_peak_bytes bg = getattr(scheduler, "batch_generator", None) if bg is not None and hasattr(bg, "_memory_limit_bytes"): bg._memory_limit_bytes = soft_limit @@ -671,6 +696,8 @@ async def _check_and_enforce(self) -> None: return current = self._current_usage_bytes() + self._usage_window.append(current) + self._recent_peak_bytes = max(self._usage_window) if self._usage_window else current soft = int(ceiling * self._soft_threshold) hard = int(ceiling * self._hard_threshold) prev_level = self._pressure_level diff --git a/omlx/scheduler.py b/omlx/scheduler.py index 253845bf9..e99038ede 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -796,10 +796,26 @@ def __init__( # soft_threshold. Schedulers stop admitting new prefills while this is # set; in-flight requests proceed. self._admission_paused: bool = False + # Recent high-water memory usage, propagated from ProcessMemoryEnforcer. + # The entry-budget admission (_predicted_prefill_peak_bytes) and the + # forward gate max the instant reading against this so they do not wave + # through a request during a prefill trough that would wall the next + # chunk. Folded in only while requests are in-flight (see + # _predicted_prefill_peak_bytes). 0 until the enforcer sets it. + self._memory_recent_peak_bytes: int = 0 # Adaptive prefill throttle params, propagated from enforcer. # Until set, _adaptive_chunk_size is a no-op (returns requested as-is). self._prefill_safe_zone_ratio: float = 0.80 self._prefill_min_chunk_tokens: int = 32 + # Conservative transient margin (bytes) added to the modelled prefill + # peak by both entry guards: _predicted_prefill_peak_bytes (admission) + # and _prefill_forward_gate (per-chunk backstop). + # estimate_prefill_peak_bytes only models KV + SDPA; it does NOT model + # the MoE expert-dequant activation spike, which on glm4.5-air-106b + # (MoE) is the dominant single-step transient. Sized from the observed + # worst-case single-step current jump on m5max (see MemorySettings + # .prefill_transient_margin_gb). 0 until the enforcer propagates it. + self._prefill_transient_margin_bytes: int = 0 # EWMA estimator of per-token chunk transient bytes, used by # _adaptive_chunk_size in the caution zone. Owned per-scheduler. _tracker_model_id = "" @@ -1729,6 +1745,107 @@ def _apply_turboquant_kv_convert(self, prompt_cache: list[Any]) -> None: f"cache layers to {bits}-bit{skip_msg}" ) + def _prefill_forward_gate( + self, + chunk_tokens: int, + *, + request_id: str, + loop_label: str, + ) -> None: + """Forward-FRONT memory gate: refuse a prefill chunk BEFORE it runs. + + BACKSTOP to the primary entry-budget admission + (_predicted_prefill_peak_bytes + the _schedule_waiting defer). Admission + snapshots memory once, at admission time; this re-reads before every + chunk and so catches CONCURRENT DRIFT -- another in-flight request's KV + growing during this request's long prefill -- that the admission + snapshot cannot see. Because admission uses the full-prompt estimate + (>= this per-chunk estimate) and the same margin, a correctly-admitted + request cannot trip this gate under static memory: it only fires when + memory drifted up after admission. + + The chunk-end check (after self.model(...) + mx.eval) only fires once + the transient has already been allocated -- on Apple Silicon a chunk + that overshoots the Metal cap kernel-panics the whole machine, so an + after-the-fact Python check never runs. This predicts the next chunk's + peak and raises BEFORE the forward when it would exceed the hard cap, + so the request is aborted cleanly (the #1405 cleanup paths convert this + RuntimeError into a finish_reason="error" output) instead of crashing. + + predicted_peak = current(high-water) + estimate(KV+SDPA) + margin + - current: max(active, phys_footprint, recent_peak). recent_peak is + the enforcer's rolling high-water mark, so a mid-prefill trough in + the instant reading does not mask the real footprint. + - estimate: memory_monitor.estimate_prefill_peak_bytes models this + chunk's KV + SDPA activation only. + - margin: _prefill_transient_margin_bytes covers what the estimate + does NOT model -- chiefly the MoE expert-dequant activation spike, + the dominant single-step transient on MoE models like glm4.5-air. + + No-op (returns) when the guard is off, the hard limit is unset, the + memory_monitor is missing, or the estimate is unavailable (0) -- in + every such case the legacy chunk-end check remains the only line of + defense, exactly as before this gate existed. + + At chunk granularity the KV+SDPA estimate is small, so the margin is + the effective trip point (gate fires once current ~> cap - margin). + Correctness depends on `current` reflecting the true high-water mark: + this iteration runs right after the previous chunk's + _sync_and_clear_cache, when mx.get_active_memory() troughs, so the + gate leans on get_phys_footprint() + the propagated recent_peak to + avoid reading a trough. See MemorySettings.prefill_transient_margin_gb. + + Raises: + RuntimeError: when the predicted peak exceeds the hard limit. + """ + if not self._prefill_memory_guard: + return + if self._memory_hard_limit_bytes <= 0: + return + if self.memory_monitor is None: + return + if chunk_tokens <= 0: + return + + estimate = self.memory_monitor.estimate_prefill_peak_bytes( + chunk_tokens, self.config.prefill_step_size + ) + if estimate == 0: + return # can't estimate this model -> leave it to the chunk-end check + + predicted_transient = estimate + self._prefill_transient_margin_bytes + current = max( + mx.get_active_memory(), + get_phys_footprint(), + self._memory_recent_peak_bytes, + ) + predicted_peak = current + predicted_transient + + if predicted_peak > self._memory_hard_limit_bytes: + logger.warning( + "[memgate:%s] rid=%s refusing prefill chunk (n=%d) BEFORE " + "forward: predicted peak %.3fGB = current %.3fGB + transient " + "%.3fGB (estimate %.3fGB + margin %.3fGB) exceeds hard cap " + "%.3fGB. Aborting request to avoid a Metal-cap kernel panic.", + loop_label, + request_id, + chunk_tokens, + predicted_peak / 1024**3, + current / 1024**3, + predicted_transient / 1024**3, + estimate / 1024**3, + self._prefill_transient_margin_bytes / 1024**3, + self._memory_hard_limit_bytes / 1024**3, + ) + raise RuntimeError( + "Prefill refused before forward: predicted peak " + f"{predicted_peak / 1024**3:.1f}GB (current " + f"{current / 1024**3:.1f}GB + transient " + f"{predicted_transient / 1024**3:.1f}GB) would exceed the " + f"memory ceiling {self._memory_hard_limit_bytes / 1024**3:.1f}GB. " + "Reduce context length or increase --max-process-memory." + ) + def _do_external_prefill( self, request: "Request", @@ -1885,6 +2002,15 @@ def _do_external_prefill( extra_kwargs, n_to_process ) + # Forward-FRONT gate: predict this chunk's peak and refuse BEFORE + # the forward if it would breach the Metal cap (post-forward checks + # cannot save us -- the overshoot kernel-panics the machine). + self._prefill_forward_gate( + n_to_process, + request_id=request.request_id, + loop_label="external", + ) + _throttle_pre = get_phys_footprint() self.model(input_arr[:, :n_to_process], cache=prompt_cache, **model_kwargs) mx.eval([c.state for c in prompt_cache]) @@ -2223,6 +2349,17 @@ def _step_prefill_chunk(self, state: _PrefillState) -> bool: chunk = state.tokens_remaining[:, :n] state.tokens_remaining = state.tokens_remaining[:, n:] + + # Forward-FRONT gate: predict this chunk's peak and refuse BEFORE the + # forward if it would breach the Metal cap. Mirrors the external loop; + # raises RuntimeError that _advance_chunked_prefills converts into a + # finish_reason="error" output without crashing the machine. + self._prefill_forward_gate( + n, + request_id=state.request.request_id, + loop_label="chunked_step", + ) + _throttle_pre = get_phys_footprint() self.model(chunk, cache=state.cache) mx.eval([c.state for c in state.cache]) @@ -4506,20 +4643,52 @@ def get_num_running(self) -> int: """Get number of running requests.""" return len(self.running) - def _preflight_memory_check(self, request: "Request") -> str | None: + def _has_inflight_kv(self) -> bool: + """True when any request holds (or is still accumulating) KV cache -- + a running decode (self.running) OR an in-flight chunked prefill + (self.prefilling). A waiting request has not prefilled yet, so it holds + no KV. Admission folds the recent_peak high-water and may DEFER only + while this is True; with nothing in flight there is nothing to drain, so + a request is admitted instead of queued forever (and a lone request too + big to fit even alone is rejected by _preflight_memory_check). + + Crucially this includes self.prefilling: the documented crash is + concurrent CHUNKED prefills stacking, and a request mid-chunked-prefill + lives in self.prefilling, not self.running. Gating on self.running alone + would wave the 2nd, 3rd, ... prefill straight past the budget into the + exact stack this guard exists to prevent. """ - Estimate whether prefill would exceed memory limits. - - Computes worst-case peak memory for the last prefill chunk - (model weights + KV cache + SDPA attention matrix) and rejects - if it would exceed the hard limit. - - For head_dim > 128, MLX SDPA uses a fallback that materializes - the full attention matrix [B, n_q, chunk, kv_len] in float32. - For head_dim <= 128, MLX uses a fused kernel with O(n) memory. - - Returns: - Error message string if request should be rejected, None if OK. + return bool(self.running or self.prefilling) + + def _predicted_prefill_peak_bytes(self, request: "Request") -> int | None: + """Predicted memory peak if this request's prefill ran now, or None. + + Shared predicate for both entry-point memory guards: the budget + admission in _schedule_waiting (admit / queue) and the lone-request + rejection in _preflight_memory_check (reject). + + predicted = current(high-water) + estimate_prefill_peak(KV+SDPA) + margin + + - current: max(active, phys_footprint). The enforcer recent_peak + high-water is folded in ONLY while work is in flight + (_has_inflight_kv): mid-generation/mid-prefill it masks the + active-memory trough left by the last _sync_and_clear_cache, but + when idle it can be a stale prior-batch peak that would false-reject + a lone request. + - estimate: memory_monitor.estimate_prefill_peak_bytes over the full + new (uncached) prompt -- KV for the whole prompt + last-chunk SDPA. + This is >= the per-chunk estimate the forward gate uses, so a + request admitted here cannot trip the gate under static memory (the + gate is a concurrent-drift backstop, not a second veto on the happy + path -- see _prefill_forward_gate). + - margin: _prefill_transient_margin_bytes carries the un-modelled MoE + expert-dequant transient, which is sub-poll and so invisible to + every memory read (see MemorySettings.prefill_transient_margin_gb). + + Returns None (caller treats as "fits") when the guard is off, the hard + limit is unset, the monitor is missing, there are no new tokens, or the + model cannot be estimated (estimate == 0) -- in each case admission + falls back to the per-concurrency cap and the legacy in-prefill checks. """ if not self._prefill_memory_guard: return None @@ -4528,31 +4697,50 @@ def _preflight_memory_check(self, request: "Request") -> str | None: if self.memory_monitor is None: return None - prompt_tokens = request.num_prompt_tokens - cached_tokens = request.cached_tokens or 0 - new_tokens = max(prompt_tokens - cached_tokens, 0) - + new_tokens = max(request.num_prompt_tokens - (request.cached_tokens or 0), 0) if new_tokens == 0: return None - peak = self.memory_monitor.estimate_prefill_peak_bytes( + estimate = self.memory_monitor.estimate_prefill_peak_bytes( new_tokens, self.config.prefill_step_size ) - if peak == 0: - return None # can't estimate, skip + if estimate == 0: + return None # can't estimate this model -> skip current = max(mx.get_active_memory(), get_phys_footprint()) + if self._has_inflight_kv(): + current = max(current, self._memory_recent_peak_bytes) + return current + estimate + self._prefill_transient_margin_bytes + + def _preflight_memory_check(self, request: "Request") -> str | None: + """Reject a request whose prefill cannot fit even on its own. - if current + peak > self._memory_hard_limit_bytes: - from .utils.hardware import format_bytes + Entry-point sibling of the budget admission in _schedule_waiting. The + admission DEFERS (queues) a request that would breach the cap only + because other requests are in-flight; this REJECTS a request that would + breach even with nothing else running -- nothing will drain to make + room, so queuing it would deadlock. The #1405 cleanup converts the + returned error into a finish_reason="error" output -- the request is + refused cleanly instead of crashing the box. - return ( - f"Prefill would require ~{format_bytes(current + peak)} peak " - f"(current {format_bytes(current)} + KV+SDPA {format_bytes(peak)}) " - f"but limit is {format_bytes(self._memory_hard_limit_bytes)}. " - f"Reduce context length or increase --max-process-memory." - ) - return None + Returns: + Error message string if request should be rejected, None if OK. + """ + predicted = self._predicted_prefill_peak_bytes(request) + if predicted is None: + return None + if predicted <= self._memory_hard_limit_bytes: + return None + + from .utils.hardware import format_bytes + + return ( + f"Prefill would need ~{format_bytes(predicted)} peak " + f"(KV+SDPA estimate + {format_bytes(self._prefill_transient_margin_bytes)} " + f"transient margin) but the memory ceiling is " + f"{format_bytes(self._memory_hard_limit_bytes)}. " + f"Reduce context length or increase --max-process-memory." + ) def _schedule_waiting( self, @@ -4650,6 +4838,41 @@ def _schedule_waiting( request.remaining_tokens = request.prompt_token_ids tokens_to_process = request.prompt_token_ids + # Memory-budget admission. Predict this request's prefill peak; if + # admitting it alongside the in-flight work would breach the hard + # cap, leave it in the waiting queue (do NOT reject) and stop + # admitting for this step. step() re-invokes _schedule_waiting every + # step, so it is re-checked and admitted once in-flight work frees + # its KV. This is what makes concurrency adaptive without a fixed + # per-model cap: a memory-rich model packs many requests, an 85GB + # model collapses to 1 and queues the rest. + # + # Gate on _has_inflight_kv() (running OR prefilling), not self.running + # alone: a request mid-chunked-prefill is in self.prefilling and is + # the dominant KV holder during the very stack this guards against + # (concurrent prefills). The first request (nothing in flight) is + # never deferred -- nothing would drain to admit it -- so a lone + # request that cannot fit even alone falls through to + # _preflight_memory_check below and is rejected instead. + if self._has_inflight_kv(): + predicted_peak = self._predicted_prefill_peak_bytes(request) + if ( + predicted_peak is not None + and predicted_peak > self._memory_hard_limit_bytes + ): + self.waiting.appendleft(request) + logger.info( + "[memadmit] deferring request %s: predicted prefill " + "peak %d + in-flight (%d running, %d prefilling) would " + "breach hard cap %d; left queued", + request.request_id, + predicted_peak, + len(self.running), + len(self.prefilling), + self._memory_hard_limit_bytes, + ) + break + # SpecPrefill requests must be alone in the batch (RoPE patching # affects the entire model). Also block scheduling if another # specprefill request is already running (offset RoPE active). diff --git a/omlx/server.py b/omlx/server.py index 0d5df6745..9f0d9fb69 100644 --- a/omlx/server.py +++ b/omlx/server.py @@ -359,6 +359,7 @@ async def lifespan(app: FastAPI): hard_threshold=mem_cfg.hard_threshold, prefill_safe_zone_ratio=mem_cfg.prefill_safe_zone_ratio, prefill_min_chunk_tokens=mem_cfg.prefill_min_chunk_tokens, + prefill_transient_margin_gb=mem_cfg.prefill_transient_margin_gb, ) _server_state.process_memory_enforcer = enforcer _server_state.engine_pool._process_memory_enforcer = enforcer diff --git a/omlx/settings.py b/omlx/settings.py index adba2a6d8..6dcd97480 100644 --- a/omlx/settings.py +++ b/omlx/settings.py @@ -391,6 +391,37 @@ class MemorySettings: # aborted via the same cleanup path the hard-limit RuntimeError uses. prefill_safe_zone_ratio: float = 0.80 prefill_min_chunk_tokens: int = 32 + # Conservative transient margin added to the modelled prefill peak by BOTH + # memory guards: the entry-point budget admission (primary -- scheduler. + # _predicted_prefill_peak_bytes, decides admit/queue/reject at admission + # time) and the forward-FRONT chunk gate (backstop -- scheduler. + # _prefill_forward_gate, refuses a chunk before its forward). The model in + # memory_monitor.estimate_prefill_peak_bytes only accounts for KV + SDPA; + # it does NOT model the MoE expert-dequant activation spike, which on a MoE + # model (glm4.5-air-106b) is the dominant single-step transient. Either + # guard refuses when current + estimate + this margin would breach the hard + # cap, so the transient never actually lands on the Metal ceiling (which + # would kernel-panic the whole machine -- an after-the-fact Python check + # cannot catch it). + # + # The load-bearing guarantee is: margin > the worst-case single-step memory + # jump. The jump is SUB-POLL -- it rises and falls faster than the + # enforcer's 1s sample, so it is invisible to every memory read (active, + # phys_footprint, recent_peak) by construction. It therefore MUST be carried + # by this margin, not by reading the footprint more cleverly. Across the + # 2026-06-06 m5max glm4.5-air-106b crash log the max trough->peak single-step + # delta was 7.44GB and the peak overshoot reached 110.4GB vs a 107.5GB cap, + # i.e. an effective transient up to ~10.6GB above the pre-step baseline. + # margin=10 was too small (10 < 10.6 -> admitted a step that then breached); + # 12 = ceil(10.6) padded, the value that would have refused that step from + # its true pre-step baseline. Both guards read `current` at high-water + # (max active / phys_footprint / -- when requests are in-flight -- the + # enforcer recent_peak); the baseline (KV+weights) is what they see and it + # is the determining factor, the sub-poll spike rides on the margin. Watch + # the [memgate]/[memcheck] logs on hardware and raise the margin if a step + # ever breaches from a baseline below cap - margin. Set to 0 to disable the + # extra margin (the guards then use the bare KV+SDPA estimate). + prefill_transient_margin_gb: float = 12.0 def to_dict(self) -> dict[str, Any]: """Convert to dictionary.""" @@ -402,6 +433,7 @@ def to_dict(self) -> dict[str, Any]: "hard_threshold": self.hard_threshold, "prefill_safe_zone_ratio": self.prefill_safe_zone_ratio, "prefill_min_chunk_tokens": self.prefill_min_chunk_tokens, + "prefill_transient_margin_gb": self.prefill_transient_margin_gb, } @classmethod @@ -440,6 +472,9 @@ def from_dict(cls, data: dict[str, Any]) -> MemorySettings: prefill_min_chunk_tokens=int( data.get("prefill_min_chunk_tokens", 32) ), + prefill_transient_margin_gb=float( + data.get("prefill_transient_margin_gb", 12.0) + ), ) diff --git a/tests/test_process_memory_enforcer.py b/tests/test_process_memory_enforcer.py index b456ecc66..64ae81c0b 100644 --- a/tests/test_process_memory_enforcer.py +++ b/tests/test_process_memory_enforcer.py @@ -965,6 +965,63 @@ async def test_check_and_enforce_walks_caps_on_soft(self, enforcer): scheduler.adjust_store_cache_cap.assert_called_with("soft") +class TestRecentPeakTracking: + """Tests for recent-peak high-water tracking across poll ticks.""" + + @pytest.mark.asyncio + async def test_recent_peak_is_window_max(self, enforcer): + """After several poll ticks, recent_peak == max over the window. + + The window update lives after the ceiling > 0 early return in + _check_and_enforce, so the fixture's positive ceiling (10 GB) is + required for the update to run at all. + """ + readings = [3 * 1024**3, 5 * 1024**3, 2 * 1024**3, 4 * 1024**3] + with patch("omlx.process_memory_enforcer.mx") as mock_mx, patch( + "omlx.process_memory_enforcer.get_phys_footprint", return_value=0 + ): + mock_mx.get_active_memory.side_effect = _cycling(readings) + for _ in readings: + await enforcer._check_and_enforce() + + assert enforcer.recent_peak_bytes() == 5 * 1024**3 + + @pytest.mark.asyncio + async def test_recent_peak_drops_after_window_slides(self, enforcer): + """Old high readings age out once they leave the maxlen=5 window.""" + # Feed one big reading, then enough small ones to push it out of the + # 5-slot window. + big = 9 * 1024**3 + small = 1 * 1024**3 + readings = [big, small, small, small, small, small] + with patch("omlx.process_memory_enforcer.mx") as mock_mx, patch( + "omlx.process_memory_enforcer.get_phys_footprint", return_value=0 + ): + mock_mx.get_active_memory.side_effect = _cycling(readings) + for _ in readings: + await enforcer._check_and_enforce() + + # After 6 ticks the first (big) reading has slid out of the window, + # leaving only small readings. + assert enforcer.recent_peak_bytes() == small + + def test_propagates_recent_peak_to_scheduler(self, enforcer): + """_propagate_memory_limit pushes recent_peak onto each scheduler.""" + scheduler = MagicMock(spec=[]) + scheduler._memory_limit_bytes = 0 + scheduler._memory_hard_limit_bytes = 0 + scheduler._memory_recent_peak_bytes = 0 + engine = MagicMock(spec=[]) + engine.scheduler = scheduler + entry = _make_entry("model-a", engine=engine) + enforcer._engine_pool._entries = {"model-a": entry} + + enforcer._recent_peak_bytes = 7 * 1024**3 + enforcer._propagate_memory_limit() + + assert scheduler._memory_recent_peak_bytes == 7 * 1024**3 + + class TestProperties: """Tests for enforcer properties.""" diff --git a/tests/test_scheduler_admission.py b/tests/test_scheduler_admission.py index 0ed0c6458..6954f3c89 100644 --- a/tests/test_scheduler_admission.py +++ b/tests/test_scheduler_admission.py @@ -2,13 +2,15 @@ """Tests for scheduler admission control (queue depth cap + admission_paused).""" from collections import deque -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest from omlx.exceptions import SchedulerQueueFullError from omlx.scheduler import Scheduler +GB = 1024**3 + @pytest.fixture def scheduler(): @@ -100,3 +102,125 @@ def test_default_false(self): s._prefill_memory_guard = False s._admission_paused = False assert s._admission_paused is False + + +def _preflight_scheduler( + hard_limit: int, recent_peak: int, peak: int, *, running=None +): + """Build a bare Scheduler wired for _preflight_memory_check. + + `peak` is the value the (mocked) memory_monitor estimates for the + prefill chunk; `recent_peak` is the propagated high-water mark. `running` + seeds self.running -- _predicted_prefill_peak_bytes folds recent_peak into + `current` only while requests are in-flight, so an empty dict (idle, the + default) means a lone request is judged on the instant reading alone. + Transient margin is set to 0 here so these tests isolate the recent_peak + folding behaviour; margin handling is covered separately. + """ + s = Scheduler.__new__(Scheduler) + s._prefill_memory_guard = True + s._memory_hard_limit_bytes = hard_limit + s._memory_recent_peak_bytes = recent_peak + s._prefill_transient_margin_bytes = 0 + s.running = running if running is not None else {} + s.prefilling = deque() + s.config = MagicMock(prefill_step_size=2048) + s.memory_monitor = MagicMock() + s.memory_monitor.estimate_prefill_peak_bytes = MagicMock(return_value=peak) + return s + + +def _preflight_request(): + r = MagicMock() + r.num_prompt_tokens = 8192 + r.cached_tokens = 0 + return r + + +class TestPreflightRecentPeak: + """_preflight_memory_check folds the recent high-water mark into `current` + while requests are in-flight, so it does not wave through a request during a + prefill trough that would wall the next chunk -- but it ignores recent_peak + when idle, so a stale prior-batch peak does not false-reject a lone request. + """ + + def test_rejects_on_recent_peak_when_inflight_and_instant_is_low(self): + """In-flight + instant active/phys low but recent_peak high -> reject. + + Models the mid-prefill trough: another request is running, the instant + reading dipped after a _sync_and_clear_cache, but recent_peak still + reflects the real in-flight footprint. Numbers are picked so low + peak + fits (an instant-only read would admit) but recent_peak + peak exceeds + the hard limit. This pins the high-water fold. + """ + hard_limit = 100 * GB + peak = 20 * GB + low = 10 * GB + high = 85 * GB + # Sanity: an instant-only read (low + peak) would have passed. + assert low + peak <= hard_limit + # Folding recent_peak (high + peak) must exceed the limit. + assert high + peak > hard_limit + + s = _preflight_scheduler( + hard_limit=hard_limit, + recent_peak=high, + peak=peak, + running={"r-other": object()}, + ) + with patch("omlx.scheduler.mx") as mock_mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=low + ): + mock_mx.get_active_memory.return_value = low + result = s._preflight_memory_check(_preflight_request()) + + assert result is not None + assert "Prefill would need" in result + + def test_admits_when_recent_peak_also_low(self): + """Control: in-flight, recent_peak low too -> the request passes.""" + hard_limit = 100 * GB + peak = 20 * GB + low = 10 * GB + + s = _preflight_scheduler( + hard_limit=hard_limit, + recent_peak=low, + peak=peak, + running={"r-other": object()}, + ) + with patch("omlx.scheduler.mx") as mock_mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=low + ): + mock_mx.get_active_memory.return_value = low + result = s._preflight_memory_check(_preflight_request()) + + assert result is None + + def test_lone_request_ignores_stale_recent_peak(self): + """Idle (no in-flight requests): a high recent_peak is NOT folded in, + so a lone request that fits on the instant reading is admitted. + + recent_peak when idle is a stale prior-batch high; folding it would + false-reject a request that physically fits. Same numbers as the + in-flight reject test, only running is empty -> opposite outcome. + """ + hard_limit = 100 * GB + peak = 20 * GB + low = 10 * GB + high = 85 * GB + # Folding recent_peak would exceed the limit (the in-flight case)... + assert high + peak > hard_limit + # ...but idle, only the instant reading counts, which fits. + assert low + peak <= hard_limit + + s = _preflight_scheduler( + hard_limit=hard_limit, recent_peak=high, peak=peak, running={} + ) + with patch("omlx.scheduler.mx") as mock_mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=low + ): + mock_mx.get_active_memory.return_value = low + result = s._preflight_memory_check(_preflight_request()) + + assert result is None diff --git a/tests/test_scheduler_budget_admission.py b/tests/test_scheduler_budget_admission.py new file mode 100644 index 000000000..254dc8f63 --- /dev/null +++ b/tests/test_scheduler_budget_admission.py @@ -0,0 +1,360 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the entry-point memory-budget admission and its relationship to the +per-chunk forward gate. + +The shared predicate _predicted_prefill_peak_bytes drives two outcomes: + - _schedule_waiting DEFERS (re-queues, does NOT reject) a request that would + breach the hard cap only because other requests are in-flight -- adaptive + concurrency. + - _preflight_memory_check REJECTS a lone request that cannot fit even alone. + +_prefill_forward_gate stays as a concurrent-drift backstop. Because admission +estimates the FULL prompt while the gate estimates a single CHUNK (and both use +the same margin), a correctly-admitted request cannot trip the gate under static +memory; the gate only fires when memory drifts up after admission. +""" + +from collections import deque +from unittest.mock import MagicMock, patch + +import pytest + +from omlx.memory_monitor import MemoryMonitor +from omlx.scheduler import Scheduler + +GB = 1024**3 + + +def _admission_scheduler( + *, + hard_limit, + recent_peak, + estimate, + margin, + running, + prefilling=None, + guard=True, + monitor=True, +): + """Bare Scheduler wired for _predicted_prefill_peak_bytes / admission.""" + s = Scheduler.__new__(Scheduler) + s._prefill_memory_guard = guard + s._memory_hard_limit_bytes = hard_limit + s._memory_recent_peak_bytes = recent_peak + s._prefill_transient_margin_bytes = margin + s.running = running + s.prefilling = prefilling if prefilling is not None else deque() + s.config = MagicMock(prefill_step_size=2048) + if monitor: + s.memory_monitor = MagicMock() + s.memory_monitor.estimate_prefill_peak_bytes = MagicMock( + return_value=estimate + ) + else: + s.memory_monitor = None + return s + + +def _request(*, num_prompt_tokens=8192, cached_tokens=0): + r = MagicMock() + r.request_id = "rid-1" + r.num_prompt_tokens = num_prompt_tokens + r.cached_tokens = cached_tokens + return r + + +class TestPredictedPrefillPeakBytes: + """Direct tests of the shared admission predicate.""" + + def test_sums_current_estimate_and_margin(self): + s = _admission_scheduler( + hard_limit=100 * GB, + recent_peak=0, + estimate=20 * GB, + margin=12 * GB, + running={}, + ) + with patch("omlx.scheduler.mx") as mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=30 * GB + ): + mx.get_active_memory.return_value = 30 * GB + predicted = s._predicted_prefill_peak_bytes(_request()) + assert predicted == 30 * GB + 20 * GB + 12 * GB + + def test_folds_recent_peak_when_running_or_prefilling(self): + """In-flight (running OR prefilling) -> recent_peak high-water folded; + fully idle -> instant reading only.""" + common = dict( + hard_limit=100 * GB, recent_peak=90 * GB, estimate=10 * GB, margin=0 + ) + s_idle = _admission_scheduler(**common, running={}, prefilling=deque()) + s_running = _admission_scheduler(**common, running={"o": object()}) + s_prefilling = _admission_scheduler( + **common, running={}, prefilling=deque([object()]) + ) + with patch("omlx.scheduler.mx") as mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=10 * GB + ): + mx.get_active_memory.return_value = 10 * GB + idle = s_idle._predicted_prefill_peak_bytes(_request()) + running = s_running._predicted_prefill_peak_bytes(_request()) + prefilling = s_prefilling._predicted_prefill_peak_bytes(_request()) + assert idle == 10 * GB + 10 * GB # instant reading only + assert running == 90 * GB + 10 * GB # folded: a decode is in flight + assert prefilling == 90 * GB + 10 * GB # folded: a chunked prefill too + + def test_cached_tokens_reduce_new_tokens(self): + s = _admission_scheduler( + hard_limit=100 * GB, + recent_peak=0, + estimate=5 * GB, + margin=0, + running={}, + ) + with patch("omlx.scheduler.mx") as mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=0 + ): + mx.get_active_memory.return_value = 0 + s._predicted_prefill_peak_bytes( + _request(num_prompt_tokens=8192, cached_tokens=8000) + ) + # estimate is asked for the 192 UNCACHED tokens, not the full 8192. + s.memory_monitor.estimate_prefill_peak_bytes.assert_called_once_with( + 192, 2048 + ) + + def test_none_when_guard_off(self): + s = _admission_scheduler( + hard_limit=100 * GB, + recent_peak=0, + estimate=5 * GB, + margin=0, + running={}, + guard=False, + ) + assert s._predicted_prefill_peak_bytes(_request()) is None + + def test_none_when_hard_limit_unset(self): + s = _admission_scheduler( + hard_limit=0, + recent_peak=0, + estimate=5 * GB, + margin=0, + running={}, + ) + assert s._predicted_prefill_peak_bytes(_request()) is None + + def test_none_when_monitor_missing(self): + s = _admission_scheduler( + hard_limit=100 * GB, + recent_peak=0, + estimate=5 * GB, + margin=0, + running={}, + monitor=False, + ) + assert s._predicted_prefill_peak_bytes(_request()) is None + + def test_none_when_no_new_tokens(self): + s = _admission_scheduler( + hard_limit=100 * GB, + recent_peak=0, + estimate=5 * GB, + margin=0, + running={}, + ) + req = _request(num_prompt_tokens=4096, cached_tokens=4096) + assert s._predicted_prefill_peak_bytes(req) is None + + def test_none_when_estimate_zero(self): + s = _admission_scheduler( + hard_limit=100 * GB, + recent_peak=0, + estimate=0, + margin=0, + running={}, + ) + with patch("omlx.scheduler.mx") as mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=0 + ): + mx.get_active_memory.return_value = 0 + assert s._predicted_prefill_peak_bytes(_request()) is None + + +class TestScheduleWaitingBudgetDefer: + """The defer branch in _schedule_waiting QUEUES, it does not reject.""" + + def _defer_scheduler( + self, *, hard_limit, recent_peak, estimate, margin, running=None, prefilling=None + ): + s = _admission_scheduler( + hard_limit=hard_limit, + recent_peak=recent_peak, + estimate=estimate, + margin=margin, + running=running if running is not None else {"r-other": object()}, + prefilling=prefilling, + ) + s.config = MagicMock(max_num_seqs=8, prefill_step_size=2048) + s._admission_paused = False + s._memory_limit_bytes = 0 # bypass the coarse generation soft-guard + s.batch_generator = MagicMock() + s._ensure_batch_generator = MagicMock() + return s + + def _waiting_request(self): + req = _request() + req.prompt_cache = None + req.remaining_tokens = None + req.prompt_token_ids = [1, 2, 3] + return req + + def test_defers_and_requeues_when_inflight_would_breach(self): + # predicted = max(10, 95) + 20 + 12 = 127 > 100 -> defer. + s = self._defer_scheduler( + hard_limit=100 * GB, + recent_peak=95 * GB, + estimate=20 * GB, + margin=12 * GB, + ) + req = self._waiting_request() + s.waiting = deque([req]) + + with patch("omlx.scheduler.mx") as mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=10 * GB + ): + mx.get_active_memory.return_value = 10 * GB + scheduled, rejected = s._schedule_waiting() + + # Nothing scheduled, nothing REJECTED -- the request is left queued. + assert scheduled == [] + assert rejected == [] + assert list(s.waiting) == [req] + + def test_defers_when_only_a_chunked_prefill_is_in_flight(self): + """The headline case: running is EMPTY but another request is + mid-chunked-prefill (self.prefilling). The new request must still be + deferred -- gating on self.running alone would stack a second prefill, + which is exactly the documented crash. + """ + # predicted = max(10, 95) + 20 + 12 = 127 > 100 -> defer. + s = self._defer_scheduler( + hard_limit=100 * GB, + recent_peak=95 * GB, + estimate=20 * GB, + margin=12 * GB, + running={}, # nothing decoding... + prefilling=deque([object()]), # ...but a chunked prefill is in flight + ) + req = self._waiting_request() + s.waiting = deque([req]) + + with patch("omlx.scheduler.mx") as mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=10 * GB + ): + mx.get_active_memory.return_value = 10 * GB + scheduled, rejected = s._schedule_waiting() + + assert scheduled == [] + assert rejected == [] + assert list(s.waiting) == [req] # queued behind the in-flight prefill + + def test_does_not_defer_when_predicted_fits(self): + # predicted = max(10, 30) + 20 + 12 = 62 <= 100 -> the budget branch + # does not fire. Asserted at the predicate so we do not have to drive + # the heavy insert path: a fitting request is never re-queued by the + # budget defer. + s = self._defer_scheduler( + hard_limit=100 * GB, + recent_peak=30 * GB, + estimate=20 * GB, + margin=12 * GB, + ) + with patch("omlx.scheduler.mx") as mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=10 * GB + ): + mx.get_active_memory.return_value = 10 * GB + predicted = s._predicted_prefill_peak_bytes(self._waiting_request()) + assert predicted is not None + assert predicted <= s._memory_hard_limit_bytes + + +class TestAdmissionDominatesGate: + """Property: a request admitted by the entry budget cannot trip the + per-chunk forward gate under static memory; the gate is a drift backstop.""" + + def _monitor(self): + m = MemoryMonitor(max_kv_cache_memory=64 * GB) + m.set_model_info( + num_layers=32, + num_kv_heads=8, + head_dim=128, + dtype_size=2, + num_attention_heads=32, + ) + return m + + def _scheduler_with(self, monitor, *, cap, current, margin): + s = Scheduler.__new__(Scheduler) + s._prefill_memory_guard = True + s._memory_hard_limit_bytes = cap + s._memory_recent_peak_bytes = current + s._prefill_transient_margin_bytes = margin + s.running = {"r-other": object()} + s.prefilling = deque() + s.config = MagicMock(prefill_step_size=2048) + s.memory_monitor = monitor + return s + + def test_admitted_request_never_trips_gate_under_static_memory(self): + monitor = self._monitor() + full_prompt, chunk, step = 8192, 256, 2048 + admission_estimate = monitor.estimate_prefill_peak_bytes(full_prompt, step) + gate_estimate = monitor.estimate_prefill_peak_bytes(chunk, step) + # Real estimate is monotonic in prompt length -> full >= per-chunk. + assert 0 < gate_estimate <= admission_estimate + + current = 80 * GB + margin = 12 * GB + # cap set exactly at the admission boundary: admission JUST passes. + cap = current + admission_estimate + margin + s = self._scheduler_with(monitor, cap=cap, current=current, margin=margin) + + with patch("omlx.scheduler.mx") as mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=current + ): + mx.get_active_memory.return_value = current + predicted = s._predicted_prefill_peak_bytes( + _request(num_prompt_tokens=full_prompt) + ) + assert predicted is not None and predicted <= cap # admitted + + # Same static memory: the gate, on any chunk of this request, must + # NOT raise (gate predicted = current + gate_estimate + margin + # <= current + admission_estimate + margin = cap). + s._prefill_forward_gate( + chunk, request_id="rid-1", loop_label="external" + ) # no RuntimeError + + def test_gate_fires_when_memory_drifts_up_after_admission(self): + monitor = self._monitor() + full_prompt, chunk, step = 8192, 256, 2048 + admission_estimate = monitor.estimate_prefill_peak_bytes(full_prompt, step) + + current = 80 * GB + margin = 12 * GB + cap = current + admission_estimate + margin + s = self._scheduler_with(monitor, cap=cap, current=current, margin=margin) + + # Other in-flight requests grow KV during this prefill: current drifts + # up well past where admission snapshotted it. The gate re-reads and + # catches what the admission snapshot could not. + drifted = current + admission_estimate + with patch("omlx.scheduler.mx") as mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=drifted + ): + mx.get_active_memory.return_value = drifted + with pytest.raises(RuntimeError, match="refused before forward"): + s._prefill_forward_gate( + chunk, request_id="rid-1", loop_label="external" + ) diff --git a/tests/test_scheduler_prefill_forward_gate.py b/tests/test_scheduler_prefill_forward_gate.py new file mode 100644 index 000000000..b4b301a31 --- /dev/null +++ b/tests/test_scheduler_prefill_forward_gate.py @@ -0,0 +1,355 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the forward-FRONT prefill memory gate (P0c). + +The gate (_prefill_forward_gate) predicts a prefill chunk's peak memory +BEFORE running self.model(...) and raises RuntimeError when it would breach +the hard cap, so the request is aborted cleanly instead of the transient +landing on the Metal ceiling and kernel-panicking the machine. The legacy +chunk-END check only fires after the allocation has already happened, which +on Apple Silicon is too late. + +Strategy: pure mocks, no model load. The discriminating assertion is that +when the predicted peak exceeds the cap the model forward is NOT called -- +on pre-change code (no forward-front gate) the forward WOULD run. +""" + +from unittest.mock import MagicMock, patch + +import mlx.core as mx +import pytest + +from omlx.request import Request, RequestStatus, SamplingParams +from omlx.scheduler import Scheduler, SchedulerConfig, _PrefillState + +GB = 1024**3 + + +# --------------------------------------------------------------------------- +# Direct unit tests of _prefill_forward_gate +# --------------------------------------------------------------------------- + + +def _gate_scheduler( + *, + hard_limit: int, + recent_peak: int, + estimate: int, + margin: int, + guard: bool = True, + monitor: bool = True, +): + """Build a bare Scheduler wired only for _prefill_forward_gate.""" + s = Scheduler.__new__(Scheduler) + s._prefill_memory_guard = guard + s._memory_hard_limit_bytes = hard_limit + s._memory_recent_peak_bytes = recent_peak + s._prefill_transient_margin_bytes = margin + s.config = MagicMock(prefill_step_size=2048) + if monitor: + s.memory_monitor = MagicMock() + s.memory_monitor.estimate_prefill_peak_bytes = MagicMock( + return_value=estimate + ) + else: + s.memory_monitor = None + return s + + +def _call_gate(s, chunk_tokens, *, instant): + """Invoke the gate with patched instant memory probes.""" + with patch("omlx.scheduler.mx") as mock_mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=instant + ): + mock_mx.get_active_memory.return_value = instant + s._prefill_forward_gate( + chunk_tokens, request_id="rid-1", loop_label="external" + ) + + +class TestPrefillForwardGateUnit: + """Direct tests of the gate predicate.""" + + def test_raises_when_predicted_peak_exceeds_cap(self): + """current(high-water) + estimate + margin > cap -> RuntimeError. + + Numbers chosen so the instant reading alone (low) + estimate would + fit, but the high-water recent_peak + estimate + margin overflow. + """ + hard = 107 * GB + estimate = 2 * GB + margin = 10 * GB + instant = 50 * GB + recent_peak = 96 * GB + # Instant + estimate (no margin) fits; this is the trough the legacy + # check could read. + assert instant + estimate <= hard + # High-water + estimate + margin overflows -> must refuse. + assert recent_peak + estimate + margin > hard + + s = _gate_scheduler( + hard_limit=hard, + recent_peak=recent_peak, + estimate=estimate, + margin=margin, + ) + with pytest.raises(RuntimeError, match="refused before forward"): + _call_gate(s, 256, instant=instant) + + def test_passes_when_predicted_peak_fits(self): + """current + estimate + margin <= cap -> no raise.""" + hard = 107 * GB + s = _gate_scheduler( + hard_limit=hard, + recent_peak=80 * GB, + estimate=2 * GB, + margin=10 * GB, + ) + # 80 + 2 + 10 = 92 < 107. + _call_gate(s, 256, instant=80 * GB) # must not raise + + def test_margin_is_what_tips_it_over(self): + """Without the margin it would pass; the margin alone forces refusal. + + Pins that the margin term is actually applied (not dropped). + """ + hard = 100 * GB + estimate = 1 * GB + instant = 90 * GB + recent_peak = 90 * GB + # current + estimate (no margin) = 91 < 100 -> would pass. + assert recent_peak + estimate < hard + # current + estimate + margin = 101 > 100 -> must refuse. + margin = 10 * GB + assert recent_peak + estimate + margin > hard + + s = _gate_scheduler( + hard_limit=hard, + recent_peak=recent_peak, + estimate=estimate, + margin=margin, + ) + with pytest.raises(RuntimeError): + _call_gate(s, 256, instant=instant) + + # Same setup, margin=0 -> passes (control). + s0 = _gate_scheduler( + hard_limit=hard, + recent_peak=recent_peak, + estimate=estimate, + margin=0, + ) + _call_gate(s0, 256, instant=instant) # must not raise + + def test_uses_recent_peak_high_water_not_just_instant(self): + """A mid-prefill trough in the instant reading must not mask the + real footprint: recent_peak high + low instant still refuses.""" + hard = 107 * GB + s = _gate_scheduler( + hard_limit=hard, + recent_peak=100 * GB, # real footprint + estimate=2 * GB, + margin=10 * GB, + ) + # Instant reads a trough at 50GB; without recent_peak it would pass. + assert 50 * GB + 2 * GB + 10 * GB < hard + with pytest.raises(RuntimeError): + _call_gate(s, 256, instant=50 * GB) + + def test_noop_when_guard_off(self): + s = _gate_scheduler( + hard_limit=107 * GB, + recent_peak=200 * GB, + estimate=200 * GB, + margin=10 * GB, + guard=False, + ) + _call_gate(s, 256, instant=200 * GB) # guard off -> never raises + + def test_noop_when_hard_limit_unset(self): + s = _gate_scheduler( + hard_limit=0, + recent_peak=200 * GB, + estimate=200 * GB, + margin=10 * GB, + ) + _call_gate(s, 256, instant=200 * GB) # no limit -> never raises + + def test_noop_when_monitor_missing(self): + s = _gate_scheduler( + hard_limit=107 * GB, + recent_peak=200 * GB, + estimate=200 * GB, + margin=10 * GB, + monitor=False, + ) + _call_gate(s, 256, instant=200 * GB) # no monitor -> never raises + + def test_noop_when_estimate_zero(self): + """estimate==0 means the model can't be estimated -> leave it to the + legacy chunk-end check, do not raise here.""" + s = _gate_scheduler( + hard_limit=107 * GB, + recent_peak=200 * GB, + estimate=0, + margin=10 * GB, + ) + _call_gate(s, 256, instant=200 * GB) # estimate 0 -> never raises + + def test_noop_when_chunk_zero(self): + s = _gate_scheduler( + hard_limit=107 * GB, + recent_peak=200 * GB, + estimate=2 * GB, + margin=10 * GB, + ) + _call_gate(s, 0, instant=200 * GB) # nothing to process -> never raises + + +# --------------------------------------------------------------------------- +# Integration: gate fires BEFORE the model forward in the real chunked loop +# --------------------------------------------------------------------------- + + +def _integration_scheduler(*, hard_gb: float, estimate_bytes: int, margin_gb: float): + """Scheduler with a mock model, hard cap on but soft off (so the adaptive + throttle passes through and only the forward-front gate can fire).""" + model = MagicMock() + model.layers = [] + tokenizer = MagicMock() + tokenizer.eos_token_id = 2 + config = SchedulerConfig( + max_num_seqs=8, + prefill_step_size=256, + chunked_prefill=True, + paged_cache_block_size=0, + ) + s = Scheduler(model=model, tokenizer=tokenizer, config=config) + s.batch_generator = MagicMock() + # Soft limit 0 -> _adaptive_chunk_size is a pure passthrough. + s._memory_limit_bytes = 0 + s._memory_hard_limit_bytes = int(hard_gb * GB) + s._prefill_memory_guard = True + s._prefill_transient_margin_bytes = int(margin_gb * GB) + s.memory_monitor = MagicMock() + s.memory_monitor.estimate_prefill_peak_bytes = MagicMock( + return_value=estimate_bytes + ) + return s, model + + +def _prefill_state(n_tokens: int) -> _PrefillState: + req = Request( + request_id="rid-int", + prompt=list(range(n_tokens + 1)), + sampling_params=SamplingParams(max_tokens=8), + ) + req.prompt_token_ids = list(range(n_tokens + 1)) + req.num_prompt_tokens = n_tokens + 1 + req.status = RequestStatus.WAITING + return _PrefillState( + request=req, + cache=[], + tokens_remaining=mx.array(list(range(n_tokens)))[None], + last_token=[n_tokens], + tokens_processed=0, + base_size=0, + emitted_boundaries={}, + boundary_enabled=False, + block_size=0, + total_length=n_tokens + 1, + ) + + +class TestForwardGateBlocksForward: + """The gate must abort the chunk BEFORE self.model(...) runs.""" + + def test_over_cap_does_not_call_model_forward(self): + """Predicted peak over cap -> RuntimeError raised and model NOT called. + + This is the discriminating assertion that pins the fix: pre-change + code (no forward-front gate) reaches self.model(chunk, ...) and the + transient lands on the cap (kernel panic on real hardware). With the + gate, the forward never runs. + """ + # recent_peak high (set via instant probes) + estimate + margin > cap. + s, model = _integration_scheduler( + hard_gb=107.0, estimate_bytes=2 * GB, margin_gb=10 * 1.0 + ) + state = _prefill_state(n_tokens=200) + + high = int(100 * GB) + with patch( + "omlx.scheduler.mx.get_active_memory", return_value=high + ), patch("omlx.scheduler.get_phys_footprint", return_value=high), patch( + "omlx.scheduler.mx.eval" + ) as mock_eval: + with pytest.raises(RuntimeError, match="refused before forward"): + s._step_prefill_chunk(state) + + # The whole point: the model forward must not have executed. + model.assert_not_called() + mock_eval.assert_not_called() + + def test_under_cap_runs_model_forward(self): + """Predicted peak under cap -> forward runs as normal (control).""" + s, model = _integration_scheduler( + hard_gb=107.0, estimate_bytes=1 * GB, margin_gb=2.0 + ) + state = _prefill_state(n_tokens=200) + + low = int(50 * GB) # 50 + 1 + 2 = 53 < 107 + with patch( + "omlx.scheduler.mx.get_active_memory", return_value=low + ), patch("omlx.scheduler.get_phys_footprint", return_value=low), patch( + "omlx.scheduler.mx.eval" + ), patch("omlx.scheduler._sync_and_clear_cache"), patch( + "omlx.scheduler.get_prefill_tracker" + ): + done = s._step_prefill_chunk(state) + + # Forward ran exactly once; prefill consumed the only chunk. + assert model.call_count == 1 + assert done is True + + +class TestForwardGateExternalLoopWiring: + """Sanity that the external loop wiring calls the gate before the forward. + + Patch _prefill_forward_gate to raise; the model forward must not run. + Uses a tiny text-only request through _do_external_prefill. + """ + + def test_external_loop_calls_gate_before_forward(self): + model = MagicMock() + model.layers = [] + tokenizer = MagicMock() + tokenizer.eos_token_id = 2 + config = SchedulerConfig( + max_num_seqs=8, + prefill_step_size=256, + chunked_prefill=False, + paged_cache_block_size=0, + ) + s = Scheduler(model=model, tokenizer=tokenizer, config=config) + + req = Request( + request_id="rid-ext", + prompt=[1, 2, 3, 4, 5], + sampling_params=SamplingParams(max_tokens=8), + ) + req.prompt_token_ids = [1, 2, 3, 4, 5] + req.num_prompt_tokens = 5 + + with patch.object( + s, + "_prefill_forward_gate", + side_effect=RuntimeError("Prefill refused before forward"), + ) as mock_gate, patch( + "omlx.scheduler.make_prompt_cache", return_value=[] + ): + with pytest.raises(RuntimeError, match="refused before forward"): + s._do_external_prefill(req, [1, 2, 3, 4, 5], None) + + mock_gate.assert_called_once() + # Gate raised -> forward must not have run. + model.assert_not_called()