Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions omlx/process_memory_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import logging
import subprocess
import sys
from collections import deque
from typing import TYPE_CHECKING, Any

import mlx.core as mx
Expand Down Expand Up @@ -291,6 +292,7 @@ def __init__(
hard_threshold: float = 0.95,
prefill_safe_zone_ratio: float = 0.80,
prefill_min_chunk_tokens: int = 32,
prefill_transient_margin_gb: float = 0.0,
):
"""
Initialize the process memory enforcer.
Expand All @@ -317,6 +319,11 @@ def __init__(
prefill_safe_zone_ratio: Fraction of hard cap below which prefill
runs at full chunk size; above triggers adaptive shrink.
prefill_min_chunk_tokens: Floor for adaptive shrink.
prefill_transient_margin_gb: Conservative margin added to the
modelled per-chunk prefill peak by the scheduler's
forward-front gate, covering the MoE expert-dequant activation
spike that estimate_prefill_peak_bytes does not model.
Propagated to each scheduler. 0 = no extra margin.
"""
self._engine_pool = engine_pool
self._memory_guard_tier = self._normalize_tier(memory_guard_tier)
Expand All @@ -331,6 +338,9 @@ def __init__(
self._hard_threshold = hard_threshold
self._prefill_safe_zone_ratio = prefill_safe_zone_ratio
self._prefill_min_chunk_tokens = prefill_min_chunk_tokens
self._prefill_transient_margin_bytes = max(
0, int(prefill_transient_margin_gb * 1024**3)
)
self._task: asyncio.Task | None = None
self._running = False
# Most recently observed pressure level, consumed by scheduler /
Expand All @@ -340,6 +350,13 @@ def __init__(
# or the call failed). Used by the admin dashboard to surface a
# warning when the kernel iogpu.wired_limit_mb is below this.
self._metal_wired_limit_request: int = 0
# Rolling window of recent usage readings + their high-water mark.
# Prefill memory dips into a trough between chunks, so the instant
# reading can read low mid-prefill; preflight admission consults this
# peak instead so it does not wave through a request that will wall
# the next chunk. Updated on every poll iteration.
self._usage_window: deque[int] = deque(maxlen=5)
self._recent_peak_bytes: int = 0

@staticmethod
def _normalize_tier(tier: str) -> str:
Expand Down Expand Up @@ -503,6 +520,10 @@ def get_final_ceiling(self) -> int:
"""Public accessor used by engine_pool pre-load admission."""
return self._get_hard_limit_bytes()

def recent_peak_bytes(self) -> int:
"""Recent high-water memory usage over the last few poll ticks."""
return self._recent_peak_bytes

def _soft_bytes(self) -> int:
"""Soft watermark: ceiling * soft_threshold."""
ceiling = self._get_hard_limit_bytes()
Expand Down Expand Up @@ -589,6 +610,10 @@ def _propagate_memory_limit(self) -> None:
scheduler._admission_paused = admission_paused
scheduler._prefill_safe_zone_ratio = self._prefill_safe_zone_ratio
scheduler._prefill_min_chunk_tokens = self._prefill_min_chunk_tokens
scheduler._prefill_transient_margin_bytes = (
self._prefill_transient_margin_bytes
)
scheduler._memory_recent_peak_bytes = self._recent_peak_bytes
bg = getattr(scheduler, "batch_generator", None)
if bg is not None and hasattr(bg, "_memory_limit_bytes"):
bg._memory_limit_bytes = soft_limit
Expand Down Expand Up @@ -671,6 +696,8 @@ async def _check_and_enforce(self) -> None:
return

current = self._current_usage_bytes()
self._usage_window.append(current)
self._recent_peak_bytes = max(self._usage_window) if self._usage_window else current
soft = int(ceiling * self._soft_threshold)
hard = int(ceiling * self._hard_threshold)
prev_level = self._pressure_level
Expand Down
130 changes: 129 additions & 1 deletion omlx/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,10 +796,23 @@ def __init__(
# soft_threshold. Schedulers stop admitting new prefills while this is
# set; in-flight requests proceed.
self._admission_paused: bool = False
# Recent high-water memory usage, propagated from ProcessMemoryEnforcer.
# Preflight admission maxes the instant reading against this so it does
# not wave through a request during a prefill trough that would wall
# the next chunk. 0 until the enforcer sets it.
self._memory_recent_peak_bytes: int = 0
# Adaptive prefill throttle params, propagated from enforcer.
# Until set, _adaptive_chunk_size is a no-op (returns requested as-is).
self._prefill_safe_zone_ratio: float = 0.80
self._prefill_min_chunk_tokens: int = 32
# Conservative transient margin (bytes) added to the modelled per-chunk
# prefill peak by the forward-front gate (_prefill_forward_gate).
# estimate_prefill_peak_bytes only models KV + SDPA; it does NOT model
# the MoE expert-dequant activation spike, which on glm4.5-air-106b
# (MoE) is the dominant single-step transient. Sized from the observed
# worst-case single-step current jump on m5max (see MemorySettings
# .prefill_transient_margin_gb). 0 until the enforcer propagates it.
self._prefill_transient_margin_bytes: int = 0
# EWMA estimator of per-token chunk transient bytes, used by
# _adaptive_chunk_size in the caution zone. Owned per-scheduler.
_tracker_model_id = ""
Expand Down Expand Up @@ -1729,6 +1742,97 @@ def _apply_turboquant_kv_convert(self, prompt_cache: list[Any]) -> None:
f"cache layers to {bits}-bit{skip_msg}"
)

def _prefill_forward_gate(
self,
chunk_tokens: int,
*,
request_id: str,
loop_label: str,
) -> None:
"""Forward-FRONT memory gate: refuse a prefill chunk BEFORE it runs.

The chunk-end check (after self.model(...) + mx.eval) only fires once
the transient has already been allocated -- on Apple Silicon a chunk
that overshoots the Metal cap kernel-panics the whole machine, so an
after-the-fact Python check never runs. This predicts the next chunk's
peak and raises BEFORE the forward when it would exceed the hard cap,
so the request is aborted cleanly (the #1405 cleanup paths convert this
RuntimeError into a finish_reason="error" output) instead of crashing.

predicted_peak = current(high-water) + estimate(KV+SDPA) + margin
- current: max(active, phys_footprint, recent_peak). recent_peak is
the enforcer's rolling high-water mark, so a mid-prefill trough in
the instant reading does not mask the real footprint.
- estimate: memory_monitor.estimate_prefill_peak_bytes models this
chunk's KV + SDPA activation only.
- margin: _prefill_transient_margin_bytes covers what the estimate
does NOT model -- chiefly the MoE expert-dequant activation spike,
the dominant single-step transient on MoE models like glm4.5-air.

No-op (returns) when the guard is off, the hard limit is unset, the
memory_monitor is missing, or the estimate is unavailable (0) -- in
every such case the legacy chunk-end check remains the only line of
defense, exactly as before this gate existed.

At chunk granularity the KV+SDPA estimate is small, so the margin is
the effective trip point (gate fires once current ~> cap - margin).
Correctness depends on `current` reflecting the true high-water mark:
this iteration runs right after the previous chunk's
_sync_and_clear_cache, when mx.get_active_memory() troughs, so the
gate leans on get_phys_footprint() + the propagated recent_peak to
avoid reading a trough. See MemorySettings.prefill_transient_margin_gb.

Raises:
RuntimeError: when the predicted peak exceeds the hard limit.
"""
if not self._prefill_memory_guard:
return
if self._memory_hard_limit_bytes <= 0:
return
if self.memory_monitor is None:
return
if chunk_tokens <= 0:
return

estimate = self.memory_monitor.estimate_prefill_peak_bytes(
chunk_tokens, self.config.prefill_step_size
)
if estimate == 0:
return # can't estimate this model -> leave it to the chunk-end check

predicted_transient = estimate + self._prefill_transient_margin_bytes
current = max(
mx.get_active_memory(),
get_phys_footprint(),
self._memory_recent_peak_bytes,
)
predicted_peak = current + predicted_transient

if predicted_peak > self._memory_hard_limit_bytes:
logger.warning(
"[memgate:%s] rid=%s refusing prefill chunk (n=%d) BEFORE "
"forward: predicted peak %.3fGB = current %.3fGB + transient "
"%.3fGB (estimate %.3fGB + margin %.3fGB) exceeds hard cap "
"%.3fGB. Aborting request to avoid a Metal-cap kernel panic.",
loop_label,
request_id,
chunk_tokens,
predicted_peak / 1024**3,
current / 1024**3,
predicted_transient / 1024**3,
estimate / 1024**3,
self._prefill_transient_margin_bytes / 1024**3,
self._memory_hard_limit_bytes / 1024**3,
)
raise RuntimeError(
"Prefill refused before forward: predicted peak "
f"{predicted_peak / 1024**3:.1f}GB (current "
f"{current / 1024**3:.1f}GB + transient "
f"{predicted_transient / 1024**3:.1f}GB) would exceed the "
f"memory ceiling {self._memory_hard_limit_bytes / 1024**3:.1f}GB. "
"Reduce context length or increase --max-process-memory."
)

def _do_external_prefill(
self,
request: "Request",
Expand Down Expand Up @@ -1885,6 +1989,15 @@ def _do_external_prefill(
extra_kwargs, n_to_process
)

# Forward-FRONT gate: predict this chunk's peak and refuse BEFORE
# the forward if it would breach the Metal cap (post-forward checks
# cannot save us -- the overshoot kernel-panics the machine).
self._prefill_forward_gate(
n_to_process,
request_id=request.request_id,
loop_label="external",
)

_throttle_pre = get_phys_footprint()
self.model(input_arr[:, :n_to_process], cache=prompt_cache, **model_kwargs)
mx.eval([c.state for c in prompt_cache])
Expand Down Expand Up @@ -2223,6 +2336,17 @@ def _step_prefill_chunk(self, state: _PrefillState) -> bool:

chunk = state.tokens_remaining[:, :n]
state.tokens_remaining = state.tokens_remaining[:, n:]

# Forward-FRONT gate: predict this chunk's peak and refuse BEFORE the
# forward if it would breach the Metal cap. Mirrors the external loop;
# raises RuntimeError that _advance_chunked_prefills converts into a
# finish_reason="error" output without crashing the machine.
self._prefill_forward_gate(
n,
request_id=state.request.request_id,
loop_label="chunked_step",
)

_throttle_pre = get_phys_footprint()
self.model(chunk, cache=state.cache)
mx.eval([c.state for c in state.cache])
Expand Down Expand Up @@ -4541,7 +4665,11 @@ def _preflight_memory_check(self, request: "Request") -> str | None:
if peak == 0:
return None # can't estimate, skip

current = max(mx.get_active_memory(), get_phys_footprint())
current = max(
mx.get_active_memory(),
get_phys_footprint(),
self._memory_recent_peak_bytes,
)

if current + peak > self._memory_hard_limit_bytes:
from .utils.hardware import format_bytes
Expand Down
1 change: 1 addition & 0 deletions omlx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ async def lifespan(app: FastAPI):
hard_threshold=mem_cfg.hard_threshold,
prefill_safe_zone_ratio=mem_cfg.prefill_safe_zone_ratio,
prefill_min_chunk_tokens=mem_cfg.prefill_min_chunk_tokens,
prefill_transient_margin_gb=mem_cfg.prefill_transient_margin_gb,
)
_server_state.process_memory_enforcer = enforcer
_server_state.engine_pool._process_memory_enforcer = enforcer
Expand Down
31 changes: 31 additions & 0 deletions omlx/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,33 @@ class MemorySettings:
# aborted via the same cleanup path the hard-limit RuntimeError uses.
prefill_safe_zone_ratio: float = 0.80
prefill_min_chunk_tokens: int = 32
# Conservative transient margin added to the modelled per-chunk prefill
# peak by the scheduler's forward-FRONT memory gate. The model in
# memory_monitor.estimate_prefill_peak_bytes only accounts for KV + SDPA;
# it does NOT model the MoE expert-dequant activation spike, which on a
# MoE model (glm4.5-air-106b) is the dominant single-step transient. The
# gate refuses a chunk before its forward when current + estimate + this
# margin would breach the hard cap, so the transient never actually lands
# on the Metal ceiling (which would kernel-panic the whole machine, an
# after-the-fact Python check cannot catch it).
#
# At chunk granularity the KV+SDPA estimate is tiny (~0.3GB for a
# 256-token GLM chunk), so this margin IS the safety mechanism. The
# load-bearing guarantee is: margin > the worst-case single-step memory
# jump. Across the 2026-06-06 m5max glm4.5-air-106b crash log the max
# trough->peak single-step delta was 7.44GB (peak overshoot reached
# 110.4GB vs a 107.5GB cap). With margin=10GB the gate effectively fires
# once current exceeds ~cap - margin (~97GB); since every observed forward
# jumps <=7.44GB < 10GB, any chunk that would land above the cap starts
# from a current already past that trip point, so the gate refuses it
# before the forward. 10 = ceil(7.44) padded for an unobserved larger
# spike. NOTE: this holds only while `current` is read at true high-water
# at gate time (it relies on phys_footprint stickiness + recent_peak to
# mask the post-_sync_and_clear_cache trough in mx.get_active_memory());
# watch the [memgate]/[memcheck] logs on hardware and raise the margin if
# a sub-trip-point trough precedes an over-cap chunk. Set to 0 to disable
# the extra margin (the gate then uses the bare KV+SDPA estimate).
prefill_transient_margin_gb: float = 10.0

def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
Expand All @@ -402,6 +429,7 @@ def to_dict(self) -> dict[str, Any]:
"hard_threshold": self.hard_threshold,
"prefill_safe_zone_ratio": self.prefill_safe_zone_ratio,
"prefill_min_chunk_tokens": self.prefill_min_chunk_tokens,
"prefill_transient_margin_gb": self.prefill_transient_margin_gb,
}

@classmethod
Expand Down Expand Up @@ -440,6 +468,9 @@ def from_dict(cls, data: dict[str, Any]) -> MemorySettings:
prefill_min_chunk_tokens=int(
data.get("prefill_min_chunk_tokens", 32)
),
prefill_transient_margin_gb=float(
data.get("prefill_transient_margin_gb", 10.0)
),
)


Expand Down
57 changes: 57 additions & 0 deletions tests/test_process_memory_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,63 @@ async def test_check_and_enforce_walks_caps_on_soft(self, enforcer):
scheduler.adjust_store_cache_cap.assert_called_with("soft")


class TestRecentPeakTracking:
"""Tests for recent-peak high-water tracking across poll ticks."""

@pytest.mark.asyncio
async def test_recent_peak_is_window_max(self, enforcer):
"""After several poll ticks, recent_peak == max over the window.

The window update lives after the ceiling > 0 early return in
_check_and_enforce, so the fixture's positive ceiling (10 GB) is
required for the update to run at all.
"""
readings = [3 * 1024**3, 5 * 1024**3, 2 * 1024**3, 4 * 1024**3]
with patch("omlx.process_memory_enforcer.mx") as mock_mx, patch(
"omlx.process_memory_enforcer.get_phys_footprint", return_value=0
):
mock_mx.get_active_memory.side_effect = _cycling(readings)
for _ in readings:
await enforcer._check_and_enforce()

assert enforcer.recent_peak_bytes() == 5 * 1024**3

@pytest.mark.asyncio
async def test_recent_peak_drops_after_window_slides(self, enforcer):
"""Old high readings age out once they leave the maxlen=5 window."""
# Feed one big reading, then enough small ones to push it out of the
# 5-slot window.
big = 9 * 1024**3
small = 1 * 1024**3
readings = [big, small, small, small, small, small]
with patch("omlx.process_memory_enforcer.mx") as mock_mx, patch(
"omlx.process_memory_enforcer.get_phys_footprint", return_value=0
):
mock_mx.get_active_memory.side_effect = _cycling(readings)
for _ in readings:
await enforcer._check_and_enforce()

# After 6 ticks the first (big) reading has slid out of the window,
# leaving only small readings.
assert enforcer.recent_peak_bytes() == small

def test_propagates_recent_peak_to_scheduler(self, enforcer):
"""_propagate_memory_limit pushes recent_peak onto each scheduler."""
scheduler = MagicMock(spec=[])
scheduler._memory_limit_bytes = 0
scheduler._memory_hard_limit_bytes = 0
scheduler._memory_recent_peak_bytes = 0
engine = MagicMock(spec=[])
engine.scheduler = scheduler
entry = _make_entry("model-a", engine=engine)
enforcer._engine_pool._entries = {"model-a": entry}

enforcer._recent_peak_bytes = 7 * 1024**3
enforcer._propagate_memory_limit()

assert scheduler._memory_recent_peak_bytes == 7 * 1024**3


class TestProperties:
"""Tests for enforcer properties."""

Expand Down
Loading