From b44852ad64effb1d5b3d754dda4337e9a9cdf0f4 Mon Sep 17 00:00:00 2001 From: yuanwei Date: Fri, 5 Jun 2026 16:34:50 -0700 Subject: [PATCH 1/2] fix(memory): refuse oversized prefill chunks before the forward (anti-panic) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Stop the m5max whole-machine watchdog panic caused by a large model's prefill transient breaching the Metal cap. The model loads and runs normally; only an individual request that cannot fit is refused (503-class), not the model. Root cause: oMLX bounds memory reactively (enforcer polls phys every 1s) and the in-prefill memory check ran AFTER self.model()+mx.eval. On Apple UMA a chunk that overshoots the Metal wired limit kernel-panics the whole machine, so the post-forward Python check never runs. glm4.5-air-106b (85GB) on a 128GB box left ~22GB for KV+prefill; a per-chunk MoE-dequant transient (measured up to 7.4GB, peaks hit 110GB vs the 107.5GB cap) crossed the cap and rebooted the box twice. - Scheduler._prefill_forward_gate: before each prefill forward (external loop + chunked-step mirror), predict current(high-water: max active/phys/recent_peak) + estimate_prefill_peak_bytes(KV+SDPA) + a conservative transient margin; if it would exceed the hard cap, raise BEFORE the forward. The existing #1405 cleanup converts that RuntimeError into a finish_reason="error" output -- request refused cleanly, machine not crashed. Legacy post-forward check stays as backstop. - New MemorySettings.prefill_transient_margin_gb (default 10GB), propagated settings -> enforcer -> scheduler. estimate_prefill_peak_bytes does not model the MoE expert-dequant spike, so this margin carries the guarantee: margin (10) > worst observed single-step jump (7.4GB). - Preflight now also maxes against the enforcer recent high-water mark. - Reverts model_load_prefill_headroom_gb load rejection (user feedback: a model must be loadable; refuse the request, not the model). Honest residual: the gate reads current just after the prior chunk's cache clear (active trough), leaning on phys_footprint stickiness + recent_peak to avoid a trough misread; a misread could still admit a crashing chunk. Not a literal never-panic guarantee -- needs on-hardware [memgate] log validation; the real fix is preemptive KV offload (separate work). Tests: 12 new (gate raise/pass, margin-is-the-trip, no-op guards, integration asserting model forward NOT called over-cap), verified to fail with gate neutered. Full suite 4542 pass / 3 known api_key fail / 19 skip on m2max -- zero regression. --- 止住 m5max 因大模型 prefill 瞬时撞穿 Metal cap 导致的整机 watchdog panic. 模型正常 load 正常用; 只拒放不下的单个请求(503 级), 不禁模型. 根因: oMLX 内存是 reactive 管控(enforcer 每 1s poll phys), prefill 内存检查在 self.model()+mx.eval 之后. Apple UMA 上 chunk 撞穿 Metal wired limit 直接整机 kernel panic, forward 后的 Python 检查根本来不及跑. glm4.5-air-106b(85GB)在 128GB 机只剩 ~22GB 给 KV+prefill; 每 chunk 的 MoE 反量化瞬时(实测达 7.4GB, 峰冲到 110GB vs 107.5 cap)撞穿后整机重启两次. - Scheduler._prefill_forward_gate: 每个 prefill forward 前(external loop + chunked-step mirror), 预测 current(高水位 max active/phys/recent_peak) + estimate_prefill_peak_bytes(KV+SDPA) + 保守 transient margin; 超 hard cap 就在 forward 前 raise. 现有 #1405 cleanup 把 RuntimeError 转成 finish_reason=error 输出 -- 干净拒请求, 不崩机. forward 后旧检查留作兜底. - 新增 MemorySettings.prefill_transient_margin_gb(默认 10GB), settings -> enforcer -> scheduler 传播. estimate 不含 MoE 反量化瞬时, margin 扛保证: margin(10) > 实测 最坏单步跳变(7.4GB). - preflight 也改用 enforcer 近期高水位取 max. - 撤掉 model_load_prefill_headroom_gb 拒 load(用户反馈: 模型必须能 load, 拒请求不拒模型). 诚实残余: gate 读 current 在上个 chunk cache clear 后(active 谷), 靠 phys_footprint 黏性 + recent_peak 避免读谷; 读到谷仍可能放行会崩的 chunk. 不是字面"绝不 panic" -- 需真机 [memgate] log 验证; 真治本是抢占式 KV offload(独立工作). 测试: 12 个新增, 验证 gate 失效时会 fail. 完整套件 m2max 4542 pass / 3 已知 api_key fail / 19 skip -- 零回归. --- omlx/process_memory_enforcer.py | 27 ++ omlx/scheduler.py | 130 ++++++- omlx/server.py | 1 + omlx/settings.py | 31 ++ tests/test_process_memory_enforcer.py | 57 +++ tests/test_scheduler_admission.py | 78 +++- tests/test_scheduler_prefill_forward_gate.py | 355 +++++++++++++++++++ 7 files changed, 677 insertions(+), 2 deletions(-) create mode 100644 tests/test_scheduler_prefill_forward_gate.py diff --git a/omlx/process_memory_enforcer.py b/omlx/process_memory_enforcer.py index 7d11fee31..41f3370d1 100644 --- a/omlx/process_memory_enforcer.py +++ b/omlx/process_memory_enforcer.py @@ -33,6 +33,7 @@ import logging import subprocess import sys +from collections import deque from typing import TYPE_CHECKING, Any import mlx.core as mx @@ -291,6 +292,7 @@ def __init__( hard_threshold: float = 0.95, prefill_safe_zone_ratio: float = 0.80, prefill_min_chunk_tokens: int = 32, + prefill_transient_margin_gb: float = 0.0, ): """ Initialize the process memory enforcer. @@ -317,6 +319,11 @@ def __init__( prefill_safe_zone_ratio: Fraction of hard cap below which prefill runs at full chunk size; above triggers adaptive shrink. prefill_min_chunk_tokens: Floor for adaptive shrink. + prefill_transient_margin_gb: Conservative margin added to the + modelled per-chunk prefill peak by the scheduler's + forward-front gate, covering the MoE expert-dequant activation + spike that estimate_prefill_peak_bytes does not model. + Propagated to each scheduler. 0 = no extra margin. """ self._engine_pool = engine_pool self._memory_guard_tier = self._normalize_tier(memory_guard_tier) @@ -331,6 +338,9 @@ def __init__( self._hard_threshold = hard_threshold self._prefill_safe_zone_ratio = prefill_safe_zone_ratio self._prefill_min_chunk_tokens = prefill_min_chunk_tokens + self._prefill_transient_margin_bytes = max( + 0, int(prefill_transient_margin_gb * 1024**3) + ) self._task: asyncio.Task | None = None self._running = False # Most recently observed pressure level, consumed by scheduler / @@ -340,6 +350,13 @@ def __init__( # or the call failed). Used by the admin dashboard to surface a # warning when the kernel iogpu.wired_limit_mb is below this. self._metal_wired_limit_request: int = 0 + # Rolling window of recent usage readings + their high-water mark. + # Prefill memory dips into a trough between chunks, so the instant + # reading can read low mid-prefill; preflight admission consults this + # peak instead so it does not wave through a request that will wall + # the next chunk. Updated on every poll iteration. + self._usage_window: deque[int] = deque(maxlen=5) + self._recent_peak_bytes: int = 0 @staticmethod def _normalize_tier(tier: str) -> str: @@ -503,6 +520,10 @@ def get_final_ceiling(self) -> int: """Public accessor used by engine_pool pre-load admission.""" return self._get_hard_limit_bytes() + def recent_peak_bytes(self) -> int: + """Recent high-water memory usage over the last few poll ticks.""" + return self._recent_peak_bytes + def _soft_bytes(self) -> int: """Soft watermark: ceiling * soft_threshold.""" ceiling = self._get_hard_limit_bytes() @@ -589,6 +610,10 @@ def _propagate_memory_limit(self) -> None: scheduler._admission_paused = admission_paused scheduler._prefill_safe_zone_ratio = self._prefill_safe_zone_ratio scheduler._prefill_min_chunk_tokens = self._prefill_min_chunk_tokens + scheduler._prefill_transient_margin_bytes = ( + self._prefill_transient_margin_bytes + ) + scheduler._memory_recent_peak_bytes = self._recent_peak_bytes bg = getattr(scheduler, "batch_generator", None) if bg is not None and hasattr(bg, "_memory_limit_bytes"): bg._memory_limit_bytes = soft_limit @@ -671,6 +696,8 @@ async def _check_and_enforce(self) -> None: return current = self._current_usage_bytes() + self._usage_window.append(current) + self._recent_peak_bytes = max(self._usage_window) if self._usage_window else current soft = int(ceiling * self._soft_threshold) hard = int(ceiling * self._hard_threshold) prev_level = self._pressure_level diff --git a/omlx/scheduler.py b/omlx/scheduler.py index 253845bf9..86ff37fc0 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -796,10 +796,23 @@ def __init__( # soft_threshold. Schedulers stop admitting new prefills while this is # set; in-flight requests proceed. self._admission_paused: bool = False + # Recent high-water memory usage, propagated from ProcessMemoryEnforcer. + # Preflight admission maxes the instant reading against this so it does + # not wave through a request during a prefill trough that would wall + # the next chunk. 0 until the enforcer sets it. + self._memory_recent_peak_bytes: int = 0 # Adaptive prefill throttle params, propagated from enforcer. # Until set, _adaptive_chunk_size is a no-op (returns requested as-is). self._prefill_safe_zone_ratio: float = 0.80 self._prefill_min_chunk_tokens: int = 32 + # Conservative transient margin (bytes) added to the modelled per-chunk + # prefill peak by the forward-front gate (_prefill_forward_gate). + # estimate_prefill_peak_bytes only models KV + SDPA; it does NOT model + # the MoE expert-dequant activation spike, which on glm4.5-air-106b + # (MoE) is the dominant single-step transient. Sized from the observed + # worst-case single-step current jump on m5max (see MemorySettings + # .prefill_transient_margin_gb). 0 until the enforcer propagates it. + self._prefill_transient_margin_bytes: int = 0 # EWMA estimator of per-token chunk transient bytes, used by # _adaptive_chunk_size in the caution zone. Owned per-scheduler. _tracker_model_id = "" @@ -1729,6 +1742,97 @@ def _apply_turboquant_kv_convert(self, prompt_cache: list[Any]) -> None: f"cache layers to {bits}-bit{skip_msg}" ) + def _prefill_forward_gate( + self, + chunk_tokens: int, + *, + request_id: str, + loop_label: str, + ) -> None: + """Forward-FRONT memory gate: refuse a prefill chunk BEFORE it runs. + + The chunk-end check (after self.model(...) + mx.eval) only fires once + the transient has already been allocated -- on Apple Silicon a chunk + that overshoots the Metal cap kernel-panics the whole machine, so an + after-the-fact Python check never runs. This predicts the next chunk's + peak and raises BEFORE the forward when it would exceed the hard cap, + so the request is aborted cleanly (the #1405 cleanup paths convert this + RuntimeError into a finish_reason="error" output) instead of crashing. + + predicted_peak = current(high-water) + estimate(KV+SDPA) + margin + - current: max(active, phys_footprint, recent_peak). recent_peak is + the enforcer's rolling high-water mark, so a mid-prefill trough in + the instant reading does not mask the real footprint. + - estimate: memory_monitor.estimate_prefill_peak_bytes models this + chunk's KV + SDPA activation only. + - margin: _prefill_transient_margin_bytes covers what the estimate + does NOT model -- chiefly the MoE expert-dequant activation spike, + the dominant single-step transient on MoE models like glm4.5-air. + + No-op (returns) when the guard is off, the hard limit is unset, the + memory_monitor is missing, or the estimate is unavailable (0) -- in + every such case the legacy chunk-end check remains the only line of + defense, exactly as before this gate existed. + + At chunk granularity the KV+SDPA estimate is small, so the margin is + the effective trip point (gate fires once current ~> cap - margin). + Correctness depends on `current` reflecting the true high-water mark: + this iteration runs right after the previous chunk's + _sync_and_clear_cache, when mx.get_active_memory() troughs, so the + gate leans on get_phys_footprint() + the propagated recent_peak to + avoid reading a trough. See MemorySettings.prefill_transient_margin_gb. + + Raises: + RuntimeError: when the predicted peak exceeds the hard limit. + """ + if not self._prefill_memory_guard: + return + if self._memory_hard_limit_bytes <= 0: + return + if self.memory_monitor is None: + return + if chunk_tokens <= 0: + return + + estimate = self.memory_monitor.estimate_prefill_peak_bytes( + chunk_tokens, self.config.prefill_step_size + ) + if estimate == 0: + return # can't estimate this model -> leave it to the chunk-end check + + predicted_transient = estimate + self._prefill_transient_margin_bytes + current = max( + mx.get_active_memory(), + get_phys_footprint(), + self._memory_recent_peak_bytes, + ) + predicted_peak = current + predicted_transient + + if predicted_peak > self._memory_hard_limit_bytes: + logger.warning( + "[memgate:%s] rid=%s refusing prefill chunk (n=%d) BEFORE " + "forward: predicted peak %.3fGB = current %.3fGB + transient " + "%.3fGB (estimate %.3fGB + margin %.3fGB) exceeds hard cap " + "%.3fGB. Aborting request to avoid a Metal-cap kernel panic.", + loop_label, + request_id, + chunk_tokens, + predicted_peak / 1024**3, + current / 1024**3, + predicted_transient / 1024**3, + estimate / 1024**3, + self._prefill_transient_margin_bytes / 1024**3, + self._memory_hard_limit_bytes / 1024**3, + ) + raise RuntimeError( + "Prefill refused before forward: predicted peak " + f"{predicted_peak / 1024**3:.1f}GB (current " + f"{current / 1024**3:.1f}GB + transient " + f"{predicted_transient / 1024**3:.1f}GB) would exceed the " + f"memory ceiling {self._memory_hard_limit_bytes / 1024**3:.1f}GB. " + "Reduce context length or increase --max-process-memory." + ) + def _do_external_prefill( self, request: "Request", @@ -1885,6 +1989,15 @@ def _do_external_prefill( extra_kwargs, n_to_process ) + # Forward-FRONT gate: predict this chunk's peak and refuse BEFORE + # the forward if it would breach the Metal cap (post-forward checks + # cannot save us -- the overshoot kernel-panics the machine). + self._prefill_forward_gate( + n_to_process, + request_id=request.request_id, + loop_label="external", + ) + _throttle_pre = get_phys_footprint() self.model(input_arr[:, :n_to_process], cache=prompt_cache, **model_kwargs) mx.eval([c.state for c in prompt_cache]) @@ -2223,6 +2336,17 @@ def _step_prefill_chunk(self, state: _PrefillState) -> bool: chunk = state.tokens_remaining[:, :n] state.tokens_remaining = state.tokens_remaining[:, n:] + + # Forward-FRONT gate: predict this chunk's peak and refuse BEFORE the + # forward if it would breach the Metal cap. Mirrors the external loop; + # raises RuntimeError that _advance_chunked_prefills converts into a + # finish_reason="error" output without crashing the machine. + self._prefill_forward_gate( + n, + request_id=state.request.request_id, + loop_label="chunked_step", + ) + _throttle_pre = get_phys_footprint() self.model(chunk, cache=state.cache) mx.eval([c.state for c in state.cache]) @@ -4541,7 +4665,11 @@ def _preflight_memory_check(self, request: "Request") -> str | None: if peak == 0: return None # can't estimate, skip - current = max(mx.get_active_memory(), get_phys_footprint()) + current = max( + mx.get_active_memory(), + get_phys_footprint(), + self._memory_recent_peak_bytes, + ) if current + peak > self._memory_hard_limit_bytes: from .utils.hardware import format_bytes diff --git a/omlx/server.py b/omlx/server.py index 0d5df6745..9f0d9fb69 100644 --- a/omlx/server.py +++ b/omlx/server.py @@ -359,6 +359,7 @@ async def lifespan(app: FastAPI): hard_threshold=mem_cfg.hard_threshold, prefill_safe_zone_ratio=mem_cfg.prefill_safe_zone_ratio, prefill_min_chunk_tokens=mem_cfg.prefill_min_chunk_tokens, + prefill_transient_margin_gb=mem_cfg.prefill_transient_margin_gb, ) _server_state.process_memory_enforcer = enforcer _server_state.engine_pool._process_memory_enforcer = enforcer diff --git a/omlx/settings.py b/omlx/settings.py index adba2a6d8..d9a4c631f 100644 --- a/omlx/settings.py +++ b/omlx/settings.py @@ -391,6 +391,33 @@ class MemorySettings: # aborted via the same cleanup path the hard-limit RuntimeError uses. prefill_safe_zone_ratio: float = 0.80 prefill_min_chunk_tokens: int = 32 + # Conservative transient margin added to the modelled per-chunk prefill + # peak by the scheduler's forward-FRONT memory gate. The model in + # memory_monitor.estimate_prefill_peak_bytes only accounts for KV + SDPA; + # it does NOT model the MoE expert-dequant activation spike, which on a + # MoE model (glm4.5-air-106b) is the dominant single-step transient. The + # gate refuses a chunk before its forward when current + estimate + this + # margin would breach the hard cap, so the transient never actually lands + # on the Metal ceiling (which would kernel-panic the whole machine, an + # after-the-fact Python check cannot catch it). + # + # At chunk granularity the KV+SDPA estimate is tiny (~0.3GB for a + # 256-token GLM chunk), so this margin IS the safety mechanism. The + # load-bearing guarantee is: margin > the worst-case single-step memory + # jump. Across the 2026-06-06 m5max glm4.5-air-106b crash log the max + # trough->peak single-step delta was 7.44GB (peak overshoot reached + # 110.4GB vs a 107.5GB cap). With margin=10GB the gate effectively fires + # once current exceeds ~cap - margin (~97GB); since every observed forward + # jumps <=7.44GB < 10GB, any chunk that would land above the cap starts + # from a current already past that trip point, so the gate refuses it + # before the forward. 10 = ceil(7.44) padded for an unobserved larger + # spike. NOTE: this holds only while `current` is read at true high-water + # at gate time (it relies on phys_footprint stickiness + recent_peak to + # mask the post-_sync_and_clear_cache trough in mx.get_active_memory()); + # watch the [memgate]/[memcheck] logs on hardware and raise the margin if + # a sub-trip-point trough precedes an over-cap chunk. Set to 0 to disable + # the extra margin (the gate then uses the bare KV+SDPA estimate). + prefill_transient_margin_gb: float = 10.0 def to_dict(self) -> dict[str, Any]: """Convert to dictionary.""" @@ -402,6 +429,7 @@ def to_dict(self) -> dict[str, Any]: "hard_threshold": self.hard_threshold, "prefill_safe_zone_ratio": self.prefill_safe_zone_ratio, "prefill_min_chunk_tokens": self.prefill_min_chunk_tokens, + "prefill_transient_margin_gb": self.prefill_transient_margin_gb, } @classmethod @@ -440,6 +468,9 @@ def from_dict(cls, data: dict[str, Any]) -> MemorySettings: prefill_min_chunk_tokens=int( data.get("prefill_min_chunk_tokens", 32) ), + prefill_transient_margin_gb=float( + data.get("prefill_transient_margin_gb", 10.0) + ), ) diff --git a/tests/test_process_memory_enforcer.py b/tests/test_process_memory_enforcer.py index b456ecc66..64ae81c0b 100644 --- a/tests/test_process_memory_enforcer.py +++ b/tests/test_process_memory_enforcer.py @@ -965,6 +965,63 @@ async def test_check_and_enforce_walks_caps_on_soft(self, enforcer): scheduler.adjust_store_cache_cap.assert_called_with("soft") +class TestRecentPeakTracking: + """Tests for recent-peak high-water tracking across poll ticks.""" + + @pytest.mark.asyncio + async def test_recent_peak_is_window_max(self, enforcer): + """After several poll ticks, recent_peak == max over the window. + + The window update lives after the ceiling > 0 early return in + _check_and_enforce, so the fixture's positive ceiling (10 GB) is + required for the update to run at all. + """ + readings = [3 * 1024**3, 5 * 1024**3, 2 * 1024**3, 4 * 1024**3] + with patch("omlx.process_memory_enforcer.mx") as mock_mx, patch( + "omlx.process_memory_enforcer.get_phys_footprint", return_value=0 + ): + mock_mx.get_active_memory.side_effect = _cycling(readings) + for _ in readings: + await enforcer._check_and_enforce() + + assert enforcer.recent_peak_bytes() == 5 * 1024**3 + + @pytest.mark.asyncio + async def test_recent_peak_drops_after_window_slides(self, enforcer): + """Old high readings age out once they leave the maxlen=5 window.""" + # Feed one big reading, then enough small ones to push it out of the + # 5-slot window. + big = 9 * 1024**3 + small = 1 * 1024**3 + readings = [big, small, small, small, small, small] + with patch("omlx.process_memory_enforcer.mx") as mock_mx, patch( + "omlx.process_memory_enforcer.get_phys_footprint", return_value=0 + ): + mock_mx.get_active_memory.side_effect = _cycling(readings) + for _ in readings: + await enforcer._check_and_enforce() + + # After 6 ticks the first (big) reading has slid out of the window, + # leaving only small readings. + assert enforcer.recent_peak_bytes() == small + + def test_propagates_recent_peak_to_scheduler(self, enforcer): + """_propagate_memory_limit pushes recent_peak onto each scheduler.""" + scheduler = MagicMock(spec=[]) + scheduler._memory_limit_bytes = 0 + scheduler._memory_hard_limit_bytes = 0 + scheduler._memory_recent_peak_bytes = 0 + engine = MagicMock(spec=[]) + engine.scheduler = scheduler + entry = _make_entry("model-a", engine=engine) + enforcer._engine_pool._entries = {"model-a": entry} + + enforcer._recent_peak_bytes = 7 * 1024**3 + enforcer._propagate_memory_limit() + + assert scheduler._memory_recent_peak_bytes == 7 * 1024**3 + + class TestProperties: """Tests for enforcer properties.""" diff --git a/tests/test_scheduler_admission.py b/tests/test_scheduler_admission.py index 0ed0c6458..1640ef8d3 100644 --- a/tests/test_scheduler_admission.py +++ b/tests/test_scheduler_admission.py @@ -2,13 +2,15 @@ """Tests for scheduler admission control (queue depth cap + admission_paused).""" from collections import deque -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest from omlx.exceptions import SchedulerQueueFullError from omlx.scheduler import Scheduler +GB = 1024**3 + @pytest.fixture def scheduler(): @@ -100,3 +102,77 @@ def test_default_false(self): s._prefill_memory_guard = False s._admission_paused = False assert s._admission_paused is False + + +def _preflight_scheduler(hard_limit: int, recent_peak: int, peak: int): + """Build a bare Scheduler wired for _preflight_memory_check. + + `peak` is the value the (mocked) memory_monitor estimates for the + prefill chunk; `recent_peak` is the propagated high-water mark. + """ + s = Scheduler.__new__(Scheduler) + s._prefill_memory_guard = True + s._memory_hard_limit_bytes = hard_limit + s._memory_recent_peak_bytes = recent_peak + s.config = MagicMock(prefill_step_size=2048) + s.memory_monitor = MagicMock() + s.memory_monitor.estimate_prefill_peak_bytes = MagicMock(return_value=peak) + return s + + +def _preflight_request(): + r = MagicMock() + r.num_prompt_tokens = 8192 + r.cached_tokens = 0 + return r + + +class TestPreflightRecentPeak: + """_preflight_memory_check uses the recent high-water mark, not just the + instant reading, so it does not wave through a request during a prefill + trough that would wall the next chunk.""" + + def test_rejects_on_recent_peak_when_instant_is_low(self): + """Instant active/phys low but recent_peak high -> reject. + + Picks numbers so that low + peak fits (pre-change behaviour would + admit) but recent_peak + peak exceeds the hard limit. This pins the + fix. + """ + hard_limit = 100 * GB + peak = 20 * GB + low = 10 * GB + high = 85 * GB + # Sanity: old code (low + peak) would have passed. + assert low + peak <= hard_limit + # New code (high + peak) must exceed the limit. + assert high + peak > hard_limit + + s = _preflight_scheduler( + hard_limit=hard_limit, recent_peak=high, peak=peak + ) + with patch("omlx.scheduler.mx") as mock_mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=low + ): + mock_mx.get_active_memory.return_value = low + result = s._preflight_memory_check(_preflight_request()) + + assert result is not None + assert "Prefill would require" in result + + def test_admits_when_recent_peak_also_low(self): + """Control: when recent_peak is low too, the request passes.""" + hard_limit = 100 * GB + peak = 20 * GB + low = 10 * GB + + s = _preflight_scheduler( + hard_limit=hard_limit, recent_peak=low, peak=peak + ) + with patch("omlx.scheduler.mx") as mock_mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=low + ): + mock_mx.get_active_memory.return_value = low + result = s._preflight_memory_check(_preflight_request()) + + assert result is None diff --git a/tests/test_scheduler_prefill_forward_gate.py b/tests/test_scheduler_prefill_forward_gate.py new file mode 100644 index 000000000..b4b301a31 --- /dev/null +++ b/tests/test_scheduler_prefill_forward_gate.py @@ -0,0 +1,355 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the forward-FRONT prefill memory gate (P0c). + +The gate (_prefill_forward_gate) predicts a prefill chunk's peak memory +BEFORE running self.model(...) and raises RuntimeError when it would breach +the hard cap, so the request is aborted cleanly instead of the transient +landing on the Metal ceiling and kernel-panicking the machine. The legacy +chunk-END check only fires after the allocation has already happened, which +on Apple Silicon is too late. + +Strategy: pure mocks, no model load. The discriminating assertion is that +when the predicted peak exceeds the cap the model forward is NOT called -- +on pre-change code (no forward-front gate) the forward WOULD run. +""" + +from unittest.mock import MagicMock, patch + +import mlx.core as mx +import pytest + +from omlx.request import Request, RequestStatus, SamplingParams +from omlx.scheduler import Scheduler, SchedulerConfig, _PrefillState + +GB = 1024**3 + + +# --------------------------------------------------------------------------- +# Direct unit tests of _prefill_forward_gate +# --------------------------------------------------------------------------- + + +def _gate_scheduler( + *, + hard_limit: int, + recent_peak: int, + estimate: int, + margin: int, + guard: bool = True, + monitor: bool = True, +): + """Build a bare Scheduler wired only for _prefill_forward_gate.""" + s = Scheduler.__new__(Scheduler) + s._prefill_memory_guard = guard + s._memory_hard_limit_bytes = hard_limit + s._memory_recent_peak_bytes = recent_peak + s._prefill_transient_margin_bytes = margin + s.config = MagicMock(prefill_step_size=2048) + if monitor: + s.memory_monitor = MagicMock() + s.memory_monitor.estimate_prefill_peak_bytes = MagicMock( + return_value=estimate + ) + else: + s.memory_monitor = None + return s + + +def _call_gate(s, chunk_tokens, *, instant): + """Invoke the gate with patched instant memory probes.""" + with patch("omlx.scheduler.mx") as mock_mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=instant + ): + mock_mx.get_active_memory.return_value = instant + s._prefill_forward_gate( + chunk_tokens, request_id="rid-1", loop_label="external" + ) + + +class TestPrefillForwardGateUnit: + """Direct tests of the gate predicate.""" + + def test_raises_when_predicted_peak_exceeds_cap(self): + """current(high-water) + estimate + margin > cap -> RuntimeError. + + Numbers chosen so the instant reading alone (low) + estimate would + fit, but the high-water recent_peak + estimate + margin overflow. + """ + hard = 107 * GB + estimate = 2 * GB + margin = 10 * GB + instant = 50 * GB + recent_peak = 96 * GB + # Instant + estimate (no margin) fits; this is the trough the legacy + # check could read. + assert instant + estimate <= hard + # High-water + estimate + margin overflows -> must refuse. + assert recent_peak + estimate + margin > hard + + s = _gate_scheduler( + hard_limit=hard, + recent_peak=recent_peak, + estimate=estimate, + margin=margin, + ) + with pytest.raises(RuntimeError, match="refused before forward"): + _call_gate(s, 256, instant=instant) + + def test_passes_when_predicted_peak_fits(self): + """current + estimate + margin <= cap -> no raise.""" + hard = 107 * GB + s = _gate_scheduler( + hard_limit=hard, + recent_peak=80 * GB, + estimate=2 * GB, + margin=10 * GB, + ) + # 80 + 2 + 10 = 92 < 107. + _call_gate(s, 256, instant=80 * GB) # must not raise + + def test_margin_is_what_tips_it_over(self): + """Without the margin it would pass; the margin alone forces refusal. + + Pins that the margin term is actually applied (not dropped). + """ + hard = 100 * GB + estimate = 1 * GB + instant = 90 * GB + recent_peak = 90 * GB + # current + estimate (no margin) = 91 < 100 -> would pass. + assert recent_peak + estimate < hard + # current + estimate + margin = 101 > 100 -> must refuse. + margin = 10 * GB + assert recent_peak + estimate + margin > hard + + s = _gate_scheduler( + hard_limit=hard, + recent_peak=recent_peak, + estimate=estimate, + margin=margin, + ) + with pytest.raises(RuntimeError): + _call_gate(s, 256, instant=instant) + + # Same setup, margin=0 -> passes (control). + s0 = _gate_scheduler( + hard_limit=hard, + recent_peak=recent_peak, + estimate=estimate, + margin=0, + ) + _call_gate(s0, 256, instant=instant) # must not raise + + def test_uses_recent_peak_high_water_not_just_instant(self): + """A mid-prefill trough in the instant reading must not mask the + real footprint: recent_peak high + low instant still refuses.""" + hard = 107 * GB + s = _gate_scheduler( + hard_limit=hard, + recent_peak=100 * GB, # real footprint + estimate=2 * GB, + margin=10 * GB, + ) + # Instant reads a trough at 50GB; without recent_peak it would pass. + assert 50 * GB + 2 * GB + 10 * GB < hard + with pytest.raises(RuntimeError): + _call_gate(s, 256, instant=50 * GB) + + def test_noop_when_guard_off(self): + s = _gate_scheduler( + hard_limit=107 * GB, + recent_peak=200 * GB, + estimate=200 * GB, + margin=10 * GB, + guard=False, + ) + _call_gate(s, 256, instant=200 * GB) # guard off -> never raises + + def test_noop_when_hard_limit_unset(self): + s = _gate_scheduler( + hard_limit=0, + recent_peak=200 * GB, + estimate=200 * GB, + margin=10 * GB, + ) + _call_gate(s, 256, instant=200 * GB) # no limit -> never raises + + def test_noop_when_monitor_missing(self): + s = _gate_scheduler( + hard_limit=107 * GB, + recent_peak=200 * GB, + estimate=200 * GB, + margin=10 * GB, + monitor=False, + ) + _call_gate(s, 256, instant=200 * GB) # no monitor -> never raises + + def test_noop_when_estimate_zero(self): + """estimate==0 means the model can't be estimated -> leave it to the + legacy chunk-end check, do not raise here.""" + s = _gate_scheduler( + hard_limit=107 * GB, + recent_peak=200 * GB, + estimate=0, + margin=10 * GB, + ) + _call_gate(s, 256, instant=200 * GB) # estimate 0 -> never raises + + def test_noop_when_chunk_zero(self): + s = _gate_scheduler( + hard_limit=107 * GB, + recent_peak=200 * GB, + estimate=2 * GB, + margin=10 * GB, + ) + _call_gate(s, 0, instant=200 * GB) # nothing to process -> never raises + + +# --------------------------------------------------------------------------- +# Integration: gate fires BEFORE the model forward in the real chunked loop +# --------------------------------------------------------------------------- + + +def _integration_scheduler(*, hard_gb: float, estimate_bytes: int, margin_gb: float): + """Scheduler with a mock model, hard cap on but soft off (so the adaptive + throttle passes through and only the forward-front gate can fire).""" + model = MagicMock() + model.layers = [] + tokenizer = MagicMock() + tokenizer.eos_token_id = 2 + config = SchedulerConfig( + max_num_seqs=8, + prefill_step_size=256, + chunked_prefill=True, + paged_cache_block_size=0, + ) + s = Scheduler(model=model, tokenizer=tokenizer, config=config) + s.batch_generator = MagicMock() + # Soft limit 0 -> _adaptive_chunk_size is a pure passthrough. + s._memory_limit_bytes = 0 + s._memory_hard_limit_bytes = int(hard_gb * GB) + s._prefill_memory_guard = True + s._prefill_transient_margin_bytes = int(margin_gb * GB) + s.memory_monitor = MagicMock() + s.memory_monitor.estimate_prefill_peak_bytes = MagicMock( + return_value=estimate_bytes + ) + return s, model + + +def _prefill_state(n_tokens: int) -> _PrefillState: + req = Request( + request_id="rid-int", + prompt=list(range(n_tokens + 1)), + sampling_params=SamplingParams(max_tokens=8), + ) + req.prompt_token_ids = list(range(n_tokens + 1)) + req.num_prompt_tokens = n_tokens + 1 + req.status = RequestStatus.WAITING + return _PrefillState( + request=req, + cache=[], + tokens_remaining=mx.array(list(range(n_tokens)))[None], + last_token=[n_tokens], + tokens_processed=0, + base_size=0, + emitted_boundaries={}, + boundary_enabled=False, + block_size=0, + total_length=n_tokens + 1, + ) + + +class TestForwardGateBlocksForward: + """The gate must abort the chunk BEFORE self.model(...) runs.""" + + def test_over_cap_does_not_call_model_forward(self): + """Predicted peak over cap -> RuntimeError raised and model NOT called. + + This is the discriminating assertion that pins the fix: pre-change + code (no forward-front gate) reaches self.model(chunk, ...) and the + transient lands on the cap (kernel panic on real hardware). With the + gate, the forward never runs. + """ + # recent_peak high (set via instant probes) + estimate + margin > cap. + s, model = _integration_scheduler( + hard_gb=107.0, estimate_bytes=2 * GB, margin_gb=10 * 1.0 + ) + state = _prefill_state(n_tokens=200) + + high = int(100 * GB) + with patch( + "omlx.scheduler.mx.get_active_memory", return_value=high + ), patch("omlx.scheduler.get_phys_footprint", return_value=high), patch( + "omlx.scheduler.mx.eval" + ) as mock_eval: + with pytest.raises(RuntimeError, match="refused before forward"): + s._step_prefill_chunk(state) + + # The whole point: the model forward must not have executed. + model.assert_not_called() + mock_eval.assert_not_called() + + def test_under_cap_runs_model_forward(self): + """Predicted peak under cap -> forward runs as normal (control).""" + s, model = _integration_scheduler( + hard_gb=107.0, estimate_bytes=1 * GB, margin_gb=2.0 + ) + state = _prefill_state(n_tokens=200) + + low = int(50 * GB) # 50 + 1 + 2 = 53 < 107 + with patch( + "omlx.scheduler.mx.get_active_memory", return_value=low + ), patch("omlx.scheduler.get_phys_footprint", return_value=low), patch( + "omlx.scheduler.mx.eval" + ), patch("omlx.scheduler._sync_and_clear_cache"), patch( + "omlx.scheduler.get_prefill_tracker" + ): + done = s._step_prefill_chunk(state) + + # Forward ran exactly once; prefill consumed the only chunk. + assert model.call_count == 1 + assert done is True + + +class TestForwardGateExternalLoopWiring: + """Sanity that the external loop wiring calls the gate before the forward. + + Patch _prefill_forward_gate to raise; the model forward must not run. + Uses a tiny text-only request through _do_external_prefill. + """ + + def test_external_loop_calls_gate_before_forward(self): + model = MagicMock() + model.layers = [] + tokenizer = MagicMock() + tokenizer.eos_token_id = 2 + config = SchedulerConfig( + max_num_seqs=8, + prefill_step_size=256, + chunked_prefill=False, + paged_cache_block_size=0, + ) + s = Scheduler(model=model, tokenizer=tokenizer, config=config) + + req = Request( + request_id="rid-ext", + prompt=[1, 2, 3, 4, 5], + sampling_params=SamplingParams(max_tokens=8), + ) + req.prompt_token_ids = [1, 2, 3, 4, 5] + req.num_prompt_tokens = 5 + + with patch.object( + s, + "_prefill_forward_gate", + side_effect=RuntimeError("Prefill refused before forward"), + ) as mock_gate, patch( + "omlx.scheduler.make_prompt_cache", return_value=[] + ): + with pytest.raises(RuntimeError, match="refused before forward"): + s._do_external_prefill(req, [1, 2, 3, 4, 5], None) + + mock_gate.assert_called_once() + # Gate raised -> forward must not have run. + model.assert_not_called() From c892d6804350c9ead0f5e145761880e918592465 Mon Sep 17 00:00:00 2001 From: yuanwei Date: Fri, 5 Jun 2026 23:31:23 -0700 Subject: [PATCH 2/2] fix(memory): admit prefill by memory budget, queue instead of stacking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Make request admission memory-aware so concurrent prefills can no longer stack past the Metal cap and kernel-panic the whole m5max box. Admission was capped only by max_num_seqs (a fixed concurrency), so N requests against an 85GB model each ran prefill and their transients summed over the ~22GB of headroom. Primary fix -- entry-point budget admission (Scheduler._schedule_waiting): before admitting a waiting request while work is in flight, predict its prefill peak (_predicted_prefill_peak_bytes = current high-water + KV+SDPA estimate + transient margin); if it would breach the hard cap, leave it QUEUED (appendleft + break, NOT rejected) and re-check next step once in-flight work frees its KV. Concurrency becomes adaptive with no per-model knob: a memory-rich model packs many requests, an 85GB model collapses to 1 and queues the rest. The first request (nothing in flight) is never deferred -- a lone request that cannot fit even alone is instead REJECTED by _preflight_memory_check (now sharing the same predicate, so it too carries the margin). Reading at admission time -- between generation steps -- is a stable high-water, not the mid-prefill chunk trough the forward gate had to read. In flight means running OR prefilling (_has_inflight_kv), NOT running alone. The documented crash is concurrent CHUNKED prefills stacking, and a request mid-chunked-prefill lives in self.prefilling, not self.running. Gating on self.running alone would wave the 2nd, 3rd, ... prefill straight past the budget into the exact stack this guard exists to prevent. self.prefilling always drains, so deferring on it cannot deadlock; a lone request still hits preflight reject (both empty). Margin 10 -> 12GB. The un-modelled MoE expert-dequant transient is SUB-POLL (faster than the enforcer's 1s sample), so it is invisible to every memory read and MUST be carried by the margin, not by reading the footprint more cleverly. The 2026-06-06 m5max crash showed an effective transient up to ~10.6GB; margin 10 was the actual root cause of the prior forward-gate miss (10 < 10.6), not the trough read. 12 = ceil(10.6) padded. The forward gate (_prefill_forward_gate) is kept, un-retired, as a per-chunk CONCURRENT-DRIFT backstop: admission snapshots memory once, the gate re-reads before every chunk and catches another in-flight request's KV growing during a long prefill. Because admission uses the full-prompt estimate (>= the gate's per-chunk estimate) with the same margin, a correctly-admitted request cannot trip the gate under static memory -- it only fires on post-admission drift. The margin 12 fix repairs the gate's prior miss too. The pre-pop generation soft-guard is left as a complementary coarse early-out (soft limit, request-agnostic); the budget defer is the fine, request-aware, hard-cap check. The defer decision logs at info ([memadmit]) so it is visible during on-hardware validation alongside the enforcer [memcheck] ceiling. Honest residual: a single huge-context request on an 85GB model can still need a preflight reject (or a smaller quant) -- entry admission solves concurrent stacking, not a lone prefill that physically cannot fit. The stable-read assumption still needs on-hardware [memadmit]/[memcheck] validation under concurrent + long glm4.5 load before this is proven. Tests: 14 new (8 predicate unit tests including recent_peak folds for running OR prefilling; budget-defer requeues-not-rejects for both the running and the chunked-prefill-only case; the dominance property that an admitted request never trips the gate under static memory + the gate firing under drift; lone-request ignores stale recent_peak). Full suite 4556 pass / 3 known api_key fail / 19 skip on m2max -- zero regression vs the baseline. --- 让请求准入感知内存, 使并发 prefill 不再叠加冲过 Metal cap 把 m5max 整机 kernel panic. 之前准入只受 max_num_seqs (固定并发) 限制, N 个请求打 85GB 模型各自跑 prefill, 瞬时叠加冲过仅 ~22GB 的余量. 主修复 -- 入口预算准入 (Scheduler._schedule_waiting): 有工作在飞时, 准入下一个 waiting 请求前预测其 prefill 峰值 (_predicted_prefill_peak_bytes = 当前高水位 + KV+SDPA 估计 + 瞬时 margin); 若会冲破 hard cap, 留在队列里 (appendleft + break, 不拒绝), 下一步等在飞工作释放 KV 后重判. 并发自适应, 无 per-model 旋钮: 内存富的 模型多并发, 85GB 模型自动压到 1 其余排队. 第一个请求 (无在飞工作) 永不 defer -- 连单独都放不下的孤请求改由 _preflight_memory_check 拒绝 (现共用同一判据, 故也带 margin). 在入口 (两个 generation step 之间) 读数是稳定高水位, 不是 forward gate 不得不读的 prefill 中途谷值. "在飞" = running 或 prefilling (_has_inflight_kv), 不是只看 running. 已记录的崩溃 正是并发 chunked prefill 叠加, 而 chunked-prefill 中途的请求在 self.prefilling 不在 self.running. 只看 self.running 会把第 2、3 ... 个 prefill 直接放过预算, 叠成这个 guard 要防的栈. self.prefilling 总会 drain, 故据它 defer 不会死锁; 孤请求仍走 preflight 拒绝 (两者皆空). margin 10 -> 12GB. 未建模的 MoE 反量化瞬时是 sub-poll (快于 enforcer 1s 采样), 对所有内存读数都不可见, 必须由 margin 兜, 而非把读数读得更聪明. 2026-06-06 m5max 崩溃实测有效瞬时达 ~10.6GB; margin 10 才是之前 forward gate 漏报的真因 (10 < 10.6), 不是谷读. 12 = ceil(10.6) 加垫. forward gate (_prefill_forward_gate) 保留并解除退役, 作 per-chunk 并发漂移兜底: 准入只快照一次内存, gate 每个 chunk 前重读, 抓另一在飞请求在本次长 prefill 期间 KV 增长. 因准入用 full-prompt 估计 (>= gate 的 per-chunk 估计) 且同 margin, 被正确 准入的请求在静态内存下绝不触发 gate -- 只在准入后漂移时才触发. margin 12 也修好了 gate 之前的漏报. pre-pop generation soft-guard 保留作互补的 coarse early-out (soft limit, 不看具体 请求); budget defer 是精确的, 看请求的, hard-cap 判据. defer 决策记 info ([memadmit]) 以便真机验证时与 enforcer [memcheck] 顶值一并可见. 诚实残余: 85GB 模型上单个超长上下文请求仍可能需 preflight 拒绝 (或换小量化) -- 入口准入解决的是并发叠加, 不是物理放不下的单次 prefill. 稳定读数这一假设仍需真机 [memadmit]/[memcheck] 在并发 + 长 glm4.5 负载下验证才算证实. 测试: 14 个新增 (8 个判据单元测试含 recent_peak 在 running 或 prefilling 时叠; budget-defer 重排队不拒绝, 覆盖 running 和仅 chunked-prefill 两种; dominance 性质即 被准入的请求静态内存下绝不触发 gate + 漂移时 gate 触发; 孤请求忽略 stale recent_peak). 完整套件 m2max 4556 pass / 3 已知 api_key fail / 19 skip -- 对 baseline 零回归. --- omlx/scheduler.py | 173 ++++++++--- omlx/settings.py | 56 ++-- tests/test_scheduler_admission.py | 80 ++++- tests/test_scheduler_budget_admission.py | 360 +++++++++++++++++++++++ 4 files changed, 588 insertions(+), 81 deletions(-) create mode 100644 tests/test_scheduler_budget_admission.py diff --git a/omlx/scheduler.py b/omlx/scheduler.py index 86ff37fc0..e99038ede 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -797,16 +797,19 @@ def __init__( # set; in-flight requests proceed. self._admission_paused: bool = False # Recent high-water memory usage, propagated from ProcessMemoryEnforcer. - # Preflight admission maxes the instant reading against this so it does - # not wave through a request during a prefill trough that would wall - # the next chunk. 0 until the enforcer sets it. + # The entry-budget admission (_predicted_prefill_peak_bytes) and the + # forward gate max the instant reading against this so they do not wave + # through a request during a prefill trough that would wall the next + # chunk. Folded in only while requests are in-flight (see + # _predicted_prefill_peak_bytes). 0 until the enforcer sets it. self._memory_recent_peak_bytes: int = 0 # Adaptive prefill throttle params, propagated from enforcer. # Until set, _adaptive_chunk_size is a no-op (returns requested as-is). self._prefill_safe_zone_ratio: float = 0.80 self._prefill_min_chunk_tokens: int = 32 - # Conservative transient margin (bytes) added to the modelled per-chunk - # prefill peak by the forward-front gate (_prefill_forward_gate). + # Conservative transient margin (bytes) added to the modelled prefill + # peak by both entry guards: _predicted_prefill_peak_bytes (admission) + # and _prefill_forward_gate (per-chunk backstop). # estimate_prefill_peak_bytes only models KV + SDPA; it does NOT model # the MoE expert-dequant activation spike, which on glm4.5-air-106b # (MoE) is the dominant single-step transient. Sized from the observed @@ -1751,6 +1754,16 @@ def _prefill_forward_gate( ) -> None: """Forward-FRONT memory gate: refuse a prefill chunk BEFORE it runs. + BACKSTOP to the primary entry-budget admission + (_predicted_prefill_peak_bytes + the _schedule_waiting defer). Admission + snapshots memory once, at admission time; this re-reads before every + chunk and so catches CONCURRENT DRIFT -- another in-flight request's KV + growing during this request's long prefill -- that the admission + snapshot cannot see. Because admission uses the full-prompt estimate + (>= this per-chunk estimate) and the same margin, a correctly-admitted + request cannot trip this gate under static memory: it only fires when + memory drifted up after admission. + The chunk-end check (after self.model(...) + mx.eval) only fires once the transient has already been allocated -- on Apple Silicon a chunk that overshoots the Metal cap kernel-panics the whole machine, so an @@ -4630,20 +4643,52 @@ def get_num_running(self) -> int: """Get number of running requests.""" return len(self.running) - def _preflight_memory_check(self, request: "Request") -> str | None: + def _has_inflight_kv(self) -> bool: + """True when any request holds (or is still accumulating) KV cache -- + a running decode (self.running) OR an in-flight chunked prefill + (self.prefilling). A waiting request has not prefilled yet, so it holds + no KV. Admission folds the recent_peak high-water and may DEFER only + while this is True; with nothing in flight there is nothing to drain, so + a request is admitted instead of queued forever (and a lone request too + big to fit even alone is rejected by _preflight_memory_check). + + Crucially this includes self.prefilling: the documented crash is + concurrent CHUNKED prefills stacking, and a request mid-chunked-prefill + lives in self.prefilling, not self.running. Gating on self.running alone + would wave the 2nd, 3rd, ... prefill straight past the budget into the + exact stack this guard exists to prevent. """ - Estimate whether prefill would exceed memory limits. - - Computes worst-case peak memory for the last prefill chunk - (model weights + KV cache + SDPA attention matrix) and rejects - if it would exceed the hard limit. - - For head_dim > 128, MLX SDPA uses a fallback that materializes - the full attention matrix [B, n_q, chunk, kv_len] in float32. - For head_dim <= 128, MLX uses a fused kernel with O(n) memory. - - Returns: - Error message string if request should be rejected, None if OK. + return bool(self.running or self.prefilling) + + def _predicted_prefill_peak_bytes(self, request: "Request") -> int | None: + """Predicted memory peak if this request's prefill ran now, or None. + + Shared predicate for both entry-point memory guards: the budget + admission in _schedule_waiting (admit / queue) and the lone-request + rejection in _preflight_memory_check (reject). + + predicted = current(high-water) + estimate_prefill_peak(KV+SDPA) + margin + + - current: max(active, phys_footprint). The enforcer recent_peak + high-water is folded in ONLY while work is in flight + (_has_inflight_kv): mid-generation/mid-prefill it masks the + active-memory trough left by the last _sync_and_clear_cache, but + when idle it can be a stale prior-batch peak that would false-reject + a lone request. + - estimate: memory_monitor.estimate_prefill_peak_bytes over the full + new (uncached) prompt -- KV for the whole prompt + last-chunk SDPA. + This is >= the per-chunk estimate the forward gate uses, so a + request admitted here cannot trip the gate under static memory (the + gate is a concurrent-drift backstop, not a second veto on the happy + path -- see _prefill_forward_gate). + - margin: _prefill_transient_margin_bytes carries the un-modelled MoE + expert-dequant transient, which is sub-poll and so invisible to + every memory read (see MemorySettings.prefill_transient_margin_gb). + + Returns None (caller treats as "fits") when the guard is off, the hard + limit is unset, the monitor is missing, there are no new tokens, or the + model cannot be estimated (estimate == 0) -- in each case admission + falls back to the per-concurrency cap and the legacy in-prefill checks. """ if not self._prefill_memory_guard: return None @@ -4652,35 +4697,50 @@ def _preflight_memory_check(self, request: "Request") -> str | None: if self.memory_monitor is None: return None - prompt_tokens = request.num_prompt_tokens - cached_tokens = request.cached_tokens or 0 - new_tokens = max(prompt_tokens - cached_tokens, 0) - + new_tokens = max(request.num_prompt_tokens - (request.cached_tokens or 0), 0) if new_tokens == 0: return None - peak = self.memory_monitor.estimate_prefill_peak_bytes( + estimate = self.memory_monitor.estimate_prefill_peak_bytes( new_tokens, self.config.prefill_step_size ) - if peak == 0: - return None # can't estimate, skip + if estimate == 0: + return None # can't estimate this model -> skip - current = max( - mx.get_active_memory(), - get_phys_footprint(), - self._memory_recent_peak_bytes, - ) + current = max(mx.get_active_memory(), get_phys_footprint()) + if self._has_inflight_kv(): + current = max(current, self._memory_recent_peak_bytes) + return current + estimate + self._prefill_transient_margin_bytes - if current + peak > self._memory_hard_limit_bytes: - from .utils.hardware import format_bytes + def _preflight_memory_check(self, request: "Request") -> str | None: + """Reject a request whose prefill cannot fit even on its own. - return ( - f"Prefill would require ~{format_bytes(current + peak)} peak " - f"(current {format_bytes(current)} + KV+SDPA {format_bytes(peak)}) " - f"but limit is {format_bytes(self._memory_hard_limit_bytes)}. " - f"Reduce context length or increase --max-process-memory." - ) - return None + Entry-point sibling of the budget admission in _schedule_waiting. The + admission DEFERS (queues) a request that would breach the cap only + because other requests are in-flight; this REJECTS a request that would + breach even with nothing else running -- nothing will drain to make + room, so queuing it would deadlock. The #1405 cleanup converts the + returned error into a finish_reason="error" output -- the request is + refused cleanly instead of crashing the box. + + Returns: + Error message string if request should be rejected, None if OK. + """ + predicted = self._predicted_prefill_peak_bytes(request) + if predicted is None: + return None + if predicted <= self._memory_hard_limit_bytes: + return None + + from .utils.hardware import format_bytes + + return ( + f"Prefill would need ~{format_bytes(predicted)} peak " + f"(KV+SDPA estimate + {format_bytes(self._prefill_transient_margin_bytes)} " + f"transient margin) but the memory ceiling is " + f"{format_bytes(self._memory_hard_limit_bytes)}. " + f"Reduce context length or increase --max-process-memory." + ) def _schedule_waiting( self, @@ -4778,6 +4838,41 @@ def _schedule_waiting( request.remaining_tokens = request.prompt_token_ids tokens_to_process = request.prompt_token_ids + # Memory-budget admission. Predict this request's prefill peak; if + # admitting it alongside the in-flight work would breach the hard + # cap, leave it in the waiting queue (do NOT reject) and stop + # admitting for this step. step() re-invokes _schedule_waiting every + # step, so it is re-checked and admitted once in-flight work frees + # its KV. This is what makes concurrency adaptive without a fixed + # per-model cap: a memory-rich model packs many requests, an 85GB + # model collapses to 1 and queues the rest. + # + # Gate on _has_inflight_kv() (running OR prefilling), not self.running + # alone: a request mid-chunked-prefill is in self.prefilling and is + # the dominant KV holder during the very stack this guards against + # (concurrent prefills). The first request (nothing in flight) is + # never deferred -- nothing would drain to admit it -- so a lone + # request that cannot fit even alone falls through to + # _preflight_memory_check below and is rejected instead. + if self._has_inflight_kv(): + predicted_peak = self._predicted_prefill_peak_bytes(request) + if ( + predicted_peak is not None + and predicted_peak > self._memory_hard_limit_bytes + ): + self.waiting.appendleft(request) + logger.info( + "[memadmit] deferring request %s: predicted prefill " + "peak %d + in-flight (%d running, %d prefilling) would " + "breach hard cap %d; left queued", + request.request_id, + predicted_peak, + len(self.running), + len(self.prefilling), + self._memory_hard_limit_bytes, + ) + break + # SpecPrefill requests must be alone in the batch (RoPE patching # affects the entire model). Also block scheduling if another # specprefill request is already running (offset RoPE active). diff --git a/omlx/settings.py b/omlx/settings.py index d9a4c631f..6dcd97480 100644 --- a/omlx/settings.py +++ b/omlx/settings.py @@ -391,33 +391,37 @@ class MemorySettings: # aborted via the same cleanup path the hard-limit RuntimeError uses. prefill_safe_zone_ratio: float = 0.80 prefill_min_chunk_tokens: int = 32 - # Conservative transient margin added to the modelled per-chunk prefill - # peak by the scheduler's forward-FRONT memory gate. The model in + # Conservative transient margin added to the modelled prefill peak by BOTH + # memory guards: the entry-point budget admission (primary -- scheduler. + # _predicted_prefill_peak_bytes, decides admit/queue/reject at admission + # time) and the forward-FRONT chunk gate (backstop -- scheduler. + # _prefill_forward_gate, refuses a chunk before its forward). The model in # memory_monitor.estimate_prefill_peak_bytes only accounts for KV + SDPA; - # it does NOT model the MoE expert-dequant activation spike, which on a - # MoE model (glm4.5-air-106b) is the dominant single-step transient. The - # gate refuses a chunk before its forward when current + estimate + this - # margin would breach the hard cap, so the transient never actually lands - # on the Metal ceiling (which would kernel-panic the whole machine, an - # after-the-fact Python check cannot catch it). + # it does NOT model the MoE expert-dequant activation spike, which on a MoE + # model (glm4.5-air-106b) is the dominant single-step transient. Either + # guard refuses when current + estimate + this margin would breach the hard + # cap, so the transient never actually lands on the Metal ceiling (which + # would kernel-panic the whole machine -- an after-the-fact Python check + # cannot catch it). # - # At chunk granularity the KV+SDPA estimate is tiny (~0.3GB for a - # 256-token GLM chunk), so this margin IS the safety mechanism. The - # load-bearing guarantee is: margin > the worst-case single-step memory - # jump. Across the 2026-06-06 m5max glm4.5-air-106b crash log the max - # trough->peak single-step delta was 7.44GB (peak overshoot reached - # 110.4GB vs a 107.5GB cap). With margin=10GB the gate effectively fires - # once current exceeds ~cap - margin (~97GB); since every observed forward - # jumps <=7.44GB < 10GB, any chunk that would land above the cap starts - # from a current already past that trip point, so the gate refuses it - # before the forward. 10 = ceil(7.44) padded for an unobserved larger - # spike. NOTE: this holds only while `current` is read at true high-water - # at gate time (it relies on phys_footprint stickiness + recent_peak to - # mask the post-_sync_and_clear_cache trough in mx.get_active_memory()); - # watch the [memgate]/[memcheck] logs on hardware and raise the margin if - # a sub-trip-point trough precedes an over-cap chunk. Set to 0 to disable - # the extra margin (the gate then uses the bare KV+SDPA estimate). - prefill_transient_margin_gb: float = 10.0 + # The load-bearing guarantee is: margin > the worst-case single-step memory + # jump. The jump is SUB-POLL -- it rises and falls faster than the + # enforcer's 1s sample, so it is invisible to every memory read (active, + # phys_footprint, recent_peak) by construction. It therefore MUST be carried + # by this margin, not by reading the footprint more cleverly. Across the + # 2026-06-06 m5max glm4.5-air-106b crash log the max trough->peak single-step + # delta was 7.44GB and the peak overshoot reached 110.4GB vs a 107.5GB cap, + # i.e. an effective transient up to ~10.6GB above the pre-step baseline. + # margin=10 was too small (10 < 10.6 -> admitted a step that then breached); + # 12 = ceil(10.6) padded, the value that would have refused that step from + # its true pre-step baseline. Both guards read `current` at high-water + # (max active / phys_footprint / -- when requests are in-flight -- the + # enforcer recent_peak); the baseline (KV+weights) is what they see and it + # is the determining factor, the sub-poll spike rides on the margin. Watch + # the [memgate]/[memcheck] logs on hardware and raise the margin if a step + # ever breaches from a baseline below cap - margin. Set to 0 to disable the + # extra margin (the guards then use the bare KV+SDPA estimate). + prefill_transient_margin_gb: float = 12.0 def to_dict(self) -> dict[str, Any]: """Convert to dictionary.""" @@ -469,7 +473,7 @@ def from_dict(cls, data: dict[str, Any]) -> MemorySettings: data.get("prefill_min_chunk_tokens", 32) ), prefill_transient_margin_gb=float( - data.get("prefill_transient_margin_gb", 10.0) + data.get("prefill_transient_margin_gb", 12.0) ), ) diff --git a/tests/test_scheduler_admission.py b/tests/test_scheduler_admission.py index 1640ef8d3..6954f3c89 100644 --- a/tests/test_scheduler_admission.py +++ b/tests/test_scheduler_admission.py @@ -104,16 +104,26 @@ def test_default_false(self): assert s._admission_paused is False -def _preflight_scheduler(hard_limit: int, recent_peak: int, peak: int): +def _preflight_scheduler( + hard_limit: int, recent_peak: int, peak: int, *, running=None +): """Build a bare Scheduler wired for _preflight_memory_check. `peak` is the value the (mocked) memory_monitor estimates for the - prefill chunk; `recent_peak` is the propagated high-water mark. + prefill chunk; `recent_peak` is the propagated high-water mark. `running` + seeds self.running -- _predicted_prefill_peak_bytes folds recent_peak into + `current` only while requests are in-flight, so an empty dict (idle, the + default) means a lone request is judged on the instant reading alone. + Transient margin is set to 0 here so these tests isolate the recent_peak + folding behaviour; margin handling is covered separately. """ s = Scheduler.__new__(Scheduler) s._prefill_memory_guard = True s._memory_hard_limit_bytes = hard_limit s._memory_recent_peak_bytes = recent_peak + s._prefill_transient_margin_bytes = 0 + s.running = running if running is not None else {} + s.prefilling = deque() s.config = MagicMock(prefill_step_size=2048) s.memory_monitor = MagicMock() s.memory_monitor.estimate_prefill_peak_bytes = MagicMock(return_value=peak) @@ -128,28 +138,35 @@ def _preflight_request(): class TestPreflightRecentPeak: - """_preflight_memory_check uses the recent high-water mark, not just the - instant reading, so it does not wave through a request during a prefill - trough that would wall the next chunk.""" + """_preflight_memory_check folds the recent high-water mark into `current` + while requests are in-flight, so it does not wave through a request during a + prefill trough that would wall the next chunk -- but it ignores recent_peak + when idle, so a stale prior-batch peak does not false-reject a lone request. + """ - def test_rejects_on_recent_peak_when_instant_is_low(self): - """Instant active/phys low but recent_peak high -> reject. + def test_rejects_on_recent_peak_when_inflight_and_instant_is_low(self): + """In-flight + instant active/phys low but recent_peak high -> reject. - Picks numbers so that low + peak fits (pre-change behaviour would - admit) but recent_peak + peak exceeds the hard limit. This pins the - fix. + Models the mid-prefill trough: another request is running, the instant + reading dipped after a _sync_and_clear_cache, but recent_peak still + reflects the real in-flight footprint. Numbers are picked so low + peak + fits (an instant-only read would admit) but recent_peak + peak exceeds + the hard limit. This pins the high-water fold. """ hard_limit = 100 * GB peak = 20 * GB low = 10 * GB high = 85 * GB - # Sanity: old code (low + peak) would have passed. + # Sanity: an instant-only read (low + peak) would have passed. assert low + peak <= hard_limit - # New code (high + peak) must exceed the limit. + # Folding recent_peak (high + peak) must exceed the limit. assert high + peak > hard_limit s = _preflight_scheduler( - hard_limit=hard_limit, recent_peak=high, peak=peak + hard_limit=hard_limit, + recent_peak=high, + peak=peak, + running={"r-other": object()}, ) with patch("omlx.scheduler.mx") as mock_mx, patch( "omlx.scheduler.get_phys_footprint", return_value=low @@ -158,16 +175,47 @@ def test_rejects_on_recent_peak_when_instant_is_low(self): result = s._preflight_memory_check(_preflight_request()) assert result is not None - assert "Prefill would require" in result + assert "Prefill would need" in result def test_admits_when_recent_peak_also_low(self): - """Control: when recent_peak is low too, the request passes.""" + """Control: in-flight, recent_peak low too -> the request passes.""" + hard_limit = 100 * GB + peak = 20 * GB + low = 10 * GB + + s = _preflight_scheduler( + hard_limit=hard_limit, + recent_peak=low, + peak=peak, + running={"r-other": object()}, + ) + with patch("omlx.scheduler.mx") as mock_mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=low + ): + mock_mx.get_active_memory.return_value = low + result = s._preflight_memory_check(_preflight_request()) + + assert result is None + + def test_lone_request_ignores_stale_recent_peak(self): + """Idle (no in-flight requests): a high recent_peak is NOT folded in, + so a lone request that fits on the instant reading is admitted. + + recent_peak when idle is a stale prior-batch high; folding it would + false-reject a request that physically fits. Same numbers as the + in-flight reject test, only running is empty -> opposite outcome. + """ hard_limit = 100 * GB peak = 20 * GB low = 10 * GB + high = 85 * GB + # Folding recent_peak would exceed the limit (the in-flight case)... + assert high + peak > hard_limit + # ...but idle, only the instant reading counts, which fits. + assert low + peak <= hard_limit s = _preflight_scheduler( - hard_limit=hard_limit, recent_peak=low, peak=peak + hard_limit=hard_limit, recent_peak=high, peak=peak, running={} ) with patch("omlx.scheduler.mx") as mock_mx, patch( "omlx.scheduler.get_phys_footprint", return_value=low diff --git a/tests/test_scheduler_budget_admission.py b/tests/test_scheduler_budget_admission.py new file mode 100644 index 000000000..254dc8f63 --- /dev/null +++ b/tests/test_scheduler_budget_admission.py @@ -0,0 +1,360 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the entry-point memory-budget admission and its relationship to the +per-chunk forward gate. + +The shared predicate _predicted_prefill_peak_bytes drives two outcomes: + - _schedule_waiting DEFERS (re-queues, does NOT reject) a request that would + breach the hard cap only because other requests are in-flight -- adaptive + concurrency. + - _preflight_memory_check REJECTS a lone request that cannot fit even alone. + +_prefill_forward_gate stays as a concurrent-drift backstop. Because admission +estimates the FULL prompt while the gate estimates a single CHUNK (and both use +the same margin), a correctly-admitted request cannot trip the gate under static +memory; the gate only fires when memory drifts up after admission. +""" + +from collections import deque +from unittest.mock import MagicMock, patch + +import pytest + +from omlx.memory_monitor import MemoryMonitor +from omlx.scheduler import Scheduler + +GB = 1024**3 + + +def _admission_scheduler( + *, + hard_limit, + recent_peak, + estimate, + margin, + running, + prefilling=None, + guard=True, + monitor=True, +): + """Bare Scheduler wired for _predicted_prefill_peak_bytes / admission.""" + s = Scheduler.__new__(Scheduler) + s._prefill_memory_guard = guard + s._memory_hard_limit_bytes = hard_limit + s._memory_recent_peak_bytes = recent_peak + s._prefill_transient_margin_bytes = margin + s.running = running + s.prefilling = prefilling if prefilling is not None else deque() + s.config = MagicMock(prefill_step_size=2048) + if monitor: + s.memory_monitor = MagicMock() + s.memory_monitor.estimate_prefill_peak_bytes = MagicMock( + return_value=estimate + ) + else: + s.memory_monitor = None + return s + + +def _request(*, num_prompt_tokens=8192, cached_tokens=0): + r = MagicMock() + r.request_id = "rid-1" + r.num_prompt_tokens = num_prompt_tokens + r.cached_tokens = cached_tokens + return r + + +class TestPredictedPrefillPeakBytes: + """Direct tests of the shared admission predicate.""" + + def test_sums_current_estimate_and_margin(self): + s = _admission_scheduler( + hard_limit=100 * GB, + recent_peak=0, + estimate=20 * GB, + margin=12 * GB, + running={}, + ) + with patch("omlx.scheduler.mx") as mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=30 * GB + ): + mx.get_active_memory.return_value = 30 * GB + predicted = s._predicted_prefill_peak_bytes(_request()) + assert predicted == 30 * GB + 20 * GB + 12 * GB + + def test_folds_recent_peak_when_running_or_prefilling(self): + """In-flight (running OR prefilling) -> recent_peak high-water folded; + fully idle -> instant reading only.""" + common = dict( + hard_limit=100 * GB, recent_peak=90 * GB, estimate=10 * GB, margin=0 + ) + s_idle = _admission_scheduler(**common, running={}, prefilling=deque()) + s_running = _admission_scheduler(**common, running={"o": object()}) + s_prefilling = _admission_scheduler( + **common, running={}, prefilling=deque([object()]) + ) + with patch("omlx.scheduler.mx") as mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=10 * GB + ): + mx.get_active_memory.return_value = 10 * GB + idle = s_idle._predicted_prefill_peak_bytes(_request()) + running = s_running._predicted_prefill_peak_bytes(_request()) + prefilling = s_prefilling._predicted_prefill_peak_bytes(_request()) + assert idle == 10 * GB + 10 * GB # instant reading only + assert running == 90 * GB + 10 * GB # folded: a decode is in flight + assert prefilling == 90 * GB + 10 * GB # folded: a chunked prefill too + + def test_cached_tokens_reduce_new_tokens(self): + s = _admission_scheduler( + hard_limit=100 * GB, + recent_peak=0, + estimate=5 * GB, + margin=0, + running={}, + ) + with patch("omlx.scheduler.mx") as mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=0 + ): + mx.get_active_memory.return_value = 0 + s._predicted_prefill_peak_bytes( + _request(num_prompt_tokens=8192, cached_tokens=8000) + ) + # estimate is asked for the 192 UNCACHED tokens, not the full 8192. + s.memory_monitor.estimate_prefill_peak_bytes.assert_called_once_with( + 192, 2048 + ) + + def test_none_when_guard_off(self): + s = _admission_scheduler( + hard_limit=100 * GB, + recent_peak=0, + estimate=5 * GB, + margin=0, + running={}, + guard=False, + ) + assert s._predicted_prefill_peak_bytes(_request()) is None + + def test_none_when_hard_limit_unset(self): + s = _admission_scheduler( + hard_limit=0, + recent_peak=0, + estimate=5 * GB, + margin=0, + running={}, + ) + assert s._predicted_prefill_peak_bytes(_request()) is None + + def test_none_when_monitor_missing(self): + s = _admission_scheduler( + hard_limit=100 * GB, + recent_peak=0, + estimate=5 * GB, + margin=0, + running={}, + monitor=False, + ) + assert s._predicted_prefill_peak_bytes(_request()) is None + + def test_none_when_no_new_tokens(self): + s = _admission_scheduler( + hard_limit=100 * GB, + recent_peak=0, + estimate=5 * GB, + margin=0, + running={}, + ) + req = _request(num_prompt_tokens=4096, cached_tokens=4096) + assert s._predicted_prefill_peak_bytes(req) is None + + def test_none_when_estimate_zero(self): + s = _admission_scheduler( + hard_limit=100 * GB, + recent_peak=0, + estimate=0, + margin=0, + running={}, + ) + with patch("omlx.scheduler.mx") as mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=0 + ): + mx.get_active_memory.return_value = 0 + assert s._predicted_prefill_peak_bytes(_request()) is None + + +class TestScheduleWaitingBudgetDefer: + """The defer branch in _schedule_waiting QUEUES, it does not reject.""" + + def _defer_scheduler( + self, *, hard_limit, recent_peak, estimate, margin, running=None, prefilling=None + ): + s = _admission_scheduler( + hard_limit=hard_limit, + recent_peak=recent_peak, + estimate=estimate, + margin=margin, + running=running if running is not None else {"r-other": object()}, + prefilling=prefilling, + ) + s.config = MagicMock(max_num_seqs=8, prefill_step_size=2048) + s._admission_paused = False + s._memory_limit_bytes = 0 # bypass the coarse generation soft-guard + s.batch_generator = MagicMock() + s._ensure_batch_generator = MagicMock() + return s + + def _waiting_request(self): + req = _request() + req.prompt_cache = None + req.remaining_tokens = None + req.prompt_token_ids = [1, 2, 3] + return req + + def test_defers_and_requeues_when_inflight_would_breach(self): + # predicted = max(10, 95) + 20 + 12 = 127 > 100 -> defer. + s = self._defer_scheduler( + hard_limit=100 * GB, + recent_peak=95 * GB, + estimate=20 * GB, + margin=12 * GB, + ) + req = self._waiting_request() + s.waiting = deque([req]) + + with patch("omlx.scheduler.mx") as mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=10 * GB + ): + mx.get_active_memory.return_value = 10 * GB + scheduled, rejected = s._schedule_waiting() + + # Nothing scheduled, nothing REJECTED -- the request is left queued. + assert scheduled == [] + assert rejected == [] + assert list(s.waiting) == [req] + + def test_defers_when_only_a_chunked_prefill_is_in_flight(self): + """The headline case: running is EMPTY but another request is + mid-chunked-prefill (self.prefilling). The new request must still be + deferred -- gating on self.running alone would stack a second prefill, + which is exactly the documented crash. + """ + # predicted = max(10, 95) + 20 + 12 = 127 > 100 -> defer. + s = self._defer_scheduler( + hard_limit=100 * GB, + recent_peak=95 * GB, + estimate=20 * GB, + margin=12 * GB, + running={}, # nothing decoding... + prefilling=deque([object()]), # ...but a chunked prefill is in flight + ) + req = self._waiting_request() + s.waiting = deque([req]) + + with patch("omlx.scheduler.mx") as mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=10 * GB + ): + mx.get_active_memory.return_value = 10 * GB + scheduled, rejected = s._schedule_waiting() + + assert scheduled == [] + assert rejected == [] + assert list(s.waiting) == [req] # queued behind the in-flight prefill + + def test_does_not_defer_when_predicted_fits(self): + # predicted = max(10, 30) + 20 + 12 = 62 <= 100 -> the budget branch + # does not fire. Asserted at the predicate so we do not have to drive + # the heavy insert path: a fitting request is never re-queued by the + # budget defer. + s = self._defer_scheduler( + hard_limit=100 * GB, + recent_peak=30 * GB, + estimate=20 * GB, + margin=12 * GB, + ) + with patch("omlx.scheduler.mx") as mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=10 * GB + ): + mx.get_active_memory.return_value = 10 * GB + predicted = s._predicted_prefill_peak_bytes(self._waiting_request()) + assert predicted is not None + assert predicted <= s._memory_hard_limit_bytes + + +class TestAdmissionDominatesGate: + """Property: a request admitted by the entry budget cannot trip the + per-chunk forward gate under static memory; the gate is a drift backstop.""" + + def _monitor(self): + m = MemoryMonitor(max_kv_cache_memory=64 * GB) + m.set_model_info( + num_layers=32, + num_kv_heads=8, + head_dim=128, + dtype_size=2, + num_attention_heads=32, + ) + return m + + def _scheduler_with(self, monitor, *, cap, current, margin): + s = Scheduler.__new__(Scheduler) + s._prefill_memory_guard = True + s._memory_hard_limit_bytes = cap + s._memory_recent_peak_bytes = current + s._prefill_transient_margin_bytes = margin + s.running = {"r-other": object()} + s.prefilling = deque() + s.config = MagicMock(prefill_step_size=2048) + s.memory_monitor = monitor + return s + + def test_admitted_request_never_trips_gate_under_static_memory(self): + monitor = self._monitor() + full_prompt, chunk, step = 8192, 256, 2048 + admission_estimate = monitor.estimate_prefill_peak_bytes(full_prompt, step) + gate_estimate = monitor.estimate_prefill_peak_bytes(chunk, step) + # Real estimate is monotonic in prompt length -> full >= per-chunk. + assert 0 < gate_estimate <= admission_estimate + + current = 80 * GB + margin = 12 * GB + # cap set exactly at the admission boundary: admission JUST passes. + cap = current + admission_estimate + margin + s = self._scheduler_with(monitor, cap=cap, current=current, margin=margin) + + with patch("omlx.scheduler.mx") as mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=current + ): + mx.get_active_memory.return_value = current + predicted = s._predicted_prefill_peak_bytes( + _request(num_prompt_tokens=full_prompt) + ) + assert predicted is not None and predicted <= cap # admitted + + # Same static memory: the gate, on any chunk of this request, must + # NOT raise (gate predicted = current + gate_estimate + margin + # <= current + admission_estimate + margin = cap). + s._prefill_forward_gate( + chunk, request_id="rid-1", loop_label="external" + ) # no RuntimeError + + def test_gate_fires_when_memory_drifts_up_after_admission(self): + monitor = self._monitor() + full_prompt, chunk, step = 8192, 256, 2048 + admission_estimate = monitor.estimate_prefill_peak_bytes(full_prompt, step) + + current = 80 * GB + margin = 12 * GB + cap = current + admission_estimate + margin + s = self._scheduler_with(monitor, cap=cap, current=current, margin=margin) + + # Other in-flight requests grow KV during this prefill: current drifts + # up well past where admission snapshotted it. The gate re-reads and + # catches what the admission snapshot could not. + drifted = current + admission_estimate + with patch("omlx.scheduler.mx") as mx, patch( + "omlx.scheduler.get_phys_footprint", return_value=drifted + ): + mx.get_active_memory.return_value = drifted + with pytest.raises(RuntimeError, match="refused before forward"): + s._prefill_forward_gate( + chunk, request_id="rid-1", loop_label="external" + )