From 35feba7d7ff42b2487b997e842d41f661012c663 Mon Sep 17 00:00:00 2001 From: yuanwei Date: Fri, 5 Jun 2026 16:34:50 -0700 Subject: [PATCH] fix(memory): refuse oversized prefill chunks before the forward (anti-panic) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Stop the m5max whole-machine watchdog panic caused by a large model's prefill transient breaching the Metal cap. The model loads and runs normally; only an individual request that cannot fit is refused (503-class), not the model. Root cause: oMLX bounds memory reactively (enforcer polls phys every 1s) and the in-prefill memory check ran AFTER self.model()+mx.eval. On Apple UMA a chunk that overshoots the Metal wired limit kernel-panics the whole machine, so the post-forward Python check never runs. glm4.5-air-106b (85GB) on a 128GB box left ~22GB for KV+prefill; a per-chunk MoE-dequant transient (measured up to 7.4GB, peaks hit 110GB vs the 107.5GB cap) crossed the cap and rebooted the box twice. - Scheduler._prefill_forward_gate: before each prefill forward (external loop + chunked-step mirror), predict current(high-water: max active/phys/recent_peak) + estimate_prefill_peak_bytes(KV+SDPA) + a conservative transient margin; if it would exceed the hard cap, raise BEFORE the forward. The existing #1405 cleanup converts that RuntimeError into a finish_reason="error" output -- request refused cleanly, machine not crashed. Legacy post-forward check stays as backstop. - New MemorySettings.prefill_transient_margin_gb (default 10GB), propagated settings -> enforcer -> scheduler. estimate_prefill_peak_bytes does not model the MoE expert-dequant spike, so this margin carries the guarantee: margin (10) > worst observed single-step jump (7.4GB). - Preflight now also maxes against the enforcer recent high-water mark. - Reverts model_load_prefill_headroom_gb load rejection (user feedback: a model must be loadable; refuse the request, not the model). Honest residual: the gate reads current just after the prior chunk's cache clear (active trough), leaning on phys_footprint stickiness + recent_peak to avoid a trough misread; a misread could still admit a crashing chunk. Not a literal never-panic guarantee -- needs on-hardware [memgate] log validation; the real fix is preemptive KV offload (separate work). Tests: 12 new (gate raise/pass, margin-is-the-trip, no-op guards, integration asserting model forward NOT called over-cap), verified to fail with gate neutered. Full suite 4542 pass / 3 known api_key fail / 19 skip on m2max -- zero regression. --- 止住 m5max 因大模型 prefill 瞬时撞穿 Metal cap 导致的整机 watchdog panic. 模型正常 load 正常用; 只拒放不下的单个请求(503 级), 不禁模型. 根因: oMLX 内存是 reactive 管控(enforcer 每 1s poll phys), prefill 内存检查在 self.model()+mx.eval 之后. Apple UMA 上 chunk 撞穿 Metal wired limit 直接整机 kernel panic, forward 后的 Python 检查根本来不及跑. glm4.5-air-106b(85GB)在 128GB 机只剩 ~22GB 给 KV+prefill; 每 chunk 的 MoE 反量化瞬时(实测达 7.4GB, 峰冲到 110GB vs 107.5 cap)撞穿后整机重启两次. - Scheduler._prefill_forward_gate: 每个 prefill forward 前(external loop + chunked-step mirror), 预测 current(高水位 max active/phys/recent_peak) + estimate_prefill_peak_bytes(KV+SDPA) + 保守 transient margin; 超 hard cap 就在 forward 前 raise. 现有 #1405 cleanup 把 RuntimeError 转成 finish_reason=error 输出 -- 干净拒请求, 不崩机. forward 后旧检查留作兜底. - 新增 MemorySettings.prefill_transient_margin_gb(默认 10GB), settings -> enforcer -> scheduler 传播. estimate 不含 MoE 反量化瞬时, margin 扛保证: margin(10) > 实测 最坏单步跳变(7.4GB). - preflight 也改用 enforcer 近期高水位取 max. - 撤掉 model_load_prefill_headroom_gb 拒 load(用户反馈: 模型必须能 load, 拒请求不拒模型). 诚实残余: gate 读 current 在上个 chunk cache clear 后(active 谷), 靠 phys_footprint 黏性 + recent_peak 避免读谷; 读到谷仍可能放行会崩的 chunk. 不是字面"绝不 panic" -- 需真机 [memgate] log 验证; 真治本是抢占式 KV offload(独立工作). 测试: 12 个新增, 验证 gate 失效时会 fail. 完整套件 m2max 4542 pass / 3 已知 api_key fail / 19 skip -- 零回归. --- omlx/process_memory_enforcer.py | 27 ++ omlx/scheduler.py | 130 ++++++- omlx/server.py | 1 + omlx/settings.py | 31 ++ tests/test_process_memory_enforcer.py | 57 +++ tests/test_scheduler_admission.py | 78 +++- tests/test_scheduler_prefill_forward_gate.py | 355 +++++++++++++++++++ 7 files changed, 677 insertions(+), 2 deletions(-) create mode 100644 tests/test_scheduler_prefill_forward_gate.py diff --git a/omlx/process_memory_enforcer.py b/omlx/process_memory_enforcer.py index 7d11fee31..41f3370d1 100644 --- a/omlx/process_memory_enforcer.py +++ b/omlx/process_memory_enforcer.py @@ -33,6 +33,7 @@ import logging import subprocess import sys +from collections import deque from typing import TYPE_CHECKING, Any import mlx.core as mx @@ -291,6 +292,7 @@ def __init__( hard_threshold: float = 0.95, prefill_safe_zone_ratio: float = 0.80, prefill_min_chunk_tokens: int = 32, + prefill_transient_margin_gb: float = 0.0, ): """ Initialize the process memory enforcer. @@ -317,6 +319,11 @@ def __init__( prefill_safe_zone_ratio: Fraction of hard cap below which prefill runs at full chunk size; above triggers adaptive shrink. prefill_min_chunk_tokens: Floor for adaptive shrink. + prefill_transient_margin_gb: Conservative margin added to the + modelled per-chunk prefill peak by the scheduler's + forward-front gate, covering the MoE expert-dequant activation + spike that estimate_prefill_peak_bytes does not model. + Propagated to each scheduler. 0 = no extra margin. """ self._engine_pool = engine_pool self._memory_guard_tier = self._normalize_tier(memory_guard_tier) @@ -331,6 +338,9 @@ def __init__( self._hard_threshold = hard_threshold self._prefill_safe_zone_ratio = prefill_safe_zone_ratio self._prefill_min_chunk_tokens = prefill_min_chunk_tokens + self._prefill_transient_margin_bytes = max( + 0, int(prefill_transient_margin_gb * 1024**3) + ) self._task: asyncio.Task | None = None self._running = False # Most recently observed pressure level, consumed by scheduler / @@ -340,6 +350,13 @@ def __init__( # or the call failed). Used by the admin dashboard to surface a # warning when the kernel iogpu.wired_limit_mb is below this. self._metal_wired_limit_request: int = 0 + # Rolling window of recent usage readings + their high-water mark. + # Prefill memory dips into a trough between chunks, so the instant + # reading can read low mid-prefill; preflight admission consults this + # peak instead so it does not wave through a request that will wall + # the next chunk. Updated on every poll iteration. + self._usage_window: deque[int] = deque(maxlen=5) + self._recent_peak_bytes: int = 0 @staticmethod def _normalize_tier(tier: str) -> str: @@ -503,6 +520,10 @@ def get_final_ceiling(self) -> int: """Public accessor used by engine_pool pre-load admission.""" return self._get_hard_limit_bytes() + def recent_peak_bytes(self) -> int: + """Recent high-water memory usage over the last few poll ticks.""" + return self._recent_peak_bytes + def _soft_bytes(self) -> int: """Soft watermark: ceiling * soft_threshold.""" ceiling = self._get_hard_limit_bytes() @@ -589,6 +610,10 @@ def _propagate_memory_limit(self) -> None: scheduler._admission_paused = admission_paused scheduler._prefill_safe_zone_ratio = self._prefill_safe_zone_ratio scheduler._prefill_min_chunk_tokens = self._prefill_min_chunk_tokens + scheduler._prefill_transient_margin_bytes = ( + self._prefill_transient_margin_bytes + ) + scheduler._memory_recent_peak_bytes = self._recent_peak_bytes bg = getattr(scheduler, "batch_generator", None) if bg is not None and hasattr(bg, "_memory_limit_bytes"): bg._memory_limit_bytes = soft_limit @@ -671,6 +696,8 @@ async def _check_and_enforce(self) -> None: return current = self._current_usage_bytes() + self._usage_window.append(current) + self._recent_peak_bytes = max(self._usage_window) if self._usage_window else current soft = int(ceiling * self._soft_threshold) hard = int(ceiling * self._hard_threshold) prev_level = self._pressure_level diff --git a/omlx/scheduler.py b/omlx/scheduler.py index 253845bf9..86ff37fc0 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -796,10 +796,23 @@ def __init__( # soft_threshold. Schedulers stop admitting new prefills while this is # set; in-flight requests proceed. self._admission_paused: bool = False + # Recent high-water memory usage, propagated from ProcessMemoryEnforcer. + # Preflight admission maxes the instant reading against this so it does + # not wave through a request during a prefill trough that would wall + # the next chunk. 0 until the enforcer sets it. + self._memory_recent_peak_bytes: int = 0 # Adaptive prefill throttle params, propagated from enforcer. # Until set, _adaptive_chunk_size is a no-op (returns requested as-is). self._prefill_safe_zone_ratio: float = 0.80 self._prefill_min_chunk_tokens: int = 32 + # Conservative transient margin (bytes) added to the modelled per-chunk + # prefill peak by the forward-front gate (_prefill_forward_gate). + # estimate_prefill_peak_bytes only models KV + SDPA; it does NOT model + # the MoE expert-dequant activation spike, which on glm4.5-air-106b + # (MoE) is the dominant single-step transient. Sized from the observed + # worst-case single-step current jump on m5max (see MemorySettings + # .prefill_transient_margin_gb). 0 until the enforcer propagates it. + self._prefill_transient_margin_bytes: int = 0 # EWMA estimator of per-token chunk transient bytes, used by # _adaptive_chunk_size in the caution zone. Owned per-scheduler. _tracker_model_id = "" @@ -1729,6 +1742,97 @@ def _apply_turboquant_kv_convert(self, prompt_cache: list[Any]) -> None: f"cache layers to {bits}-bit{skip_msg}" ) + def _prefill_forward_gate( + self, + chunk_tokens: int, + *, + request_id: str, + loop_label: str, + ) -> None: + """Forward-FRONT memory gate: refuse a prefill chunk BEFORE it runs. + + The chunk-end check (after self.model(...) + mx.eval) only fires once + the transient has already been allocated -- on Apple Silicon a chunk + that overshoots the Metal cap kernel-panics the whole machine, so an + after-the-fact Python check never runs. This predicts the next chunk's + peak and raises BEFORE the forward when it would exceed the hard cap, + so the request is aborted cleanly (the #1405 cleanup paths convert this + RuntimeError into a finish_reason="error" output) instead of crashing. + + predicted_peak = current(high-water) + estimate(KV+SDPA) + margin + - current: max(active, phys_footprint, recent_peak). recent_peak is + the enforcer's rolling high-water mark, so a mid-prefill trough in + the instant reading does not mask the real footprint. + - estimate: memory_monitor.estimate_prefill_peak_bytes models this + chunk's KV + SDPA activation only. + - margin: _prefill_transient_margin_bytes covers what the estimate + does NOT model -- chiefly the MoE expert-dequant activation spike, + the dominant single-step transient on MoE models like glm4.5-air. + + No-op (returns) when the guard is off, the hard limit is unset, the + memory_monitor is missing, or the estimate is unavailable (0) -- in + every such case the legacy chunk-end check remains the only line of + defense, exactly as before this gate existed. + + At chunk granularity the KV+SDPA estimate is small, so the margin is + the effective trip point (gate fires once current ~> cap - margin). + Correctness depends on `current` reflecting the true high-water mark: + this iteration runs right after the previous chunk's + _sync_and_clear_cache, when mx.get_active_memory() troughs, so the + gate leans on get_phys_footprint() + the propagated recent_peak to + avoid reading a trough. See MemorySettings.prefill_transient_margin_gb. + + Raises: + RuntimeError: when the predicted peak exceeds the hard limit. + """ + if not self._prefill_memory_guard: + return + if self._memory_hard_limit_bytes <= 0: + return + if self.memory_monitor is None: + return + if chunk_tokens <= 0: + return + + estimate = self.memory_monitor.estimate_prefill_peak_bytes( + chunk_tokens, self.config.prefill_step_size + ) + if estimate == 0: + return # can't estimate this model -> leave it to the chunk-end check + + predicted_transient = estimate + self._prefill_transient_margin_bytes + current = max( + mx.get_active_memory(), + get_phys_footprint(), + self._memory_recent_peak_bytes, + ) + predicted_peak = current + predicted_transient + + if predicted_peak > self._memory_hard_limit_bytes: + logger.warning( + "[memgate:%s] rid=%s refusing prefill chunk (n=%d) BEFORE " + "forward: predicted peak %.3fGB = current %.3fGB + transient " + "%.3fGB (estimate %.3fGB + margin %.3fGB) exceeds hard cap " + "%.3fGB. Aborting request to avoid a Metal-cap kernel panic.", + loop_label, + request_id, + chunk_tokens, + predicted_peak / 1024**3, + current / 1024**3, + predicted_transient / 1024**3, + estimate / 1024**3, + self._prefill_transient_margin_bytes / 1024**3, + self._memory_hard_limit_bytes / 1024**3, + ) + raise RuntimeError( + "Prefill refused before forward: predicted peak " + f"{predicted_peak / 1024**3:.1f}GB (current " + f"{current / 1024**3:.1f}GB + transient " + f"{predicted_transient / 1024**3:.1f}GB) would exceed the " + f"memory ceiling {self._memory_hard_limit_bytes / 1024**3:.1f}GB. " + "Reduce context length or increase --max-process-memory." + ) + def _do_external_prefill( self, request: "Request", @@ -1885,6 +1989,15 @@ def _do_external_prefill( extra_kwargs, n_to_process ) + # Forward-FRONT gate: predict this chunk's peak and refuse BEFORE + # the forward if it would breach the Metal cap (post-forward checks + # cannot save us -- the overshoot kernel-panics the machine). + self._prefill_forward_gate( + n_to_process, + request_id=request.request_id, + loop_label="external", + ) + _throttle_pre = get_phys_footprint() self.model(input_arr[:, :n_to_process], cache=prompt_cache, **model_kwargs) mx.eval([c.state for c in prompt_cache]) @@ -2223,6 +2336,17 @@ def _step_prefill_chunk(self, state: _PrefillState) -> bool: chunk = state.tokens_remaining[:, :n] state.tokens_remaining = state.tokens_remaining[:, n:] + + # Forward-FRONT gate: predict this chunk's peak and refuse BEFORE the + # forward if it would breach the Metal cap. Mirrors the external loop; + # raises RuntimeError that _advance_chunked_prefills converts into a + # finish_reason="error" output without crashing the machine. + self._prefill_forward_gate( + n, + request_id=state.request.request_id, + loop_label="chunked_step", + ) + _throttle_pre = get_phys_footprint() self.model(chunk, cache=state.cache) mx.eval([c.state for c in state.cache]) @@ -4541,7 +4665,11 @@ def _preflight_memory_check(self, request: "Request") -> str | None: if peak == 0: return None # can't estimate, skip - current = max(mx.get_active_memory(), get_phys_footprint()) + current = max( + mx.get_active_memory(), + get_phys_footprint(), + self._memory_recent_peak_bytes, + ) if current + peak > self._memory_hard_limit_bytes: from .utils.hardware import format_bytes diff --git a/omlx/server.py b/omlx/server.py index 0d5df6745..9f0d9fb69 100644 --- a/omlx/server.py +++ b/omlx/server.py @@ -359,6 +359,7 @@ async def lifespan(app: FastAPI): hard_threshold=mem_cfg.hard_threshold, prefill_safe_zone_ratio=mem_cfg.prefill_safe_zone_ratio, prefill_min_chunk_tokens=mem_cfg.prefill_min_chunk_tokens, + prefill_transient_margin_gb=mem_cfg.prefill_transient_margin_gb, ) _server_state.process_memory_enforcer = enforcer _server_state.engine_pool._process_memory_enforcer = enforcer diff --git a/omlx/settings.py b/omlx/settings.py index adba2a6d8..d9a4c631f 100644 --- a/omlx/settings.py +++ b/omlx/settings.py @@ -391,6 +391,33 @@ class MemorySettings: # aborted via the same cleanup path the hard-limit RuntimeError uses. prefill_safe_zone_ratio: float = 0.80 prefill_min_chunk_tokens: int = 32 + # Conservative transient margin added to the modelled per-chunk prefill + # peak by the scheduler's forward-FRONT memory gate. The model in + # memory_monitor.estimate_prefill_peak_bytes only accounts for KV + SDPA; + # it does NOT model the MoE expert-dequant activation spike, which on a + # MoE model (glm4.5-air-106b) is the dominant single-step transient. The + # gate refuses a chunk before its forward when current + estimate + this + # margin would breach the hard cap, so the transient never actually lands + # on the Metal ceiling (which would kernel-panic the whole machine, an + # after-the-fact Python check cannot catch it). + # + # At chunk granularity the KV+SDPA estimate is tiny (~0.3GB for a + # 256-token GLM chunk), so this margin IS the safety mechanism. The + # load-bearing guarantee is: margin > the worst-case single-step memory + # jump. Across the 2026-06-06 m5max glm4.5-air-106b crash log the max + # trough->peak single-step delta was 7.44GB (peak overshoot reached + # 110.4GB vs a 107.5GB cap). With margin=10GB the gate effectively fires + # once current exceeds ~cap - margin (~97GB); since every observed forward + # jumps <=7.44GB < 10GB, any chunk that would land above the cap starts + # from a current already past that trip point, so the gate refuses it + # before the forward. 10 = ceil(7.44) padded for an unobserved larger + # spike. NOTE: this holds only while `current` is read at true high-water + # at gate time (it relies on phys_footprint stickiness + recent_peak to + # mask the post-_sync_and_clear_cache trough in mx.get_active_memory()); + # watch the [memgate]/[memcheck] logs on hardware and raise the margin if + # a sub-trip-point trough precedes an over-cap chunk. Set to 0 to disable + # the extra margin (the gate then uses the bare KV+SDPA estimate). + prefill_transient_margin_gb: float = 10.0 def to_dict(self) -> dict[str, Any]: """Convert to dictionary.""" @@ -402,6 +429,7 @@ def to_dict(self) -> dict[str, Any]: "hard_threshold": self.hard_threshold, "prefill_safe_zone_ratio": self.prefill_safe_zone_ratio, "prefill_min_chunk_tokens": self.prefill_min_chunk_tokens, + "prefill_transient_margin_gb": self.prefill_transient_margin_gb, } @classmethod @@ -440,6 +468,9 @@ def from_dict(cls, data: dict[str, Any]) -> MemorySettings: prefill_min_chunk_tokens=int( data.get("prefill_min_chunk_tokens", 32) ), + prefill_transient_margin_gb=float( + data.get("prefill_transient_margin_gb", 10.0) + ), ) diff --git a/tests/test_process_memory_enforcer.py b/tests/test_process_memory_enforcer.py index b456ecc66..64ae81c0b 100644 --- a/tests/test_process_memory_enforcer.py +++ b/tests/test_process_memory_enforcer.py @@ -965,6 +965,63 @@ async def test_check_and_enforce_walks_caps_on_soft(self, enforcer): scheduler.adjust_store_cache_cap.assert_called_with("soft") +class TestRecentPeakTracking: + """Tests for recent-peak high-water tracking across poll ticks.""" + + @pytest.mark.asyncio + async def test_recent_peak_is_window_max(self, enforcer): + """After several poll ticks, recent_peak == max over the window. + + The window update lives after the ceiling > 0 early return in + _check_and_enforce, so the fixture's positive ceiling (10 GB) is + required for the update to run at all. + """ + readings = [3 * 1024**3, 5 * 1024**3, 2 * 1024**3, 4 * 1024**3] + with patch("omlx.process_memory_enforcer.mx") as mock_mx, patch( + "omlx.process_memory_enforcer.get_phys_footprint", return_value=0 + ): + mock_mx.get_active_memory.side_effect = _cycling(readings) + for _ in readings: + await enforcer._check_and_enforce() + + assert enforcer.recent_peak_bytes() == 5 * 1024**3 + + @pytest.mark.asyncio + async def test_recent_peak_drops_after_window_slides(self, enforcer): + """Old high readings age out once they leave the maxlen=5 window.""" + # Feed one big reading, then enough small ones to push it out of the + # 5-slot window. + big = 9 * 1024**3 + small = 1 * 1024**3 + readings = [big, small, small, small, small, small] + with patch("omlx.process_memory_enforcer.mx") as mock_mx, patch( + "omlx.process_memory_enforcer.get_phys_footprint", return_value=0 + ): + mock_mx.get_active_memory.side_effect = _cycling(readings) + for _ in readings: + await enforcer._check_and_enforce() + + # After 6 ticks the first (big) reading has slid out of the window, + # leaving only small readings. + assert enforcer.recent_peak_bytes() == small + + def test_propagates_recent_peak_to_scheduler(self, enforcer): + """_propagate_memory_limit pushes recent_peak onto each scheduler.""" + scheduler = MagicMock(spec=[]) + scheduler._memory_limit_bytes = 0 + scheduler._memory_hard_limit_bytes = 0 + scheduler._memory_recent_peak_bytes = 0 + engine = MagicMock(spec=[]) + engine.scheduler = scheduler + entry = _make_entry("model-a", engine=engine) + enforcer._engine_pool._entries = {"model-a": entry} + + enforcer._recent_peak_bytes = 7 * 1024**3 + enforcer._propagate_memory_limit() + + assert scheduler._memory_recent_peak_bytes == 7 * 1024**3 + + class TestProperties: """Tests for enforcer properties.""" diff --git a/tests/test_scheduler_admission.py b/tests/test_scheduler_admission.py index 0ed0c6458..1640ef8d3 100644 --- a/tests/test_scheduler_admission.py +++ b/tests/test_scheduler_admission.py @@ -2,13 +2,15 @@ """Tests for scheduler admission control (queue depth cap + admission_paused).""" from collections import deque -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest from omlx.exceptions import SchedulerQueueFullError from omlx.scheduler import Scheduler +GB = 1024**3 + @pytest.fixture def scheduler(): @@ -100,3 +102,77 @@ def test_default_false(self): s._prefill_memory_guard = False s._admission_paused = False assert s._admission_paused is False + + +def _preflight_scheduler(hard_limit: int, recent_peak: int, peak: int): + """Build a bare Scheduler wired for _preflight_memory_check. + + `peak` is the value the (mocked) memory_monitor estimates for the + prefill chunk; `recent_peak` is the propagated high-water mark. + """ + s = Scheduler.__new__(Scheduler) + s._prefill_memory_guard = True + s._memory_hard_limit_bytes = hard_limit + s._memory_recent_peak_bytes = recent_peak + s.config = MagicMock(prefill_step_size=2048) + s.memory_monitor = MagicMock() + s.memory_monitor.estimate_prefill_peak_bytes = MagicMock(return_value=peak) + return s + + +def _preflight_request(): + r = MagicMock() + r.num_prompt_tokens = 8192 + r.cached_tokens = 0 + return r + + +class TestPreflightRecentPeak: + """_preflight_memory_check uses the recent high-water mark, not just the + instant reading, so it does not wave through a request during a prefill + trough that would wall the next chunk.""" + + def test_rejects_on_recent_peak_when_instant_is_low(self): + """Instant active/phys low but recent_peak high -> reject. + + Picks numbers so that low + peak fits (pre-change behaviour would + admit) but recent_peak + peak exceeds the hard limit. This pins the + fix. + """ + hard_limit = 100 * GB + peak = 20 * GB + low = 10 * GB + high = 85 * GB + # Sanity: old code (low + peak) would have passed. + assert low + peak <= hard_limit + # New code (high + peak) must exceed the limit. + assert high + peak > hard_limit + + s = _preflight_scheduler( + hard_limit=hard_limit, recent_peak=high, peak=peak + ) + with patch("omlx.scheduler.mx") as mock_mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=low + ): + mock_mx.get_active_memory.return_value = low + result = s._preflight_memory_check(_preflight_request()) + + assert result is not None + assert "Prefill would require" in result + + def test_admits_when_recent_peak_also_low(self): + """Control: when recent_peak is low too, the request passes.""" + hard_limit = 100 * GB + peak = 20 * GB + low = 10 * GB + + s = _preflight_scheduler( + hard_limit=hard_limit, recent_peak=low, peak=peak + ) + with patch("omlx.scheduler.mx") as mock_mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=low + ): + mock_mx.get_active_memory.return_value = low + result = s._preflight_memory_check(_preflight_request()) + + assert result is None diff --git a/tests/test_scheduler_prefill_forward_gate.py b/tests/test_scheduler_prefill_forward_gate.py new file mode 100644 index 000000000..b4b301a31 --- /dev/null +++ b/tests/test_scheduler_prefill_forward_gate.py @@ -0,0 +1,355 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the forward-FRONT prefill memory gate (P0c). + +The gate (_prefill_forward_gate) predicts a prefill chunk's peak memory +BEFORE running self.model(...) and raises RuntimeError when it would breach +the hard cap, so the request is aborted cleanly instead of the transient +landing on the Metal ceiling and kernel-panicking the machine. The legacy +chunk-END check only fires after the allocation has already happened, which +on Apple Silicon is too late. + +Strategy: pure mocks, no model load. The discriminating assertion is that +when the predicted peak exceeds the cap the model forward is NOT called -- +on pre-change code (no forward-front gate) the forward WOULD run. +""" + +from unittest.mock import MagicMock, patch + +import mlx.core as mx +import pytest + +from omlx.request import Request, RequestStatus, SamplingParams +from omlx.scheduler import Scheduler, SchedulerConfig, _PrefillState + +GB = 1024**3 + + +# --------------------------------------------------------------------------- +# Direct unit tests of _prefill_forward_gate +# --------------------------------------------------------------------------- + + +def _gate_scheduler( + *, + hard_limit: int, + recent_peak: int, + estimate: int, + margin: int, + guard: bool = True, + monitor: bool = True, +): + """Build a bare Scheduler wired only for _prefill_forward_gate.""" + s = Scheduler.__new__(Scheduler) + s._prefill_memory_guard = guard + s._memory_hard_limit_bytes = hard_limit + s._memory_recent_peak_bytes = recent_peak + s._prefill_transient_margin_bytes = margin + s.config = MagicMock(prefill_step_size=2048) + if monitor: + s.memory_monitor = MagicMock() + s.memory_monitor.estimate_prefill_peak_bytes = MagicMock( + return_value=estimate + ) + else: + s.memory_monitor = None + return s + + +def _call_gate(s, chunk_tokens, *, instant): + """Invoke the gate with patched instant memory probes.""" + with patch("omlx.scheduler.mx") as mock_mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=instant + ): + mock_mx.get_active_memory.return_value = instant + s._prefill_forward_gate( + chunk_tokens, request_id="rid-1", loop_label="external" + ) + + +class TestPrefillForwardGateUnit: + """Direct tests of the gate predicate.""" + + def test_raises_when_predicted_peak_exceeds_cap(self): + """current(high-water) + estimate + margin > cap -> RuntimeError. + + Numbers chosen so the instant reading alone (low) + estimate would + fit, but the high-water recent_peak + estimate + margin overflow. + """ + hard = 107 * GB + estimate = 2 * GB + margin = 10 * GB + instant = 50 * GB + recent_peak = 96 * GB + # Instant + estimate (no margin) fits; this is the trough the legacy + # check could read. + assert instant + estimate <= hard + # High-water + estimate + margin overflows -> must refuse. + assert recent_peak + estimate + margin > hard + + s = _gate_scheduler( + hard_limit=hard, + recent_peak=recent_peak, + estimate=estimate, + margin=margin, + ) + with pytest.raises(RuntimeError, match="refused before forward"): + _call_gate(s, 256, instant=instant) + + def test_passes_when_predicted_peak_fits(self): + """current + estimate + margin <= cap -> no raise.""" + hard = 107 * GB + s = _gate_scheduler( + hard_limit=hard, + recent_peak=80 * GB, + estimate=2 * GB, + margin=10 * GB, + ) + # 80 + 2 + 10 = 92 < 107. + _call_gate(s, 256, instant=80 * GB) # must not raise + + def test_margin_is_what_tips_it_over(self): + """Without the margin it would pass; the margin alone forces refusal. + + Pins that the margin term is actually applied (not dropped). + """ + hard = 100 * GB + estimate = 1 * GB + instant = 90 * GB + recent_peak = 90 * GB + # current + estimate (no margin) = 91 < 100 -> would pass. + assert recent_peak + estimate < hard + # current + estimate + margin = 101 > 100 -> must refuse. + margin = 10 * GB + assert recent_peak + estimate + margin > hard + + s = _gate_scheduler( + hard_limit=hard, + recent_peak=recent_peak, + estimate=estimate, + margin=margin, + ) + with pytest.raises(RuntimeError): + _call_gate(s, 256, instant=instant) + + # Same setup, margin=0 -> passes (control). + s0 = _gate_scheduler( + hard_limit=hard, + recent_peak=recent_peak, + estimate=estimate, + margin=0, + ) + _call_gate(s0, 256, instant=instant) # must not raise + + def test_uses_recent_peak_high_water_not_just_instant(self): + """A mid-prefill trough in the instant reading must not mask the + real footprint: recent_peak high + low instant still refuses.""" + hard = 107 * GB + s = _gate_scheduler( + hard_limit=hard, + recent_peak=100 * GB, # real footprint + estimate=2 * GB, + margin=10 * GB, + ) + # Instant reads a trough at 50GB; without recent_peak it would pass. + assert 50 * GB + 2 * GB + 10 * GB < hard + with pytest.raises(RuntimeError): + _call_gate(s, 256, instant=50 * GB) + + def test_noop_when_guard_off(self): + s = _gate_scheduler( + hard_limit=107 * GB, + recent_peak=200 * GB, + estimate=200 * GB, + margin=10 * GB, + guard=False, + ) + _call_gate(s, 256, instant=200 * GB) # guard off -> never raises + + def test_noop_when_hard_limit_unset(self): + s = _gate_scheduler( + hard_limit=0, + recent_peak=200 * GB, + estimate=200 * GB, + margin=10 * GB, + ) + _call_gate(s, 256, instant=200 * GB) # no limit -> never raises + + def test_noop_when_monitor_missing(self): + s = _gate_scheduler( + hard_limit=107 * GB, + recent_peak=200 * GB, + estimate=200 * GB, + margin=10 * GB, + monitor=False, + ) + _call_gate(s, 256, instant=200 * GB) # no monitor -> never raises + + def test_noop_when_estimate_zero(self): + """estimate==0 means the model can't be estimated -> leave it to the + legacy chunk-end check, do not raise here.""" + s = _gate_scheduler( + hard_limit=107 * GB, + recent_peak=200 * GB, + estimate=0, + margin=10 * GB, + ) + _call_gate(s, 256, instant=200 * GB) # estimate 0 -> never raises + + def test_noop_when_chunk_zero(self): + s = _gate_scheduler( + hard_limit=107 * GB, + recent_peak=200 * GB, + estimate=2 * GB, + margin=10 * GB, + ) + _call_gate(s, 0, instant=200 * GB) # nothing to process -> never raises + + +# --------------------------------------------------------------------------- +# Integration: gate fires BEFORE the model forward in the real chunked loop +# --------------------------------------------------------------------------- + + +def _integration_scheduler(*, hard_gb: float, estimate_bytes: int, margin_gb: float): + """Scheduler with a mock model, hard cap on but soft off (so the adaptive + throttle passes through and only the forward-front gate can fire).""" + model = MagicMock() + model.layers = [] + tokenizer = MagicMock() + tokenizer.eos_token_id = 2 + config = SchedulerConfig( + max_num_seqs=8, + prefill_step_size=256, + chunked_prefill=True, + paged_cache_block_size=0, + ) + s = Scheduler(model=model, tokenizer=tokenizer, config=config) + s.batch_generator = MagicMock() + # Soft limit 0 -> _adaptive_chunk_size is a pure passthrough. + s._memory_limit_bytes = 0 + s._memory_hard_limit_bytes = int(hard_gb * GB) + s._prefill_memory_guard = True + s._prefill_transient_margin_bytes = int(margin_gb * GB) + s.memory_monitor = MagicMock() + s.memory_monitor.estimate_prefill_peak_bytes = MagicMock( + return_value=estimate_bytes + ) + return s, model + + +def _prefill_state(n_tokens: int) -> _PrefillState: + req = Request( + request_id="rid-int", + prompt=list(range(n_tokens + 1)), + sampling_params=SamplingParams(max_tokens=8), + ) + req.prompt_token_ids = list(range(n_tokens + 1)) + req.num_prompt_tokens = n_tokens + 1 + req.status = RequestStatus.WAITING + return _PrefillState( + request=req, + cache=[], + tokens_remaining=mx.array(list(range(n_tokens)))[None], + last_token=[n_tokens], + tokens_processed=0, + base_size=0, + emitted_boundaries={}, + boundary_enabled=False, + block_size=0, + total_length=n_tokens + 1, + ) + + +class TestForwardGateBlocksForward: + """The gate must abort the chunk BEFORE self.model(...) runs.""" + + def test_over_cap_does_not_call_model_forward(self): + """Predicted peak over cap -> RuntimeError raised and model NOT called. + + This is the discriminating assertion that pins the fix: pre-change + code (no forward-front gate) reaches self.model(chunk, ...) and the + transient lands on the cap (kernel panic on real hardware). With the + gate, the forward never runs. + """ + # recent_peak high (set via instant probes) + estimate + margin > cap. + s, model = _integration_scheduler( + hard_gb=107.0, estimate_bytes=2 * GB, margin_gb=10 * 1.0 + ) + state = _prefill_state(n_tokens=200) + + high = int(100 * GB) + with patch( + "omlx.scheduler.mx.get_active_memory", return_value=high + ), patch("omlx.scheduler.get_phys_footprint", return_value=high), patch( + "omlx.scheduler.mx.eval" + ) as mock_eval: + with pytest.raises(RuntimeError, match="refused before forward"): + s._step_prefill_chunk(state) + + # The whole point: the model forward must not have executed. + model.assert_not_called() + mock_eval.assert_not_called() + + def test_under_cap_runs_model_forward(self): + """Predicted peak under cap -> forward runs as normal (control).""" + s, model = _integration_scheduler( + hard_gb=107.0, estimate_bytes=1 * GB, margin_gb=2.0 + ) + state = _prefill_state(n_tokens=200) + + low = int(50 * GB) # 50 + 1 + 2 = 53 < 107 + with patch( + "omlx.scheduler.mx.get_active_memory", return_value=low + ), patch("omlx.scheduler.get_phys_footprint", return_value=low), patch( + "omlx.scheduler.mx.eval" + ), patch("omlx.scheduler._sync_and_clear_cache"), patch( + "omlx.scheduler.get_prefill_tracker" + ): + done = s._step_prefill_chunk(state) + + # Forward ran exactly once; prefill consumed the only chunk. + assert model.call_count == 1 + assert done is True + + +class TestForwardGateExternalLoopWiring: + """Sanity that the external loop wiring calls the gate before the forward. + + Patch _prefill_forward_gate to raise; the model forward must not run. + Uses a tiny text-only request through _do_external_prefill. + """ + + def test_external_loop_calls_gate_before_forward(self): + model = MagicMock() + model.layers = [] + tokenizer = MagicMock() + tokenizer.eos_token_id = 2 + config = SchedulerConfig( + max_num_seqs=8, + prefill_step_size=256, + chunked_prefill=False, + paged_cache_block_size=0, + ) + s = Scheduler(model=model, tokenizer=tokenizer, config=config) + + req = Request( + request_id="rid-ext", + prompt=[1, 2, 3, 4, 5], + sampling_params=SamplingParams(max_tokens=8), + ) + req.prompt_token_ids = [1, 2, 3, 4, 5] + req.num_prompt_tokens = 5 + + with patch.object( + s, + "_prefill_forward_gate", + side_effect=RuntimeError("Prefill refused before forward"), + ) as mock_gate, patch( + "omlx.scheduler.make_prompt_cache", return_value=[] + ): + with pytest.raises(RuntimeError, match="refused before forward"): + s._do_external_prefill(req, [1, 2, 3, 4, 5], None) + + mock_gate.assert_called_once() + # Gate raised -> forward must not have run. + model.assert_not_called()