Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions omlx/process_memory_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import logging
import subprocess
import sys
from collections import deque
from typing import TYPE_CHECKING, Any

import mlx.core as mx
Expand Down Expand Up @@ -291,6 +292,7 @@ def __init__(
hard_threshold: float = 0.95,
prefill_safe_zone_ratio: float = 0.80,
prefill_min_chunk_tokens: int = 32,
prefill_transient_margin_gb: float = 0.0,
):
"""
Initialize the process memory enforcer.
Expand All @@ -317,6 +319,11 @@ def __init__(
prefill_safe_zone_ratio: Fraction of hard cap below which prefill
runs at full chunk size; above triggers adaptive shrink.
prefill_min_chunk_tokens: Floor for adaptive shrink.
prefill_transient_margin_gb: Conservative margin added to the
modelled per-chunk prefill peak by the scheduler's
forward-front gate, covering the MoE expert-dequant activation
spike that estimate_prefill_peak_bytes does not model.
Propagated to each scheduler. 0 = no extra margin.
"""
self._engine_pool = engine_pool
self._memory_guard_tier = self._normalize_tier(memory_guard_tier)
Expand All @@ -331,6 +338,9 @@ def __init__(
self._hard_threshold = hard_threshold
self._prefill_safe_zone_ratio = prefill_safe_zone_ratio
self._prefill_min_chunk_tokens = prefill_min_chunk_tokens
self._prefill_transient_margin_bytes = max(
0, int(prefill_transient_margin_gb * 1024**3)
)
self._task: asyncio.Task | None = None
self._running = False
# Most recently observed pressure level, consumed by scheduler /
Expand All @@ -340,6 +350,13 @@ def __init__(
# or the call failed). Used by the admin dashboard to surface a
# warning when the kernel iogpu.wired_limit_mb is below this.
self._metal_wired_limit_request: int = 0
# Rolling window of recent usage readings + their high-water mark.
# Prefill memory dips into a trough between chunks, so the instant
# reading can read low mid-prefill; preflight admission consults this
# peak instead so it does not wave through a request that will wall
# the next chunk. Updated on every poll iteration.
self._usage_window: deque[int] = deque(maxlen=5)
self._recent_peak_bytes: int = 0

@staticmethod
def _normalize_tier(tier: str) -> str:
Expand Down Expand Up @@ -503,6 +520,10 @@ def get_final_ceiling(self) -> int:
"""Public accessor used by engine_pool pre-load admission."""
return self._get_hard_limit_bytes()

def recent_peak_bytes(self) -> int:
"""Recent high-water memory usage over the last few poll ticks."""
return self._recent_peak_bytes

def _soft_bytes(self) -> int:
"""Soft watermark: ceiling * soft_threshold."""
ceiling = self._get_hard_limit_bytes()
Expand Down Expand Up @@ -589,6 +610,10 @@ def _propagate_memory_limit(self) -> None:
scheduler._admission_paused = admission_paused
scheduler._prefill_safe_zone_ratio = self._prefill_safe_zone_ratio
scheduler._prefill_min_chunk_tokens = self._prefill_min_chunk_tokens
scheduler._prefill_transient_margin_bytes = (
self._prefill_transient_margin_bytes
)
scheduler._memory_recent_peak_bytes = self._recent_peak_bytes
bg = getattr(scheduler, "batch_generator", None)
if bg is not None and hasattr(bg, "_memory_limit_bytes"):
bg._memory_limit_bytes = soft_limit
Expand Down Expand Up @@ -671,6 +696,8 @@ async def _check_and_enforce(self) -> None:
return

current = self._current_usage_bytes()
self._usage_window.append(current)
self._recent_peak_bytes = max(self._usage_window) if self._usage_window else current
soft = int(ceiling * self._soft_threshold)
hard = int(ceiling * self._hard_threshold)
prev_level = self._pressure_level
Expand Down
281 changes: 252 additions & 29 deletions omlx/scheduler.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions omlx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ async def lifespan(app: FastAPI):
hard_threshold=mem_cfg.hard_threshold,
prefill_safe_zone_ratio=mem_cfg.prefill_safe_zone_ratio,
prefill_min_chunk_tokens=mem_cfg.prefill_min_chunk_tokens,
prefill_transient_margin_gb=mem_cfg.prefill_transient_margin_gb,
)
_server_state.process_memory_enforcer = enforcer
_server_state.engine_pool._process_memory_enforcer = enforcer
Expand Down
35 changes: 35 additions & 0 deletions omlx/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,37 @@ class MemorySettings:
# aborted via the same cleanup path the hard-limit RuntimeError uses.
prefill_safe_zone_ratio: float = 0.80
prefill_min_chunk_tokens: int = 32
# Conservative transient margin added to the modelled prefill peak by BOTH
# memory guards: the entry-point budget admission (primary -- scheduler.
# _predicted_prefill_peak_bytes, decides admit/queue/reject at admission
# time) and the forward-FRONT chunk gate (backstop -- scheduler.
# _prefill_forward_gate, refuses a chunk before its forward). The model in
# memory_monitor.estimate_prefill_peak_bytes only accounts for KV + SDPA;
# it does NOT model the MoE expert-dequant activation spike, which on a MoE
# model (glm4.5-air-106b) is the dominant single-step transient. Either
# guard refuses when current + estimate + this margin would breach the hard
# cap, so the transient never actually lands on the Metal ceiling (which
# would kernel-panic the whole machine -- an after-the-fact Python check
# cannot catch it).
#
# The load-bearing guarantee is: margin > the worst-case single-step memory
# jump. The jump is SUB-POLL -- it rises and falls faster than the
# enforcer's 1s sample, so it is invisible to every memory read (active,
# phys_footprint, recent_peak) by construction. It therefore MUST be carried
# by this margin, not by reading the footprint more cleverly. Across the
# 2026-06-06 m5max glm4.5-air-106b crash log the max trough->peak single-step
# delta was 7.44GB and the peak overshoot reached 110.4GB vs a 107.5GB cap,
# i.e. an effective transient up to ~10.6GB above the pre-step baseline.
# margin=10 was too small (10 < 10.6 -> admitted a step that then breached);
# 12 = ceil(10.6) padded, the value that would have refused that step from
# its true pre-step baseline. Both guards read `current` at high-water
# (max active / phys_footprint / -- when requests are in-flight -- the
# enforcer recent_peak); the baseline (KV+weights) is what they see and it
# is the determining factor, the sub-poll spike rides on the margin. Watch
# the [memgate]/[memcheck] logs on hardware and raise the margin if a step
# ever breaches from a baseline below cap - margin. Set to 0 to disable the
# extra margin (the guards then use the bare KV+SDPA estimate).
prefill_transient_margin_gb: float = 12.0

def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
Expand All @@ -402,6 +433,7 @@ def to_dict(self) -> dict[str, Any]:
"hard_threshold": self.hard_threshold,
"prefill_safe_zone_ratio": self.prefill_safe_zone_ratio,
"prefill_min_chunk_tokens": self.prefill_min_chunk_tokens,
"prefill_transient_margin_gb": self.prefill_transient_margin_gb,
}

@classmethod
Expand Down Expand Up @@ -440,6 +472,9 @@ def from_dict(cls, data: dict[str, Any]) -> MemorySettings:
prefill_min_chunk_tokens=int(
data.get("prefill_min_chunk_tokens", 32)
),
prefill_transient_margin_gb=float(
data.get("prefill_transient_margin_gb", 12.0)
),
)


Expand Down
57 changes: 57 additions & 0 deletions tests/test_process_memory_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,63 @@ async def test_check_and_enforce_walks_caps_on_soft(self, enforcer):
scheduler.adjust_store_cache_cap.assert_called_with("soft")


class TestRecentPeakTracking:
"""Tests for recent-peak high-water tracking across poll ticks."""

@pytest.mark.asyncio
async def test_recent_peak_is_window_max(self, enforcer):
"""After several poll ticks, recent_peak == max over the window.

The window update lives after the ceiling > 0 early return in
_check_and_enforce, so the fixture's positive ceiling (10 GB) is
required for the update to run at all.
"""
readings = [3 * 1024**3, 5 * 1024**3, 2 * 1024**3, 4 * 1024**3]
with patch("omlx.process_memory_enforcer.mx") as mock_mx, patch(
"omlx.process_memory_enforcer.get_phys_footprint", return_value=0
):
mock_mx.get_active_memory.side_effect = _cycling(readings)
for _ in readings:
await enforcer._check_and_enforce()

assert enforcer.recent_peak_bytes() == 5 * 1024**3

@pytest.mark.asyncio
async def test_recent_peak_drops_after_window_slides(self, enforcer):
"""Old high readings age out once they leave the maxlen=5 window."""
# Feed one big reading, then enough small ones to push it out of the
# 5-slot window.
big = 9 * 1024**3
small = 1 * 1024**3
readings = [big, small, small, small, small, small]
with patch("omlx.process_memory_enforcer.mx") as mock_mx, patch(
"omlx.process_memory_enforcer.get_phys_footprint", return_value=0
):
mock_mx.get_active_memory.side_effect = _cycling(readings)
for _ in readings:
await enforcer._check_and_enforce()

# After 6 ticks the first (big) reading has slid out of the window,
# leaving only small readings.
assert enforcer.recent_peak_bytes() == small

def test_propagates_recent_peak_to_scheduler(self, enforcer):
"""_propagate_memory_limit pushes recent_peak onto each scheduler."""
scheduler = MagicMock(spec=[])
scheduler._memory_limit_bytes = 0
scheduler._memory_hard_limit_bytes = 0
scheduler._memory_recent_peak_bytes = 0
engine = MagicMock(spec=[])
engine.scheduler = scheduler
entry = _make_entry("model-a", engine=engine)
enforcer._engine_pool._entries = {"model-a": entry}

enforcer._recent_peak_bytes = 7 * 1024**3
enforcer._propagate_memory_limit()

assert scheduler._memory_recent_peak_bytes == 7 * 1024**3


class TestProperties:
"""Tests for enforcer properties."""

Expand Down
126 changes: 125 additions & 1 deletion tests/test_scheduler_admission.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
"""Tests for scheduler admission control (queue depth cap + admission_paused)."""

from collections import deque
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pytest

from omlx.exceptions import SchedulerQueueFullError
from omlx.scheduler import Scheduler

GB = 1024**3


@pytest.fixture
def scheduler():
Expand Down Expand Up @@ -100,3 +102,125 @@ def test_default_false(self):
s._prefill_memory_guard = False
s._admission_paused = False
assert s._admission_paused is False


def _preflight_scheduler(
hard_limit: int, recent_peak: int, peak: int, *, running=None
):
"""Build a bare Scheduler wired for _preflight_memory_check.

`peak` is the value the (mocked) memory_monitor estimates for the
prefill chunk; `recent_peak` is the propagated high-water mark. `running`
seeds self.running -- _predicted_prefill_peak_bytes folds recent_peak into
`current` only while requests are in-flight, so an empty dict (idle, the
default) means a lone request is judged on the instant reading alone.
Transient margin is set to 0 here so these tests isolate the recent_peak
folding behaviour; margin handling is covered separately.
"""
s = Scheduler.__new__(Scheduler)
s._prefill_memory_guard = True
s._memory_hard_limit_bytes = hard_limit
s._memory_recent_peak_bytes = recent_peak
s._prefill_transient_margin_bytes = 0
s.running = running if running is not None else {}
s.prefilling = deque()
s.config = MagicMock(prefill_step_size=2048)
s.memory_monitor = MagicMock()
s.memory_monitor.estimate_prefill_peak_bytes = MagicMock(return_value=peak)
return s


def _preflight_request():
r = MagicMock()
r.num_prompt_tokens = 8192
r.cached_tokens = 0
return r


class TestPreflightRecentPeak:
"""_preflight_memory_check folds the recent high-water mark into `current`
while requests are in-flight, so it does not wave through a request during a
prefill trough that would wall the next chunk -- but it ignores recent_peak
when idle, so a stale prior-batch peak does not false-reject a lone request.
"""

def test_rejects_on_recent_peak_when_inflight_and_instant_is_low(self):
"""In-flight + instant active/phys low but recent_peak high -> reject.

Models the mid-prefill trough: another request is running, the instant
reading dipped after a _sync_and_clear_cache, but recent_peak still
reflects the real in-flight footprint. Numbers are picked so low + peak
fits (an instant-only read would admit) but recent_peak + peak exceeds
the hard limit. This pins the high-water fold.
"""
hard_limit = 100 * GB
peak = 20 * GB
low = 10 * GB
high = 85 * GB
# Sanity: an instant-only read (low + peak) would have passed.
assert low + peak <= hard_limit
# Folding recent_peak (high + peak) must exceed the limit.
assert high + peak > hard_limit

s = _preflight_scheduler(
hard_limit=hard_limit,
recent_peak=high,
peak=peak,
running={"r-other": object()},
)
with patch("omlx.scheduler.mx") as mock_mx, patch(
"omlx.scheduler.get_phys_footprint", return_value=low
):
mock_mx.get_active_memory.return_value = low
result = s._preflight_memory_check(_preflight_request())

assert result is not None
assert "Prefill would need" in result

def test_admits_when_recent_peak_also_low(self):
"""Control: in-flight, recent_peak low too -> the request passes."""
hard_limit = 100 * GB
peak = 20 * GB
low = 10 * GB

s = _preflight_scheduler(
hard_limit=hard_limit,
recent_peak=low,
peak=peak,
running={"r-other": object()},
)
with patch("omlx.scheduler.mx") as mock_mx, patch(
"omlx.scheduler.get_phys_footprint", return_value=low
):
mock_mx.get_active_memory.return_value = low
result = s._preflight_memory_check(_preflight_request())

assert result is None

def test_lone_request_ignores_stale_recent_peak(self):
"""Idle (no in-flight requests): a high recent_peak is NOT folded in,
so a lone request that fits on the instant reading is admitted.

recent_peak when idle is a stale prior-batch high; folding it would
false-reject a request that physically fits. Same numbers as the
in-flight reject test, only running is empty -> opposite outcome.
"""
hard_limit = 100 * GB
peak = 20 * GB
low = 10 * GB
high = 85 * GB
# Folding recent_peak would exceed the limit (the in-flight case)...
assert high + peak > hard_limit
# ...but idle, only the instant reading counts, which fits.
assert low + peak <= hard_limit

s = _preflight_scheduler(
hard_limit=hard_limit, recent_peak=high, peak=peak, running={}
)
with patch("omlx.scheduler.mx") as mock_mx, patch(
"omlx.scheduler.get_phys_footprint", return_value=low
):
mock_mx.get_active_memory.return_value = low
result = s._preflight_memory_check(_preflight_request())

assert result is None
Loading