Skip to content
50 changes: 44 additions & 6 deletions omlx/process_memory_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,10 @@ def _get_dynamic_ceiling(self) -> int:
def _get_hard_limit_bytes(self) -> int:
"""Final hard ceiling = min(static, dynamic, metal_cap).

Thin wrapper over ``_get_ceiling_breakdown`` that discards the
component breakdown. Hot callers that don't need to know which
ceiling is binding should keep using this helper.

`metal_cap` is the effective Metal allocation cap (kernel
iogpu.wired_limit_mb when set, otherwise Apple's
max_recommended_working_set_size). Including it here means oMLX
Expand All @@ -562,17 +566,37 @@ def _get_hard_limit_bytes(self) -> int:
Returns 0 if the memory guard is disabled (callers treat 0 as
"no limit").
"""
return self._get_ceiling_breakdown()["hard_limit"]

def _get_ceiling_breakdown(self) -> dict[str, int]:
"""Compute the hard limit AND the three component ceilings.

Returns a dict with keys ``static``, ``dynamic``, ``metal_cap``,
``hard_limit`` (= min of the three non-zero values, or 0 when
the guard is disabled). Used by ``_propagate_memory_limit`` to
push the breakdown to schedulers so the prefill-rejection error
message can identify which constraint is binding and suggest the
right remedy. Single computation so the subprocess to ``sysctl``
(inside ``get_effective_metal_cap_bytes``) only fires once per
call.
"""
if not self._prefill_memory_guard:
return 0
candidates = [self._get_static_ceiling()]
return {"static": 0, "dynamic": 0, "metal_cap": 0, "hard_limit": 0}
static_ceiling = self._get_static_ceiling()
if self._memory_guard_tier == "custom":
candidates.append(max(0, self._memory_guard_custom_ceiling_bytes))
dynamic_ceiling = max(0, self._memory_guard_custom_ceiling_bytes)
else:
candidates.append(self._get_dynamic_ceiling())
dynamic_ceiling = self._get_dynamic_ceiling()
metal_cap = self._get_effective_metal_cap_bytes()
candidates = [static_ceiling, dynamic_ceiling]
if metal_cap > 0:
candidates.append(metal_cap)
return min(candidates)
return {
"static": static_ceiling,
"dynamic": dynamic_ceiling,
"metal_cap": metal_cap,
"hard_limit": min(candidates),
}

def get_final_ceiling(self) -> int:
"""Public accessor used by engine_pool pre-load admission."""
Expand Down Expand Up @@ -755,7 +779,8 @@ def _propagate_memory_limit(self) -> None:
Called on every enforcer tick so the dynamic ceiling reaches the
schedulers as fast as the poll interval allows.
"""
ceiling = self._get_hard_limit_bytes()
breakdown = self._get_ceiling_breakdown()
ceiling = breakdown["hard_limit"]
soft_limit = int(ceiling * self._soft_threshold) if ceiling > 0 else 0
admission_paused = self._pressure_level != "ok"
for entry in self._engine_pool._entries.values():
Expand Down Expand Up @@ -805,6 +830,19 @@ def _propagate_memory_limit(self) -> None:
scheduler._memory_hard_limit_bytes = ceiling
scheduler._memory_abort_limit_bytes = self._get_abort_limit_bytes()
scheduler._prefill_abort_margin = self._get_prefill_abort_margin()
# Propagate the component ceilings too so the rejection
# message in ``_preflight_memory_check`` can name the binding
# constraint and steer the user toward the right remedy
# (close apps for dynamic, raise sysctl for metal_cap, raise
# tier or reduce context for static).
scheduler._memory_static_ceiling_bytes = breakdown["static"]
scheduler._memory_dynamic_ceiling_bytes = breakdown["dynamic"]
scheduler._memory_metal_cap_bytes = breakdown["metal_cap"]
# Tier name disambiguates dynamic = computed reclaimable
# (safe/balanced/aggressive) from dynamic = user-pinned
# custom_ceiling_bytes (custom). The advice ladder needs
# the distinction to point at the right knob.
scheduler._memory_guard_tier = self._memory_guard_tier
scheduler._prefill_memory_guard = self._prefill_memory_guard
scheduler._admission_paused = admission_paused
scheduler._prefill_safe_zone_ratio = self._prefill_safe_zone_ratio
Expand Down
122 changes: 102 additions & 20 deletions omlx/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,18 @@ def __init__(
# executor thread. The background memory enforcer reads this cached
# value during active decode instead of touching MLX/Metal directly.
self._last_mlx_active_memory_bytes: int = 0
# Component ceilings — propagated alongside the hard limit so the
# rejection-path error message can identify which constraint is
# binding and suggest the right remedy (close apps / raise tier /
# raise iogpu.wired_limit_mb / reduce context). 0 = not set yet.
self._memory_static_ceiling_bytes: int = 0
self._memory_dynamic_ceiling_bytes: int = 0
self._memory_metal_cap_bytes: int = 0
# Tier name propagated alongside the breakdown. For ``custom`` the
# "dynamic" ceiling is the user-pinned ``custom_ceiling_bytes``
# rather than computed reclaimable memory, so the advice ladder
# must steer the user to that knob instead of "close other apps".
self._memory_guard_tier: str = "balanced"
self._prefill_memory_guard: bool = False # set by ProcessMemoryEnforcer
# Set to True by ProcessMemoryEnforcer when phys_footprint crosses
# soft_threshold. Schedulers stop admitting new prefills while this is
Expand Down Expand Up @@ -6039,15 +6051,11 @@ def _preflight_memory_check(
)
from .utils.hardware import format_bytes

usage_gb = current / (1024**3)
ceiling_gb = hard_limit / (1024**3)
message = (
f"Prefill would require ~{format_bytes(estimated)} peak "
f"(current {format_bytes(current)} + KV+SDPA {format_bytes(peak)}) "
f"but ceiling is {format_bytes(hard_limit)} "
f"(usage {usage_gb:.1f} GB, ceiling {ceiling_gb:.1f} GB). "
f"Reduce context length, free system memory, or loosen "
f"memory_guard_tier (safe → balanced → aggressive)."
message = self._format_rejection_message(
estimated=estimated,
current=current,
peak=peak,
hard_limit=hard_limit,
)
return _PreflightRejection(
message=message,
Expand All @@ -6074,6 +6082,86 @@ def _preflight_memory_check(
return safety_rejection
return None

def _format_rejection_message(
self,
*,
estimated: int,
current: int,
peak: int,
hard_limit: int,
) -> str:
"""Build the prefill-rejection diagnostic.

Identifies which of static / dynamic / metal_cap is binding so the
message can steer the user to the right remedy (close apps for
dynamic, raise sysctl for metal_cap, raise tier or reduce context
for static). Component ceilings are propagated by
``ProcessMemoryEnforcer._propagate_memory_limit``; if a caller
wired this scheduler outside that path the components stay 0 and
we fall back to a generic message.
"""
from .utils.hardware import format_bytes

static = self._memory_static_ceiling_bytes
dynamic = self._memory_dynamic_ceiling_bytes
metal_cap = self._memory_metal_cap_bytes

binding: list[str] = []
if static and static == hard_limit:
binding.append("static")
if dynamic and dynamic == hard_limit:
binding.append("dynamic")
if metal_cap and metal_cap == hard_limit:
binding.append("metal_cap")
binding_str = "/".join(binding) if binding else "effective"

# Order remedies by likelihood of helping for the binding cause.
# Dynamic-bound on a reclaim tier (safe/balanced/aggressive) means
# reclaimable memory is low right now even though the static cap
# has room — closing apps raises ``free`` / ``inactive`` and a
# more aggressive ``memory_guard_tier`` raises the active-reclaim
# ratio. Dynamic-bound under ``custom`` means the user pinned the
# ceiling there; the only knob that helps is raising
# ``custom_ceiling_bytes`` itself. Metal-cap bound means the
# kernel sysctl is the ceiling, so raising ``iogpu.wired_limit_mb``
# is the only knob that helps. Static-bound (or no breakdown
# available) leaves ``memory_guard_tier`` / context length as the
# levers.
is_custom = self._memory_guard_tier == "custom"
if "dynamic" in binding and is_custom:
advice = (
f"raise custom_ceiling_bytes in admin Memory settings "
f"(currently pinned at {format_bytes(dynamic)}), "
f"or reduce context length"
)
elif "dynamic" in binding and static and static > dynamic:
headroom = max(0, dynamic - current)
advice = (
f"close other apps to free RAM "
f"(static cap is {format_bytes(static)} but only "
f"{format_bytes(headroom)} is reclaimable right now), "
f"raise memory_guard_tier (safe → balanced → aggressive), "
f"or reduce context length"
)
elif "metal_cap" in binding:
advice = (
f"raise kernel iogpu.wired_limit_mb in Terminal "
f"(currently caps Metal at {format_bytes(metal_cap)}), "
f"or reduce context length"
)
else:
advice = (
"reduce context length or raise memory_guard_tier "
"(safe → balanced → aggressive)"
)

return (
f"Prefill would require ~{format_bytes(estimated)} peak "
f"(current {format_bytes(current)} + KV+SDPA {format_bytes(peak)}) "
f"but {binding_str} ceiling is {format_bytes(hard_limit)}. "
f"{advice.capitalize()}."
)

def preflight_or_raise(
self,
*,
Expand Down Expand Up @@ -6117,17 +6205,11 @@ def preflight_or_raise(
request_id = f"preflight-{_uuid.uuid4().hex[:8]}"

if current + peak > self._memory_hard_limit_bytes:
from .utils.hardware import format_bytes

usage_gb = current / (1024**3)
ceiling_gb = self._memory_hard_limit_bytes / (1024**3)
message = (
f"Prefill would require ~{format_bytes(current + peak)} peak "
f"(current {format_bytes(current)} + KV+SDPA {format_bytes(peak)}) "
f"but ceiling is {format_bytes(self._memory_hard_limit_bytes)} "
f"(usage {usage_gb:.1f} GB, ceiling {ceiling_gb:.1f} GB). "
f"Reduce context length, free system memory, or loosen "
f"memory_guard_tier (safe → balanced → aggressive)."
message = self._format_rejection_message(
estimated=current + peak,
current=current,
peak=peak,
hard_limit=self._memory_hard_limit_bytes,
)

logger.warning(
Expand Down
137 changes: 137 additions & 0 deletions tests/test_engine_preflight.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,3 +656,140 @@ def test_scheduler_add_request_cleans_block_table_on_rejection(monkeypatch):
# The request must not have entered self.waiting.
assert req not in scheduler.waiting
assert req.request_id not in scheduler.requests


# ---------------------------------------------------------------------------
# Rejection message identifies the binding ceiling
# ---------------------------------------------------------------------------


class TestRejectionMessageNamesBindingCeiling:
"""When a request is rejected, the message must name which of the
three component ceilings (static / dynamic / metal_cap) is binding
and steer the user to the right remedy.

Without this discrimination operators on Pi-class hosts spent hours
staring at a generic "reduce context length, free system memory, or
loosen memory_guard_tier" message that didn't tell them which of
their three knobs to actually turn. The most common confusion was a
metal_cap-bound 413 on hosts where ``iogpu.wired_limit_mb`` had
never been raised — the message told them to free system memory
when no amount of freeing system memory would help.
"""

def _arm_ceilings(
self,
sched,
*,
static: int,
dynamic: int,
metal_cap: int,
tier: str = "balanced",
) -> None:
"""Set the four propagated ceiling fields directly.

Mirrors what ``ProcessMemoryEnforcer._propagate_memory_limit``
does on a real run; the binding-aware message reads only these
fields plus ``_memory_hard_limit_bytes``.
"""
sched._prefill_memory_guard = True
sched._memory_hard_limit_bytes = min(
v for v in (static, dynamic, metal_cap) if v > 0
)
sched._memory_static_ceiling_bytes = static
sched._memory_dynamic_ceiling_bytes = dynamic
sched._memory_metal_cap_bytes = metal_cap
sched._memory_guard_tier = tier
# Set_model_info populated dims at scheduler construction; we
# only need a non-zero peak estimate to drive the rejection
# path, not exact bytes.

def _force_rejection(self, sched, monkeypatch):
"""Mock the parts of the math we don't care about and call
``_preflight_memory_check`` so we can inspect the message it
returns."""
# Peak chosen larger than any ceiling tested below so the
# rejection branch fires deterministically.
sched.memory_monitor = MagicMock()
sched.memory_monitor.estimate_prefill_peak_bytes.return_value = (
512 * 1024**3
)

import omlx.scheduler as scheduler_mod

monkeypatch.setattr(scheduler_mod.mx, "get_active_memory", lambda: 0)
monkeypatch.setattr(scheduler_mod, "get_phys_footprint", lambda: 0)

req = MagicMock()
req.request_id = "binding-test"
req.num_prompt_tokens = 65536
req.cached_tokens = 0
# _preflight_memory_check tries an LRU eviction retry first; we
# don't want that path here.
monkeypatch.setattr(
sched,
"_raise_prefill_eviction_if_available",
lambda **kw: None,
)
rej = sched._preflight_memory_check(req)
assert rej is not None, "rejection branch must fire when peak > ceiling"
return rej

def test_metal_cap_binding_names_sysctl(self, monkeypatch):
sched = _make_scheduler()
self._arm_ceilings(
sched, static=64 * 1024**3, dynamic=32 * 1024**3, metal_cap=16 * 1024**3
)
rej = self._force_rejection(sched, monkeypatch)
assert "iogpu.wired_limit_mb" in rej.message, (
f"metal_cap binding must steer user to the sysctl knob; got: {rej.message}"
)
assert "metal_cap ceiling" in rej.message

def test_dynamic_binding_under_custom_names_admin_setting(self, monkeypatch):
sched = _make_scheduler()
self._arm_ceilings(
sched,
static=64 * 1024**3,
dynamic=16 * 1024**3,
metal_cap=48 * 1024**3,
tier="custom",
)
rej = self._force_rejection(sched, monkeypatch)
assert "custom_ceiling_bytes" in rej.message, (
"dynamic binding under custom tier must point at the admin "
f"Memory setting, not 'close other apps'; got: {rej.message}"
)
assert "close other apps" not in rej.message.lower()

def test_dynamic_binding_under_reclaim_tier_names_apps(self, monkeypatch):
sched = _make_scheduler()
# Static > dynamic, balanced tier: closing apps and/or raising
# tier is what helps.
self._arm_ceilings(
sched,
static=64 * 1024**3,
dynamic=16 * 1024**3,
metal_cap=48 * 1024**3,
tier="balanced",
)
rej = self._force_rejection(sched, monkeypatch)
assert "close other apps" in rej.message.lower(), (
"dynamic binding on a reclaim tier should suggest closing "
f"apps; got: {rej.message}"
)
assert "memory_guard_tier" in rej.message

def test_static_binding_falls_back_to_generic_advice(self, monkeypatch):
sched = _make_scheduler()
# Static is the smallest non-zero ceiling.
self._arm_ceilings(
sched,
static=16 * 1024**3,
dynamic=64 * 1024**3,
metal_cap=48 * 1024**3,
)
rej = self._force_rejection(sched, monkeypatch)
assert "memory_guard_tier" in rej.message
assert "iogpu.wired_limit_mb" not in rej.message
assert "custom_ceiling_bytes" not in rej.message
Loading