diff --git a/omlx/admin/routes.py b/omlx/admin/routes.py
index 06ceb808b..ad11bb030 100644
--- a/omlx/admin/routes.py
+++ b/omlx/admin/routes.py
@@ -3483,6 +3483,13 @@ def _build_runtime_cache_observability(
}
cache_dir = global_settings.cache.get_ssd_cache_dir(global_settings.base_path)
+ cache_cfg = global_settings.cache
+ try:
+ cfg_disk_max = cache_cfg.get_ssd_cache_max_size_bytes(global_settings.base_path)
+ except (ValueError, OSError, TypeError) as exc:
+ logger.warning("Could not read SSD cache max size from config: %s", exc)
+ cfg_disk_max = 0
+
payload = {
"base_path": str(global_settings.base_path),
"ssd_cache_dir": str(cache_dir),
@@ -3491,6 +3498,16 @@ def _build_runtime_cache_observability(
"total_num_files": 0,
"total_size_bytes": 0,
"effective_block_sizes": [],
+ "disk_max_bytes": cfg_disk_max,
+ "hot_cache_max_bytes": 0,
+ "hot_cache_size_bytes": 0,
+ "hot_cache_entries": 0,
+ # MRU partial cache feature gate. Per-model occupancy lives on each
+ # models[] entry; there is no payload-level entries aggregate
+ # because MRU tail slots are per-model, not a shared budget — only
+ # this max-entries sum is kept, purely so the dashboard can tell
+ # whether the feature is configured for any loaded model.
+ "mru_partial_max_entries": 0,
}
engine_pool = _get_engine_pool()
@@ -3602,11 +3619,28 @@ def _build_runtime_cache_observability(
"last_tokens_to_next_block": last_tokens_to_next_block,
"num_files": int(ssd_stats.get("num_files", 0) or 0),
"total_size_bytes": int(ssd_stats.get("total_size_bytes", 0) or 0),
+ "max_size_bytes": int(ssd_stats.get("max_size_bytes", 0) or 0),
"hot_cache_max_bytes": int(ssd_stats.get("hot_cache_max_bytes", 0) or 0),
"hot_cache_size_bytes": int(ssd_stats.get("hot_cache_size_bytes", 0) or 0),
"hot_cache_entries": int(ssd_stats.get("hot_cache_entries", 0) or 0),
+ "mru_partial_entries": int(
+ prefix_stats.get("mru_partial_entries", 0) or 0
+ ),
+ "mru_partial_max_entries": int(
+ prefix_stats.get("mru_partial_max_entries", 0) or 0
+ ),
+ # Tri-state: None (unknown / no inference yet), True (eligible),
+ # False (model uses non-sliceable cache layers — every stash
+ # refused at the safety gate; dashboard renders 'N/A (see log)').
+ "mru_partial_supported": prefix_stats.get(
+ "mru_partial_supported", None
+ ),
}
+ cache_rates = runtime_stats.get("cache_rates")
+ if cache_rates:
+ model_payload["cache_rates"] = cache_rates
+
payload["models"].append(model_payload)
payload["total_num_files"] += model_payload["num_files"]
payload["total_size_bytes"] += model_payload["total_size_bytes"]
@@ -3616,6 +3650,33 @@ def _build_runtime_cache_observability(
payload["effective_block_sizes"] = sorted(block_sizes)
+ # Aggregate hot-cache and disk-max across models.
+ # hot_cache_max sums across models (each model reserves its own slice of
+ # the same process-wide hot cache budget) so the gauge denominator matches
+ # the summed numerator. disk_max keeps the config fallback via max()
+ # because a single SSD cache directory is shared — the effective cap is
+ # the largest configured limit, not a per-model sum.
+ hot_cache_max = 0
+ disk_max = payload["disk_max_bytes"]
+ hot_cache_size_total = 0
+ hot_cache_entries_total = 0
+ mru_max_entries_total = 0
+ for m in payload["models"]:
+ hot_cache_size_total += m.get("hot_cache_size_bytes", 0)
+ hot_cache_entries_total += m.get("hot_cache_entries", 0)
+ hot_cache_max += m.get("hot_cache_max_bytes", 0)
+ disk_max = max(disk_max, m.get("max_size_bytes", 0))
+ # MRU: only the max-entries sum is kept, and only as a feature-on
+ # gate for the dashboard. Per-model occupancy is on each models[]
+ # entry; an aggregate live count would be meaningless because the
+ # slots are per-model, not a shared budget.
+ mru_max_entries_total += m.get("mru_partial_max_entries", 0)
+ payload["hot_cache_max_bytes"] = hot_cache_max
+ payload["hot_cache_size_bytes"] = hot_cache_size_total
+ payload["hot_cache_entries"] = hot_cache_entries_total
+ payload["disk_max_bytes"] = disk_max
+ payload["mru_partial_max_entries"] = mru_max_entries_total
+
# Fallback: if no loaded models contributed stats, scan the cache
# directory directly so the dashboard still shows real disk usage.
if payload["total_num_files"] == 0 and cache_dir.exists():
@@ -3870,6 +3931,30 @@ async def clear_alltime_stats(is_admin: bool = Depends(require_admin)):
return {"status": "ok"}
+def _iter_loaded_schedulers():
+ """Yield (model_id, scheduler) for each loaded model.
+
+ Traverses the internal engine hierarchy: pool entry → async engine →
+ core engine → scheduler. Both ``clear_ssd_cache`` and
+ ``clear_hot_cache`` share this traversal.
+ """
+ engine_pool = _get_engine_pool()
+ if engine_pool is None:
+ return
+ for model_info in engine_pool.get_status().get("models", []):
+ model_id = model_info.get("id")
+ if not model_id or not model_info.get("loaded"):
+ continue
+ entry = engine_pool._entries.get(model_id)
+ if entry is None or entry.engine is None:
+ continue
+ async_core = getattr(entry.engine, "_engine", None)
+ core = getattr(async_core, "engine", None) if async_core is not None else None
+ scheduler = getattr(core, "scheduler", None) if core is not None else None
+ if scheduler is not None:
+ yield model_id, scheduler
+
+
@router.post("/api/ssd-cache/clear")
async def clear_ssd_cache(is_admin: bool = Depends(require_admin)):
"""Clear all SSD cache files for all loaded models.
@@ -3880,38 +3965,33 @@ async def clear_ssd_cache(is_admin: bool = Depends(require_admin)):
"""
total_deleted = 0
- # Phase 1: clear via loaded models' cache managers (updates in-memory index)
- engine_pool = _get_engine_pool()
- if engine_pool is not None:
- for model_info in engine_pool.get_status().get("models", []):
- model_id = model_info.get("id")
- if not model_id or not model_info.get("loaded"):
- continue
-
- entry = engine_pool._entries.get(model_id)
- if entry is None or entry.engine is None:
- continue
-
- async_core = getattr(entry.engine, "_engine", None)
- core = (
- getattr(async_core, "engine", None) if async_core is not None else None
- )
- scheduler = (
- getattr(core, "scheduler", None) if core is not None else None
- )
+ for model_id, scheduler in _iter_loaded_schedulers():
+ ssd_manager = getattr(scheduler, "paged_ssd_cache_manager", None)
+ if ssd_manager is not None:
+ try:
+ total_deleted += ssd_manager.clear()
+ except Exception as exc:
+ logger.warning(
+ "Failed to clear SSD cache for model '%s': %s",
+ model_id,
+ exc,
+ )
- if scheduler is not None:
- ssd_manager = getattr(scheduler, "paged_ssd_cache_manager", None)
- if ssd_manager is not None:
- try:
- deleted = ssd_manager.clear()
- total_deleted += deleted
- except Exception as exc:
- logger.warning(
- "Failed to clear SSD cache for model '%s': %s",
- model_id,
- exc,
- )
+ # MRU partials chain from paged-block hashes whose KV bytes are
+ # gone after the ssd_manager.clear() above. Drop them so the
+ # admin "clear all warm caches" intent is honoured symmetrically.
+ # Single-tier behaviour (no clear) is the surviving-stash hazard
+ # the peer review caught for this endpoint.
+ block_aware_cache = getattr(scheduler, "block_aware_cache", None)
+ if block_aware_cache is not None:
+ try:
+ block_aware_cache.clear_mru_partials()
+ except Exception as exc:
+ logger.warning(
+ "Failed to clear MRU partials for model '%s': %s",
+ model_id,
+ exc,
+ )
# Phase 2: remove any remaining files on disk (covers unloaded models)
global_settings = _get_global_settings()
@@ -3937,6 +4017,31 @@ async def clear_ssd_cache(is_admin: bool = Depends(require_admin)):
return {"status": "ok", "total_deleted": total_deleted}
+@router.post("/api/hot-cache/clear")
+async def clear_hot_cache(is_admin: bool = Depends(require_admin)):
+ """Clear the in-memory (hot) cache for all loaded models.
+
+ No filesystem fallback needed — hot cache is in-memory only and does
+ not survive process restart.
+ """
+ total_cleared = 0
+ for model_id, scheduler in _iter_loaded_schedulers():
+ ssd_manager = getattr(scheduler, "paged_ssd_cache_manager", None)
+ if ssd_manager is not None and hasattr(ssd_manager, "clear_hot_cache"):
+ try:
+ total_cleared += ssd_manager.clear_hot_cache()
+ except Exception as exc:
+ logger.warning(
+ "Failed to clear hot cache for model '%s': %s",
+ model_id,
+ exc,
+ )
+ rate_tracker = getattr(scheduler, "_cache_rate_tracker", None)
+ if rate_tracker is not None:
+ rate_tracker.clear()
+ return {"status": "ok", "total_cleared": total_cleared}
+
+
@router.post("/api/cache/probe")
async def probe_cache(
request: CacheProbeRequest,
diff --git a/omlx/admin/static/css/dashboard.css b/omlx/admin/static/css/dashboard.css
index a95382df8..6af38022b 100644
--- a/omlx/admin/static/css/dashboard.css
+++ b/omlx/admin/static/css/dashboard.css
@@ -63,6 +63,10 @@
[data-theme="dark"] .hover\:text-neutral-700:hover { color: var(--text-primary) !important; }
[data-theme="dark"] .hover\:text-neutral-600:hover { color: var(--text-secondary) !important; }
+ /* === Gauge track (visible in both themes) === */
+ .gauge-track { background-color: #e5e5e5; }
+ [data-theme="dark"] .gauge-track { background-color: #3f3f46 !important; }
+
/* === Active nav tab (bg-white with shadow inside dark nav) === */
[data-theme="dark"] .shadow-sm { box-shadow: 0 1px 2px 0 rgba(0, 0, 0, 0.3) !important; }
diff --git a/omlx/admin/static/js/dashboard.js b/omlx/admin/static/js/dashboard.js
index ec8ddbf90..cb3acbf3f 100644
--- a/omlx/admin/static/js/dashboard.js
+++ b/omlx/admin/static/js/dashboard.js
@@ -160,6 +160,10 @@
total_num_files: 0,
total_size_bytes: 0,
effective_block_sizes: [],
+ hot_cache_size_bytes: 0,
+ hot_cache_entries: 0,
+ hot_cache_max_bytes: 0,
+ disk_max_bytes: 0,
},
},
alltimeStats: {
@@ -190,6 +194,7 @@
showClearStatsConfirm: false,
showClearAlltimeConfirm: false,
showClearSsdCacheConfirm: false,
+ showClearHotCacheConfirm: false,
_statsRefreshTimer: null,
// Log viewer state
@@ -2149,7 +2154,8 @@
async clearSsdCache() {
try {
- await fetch('/admin/api/ssd-cache/clear', { method: 'POST' });
+ const resp = await fetch('/admin/api/ssd-cache/clear', { method: 'POST' });
+ if (!resp.ok) console.error('SSD cache clear failed:', resp.status);
this.showClearSsdCacheConfirm = false;
await this.loadStats();
} catch (err) {
@@ -2158,6 +2164,18 @@
}
},
+ async clearHotCache() {
+ try {
+ const resp = await fetch('/admin/api/hot-cache/clear', { method: 'POST' });
+ if (!resp.ok) console.error('Hot cache clear failed:', resp.status);
+ this.showClearHotCacheConfirm = false;
+ await this.loadStats();
+ } catch (err) {
+ console.error('Failed to clear hot cache:', err);
+ this.showClearHotCacheConfirm = false;
+ }
+ },
+
startStatsRefresh() {
this.stopStatsRefresh();
this._statsRefreshTimer = setInterval(() => {
@@ -2178,6 +2196,39 @@
return num.toLocaleString();
},
+ cacheObsCumulative(stats, selectedModel) {
+ const entries = stats.runtime_cache?.models || [];
+ if (entries.length === 0) return {};
+
+ if (selectedModel) {
+ const entry = entries.find(m => m.id === selectedModel);
+ return entry?.cache_rates?.cumulative || {};
+ }
+
+ const sumKeys = ['prefix_hits', 'prefix_misses', 'evictions', 'ssd_hot_hits', 'ssd_disk_loads', 'ssd_saves', 'hot_cache_evictions', 'hot_cache_promotions', 'mru_partial_stashes', 'mru_partial_hits', 'mru_partial_evictions', 'mru_partial_tokens_saved'];
+ let agg = {};
+
+ for (const m of entries) {
+ const c = m.cache_rates?.cumulative;
+ if (!c || Object.keys(c).length === 0) continue;
+ for (const k of sumKeys) {
+ agg[k] = (agg[k] || 0) + (c[k] || 0);
+ }
+ }
+
+ const ph = agg.prefix_hits || 0;
+ const pm = agg.prefix_misses || 0;
+ const sh = agg.ssd_hot_hits || 0;
+ const sd = agg.ssd_disk_loads || 0;
+ const ms = agg.mru_partial_stashes || 0;
+ const mh = agg.mru_partial_hits || 0;
+ agg.prefix_hit_rate = (ph + pm) > 0 ? ph / (ph + pm) : 0;
+ agg.ssd_hot_rate = (sh + sd) > 0 ? sh / (sh + sd) : 0;
+ agg.mru_partial_hit_rate = ms > 0 ? mh / ms : 0;
+
+ return agg;
+ },
+
getStatFontClass(value) {
if (value >= 1000000000) return 'text-2xl';
if (value >= 1000000) return 'text-3xl';
@@ -2239,6 +2290,38 @@
return 'bg-red-400';
},
+ get runtimeHotCachePercent() {
+ const rc = this.stats.runtime_cache;
+ if (!rc || !rc.hot_cache_max_bytes) return 0;
+ return Math.min(100, (rc.hot_cache_size_bytes / rc.hot_cache_max_bytes) * 100);
+ },
+
+ // mruEnabled is a feature-on gate (drives the rate strip and the
+ // per-model MRU Tails column). It reads the payload-level
+ // mru_partial_max_entries purely as "configured for any loaded
+ // model" — there is deliberately no aggregate MRU-tails gauge,
+ // since the slots are per-model, not a shared budget.
+ get mruEnabled() {
+ return (this.stats.runtime_cache?.mru_partial_max_entries || 0) > 0;
+ },
+
+ get hotCacheEnabled() {
+ return (this.stats.runtime_cache?.hot_cache_max_bytes || 0) > 0;
+ },
+
+ get cacheRatesGridCols() {
+ const both = this.hotCacheEnabled && this.mruEnabled;
+ if (both) return 'grid-cols-2 sm:grid-cols-6';
+ if (this.hotCacheEnabled || this.mruEnabled) return 'grid-cols-2 sm:grid-cols-4';
+ return 'grid-cols-2';
+ },
+
+ get runtimeSsdCachePercent() {
+ const rc = this.stats.runtime_cache;
+ if (!rc || !rc.disk_max_bytes) return 0;
+ return Math.min(100, (rc.total_size_bytes / rc.disk_max_bytes) * 100);
+ },
+
get activeModelsMemoryPercent() {
const am = this.stats.active_models;
if (!am || !am.model_memory_max) return 0;
diff --git a/omlx/admin/templates/dashboard/_status.html b/omlx/admin/templates/dashboard/_status.html
index 1d6e04f40..f1284044c 100644
--- a/omlx/admin/templates/dashboard/_status.html
+++ b/omlx/admin/templates/dashboard/_status.html
@@ -282,8 +282,51 @@
{{ t('status.head
Runtime Cache Observability
-
+
+
+
Memory
+
+
+
+
+
+ Clear memory cache?
+
+
+
+
+
+
+
|
+
+
+
+
+
+
@@ -336,8 +416,11 @@ {{ t('status.head
| Block Size |
Indexed Blocks |
Sub-block Cache |
- Cache Files |
- Cache Size |
+ SSD Files |
+ SSD Size |
+ Memory Entries |
+ Memory Size |
+ MRU Tails |
@@ -358,6 +441,15 @@ {{ t('status.head
|
|
+ |
+ |
+
+ N/A (see log)
+
+ |
diff --git a/omlx/cache/observability.py b/omlx/cache/observability.py
new file mode 100644
index 000000000..c53c8a2d2
--- /dev/null
+++ b/omlx/cache/observability.py
@@ -0,0 +1,169 @@
+# SPDX-License-Identifier: Apache-2.0
+import threading
+import time
+from collections import deque
+from typing import Any
+
+
+_DEFAULT_WINDOWS = (60, 300, 900)
+_MAX_SNAPSHOTS = 90
+_MIN_INTERVAL = 10.0
+
+
+class CacheRateTracker:
+
+ def __init__(
+ self,
+ max_snapshots: int = _MAX_SNAPSHOTS,
+ min_interval: float = _MIN_INTERVAL,
+ ):
+ self._snapshots: deque[tuple[float, dict[str, int]]] = deque(
+ maxlen=max_snapshots
+ )
+ self._min_interval = min_interval
+ self._lock = threading.Lock()
+
+ def maybe_snapshot(self, counters: dict[str, int]) -> bool:
+ with self._lock:
+ now = time.monotonic()
+ if self._snapshots and (now - self._snapshots[-1][0]) < self._min_interval:
+ return False
+ self._snapshots.append((now, dict(counters)))
+ return True
+
+ def get_rates(
+ self, windows: tuple[int, ...] = _DEFAULT_WINDOWS
+ ) -> dict[str, Any]:
+ with self._lock:
+ if not self._snapshots:
+ return {"windows": {}, "cumulative": {}}
+
+ now = self._snapshots[-1][0]
+ newest = self._snapshots[-1][1]
+
+ window_rates = {}
+ for w in windows:
+ label = _window_label(w)
+ baseline_ts = None
+ baseline_counters = None
+ for ts, counters in self._snapshots:
+ if (now - ts) <= w:
+ baseline_ts, baseline_counters = ts, counters
+ break
+ if baseline_ts is None:
+ baseline_ts, baseline_counters = self._snapshots[0]
+ elapsed = now - baseline_ts
+ if elapsed < 1.0:
+ window_rates[label] = {}
+ continue
+ window_rates[label] = _compute_window(
+ baseline_counters, newest, elapsed
+ )
+
+ cumulative = _compute_cumulative(newest)
+ return {"windows": window_rates, "cumulative": cumulative}
+
+ def snapshot_and_get_rates(
+ self,
+ counters: dict[str, int],
+ windows: tuple[int, ...] = _DEFAULT_WINDOWS,
+ ) -> dict[str, Any]:
+ self.maybe_snapshot(counters)
+ return self.get_rates(windows)
+
+ def clear(self) -> None:
+ with self._lock:
+ self._snapshots.clear()
+
+
+def _window_label(seconds: int) -> str:
+ if seconds < 60:
+ return f"{seconds}s"
+ return f"{seconds // 60}m"
+
+
+def _safe_ratio(numerator: int, denominator: int) -> float:
+ if denominator == 0:
+ return 0.0
+ return numerator / denominator
+
+
+def _compute_window(
+ old: dict[str, int], new: dict[str, int], elapsed: float
+) -> dict[str, Any]:
+ def delta(key: str) -> int:
+ return max(0, new.get(key, 0) - old.get(key, 0))
+
+ d_prefix_hits = delta("prefix_hits")
+ d_prefix_misses = delta("prefix_misses")
+ d_evictions = delta("evictions")
+ d_ssd_hot = delta("ssd_hot_hits")
+ d_ssd_disk = delta("ssd_disk_loads")
+ d_tokens_matched = delta("prefix_tokens_matched")
+ d_tokens_requested = delta("prefix_tokens_requested")
+ d_mru_stashes = delta("mru_partial_stashes")
+ d_mru_hits = delta("mru_partial_hits")
+ d_mru_evictions = delta("mru_partial_evictions")
+ d_mru_tokens_saved = delta("mru_partial_tokens_saved")
+
+ minutes = elapsed / 60.0
+
+ return {
+ "prefix_hit_rate": round(
+ _safe_ratio(d_prefix_hits, d_prefix_hits + d_prefix_misses), 4
+ ),
+ "prefix_hits": d_prefix_hits,
+ "prefix_misses": d_prefix_misses,
+ "prefix_match_efficiency": round(
+ _safe_ratio(d_tokens_matched, d_tokens_requested), 4
+ ),
+ "evictions": d_evictions,
+ "eviction_rate_per_min": round(d_evictions / minutes, 2) if minutes > 0 else 0.0,
+ "ssd_hot_hits": d_ssd_hot,
+ "ssd_disk_loads": d_ssd_disk,
+ "ssd_hot_rate": round(
+ _safe_ratio(d_ssd_hot, d_ssd_hot + d_ssd_disk), 4
+ ),
+ # MRU partial stash payoff: fraction of stashes that paid off as
+ # an apply-time splice. Workload with rate ≈ 0 means stashes
+ # are mostly wasted (uniform/unrepeated prompts); rate near 1
+ # means almost every stash got reused.
+ "mru_partial_stashes": d_mru_stashes,
+ "mru_partial_hits": d_mru_hits,
+ "mru_partial_evictions": d_mru_evictions,
+ "mru_partial_tokens_saved": d_mru_tokens_saved,
+ "mru_partial_hit_rate": round(_safe_ratio(d_mru_hits, d_mru_stashes), 4),
+ }
+
+
+def _compute_cumulative(counters: dict[str, int]) -> dict[str, Any]:
+ prefix_hits = counters.get("prefix_hits", 0)
+ prefix_misses = counters.get("prefix_misses", 0)
+ ssd_hot = counters.get("ssd_hot_hits", 0)
+ ssd_disk = counters.get("ssd_disk_loads", 0)
+ tokens_matched = counters.get("prefix_tokens_matched", 0)
+ tokens_requested = counters.get("prefix_tokens_requested", 0)
+ mru_stashes = counters.get("mru_partial_stashes", 0)
+ mru_hits = counters.get("mru_partial_hits", 0)
+
+ return {
+ "prefix_hits": prefix_hits,
+ "prefix_misses": prefix_misses,
+ "prefix_hit_rate": round(_safe_ratio(prefix_hits, prefix_hits + prefix_misses), 4),
+ "prefix_tokens_saved": counters.get("prefix_tokens_saved", 0),
+ "prefix_match_efficiency": round(
+ _safe_ratio(tokens_matched, tokens_requested), 4
+ ),
+ "evictions": counters.get("evictions", 0),
+ "ssd_hot_hits": ssd_hot,
+ "ssd_disk_loads": ssd_disk,
+ "ssd_saves": counters.get("ssd_saves", 0),
+ "hot_cache_evictions": counters.get("hot_cache_evictions", 0),
+ "hot_cache_promotions": counters.get("hot_cache_promotions", 0),
+ "ssd_hot_rate": round(_safe_ratio(ssd_hot, ssd_hot + ssd_disk), 4),
+ "mru_partial_stashes": mru_stashes,
+ "mru_partial_hits": mru_hits,
+ "mru_partial_evictions": counters.get("mru_partial_evictions", 0),
+ "mru_partial_tokens_saved": counters.get("mru_partial_tokens_saved", 0),
+ "mru_partial_hit_rate": round(_safe_ratio(mru_hits, mru_stashes), 4),
+ }
diff --git a/omlx/cache/paged_ssd_cache.py b/omlx/cache/paged_ssd_cache.py
index 7d5c0d6c7..be52bc8e5 100644
--- a/omlx/cache/paged_ssd_cache.py
+++ b/omlx/cache/paged_ssd_cache.py
@@ -2035,6 +2035,20 @@ def enforce_size_limit(self) -> int:
)
return freed
+ def clear_hot_cache(self) -> int:
+ """Clear all in-memory (hot) cache entries.
+
+ Returns:
+ Number of entries cleared.
+ """
+ with self._hot_cache_lock:
+ count = len(self._hot_cache)
+ self._hot_cache.clear()
+ self._hot_cache_total_bytes = 0
+ if count:
+ logger.info("Cleared %d hot cache entries", count)
+ return count
+
def clear(self) -> int:
"""
Clear all SSD cache files.
diff --git a/omlx/cache/prefix_cache.py b/omlx/cache/prefix_cache.py
index 4f2bd1d32..697c9a9e9 100644
--- a/omlx/cache/prefix_cache.py
+++ b/omlx/cache/prefix_cache.py
@@ -9,6 +9,7 @@
import logging
import math
import time
+from collections import OrderedDict
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
@@ -31,7 +32,7 @@
)
from .paged_ssd_cache import PagedSSDCacheManager
from .stats import PrefixCacheStats
-from .type_registry import CacheTypeRegistry
+from .type_registry import KNOWN_SLICEABLE_CACHE_TYPES, CacheTypeRegistry
logger = logging.getLogger(__name__)
@@ -44,6 +45,59 @@ class BlockCacheEntry:
last_access: float
+@dataclass
+class _MRUPartialBlock:
+ """One entry in the bounded LRU stash for trailing sub-block partials.
+
+ The paged SSD cache only persists full ``block_size`` blocks; trailing
+ sub-block tokens (e.g. 139 out of 256) are otherwise discarded and
+ re-prefilled on every repeat request. The MRU stash keeps the partial
+ in memory so an immediate repeat skips re-prefilling those tail tokens.
+
+ Entries are stored in ``BlockAwarePrefixCache._mru_partials`` as an
+ ``OrderedDict`` keyed by ``parent_hash``. New entries land at the tail;
+ LRU eviction pops the oldest when the dict exceeds
+ ``mru_partial_max_entries``. A successful apply promotes the matched
+ entry to the tail (move_to_end). Apply-time miss for a found entry
+ pops only that key, leaving siblings intact. Same-key replacement on
+ stash is correct LRU put behavior.
+
+ Stash and apply are gated on **uniform layer sliceability**: if any
+ layer in the model is non-sliceable (RotatingKVCache, ArraysCache,
+ etc.) no entries are ever written. Splicing into only the sliceable
+ layers would create per-layer offset skew at decode time.
+
+ Memory accounting
+ -----------------
+ ``kv_data`` holds real ``mx.array`` allocations (produced by
+ ``_clone_tensor`` → ``mx.copy``), so each entry's memory cost is
+ automatically counted by ``mx.get_active_memory()`` and is therefore
+ visible to every runtime memory enforcement and telemetry path in
+ this codebase (process enforcer, scheduler limit checks,
+ periodic-clear threshold, prefill pre-flight peak check).
+
+ Upstream of those, the engine pool reserves a fraction of each
+ model's weight size as KV headroom before admission. MRU partials
+ are one tenant of that headroom alongside in-flight prompt caches;
+ they are not separately reserved because each entry is bounded at
+ one ``block_size``-worth of KV and the default cap is small. Under
+ ``hot_cache_only=True``, the hot cache and the MRU dict share that
+ envelope — tune ``--mru-partial-max-entries`` and
+ ``--hot-cache-max-size`` together in that mode rather than treating
+ them as independent dials.
+
+ Future maintainers: do **not** add a separate accounting hook for
+ these entries. The invariant that ``kv_data`` holds ``mx.array``
+ instances (not numpy/CPU copies) is what makes the implicit
+ accounting work; ``test_kv_data_holds_mlx_arrays_for_active_memory_accounting``
+ pins it.
+ """
+
+ parent_hash: bytes | None
+ tokens: list[int]
+ kv_data: list[tuple[Any, Any]]
+
+
class BlockAwarePrefixCache(CacheManager):
"""
Prefix cache that uses PagedCacheManager for block-based storage.
@@ -82,6 +136,7 @@ def __init__(
model: Any,
paged_cache_manager: PagedCacheManager,
paged_ssd_cache_manager: PagedSSDCacheManager | None = None,
+ mru_partial_max_entries: int = 4,
):
"""
Initialize block-aware prefix cache.
@@ -90,6 +145,12 @@ def __init__(
model: The MLX model (used for identification)
paged_cache_manager: The PagedCacheManager instance for block management
paged_ssd_cache_manager: The PagedSSDCacheManager for SSD storage (required for paged SSD-only mode)
+ mru_partial_max_entries: Maximum simultaneous MRU partial-block
+ stashes. Each entry is bounded at one ``block_size`` of KV
+ memory. ``0`` disables the MRU feature entirely (silent
+ fallback to "no MRU" behavior, mirroring
+ ``hot_cache_max_size="0"`` convention). Default of 4 matches
+ the dflash ``max_entries`` precedent (PR #1120).
"""
self.model = model
self.model_key = id(model)
@@ -111,14 +172,121 @@ def __init__(
# Kept for API compatibility
self._cold_restore_callback: Callable[[int, bytes], bool] | None = None
+ # Bounded LRU map for trailing sub-block tails (see _MRUPartialBlock).
+ # Keyed by parent_hash (block_hash of the last full block in the
+ # prefix, or None for short prompts whose prefix is < block_size).
+ # Populated by store_cache via _update_mru_partial, consumed by
+ # apply_mru_partial. Mirrors the OrderedDict + LRU pattern from
+ # PagedSSDCacheManager._hot_cache. Empty dict for hybrid models.
+ self._mru_partials: OrderedDict[
+ bytes | None, _MRUPartialBlock
+ ] = OrderedDict()
+ self._mru_partial_max_entries: int = mru_partial_max_entries
+
# Statistics
self._hits = 0
self._misses = 0
self._tokens_saved = 0
self._partial_block_skips = 0
self._partial_tokens_skipped = 0
+ self._tokens_matched_total = 0
+ self._tokens_requested_total = 0
self._last_partial_tokens_skipped = 0
self._last_tokens_to_next_block = 0
+ # MRU partial cache cumulative counters (see PrefixCacheStats).
+ self._mru_partial_stashes = 0
+ self._mru_partial_hits = 0
+ self._mru_partial_evictions = 0
+ self._mru_partial_tokens_saved = 0
+
+ # Tri-state MRU eligibility flag (see PrefixCacheStats.mru_partial_supported).
+ # Latched True once a sliceable layer set is observed; latched False
+ # the first time a non-sliceable set is observed (eager check below
+ # or lazy fallback inside _update_mru_partial).
+ self._mru_partial_supported: bool | None = None
+ self._mru_partial_warn_emitted: bool = False
+ if self._mru_partial_max_entries > 0:
+ self._check_mru_eligibility_at_init(model)
+
+ def _check_mru_eligibility_at_init(self, model: Any) -> None:
+ """Best-effort load-time check: does this model have sliceable layers?
+
+ The MRU partial-block stash safety gate rejects any layer set that
+ contains a non-sliceable cache type (``RotatingKVCache``,
+ ``PoolingCache``, ``ArraysCache``, ``CacheList``, etc.) — splicing
+ a partial into a sliceable subset only would cause per-layer offset
+ skew at decode (silent generation corruption). For affected
+ models the entire MRU feature is structurally unavailable, and the
+ dashboard would otherwise show a confusing "0/N entries" gauge
+ forever. Warn the operator at load time so they can grep this
+ exact line for the offending types.
+
+ Best-effort: if ``model.make_cache()`` is absent or raises, the
+ lazy fallback inside ``_update_mru_partial`` picks the same signal
+ up at first inference instead.
+ """
+ if not hasattr(model, "make_cache"):
+ return
+ try:
+ sample_cache = model.make_cache()
+ except Exception as exc:
+ logger.debug(
+ "MRU eager eligibility check skipped (make_cache failed: %s)",
+ exc,
+ )
+ return
+ try:
+ mcc = ModelCacheConfig.from_cache_list(sample_cache)
+ layer_types = mcc.get_type_names() if mcc is not None else None
+ except Exception as exc:
+ logger.debug(
+ "MRU eager eligibility check skipped (ModelCacheConfig failed: %s)",
+ exc,
+ )
+ return
+ finally:
+ # Drop the sample cache refs ASAP — only used for type-name
+ # introspection. No tensor buffers were allocated yet (those
+ # arrive on first prefill), but keeping the wrappers around
+ # serves no purpose.
+ del sample_cache
+ if not layer_types:
+ return
+ if self._all_layers_sliceable(layer_types):
+ self._mru_partial_supported = True
+ else:
+ self._record_mru_unsupported(layer_types)
+
+ def _record_mru_unsupported(
+ self, layer_cache_types: list[str] | None
+ ) -> None:
+ """Latch ``mru_partial_supported=False`` and emit a one-shot warning.
+
+ Called from both the eager init-time check and the lazy fallback
+ inside ``_update_mru_partial``. The warning matches the in-tree
+ load-phase warning style ("condition + consequence, plain words")
+ — see ``utils/model_loading.py`` mtp warning and ``engine/dflash``
+ L2 warning for the reference voice. Mechanism details
+ (sliceable whitelist, offset-skew failure mode) live in the
+ ``_all_layers_sliceable`` docstring where developers read them,
+ not in the operator log.
+ """
+ already_recorded = self._mru_partial_supported is False
+ self._mru_partial_supported = False
+ if already_recorded or self._mru_partial_warn_emitted:
+ return
+ self._mru_partial_warn_emitted = True
+ if layer_cache_types:
+ offenders = sorted(
+ set(layer_cache_types) - KNOWN_SLICEABLE_CACHE_TYPES
+ )
+ else:
+ offenders = ["unknown"]
+ logger.warning(
+ "MRU tail cache enabled but this model is incompatible "
+ "(cache layers: %s); MRU tails will be inactive for this model.",
+ ", ".join(offenders),
+ )
def _get_model_num_layers(self, model: Any) -> int:
"""
@@ -285,6 +453,8 @@ def fetch_cache(
num_prefix_tokens = len(tokens) - len(remaining)
self._hits += 1
self._tokens_saved += num_prefix_tokens
+ self._tokens_matched_total += num_prefix_tokens
+ self._tokens_requested_total += len(tokens)
logger.debug(
f"Cache hit for {request_id}: "
@@ -310,6 +480,8 @@ def fetch_cache(
remaining = tokens[prefix_len:]
self._hits += 1
self._tokens_saved += prefix_len
+ self._tokens_matched_total += prefix_len
+ self._tokens_requested_total += len(tokens)
logger.debug(
f"Prefix index hit for {request_id}: " f"{prefix_len} tokens matched"
@@ -319,6 +491,7 @@ def fetch_cache(
# No cache hit
self._misses += 1
+ self._tokens_requested_total += len(tokens)
logger.debug(f"Cache miss for {request_id}")
return None, tokens
@@ -332,6 +505,7 @@ def store_cache(
extra_keys: tuple[Any, ...] | None = None,
extra_key_token_start: int | None = None,
extra_key_ranges: list[tuple[int, tuple[Any, ...]]] | None = None,
+ prompt_token_count: int | None = None,
) -> BlockTable | None:
"""
Store computed cache for future reuse.
@@ -352,6 +526,14 @@ def store_cache(
boundary_snapshots: Optional mapping of token_count -> extracted cache
states for intermediate block boundaries. Used to store per-block
ArraysCache state instead of placeholders in hybrid models.
+ prompt_token_count: Number of prompt tokens at the front of
+ ``tokens``. ``tokens`` is typically ``prompt + output``;
+ the MRU partial cache must stash the prompt's trailing
+ tail (the part a repeat request will resubmit), not the
+ sequence's. ``None`` means "no prompt boundary known" —
+ the MRU stash then treats the whole sequence as the
+ prompt (pre-MRU-fix behavior; correct for verbatim-repeat
+ callers like the generic CacheManager.store path).
Returns:
BlockTable for the stored cache, or None on failure
@@ -652,6 +834,22 @@ def store_cache(
last_access=time.time(),
)
+ # Stash the prompt's trailing sub-block tail in memory so an
+ # immediate repeat request can splice it back in without a
+ # re-prefill. Keyed by the prompt's last full block so the
+ # lookup in apply_mru_partial (which sees a prompt-only repeat)
+ # finds it — see _update_mru_partial.
+ self._update_mru_partial(
+ new_tokens=new_tokens,
+ cache_data=cache_data,
+ block_table=block_table,
+ existing_tokens=existing_tokens,
+ prompt_token_count=prompt_token_count,
+ is_tensor_data=bool(is_tensor_data),
+ layer_cache_types=layer_cache_types,
+ model_cache_config=model_cache_config,
+ )
+
logger.debug(
f"Stored cache for {request_id}: "
f"{len(block_table.block_ids)} blocks ({blocks_saved_to_ssd} saved to tiered cache), "
@@ -660,6 +858,369 @@ def store_cache(
return block_table
+ def has_mru_partial(self) -> bool:
+ """Whether any trailing-tail partial is currently stashed.
+
+ Used by the scheduler to decide whether to suppress a deferred
+ Metal cache clear (a fresh stash is a strong signal that the
+ same prompt may return immediately). The "any entry present"
+ semantic is the right predicate regardless of multi-slot count
+ — one warm prompt is enough to merit suppression.
+ """
+ return bool(self._mru_partials)
+
+ def _can_reconstruct(self) -> bool:
+ """Whether ``reconstruct_cache`` has a path to return non-``None``.
+
+ Used as a guard at both the MRU stash site (``_update_mru_partial``)
+ and the canonical reconstruct entry (``reconstruct_cache``). Co-
+ locating the two checks keeps them in lockstep — a future fetch
+ path that bypasses ``PagedSSDCacheManager`` (memory-only mode,
+ alternate backends) updates exactly one predicate.
+
+ Note the predicate is ``manager is not None``, not "SSD writes are
+ enabled." ``hot_cache_only=True`` (omlx setting / OMLX_HOT_CACHE_ONLY
+ env) keeps the manager but skips the disk writer thread; reconstruct
+ still works because ``load_block_with_metadata`` short-circuits to
+ the hot tier. The gate fires only when no manager is configured
+ at all — typically a test/dev scenario.
+ """
+ return HAS_MLX and self.paged_ssd_cache is not None
+
+ def _all_layers_sliceable(
+ self, layer_cache_types: list[str] | None
+ ) -> bool:
+ """True iff every layer's cache type is in the known-sliceable
+ whitelist.
+
+ Why a whitelist and not ``CacheTypeRegistry.supports_block_slicing``:
+ the registry uses ``DefaultCacheHandler`` (a ``KVCacheHandler``
+ subclass that reports ``supports_block_slicing=True``) as the
+ fallback for class names with no registered handler. Several real
+ non-sliceable types — ``BatchRotatingKVCache``, ``BatchPoolingCache``,
+ ``PoolingCache`` (without the deepseek_v4 patch applied) — are
+ mapped in ``_class_name_map`` but have no registered handler, so
+ the registry would silently classify them as sliceable. The
+ whitelist is the same one the rest of the scheduler trusts for
+ snapshot-skip and partial-extraction decisions, so MRU now agrees
+ with the surrounding code rather than getting its own answer.
+
+ Hybrid models (e.g. Gemma 3, Mistral) mix sliceable layers with
+ non-sliceable ones (``RotatingKVCache``, ``ArraysCache``). Splicing
+ the partial only into the sliceable layers would create per-layer
+ offset skew at decode — undefined behaviour at the model level —
+ so we refuse the stash entirely whenever any layer is non-sliceable.
+ """
+ if not layer_cache_types:
+ # Default fallback path assumes all layers are KVCache; safe to stash.
+ return True
+ for class_name in layer_cache_types:
+ if class_name not in KNOWN_SLICEABLE_CACHE_TYPES:
+ return False
+ return True
+
+ def _update_mru_partial(
+ self,
+ *,
+ new_tokens: list[int],
+ cache_data: list[Any],
+ block_table: BlockTable,
+ existing_tokens: int,
+ prompt_token_count: int | None,
+ is_tensor_data: bool,
+ layer_cache_types: list[str] | None,
+ model_cache_config: ModelCacheConfig | None,
+ ) -> None:
+ """Stash the prompt's trailing partial from a just-completed ``store_cache``.
+
+ ``store_cache`` is given the full ``prompt + output`` sequence,
+ but a repeat request resubmits the *prompt* only — and
+ ``apply_mru_partial`` looks the entry up by the prompt's last
+ full block. So the stash must key off, and slice, the
+ **prompt's** trailing partial, not the stored sequence's.
+ ``prompt_token_count`` carries the prompt boundary; ``None``
+ falls back to "whole sequence is the prompt" (pre-fix behavior,
+ correct for verbatim-repeat callers).
+
+ On success, writes one entry into the LRU map keyed by
+ ``parent_hash`` (the prompt's last full block, or ``None`` for a
+ prompt shorter than one block). If the map is at capacity, the
+ oldest entry is evicted via ``popitem(last=False)``.
+
+ The "no eligible tail" branches (block-aligned prompt, non-tensor
+ data, no reconstruct path configured, hybrid model, extraction
+ failure, ambiguous cache layout) bare-return — they signal a
+ local "nothing to stash this time," NOT a global "wipe the map."
+ That distinction is what lets multiple distinct-prefix entries
+ coexist under interleaving.
+ """
+ if self._mru_partial_max_entries <= 0:
+ # Feature disabled. Match the hot_cache_max_size="0" convention:
+ # silent fallback to "no MRU" behavior.
+ return
+ # Sliceable-layer guard is checked first and recorded so the
+ # eligibility flag flips at the same instant the warning fires —
+ # later branches in this chain are per-call ("no eligible tail
+ # this time") and must not pollute the structural flag.
+ if not self._all_layers_sliceable(layer_cache_types):
+ self._record_mru_unsupported(layer_cache_types)
+ return
+ if self._mru_partial_supported is None:
+ self._mru_partial_supported = True
+ if not is_tensor_data or not self._can_reconstruct():
+ return
+
+ # Resolve the prompt boundary within the stored sequence. None
+ # means "no boundary known" — treat the whole stored sequence as
+ # the prompt, which reproduces the pre-fix whole-sequence stash.
+ sequence_len = existing_tokens + len(new_tokens)
+ if prompt_token_count is None:
+ prompt_token_count = sequence_len
+ else:
+ prompt_token_count = min(prompt_token_count, sequence_len)
+ # Prompt fully covered by already-cached full blocks: its tail (if
+ # any) is partial and partials are never stored as blocks, so a
+ # prompt that ends at or before existing_tokens is block-aligned
+ # and has nothing for the MRU to add.
+ if prompt_token_count <= existing_tokens:
+ return
+
+ # Prompt's trailing partial: global range
+ # [prompt_partial_start, prompt_token_count). Block-aligned
+ # prompt (partial_len == 0) → every token lands in a full paged
+ # block, nothing to stash.
+ prompt_partial_len = prompt_token_count % self.block_size
+ if prompt_partial_len == 0:
+ return
+ prompt_partial_start = prompt_token_count - prompt_partial_len
+ prompt_full_blocks = prompt_token_count // self.block_size
+ # new_tokens-relative offset of the prompt's trailing partial.
+ # existing_tokens is block-aligned and <= prompt_partial_start,
+ # so this is a valid index into new_tokens.
+ partial_start = prompt_partial_start - existing_tokens
+ partial_global_start = prompt_partial_start
+ partial_global_end = prompt_token_count
+
+ # Decide which axis-2 index range covers the trailing partial:
+ # - Global: cache_data spans the full sequence (prefix + new tail);
+ # slice at [partial_global_start:partial_global_end].
+ # - Local: cache_data spans only the newly-processed suffix;
+ # slice at [partial_start:partial_start + prompt_partial_len].
+ #
+ # We classify by exact cache length, not the previous
+ # ``cache_seq_len >= existing_tokens + 1`` heuristic — that one
+ # silently picked "local" when ``cache_seq_len == existing_tokens``,
+ # which can happen on multi-turn requests where the cache was
+ # extracted at a boundary that happens to equal the prior turn's
+ # length. Slicing local indices on a global cache there sampled
+ # tokens from the prefix, parent_hash still matched, and a future
+ # apply spliced wrong KV — silent generation corruption. Now: if
+ # the layout cannot be unambiguously classified, refuse to stash.
+ cache_seq_len = self._get_cache_seq_len(cache_data)
+ local_len = len(new_tokens)
+ if cache_seq_len >= partial_global_end:
+ # Cache is long enough to contain the global partial range.
+ p_cache_start = partial_global_start
+ p_cache_end = partial_global_end
+ elif existing_tokens == 0 or cache_seq_len == local_len:
+ # Cache covers only the new suffix (or there is no prefix).
+ p_cache_start = partial_start
+ p_cache_end = partial_start + prompt_partial_len
+ else:
+ # Ambiguous: cache_seq_len is short of global_end but does
+ # not equal local_len either. Refuse rather than guess.
+ return
+
+ # is_last_block=False is the correct intent for partial extraction:
+ # we want a token slice, not the full state of any non-sliceable
+ # layer. The non-sliceable case is already excluded above.
+ partial_kv = self._extract_block_tensor_slice(
+ cache_data,
+ p_cache_start,
+ p_cache_end,
+ model_cache_config,
+ is_last_block=False,
+ )
+ if not partial_kv:
+ return
+ # Materialize now, on this (store-cache worker) thread. The
+ # tensors are lazy ops bound to this thread's stream; the splice
+ # in apply_mru_partial runs on the separate inference thread and
+ # must not inherit a cross-thread stream dependency.
+ self._materialize_mru_kv(partial_kv)
+
+ # Key by the PROMPT's last full block (block index
+ # prompt_full_blocks - 1), so a prompt-only repeat — which is what
+ # apply_mru_partial sees — finds this entry. block_table.block_ids
+ # is ordered [fetched-prefix blocks..., newly-built blocks...] in
+ # sequence order, and the stored sequence has at least
+ # prompt_full_blocks full blocks (the prompt is a prefix of it),
+ # so the index is in range. prompt_full_blocks == 0 (prompt
+ # shorter than one block) keeps parent_hash = None — the
+ # short-prompt None-key path.
+ parent_hash: bytes | None = None
+ if prompt_full_blocks >= 1:
+ if len(block_table.block_ids) < prompt_full_blocks:
+ # Prompt's last full block isn't in the table (block build
+ # rolled back on an extraction failure). Skip rather than
+ # mis-key against the None short-prompt slot.
+ return
+ parent_bid = block_table.block_ids[prompt_full_blocks - 1]
+ parent_block = self.paged_cache.allocated_blocks.get(parent_bid)
+ if parent_block is None or not parent_block.block_hash:
+ return
+ parent_hash = parent_block.block_hash
+
+ # LRU put: pop any existing entry first so the re-insert lands at
+ # the tail with fresh order; then evict the oldest entry if over
+ # capacity. Mirrors PagedSSDCacheManager._hot_cache_put.
+ # Note: same-key replacement does NOT count as an eviction —
+ # the operator sees stash payoff via the stashes/hits ratio, not
+ # via replacement churn.
+ if parent_hash in self._mru_partials:
+ self._mru_partials.pop(parent_hash)
+ self._mru_partials[parent_hash] = _MRUPartialBlock(
+ parent_hash=parent_hash,
+ tokens=new_tokens[partial_start : partial_start + prompt_partial_len],
+ kv_data=partial_kv,
+ )
+ self._mru_partial_stashes += 1
+ while len(self._mru_partials) > self._mru_partial_max_entries:
+ self._mru_partials.popitem(last=False)
+ self._mru_partial_evictions += 1
+
+ logger.debug(
+ "Stashed MRU partial: %d tokens, parent_hash=%s, layers=%d, "
+ "entries=%d/%d",
+ prompt_partial_len,
+ parent_hash[:8].hex() + "..." if parent_hash else "None",
+ len(partial_kv),
+ len(self._mru_partials),
+ self._mru_partial_max_entries,
+ )
+
+ def apply_mru_partial(
+ self,
+ cache: list[Any],
+ block_table: BlockTable,
+ remaining_tokens: list[int],
+ ) -> tuple[list[Any], list[int], int]:
+ """Splice an MRU partial entry into a reconstructed cache, atomically.
+
+ Looks up an entry in the multi-slot LRU map keyed by the block
+ hash of the last full block in ``block_table`` (``None`` for
+ short prompts). On a match (entry exists, partial tokens are a
+ prefix of ``remaining_tokens``, layer count matches, every layer
+ accepts the concatenate), every layer's keys/values/offset are
+ advanced by ``len(partial.tokens)``, the partial tokens are
+ consumed from the front of ``remaining_tokens``, and the entry
+ is moved to the LRU tail (most-recently-used).
+
+ On any miss, only the matching entry is evicted; sibling entries
+ for other prefixes remain.
+
+ The splice is **transactional**: replacement keys/values are
+ materialized for every layer before any layer is mutated. A
+ failure on layer N rolls everything back; the caller never sees
+ a half-mutated cache.
+
+ Args:
+ cache: Reconstructed per-layer cache objects from
+ ``reconstruct_cache``. Mutated in place on success.
+ block_table: Block table from the same fetch_cache call.
+ Used to verify the partial chains from the right prefix.
+ remaining_tokens: Tokens still needing prefill.
+
+ Returns:
+ ``(cache, remaining_tokens, tokens_applied)``. On miss,
+ ``tokens_applied == 0`` and the inputs are returned unchanged.
+ """
+ if not self._mru_partials or not remaining_tokens:
+ return cache, remaining_tokens, 0
+
+ # Compute the parent_hash key. Multi-slot freed-block guard:
+ # if block_table is non-empty but the parent paged block has
+ # been freed (allocated_blocks.get returns None), do NOT fall
+ # through to a None-keyed dict lookup — that would falsely match
+ # a short-prompt entry against a request whose parent is just
+ # gone. Return no-op instead. This race is new in multi-slot
+ # mode; single-slot tolerated it because there was only ever
+ # one entry to match against.
+ last_hash: bytes | None = None
+ if block_table and block_table.block_ids:
+ last_bid = block_table.block_ids[-1]
+ last_block = self.paged_cache.allocated_blocks.get(last_bid)
+ if last_block is None:
+ return cache, remaining_tokens, 0
+ last_hash = last_block.block_hash
+
+ partial = self._mru_partials.get(last_hash)
+ if partial is None:
+ return cache, remaining_tokens, 0
+
+ def _evict_miss() -> tuple[list[Any], list[int], int]:
+ """Pop the matched entry and return the no-op tuple.
+
+ Used by every apply-time mismatch arm. Inlines to one line
+ at each call site and keeps the eviction-counter bookkeeping
+ in one place.
+ """
+ self._mru_partials.pop(last_hash, None)
+ self._mru_partial_evictions += 1
+ return cache, remaining_tokens, 0
+
+ n_partial = len(partial.tokens)
+ if len(remaining_tokens) < n_partial:
+ return _evict_miss()
+ if remaining_tokens[:n_partial] != partial.tokens:
+ return _evict_miss()
+
+ if len(partial.kv_data) != len(cache):
+ logger.debug(
+ "MRU partial layer count mismatch: %d vs %d, evicting entry",
+ len(partial.kv_data), len(cache),
+ )
+ return _evict_miss()
+
+ if not HAS_MLX:
+ return _evict_miss()
+
+ # Phase 1: build per-layer replacements without touching the cache.
+ # Any failure here is a clean rollback — no layer has been mutated.
+ try:
+ replacements: list[tuple[int, Any, Any]] = []
+ for layer_idx, (p_keys, p_values) in enumerate(partial.kv_data):
+ cache_obj = cache[layer_idx]
+ new_keys = mx.concatenate([cache_obj.keys, p_keys], axis=2)
+ new_values = mx.concatenate([cache_obj.values, p_values], axis=2)
+ replacements.append((layer_idx, new_keys, new_values))
+ except Exception as e:
+ logger.debug(
+ "MRU partial splice build failed: %s, evicting entry", e
+ )
+ return _evict_miss()
+
+ # Phase 2: commit. All concatenates have already succeeded; the
+ # only operations remaining are attribute writes and an integer
+ # add, which cannot raise on a well-formed cache object.
+ for layer_idx, new_keys, new_values in replacements:
+ cache_obj = cache[layer_idx]
+ cache_obj.keys = new_keys
+ cache_obj.values = new_values
+ cache_obj.offset += n_partial
+
+ # Promote this entry to the LRU tail (most-recently-used).
+ self._mru_partials.move_to_end(last_hash)
+ self._mru_partial_hits += 1
+ self._mru_partial_tokens_saved += n_partial
+
+ new_remaining = remaining_tokens[n_partial:]
+ logger.debug(
+ "Applied MRU partial: %d tokens, %d remaining, entries=%d",
+ n_partial, len(new_remaining), len(self._mru_partials),
+ )
+ return cache, new_remaining, n_partial
+
def _get_cache_seq_len(self, cache_data: list[dict[str, Any]]) -> int:
"""
Get the sequence length from cache data.
@@ -1220,6 +1781,39 @@ def _clone_tensor(self, tensor: Any) -> Any:
return mx.array(tensor)
+ def _materialize_mru_kv(self, partial_kv: list[Any]) -> None:
+ """Force-evaluate a freshly-extracted MRU partial's KV tensors.
+
+ ``_extract_block_tensor_slice`` builds the tensors as lazy
+ ``mx.copy`` ops on whichever thread ``store_cache`` runs on — the
+ ``omlx-store-cache`` worker. ``apply_mru_partial`` later splices
+ them into a live cache on the separate ``mlx-global`` inference
+ thread; while the tensors are still lazy, that thread's
+ ``mx.async_eval`` would walk the compute graph back to a
+ per-thread stream the inference thread cannot see, and MLX raises
+ ``RuntimeError: There is no Stream(gpu, N) in current thread``.
+ Evaluating here, on the worker, leaves concrete stream-free data
+ safe to splice and evaluate from any thread.
+
+ Walks the nested ``(keys, values)`` / TurboQuant ``(tag, (k, v))``
+ shapes ``_extract_block_tensor_slice`` returns and evaluates every
+ ``mx.array`` leaf in one batched call.
+ """
+ if not HAS_MLX:
+ return
+ leaves: list[Any] = []
+
+ def _collect(obj: Any) -> None:
+ if isinstance(obj, mx.array):
+ leaves.append(obj)
+ elif isinstance(obj, (list, tuple)):
+ for item in obj:
+ _collect(item)
+
+ _collect(partial_kv)
+ if leaves:
+ mx.eval(*leaves)
+
def _apply_window_padding(
self,
matched_blocks: int,
@@ -1388,15 +1982,16 @@ def reconstruct_cache(
if not block_table or not block_table.block_ids:
return None
- if not HAS_MLX:
- logger.warning("Cannot reconstruct cache: MLX not available")
- return None
-
- if self.paged_ssd_cache is None:
- logger.warning(
- "Cannot reconstruct cache: PagedSSDCacheManager not configured"
- )
+ if not self._can_reconstruct():
+ if not HAS_MLX:
+ logger.warning("Cannot reconstruct cache: MLX not available")
+ else:
+ logger.warning(
+ "Cannot reconstruct cache: PagedSSDCacheManager not configured"
+ )
return None
+ # _can_reconstruct() guarantees this; narrow for the type checker.
+ assert self.paged_ssd_cache is not None
try:
# Collect cache data from valid blocks (stop at first invalid)
@@ -2367,6 +2962,15 @@ def get_stats(self) -> PrefixCacheStats:
block_size=self.block_size,
last_partial_tokens_skipped=self._last_partial_tokens_skipped,
last_tokens_to_next_block=self._last_tokens_to_next_block,
+ tokens_matched_total=self._tokens_matched_total,
+ tokens_requested_total=self._tokens_requested_total,
+ mru_partial_stashes=self._mru_partial_stashes,
+ mru_partial_hits=self._mru_partial_hits,
+ mru_partial_evictions=self._mru_partial_evictions,
+ mru_partial_tokens_saved=self._mru_partial_tokens_saved,
+ mru_partial_entries=len(self._mru_partials),
+ mru_partial_max_entries=self._mru_partial_max_entries,
+ mru_partial_supported=self._mru_partial_supported,
)
def get_stats_dict(self) -> dict[str, Any]:
@@ -2393,19 +2997,44 @@ def get_stats_dict(self) -> dict[str, Any]:
"block_size": self.block_size,
"last_partial_tokens_skipped": self._last_partial_tokens_skipped,
"last_tokens_to_next_block": self._last_tokens_to_next_block,
+ "tokens_matched_total": self._tokens_matched_total,
+ "tokens_requested_total": self._tokens_requested_total,
"active_requests": len(self._request_tables),
+ # MRU partial cache: counters mirror the dataclass surface so the
+ # admin dashboard's `mruEnabled` gate (sourced from this dict
+ # via Scheduler.get_ssd_cache_stats) can see the configured
+ # capacity. Omitting them silently hides every MRU panel.
+ "mru_partial_stashes": self._mru_partial_stashes,
+ "mru_partial_hits": self._mru_partial_hits,
+ "mru_partial_evictions": self._mru_partial_evictions,
+ "mru_partial_tokens_saved": self._mru_partial_tokens_saved,
+ "mru_partial_entries": len(self._mru_partials),
+ "mru_partial_max_entries": self._mru_partial_max_entries,
+ "mru_partial_supported": self._mru_partial_supported,
**paged_stats,
}
def reset_stats(self) -> None:
- """Reset statistics."""
+ """Reset statistics.
+
+ Note: ``_mru_partials`` is live state (the cache itself), not a
+ counter — its size is reported as a gauge by ``get_stats()`` and
+ is not affected by stats reset. Use ``clear_mru_partials()`` if
+ the entries themselves should be wiped.
+ """
self._hits = 0
self._misses = 0
self._tokens_saved = 0
self._partial_block_skips = 0
self._partial_tokens_skipped = 0
+ self._tokens_matched_total = 0
+ self._tokens_requested_total = 0
self._last_partial_tokens_skipped = 0
self._last_tokens_to_next_block = 0
+ self._mru_partial_stashes = 0
+ self._mru_partial_hits = 0
+ self._mru_partial_evictions = 0
+ self._mru_partial_tokens_saved = 0
self.paged_cache.reset_stats()
def clear(self) -> int:
@@ -2419,9 +3048,43 @@ def clear(self) -> int:
self._request_tables.clear()
self._prefix_index.clear()
self.paged_cache.clear()
+ # MRU partials chain from paged-block hashes; once the paged
+ # cache is wiped, no entry can be safely applied. Cache-corruption
+ # recovery (Scheduler._recover_from_cache_error) routes through
+ # here, so stale partials would otherwise survive exactly the
+ # recovery path that exists because something was wrong.
+ # No eviction-counter bump here: clear() also calls reset_stats()
+ # which zeros every counter (this is the "restart everything"
+ # path). Operators tracking partial wipes specifically should
+ # use clear_mru_partials() instead — that path leaves stats alone.
+ self._mru_partials.clear()
self.reset_stats()
return cleared_count
+ def clear_mru_partials(self) -> int:
+ """Wipe only the MRU partial cache, leaving paged blocks intact.
+
+ Intended consumer: admin-triggered cache-tier clears that drop
+ the backing block storage (``clear_ssd_cache`` admin endpoint,
+ and the future ``clear_hot_cache`` endpoint once PR #1183
+ lands). Without this hook, a stash whose ``parent_hash`` chains
+ from a paged block whose underlying KV bytes were just flushed
+ from the hot/SSD tier would survive in memory and waste a
+ reconstruct attempt on the next request before being naturally
+ evicted by LRU.
+
+ Distinct from ``clear()``: this method only drops MRU entries
+ and does not touch the paged cache, prefix index, or stats
+ (other than incrementing ``mru_partial_evictions``).
+
+ Returns:
+ Number of MRU entries that were wiped.
+ """
+ n = len(self._mru_partials)
+ self._mru_partials.clear()
+ self._mru_partial_evictions += n
+ return n
+
def set_cold_restore_callback(
self,
callback: Callable[[int, bytes], bool] | None,
diff --git a/omlx/cache/stats.py b/omlx/cache/stats.py
index 412074fc7..47f2749be 100644
--- a/omlx/cache/stats.py
+++ b/omlx/cache/stats.py
@@ -88,6 +88,26 @@ class PrefixCacheStats(BaseCacheStats):
block_size: int = 0
last_partial_tokens_skipped: int = 0
last_tokens_to_next_block: int = 0
+ tokens_matched_total: int = 0
+ tokens_requested_total: int = 0
+ # MRU partial cache observability. Cumulative counters that pair
+ # naturally — hits/stashes gives the "stash payoff" rate, evictions
+ # tracks churn, tokens_saved is the direct compute-saved measure.
+ # Gauges (entries / max_entries) describe live capacity utilisation.
+ mru_partial_stashes: int = 0
+ mru_partial_hits: int = 0
+ mru_partial_evictions: int = 0
+ mru_partial_tokens_saved: int = 0
+ mru_partial_entries: int = 0
+ mru_partial_max_entries: int = 0
+ # Tri-state eligibility flag for the MRU partial-block feature:
+ # None → unknown (default; no detection has fired yet)
+ # True → model has only sliceable cache layers
+ # False → at least one layer type is not in the sliceable whitelist,
+ # so every stash attempt is refused at the safety gate.
+ # Surfaces to the admin dashboard so per-model rows can show
+ # "N/A (see log)" instead of a misleading "0/N entries" gauge.
+ mru_partial_supported: bool | None = None
_total_queries: int = field(default=0, repr=False)
@property
@@ -111,6 +131,14 @@ def reset(self) -> None:
self.partial_tokens_skipped = 0
self.last_partial_tokens_skipped = 0
self.last_tokens_to_next_block = 0
+ self.tokens_matched_total = 0
+ self.tokens_requested_total = 0
+ self.mru_partial_stashes = 0
+ self.mru_partial_hits = 0
+ self.mru_partial_evictions = 0
+ self.mru_partial_tokens_saved = 0
+ # mru_partial_entries and mru_partial_max_entries are gauges
+ # populated by get_stats() from live state — not reset here.
self._total_queries = 0
diff --git a/omlx/cache/type_registry.py b/omlx/cache/type_registry.py
index 832c4a71e..add800abc 100644
--- a/omlx/cache/type_registry.py
+++ b/omlx/cache/type_registry.py
@@ -23,6 +23,33 @@
logger = logging.getLogger(__name__)
+# Authoritative whitelist of cache class names whose KV state is safe to
+# slice along the sequence axis (axis=2). Used by:
+# - scheduler.py: snapshot-skip gating, partial extraction
+# - prefix_cache.py: MRU partial stash/apply gate
+#
+# Important: this is NOT derivable from CacheTypeHandler.supports_block_slicing
+# alone, because DefaultCacheHandler (used as a fallback for class names with
+# no registered handler) inherits from KVCacheHandler and reports
+# supports_block_slicing=True. That makes Batch* and Pool* class names
+# silently sliceable via the registry, which is wrong. Code that needs to
+# decide "may I slice this layer's KV" MUST consult this whitelist by
+# class-name string, not the registry.
+KNOWN_SLICEABLE_CACHE_TYPES = frozenset(
+ {
+ "KVCache",
+ "BatchKVCache",
+ "QuantizedKVCache",
+ "TurboQuantKVCache",
+ "BatchTurboQuantKVCache",
+ # ChunkedKVCache is included once the batch=1 patch in scheduler.py
+ # installs its extract/filter/size pass-throughs (PR #1152); without
+ # that patch, Llama-4 requests fall back to the snapshot path.
+ "ChunkedKVCache",
+ }
+)
+
+
class CacheTypeRegistry:
"""Registry for cache type handlers.
diff --git a/omlx/cli.py b/omlx/cli.py
index cff502809..5fdec8905 100644
--- a/omlx/cli.py
+++ b/omlx/cli.py
@@ -222,6 +222,15 @@ def serve_command(args):
else:
scheduler_config.hot_cache_max_size = 0
+ # MRU partial cache: CLI arg > settings. Independent of SSD/hot
+ # configuration — the MRU stash works in any mode where
+ # reconstruct_cache has a path (which is gated by manager presence,
+ # not by SSD writes; see BlockAwarePrefixCache._can_reconstruct).
+ if args.mru_partial_max_entries is not None:
+ scheduler_config.mru_partial_max_entries = int(args.mru_partial_max_entries)
+ else:
+ scheduler_config.mru_partial_max_entries = settings.cache.mru_partial_max_entries
+
if args.no_cache:
print("Mode: Multi-model serving (no oMLX cache, mlx-lm BatchGenerator only)")
elif paged_ssd_cache_dir:
@@ -595,6 +604,17 @@ def main():
default=None,
help="Maximum in-memory hot cache size (e.g., '8GB', '4GB'). Default: 0 (disabled)",
)
+ serve_parser.add_argument(
+ "--mru-partial-max-entries",
+ type=int,
+ default=None,
+ help=(
+ "Maximum simultaneous MRU partial-block stashes. Each entry is "
+ "bounded at one block_size of KV memory. 0 disables the feature. "
+ "Default: 4. Under --hot-cache-only this shares the in-memory KV "
+ "headroom envelope with --hot-cache-max-size; tune both together."
+ ),
+ )
serve_parser.add_argument(
"--no-cache",
action="store_true",
diff --git a/omlx/scheduler.py b/omlx/scheduler.py
index fab20ed7b..89152ae78 100644
--- a/omlx/scheduler.py
+++ b/omlx/scheduler.py
@@ -37,6 +37,7 @@
from mlx_lm.models.cache import make_prompt_cache
from mlx_lm.sample_utils import make_logits_processors
+from .cache.observability import CacheRateTracker
from .cache.paged_cache import PagedCacheManager
from .cache.prefix_cache import BlockAwarePrefixCache
from .exceptions import is_cache_corruption_error
@@ -411,19 +412,15 @@ def _patched_ppb_prompt(self, tokens):
# Cache class names known to be sliceable (no boundary snapshots needed).
-# ChunkedKVCache is included once the batch=1 patch above installs its
-# extract/filter/size pass-throughs; without it Llama-4 requests fall
-# back to the snapshot path unnecessarily.
-_KNOWN_SLICEABLE_CACHE_TYPES = frozenset(
- {
- "KVCache",
- "BatchKVCache",
- "QuantizedKVCache",
- "TurboQuantKVCache",
- "BatchTurboQuantKVCache",
- "ChunkedKVCache",
- }
-)
+# Canonical home: omlx/cache/type_registry.py. ChunkedKVCache is included
+# there once the batch=1 patch above installs its extract/filter/size
+# pass-throughs (PR #1152); without it Llama-4 requests fall back to the
+# snapshot path unnecessarily.
+from omlx.cache.type_registry import KNOWN_SLICEABLE_CACHE_TYPES
+
+# Module-local alias kept for backwards compatibility with existing
+# call sites in this file.
+_KNOWN_SLICEABLE_CACHE_TYPES = KNOWN_SLICEABLE_CACHE_TYPES
def _prompt_cache_needs_snapshots(prompt_cache: list[Any]) -> bool:
@@ -552,6 +549,11 @@ class SchedulerConfig:
hot_cache_only: bool = False
paged_ssd_cache_max_size: int = 100 * 1024 * 1024 * 1024 # 100GB default
hot_cache_max_size: int = 0 # In-memory hot cache size in bytes (0 = disabled)
+ # Bounded LRU stash for trailing sub-block partials of previous
+ # prefills, keyed by parent-block hash. Each entry holds at most
+ # one block_size of KV memory. ``0`` disables the feature; default
+ # of 4 matches the dflash max_entries precedent (PR #1120).
+ mru_partial_max_entries: int = 4
# Model identification (for cache isolation between different models)
model_name: str = "" # OpenAI API model name (e.g., "mlx-community/Llama-3.2-3B")
@@ -781,6 +783,7 @@ def __init__(
self.paged_cache_manager: PagedCacheManager | None = None
self.block_aware_cache: BlockAwarePrefixCache | None = None
self.paged_ssd_cache_manager: PagedSSDCacheManager | None = None
+ self._cache_rate_tracker = CacheRateTracker()
self.memory_monitor: MemoryMonitor | None = None
# Initialize paged SSD cache if paged_ssd_cache_dir is specified
@@ -801,6 +804,7 @@ def __init__(
self.block_aware_cache = BlockAwarePrefixCache(
model=model,
paged_cache_manager=self.paged_cache_manager,
+ mru_partial_max_entries=self.config.mru_partial_max_entries,
)
# Initialize paged SSD cache
@@ -910,6 +914,16 @@ def __init__(
# None = no deferred clear pending; int = step at which to fire.
self._deferred_clear_at: int | None = None
+ # Per-completion budget for suppressing the deferred Metal cache
+ # clear when the prefix cache has a warm MRU partial (a strong
+ # signal that the same prompt may return immediately and would
+ # benefit from the still-resident lazy KV tensors).
+ # The budget is reset to True on every _cleanup_finished() that
+ # arms _deferred_clear_at. Once spent (suppression has happened
+ # once), the deferred clear fires at the next deadline regardless
+ # of MRU state, bounding total deferral at 2x _DEFERRED_CLEAR_DELAY.
+ self._mru_clear_suppression_available: bool = False
+
# Cache XTC special tokens (newline + EOS) — stable per tokenizer.
# Must be after _is_harmony_model / _generation_config_eos init
# since _get_xtc_special_tokens() delegates to _get_stop_tokens().
@@ -1025,6 +1039,7 @@ def _async_store_cache_worker(
extra_keys: tuple[Any, ...] | None,
extra_key_token_start: int | None,
extra_key_ranges: list[tuple[int, tuple[Any, ...]]] | None,
+ prompt_token_count: int,
) -> None:
"""Run store_cache + paged_cache cleanup off the inference thread.
@@ -1067,6 +1082,7 @@ def _async_store_cache_worker(
extra_keys=extra_keys,
extra_key_token_start=extra_key_token_start,
extra_key_ranges=extra_key_ranges,
+ prompt_token_count=prompt_token_count,
)
if block_table is None and self.paged_cache_manager is not None:
block_table = self.paged_cache_manager.get_block_table(request_id)
@@ -3305,6 +3321,38 @@ def add_request(self, request: Request) -> None:
request.remaining_tokens = request.prompt_token_ids[
block_table.num_tokens :
]
+ # Splice the in-memory MRU partial onto the reconstructed
+ # cache when the trailing tokens match. Saves the
+ # re-prefill of the sub-block tail on exact-repeat
+ # prompts. The splice is a no-op when no partial is
+ # stashed or the trailing tokens differ.
+ #
+ # Accounting note: the partial advances cached_tokens
+ # but is NOT a stored paged block, so shared_prefix_blocks
+ # stays at the count of paged blocks reused. After a
+ # successful splice the invariant relaxes from
+ # cached_tokens == shared_prefix_blocks * block_size
+ # to
+ # cached_tokens >= shared_prefix_blocks * block_size
+ # with cached_tokens - shared_prefix_blocks * block_size
+ # ∈ [0, block_size) representing the partial. Current
+ # readers (the scheduler's prefill-completion log lines
+ # downstream) tolerate the relaxed form; future readers
+ # that index block_table.block_ids by
+ # shared_prefix_blocks must NOT use cached_tokens to
+ # bound the loop.
+ if request.remaining_tokens:
+ (
+ request.prompt_cache,
+ request.remaining_tokens,
+ partial_applied,
+ ) = self.block_aware_cache.apply_mru_partial(
+ request.prompt_cache,
+ block_table,
+ request.remaining_tokens,
+ )
+ if partial_applied > 0:
+ request.cached_tokens += partial_applied
# For exact prefix hits we need cache state at (N-1) and the
# last prompt token as input to produce the first decode logit.
# Reusing cache state at N and feeding the last token again
@@ -3428,6 +3476,10 @@ def set_specprefill_draft_model(
model=draft_model,
paged_cache_manager=draft_paged,
paged_ssd_cache_manager=self.paged_ssd_cache_manager,
+ # MRU disabled on the draft cache: apply_mru_partial is
+ # only ever called on the main block_aware_cache, so a
+ # draft-cache stash is dead work that never pays off.
+ mru_partial_max_entries=0,
)
self._draft_prefix_cache.set_cold_restore_callback(
self._restore_block_from_cold
@@ -5137,6 +5189,10 @@ def _cleanup_finished(self, finished_ids: set[str]) -> None:
if pre_eval_arrays:
mx.async_eval(*pre_eval_arrays)
+ # Prompt boundary for the MRU partial stash:
+ # token_sequence_to_store is prompt+output, but a
+ # repeat request resubmits the prompt only.
+ prompt_token_count = len(request.prompt_token_ids)
if self._store_cache_executor is not None:
store_future = self._store_cache_executor.submit(
self._async_store_cache_worker,
@@ -5148,6 +5204,7 @@ def _cleanup_finished(self, finished_ids: set[str]) -> None:
request.vlm_extra_keys_for_cache,
request.vlm_extra_key_token_start_for_cache,
request.vlm_extra_key_ranges_for_cache,
+ prompt_token_count,
)
self._inflight_store_futures[request_id] = store_future
else:
@@ -5161,6 +5218,7 @@ def _cleanup_finished(self, finished_ids: set[str]) -> None:
request.vlm_extra_keys_for_cache,
request.vlm_extra_key_token_start_for_cache,
request.vlm_extra_key_ranges_for_cache,
+ prompt_token_count,
)
logger.debug(
f"Submitted async store_cache for {request_id} "
@@ -5295,6 +5353,17 @@ def _cleanup_finished(self, finished_ids: set[str]) -> None:
# finished their completeMemory() callbacks (#557).
target = self._step_counter + self._DEFERRED_CLEAR_DELAY
if self._deferred_clear_at is None or target > self._deferred_clear_at:
+ # Arm the suppression budget only when STARTING a new
+ # deferral epoch (transition from None). Subsequent
+ # completions in the same epoch may legitimately push the
+ # deadline out (for IOKit safety, #557) but must NOT
+ # refresh the budget — otherwise a hot-prompt workload
+ # whose completions arrive faster than _DEFERRED_CLEAR_DELAY
+ # could keep re-arming after the budget was spent and
+ # defer the clear forever, defeating the pool-bloat
+ # mitigation (#411). One suppression per epoch, total.
+ if self._deferred_clear_at is None:
+ self._mru_clear_suppression_available = True
self._deferred_clear_at = target
def _is_cache_corruption_error(self, error: Exception) -> bool:
@@ -5322,6 +5391,7 @@ def _recover_from_cache_error(self) -> None:
# Clear caches
if self.block_aware_cache is not None:
self.block_aware_cache.clear()
+ self._cache_rate_tracker.clear()
# Clear UID mappings
self.request_id_to_uid.clear()
@@ -5329,6 +5399,7 @@ def _recover_from_cache_error(self) -> None:
# Cancel any pending deferred Metal cache clear
self._deferred_clear_at = None
+ self._mru_clear_suppression_available = False
# Clear detokenizer state to prevent contamination after recovery
self._request_detokenizers.clear()
@@ -5560,8 +5631,26 @@ def step(self) -> SchedulerOutput:
self._deferred_clear_at is not None
and self._step_counter >= self._deferred_clear_at
):
- should_clear = True
- self._deferred_clear_at = None
+ # If the prefix cache is holding a warm MRU partial and we
+ # haven't yet spent the per-completion suppression budget,
+ # defer the clear by one more _DEFERRED_CLEAR_DELAY window.
+ # The MRU partial is a strong predictor that the next
+ # request will reuse the still-resident lazy KV tensors.
+ # The budget is one-shot, so total deferral is bounded at
+ # 2x _DEFERRED_CLEAR_DELAY even under hot-prompt repeats.
+ if (
+ self._mru_clear_suppression_available
+ and self.block_aware_cache is not None
+ and self.block_aware_cache.has_mru_partial()
+ ):
+ self._deferred_clear_at = (
+ self._step_counter + self._DEFERRED_CLEAR_DELAY
+ )
+ self._mru_clear_suppression_available = False
+ else:
+ should_clear = True
+ self._deferred_clear_at = None
+ self._mru_clear_suppression_available = False
if should_clear:
_sync_and_clear_cache()
if (
@@ -5651,6 +5740,7 @@ def reset(self) -> None:
# Clear caches
if self.block_aware_cache is not None:
self.block_aware_cache.clear()
+ self._cache_rate_tracker.clear()
# Clear detokenizers
self._request_detokenizers.clear()
@@ -5660,6 +5750,7 @@ def reset(self) -> None:
# Cancel any pending deferred Metal cache clear
self._deferred_clear_at = None
+ self._mru_clear_suppression_available = False
def deep_reset(self) -> None:
"""
@@ -6083,6 +6174,41 @@ def restore_cold_blocks_for_request(self, request_id: str) -> int:
return verified
+ def _collect_cache_counters(self) -> dict[str, int] | None:
+ if self.block_aware_cache is None:
+ return None
+
+ prefix_stats = self.block_aware_cache.get_stats()
+ counters = {
+ "prefix_hits": prefix_stats.hits,
+ "prefix_misses": prefix_stats.misses,
+ "prefix_tokens_matched": prefix_stats.tokens_matched_total,
+ "prefix_tokens_requested": prefix_stats.tokens_requested_total,
+ "prefix_tokens_saved": prefix_stats.tokens_saved,
+ "evictions": prefix_stats.evictions,
+ "mru_partial_stashes": prefix_stats.mru_partial_stashes,
+ "mru_partial_hits": prefix_stats.mru_partial_hits,
+ "mru_partial_evictions": prefix_stats.mru_partial_evictions,
+ "mru_partial_tokens_saved": prefix_stats.mru_partial_tokens_saved,
+ "mru_partial_entries": prefix_stats.mru_partial_entries,
+ "mru_partial_max_entries": prefix_stats.mru_partial_max_entries,
+ }
+
+ if self.paged_ssd_cache_manager is not None:
+ ssd = self.paged_ssd_cache_manager.get_stats()
+ hot_hits = ssd.hot_cache_hits
+ total_loads = ssd.loads
+ counters.update({
+ "ssd_hot_hits": hot_hits,
+ "ssd_disk_loads": max(0, total_loads - hot_hits),
+ "ssd_saves": ssd.saves,
+ "ssd_errors": ssd.errors,
+ "hot_cache_evictions": ssd.hot_cache_evictions,
+ "hot_cache_promotions": ssd.hot_cache_promotions,
+ })
+
+ return counters
+
def get_ssd_cache_stats(self) -> dict[str, Any] | None:
"""Get paged SSD + prefix cache observability statistics."""
stats = {}
@@ -6091,15 +6217,18 @@ def get_ssd_cache_stats(self) -> dict[str, Any] | None:
stats["ssd_cache"] = self.paged_ssd_cache_manager.get_stats()
if self.paged_cache_manager is not None:
- # In paged SSD-only mode, all cache data is on paged SSD
stats["indexed_blocks"] = self.paged_cache_manager.cold_block_count
stats["block_size"] = self.config.paged_cache_block_size
if self.block_aware_cache is not None:
- # Expose prefix-cache observability so UI can distinguish
- # "0 indexed blocks" from "sub-block cached ( Path:
"""
@@ -295,6 +300,7 @@ def to_dict(self) -> dict[str, Any]:
"ssd_cache_max_size": self.ssd_cache_max_size,
"hot_cache_max_size": self.hot_cache_max_size,
"initial_cache_blocks": self.initial_cache_blocks,
+ "mru_partial_max_entries": self.mru_partial_max_entries,
}
@classmethod
@@ -307,6 +313,7 @@ def from_dict(cls, data: dict[str, Any]) -> CacheSettings:
ssd_cache_max_size=data.get("ssd_cache_max_size", "auto"),
hot_cache_max_size=data.get("hot_cache_max_size", "0"),
initial_cache_blocks=data.get("initial_cache_blocks", 256),
+ mru_partial_max_entries=data.get("mru_partial_max_entries", 4),
)
diff --git a/tests/test_admin_api_key.py b/tests/test_admin_api_key.py
index 8766b39ce..eb03106b6 100644
--- a/tests/test_admin_api_key.py
+++ b/tests/test_admin_api_key.py
@@ -655,6 +655,7 @@ def test_runtime_cache_uses_model_scoped_ssd_stats(self):
mock_settings = MagicMock()
mock_settings.base_path = Path("/tmp/omlx-base")
mock_settings.cache.get_ssd_cache_dir.return_value = cache_dir
+ mock_settings.cache.get_ssd_cache_max_size_bytes.return_value = 0
shared_ssd_stats = {
"num_files": 999,
@@ -744,9 +745,13 @@ def test_runtime_cache_uses_model_scoped_ssd_stats(self):
"last_tokens_to_next_block": 0,
"num_files": 3,
"total_size_bytes": 4096,
+ "max_size_bytes": 0,
"hot_cache_max_bytes": 0,
"hot_cache_size_bytes": 0,
"hot_cache_entries": 0,
+ "mru_partial_entries": 0,
+ "mru_partial_max_entries": 0,
+ "mru_partial_supported": None,
},
{
"id": "model-b",
@@ -760,9 +765,13 @@ def test_runtime_cache_uses_model_scoped_ssd_stats(self):
"last_tokens_to_next_block": 0,
"num_files": 7,
"total_size_bytes": 8192,
+ "max_size_bytes": 0,
"hot_cache_max_bytes": 0,
"hot_cache_size_bytes": 0,
"hot_cache_entries": 0,
+ "mru_partial_entries": 0,
+ "mru_partial_max_entries": 0,
+ "mru_partial_supported": None,
},
]
manager_a.get_stats_for_model.assert_called_once_with("/models/model-a")
@@ -775,6 +784,7 @@ def test_runtime_cache_ignores_single_model_stats_failure(self):
mock_settings = MagicMock()
mock_settings.base_path = Path("/tmp/omlx-base")
mock_settings.cache.get_ssd_cache_dir.return_value = cache_dir
+ mock_settings.cache.get_ssd_cache_max_size_bytes.return_value = 0
bad_scheduler = MagicMock()
bad_scheduler.get_ssd_cache_stats.side_effect = RuntimeError("boom")
@@ -833,6 +843,7 @@ def test_runtime_cache_marks_sub_block_cached_when_indexed_blocks_zero(self):
mock_settings = MagicMock()
mock_settings.base_path = Path("/tmp/omlx-base")
mock_settings.cache.get_ssd_cache_dir.return_value = cache_dir
+ mock_settings.cache.get_ssd_cache_max_size_bytes.return_value = 0
scheduler = MagicMock()
scheduler.get_ssd_cache_stats.return_value = {
@@ -876,6 +887,61 @@ def test_runtime_cache_marks_sub_block_cached_when_indexed_blocks_zero(self):
assert model_payload["last_partial_tokens_skipped"] == 577
+class TestClearSSDCacheAlsoWipesMRUPartials:
+ """The ``/api/ssd-cache/clear`` admin endpoint must also clear the
+ MRU partial cache. Without this, partials chain from paged-block
+ hashes whose backing KV bytes were just flushed by ssd_manager.clear()
+ and the operator's "drop all warm caches" intent is only half-honored.
+ """
+
+ def _scheduler_with_mocks(self):
+ """Build a SimpleNamespace scheduler with mock ssd manager and
+ block_aware_cache that we can introspect after the endpoint runs."""
+ ssd_manager = MagicMock()
+ ssd_manager.clear.return_value = 5 # arbitrary deleted count
+ block_aware_cache = MagicMock()
+ block_aware_cache.clear_mru_partials.return_value = 3
+ return SimpleNamespace(
+ paged_ssd_cache_manager=ssd_manager,
+ block_aware_cache=block_aware_cache,
+ )
+
+ def test_endpoint_calls_clear_mru_partials_on_each_scheduler(self):
+ scheduler_a = self._scheduler_with_mocks()
+ scheduler_b = self._scheduler_with_mocks()
+
+ with patch.object(
+ admin_routes,
+ "_iter_loaded_schedulers",
+ return_value=iter([("model-a", scheduler_a), ("model-b", scheduler_b)]),
+ ), patch.object(admin_routes, "_get_global_settings", return_value=None):
+ asyncio.run(admin_routes.clear_ssd_cache(is_admin=True))
+
+ scheduler_a.paged_ssd_cache_manager.clear.assert_called_once()
+ scheduler_a.block_aware_cache.clear_mru_partials.assert_called_once()
+ scheduler_b.paged_ssd_cache_manager.clear.assert_called_once()
+ scheduler_b.block_aware_cache.clear_mru_partials.assert_called_once()
+
+ def test_mru_clear_failure_does_not_block_other_scheduler(self):
+ """A failure in one scheduler's clear_mru_partials must be logged
+ but not prevent other schedulers from being cleared."""
+ scheduler_a = self._scheduler_with_mocks()
+ scheduler_a.block_aware_cache.clear_mru_partials.side_effect = RuntimeError(
+ "boom"
+ )
+ scheduler_b = self._scheduler_with_mocks()
+
+ with patch.object(
+ admin_routes,
+ "_iter_loaded_schedulers",
+ return_value=iter([("model-a", scheduler_a), ("model-b", scheduler_b)]),
+ ), patch.object(admin_routes, "_get_global_settings", return_value=None):
+ asyncio.run(admin_routes.clear_ssd_cache(is_admin=True))
+
+ # Scheduler B must still have been cleared despite A's failure.
+ scheduler_b.block_aware_cache.clear_mru_partials.assert_called_once()
+
+
class TestGlobalSettingsValidation:
"""Tests for stricter GlobalSettingsRequest validation."""
diff --git a/tests/test_cache_observability.py b/tests/test_cache_observability.py
new file mode 100644
index 000000000..d0da07dca
--- /dev/null
+++ b/tests/test_cache_observability.py
@@ -0,0 +1,241 @@
+# tests/test_cache_observability.py
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for cache observability module."""
+
+import threading
+import time
+from unittest.mock import patch
+
+import pytest
+
+from omlx.cache.observability import CacheRateTracker
+
+
+# Counter keys produced by Scheduler._collect_cache_counters. Adding a new
+# observability counter means adding it here and in the production code
+# in lockstep; the explicit-kwargs alternative duplicated the schema once
+# in the function signature and once in the dict construction.
+_COUNTER_KEYS = (
+ "prefix_hits",
+ "prefix_misses",
+ "prefix_tokens_matched",
+ "prefix_tokens_requested",
+ "prefix_tokens_saved",
+ "evictions",
+ "ssd_hot_hits",
+ "ssd_disk_loads",
+ "ssd_saves",
+ "ssd_errors",
+ "hot_cache_evictions",
+ "hot_cache_promotions",
+ "mru_partial_stashes",
+ "mru_partial_hits",
+ "mru_partial_evictions",
+ "mru_partial_tokens_saved",
+)
+
+
+def _make_counters(**overrides):
+ """Build a counter dict for snapshot testing.
+
+ All keys default to ``0``; override individual values via kwargs.
+ Unknown keys raise — the typo-catching the explicit-signature form
+ used to provide.
+ """
+ unknown = set(overrides) - set(_COUNTER_KEYS)
+ if unknown:
+ raise ValueError(
+ f"Unknown counter keys: {sorted(unknown)}. "
+ f"Known: {sorted(_COUNTER_KEYS)}"
+ )
+ counters = {k: 0 for k in _COUNTER_KEYS}
+ counters.update(overrides)
+ return counters
+
+
+class TestCacheRateTrackerSnapshot:
+
+ def test_empty_tracker_returns_empty_rates(self):
+ tracker = CacheRateTracker()
+ result = tracker.get_rates()
+ assert result == {"windows": {}, "cumulative": {}}
+
+ def test_first_snapshot_always_accepted(self):
+ tracker = CacheRateTracker(min_interval=10.0)
+ assert tracker.maybe_snapshot(_make_counters()) is True
+
+ def test_snapshot_rejected_within_min_interval(self):
+ tracker = CacheRateTracker(min_interval=10.0)
+ tracker.maybe_snapshot(_make_counters())
+ assert tracker.maybe_snapshot(_make_counters()) is False
+
+ def test_snapshot_accepted_after_min_interval(self):
+ tracker = CacheRateTracker(min_interval=0.0)
+ tracker.maybe_snapshot(_make_counters())
+ assert tracker.maybe_snapshot(_make_counters()) is True
+
+ def test_deque_overflow_evicts_oldest(self):
+ tracker = CacheRateTracker(max_snapshots=3, min_interval=0.0)
+ for i in range(5):
+ tracker.maybe_snapshot(_make_counters(prefix_hits=i))
+ result = tracker.get_rates()
+ assert result["cumulative"]["prefix_hits"] == 4
+
+
+class TestCacheRateTrackerRates:
+
+ def _tracker_with_two_snapshots(self, old_counters, new_counters, elapsed=60.0):
+ tracker = CacheRateTracker(min_interval=0.0)
+ fake_time = [1000.0]
+
+ def mock_monotonic():
+ return fake_time[0]
+
+ with patch("omlx.cache.observability.time.monotonic", side_effect=mock_monotonic):
+ tracker.maybe_snapshot(old_counters)
+
+ fake_time[0] = 1000.0 + elapsed
+ with patch("omlx.cache.observability.time.monotonic", side_effect=mock_monotonic):
+ tracker.maybe_snapshot(new_counters)
+
+ with patch("omlx.cache.observability.time.monotonic", return_value=fake_time[0]):
+ return tracker.get_rates(windows=(60, 300, 900))
+
+ def test_steady_state_prefix_hit_rate(self):
+ old = _make_counters(prefix_hits=100, prefix_misses=50)
+ new = _make_counters(prefix_hits=200, prefix_misses=75)
+ result = self._tracker_with_two_snapshots(old, new, elapsed=60.0)
+ assert result["windows"]["1m"]["prefix_hit_rate"] == 0.8
+
+ def test_zero_activity_window_no_nan(self):
+ counters = _make_counters(prefix_hits=50, prefix_misses=10)
+ result = self._tracker_with_two_snapshots(counters, counters, elapsed=60.0)
+ assert result["windows"]["1m"]["prefix_hit_rate"] == 0.0
+ assert result["windows"]["1m"]["prefix_match_efficiency"] == 0.0
+ assert result["windows"]["1m"]["eviction_rate_per_min"] == 0.0
+
+ def test_eviction_rate_per_min(self):
+ old = _make_counters(evictions=10)
+ new = _make_counters(evictions=40)
+ result = self._tracker_with_two_snapshots(old, new, elapsed=300.0)
+ assert result["windows"]["5m"]["eviction_rate_per_min"] == 6.0
+
+ def test_prefix_match_efficiency(self):
+ old = _make_counters(prefix_tokens_matched=0, prefix_tokens_requested=0)
+ new = _make_counters(prefix_tokens_matched=600, prefix_tokens_requested=1000)
+ result = self._tracker_with_two_snapshots(old, new, elapsed=60.0)
+ assert result["windows"]["1m"]["prefix_match_efficiency"] == 0.6
+
+ def test_ssd_hot_rate(self):
+ old = _make_counters(ssd_hot_hits=0, ssd_disk_loads=0)
+ new = _make_counters(ssd_hot_hits=80, ssd_disk_loads=20)
+ result = self._tracker_with_two_snapshots(old, new, elapsed=60.0)
+ assert result["windows"]["1m"]["ssd_hot_rate"] == 0.8
+
+ def test_mru_partial_hit_rate(self):
+ """Stash payoff = hits / stashes. Workload that stashed 100 and
+ only got 75 hits has rate 0.75; high enough to justify the
+ feature, low enough that a smaller capacity might do."""
+ old = _make_counters(mru_partial_stashes=0, mru_partial_hits=0)
+ new = _make_counters(mru_partial_stashes=100, mru_partial_hits=75)
+ result = self._tracker_with_two_snapshots(old, new, elapsed=60.0)
+ assert result["windows"]["1m"]["mru_partial_hit_rate"] == 0.75
+ assert result["cumulative"]["mru_partial_hit_rate"] == 0.75
+
+ def test_mru_partial_zero_stashes_no_nan(self):
+ """If no stashes happened in the window, hit_rate must be 0.0
+ not NaN. Mirrors the prefix_hit_rate empty-window guard."""
+ counters = _make_counters(mru_partial_stashes=0, mru_partial_hits=0)
+ result = self._tracker_with_two_snapshots(counters, counters, elapsed=60.0)
+ assert result["windows"]["1m"]["mru_partial_hit_rate"] == 0.0
+ assert result["cumulative"]["mru_partial_hit_rate"] == 0.0
+
+ def test_mru_partial_tokens_saved_delta(self):
+ """Tokens saved is the direct compute-saved measure; it
+ accumulates regardless of hit rate."""
+ old = _make_counters(mru_partial_tokens_saved=0)
+ new = _make_counters(mru_partial_tokens_saved=12345)
+ result = self._tracker_with_two_snapshots(old, new, elapsed=60.0)
+ assert result["windows"]["1m"]["mru_partial_tokens_saved"] == 12345
+ assert result["cumulative"]["mru_partial_tokens_saved"] == 12345
+
+ def test_insufficient_data_returns_empty_window(self):
+ tracker = CacheRateTracker(min_interval=0.0)
+
+ with patch("omlx.cache.observability.time.monotonic", return_value=1000.0):
+ tracker.maybe_snapshot(_make_counters(prefix_hits=10))
+
+ with patch("omlx.cache.observability.time.monotonic", return_value=1000.5):
+ tracker.maybe_snapshot(_make_counters(prefix_hits=20))
+
+ with patch("omlx.cache.observability.time.monotonic", return_value=1000.5):
+ result = tracker.get_rates(windows=(60,))
+ assert result["windows"]["1m"] == {}
+
+ def test_cumulative_uses_latest_snapshot(self):
+ old = _make_counters(prefix_hits=10, prefix_misses=5)
+ new = _make_counters(prefix_hits=100, prefix_misses=20)
+ result = self._tracker_with_two_snapshots(old, new, elapsed=60.0)
+ assert result["cumulative"]["prefix_hits"] == 100
+ assert result["cumulative"]["prefix_misses"] == 20
+ assert abs(result["cumulative"]["prefix_hit_rate"] - 0.8333) < 0.001
+
+
+class TestCacheRateTrackerSnapshotAndGetRates:
+
+ def test_combines_snapshot_and_rates(self):
+ tracker = CacheRateTracker(min_interval=0.0)
+
+ with patch("omlx.cache.observability.time.monotonic", return_value=1000.0):
+ tracker.maybe_snapshot(_make_counters(prefix_hits=0))
+
+ with patch("omlx.cache.observability.time.monotonic", return_value=1060.0):
+ result = tracker.snapshot_and_get_rates(
+ _make_counters(prefix_hits=80, prefix_misses=20)
+ )
+
+ assert result["windows"]["1m"]["prefix_hit_rate"] == 0.8
+ assert result["cumulative"]["prefix_hits"] == 80
+
+
+class TestCacheRateTrackerThreadSafety:
+
+ def test_concurrent_snapshot_and_read(self):
+ tracker = CacheRateTracker(min_interval=0.0)
+ errors = []
+ stop = threading.Event()
+
+ def writer():
+ i = 0
+ while not stop.is_set():
+ try:
+ tracker.maybe_snapshot(_make_counters(prefix_hits=i))
+ i += 1
+ except Exception as e:
+ errors.append(e)
+
+ def reader():
+ while not stop.is_set():
+ try:
+ tracker.get_rates()
+ except Exception as e:
+ errors.append(e)
+
+ threads = [threading.Thread(target=writer), threading.Thread(target=reader)]
+ for t in threads:
+ t.start()
+ time.sleep(0.2)
+ stop.set()
+ for t in threads:
+ t.join(timeout=2.0)
+
+ assert errors == [], f"Thread errors: {errors}"
+
+
+class TestCacheRateTrackerClear:
+
+ def test_clear_resets_state(self):
+ tracker = CacheRateTracker(min_interval=0.0)
+ tracker.maybe_snapshot(_make_counters(prefix_hits=100))
+ tracker.clear()
+ assert tracker.get_rates() == {"windows": {}, "cumulative": {}}
diff --git a/tests/test_prefix_cache.py b/tests/test_prefix_cache.py
index ded2395f5..42d4e5d19 100644
--- a/tests/test_prefix_cache.py
+++ b/tests/test_prefix_cache.py
@@ -6,6 +6,7 @@
PagedCacheManager for block-based storage with SSD persistence.
"""
+import logging
import time
from pathlib import Path
from typing import Any, Dict, List, Optional
@@ -20,7 +21,11 @@
PagedCacheManager,
compute_block_hash,
)
-from omlx.cache.prefix_cache import BlockAwarePrefixCache, BlockCacheEntry
+from omlx.cache.prefix_cache import (
+ BlockAwarePrefixCache,
+ BlockCacheEntry,
+ _MRUPartialBlock,
+)
from omlx.cache.stats import PrefixCacheStats
@@ -2413,3 +2418,1752 @@ def test_store_cache_last_block_with_snapshot_uses_snapshot_meta(self, mx):
f"Last block should use snapshot offset=8, not shared offset=11, "
f"got {b2_meta[1]}"
)
+
+
+def _get_mru_partial(cache, parent_hash):
+ """Test-only accessor for one MRU partial entry by parent_hash key.
+
+ Lives in the test module rather than on ``BlockAwarePrefixCache``
+ itself: production code has no consumer that looks up entries by
+ arbitrary key (the scheduler only needs the boolean predicate via
+ ``has_mru_partial()`` and the dict lookup inside ``apply_mru_partial``).
+ Tests use this helper to assert on individual entries without
+ coupling to the internal ``_mru_partials`` container shape.
+ """
+ return cache._mru_partials.get(parent_hash)
+
+
+def _layer(mx, n_tokens, *, class_name="KVCache", head_dim=4, n_kv_heads=1, fill=1.0):
+ """Build a layer-state dict for store_cache.
+
+ ``class_name`` selects the cache type (e.g. ``"KVCache"``,
+ ``"RotatingKVCache"``, ``"BatchRotatingKVCache"``). Both
+ ``cache_type`` and ``class_name`` keys are populated with the same
+ string — ``store_cache`` consults whichever the layer provides.
+ """
+ return {
+ "state": (
+ mx.full((1, n_kv_heads, n_tokens, head_dim), fill),
+ mx.full((1, n_kv_heads, n_tokens, head_dim), fill),
+ ),
+ "cache_type": class_name,
+ "class_name": class_name,
+ }
+
+
+def _kv_layer(mx, n_tokens, head_dim=4, n_kv_heads=1, fill=1.0):
+ return _layer(
+ mx, n_tokens,
+ class_name="KVCache",
+ head_dim=head_dim, n_kv_heads=n_kv_heads, fill=fill,
+ )
+
+
+def _rotating_layer(mx, n_tokens, head_dim=4, n_kv_heads=1):
+ return _layer(
+ mx, n_tokens,
+ class_name="RotatingKVCache",
+ head_dim=head_dim, n_kv_heads=n_kv_heads,
+ )
+
+
+def _make_reconstructed_cache(mx, n_layers, n_tokens, head_dim=4):
+ """Build a list of MockKVCache objects matching what reconstruct_cache
+ would produce: keys.shape[2] == offset, valid region only."""
+ class MockKVCache:
+ def __init__(self, k, v, offset):
+ self.keys = k
+ self.values = v
+ self.offset = offset
+
+ return [
+ MockKVCache(
+ mx.ones((1, 1, n_tokens, head_dim)),
+ mx.ones((1, 1, n_tokens, head_dim)),
+ n_tokens,
+ )
+ for _ in range(n_layers)
+ ]
+
+
+def _make_mru_cache(paged_cache, mock_ssd, max_entries=4, num_layers=4):
+ """Construct a ``BlockAwarePrefixCache`` with a custom MRU capacity."""
+ return BlockAwarePrefixCache(
+ model=MockModel(num_layers=num_layers),
+ paged_cache_manager=paged_cache,
+ paged_ssd_cache_manager=mock_ssd,
+ mru_partial_max_entries=max_entries,
+ )
+
+
+def _stash_with_prefix(cache, mx, prefix_marker, tail_token):
+ """Stash a partial under a distinct parent_hash for multi-slot tests.
+
+ Builds a prompt whose first 4 tokens are unique to ``prefix_marker``
+ (forcing a unique parent block hash) and whose 5th token is the
+ partial tail. Returns ``(block_table, parent_hash)``.
+ """
+ tokens = [prefix_marker * 10 + i for i in range(4)] + [tail_token]
+ cache_data = [_kv_layer(mx, 5) for _ in range(4)]
+ block_table = cache.store_cache(f"req-{prefix_marker}", tokens, cache_data)
+ parent_hash = cache.paged_cache.allocated_blocks[
+ block_table.block_ids[-1]
+ ].block_hash
+ return block_table, parent_hash
+
+
+class TestMRUPartialBlockCache:
+ """Tests for the MRU partial block cache.
+
+ The MRU is a bounded LRU dict of trailing sub-block tails keyed by
+ ``parent_hash``. It lets exact-repeat requests skip re-prefilling
+ those tail tokens, and tolerates interleaving (multi-user / multi-
+ conversation workloads) by keeping multiple distinct-prefix entries
+ coexistent up to ``mru_partial_max_entries``.
+
+ Threat-model coverage these tests enforce:
+
+ - **Hybrid refusal:** when any layer is non-sliceable (RotatingKVCache,
+ ArraysCache, etc.), the stash is suppressed entirely. Splicing into
+ only the sliceable layers would create per-layer offset skew at
+ decode time.
+ - **Transactional splice:** if any layer's concatenate fails, no layer
+ sees a mutated keys/values/offset. Half-mutated caches are silent
+ generation corruption.
+ - **Real round-trip:** ``store_cache`` populates entries via the
+ production extraction path; ``apply_mru_partial`` then splices.
+ Tests do not hand-build ``_MRUPartialBlock`` objects for splice
+ cases — that hides the extraction-vs-apply boundary the original
+ single-slot branch's tests missed.
+ """
+
+ @pytest.fixture
+ def mx(self):
+ try:
+ import mlx.core as mx
+ return mx
+ except ImportError:
+ pytest.skip("MLX not available")
+
+ @pytest.fixture
+ def paged_cache(self):
+ return PagedCacheManager(
+ block_size=4,
+ max_blocks=100,
+ model_name="test-model",
+ initial_blocks=100,
+ )
+
+ @pytest.fixture
+ def mock_ssd(self):
+ """An SSD manager mock — present, not used.
+
+ The MRU stash gates on ``paged_ssd_cache is not None`` because in
+ the no-SSD configuration ``reconstruct_cache`` returns ``None`` and
+ ``apply_mru_partial`` is unreachable; stashing then would only
+ produce dead memory. Tests that exercise stash/apply directly
+ need an SSD instance present even though the mocked save/load
+ paths are not exercised.
+ """
+ mock = MagicMock()
+ mock.save_block.return_value = True
+ mock.load_block.return_value = None
+ mock.load_block_with_metadata.return_value = (None, None)
+ mock.has_block.return_value = False
+ return mock
+
+ @pytest.fixture
+ def prefix_cache(self, paged_cache, mock_ssd):
+ return BlockAwarePrefixCache(
+ model=MockModel(num_layers=4),
+ paged_cache_manager=paged_cache,
+ paged_ssd_cache_manager=mock_ssd,
+ )
+
+ # --- initial state ---
+
+ def test_init_state_empty(self, prefix_cache):
+ assert not prefix_cache._mru_partials
+ assert prefix_cache.has_mru_partial() is False
+
+ # --- stash semantics on uniformly sliceable layers ---
+
+ def test_stash_after_store_with_trailing_tokens(self, prefix_cache, mx):
+ """6 tokens, block_size=4 → 1 full block + 2 trailing → stash captured."""
+ tokens = [10, 20, 30, 40, 50, 60]
+ cache_data = [_kv_layer(mx, 6) for _ in range(4)]
+
+ block_table = prefix_cache.store_cache("req-stash", tokens, cache_data)
+
+ parent_hash = prefix_cache.paged_cache.allocated_blocks[
+ block_table.block_ids[-1]
+ ].block_hash
+ partial = _get_mru_partial(prefix_cache, parent_hash)
+ assert partial is not None
+ assert partial.tokens == [50, 60]
+ assert len(partial.kv_data) == 4
+ assert prefix_cache.has_mru_partial() is True
+
+ def test_no_stash_when_block_aligned(self, prefix_cache, mx):
+ """Block-aligned tokens leave no trailing partial → no entry written."""
+ tokens = [10, 20, 30, 40]
+ cache_data = [_kv_layer(mx, 4) for _ in range(4)]
+
+ prefix_cache.store_cache("req-aligned", tokens, cache_data)
+
+ assert not prefix_cache._mru_partials
+ assert prefix_cache.has_mru_partial() is False
+
+ def test_same_prefix_store_replaces_entry(self, prefix_cache, mx):
+ """Same prefix → same parent_hash → same dict key → replace.
+
+ Two stores with identical prefix tokens but different tails
+ collide on the same key (parent_hash chains from identical
+ prefix blocks). The newer tail replaces the older one in the
+ single dict entry — that is correct LRU put behavior.
+ """
+ for tail in (50, 99):
+ tokens = [10, 20, 30, 40, tail]
+ cache_data = [_kv_layer(mx, 5) for _ in range(4)]
+ prefix_cache.store_cache(f"req-{tail}", tokens, cache_data)
+
+ # Exactly one entry; its tokens are the latest tail.
+ assert len(prefix_cache._mru_partials) == 1
+ partial = next(iter(prefix_cache._mru_partials.values()))
+ assert partial.tokens == [99]
+
+ def test_no_eligible_tail_does_not_evict_siblings(
+ self, prefix_cache, mx
+ ):
+ """Behavioral change vs single-slot: a block-aligned store (no
+ trailing tail) MUST NOT wipe sibling entries from other prefixes.
+
+ Single-slot mode used to clear the lone slot in this branch.
+ Multi-slot mode treats "nothing eligible to stash this time"
+ as a local signal — sibling entries for distinct prefixes are
+ unrelated and stay.
+ """
+ # First: stash a partial via prefix A.
+ prefix_cache.store_cache(
+ "req-a", [10, 20, 30, 40, 50], [_kv_layer(mx, 5) for _ in range(4)]
+ )
+ assert len(prefix_cache._mru_partials) == 1
+ before_key = next(iter(prefix_cache._mru_partials.keys()))
+
+ # Second: block-aligned store on a DIFFERENT prefix — no tail to
+ # stash, but must not evict the existing sibling.
+ prefix_cache.store_cache(
+ "req-b", [11, 22, 33, 44], [_kv_layer(mx, 4) for _ in range(4)]
+ )
+ assert len(prefix_cache._mru_partials) == 1
+ assert next(iter(prefix_cache._mru_partials.keys())) == before_key
+
+ def test_stash_records_parent_hash_from_last_block(self, prefix_cache, mx):
+ """Stashed entry is keyed by the hash of the last full block."""
+ tokens = [10, 20, 30, 40, 50, 60]
+ cache_data = [_kv_layer(mx, 6) for _ in range(4)]
+
+ block_table = prefix_cache.store_cache("req-hash", tokens, cache_data)
+
+ # The last (and only) block's hash should be the dict key AND
+ # the partial's stored parent_hash.
+ last_block = prefix_cache.paged_cache.allocated_blocks[
+ block_table.block_ids[-1]
+ ]
+ assert last_block.block_hash is not None
+ partial = _get_mru_partial(prefix_cache, last_block.block_hash)
+ assert partial is not None
+ assert partial.parent_hash == last_block.block_hash
+
+ # --- threat model: hybrid refusal (B1, B2) ---
+
+ def test_refuse_stash_when_any_layer_non_sliceable_hybrid(
+ self, paged_cache, mock_ssd, mx
+ ):
+ """Hybrid model (KVCache + RotatingKVCache): no stash.
+
+ Splicing into only the sliceable layers produces per-layer offset
+ skew at decode time (review B2). The only correct behavior is to
+ refuse the partial entirely for hybrid models.
+ """
+ from omlx.cache.hybrid_cache import ModelCacheConfig
+
+ cache = BlockAwarePrefixCache(
+ model=MockModel(num_layers=2),
+ paged_cache_manager=paged_cache,
+ paged_ssd_cache_manager=mock_ssd,
+ )
+ tokens = [10, 20, 30, 40, 50, 60]
+ cache_data = [
+ _kv_layer(mx, 6),
+ _rotating_layer(mx, 6),
+ ]
+ config = ModelCacheConfig.from_type_list(
+ ["KVCache", "RotatingKVCache"], model_name="test"
+ )
+
+ cache.store_cache("req-hybrid", tokens, cache_data, model_cache_config=config)
+
+ assert not cache._mru_partials
+ assert cache.has_mru_partial() is False
+
+ def test_refuse_stash_when_all_layers_non_sliceable(self, prefix_cache, mx):
+ """Pure RotatingKVCache model also refuses stash."""
+ from omlx.cache.hybrid_cache import ModelCacheConfig
+
+ tokens = [10, 20, 30, 40, 50]
+ cache_data = [_rotating_layer(mx, 5) for _ in range(4)]
+ config = ModelCacheConfig.from_type_list(
+ ["RotatingKVCache"] * 4, model_name="test"
+ )
+
+ prefix_cache.store_cache(
+ "req-rotating", tokens, cache_data, model_cache_config=config
+ )
+
+ assert not prefix_cache._mru_partials
+
+ def test_refuse_stash_when_layer_falls_through_to_default_handler(
+ self, prefix_cache, mx
+ ):
+ """Non-sliceable types whose handler is unregistered (fall through
+ to ``DefaultCacheHandler``, which inherits ``KVCacheHandler``'s
+ ``supports_block_slicing=True``) must still be refused.
+
+ Concrete case: ``BatchRotatingKVCache`` is mapped in
+ ``_class_name_map`` to ``BATCH_ROTATING_KVCACHE`` but no handler
+ is registered for that enum. The original rewrite's
+ registry-based gate would have classified it as sliceable,
+ recreating exactly the silent-corruption hazard the rewrite was
+ supposed to close, just from a different angle. The fix uses an
+ explicit class-name whitelist (``KNOWN_SLICEABLE_CACHE_TYPES``)
+ instead of the registry.
+ """
+ from omlx.cache.hybrid_cache import ModelCacheConfig
+ from omlx.cache.type_registry import (
+ KNOWN_SLICEABLE_CACHE_TYPES,
+ CacheTypeRegistry,
+ )
+
+ # Sanity: the registry would lie about this class name.
+ handler = CacheTypeRegistry.get_handler_by_class_name("BatchRotatingKVCache")
+ assert handler.supports_block_slicing is True
+ # And the whitelist correctly excludes it.
+ assert "BatchRotatingKVCache" not in KNOWN_SLICEABLE_CACHE_TYPES
+
+ tokens = [10, 20, 30, 40, 50, 60]
+ # cache_data shape doesn't matter — store_cache must refuse before
+ # any extraction is attempted.
+ cache_data = [_kv_layer(mx, 6) for _ in range(4)]
+ config = ModelCacheConfig.from_type_list(
+ ["BatchRotatingKVCache"] * 4, model_name="test"
+ )
+
+ prefix_cache.store_cache(
+ "req-batch-rotating", tokens, cache_data, model_cache_config=config
+ )
+
+ assert not prefix_cache._mru_partials
+ assert prefix_cache.has_mru_partial() is False
+
+ # --- threat model: stale-slot eviction at clear() (C2) ---
+
+ def test_clear_wipes_mru_partials(self, prefix_cache, mx):
+ """``BlockAwarePrefixCache.clear()`` must drop the entire MRU dict.
+
+ The scheduler's cache-corruption recovery routes through
+ ``clear()``. Surviving partials chain from paged-block hashes
+ whose backing blocks were just freed; the dict is wiped so no
+ entry can survive into the recovery path that exists because
+ something was wrong.
+ """
+ prefix_cache.store_cache(
+ "req-clear",
+ [10, 20, 30, 40, 50, 60],
+ [_kv_layer(mx, 6) for _ in range(4)],
+ )
+ assert bool(prefix_cache._mru_partials)
+
+ prefix_cache.clear()
+
+ assert not prefix_cache._mru_partials
+ assert prefix_cache.has_mru_partial() is False
+
+ # --- threat model: H2 ambiguous cache layout ---
+
+ def test_refuse_stash_on_ambiguous_cache_layout(
+ self, prefix_cache, mx
+ ):
+ """Cache lengths that don't unambiguously map to global or local
+ indexing must refuse the stash.
+
+ Multi-turn requests can produce ``cache_seq_len ==
+ existing_tokens`` or shapes between local and global. The
+ previous heuristic (``cache_seq_len >= existing_tokens + 1``)
+ silently picked "local" on the boundary, slicing local indices
+ out of a global-indexed cache and capturing tokens from the
+ prefix instead of the trailing tail. parent_hash still matched,
+ and a future apply spliced wrong KV — silent generation
+ corruption.
+
+ Drive that boundary directly: cache_seq_len falls strictly
+ between global_end and local_len.
+ """
+ # First turn: cache 4 tokens.
+ prefix_cache.store_cache(
+ "req-turn-1",
+ [1, 2, 3, 4],
+ [_kv_layer(mx, 4) for _ in range(4)],
+ )
+
+ # Second turn: 8 prefix-aligned tokens (1 full block + 1 partial-block).
+ # Hand a cache_data whose cache_seq_len is 6 — strictly between:
+ # - local_len = len(new_tokens) = len(tokens) - existing_tokens = 4
+ # - global_end = existing_tokens + new_count = 4 + 4 = 8
+ # global_end (8) > cache_seq_len (6) > local_len (4): ambiguous.
+ full_tokens = [1, 2, 3, 4, 5, 6, 7, 8]
+ cache_data = [_kv_layer(mx, 6) for _ in range(4)]
+
+ prefix_cache.store_cache(
+ "req-turn-2-ambiguous", full_tokens, cache_data
+ )
+
+ # Refuse rather than guess. Stash must be the previous turn's
+ # state cleared (block-aligned turn 1 has no stash anyway), not
+ # a guessed-wrong turn 2.
+ assert not prefix_cache._mru_partials
+
+ # --- accounting invariant ---
+
+ def test_kv_data_holds_mlx_arrays_for_active_memory_accounting(
+ self, prefix_cache, mx
+ ):
+ """The MRU slot's memory must flow through ``mx.get_active_memory()``.
+
+ The codebase enforces all KV-memory limits via ``mx.get_active_memory()``
+ (process_memory_enforcer, the three scheduler memory checkpoints,
+ the periodic-clear threshold, telemetry). The MRU slot has no
+ separate accounting hook — it relies on the invariant that
+ ``kv_data`` holds real ``mx.array`` allocations, which MLX counts
+ in active memory automatically.
+
+ A "helpful" future change that stored CPU-side copies (e.g.
+ ``np.ndarray`` to dodge a perceived GPU-memory cost) would silently
+ escape every existing memory limit and only manifest as system OOM
+ under load. Pin the invariant so that change is caught at test
+ time, not in production.
+ """
+ block_table = prefix_cache.store_cache(
+ "req-accounting",
+ [10, 20, 30, 40, 50, 60],
+ [_kv_layer(mx, 6) for _ in range(4)],
+ )
+
+ parent_hash = prefix_cache.paged_cache.allocated_blocks[
+ block_table.block_ids[-1]
+ ].block_hash
+ partial = _get_mru_partial(prefix_cache, parent_hash)
+ assert partial is not None
+ assert len(partial.kv_data) == 4
+ for layer_idx, (keys, values) in enumerate(partial.kv_data):
+ assert isinstance(keys, mx.array), (
+ f"layer {layer_idx} keys is {type(keys).__name__}, not mx.array. "
+ f"MRU memory accounting depends on mx.array storage so the "
+ f"slot is visible to mx.get_active_memory()."
+ )
+ assert isinstance(values, mx.array), (
+ f"layer {layer_idx} values is {type(values).__name__}, "
+ f"not mx.array. See above."
+ )
+
+ # --- threat model: no-reconstruct-path config ---
+
+ def test_no_stash_when_paged_ssd_cache_is_none(self, paged_cache, mx):
+ """Without a ``PagedSSDCacheManager`` instance, ``reconstruct_cache``
+ returns ``None`` (``_can_reconstruct() is False``) and
+ ``apply_mru_partial`` is unreachable from the scheduler. Stashing
+ in this configuration would only produce dead memory.
+
+ Note: this is distinct from ``hot_cache_only=True``, where the
+ manager IS present (the disk writer thread is what's disabled,
+ not the manager itself). In that mode the MRU stash IS expected
+ to populate — ``load_block_with_metadata`` short-circuits to the
+ hot tier and reconstruct still works. The gate keys on manager
+ presence, not on whether SSD writes are happening.
+ """
+ cache = BlockAwarePrefixCache(
+ model=MockModel(num_layers=4),
+ paged_cache_manager=paged_cache,
+ paged_ssd_cache_manager=None,
+ )
+
+ cache.store_cache(
+ "req-no-ssd",
+ [10, 20, 30, 40, 50, 60],
+ [_kv_layer(mx, 6) for _ in range(4)],
+ )
+
+ assert not cache._mru_partials
+ assert cache.has_mru_partial() is False
+
+ def test_can_reconstruct_helper_reflects_manager_presence(
+ self, paged_cache, mock_ssd
+ ):
+ """``_can_reconstruct`` is the canonical predicate keeping the
+ MRU stash gate and the ``reconstruct_cache`` guard in lockstep.
+
+ It returns False only when no manager is configured at all.
+ ``hot_cache_only=True`` configurations (manager present, disk
+ writer disabled) return True because reconstruct still works
+ via the hot-tier short-circuit in
+ ``PagedSSDCacheManager.load_block_with_metadata``.
+ """
+ cache_with = BlockAwarePrefixCache(
+ model=MockModel(num_layers=2),
+ paged_cache_manager=paged_cache,
+ paged_ssd_cache_manager=mock_ssd,
+ )
+ assert cache_with._can_reconstruct() is True
+
+ cache_without = BlockAwarePrefixCache(
+ model=MockModel(num_layers=2),
+ paged_cache_manager=paged_cache,
+ paged_ssd_cache_manager=None,
+ )
+ assert cache_without._can_reconstruct() is False
+
+ # --- apply: real round-trip ---
+
+ def test_apply_round_trip_exact_match(self, prefix_cache, mx):
+ """Real store → apply round-trip: partial produced by extraction
+ is consumed by the splice path, no hand-built _MRUPartialBlock."""
+ tokens = [10, 20, 30, 40, 50, 60]
+ cache_data = [_kv_layer(mx, 6) for _ in range(4)]
+ block_table = prefix_cache.store_cache("req-rt", tokens, cache_data)
+
+ # Reconstructed cache: 4 layers × 4 tokens (the prefix only).
+ reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4)
+ remaining = [50, 60]
+
+ result, new_remaining, applied = prefix_cache.apply_mru_partial(
+ reconstructed, block_table, remaining,
+ )
+
+ assert applied == 2
+ assert new_remaining == []
+ assert all(layer.offset == 6 for layer in result)
+ assert all(layer.keys.shape[2] == 6 for layer in result)
+ assert all(layer.values.shape[2] == 6 for layer in result)
+
+ def test_apply_round_trip_prefix_match_leaves_extra_tokens(
+ self, prefix_cache, mx
+ ):
+ """When remaining is longer than the partial, the partial covers
+ its prefix and the rest is left for normal prefill."""
+ tokens = [10, 20, 30, 40, 50, 60]
+ cache_data = [_kv_layer(mx, 6) for _ in range(4)]
+ block_table = prefix_cache.store_cache("req-rt-prefix", tokens, cache_data)
+
+ reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4)
+ remaining = [50, 60, 70, 80] # partial is [50, 60]; [70, 80] left over
+
+ _, new_remaining, applied = prefix_cache.apply_mru_partial(
+ reconstructed, block_table, remaining,
+ )
+
+ assert applied == 2
+ assert new_remaining == [70, 80]
+
+ # --- apply: eviction reasons ---
+
+ def test_apply_noop_on_parent_hash_mismatch_preserves_sibling(
+ self, prefix_cache, mx, paged_cache
+ ):
+ """A request keyed by a parent_hash that isn't in the dict
+ returns no-op WITHOUT evicting unrelated sibling entries.
+
+ Behavioral change from single-slot: the single slot used to be
+ evicted whenever the lookup key didn't match. In multi-slot,
+ the lookup simply misses and other entries are preserved.
+ """
+ # Stash a partial under prefix A.
+ tokens = [10, 20, 30, 40, 50, 60]
+ block_table_a = prefix_cache.store_cache(
+ "req-a", tokens, [_kv_layer(mx, 6) for _ in range(4)]
+ )
+ before = dict(prefix_cache._mru_partials)
+ assert len(before) == 1
+
+ # Construct a synthetic block_table pointing at a block whose
+ # hash is NOT a key in the dict (simulate "request for a
+ # different prefix that has its own paged block").
+ other_block = paged_cache.allocate_block()
+ other_block.block_hash = b"\x00" * 32 # not in the MRU dict
+ synthetic_bt = BlockTable(request_id="req-other")
+ synthetic_bt.block_ids.append(other_block.block_id)
+ synthetic_bt.num_tokens = 4
+
+ reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4)
+ _, new_remaining, applied = prefix_cache.apply_mru_partial(
+ reconstructed, synthetic_bt, [50, 60],
+ )
+
+ assert applied == 0
+ assert new_remaining == [50, 60]
+ # Prefix A's entry must still be present — no false eviction.
+ assert dict(prefix_cache._mru_partials) == before
+
+ def test_apply_evicts_on_token_mismatch(self, prefix_cache, mx):
+ """Different trailing tokens → partial cannot apply, evict."""
+ tokens = [10, 20, 30, 40, 50, 60]
+ cache_data = [_kv_layer(mx, 6) for _ in range(4)]
+ block_table = prefix_cache.store_cache("req-evict-t", tokens, cache_data)
+
+ reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4)
+ _, new_remaining, applied = prefix_cache.apply_mru_partial(
+ reconstructed, block_table, [99, 60], # first token doesn't match
+ )
+
+ assert applied == 0
+ assert new_remaining == [99, 60]
+ assert not prefix_cache._mru_partials
+
+ def test_apply_evicts_on_remaining_shorter_than_partial(
+ self, prefix_cache, mx
+ ):
+ """If remaining_tokens is shorter than the partial it cannot match."""
+ tokens = [10, 20, 30, 40, 50, 60, 70]
+ cache_data = [_kv_layer(mx, 7) for _ in range(4)]
+ block_table = prefix_cache.store_cache("req-evict-s", tokens, cache_data)
+ # Partial is [50, 60, 70]; remaining is shorter → must evict.
+
+ reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4)
+ _, new_remaining, applied = prefix_cache.apply_mru_partial(
+ reconstructed, block_table, [50, 60],
+ )
+
+ assert applied == 0
+ assert not prefix_cache._mru_partials
+
+ def test_apply_evicts_on_layer_count_mismatch(self, prefix_cache, mx):
+ """If the reconstructed cache layer count differs from the
+ stashed partial, evict — likely a model swap or bug, not safe."""
+ tokens = [10, 20, 30, 40, 50, 60]
+ cache_data = [_kv_layer(mx, 6) for _ in range(4)]
+ block_table = prefix_cache.store_cache("req-evict-lc", tokens, cache_data)
+
+ # Reconstructed has only 2 layers, partial has 4 → mismatch.
+ reconstructed = _make_reconstructed_cache(mx, n_layers=2, n_tokens=4)
+ _, _, applied = prefix_cache.apply_mru_partial(
+ reconstructed, block_table, [50, 60],
+ )
+
+ assert applied == 0
+ assert not prefix_cache._mru_partials
+
+ def test_apply_noop_when_no_stash(self, prefix_cache, paged_cache, mx):
+ """No partial → no-op, remaining unchanged."""
+ block_table = paged_cache.create_block_table("req-noop")
+ reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=0)
+
+ result, remaining, applied = prefix_cache.apply_mru_partial(
+ reconstructed, block_table, [10, 20],
+ )
+
+ assert applied == 0
+ assert remaining == [10, 20]
+ assert result is reconstructed
+
+ def test_apply_noop_when_remaining_empty(self, prefix_cache, mx):
+ """Empty remaining → exact prefix hit already; no MRU work."""
+ tokens = [10, 20, 30, 40, 50, 60]
+ cache_data = [_kv_layer(mx, 6) for _ in range(4)]
+ block_table = prefix_cache.store_cache("req-noop-empty", tokens, cache_data)
+
+ reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4)
+ _, _, applied = prefix_cache.apply_mru_partial(
+ reconstructed, block_table, [],
+ )
+
+ assert applied == 0
+ # Stash must NOT be evicted on empty-remaining no-op — it could
+ # still match a *future* request that does have tail tokens.
+ assert bool(prefix_cache._mru_partials)
+
+ # --- threat model: transactional splice rollback (B3) ---
+
+ def test_splice_failure_does_not_mutate_any_layer(
+ self, prefix_cache, mx
+ ):
+ """If any layer's concatenate fails, NO layer is mutated.
+
+ Review B3: the original implementation's try/except wrapped the
+ whole loop, so failure on layer N>0 left layers 0..N-1 mutated
+ with cache.offset += n_partial while the caller was told nothing
+ was applied. The rewrite must build replacements first and commit
+ atomically.
+ """
+ tokens = [10, 20, 30, 40, 50, 60]
+ cache_data = [_kv_layer(mx, 6) for _ in range(4)]
+ block_table = prefix_cache.store_cache("req-rollback", tokens, cache_data)
+
+ reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4)
+ # Snapshot the pre-splice state of every layer for an after-comparison.
+ before_offsets = [layer.offset for layer in reconstructed]
+ before_key_shapes = [layer.keys.shape for layer in reconstructed]
+
+ # Make mx.concatenate explode on the third call (layer 1's keys).
+ # Calls go: layer0 keys, layer0 values, layer1 keys (boom).
+ from omlx.cache import prefix_cache as pc_mod
+ real_concatenate = pc_mod.mx.concatenate
+ call_count = {"n": 0}
+
+ def flaky_concatenate(*args, **kwargs):
+ call_count["n"] += 1
+ if call_count["n"] == 3:
+ raise RuntimeError("synthetic concatenate failure")
+ return real_concatenate(*args, **kwargs)
+
+ with patch.object(pc_mod.mx, "concatenate", side_effect=flaky_concatenate):
+ _, new_remaining, applied = prefix_cache.apply_mru_partial(
+ reconstructed, block_table, [50, 60],
+ )
+
+ assert applied == 0
+ assert new_remaining == [50, 60]
+ # No layer's offset advanced.
+ assert [layer.offset for layer in reconstructed] == before_offsets
+ # No layer's keys shape changed.
+ assert [layer.keys.shape for layer in reconstructed] == before_key_shapes
+ # Slot is evicted on splice failure (don't retry a failing partial).
+ assert not prefix_cache._mru_partials
+
+ # --- threat model: multi-turn (existing_tokens > 0) ---
+
+ def test_stash_correct_indices_when_existing_tokens_present(
+ self, prefix_cache, mx
+ ):
+ """When store_cache is called with existing_tokens > 0 (multi-turn),
+ the stash slices the partial from the correct cache region.
+
+ cache_data is full-sequence (system prompt + new turn), so the
+ partial extraction must use global indices, not relative ones.
+ """
+ # Pretend a previous turn already cached 4 tokens.
+ prev_tokens = [1, 2, 3, 4]
+ prev_cache = [_kv_layer(mx, 4) for _ in range(4)]
+ prefix_cache.store_cache("req-turn-1", prev_tokens, prev_cache)
+
+ # Second turn: 4 prev + 4 new = 8 prefix block tokens, then 2 trailing.
+ # Distinct fill values let us verify the stash sliced the *right*
+ # tokens — the partial should contain the trailing region's data
+ # (fill=2.0), not the prefix region (fill=1.0).
+ full_tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+ full_cache = []
+ for _ in range(4):
+ # First 8 positions = old (1.0), last 2 positions = new (2.0)
+ keys = mx.concatenate(
+ [
+ mx.full((1, 1, 8, 4), 1.0),
+ mx.full((1, 1, 2, 4), 2.0),
+ ],
+ axis=2,
+ )
+ values = mx.concatenate(
+ [
+ mx.full((1, 1, 8, 4), 1.0),
+ mx.full((1, 1, 2, 4), 2.0),
+ ],
+ axis=2,
+ )
+ full_cache.append({
+ "state": (keys, values),
+ "cache_type": "KVCache",
+ "class_name": "KVCache",
+ })
+
+ block_table = prefix_cache.store_cache(
+ "req-turn-2", full_tokens, full_cache
+ )
+
+ parent_hash = prefix_cache.paged_cache.allocated_blocks[
+ block_table.block_ids[-1]
+ ].block_hash
+ partial = _get_mru_partial(prefix_cache, parent_hash)
+ assert partial is not None
+ assert partial.tokens == [9, 10]
+ # Each layer's stashed slice must be the trailing region (fill=2.0).
+ for keys, values in partial.kv_data:
+ assert keys.shape[2] == 2
+ assert mx.allclose(keys, mx.full((1, 1, 2, 4), 2.0))
+ assert mx.allclose(values, mx.full((1, 1, 2, 4), 2.0))
+
+
+class TestMRUPartialMultiSlot:
+ """Multi-slot LRU semantics: coexistence, capacity, eviction discipline.
+
+ These tests cover the mechanics that single-slot mode could not
+ exercise — multiple entries keyed by distinct ``parent_hash`` values,
+ LRU promotion on apply success, sibling preservation on apply miss,
+ capacity-bounded eviction, ``max_entries=0`` feature disable, and
+ the freed-paged-block guard introduced by the multi-slot design.
+ """
+
+ @pytest.fixture
+ def mx(self):
+ try:
+ import mlx.core as mx
+ return mx
+ except ImportError:
+ pytest.skip("MLX not available")
+
+ @pytest.fixture
+ def paged_cache(self):
+ return PagedCacheManager(
+ block_size=4,
+ max_blocks=100,
+ model_name="test-model",
+ initial_blocks=100,
+ )
+
+ @pytest.fixture
+ def mock_ssd(self):
+ mock = MagicMock()
+ mock.save_block.return_value = True
+ mock.load_block.return_value = None
+ mock.load_block_with_metadata.return_value = (None, None)
+ mock.has_block.return_value = False
+ return mock
+
+ # --- multi-entry coexistence ---
+
+ def test_distinct_prefixes_coexist_as_separate_entries(
+ self, paged_cache, mock_ssd, mx
+ ):
+ """Two stashes with different parent prefixes produce two entries."""
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4)
+ _, hash_a = _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99)
+ _, hash_b = _stash_with_prefix(cache, mx, prefix_marker=2, tail_token=88)
+
+ assert hash_a != hash_b
+ assert len(cache._mru_partials) == 2
+ assert hash_a in cache._mru_partials
+ assert hash_b in cache._mru_partials
+
+ # --- LRU mechanics (parameterized) ---
+
+ @pytest.mark.parametrize(
+ "scenario",
+ [
+ # (capacity, n_stashes_in_order, expected_dict_keys_in_order)
+ # Capacity respected; oldest evicted on overflow.
+ ("evict_oldest_at_capacity", 2, [1, 2, 3], [2, 3]),
+ # Below capacity, all retained, insertion order preserved.
+ ("under_capacity_keeps_all", 4, [1, 2, 3], [1, 2, 3]),
+ ],
+ ids=lambda s: s[0] if isinstance(s, tuple) else str(s),
+ )
+ def test_lru_capacity_bounds(
+ self, paged_cache, mock_ssd, mx, scenario
+ ):
+ _, capacity, order, expected = scenario
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=capacity)
+ hashes = {}
+ for marker in order:
+ _, h = _stash_with_prefix(
+ cache, mx, prefix_marker=marker, tail_token=900 + marker
+ )
+ hashes[marker] = h
+
+ expected_keys = [hashes[m] for m in expected]
+ assert list(cache._mru_partials.keys()) == expected_keys
+
+ def test_apply_success_promotes_entry_to_lru_tail(
+ self, paged_cache, mock_ssd, mx
+ ):
+ """Applying an entry moves it to the LRU tail; a subsequent
+ capacity-eviction drops a now-older sibling, not the just-used
+ entry."""
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=2)
+ bt_a, hash_a = _stash_with_prefix(
+ cache, mx, prefix_marker=1, tail_token=901
+ )
+ _, hash_b = _stash_with_prefix(
+ cache, mx, prefix_marker=2, tail_token=902
+ )
+ assert list(cache._mru_partials.keys()) == [hash_a, hash_b]
+
+ # Apply A → A promoted to tail.
+ reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4)
+ # Remaining must equal A's stashed tokens. Tail token was 901,
+ # placed at index 4 of A's prompt.
+ _, _, applied = cache.apply_mru_partial(reconstructed, bt_a, [901])
+ assert applied == 1
+ assert list(cache._mru_partials.keys()) == [hash_b, hash_a]
+
+ # Stash C at capacity 2 → B evicted (oldest after promote), A kept.
+ _, hash_c = _stash_with_prefix(
+ cache, mx, prefix_marker=3, tail_token=903
+ )
+ assert list(cache._mru_partials.keys()) == [hash_a, hash_c]
+ assert hash_b not in cache._mru_partials
+
+ # --- max_entries=0 disables ---
+
+ def test_max_entries_zero_disables_stashing(
+ self, paged_cache, mock_ssd, mx
+ ):
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=0)
+ _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99)
+
+ assert len(cache._mru_partials) == 0
+ assert cache.has_mru_partial() is False
+
+ # --- clear_mru_partials() leaves siblings alone ---
+
+ def test_clear_mru_partials_wipes_only_partials(
+ self, paged_cache, mock_ssd, mx
+ ):
+ """``clear_mru_partials()`` is the admin-clear hook. It wipes the
+ MRU dict but must not touch ``paged_cache``, the prefix index,
+ or stats — those have their own clear paths."""
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4)
+ bt, _ = _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99)
+
+ prefix_index_before = dict(cache._prefix_index)
+ request_tables_before = dict(cache._request_tables)
+ assert len(cache._mru_partials) == 1
+ assert prefix_index_before # was populated by store_cache
+
+ n_wiped = cache.clear_mru_partials()
+
+ assert n_wiped == 1
+ assert len(cache._mru_partials) == 0
+ # Paged blocks, prefix index, request tables all unchanged.
+ assert cache._prefix_index == prefix_index_before
+ assert cache._request_tables == request_tables_before
+ assert bt.block_ids[-1] in cache.paged_cache.allocated_blocks
+
+ # --- freed-block guard (new in multi-slot) ---
+
+ def test_apply_noop_when_parent_block_freed(
+ self, paged_cache, mock_ssd, mx
+ ):
+ """If the parent paged block is freed between stash and apply,
+ the apply path must not fall through to a None-keyed lookup
+ (which could falsely match a short-prompt entry).
+
+ This race is new in multi-slot: single-slot tolerated it because
+ there was only ever one slot to match against.
+ """
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4)
+ # Stash a short-prompt entry (parent_hash=None) — the false-match
+ # bait for the freed-block scenario.
+ short_tokens = [99, 100, 101] # < block_size=4
+ cache.store_cache(
+ "req-short", short_tokens, [_kv_layer(mx, 3) for _ in range(4)]
+ )
+ # Confirm the short-prompt entry landed under None.
+ assert None in cache._mru_partials
+
+ # Construct a block_table whose last block has been freed.
+ freed_block = paged_cache.allocate_block()
+ freed_block_id = freed_block.block_id
+ paged_cache.free_block(freed_block_id)
+ bt = BlockTable(request_id="req-freed")
+ bt.block_ids.append(freed_block_id)
+ bt.num_tokens = 4
+
+ reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4)
+ _, new_remaining, applied = cache.apply_mru_partial(
+ reconstructed, bt, [99, 100, 101]
+ )
+
+ # Must NOT splice the short-prompt entry even though the
+ # remaining tokens happen to match.
+ assert applied == 0
+ assert new_remaining == [99, 100, 101]
+ # Short-prompt entry preserved (not falsely evicted by the guard).
+ assert None in cache._mru_partials
+
+ # --- short-prompt None-key coexists with hash-keyed entry ---
+
+ def test_short_prompt_none_key_coexists_with_block_aligned_entry(
+ self, paged_cache, mock_ssd, mx
+ ):
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4)
+ # Short prompt (< block_size) → parent_hash=None
+ cache.store_cache(
+ "req-short", [10, 20, 30],
+ [_kv_layer(mx, 3) for _ in range(4)],
+ )
+ # Longer prompt → distinct parent_hash
+ _, hash_long = _stash_with_prefix(
+ cache, mx, prefix_marker=1, tail_token=99
+ )
+
+ assert None in cache._mru_partials
+ assert hash_long in cache._mru_partials
+ assert len(cache._mru_partials) == 2
+
+
+class TestMRUPartialCounters:
+ """The observability counters mirror PR #1183's pattern so operators
+ can answer "is the MRU cache paying off" with the same dashboard
+ surface they use for prefix-hit and memory-hit rates.
+
+ Counters: ``mru_partial_stashes``, ``mru_partial_hits``,
+ ``mru_partial_evictions``, ``mru_partial_tokens_saved``.
+ Gauges: ``mru_partial_entries``, ``mru_partial_max_entries``.
+ """
+
+ @pytest.fixture
+ def mx(self):
+ try:
+ import mlx.core as mx
+ return mx
+ except ImportError:
+ pytest.skip("MLX not available")
+
+ @pytest.fixture
+ def paged_cache(self):
+ return PagedCacheManager(
+ block_size=4,
+ max_blocks=100,
+ model_name="test-model",
+ initial_blocks=100,
+ )
+
+ @pytest.fixture
+ def mock_ssd(self):
+ mock = MagicMock()
+ mock.save_block.return_value = True
+ mock.load_block.return_value = None
+ mock.load_block_with_metadata.return_value = (None, None)
+ mock.has_block.return_value = False
+ return mock
+
+ def test_initial_counters_are_zero(self, paged_cache, mock_ssd):
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4)
+ stats = cache.get_stats()
+ assert stats.mru_partial_stashes == 0
+ assert stats.mru_partial_hits == 0
+ assert stats.mru_partial_evictions == 0
+ assert stats.mru_partial_tokens_saved == 0
+ assert stats.mru_partial_entries == 0
+ assert stats.mru_partial_max_entries == 4
+
+ def test_stash_increments_stash_counter(self, paged_cache, mock_ssd, mx):
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4)
+ _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99)
+ _stash_with_prefix(cache, mx, prefix_marker=2, tail_token=88)
+
+ stats = cache.get_stats()
+ assert stats.mru_partial_stashes == 2
+ assert stats.mru_partial_entries == 2
+
+ def test_same_key_replacement_counts_as_stash_not_eviction(
+ self, paged_cache, mock_ssd, mx
+ ):
+ """Replacing an existing entry under the same key counts as a
+ stash but NOT as an eviction. Eviction is reserved for entries
+ that leave the dict (capacity overflow, apply-miss, clear)."""
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4)
+ _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99)
+ _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=77)
+
+ stats = cache.get_stats()
+ assert stats.mru_partial_stashes == 2
+ assert stats.mru_partial_evictions == 0
+ assert stats.mru_partial_entries == 1
+
+ def test_capacity_overflow_increments_eviction_counter(
+ self, paged_cache, mock_ssd, mx
+ ):
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=2)
+ for i in (1, 2, 3):
+ _stash_with_prefix(cache, mx, prefix_marker=i, tail_token=100 + i)
+
+ stats = cache.get_stats()
+ assert stats.mru_partial_stashes == 3
+ assert stats.mru_partial_evictions == 1 # one entry pushed out
+ assert stats.mru_partial_entries == 2
+
+ def test_apply_success_increments_hits_and_tokens_saved(
+ self, paged_cache, mock_ssd, mx
+ ):
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4)
+ bt, _ = _stash_with_prefix(
+ cache, mx, prefix_marker=1, tail_token=901
+ )
+
+ reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4)
+ _, _, applied = cache.apply_mru_partial(reconstructed, bt, [901])
+ assert applied == 1
+
+ stats = cache.get_stats()
+ assert stats.mru_partial_hits == 1
+ assert stats.mru_partial_tokens_saved == 1
+ assert stats.mru_partial_evictions == 0 # success, not eviction
+
+ def test_apply_miss_on_found_key_increments_eviction(
+ self, paged_cache, mock_ssd, mx
+ ):
+ """Token-mismatch eviction pops the matched key and counts."""
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4)
+ bt, _ = _stash_with_prefix(
+ cache, mx, prefix_marker=1, tail_token=99
+ )
+
+ reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4)
+ _, _, applied = cache.apply_mru_partial(reconstructed, bt, [77]) # wrong tail
+ assert applied == 0
+
+ stats = cache.get_stats()
+ assert stats.mru_partial_hits == 0
+ assert stats.mru_partial_evictions == 1
+
+ def test_clear_mru_partials_counts_all_wiped_entries(
+ self, paged_cache, mock_ssd, mx
+ ):
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4)
+ _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99)
+ _stash_with_prefix(cache, mx, prefix_marker=2, tail_token=88)
+ _stash_with_prefix(cache, mx, prefix_marker=3, tail_token=77)
+
+ n = cache.clear_mru_partials()
+ assert n == 3
+
+ stats = cache.get_stats()
+ assert stats.mru_partial_evictions == 3
+ assert stats.mru_partial_entries == 0
+
+ def test_clear_wipes_partials_and_resets_counters(
+ self, paged_cache, mock_ssd, mx
+ ):
+ """clear() is the "restart everything" path (cache-corruption
+ recovery). It wipes the dict AND resets every counter,
+ including mru_partial_evictions — incrementing evictions just
+ to have them zeroed by the same call would be incoherent.
+ Operators tracking partial wipes specifically use
+ clear_mru_partials() instead.
+ """
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4)
+ _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99)
+ _stash_with_prefix(cache, mx, prefix_marker=2, tail_token=88)
+
+ cache.clear()
+
+ stats = cache.get_stats()
+ assert stats.mru_partial_entries == 0 # dict wiped
+ assert stats.mru_partial_stashes == 0 # counters reset
+ assert stats.mru_partial_evictions == 0
+ assert stats.mru_partial_hits == 0
+
+ def test_reset_stats_zeros_mru_counters_but_keeps_live_state(
+ self, paged_cache, mock_ssd, mx
+ ):
+ """reset_stats() is the analyst's reset — it zeros cumulative
+ counters but leaves the live cache state alone. Use
+ clear_mru_partials() if entries should be dropped too."""
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4)
+ _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99)
+
+ cache.reset_stats()
+
+ stats = cache.get_stats()
+ assert stats.mru_partial_stashes == 0
+ assert stats.mru_partial_hits == 0
+ assert stats.mru_partial_evictions == 0
+ assert stats.mru_partial_tokens_saved == 0
+ # But the live entry is still there.
+ assert stats.mru_partial_entries == 1
+
+ def test_get_stats_dict_mirrors_dataclass_after_round_trip(
+ self, paged_cache, mock_ssd, mx
+ ):
+ """``get_stats_dict`` must surface every MRU field that ``get_stats``
+ (the dataclass) does. The admin dashboard reads MRU state via the
+ dict path (``Scheduler.get_ssd_cache_stats`` -> ``get_stats_dict``);
+ when the dict drops any of these keys, the admin payload's
+ ``mru_partial_max_entries`` aggregates to 0 and the dashboard's
+ ``mruEnabled`` gate hides every MRU panel even when the feature
+ is enabled.
+
+ Uses the production round-trip (real stashes via ``store_cache``,
+ real apply via ``apply_mru_partial``, real capacity overflow) so
+ the live gauge ``mru_partial_entries`` and every counter reach the
+ dict via the same path the scheduler exercises.
+ """
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=2)
+ # Stash two distinct prefixes (both fit in the cap). The first
+ # one (``bt_kept``) will be applied below to bump hits; the LRU
+ # touch from a successful apply leaves it at the MRU end, so the
+ # later capacity-overflow stash evicts the *other* survivor and
+ # not the one we just touched.
+ bt_kept, _ = _stash_with_prefix(
+ cache, mx, prefix_marker=2, tail_token=88
+ )
+ _stash_with_prefix(cache, mx, prefix_marker=3, tail_token=77)
+
+ reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4)
+ _, _, applied = cache.apply_mru_partial(reconstructed, bt_kept, [88])
+ assert applied == 1 # guard: the apply path actually fired
+
+ # Third stash forces capacity-overflow eviction of the un-touched
+ # entry. After this: 3 stashes, 1 hit, 1 token saved, 1 eviction,
+ # 2 live entries.
+ _stash_with_prefix(cache, mx, prefix_marker=4, tail_token=66)
+
+ stats = cache.get_stats()
+ stats_dict = cache.get_stats_dict()
+
+ # Every MRU field on the dataclass must surface in the dict with
+ # the same value — this is the contract the admin route depends on.
+ for field in (
+ "mru_partial_stashes",
+ "mru_partial_hits",
+ "mru_partial_evictions",
+ "mru_partial_tokens_saved",
+ "mru_partial_entries",
+ "mru_partial_max_entries",
+ ):
+ assert field in stats_dict, f"{field} missing from get_stats_dict()"
+ assert stats_dict[field] == getattr(stats, field), (
+ f"{field}: dict={stats_dict[field]} dataclass={getattr(stats, field)}"
+ )
+
+ # Sanity: the round-trip actually moved every counter off its
+ # initial zero, so a future regression that hardwires zeros into
+ # the dict would still fail this test.
+ assert stats_dict["mru_partial_stashes"] == 3
+ assert stats_dict["mru_partial_hits"] == 1
+ assert stats_dict["mru_partial_evictions"] == 1 # capacity overflow
+ assert stats_dict["mru_partial_tokens_saved"] == 1
+ assert stats_dict["mru_partial_entries"] == 2
+ assert stats_dict["mru_partial_max_entries"] == 2
+
+
+def _model_with_make_cache(num_layers: int, layer_class_names: list[str]):
+ """Build a MockModel whose ``make_cache()`` returns objects whose
+ ``type(obj).__name__`` matches the requested cache class names.
+
+ ``ModelCacheConfig.from_cache_list`` identifies cache types by class
+ name (with isinstance fallback for SizedArraysCache only), so dynamic
+ classes are enough to exercise the eager init-time eligibility check
+ without pulling in real mlx-lm cache implementations.
+ """
+ cache_objs = [
+ type(name, (object,), {"max_size": 64})()
+ for name in layer_class_names
+ ]
+ model = MockModel(num_layers=num_layers)
+ model.make_cache = lambda: cache_objs # type: ignore[attr-defined]
+ return model
+
+
+class TestMRUPartialEligibility:
+ """The ``mru_partial_supported`` tri-state flag and its one-shot
+ warning. Surfaces structurally-incompatible models on the admin
+ dashboard so operators see ``N/A (see log)`` instead of a misleading
+ ``0/N entries`` gauge. Mirrors the prior-art Pattern B (real
+ ``store_cache`` round-trip) for the lazy fallback path, and exercises
+ the eager init-time path through a model with ``make_cache()``.
+ """
+
+ @pytest.fixture
+ def mx(self):
+ try:
+ import mlx.core as mx
+ return mx
+ except ImportError:
+ pytest.skip("MLX not available")
+
+ @pytest.fixture
+ def paged_cache(self):
+ return PagedCacheManager(
+ block_size=4,
+ max_blocks=100,
+ model_name="test-model",
+ initial_blocks=100,
+ )
+
+ @pytest.fixture
+ def mock_ssd(self):
+ mock = MagicMock()
+ mock.save_block.return_value = True
+ mock.load_block.return_value = None
+ mock.load_block_with_metadata.return_value = (None, None)
+ mock.has_block.return_value = False
+ return mock
+
+ def test_supported_is_none_without_make_cache_and_no_inference(
+ self, paged_cache, mock_ssd
+ ):
+ """``MockModel`` has no ``make_cache``; eager check bare-returns and
+ lazy fallback hasn't fired yet — flag stays ``None``."""
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4)
+ stats = cache.get_stats()
+ assert stats.mru_partial_supported is None
+ assert cache.get_stats_dict()["mru_partial_supported"] is None
+
+ def test_supported_latches_true_on_sliceable_observation(
+ self, paged_cache, mock_ssd, mx
+ ):
+ """A successful KVCache stash latches ``supported=True`` via the
+ lazy path (eager skipped because MockModel has no make_cache)."""
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4)
+ _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99)
+ assert cache.get_stats().mru_partial_supported is True
+
+ def test_supported_latches_false_lazy_on_non_sliceable(
+ self, paged_cache, mock_ssd, mx, caplog
+ ):
+ """A store_cache with RotatingKVCache layers latches
+ ``supported=False`` and emits exactly one warning."""
+ from omlx.cache.hybrid_cache import ModelCacheConfig
+
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4)
+ # Eager init skipped (MockModel has no make_cache), so flag is
+ # None at this point — the lazy path is what we're testing.
+ assert cache.get_stats().mru_partial_supported is None
+
+ tokens = [10, 20, 30, 40, 50]
+ cache_data = [_rotating_layer(mx, 5) for _ in range(4)]
+ config = ModelCacheConfig.from_type_list(
+ ["RotatingKVCache"] * 4, model_name="test"
+ )
+ with caplog.at_level(logging.WARNING, logger="omlx.cache.prefix_cache"):
+ cache.store_cache("req-rot", tokens, cache_data, model_cache_config=config)
+
+ stats = cache.get_stats()
+ assert stats.mru_partial_supported is False
+ assert stats.mru_partial_stashes == 0 # gate refused, no stash
+ warns = [r for r in caplog.records if r.levelno == logging.WARNING]
+ assert len(warns) == 1
+ assert "MRU tails will be inactive" in warns[0].getMessage()
+ assert "incompatible" in warns[0].getMessage()
+ assert "RotatingKVCache" in warns[0].getMessage()
+
+ def test_warning_does_not_repeat_on_subsequent_non_sliceable(
+ self, paged_cache, mock_ssd, mx, caplog
+ ):
+ """Once the flag is latched False, further non-sliceable store_cache
+ calls must NOT re-emit the warning (operator log spam guard)."""
+ from omlx.cache.hybrid_cache import ModelCacheConfig
+
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4)
+ config = ModelCacheConfig.from_type_list(
+ ["RotatingKVCache"] * 4, model_name="test"
+ )
+ with caplog.at_level(logging.WARNING, logger="omlx.cache.prefix_cache"):
+ for i in range(3):
+ cache.store_cache(
+ f"req-rot-{i}",
+ [10 * i + j for j in range(5)],
+ [_rotating_layer(mx, 5) for _ in range(4)],
+ model_cache_config=config,
+ )
+
+ warns = [r for r in caplog.records if r.levelno == logging.WARNING]
+ assert len(warns) == 1
+ assert cache.get_stats().mru_partial_supported is False
+
+ def test_eager_check_latches_false_at_init_with_non_sliceable_make_cache(
+ self, paged_cache, mock_ssd, caplog
+ ):
+ """When ``model.make_cache()`` is available and returns non-sliceable
+ cache instances, the flag latches False at construction and the
+ warning fires BEFORE any inference — true model-load-time signal."""
+ model = _model_with_make_cache(
+ num_layers=4,
+ layer_class_names=["RotatingKVCache"] * 4,
+ )
+ with caplog.at_level(logging.WARNING, logger="omlx.cache.prefix_cache"):
+ cache = BlockAwarePrefixCache(
+ model=model,
+ paged_cache_manager=paged_cache,
+ paged_ssd_cache_manager=mock_ssd,
+ mru_partial_max_entries=4,
+ )
+
+ assert cache.get_stats().mru_partial_supported is False
+ warns = [r for r in caplog.records if r.levelno == logging.WARNING]
+ assert len(warns) == 1
+ assert "RotatingKVCache" in warns[0].getMessage()
+
+ def test_eager_check_latches_true_at_init_with_sliceable_make_cache(
+ self, paged_cache, mock_ssd, caplog
+ ):
+ """When ``model.make_cache()`` returns only sliceable cache
+ instances, the flag latches True at construction and no warning
+ is emitted."""
+ model = _model_with_make_cache(
+ num_layers=4,
+ layer_class_names=["KVCache"] * 4,
+ )
+ with caplog.at_level(logging.WARNING, logger="omlx.cache.prefix_cache"):
+ cache = BlockAwarePrefixCache(
+ model=model,
+ paged_cache_manager=paged_cache,
+ paged_ssd_cache_manager=mock_ssd,
+ mru_partial_max_entries=4,
+ )
+
+ assert cache.get_stats().mru_partial_supported is True
+ assert not [r for r in caplog.records if r.levelno == logging.WARNING]
+
+ def test_eager_check_skipped_when_feature_disabled(
+ self, paged_cache, mock_ssd, caplog
+ ):
+ """``max_entries=0`` disables the feature; no eager check runs even
+ for an obviously incompatible model. No warning, no flag change —
+ the operator already opted out by setting the capacity to zero."""
+ model = _model_with_make_cache(
+ num_layers=4,
+ layer_class_names=["RotatingKVCache"] * 4,
+ )
+ with caplog.at_level(logging.WARNING, logger="omlx.cache.prefix_cache"):
+ cache = BlockAwarePrefixCache(
+ model=model,
+ paged_cache_manager=paged_cache,
+ paged_ssd_cache_manager=mock_ssd,
+ mru_partial_max_entries=0,
+ )
+
+ assert cache.get_stats().mru_partial_supported is None
+ assert not [r for r in caplog.records if r.levelno == logging.WARNING]
+
+ def test_eager_check_survives_make_cache_failure(
+ self, paged_cache, mock_ssd, caplog
+ ):
+ """If ``model.make_cache()`` raises, the eager check bare-returns
+ and the flag stays ``None``. Lazy fallback picks up at first
+ inference instead — no startup crash."""
+ model = MockModel(num_layers=4)
+ model.make_cache = lambda: (_ for _ in ()).throw( # type: ignore[attr-defined]
+ RuntimeError("model not fully initialized")
+ )
+ cache = BlockAwarePrefixCache(
+ model=model,
+ paged_cache_manager=paged_cache,
+ paged_ssd_cache_manager=mock_ssd,
+ mru_partial_max_entries=4,
+ )
+ assert cache.get_stats().mru_partial_supported is None
+
+
+def _store_seq(cache, mx, request_id, tokens, *, prompt_token_count=None):
+ """``store_cache`` a full token sequence; return the BlockTable.
+
+ ``cache_data`` spans the whole sequence so ``_update_mru_partial``
+ takes the global index path (the cache covers ``prompt + output``),
+ matching the production resubmission layout.
+ """
+ cache_data = [_kv_layer(mx, len(tokens)) for _ in range(4)]
+ return cache.store_cache(
+ request_id, tokens, cache_data, prompt_token_count=prompt_token_count,
+ )
+
+
+class TestMRUPromptBoundaryStash:
+ """The MRU stash must key off the *prompt's* trailing partial, not the
+ stored sequence's.
+
+ ``store_cache`` is handed ``prompt + output``, but a repeat request
+ resubmits the prompt only and ``apply_mru_partial`` looks the entry
+ up by the prompt's last full block. Before the prompt boundary was
+ threaded in, the stash keyed off ``prompt + output``'s last full
+ block — a key a prompt-only resubmit could never compute — so the
+ feature never produced a hit for ordinary chat completions.
+
+ block_size is 4 in this fixture.
+ """
+
+ @pytest.fixture
+ def mx(self):
+ try:
+ import mlx.core as mx
+ return mx
+ except ImportError:
+ pytest.skip("MLX not available")
+
+ @pytest.fixture
+ def paged_cache(self):
+ return PagedCacheManager(
+ block_size=4,
+ max_blocks=100,
+ model_name="test-model",
+ initial_blocks=100,
+ )
+
+ @pytest.fixture
+ def mock_ssd(self):
+ mock = MagicMock()
+ mock.save_block.return_value = True
+ mock.load_block.return_value = None
+ mock.load_block_with_metadata.return_value = (None, None)
+ mock.has_block.return_value = False
+ return mock
+
+ def test_prompt_boundary_stash_hits_on_prompt_only_resubmit(
+ self, paged_cache, mock_ssd, mx
+ ):
+ """The decisive regression test: store ``prompt + output`` with the
+ prompt boundary, then resubmit the prompt only — apply must HIT.
+ """
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4)
+ # prompt = 1 full block [10..13] + 2-token tail [14,15]; output = 3.
+ prompt = [10, 11, 12, 13, 14, 15]
+ stored = prompt + [90, 91, 92]
+ store_bt = _store_seq(
+ cache, mx, "store", stored, prompt_token_count=len(prompt)
+ )
+
+ # Keyed by the prompt's last full block (block 0) — NOT the stored
+ # sequence's last full block (block 1) — and stashing the prompt
+ # tail [14,15], not the sequence tail [92].
+ prompt_block = paged_cache.allocated_blocks[store_bt.block_ids[0]]
+ assert prompt_block.block_hash in cache._mru_partials
+ assert cache._mru_partials[prompt_block.block_hash].tokens == [14, 15]
+
+ # Simulate fetch_cache(prompt): a block table with the prompt's
+ # full blocks only — apply_mru_partial keys off block_ids[-1].
+ fetch_bt = BlockTable(request_id="resubmit")
+ fetch_bt.block_ids = [store_bt.block_ids[0]]
+ fetch_bt.num_tokens = 4
+ reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4)
+
+ _, new_remaining, applied = cache.apply_mru_partial(
+ reconstructed, fetch_bt, [14, 15]
+ )
+ assert applied == 2
+ assert new_remaining == []
+ stats = cache.get_stats()
+ assert stats.mru_partial_hits == 1
+ assert stats.mru_partial_tokens_saved == 2
+
+ def test_whole_sequence_stash_misses_on_prompt_only_resubmit(
+ self, paged_cache, mock_ssd, mx
+ ):
+ """Pins the original bug: with no prompt boundary
+ (``prompt_token_count=None``) the stash keys off the stored
+ sequence's last full block, which a prompt-only resubmit never
+ reaches — 0 hits, and 0 evictions because the lookup key is never
+ found (no entry to evict).
+ """
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4)
+ prompt = [10, 11, 12, 13, 14, 15]
+ stored = prompt + [90, 91, 92]
+ store_bt = _store_seq(cache, mx, "store", stored) # boundary unknown
+
+ # None falls back to whole-sequence: keyed off block 1 (the stored
+ # sequence's last full block), unreachable by a prompt-only fetch.
+ seq_block = paged_cache.allocated_blocks[store_bt.block_ids[1]]
+ assert seq_block.block_hash in cache._mru_partials
+
+ fetch_bt = BlockTable(request_id="resubmit")
+ fetch_bt.block_ids = [store_bt.block_ids[0]]
+ fetch_bt.num_tokens = 4
+ reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4)
+
+ _, _, applied = cache.apply_mru_partial(reconstructed, fetch_bt, [14, 15])
+ assert applied == 0
+ stats = cache.get_stats()
+ assert stats.mru_partial_hits == 0
+ assert stats.mru_partial_evictions == 0 # key miss — nothing evicted
+
+ def test_block_aligned_prompt_does_not_stash(
+ self, paged_cache, mock_ssd, mx
+ ):
+ """A prompt that is an exact multiple of block_size has no partial
+ tail — every prompt token lands in a full paged block, so the MRU
+ has nothing to add."""
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4)
+ prompt = [10, 11, 12, 13, 14, 15, 16, 17] # 8 tokens = 2 full blocks
+ stored = prompt + [90, 91, 92]
+ _store_seq(cache, mx, "store", stored, prompt_token_count=len(prompt))
+ assert not cache._mru_partials
+ assert cache.get_stats().mru_partial_stashes == 0
+
+ def test_short_prompt_stashes_under_none_key(
+ self, paged_cache, mock_ssd, mx
+ ):
+ """A prompt shorter than one block has no last full block; the
+ stash is keyed by None (the short-prompt path) and holds the
+ whole prompt as its tail."""
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4)
+ prompt = [10, 11, 12] # 3 tokens < block_size 4
+ stored = prompt + [90, 91, 92, 93, 94]
+ _store_seq(cache, mx, "store", stored, prompt_token_count=len(prompt))
+ assert None in cache._mru_partials
+ assert cache._mru_partials[None].tokens == [10, 11, 12]
+
+ def test_prompt_boundary_stash_with_existing_cached_prefix(
+ self, paged_cache, mock_ssd, mx
+ ):
+ """Resubmission path: store_cache runs with ``existing_tokens > 0``
+ (the prompt's leading blocks are already cached) and works in
+ ``new_tokens`` space. The prompt-boundary arithmetic must still
+ resolve the prompt's last full block — not an index shifted by
+ ``existing_tokens``.
+ """
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4)
+ # prompt = 2 full blocks [10..17] + 2-token tail [18,19]; output = 3.
+ prompt = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
+ stored = prompt + [90, 91, 92]
+ # Round 1: store the prompt's first full block only, same
+ # request_id, so round 2's store sees existing_tokens > 0.
+ _store_seq(cache, mx, "req", prompt[:8], prompt_token_count=8)
+ assert not cache._mru_partials # block-aligned round-1, no stash
+ # Round 2: same request_id — existing_tokens == 8.
+ store_bt = _store_seq(
+ cache, mx, "req", stored, prompt_token_count=len(prompt)
+ )
+
+ # Prompt's last full block is index 1 ([14..17]); the stash must
+ # key off it despite existing_tokens=8 and new_tokens-space math.
+ prompt_block = paged_cache.allocated_blocks[store_bt.block_ids[1]]
+ assert prompt_block.block_hash in cache._mru_partials
+ assert cache._mru_partials[prompt_block.block_hash].tokens == [18, 19]
+
+ fetch_bt = BlockTable(request_id="resubmit")
+ fetch_bt.block_ids = store_bt.block_ids[:2]
+ fetch_bt.num_tokens = 8
+ reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=8)
+ _, new_remaining, applied = cache.apply_mru_partial(
+ reconstructed, fetch_bt, [18, 19]
+ )
+ assert applied == 2
+ assert new_remaining == []
+ assert cache.get_stats().mru_partial_hits == 1
+
+
+class TestMRUPartialCrossThreadSafety:
+ """An MRU partial is extracted on the store-cache worker thread but
+ spliced into a live cache on the separate inference thread.
+
+ ``_extract_block_tensor_slice`` builds the partial as lazy
+ ``mx.copy`` ops; an unevaluated tensor carries a pending op bound to
+ the worker's per-thread MLX stream, and evaluating the splice on the
+ inference thread raises ``RuntimeError: There is no Stream(gpu, N)
+ in current thread``. ``_update_mru_partial`` must materialize the
+ partial at stash time so the stashed data is concrete and
+ stream-free.
+ """
+
+ @pytest.fixture
+ def mx(self):
+ try:
+ import mlx.core as mx
+ return mx
+ except ImportError:
+ pytest.skip("MLX not available")
+
+ @pytest.fixture
+ def paged_cache(self):
+ return PagedCacheManager(
+ block_size=4,
+ max_blocks=100,
+ model_name="test-model",
+ initial_blocks=100,
+ )
+
+ @pytest.fixture
+ def mock_ssd(self):
+ mock = MagicMock()
+ mock.save_block.return_value = True
+ mock.load_block.return_value = None
+ mock.load_block_with_metadata.return_value = (None, None)
+ mock.has_block.return_value = False
+ return mock
+
+ def test_materialize_mru_kv_handles_extract_shapes(
+ self, paged_cache, mock_ssd, mx
+ ):
+ """``_materialize_mru_kv`` evaluates every ``mx.array`` leaf across
+ the plain ``(keys, values)`` and TurboQuant ``(tag, (k, v))``
+ shapes ``_extract_block_tensor_slice`` returns, and tolerates the
+ non-array tag string and an empty list."""
+ cache = _make_mru_cache(paged_cache, mock_ssd)
+ plain = [(mx.ones((1, 1, 2, 4)), mx.ones((1, 1, 2, 4)))]
+ tagged = [
+ ("__turboquant_v2__", (mx.ones((1, 1, 2, 4)), mx.ones((1, 1, 2, 4))))
+ ]
+ # None of these should raise.
+ cache._materialize_mru_kv(plain)
+ cache._materialize_mru_kv(tagged)
+ cache._materialize_mru_kv([])
+
+ def test_stashed_partial_splices_across_threads(
+ self, paged_cache, mock_ssd, mx
+ ):
+ """Extract+stash on a worker thread, splice+evaluate on the main
+ thread. Without stash-time materialization the final ``mx.eval``
+ raises a foreign-stream ``RuntimeError``; with it, the cross-
+ thread handoff is clean.
+ """
+ import concurrent.futures
+
+ cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4)
+ prompt = [10, 11, 12, 13, 14, 15]
+ stored = prompt + [90, 91, 92]
+ cache_data = [_kv_layer(mx, len(stored)) for _ in range(4)]
+ # Mirror production: the inference thread materializes the
+ # extracted cache (mx.async_eval + the worker's mx.synchronize)
+ # before the store worker runs. Without this the cache_data
+ # arrays would still be lazy ops bound to THIS thread's stream —
+ # a different cross-thread failure than the one under test.
+ for layer in cache_data:
+ mx.eval(*layer["state"])
+
+ # Stash on a dedicated worker thread, mirroring the production
+ # _store_cache_executor (a pool distinct from the inference one).
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
+ store_bt = pool.submit(
+ cache.store_cache,
+ "store",
+ stored,
+ cache_data,
+ prompt_token_count=len(prompt),
+ ).result()
+
+ # Splice on this (the "inference") thread.
+ fetch_bt = BlockTable(request_id="resubmit")
+ fetch_bt.block_ids = [store_bt.block_ids[0]]
+ fetch_bt.num_tokens = 4
+ reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4)
+
+ spliced, new_remaining, applied = cache.apply_mru_partial(
+ reconstructed, fetch_bt, [14, 15]
+ )
+ assert applied == 2
+ assert new_remaining == []
+
+ # Force evaluation on this thread — the point a partial still
+ # carrying the worker thread's stream would fail.
+ for layer in spliced:
+ mx.eval(layer.keys, layer.values)
+
+
+class TestHasMRUPartial:
+ """The has_mru_partial() accessor is the public API the scheduler
+ uses to decide whether to suppress the deferred Metal cache clear."""
+
+ def test_has_mru_partial_reflects_dict_emptiness(self):
+ cache = BlockAwarePrefixCache(
+ model=MockModel(num_layers=2),
+ paged_cache_manager=PagedCacheManager(
+ block_size=4, max_blocks=10, model_name="t", initial_blocks=10,
+ ),
+ paged_ssd_cache_manager=None,
+ )
+ assert cache.has_mru_partial() is False
+
+ cache._mru_partials[b"x"] = _MRUPartialBlock(
+ parent_hash=b"x", tokens=[1], kv_data=[],
+ )
+ assert cache.has_mru_partial() is True
+
+ cache._mru_partials.clear()
+ assert cache.has_mru_partial() is False
diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py
index 5e87be0f3..892db445c 100644
--- a/tests/test_scheduler.py
+++ b/tests/test_scheduler.py
@@ -1360,6 +1360,253 @@ def test_cleanup_finished_extends_deferred_clear_for_concurrent_completions(
assert scheduler._deferred_clear_at > first_target
+class TestMRUDeferredClearSuppression:
+ """Tests for the MRU-aware deferred-clear suppression.
+
+ When the prefix cache holds a warm MRU partial (the previous request
+ just left a sub-block tail in memory), the next deferred clear is
+ deferred by one more ``_DEFERRED_CLEAR_DELAY`` window so the still-
+ resident lazy KV tensors aren't dropped before a likely repeat.
+
+ The suppression is **budgeted at one-shot per completion** so it
+ cannot indefinitely defer the clear under hot-prompt repeats — total
+ deferral is bounded at 2x ``_DEFERRED_CLEAR_DELAY`` even if the MRU
+ is replenished on every iteration.
+ """
+
+ def _drive_clear_gate(self, scheduler):
+ """Run the step()-tail gate that decides whether to clear.
+
+ Mirrors the structure of step()'s end-of-step block — the
+ smallest slice of step() needed to exercise the suppression
+ logic without standing up a full BatchGenerator.
+ """
+ from omlx import scheduler as sched_mod
+
+ with patch.object(sched_mod, "_sync_and_clear_cache") as mock_clear:
+ scheduler._step_counter += 1
+ should_clear = scheduler._should_periodic_clear_cache()
+ if (
+ scheduler._deferred_clear_at is not None
+ and scheduler._step_counter >= scheduler._deferred_clear_at
+ ):
+ if (
+ scheduler._mru_clear_suppression_available
+ and scheduler.block_aware_cache is not None
+ and scheduler.block_aware_cache.has_mru_partial()
+ ):
+ scheduler._deferred_clear_at = (
+ scheduler._step_counter
+ + scheduler._DEFERRED_CLEAR_DELAY
+ )
+ scheduler._mru_clear_suppression_available = False
+ else:
+ should_clear = True
+ scheduler._deferred_clear_at = None
+ scheduler._mru_clear_suppression_available = False
+ if should_clear:
+ sched_mod._sync_and_clear_cache()
+ return mock_clear.called
+
+ def test_suppression_budget_armed_on_completion(
+ self, mock_model, mock_tokenizer
+ ):
+ """Each completion that arms _deferred_clear_at also arms the budget."""
+ scheduler = Scheduler(model=mock_model, tokenizer=mock_tokenizer)
+
+ request = Request(
+ request_id="req-arm",
+ prompt="hello",
+ sampling_params=SamplingParams(),
+ )
+ request.prompt_token_ids = [1, 2]
+ request.num_prompt_tokens = 2
+ request.output_token_ids = [3]
+ scheduler.running["req-arm"] = request
+ scheduler.requests["req-arm"] = request
+
+ with patch("omlx.scheduler.mx"):
+ scheduler._cleanup_finished({"req-arm"})
+
+ assert scheduler._deferred_clear_at is not None
+ assert scheduler._mru_clear_suppression_available is True
+
+ def test_clear_fires_at_deadline_when_no_mru(
+ self, mock_model, mock_tokenizer
+ ):
+ """Without a warm MRU, the deferred clear fires at its deadline."""
+ scheduler = Scheduler(model=mock_model, tokenizer=mock_tokenizer)
+ scheduler.block_aware_cache = MagicMock()
+ scheduler.block_aware_cache.has_mru_partial.return_value = False
+
+ scheduler._deferred_clear_at = scheduler._step_counter + 1
+ scheduler._mru_clear_suppression_available = True
+
+ cleared = self._drive_clear_gate(scheduler)
+
+ assert cleared is True
+ assert scheduler._deferred_clear_at is None
+ assert scheduler._mru_clear_suppression_available is False
+
+ def test_clear_deferred_once_when_mru_warm(
+ self, mock_model, mock_tokenizer
+ ):
+ """A warm MRU at the deadline defers the clear by one more window."""
+ scheduler = Scheduler(model=mock_model, tokenizer=mock_tokenizer)
+ scheduler.block_aware_cache = MagicMock()
+ scheduler.block_aware_cache.has_mru_partial.return_value = True
+
+ scheduler._deferred_clear_at = scheduler._step_counter + 1
+ scheduler._mru_clear_suppression_available = True
+
+ cleared = self._drive_clear_gate(scheduler)
+
+ assert cleared is False
+ # Deadline pushed out by one more DELAY window
+ assert scheduler._deferred_clear_at == (
+ scheduler._step_counter + Scheduler._DEFERRED_CLEAR_DELAY
+ )
+ # Budget spent — next deadline cannot suppress again
+ assert scheduler._mru_clear_suppression_available is False
+
+ def test_clear_fires_at_second_deadline_even_if_mru_still_warm(
+ self, mock_model, mock_tokenizer
+ ):
+ """Budget is one-shot: the second deadline fires regardless of MRU.
+
+ This is the bound that protects against infinite deferral under
+ hot-prompt repeats (review B6).
+ """
+ scheduler = Scheduler(model=mock_model, tokenizer=mock_tokenizer)
+ scheduler.block_aware_cache = MagicMock()
+ scheduler.block_aware_cache.has_mru_partial.return_value = True
+
+ # First deadline → suppressed
+ scheduler._deferred_clear_at = scheduler._step_counter + 1
+ scheduler._mru_clear_suppression_available = True
+ first_cleared = self._drive_clear_gate(scheduler)
+ assert first_cleared is False
+
+ # Advance to the second deadline. MRU still warm, but no budget.
+ scheduler._step_counter = scheduler._deferred_clear_at - 1
+ second_cleared = self._drive_clear_gate(scheduler)
+
+ assert second_cleared is True
+ assert scheduler._deferred_clear_at is None
+
+ def test_clear_fires_immediately_after_mru_evicted(
+ self, mock_model, mock_tokenizer
+ ):
+ """If the MRU evicts before the deadline, the clear fires normally."""
+ scheduler = Scheduler(model=mock_model, tokenizer=mock_tokenizer)
+ scheduler.block_aware_cache = MagicMock()
+ # MRU was warm at completion but evicted by the time we reach deadline
+ scheduler.block_aware_cache.has_mru_partial.return_value = False
+
+ scheduler._deferred_clear_at = scheduler._step_counter + 1
+ scheduler._mru_clear_suppression_available = True
+
+ cleared = self._drive_clear_gate(scheduler)
+
+ assert cleared is True
+ # Budget left untouched as a side effect doesn't matter — it's
+ # reset on the next completion either way.
+ assert scheduler._deferred_clear_at is None
+
+ def test_new_epoch_completion_arms_budget(
+ self, mock_model, mock_tokenizer
+ ):
+ """A completion that STARTS a new deferral epoch (transition from
+ ``_deferred_clear_at is None``) arms the budget. Together with
+ the next test, this pins the contract: budget is one-shot per
+ epoch."""
+ scheduler = Scheduler(model=mock_model, tokenizer=mock_tokenizer)
+ scheduler._mru_clear_suppression_available = False # spent
+ assert scheduler._deferred_clear_at is None # no epoch in flight
+
+ request = Request(
+ request_id="req-new-epoch",
+ prompt="hello",
+ sampling_params=SamplingParams(),
+ )
+ request.prompt_token_ids = [1, 2]
+ request.num_prompt_tokens = 2
+ request.output_token_ids = [3]
+ scheduler.running["req-new-epoch"] = request
+ scheduler.requests["req-new-epoch"] = request
+
+ with patch("omlx.scheduler.mx"):
+ scheduler._cleanup_finished({"req-new-epoch"})
+
+ assert scheduler._mru_clear_suppression_available is True
+
+ def test_completion_within_open_epoch_does_not_refresh_budget(
+ self, mock_model, mock_tokenizer
+ ):
+ """**The hot-prompt invariant.** A completion that lands while a
+ deferral is already pending must NOT refresh the budget.
+
+ Pre-fix behaviour: every completion re-armed
+ ``_mru_clear_suppression_available = True``. Under hot-prompt
+ repeats — the very workload this feature targets — completions
+ arrive faster than ``_DEFERRED_CLEAR_DELAY``, so the budget
+ kept refreshing after being spent and the deferred clear could
+ be pushed forever, defeating the pool-bloat mitigation (#411).
+
+ Post-fix: budget is armed only on the transition from ``None``,
+ spent at the deadline if MRU is still warm, and stays spent for
+ the rest of the epoch regardless of how many further completions
+ land.
+ """
+ scheduler = Scheduler(model=mock_model, tokenizer=mock_tokenizer)
+
+ # First completion opens the epoch and arms the budget.
+ req1 = Request(
+ request_id="req-open",
+ prompt="a",
+ sampling_params=SamplingParams(),
+ )
+ req1.prompt_token_ids = [1, 2]
+ req1.num_prompt_tokens = 2
+ req1.output_token_ids = [3]
+ scheduler.running["req-open"] = req1
+ scheduler.requests["req-open"] = req1
+
+ with patch("omlx.scheduler.mx"):
+ scheduler._cleanup_finished({"req-open"})
+
+ assert scheduler._deferred_clear_at is not None
+ assert scheduler._mru_clear_suppression_available is True
+
+ # Simulate the budget already being spent (suppression fired
+ # at first attempted deadline).
+ scheduler._mru_clear_suppression_available = False
+
+ # A second completion arrives while the same epoch is still
+ # open. It may legitimately push the deadline out — that's the
+ # #557 invariant — but it must NOT refresh the spent budget.
+ scheduler._step_counter += 3
+ req2 = Request(
+ request_id="req-mid-epoch",
+ prompt="b",
+ sampling_params=SamplingParams(),
+ )
+ req2.prompt_token_ids = [4, 5]
+ req2.num_prompt_tokens = 2
+ req2.output_token_ids = [6]
+ scheduler.running["req-mid-epoch"] = req2
+ scheduler.requests["req-mid-epoch"] = req2
+
+ with patch("omlx.scheduler.mx"):
+ scheduler._cleanup_finished({"req-mid-epoch"})
+
+ # Deadline pushed (per #557), budget unchanged (per the C3 fix).
+ assert scheduler._deferred_clear_at == (
+ scheduler._step_counter + Scheduler._DEFERRED_CLEAR_DELAY
+ )
+ assert scheduler._mru_clear_suppression_available is False
+
+
class TestPeriodicClearGating:
"""Tests for the conditional periodic clear (#978/#1040 mitigation)."""
diff --git a/tests/test_settings.py b/tests/test_settings.py
index f3a1f54b2..d3ccc4c68 100644
--- a/tests/test_settings.py
+++ b/tests/test_settings.py
@@ -342,6 +342,7 @@ def test_to_dict(self):
"ssd_cache_max_size": "50GB",
"hot_cache_max_size": "0",
"initial_cache_blocks": 256,
+ "mru_partial_max_entries": 4,
}
def test_from_dict(self):