From 0de60746fdb519cf31c0a82eb9fa65997c7b50b2 Mon Sep 17 00:00:00 2001 From: Ivan Iguaran Date: Mon, 11 May 2026 11:18:43 -0400 Subject: [PATCH 01/18] feat(cache): add per-model cache hit-rate observability MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Server-side snapshot differencing via CacheRateTracker: stores the last 90 snapshots of cumulative cache counters (10s intervals, 15 min window) and computes rates via start/end differencing. Zero hot-path changes — snapshots are lazy, driven by dashboard polling cadence. New metrics exposed in /api/stats under cache_observability: - prefix_hit_rate (cumulative + windowed) - eviction count, ssd_hot_rate - per-model and weighted aggregate across models Dashboard: new "Cache Breakdown" card below Average Speed showing hit rate, evictions, and hot cache hits. Session-only (hidden in All-Time view since counters reset on model reload). --- omlx/admin/routes.py | 128 +++++++++---- omlx/admin/static/css/dashboard.css | 4 + omlx/admin/static/js/dashboard.js | 62 +++++- omlx/admin/templates/dashboard/_status.html | 78 +++++++- omlx/cache/observability.py | 149 +++++++++++++++ omlx/cache/paged_ssd_cache.py | 14 ++ omlx/cache/prefix_cache.py | 13 ++ omlx/cache/stats.py | 4 + omlx/scheduler.py | 42 +++- tests/test_cache_observability.py | 202 ++++++++++++++++++++ 10 files changed, 656 insertions(+), 40 deletions(-) create mode 100644 omlx/cache/observability.py create mode 100644 tests/test_cache_observability.py diff --git a/omlx/admin/routes.py b/omlx/admin/routes.py index 30522273..cb7d4989 100644 --- a/omlx/admin/routes.py +++ b/omlx/admin/routes.py @@ -3525,6 +3525,13 @@ def _build_runtime_cache_observability( } cache_dir = global_settings.cache.get_ssd_cache_dir(global_settings.base_path) + cache_cfg = global_settings.cache + try: + cfg_disk_max = cache_cfg.get_ssd_cache_max_size_bytes(global_settings.base_path) + except (ValueError, OSError, TypeError) as exc: + logger.warning("Could not read SSD cache max size from config: %s", exc) + cfg_disk_max = 0 + payload = { "base_path": str(global_settings.base_path), "ssd_cache_dir": str(cache_dir), @@ -3533,6 +3540,10 @@ def _build_runtime_cache_observability( "total_num_files": 0, "total_size_bytes": 0, "effective_block_sizes": [], + "disk_max_bytes": cfg_disk_max, + "hot_cache_max_bytes": 0, + "hot_cache_size_bytes": 0, + "hot_cache_entries": 0, } engine_pool = _get_engine_pool() @@ -3644,11 +3655,16 @@ def _build_runtime_cache_observability( "last_tokens_to_next_block": last_tokens_to_next_block, "num_files": int(ssd_stats.get("num_files", 0) or 0), "total_size_bytes": int(ssd_stats.get("total_size_bytes", 0) or 0), + "max_size_bytes": int(ssd_stats.get("max_size_bytes", 0) or 0), "hot_cache_max_bytes": int(ssd_stats.get("hot_cache_max_bytes", 0) or 0), "hot_cache_size_bytes": int(ssd_stats.get("hot_cache_size_bytes", 0) or 0), "hot_cache_entries": int(ssd_stats.get("hot_cache_entries", 0) or 0), } + cache_rates = runtime_stats.get("cache_rates") + if cache_rates: + model_payload["cache_rates"] = cache_rates + payload["models"].append(model_payload) payload["total_num_files"] += model_payload["num_files"] payload["total_size_bytes"] += model_payload["total_size_bytes"] @@ -3658,6 +3674,26 @@ def _build_runtime_cache_observability( payload["effective_block_sizes"] = sorted(block_sizes) + # Aggregate hot-cache and disk-max across models. + # hot_cache_max sums across models (each model reserves its own slice of + # the same process-wide hot cache budget) so the gauge denominator matches + # the summed numerator. disk_max keeps the config fallback via max() + # because a single SSD cache directory is shared — the effective cap is + # the largest configured limit, not a per-model sum. + hot_cache_max = 0 + disk_max = payload["disk_max_bytes"] + hot_cache_size_total = 0 + hot_cache_entries_total = 0 + for m in payload["models"]: + hot_cache_size_total += m.get("hot_cache_size_bytes", 0) + hot_cache_entries_total += m.get("hot_cache_entries", 0) + hot_cache_max += m.get("hot_cache_max_bytes", 0) + disk_max = max(disk_max, m.get("max_size_bytes", 0)) + payload["hot_cache_max_bytes"] = hot_cache_max + payload["hot_cache_size_bytes"] = hot_cache_size_total + payload["hot_cache_entries"] = hot_cache_entries_total + payload["disk_max_bytes"] = disk_max + # Fallback: if no loaded models contributed stats, scan the cache # directory directly so the dashboard still shows real disk usage. if payload["total_num_files"] == 0 and cache_dir.exists(): @@ -3912,6 +3948,30 @@ async def clear_alltime_stats(is_admin: bool = Depends(require_admin)): return {"status": "ok"} +def _iter_loaded_schedulers(): + """Yield (model_id, scheduler) for each loaded model. + + Traverses the internal engine hierarchy: pool entry → async engine → + core engine → scheduler. Both ``clear_ssd_cache`` and + ``clear_hot_cache`` share this traversal. + """ + engine_pool = _get_engine_pool() + if engine_pool is None: + return + for model_info in engine_pool.get_status().get("models", []): + model_id = model_info.get("id") + if not model_id or not model_info.get("loaded"): + continue + entry = engine_pool._entries.get(model_id) + if entry is None or entry.engine is None: + continue + async_core = getattr(entry.engine, "_engine", None) + core = getattr(async_core, "engine", None) if async_core is not None else None + scheduler = getattr(core, "scheduler", None) if core is not None else None + if scheduler is not None: + yield model_id, scheduler + + @router.post("/api/ssd-cache/clear") async def clear_ssd_cache(is_admin: bool = Depends(require_admin)): """Clear all SSD cache files for all loaded models. @@ -3922,38 +3982,17 @@ async def clear_ssd_cache(is_admin: bool = Depends(require_admin)): """ total_deleted = 0 - # Phase 1: clear via loaded models' cache managers (updates in-memory index) - engine_pool = _get_engine_pool() - if engine_pool is not None: - for model_info in engine_pool.get_status().get("models", []): - model_id = model_info.get("id") - if not model_id or not model_info.get("loaded"): - continue - - entry = engine_pool._entries.get(model_id) - if entry is None or entry.engine is None: - continue - - async_core = getattr(entry.engine, "_engine", None) - core = ( - getattr(async_core, "engine", None) if async_core is not None else None - ) - scheduler = ( - getattr(core, "scheduler", None) if core is not None else None - ) - - if scheduler is not None: - ssd_manager = getattr(scheduler, "paged_ssd_cache_manager", None) - if ssd_manager is not None: - try: - deleted = ssd_manager.clear() - total_deleted += deleted - except Exception as exc: - logger.warning( - "Failed to clear SSD cache for model '%s': %s", - model_id, - exc, - ) + for model_id, scheduler in _iter_loaded_schedulers(): + ssd_manager = getattr(scheduler, "paged_ssd_cache_manager", None) + if ssd_manager is not None: + try: + total_deleted += ssd_manager.clear() + except Exception as exc: + logger.warning( + "Failed to clear SSD cache for model '%s': %s", + model_id, + exc, + ) # Phase 2: remove any remaining files on disk (covers unloaded models) global_settings = _get_global_settings() @@ -3979,6 +4018,31 @@ async def clear_ssd_cache(is_admin: bool = Depends(require_admin)): return {"status": "ok", "total_deleted": total_deleted} +@router.post("/api/hot-cache/clear") +async def clear_hot_cache(is_admin: bool = Depends(require_admin)): + """Clear the in-memory (hot) cache for all loaded models. + + No filesystem fallback needed — hot cache is in-memory only and does + not survive process restart. + """ + total_cleared = 0 + for model_id, scheduler in _iter_loaded_schedulers(): + ssd_manager = getattr(scheduler, "paged_ssd_cache_manager", None) + if ssd_manager is not None and hasattr(ssd_manager, "clear_hot_cache"): + try: + total_cleared += ssd_manager.clear_hot_cache() + except Exception as exc: + logger.warning( + "Failed to clear hot cache for model '%s': %s", + model_id, + exc, + ) + rate_tracker = getattr(scheduler, "_cache_rate_tracker", None) + if rate_tracker is not None: + rate_tracker.clear() + return {"status": "ok", "total_cleared": total_cleared} + + @router.post("/api/cache/probe") async def probe_cache( request: CacheProbeRequest, diff --git a/omlx/admin/static/css/dashboard.css b/omlx/admin/static/css/dashboard.css index a95382df..6af38022 100644 --- a/omlx/admin/static/css/dashboard.css +++ b/omlx/admin/static/css/dashboard.css @@ -63,6 +63,10 @@ [data-theme="dark"] .hover\:text-neutral-700:hover { color: var(--text-primary) !important; } [data-theme="dark"] .hover\:text-neutral-600:hover { color: var(--text-secondary) !important; } + /* === Gauge track (visible in both themes) === */ + .gauge-track { background-color: #e5e5e5; } + [data-theme="dark"] .gauge-track { background-color: #3f3f46 !important; } + /* === Active nav tab (bg-white with shadow inside dark nav) === */ [data-theme="dark"] .shadow-sm { box-shadow: 0 1px 2px 0 rgba(0, 0, 0, 0.3) !important; } diff --git a/omlx/admin/static/js/dashboard.js b/omlx/admin/static/js/dashboard.js index 94f1841f..4b4d9a54 100644 --- a/omlx/admin/static/js/dashboard.js +++ b/omlx/admin/static/js/dashboard.js @@ -160,6 +160,10 @@ total_num_files: 0, total_size_bytes: 0, effective_block_sizes: [], + hot_cache_size_bytes: 0, + hot_cache_entries: 0, + hot_cache_max_bytes: 0, + disk_max_bytes: 0, }, }, alltimeStats: { @@ -190,6 +194,7 @@ showClearStatsConfirm: false, showClearAlltimeConfirm: false, showClearSsdCacheConfirm: false, + showClearHotCacheConfirm: false, _statsRefreshTimer: null, // Log viewer state @@ -2143,7 +2148,8 @@ async clearSsdCache() { try { - await fetch('/admin/api/ssd-cache/clear', { method: 'POST' }); + const resp = await fetch('/admin/api/ssd-cache/clear', { method: 'POST' }); + if (!resp.ok) console.error('SSD cache clear failed:', resp.status); this.showClearSsdCacheConfirm = false; await this.loadStats(); } catch (err) { @@ -2152,6 +2158,18 @@ } }, + async clearHotCache() { + try { + const resp = await fetch('/admin/api/hot-cache/clear', { method: 'POST' }); + if (!resp.ok) console.error('Hot cache clear failed:', resp.status); + this.showClearHotCacheConfirm = false; + await this.loadStats(); + } catch (err) { + console.error('Failed to clear hot cache:', err); + this.showClearHotCacheConfirm = false; + } + }, + startStatsRefresh() { this.stopStatsRefresh(); this._statsRefreshTimer = setInterval(() => { @@ -2172,6 +2190,36 @@ return num.toLocaleString(); }, + cacheObsCumulative(stats, selectedModel) { + const entries = stats.runtime_cache?.models || []; + if (entries.length === 0) return {}; + + if (selectedModel) { + const entry = entries.find(m => m.id === selectedModel); + return entry?.cache_rates?.cumulative || {}; + } + + const sumKeys = ['prefix_hits', 'prefix_misses', 'evictions', 'ssd_hot_hits', 'ssd_disk_loads', 'ssd_saves', 'hot_cache_evictions', 'hot_cache_promotions']; + let agg = {}; + + for (const m of entries) { + const c = m.cache_rates?.cumulative; + if (!c || Object.keys(c).length === 0) continue; + for (const k of sumKeys) { + agg[k] = (agg[k] || 0) + (c[k] || 0); + } + } + + const ph = agg.prefix_hits || 0; + const pm = agg.prefix_misses || 0; + const sh = agg.ssd_hot_hits || 0; + const sd = agg.ssd_disk_loads || 0; + agg.prefix_hit_rate = (ph + pm) > 0 ? ph / (ph + pm) : 0; + agg.ssd_hot_rate = (sh + sd) > 0 ? sh / (sh + sd) : 0; + + return agg; + }, + getStatFontClass(value) { if (value >= 1000000000) return 'text-2xl'; if (value >= 1000000) return 'text-3xl'; @@ -2233,6 +2281,18 @@ return 'bg-red-400'; }, + get runtimeHotCachePercent() { + const rc = this.stats.runtime_cache; + if (!rc || !rc.hot_cache_max_bytes) return 0; + return Math.min(100, (rc.hot_cache_size_bytes / rc.hot_cache_max_bytes) * 100); + }, + + get runtimeSsdCachePercent() { + const rc = this.stats.runtime_cache; + if (!rc || !rc.disk_max_bytes) return 0; + return Math.min(100, (rc.total_size_bytes / rc.disk_max_bytes) * 100); + }, + get activeModelsMemoryPercent() { const am = this.stats.active_models; if (!am || !am.model_memory_max) return 0; diff --git a/omlx/admin/templates/dashboard/_status.html b/omlx/admin/templates/dashboard/_status.html index e35100a2..527c81c9 100644 --- a/omlx/admin/templates/dashboard/_status.html +++ b/omlx/admin/templates/dashboard/_status.html @@ -282,8 +282,47 @@

{{ t('status.head Runtime Cache Observability
- + +
+ Memory +
+
+
+ +
+ +
+ Clear memory cache? + + +
+
+
+ + | + +
+ SSD +
+
+
+ +
+ +
+
+

Prefix Hit Rate

+

+
+
+

Memory Hit Rate

+

+
+
+

Prefix Evictions

+

+
+
+

Memory Evictions

+

+
+
+
@@ -336,8 +402,10 @@

{{ t('status.head

- - + + + + @@ -358,6 +426,8 @@

{{ t('status.head

+ + diff --git a/omlx/cache/observability.py b/omlx/cache/observability.py new file mode 100644 index 00000000..72a0e370 --- /dev/null +++ b/omlx/cache/observability.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: Apache-2.0 +import threading +import time +from collections import deque +from typing import Any + + +_DEFAULT_WINDOWS = (60, 300, 900) +_MAX_SNAPSHOTS = 90 +_MIN_INTERVAL = 10.0 + + +class CacheRateTracker: + + def __init__( + self, + max_snapshots: int = _MAX_SNAPSHOTS, + min_interval: float = _MIN_INTERVAL, + ): + self._snapshots: deque[tuple[float, dict[str, int]]] = deque( + maxlen=max_snapshots + ) + self._min_interval = min_interval + self._lock = threading.Lock() + + def maybe_snapshot(self, counters: dict[str, int]) -> bool: + with self._lock: + now = time.monotonic() + if self._snapshots and (now - self._snapshots[-1][0]) < self._min_interval: + return False + self._snapshots.append((now, dict(counters))) + return True + + def get_rates( + self, windows: tuple[int, ...] = _DEFAULT_WINDOWS + ) -> dict[str, Any]: + with self._lock: + if not self._snapshots: + return {"windows": {}, "cumulative": {}} + + now = self._snapshots[-1][0] + newest = self._snapshots[-1][1] + + window_rates = {} + for w in windows: + label = _window_label(w) + baseline_ts = None + baseline_counters = None + for ts, counters in self._snapshots: + if (now - ts) <= w: + baseline_ts, baseline_counters = ts, counters + break + if baseline_ts is None: + baseline_ts, baseline_counters = self._snapshots[0] + elapsed = now - baseline_ts + if elapsed < 1.0: + window_rates[label] = {} + continue + window_rates[label] = _compute_window( + baseline_counters, newest, elapsed + ) + + cumulative = _compute_cumulative(newest) + return {"windows": window_rates, "cumulative": cumulative} + + def snapshot_and_get_rates( + self, + counters: dict[str, int], + windows: tuple[int, ...] = _DEFAULT_WINDOWS, + ) -> dict[str, Any]: + self.maybe_snapshot(counters) + return self.get_rates(windows) + + def clear(self) -> None: + with self._lock: + self._snapshots.clear() + + +def _window_label(seconds: int) -> str: + if seconds < 60: + return f"{seconds}s" + return f"{seconds // 60}m" + + +def _safe_ratio(numerator: int, denominator: int) -> float: + if denominator == 0: + return 0.0 + return numerator / denominator + + +def _compute_window( + old: dict[str, int], new: dict[str, int], elapsed: float +) -> dict[str, Any]: + def delta(key: str) -> int: + return max(0, new.get(key, 0) - old.get(key, 0)) + + d_prefix_hits = delta("prefix_hits") + d_prefix_misses = delta("prefix_misses") + d_evictions = delta("evictions") + d_ssd_hot = delta("ssd_hot_hits") + d_ssd_disk = delta("ssd_disk_loads") + d_tokens_matched = delta("prefix_tokens_matched") + d_tokens_requested = delta("prefix_tokens_requested") + + minutes = elapsed / 60.0 + + return { + "prefix_hit_rate": round( + _safe_ratio(d_prefix_hits, d_prefix_hits + d_prefix_misses), 4 + ), + "prefix_hits": d_prefix_hits, + "prefix_misses": d_prefix_misses, + "prefix_match_efficiency": round( + _safe_ratio(d_tokens_matched, d_tokens_requested), 4 + ), + "evictions": d_evictions, + "eviction_rate_per_min": round(d_evictions / minutes, 2) if minutes > 0 else 0.0, + "ssd_hot_hits": d_ssd_hot, + "ssd_disk_loads": d_ssd_disk, + "ssd_hot_rate": round( + _safe_ratio(d_ssd_hot, d_ssd_hot + d_ssd_disk), 4 + ), + } + + +def _compute_cumulative(counters: dict[str, int]) -> dict[str, Any]: + prefix_hits = counters.get("prefix_hits", 0) + prefix_misses = counters.get("prefix_misses", 0) + ssd_hot = counters.get("ssd_hot_hits", 0) + ssd_disk = counters.get("ssd_disk_loads", 0) + tokens_matched = counters.get("prefix_tokens_matched", 0) + tokens_requested = counters.get("prefix_tokens_requested", 0) + + return { + "prefix_hits": prefix_hits, + "prefix_misses": prefix_misses, + "prefix_hit_rate": round(_safe_ratio(prefix_hits, prefix_hits + prefix_misses), 4), + "prefix_tokens_saved": counters.get("prefix_tokens_saved", 0), + "prefix_match_efficiency": round( + _safe_ratio(tokens_matched, tokens_requested), 4 + ), + "evictions": counters.get("evictions", 0), + "ssd_hot_hits": ssd_hot, + "ssd_disk_loads": ssd_disk, + "ssd_saves": counters.get("ssd_saves", 0), + "hot_cache_evictions": counters.get("hot_cache_evictions", 0), + "hot_cache_promotions": counters.get("hot_cache_promotions", 0), + "ssd_hot_rate": round(_safe_ratio(ssd_hot, ssd_hot + ssd_disk), 4), + } diff --git a/omlx/cache/paged_ssd_cache.py b/omlx/cache/paged_ssd_cache.py index 7d5c0d6c..be52bc8e 100644 --- a/omlx/cache/paged_ssd_cache.py +++ b/omlx/cache/paged_ssd_cache.py @@ -2035,6 +2035,20 @@ def enforce_size_limit(self) -> int: ) return freed + def clear_hot_cache(self) -> int: + """Clear all in-memory (hot) cache entries. + + Returns: + Number of entries cleared. + """ + with self._hot_cache_lock: + count = len(self._hot_cache) + self._hot_cache.clear() + self._hot_cache_total_bytes = 0 + if count: + logger.info("Cleared %d hot cache entries", count) + return count + def clear(self) -> int: """ Clear all SSD cache files. diff --git a/omlx/cache/prefix_cache.py b/omlx/cache/prefix_cache.py index 4f2bd1d3..c9efaf34 100644 --- a/omlx/cache/prefix_cache.py +++ b/omlx/cache/prefix_cache.py @@ -117,6 +117,8 @@ def __init__( self._tokens_saved = 0 self._partial_block_skips = 0 self._partial_tokens_skipped = 0 + self._tokens_matched_total = 0 + self._tokens_requested_total = 0 self._last_partial_tokens_skipped = 0 self._last_tokens_to_next_block = 0 @@ -285,6 +287,8 @@ def fetch_cache( num_prefix_tokens = len(tokens) - len(remaining) self._hits += 1 self._tokens_saved += num_prefix_tokens + self._tokens_matched_total += num_prefix_tokens + self._tokens_requested_total += len(tokens) logger.debug( f"Cache hit for {request_id}: " @@ -310,6 +314,8 @@ def fetch_cache( remaining = tokens[prefix_len:] self._hits += 1 self._tokens_saved += prefix_len + self._tokens_matched_total += prefix_len + self._tokens_requested_total += len(tokens) logger.debug( f"Prefix index hit for {request_id}: " f"{prefix_len} tokens matched" @@ -319,6 +325,7 @@ def fetch_cache( # No cache hit self._misses += 1 + self._tokens_requested_total += len(tokens) logger.debug(f"Cache miss for {request_id}") return None, tokens @@ -2367,6 +2374,8 @@ def get_stats(self) -> PrefixCacheStats: block_size=self.block_size, last_partial_tokens_skipped=self._last_partial_tokens_skipped, last_tokens_to_next_block=self._last_tokens_to_next_block, + tokens_matched_total=self._tokens_matched_total, + tokens_requested_total=self._tokens_requested_total, ) def get_stats_dict(self) -> dict[str, Any]: @@ -2393,6 +2402,8 @@ def get_stats_dict(self) -> dict[str, Any]: "block_size": self.block_size, "last_partial_tokens_skipped": self._last_partial_tokens_skipped, "last_tokens_to_next_block": self._last_tokens_to_next_block, + "tokens_matched_total": self._tokens_matched_total, + "tokens_requested_total": self._tokens_requested_total, "active_requests": len(self._request_tables), **paged_stats, } @@ -2404,6 +2415,8 @@ def reset_stats(self) -> None: self._tokens_saved = 0 self._partial_block_skips = 0 self._partial_tokens_skipped = 0 + self._tokens_matched_total = 0 + self._tokens_requested_total = 0 self._last_partial_tokens_skipped = 0 self._last_tokens_to_next_block = 0 self.paged_cache.reset_stats() diff --git a/omlx/cache/stats.py b/omlx/cache/stats.py index 412074fc..01a78c53 100644 --- a/omlx/cache/stats.py +++ b/omlx/cache/stats.py @@ -88,6 +88,8 @@ class PrefixCacheStats(BaseCacheStats): block_size: int = 0 last_partial_tokens_skipped: int = 0 last_tokens_to_next_block: int = 0 + tokens_matched_total: int = 0 + tokens_requested_total: int = 0 _total_queries: int = field(default=0, repr=False) @property @@ -111,6 +113,8 @@ def reset(self) -> None: self.partial_tokens_skipped = 0 self.last_partial_tokens_skipped = 0 self.last_tokens_to_next_block = 0 + self.tokens_matched_total = 0 + self.tokens_requested_total = 0 self._total_queries = 0 diff --git a/omlx/scheduler.py b/omlx/scheduler.py index d3f9ca1d..bcf8f210 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -37,6 +37,7 @@ from mlx_lm.models.cache import make_prompt_cache from mlx_lm.sample_utils import make_logits_processors +from .cache.observability import CacheRateTracker from .cache.paged_cache import PagedCacheManager from .cache.prefix_cache import BlockAwarePrefixCache from .exceptions import is_cache_corruption_error @@ -781,6 +782,7 @@ def __init__( self.paged_cache_manager: PagedCacheManager | None = None self.block_aware_cache: BlockAwarePrefixCache | None = None self.paged_ssd_cache_manager: PagedSSDCacheManager | None = None + self._cache_rate_tracker = CacheRateTracker() self.memory_monitor: MemoryMonitor | None = None # Initialize paged SSD cache if paged_ssd_cache_dir is specified @@ -5322,6 +5324,7 @@ def _recover_from_cache_error(self) -> None: # Clear caches if self.block_aware_cache is not None: self.block_aware_cache.clear() + self._cache_rate_tracker.clear() # Clear UID mappings self.request_id_to_uid.clear() @@ -5651,6 +5654,7 @@ def reset(self) -> None: # Clear caches if self.block_aware_cache is not None: self.block_aware_cache.clear() + self._cache_rate_tracker.clear() # Clear detokenizers self._request_detokenizers.clear() @@ -6083,6 +6087,35 @@ def restore_cold_blocks_for_request(self, request_id: str) -> int: return verified + def _collect_cache_counters(self) -> dict[str, int] | None: + if self.block_aware_cache is None: + return None + + prefix_stats = self.block_aware_cache.get_stats() + counters = { + "prefix_hits": prefix_stats.hits, + "prefix_misses": prefix_stats.misses, + "prefix_tokens_matched": prefix_stats.tokens_matched_total, + "prefix_tokens_requested": prefix_stats.tokens_requested_total, + "prefix_tokens_saved": prefix_stats.tokens_saved, + "evictions": prefix_stats.evictions, + } + + if self.paged_ssd_cache_manager is not None: + ssd = self.paged_ssd_cache_manager.get_stats() + hot_hits = ssd.hot_cache_hits + total_loads = ssd.loads + counters.update({ + "ssd_hot_hits": hot_hits, + "ssd_disk_loads": max(0, total_loads - hot_hits), + "ssd_saves": ssd.saves, + "ssd_errors": ssd.errors, + "hot_cache_evictions": ssd.hot_cache_evictions, + "hot_cache_promotions": ssd.hot_cache_promotions, + }) + + return counters + def get_ssd_cache_stats(self) -> dict[str, Any] | None: """Get paged SSD + prefix cache observability statistics.""" stats = {} @@ -6091,15 +6124,18 @@ def get_ssd_cache_stats(self) -> dict[str, Any] | None: stats["ssd_cache"] = self.paged_ssd_cache_manager.get_stats() if self.paged_cache_manager is not None: - # In paged SSD-only mode, all cache data is on paged SSD stats["indexed_blocks"] = self.paged_cache_manager.cold_block_count stats["block_size"] = self.config.paged_cache_block_size if self.block_aware_cache is not None: - # Expose prefix-cache observability so UI can distinguish - # "0 indexed blocks" from "sub-block cached ( Date: Wed, 13 May 2026 17:10:00 -0400 Subject: [PATCH 02/18] fix(cache): configure disk-max mock in cache observability tests The new disk_max aggregation reads get_ssd_cache_max_size_bytes which the existing tests left unconfigured, producing a MagicMock that raises on max(MagicMock, int). Also add the missing max_size_bytes key to the expected model payloads. --- tests/test_admin_api_key.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_admin_api_key.py b/tests/test_admin_api_key.py index 8766b39c..0ace324e 100644 --- a/tests/test_admin_api_key.py +++ b/tests/test_admin_api_key.py @@ -655,6 +655,7 @@ def test_runtime_cache_uses_model_scoped_ssd_stats(self): mock_settings = MagicMock() mock_settings.base_path = Path("/tmp/omlx-base") mock_settings.cache.get_ssd_cache_dir.return_value = cache_dir + mock_settings.cache.get_ssd_cache_max_size_bytes.return_value = 0 shared_ssd_stats = { "num_files": 999, @@ -744,6 +745,7 @@ def test_runtime_cache_uses_model_scoped_ssd_stats(self): "last_tokens_to_next_block": 0, "num_files": 3, "total_size_bytes": 4096, + "max_size_bytes": 0, "hot_cache_max_bytes": 0, "hot_cache_size_bytes": 0, "hot_cache_entries": 0, @@ -760,6 +762,7 @@ def test_runtime_cache_uses_model_scoped_ssd_stats(self): "last_tokens_to_next_block": 0, "num_files": 7, "total_size_bytes": 8192, + "max_size_bytes": 0, "hot_cache_max_bytes": 0, "hot_cache_size_bytes": 0, "hot_cache_entries": 0, @@ -775,6 +778,7 @@ def test_runtime_cache_ignores_single_model_stats_failure(self): mock_settings = MagicMock() mock_settings.base_path = Path("/tmp/omlx-base") mock_settings.cache.get_ssd_cache_dir.return_value = cache_dir + mock_settings.cache.get_ssd_cache_max_size_bytes.return_value = 0 bad_scheduler = MagicMock() bad_scheduler.get_ssd_cache_stats.side_effect = RuntimeError("boom") @@ -833,6 +837,7 @@ def test_runtime_cache_marks_sub_block_cached_when_indexed_blocks_zero(self): mock_settings = MagicMock() mock_settings.base_path = Path("/tmp/omlx-base") mock_settings.cache.get_ssd_cache_dir.return_value = cache_dir + mock_settings.cache.get_ssd_cache_max_size_bytes.return_value = 0 scheduler = MagicMock() scheduler.get_ssd_cache_stats.return_value = { From 649a22188a72c3b0b0512bfef625d23cdab3aec4 Mon Sep 17 00:00:00 2001 From: Blightbow Date: Fri, 8 May 2026 01:37:03 -0400 Subject: [PATCH 03/18] feat(cache): MRU partial cache for repeat prompts The paged SSD cache only persists full block_size blocks; the trailing sub-block tail (e.g. 139 of 256 tokens) is otherwise re-prefilled on every repeat request. For Kimi K2.5-class models this adds ~1.5-2s of avoidable TTFT per submission of an identical prompt. Add a single-slot in-memory stash for that tail. After every store_cache that produces a trailing partial we keep its KV state under the parent block's hash. The next admission whose remaining_tokens start with the stashed tokens splices the partial onto the reconstructed cache and shrinks remaining_tokens by the partial's length, eliminating the tail prefill. This is a from-scratch rewrite of the archived feat/mru-partial-block- cache branch (now archive/mru-partial-block-cache-v1). The original landed three structural bugs that the test suite never exercised: 1. The duck-typed splice gate (hasattr(cache_obj, 'keys') and hasattr(cache_obj, 'offset')) misclassified RotatingKVCache as sliceable. RotatingKVCache HAS those attributes, so the gate would concatenate the full rotating-window state onto the new request's cache, blowing past max_size and leaving _idx stale. Hybrid models (Gemma 3, Mistral, anything with sliding window) would have been silently corrupted on every repeat. 2. The store-side extraction passed is_last_block=True, which makes _extract_block_tensor_slice return the *full state* (not a token slice) for non-sliceable layers. Wrong intent for partial extraction; compounded #1. 3. The splice's try/except wrapped the whole layer loop, so a concatenate failure on layer N>0 left layers 0..N-1 already mutated (offset += n_partial, keys/values overwritten) while the caller was told zero tokens were applied. Half-mutated caches are silent generation corruption. Companion bug in the original deferred-clear suppression: the suppression had no upper bound, so a hot-prompt workload (each repeat stashes a fresh MRU before the prior is consumed) could defer the Metal cache clear forever, defeating the pool-bloat mitigation (#411). Safety properties of the rewrite: - Hybrid refusal. Stash and apply both gate on uniform layer sliceability via CacheTypeRegistry.get_handler_by_class_name(...) .supports_block_slicing. If any layer is non-sliceable (RotatingKVCache, ArraysCache, etc.) the slot is left empty. Splicing only the sliceable layers in a hybrid would create per-layer offset skew at decode -- undefined behaviour at the model level -- so refusal is the only correct policy. - Transactional splice. apply_mru_partial runs in two phases. Phase 1 materialises the replacement keys/values for every layer without touching the cache; phase 2 commits the writes. A concatenate failure during phase 1 returns (cache, remaining, 0) with no layer mutated. The slot is evicted on failure so a consistently-failing partial does not get re-attempted. - Eviction on every miss kind. Parent-hash mismatch, token mismatch, length mismatch, layer-count mismatch, splice failure all clear the slot. A stale or mistargeted partial cannot survive into a future apply. - Bounded deferred-clear suppression. Each completion's _cleanup_finished arms a one-shot _mru_clear_suppression_available budget alongside the existing _deferred_clear_at target. At the deadline, if the budget is intact and the cache reports has_mru_partial(), the deadline is pushed out by one more _DEFERRED_CLEAR_DELAY window and the budget is spent. The next deadline fires regardless, bounding total deferral at 2x _DEFERRED_CLEAR_DELAY (~10-40 ms today). Patched against _deferred_clear_at, the post-#557 gate -- the original was patching the obsolete _deferred_clear_steps path. Tests (25 new, all passing): TestMRUPartialBlockCache (19) -- init state, stash semantics, no-stash on block alignment, slot replacement on subsequent store, parent-hash linkage, hybrid refusal (KVCache + RotatingKVCache and pure rotating), real round-trip via store_cache through apply for exact and prefix matches, every eviction reason, no-op on empty remaining, layer-count mismatch eviction, transactional rollback under a mocked mx.concatenate failure on layer 1 (asserts no layer's offset/shape changed), multi-turn correctness with existing_tokens > 0 and distinct fill values to verify the right slice was captured. TestHasMRUPartial (1) -- the public accessor used by the scheduler reflects slot transitions. TestMRUDeferredClearSuppression (5) -- budget armed by completion, clear fires at deadline without MRU, suppressed once with MRU (deadline pushed by exactly one DELAY), clear fires at second deadline even if MRU still warm, fires immediately after MRU eviction, fresh completion refreshes spent budget. Suite results on this commit: 189 passed (tests/test_prefix_cache.py + tests/test_scheduler.py) Co-Authored-By: Claude Opus 4.7 (1M context) --- omlx/cache/prefix_cache.py | 251 +++++++++++++++++++ omlx/scheduler.py | 54 ++++- tests/test_prefix_cache.py | 478 ++++++++++++++++++++++++++++++++++++- tests/test_scheduler.py | 177 ++++++++++++++ 4 files changed, 957 insertions(+), 3 deletions(-) diff --git a/omlx/cache/prefix_cache.py b/omlx/cache/prefix_cache.py index c9efaf34..6e4f5e97 100644 --- a/omlx/cache/prefix_cache.py +++ b/omlx/cache/prefix_cache.py @@ -44,6 +44,32 @@ class BlockCacheEntry: last_access: float +@dataclass +class _MRUPartialBlock: + """Single-slot most-recent stash for a trailing sub-block partial. + + The paged SSD cache only persists full ``block_size`` blocks; trailing + sub-block tokens (e.g. 139 out of 256) are otherwise discarded and + re-prefilled on every repeat request. ``_MRUPartialBlock`` keeps the + last partial in memory so an immediate repeat skips re-prefilling + those tail tokens. + + Single slot, replaced on every ``store_cache`` that produces a new + trailing partial. Evicted on any mismatch — parent-hash chain, token + prefix, layer count, splice failure — so a stale or mistargeted + partial cannot accumulate. + + Stash and apply are gated on **uniform layer sliceability**: if any + layer in the model is non-sliceable (RotatingKVCache, ArraysCache, + etc.) the slot is left empty. Splicing into only the sliceable + layers would create per-layer offset skew at decode time. + """ + + parent_hash: bytes | None + tokens: list[int] + kv_data: list[tuple[Any, Any]] + + class BlockAwarePrefixCache(CacheManager): """ Prefix cache that uses PagedCacheManager for block-based storage. @@ -111,6 +137,11 @@ def __init__( # Kept for API compatibility self._cold_restore_callback: Callable[[int, bytes], bool] | None = None + # Single-slot stash for the trailing sub-block tail (see + # _MRUPartialBlock). Populated by store_cache, consumed by + # apply_mru_partial. Always None for hybrid models. + self._mru_partial: _MRUPartialBlock | None = None + # Statistics self._hits = 0 self._misses = 0 @@ -659,6 +690,23 @@ def store_cache( last_access=time.time(), ) + # Stash the trailing sub-block tail in memory so an immediate + # repeat request can splice it back in without a re-prefill. + # The slot is replaced unconditionally — if there's no eligible + # tail (block-aligned, hybrid model, extraction failure) we clear + # it so a stale partial from a previous store cannot survive. + self._update_mru_partial( + new_tokens=new_tokens, + cache_data=cache_data, + block_table=block_table, + existing_tokens=existing_tokens, + num_new_blocks=num_new_blocks, + trailing_partial_tokens=trailing_partial_tokens, + is_tensor_data=bool(is_tensor_data), + layer_cache_types=layer_cache_types, + model_cache_config=model_cache_config, + ) + logger.debug( f"Stored cache for {request_id}: " f"{len(block_table.block_ids)} blocks ({blocks_saved_to_ssd} saved to tiered cache), " @@ -667,6 +715,209 @@ def store_cache( return block_table + def has_mru_partial(self) -> bool: + """Whether a trailing-tail partial is currently stashed. + + Used by the scheduler to decide whether to suppress a deferred + Metal cache clear (a fresh stash is a strong signal that the + same prompt may return immediately). + """ + return self._mru_partial is not None + + def _all_layers_sliceable( + self, layer_cache_types: list[str] | None + ) -> bool: + """True iff every layer's cache type supports block slicing. + + Hybrid models (e.g. Gemma 3, Mistral) mix sliceable layers with + non-sliceable ones (RotatingKVCache, ArraysCache). Splicing the + partial only into the sliceable layers would create per-layer + offset skew at decode — undefined behavior at the model level. + """ + if not layer_cache_types: + # Default fallback path assumes all layers are KVCache; safe to stash. + return True + for class_name in layer_cache_types: + handler = CacheTypeRegistry.get_handler_by_class_name(class_name) + if not handler.supports_block_slicing: + return False + return True + + def _update_mru_partial( + self, + *, + new_tokens: list[int], + cache_data: list[Any], + block_table: BlockTable, + existing_tokens: int, + num_new_blocks: int, + trailing_partial_tokens: int, + is_tensor_data: bool, + layer_cache_types: list[str] | None, + model_cache_config: ModelCacheConfig | None, + ) -> None: + """Refresh the MRU partial slot from a just-completed store_cache. + + Clears the slot in every "no eligible tail" branch (no trailing + tokens, non-tensor data, MLX missing, hybrid model, extraction + failure) so a stale partial cannot survive into a future + ``apply_mru_partial`` call. + """ + if ( + trailing_partial_tokens == 0 + or not is_tensor_data + or not HAS_MLX + or not self._all_layers_sliceable(layer_cache_types) + ): + self._mru_partial = None + return + + partial_start = num_new_blocks * self.block_size + partial_global_start = existing_tokens + partial_start + partial_global_end = partial_global_start + trailing_partial_tokens + + cache_seq_len = self._get_cache_seq_len(cache_data) + cache_uses_global_indices = ( + existing_tokens > 0 and cache_seq_len >= existing_tokens + 1 + ) + if cache_uses_global_indices: + p_cache_start = partial_global_start + p_cache_end = partial_global_end + else: + p_cache_start = partial_start + p_cache_end = partial_start + trailing_partial_tokens + + # is_last_block=False is the correct intent for partial extraction: + # we want a token slice, not the full state of any non-sliceable + # layer. The non-sliceable case is already excluded above. + partial_kv = self._extract_block_tensor_slice( + cache_data, + p_cache_start, + p_cache_end, + model_cache_config, + is_last_block=False, + ) + if not partial_kv: + self._mru_partial = None + return + + parent_hash: bytes | None = None + if block_table.block_ids: + last_bid = block_table.block_ids[-1] + last_block = self.paged_cache.allocated_blocks.get(last_bid) + if last_block is not None and last_block.block_hash: + parent_hash = last_block.block_hash + + self._mru_partial = _MRUPartialBlock( + parent_hash=parent_hash, + tokens=new_tokens[partial_start:], + kv_data=partial_kv, + ) + logger.debug( + "Stashed MRU partial: %d tokens, parent_hash=%s, layers=%d", + len(self._mru_partial.tokens), + parent_hash[:8].hex() + "..." if parent_hash else "None", + len(partial_kv), + ) + + def apply_mru_partial( + self, + cache: list[Any], + block_table: BlockTable, + remaining_tokens: list[int], + ) -> tuple[list[Any], list[int], int]: + """Splice the MRU partial into a reconstructed cache, atomically. + + On a match (parent-hash chain ok, partial tokens are a prefix of + ``remaining_tokens``, layer count matches, every layer is sliceable + and accepts the concatenate), every layer's keys/values/offset are + advanced by ``len(partial.tokens)`` and the partial tokens are + consumed from the front of ``remaining_tokens``. + + On any miss or splice failure, the slot is evicted and the cache + is returned unchanged. + + The splice is **transactional**: replacement keys/values are + materialized for every layer before any layer is mutated. A + failure on layer N rolls everything back; the caller never sees + a half-mutated cache. (This is the fix for review B3.) + + Args: + cache: Reconstructed per-layer cache objects from + ``reconstruct_cache``. Mutated in place on success. + block_table: Block table from the same fetch_cache call. + Used to verify the partial chains from the right prefix. + remaining_tokens: Tokens still needing prefill. + + Returns: + ``(cache, remaining_tokens, tokens_applied)``. On miss, + ``tokens_applied == 0`` and the inputs are returned unchanged. + """ + partial = self._mru_partial + if partial is None or not remaining_tokens: + return cache, remaining_tokens, 0 + + # Parent-hash chain check + last_hash: bytes | None = None + if block_table and block_table.block_ids: + last_bid = block_table.block_ids[-1] + last_block = self.paged_cache.allocated_blocks.get(last_bid) + if last_block is not None: + last_hash = last_block.block_hash + if partial.parent_hash != last_hash: + self._mru_partial = None + return cache, remaining_tokens, 0 + + n_partial = len(partial.tokens) + if len(remaining_tokens) < n_partial: + self._mru_partial = None + return cache, remaining_tokens, 0 + if remaining_tokens[:n_partial] != partial.tokens: + self._mru_partial = None + return cache, remaining_tokens, 0 + + if len(partial.kv_data) != len(cache): + logger.debug( + "MRU partial layer count mismatch: %d vs %d, evicting", + len(partial.kv_data), len(cache), + ) + self._mru_partial = None + return cache, remaining_tokens, 0 + + if not HAS_MLX: + self._mru_partial = None + return cache, remaining_tokens, 0 + + # Phase 1: build per-layer replacements without touching the cache. + # Any failure here is a clean rollback — no layer has been mutated. + try: + replacements: list[tuple[int, Any, Any]] = [] + for layer_idx, (p_keys, p_values) in enumerate(partial.kv_data): + cache_obj = cache[layer_idx] + new_keys = mx.concatenate([cache_obj.keys, p_keys], axis=2) + new_values = mx.concatenate([cache_obj.values, p_values], axis=2) + replacements.append((layer_idx, new_keys, new_values)) + except Exception as e: + logger.debug("MRU partial splice build failed: %s, evicting", e) + self._mru_partial = None + return cache, remaining_tokens, 0 + + # Phase 2: commit. All concatenates have already succeeded; the + # only operations remaining are attribute writes and an integer + # add, which cannot raise on a well-formed cache object. + for layer_idx, new_keys, new_values in replacements: + cache_obj = cache[layer_idx] + cache_obj.keys = new_keys + cache_obj.values = new_values + cache_obj.offset += n_partial + + new_remaining = remaining_tokens[n_partial:] + logger.debug( + "Applied MRU partial: %d tokens, %d remaining", + n_partial, len(new_remaining), + ) + return cache, new_remaining, n_partial + def _get_cache_seq_len(self, cache_data: list[dict[str, Any]]) -> int: """ Get the sequence length from cache data. diff --git a/omlx/scheduler.py b/omlx/scheduler.py index bcf8f210..010e8496 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -912,6 +912,16 @@ def __init__( # None = no deferred clear pending; int = step at which to fire. self._deferred_clear_at: int | None = None + # Per-completion budget for suppressing the deferred Metal cache + # clear when the prefix cache has a warm MRU partial (a strong + # signal that the same prompt may return immediately and would + # benefit from the still-resident lazy KV tensors). + # The budget is reset to True on every _cleanup_finished() that + # arms _deferred_clear_at. Once spent (suppression has happened + # once), the deferred clear fires at the next deadline regardless + # of MRU state, bounding total deferral at 2x _DEFERRED_CLEAR_DELAY. + self._mru_clear_suppression_available: bool = False + # Cache XTC special tokens (newline + EOS) — stable per tokenizer. # Must be after _is_harmony_model / _generation_config_eos init # since _get_xtc_special_tokens() delegates to _get_stop_tokens(). @@ -3307,6 +3317,23 @@ def add_request(self, request: Request) -> None: request.remaining_tokens = request.prompt_token_ids[ block_table.num_tokens : ] + # Splice the in-memory MRU partial onto the reconstructed + # cache when the trailing tokens match. Saves the + # re-prefill of the sub-block tail on exact-repeat + # prompts. The splice is a no-op when no partial is + # stashed or the trailing tokens differ. + if request.remaining_tokens: + ( + request.prompt_cache, + request.remaining_tokens, + partial_applied, + ) = self.block_aware_cache.apply_mru_partial( + request.prompt_cache, + block_table, + request.remaining_tokens, + ) + if partial_applied > 0: + request.cached_tokens += partial_applied # For exact prefix hits we need cache state at (N-1) and the # last prompt token as input to produce the first decode logit. # Reusing cache state at N and feeding the last token again @@ -5298,6 +5325,9 @@ def _cleanup_finished(self, finished_ids: set[str]) -> None: target = self._step_counter + self._DEFERRED_CLEAR_DELAY if self._deferred_clear_at is None or target > self._deferred_clear_at: self._deferred_clear_at = target + # Each completion can stash a fresh MRU partial; grant a + # one-shot budget to defer the clear once for it. + self._mru_clear_suppression_available = True def _is_cache_corruption_error(self, error: Exception) -> bool: """Check if an error indicates cache corruption.""" @@ -5332,6 +5362,7 @@ def _recover_from_cache_error(self) -> None: # Cancel any pending deferred Metal cache clear self._deferred_clear_at = None + self._mru_clear_suppression_available = False # Clear detokenizer state to prevent contamination after recovery self._request_detokenizers.clear() @@ -5563,8 +5594,26 @@ def step(self) -> SchedulerOutput: self._deferred_clear_at is not None and self._step_counter >= self._deferred_clear_at ): - should_clear = True - self._deferred_clear_at = None + # If the prefix cache is holding a warm MRU partial and we + # haven't yet spent the per-completion suppression budget, + # defer the clear by one more _DEFERRED_CLEAR_DELAY window. + # The MRU partial is a strong predictor that the next + # request will reuse the still-resident lazy KV tensors. + # The budget is one-shot, so total deferral is bounded at + # 2x _DEFERRED_CLEAR_DELAY even under hot-prompt repeats. + if ( + self._mru_clear_suppression_available + and self.block_aware_cache is not None + and self.block_aware_cache.has_mru_partial() + ): + self._deferred_clear_at = ( + self._step_counter + self._DEFERRED_CLEAR_DELAY + ) + self._mru_clear_suppression_available = False + else: + should_clear = True + self._deferred_clear_at = None + self._mru_clear_suppression_available = False if should_clear: _sync_and_clear_cache() if ( @@ -5664,6 +5713,7 @@ def reset(self) -> None: # Cancel any pending deferred Metal cache clear self._deferred_clear_at = None + self._mru_clear_suppression_available = False def deep_reset(self) -> None: """ diff --git a/tests/test_prefix_cache.py b/tests/test_prefix_cache.py index ded2395f..b1b2a589 100644 --- a/tests/test_prefix_cache.py +++ b/tests/test_prefix_cache.py @@ -20,7 +20,11 @@ PagedCacheManager, compute_block_hash, ) -from omlx.cache.prefix_cache import BlockAwarePrefixCache, BlockCacheEntry +from omlx.cache.prefix_cache import ( + BlockAwarePrefixCache, + BlockCacheEntry, + _MRUPartialBlock, +) from omlx.cache.stats import PrefixCacheStats @@ -2413,3 +2417,475 @@ def test_store_cache_last_block_with_snapshot_uses_snapshot_meta(self, mx): f"Last block should use snapshot offset=8, not shared offset=11, " f"got {b2_meta[1]}" ) + + +class TestMRUPartialBlockCache: + """Tests for the MRU partial block cache. + + The MRU is a single-slot in-memory cache holding the trailing sub-block + tail from the most recent ``store_cache`` call. It lets exact-repeat + requests skip re-prefilling those tail tokens. + + Threat-model coverage these tests enforce: + + - **Hybrid refusal:** when any layer is non-sliceable (RotatingKVCache, + ArraysCache, etc.), the stash is suppressed entirely. Splicing into + only the sliceable layers would create per-layer offset skew at + decode time. + - **Transactional splice:** if any layer's concatenate fails, no layer + sees a mutated keys/values/offset. Half-mutated caches are silent + generation corruption. + - **Real round-trip:** store_cache populates ``_mru_partial`` via the + production extraction path; apply_mru_partial then splices it. The + tests do not hand-build ``_MRUPartialBlock`` objects for the splice + cases — that hides the extraction-vs-apply boundary the original + branch's tests missed. + """ + + @pytest.fixture + def mx(self): + try: + import mlx.core as mx + return mx + except ImportError: + pytest.skip("MLX not available") + + @pytest.fixture + def paged_cache(self): + return PagedCacheManager( + block_size=4, + max_blocks=100, + model_name="test-model", + initial_blocks=100, + ) + + @pytest.fixture + def prefix_cache(self, paged_cache): + return BlockAwarePrefixCache( + model=MockModel(num_layers=4), + paged_cache_manager=paged_cache, + paged_ssd_cache_manager=None, + ) + + def _kv_layer(self, mx, n_tokens, head_dim=4, n_kv_heads=1, fill=1.0): + return { + "state": ( + mx.full((1, n_kv_heads, n_tokens, head_dim), fill), + mx.full((1, n_kv_heads, n_tokens, head_dim), fill), + ), + "cache_type": "KVCache", + "class_name": "KVCache", + } + + def _rotating_layer(self, mx, n_tokens, head_dim=4, n_kv_heads=1): + return { + "state": ( + mx.ones((1, n_kv_heads, n_tokens, head_dim)), + mx.ones((1, n_kv_heads, n_tokens, head_dim)), + ), + "cache_type": "RotatingKVCache", + "class_name": "RotatingKVCache", + } + + def _make_reconstructed_cache(self, mx, n_layers, n_tokens, head_dim=4): + """Build a list of MockKVCache objects matching what reconstruct_cache + would produce: keys.shape[2] == offset, valid region only.""" + class MockKVCache: + def __init__(self, k, v, offset): + self.keys = k + self.values = v + self.offset = offset + + return [ + MockKVCache( + mx.ones((1, 1, n_tokens, head_dim)), + mx.ones((1, 1, n_tokens, head_dim)), + n_tokens, + ) + for _ in range(n_layers) + ] + + # --- initial state --- + + def test_init_state_empty(self, prefix_cache): + assert prefix_cache._mru_partial is None + assert prefix_cache.has_mru_partial() is False + + # --- stash semantics on uniformly sliceable layers --- + + def test_stash_after_store_with_trailing_tokens(self, prefix_cache, mx): + """6 tokens, block_size=4 → 1 full block + 2 trailing → stash captured.""" + tokens = [10, 20, 30, 40, 50, 60] + cache_data = [self._kv_layer(mx, 6) for _ in range(4)] + + prefix_cache.store_cache("req-stash", tokens, cache_data) + + partial = prefix_cache._mru_partial + assert partial is not None + assert partial.tokens == [50, 60] + assert len(partial.kv_data) == 4 + assert prefix_cache.has_mru_partial() is True + + def test_no_stash_when_block_aligned(self, prefix_cache, mx): + """Block-aligned tokens leave no trailing partial → slot cleared.""" + tokens = [10, 20, 30, 40] + cache_data = [self._kv_layer(mx, 4) for _ in range(4)] + + prefix_cache.store_cache("req-aligned", tokens, cache_data) + + assert prefix_cache._mru_partial is None + assert prefix_cache.has_mru_partial() is False + + def test_stash_replaced_on_subsequent_store(self, prefix_cache, mx): + """Each store_cache replaces the previous slot — single-slot semantics.""" + for tail in (50, 99): + tokens = [10, 20, 30, 40, tail] + cache_data = [self._kv_layer(mx, 5) for _ in range(4)] + prefix_cache.store_cache(f"req-{tail}", tokens, cache_data) + + assert prefix_cache._mru_partial is not None + assert prefix_cache._mru_partial.tokens == [99] + + def test_stash_clears_when_subsequent_store_is_block_aligned( + self, prefix_cache, mx + ): + """Stash from request A is cleared when request B has no trailing partial.""" + # First: stash a partial + prefix_cache.store_cache( + "req-a", [10, 20, 30, 40, 50], [self._kv_layer(mx, 5) for _ in range(4)] + ) + assert prefix_cache._mru_partial is not None + + # Second: block-aligned, must clear the stash so an apply attempt + # against a stale partial cannot succeed. + prefix_cache.store_cache( + "req-b", [11, 22, 33, 44], [self._kv_layer(mx, 4) for _ in range(4)] + ) + assert prefix_cache._mru_partial is None + + def test_stash_records_parent_hash_from_last_block(self, prefix_cache, mx): + """Stashed partial chains from the hash of the last full block.""" + tokens = [10, 20, 30, 40, 50, 60] + cache_data = [self._kv_layer(mx, 6) for _ in range(4)] + + block_table = prefix_cache.store_cache("req-hash", tokens, cache_data) + + # The last (and only) block's hash should be the partial's parent. + last_block = prefix_cache.paged_cache.allocated_blocks[ + block_table.block_ids[-1] + ] + assert prefix_cache._mru_partial.parent_hash == last_block.block_hash + assert last_block.block_hash is not None + + # --- threat model: hybrid refusal (B1, B2) --- + + def test_refuse_stash_when_any_layer_non_sliceable_hybrid(self, paged_cache, mx): + """Hybrid model (KVCache + RotatingKVCache): no stash. + + Splicing into only the sliceable layers produces per-layer offset + skew at decode time (review B2). The only correct behavior is to + refuse the partial entirely for hybrid models. + """ + from omlx.cache.hybrid_cache import ModelCacheConfig + + cache = BlockAwarePrefixCache( + model=MockModel(num_layers=2), + paged_cache_manager=paged_cache, + paged_ssd_cache_manager=None, + ) + tokens = [10, 20, 30, 40, 50, 60] + cache_data = [ + self._kv_layer(mx, 6), + self._rotating_layer(mx, 6), + ] + config = ModelCacheConfig.from_type_list( + ["KVCache", "RotatingKVCache"], model_name="test" + ) + + cache.store_cache("req-hybrid", tokens, cache_data, model_cache_config=config) + + assert cache._mru_partial is None + assert cache.has_mru_partial() is False + + def test_refuse_stash_when_all_layers_non_sliceable(self, prefix_cache, mx): + """Pure RotatingKVCache model also refuses stash.""" + from omlx.cache.hybrid_cache import ModelCacheConfig + + tokens = [10, 20, 30, 40, 50] + cache_data = [self._rotating_layer(mx, 5) for _ in range(4)] + config = ModelCacheConfig.from_type_list( + ["RotatingKVCache"] * 4, model_name="test" + ) + + prefix_cache.store_cache( + "req-rotating", tokens, cache_data, model_cache_config=config + ) + + assert prefix_cache._mru_partial is None + + # --- apply: real round-trip --- + + def test_apply_round_trip_exact_match(self, prefix_cache, mx): + """Real store → apply round-trip: partial produced by extraction + is consumed by the splice path, no hand-built _MRUPartialBlock.""" + tokens = [10, 20, 30, 40, 50, 60] + cache_data = [self._kv_layer(mx, 6) for _ in range(4)] + block_table = prefix_cache.store_cache("req-rt", tokens, cache_data) + + # Reconstructed cache: 4 layers × 4 tokens (the prefix only). + reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + remaining = [50, 60] + + result, new_remaining, applied = prefix_cache.apply_mru_partial( + reconstructed, block_table, remaining, + ) + + assert applied == 2 + assert new_remaining == [] + assert all(layer.offset == 6 for layer in result) + assert all(layer.keys.shape[2] == 6 for layer in result) + assert all(layer.values.shape[2] == 6 for layer in result) + + def test_apply_round_trip_prefix_match_leaves_extra_tokens( + self, prefix_cache, mx + ): + """When remaining is longer than the partial, the partial covers + its prefix and the rest is left for normal prefill.""" + tokens = [10, 20, 30, 40, 50, 60] + cache_data = [self._kv_layer(mx, 6) for _ in range(4)] + block_table = prefix_cache.store_cache("req-rt-prefix", tokens, cache_data) + + reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + remaining = [50, 60, 70, 80] # partial is [50, 60]; [70, 80] left over + + _, new_remaining, applied = prefix_cache.apply_mru_partial( + reconstructed, block_table, remaining, + ) + + assert applied == 2 + assert new_remaining == [70, 80] + + # --- apply: eviction reasons --- + + def test_apply_evicts_on_parent_hash_mismatch(self, prefix_cache, mx): + """If the matched prefix's last block hash differs from the + partial's parent_hash, the partial is from a different prefix + and must be evicted.""" + tokens = [10, 20, 30, 40, 50, 60] + cache_data = [self._kv_layer(mx, 6) for _ in range(4)] + block_table = prefix_cache.store_cache("req-evict-h", tokens, cache_data) + + # Tamper with the partial's parent_hash so it doesn't chain. + prefix_cache._mru_partial.parent_hash = b"different-prefix" + + reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + _, new_remaining, applied = prefix_cache.apply_mru_partial( + reconstructed, block_table, [50, 60], + ) + + assert applied == 0 + assert new_remaining == [50, 60] + assert prefix_cache._mru_partial is None + + def test_apply_evicts_on_token_mismatch(self, prefix_cache, mx): + """Different trailing tokens → partial cannot apply, evict.""" + tokens = [10, 20, 30, 40, 50, 60] + cache_data = [self._kv_layer(mx, 6) for _ in range(4)] + block_table = prefix_cache.store_cache("req-evict-t", tokens, cache_data) + + reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + _, new_remaining, applied = prefix_cache.apply_mru_partial( + reconstructed, block_table, [99, 60], # first token doesn't match + ) + + assert applied == 0 + assert new_remaining == [99, 60] + assert prefix_cache._mru_partial is None + + def test_apply_evicts_on_remaining_shorter_than_partial( + self, prefix_cache, mx + ): + """If remaining_tokens is shorter than the partial it cannot match.""" + tokens = [10, 20, 30, 40, 50, 60, 70] + cache_data = [self._kv_layer(mx, 7) for _ in range(4)] + block_table = prefix_cache.store_cache("req-evict-s", tokens, cache_data) + # Partial is [50, 60, 70]; remaining is shorter → must evict. + + reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + _, new_remaining, applied = prefix_cache.apply_mru_partial( + reconstructed, block_table, [50, 60], + ) + + assert applied == 0 + assert prefix_cache._mru_partial is None + + def test_apply_evicts_on_layer_count_mismatch(self, prefix_cache, mx): + """If the reconstructed cache layer count differs from the + stashed partial, evict — likely a model swap or bug, not safe.""" + tokens = [10, 20, 30, 40, 50, 60] + cache_data = [self._kv_layer(mx, 6) for _ in range(4)] + block_table = prefix_cache.store_cache("req-evict-lc", tokens, cache_data) + + # Reconstructed has only 2 layers, partial has 4 → mismatch. + reconstructed = self._make_reconstructed_cache(mx, n_layers=2, n_tokens=4) + _, _, applied = prefix_cache.apply_mru_partial( + reconstructed, block_table, [50, 60], + ) + + assert applied == 0 + assert prefix_cache._mru_partial is None + + def test_apply_noop_when_no_stash(self, prefix_cache, paged_cache, mx): + """No partial → no-op, remaining unchanged.""" + block_table = paged_cache.create_block_table("req-noop") + reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=0) + + result, remaining, applied = prefix_cache.apply_mru_partial( + reconstructed, block_table, [10, 20], + ) + + assert applied == 0 + assert remaining == [10, 20] + assert result is reconstructed + + def test_apply_noop_when_remaining_empty(self, prefix_cache, mx): + """Empty remaining → exact prefix hit already; no MRU work.""" + tokens = [10, 20, 30, 40, 50, 60] + cache_data = [self._kv_layer(mx, 6) for _ in range(4)] + block_table = prefix_cache.store_cache("req-noop-empty", tokens, cache_data) + + reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + _, _, applied = prefix_cache.apply_mru_partial( + reconstructed, block_table, [], + ) + + assert applied == 0 + # Stash must NOT be evicted on empty-remaining no-op — it could + # still match a *future* request that does have tail tokens. + assert prefix_cache._mru_partial is not None + + # --- threat model: transactional splice rollback (B3) --- + + def test_splice_failure_does_not_mutate_any_layer( + self, prefix_cache, mx + ): + """If any layer's concatenate fails, NO layer is mutated. + + Review B3: the original implementation's try/except wrapped the + whole loop, so failure on layer N>0 left layers 0..N-1 mutated + with cache.offset += n_partial while the caller was told nothing + was applied. The rewrite must build replacements first and commit + atomically. + """ + tokens = [10, 20, 30, 40, 50, 60] + cache_data = [self._kv_layer(mx, 6) for _ in range(4)] + block_table = prefix_cache.store_cache("req-rollback", tokens, cache_data) + + reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + # Snapshot the pre-splice state of every layer for an after-comparison. + before_offsets = [layer.offset for layer in reconstructed] + before_key_shapes = [layer.keys.shape for layer in reconstructed] + + # Make mx.concatenate explode on the third call (layer 1's keys). + # Calls go: layer0 keys, layer0 values, layer1 keys (boom). + from omlx.cache import prefix_cache as pc_mod + real_concatenate = pc_mod.mx.concatenate + call_count = {"n": 0} + + def flaky_concatenate(*args, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 3: + raise RuntimeError("synthetic concatenate failure") + return real_concatenate(*args, **kwargs) + + with patch.object(pc_mod.mx, "concatenate", side_effect=flaky_concatenate): + _, new_remaining, applied = prefix_cache.apply_mru_partial( + reconstructed, block_table, [50, 60], + ) + + assert applied == 0 + assert new_remaining == [50, 60] + # No layer's offset advanced. + assert [layer.offset for layer in reconstructed] == before_offsets + # No layer's keys shape changed. + assert [layer.keys.shape for layer in reconstructed] == before_key_shapes + # Slot is evicted on splice failure (don't retry a failing partial). + assert prefix_cache._mru_partial is None + + # --- threat model: multi-turn (existing_tokens > 0) --- + + def test_stash_correct_indices_when_existing_tokens_present( + self, prefix_cache, mx + ): + """When store_cache is called with existing_tokens > 0 (multi-turn), + the stash slices the partial from the correct cache region. + + cache_data is full-sequence (system prompt + new turn), so the + partial extraction must use global indices, not relative ones. + """ + # Pretend a previous turn already cached 4 tokens. + prev_tokens = [1, 2, 3, 4] + prev_cache = [self._kv_layer(mx, 4) for _ in range(4)] + prefix_cache.store_cache("req-turn-1", prev_tokens, prev_cache) + + # Second turn: 4 prev + 4 new = 8 prefix block tokens, then 2 trailing. + # Distinct fill values let us verify the stash sliced the *right* + # tokens — the partial should contain the trailing region's data + # (fill=2.0), not the prefix region (fill=1.0). + full_tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + full_cache = [] + for _ in range(4): + # First 8 positions = old (1.0), last 2 positions = new (2.0) + keys = mx.concatenate( + [ + mx.full((1, 1, 8, 4), 1.0), + mx.full((1, 1, 2, 4), 2.0), + ], + axis=2, + ) + values = mx.concatenate( + [ + mx.full((1, 1, 8, 4), 1.0), + mx.full((1, 1, 2, 4), 2.0), + ], + axis=2, + ) + full_cache.append({ + "state": (keys, values), + "cache_type": "KVCache", + "class_name": "KVCache", + }) + + prefix_cache.store_cache("req-turn-2", full_tokens, full_cache) + + partial = prefix_cache._mru_partial + assert partial is not None + assert partial.tokens == [9, 10] + # Each layer's stashed slice must be the trailing region (fill=2.0). + for keys, values in partial.kv_data: + assert keys.shape[2] == 2 + assert mx.allclose(keys, mx.full((1, 1, 2, 4), 2.0)) + assert mx.allclose(values, mx.full((1, 1, 2, 4), 2.0)) + + +class TestHasMRUPartial: + """The has_mru_partial() accessor is the public API the scheduler + uses to decide whether to suppress the deferred Metal cache clear.""" + + def test_has_mru_partial_reflects_slot(self): + cache = BlockAwarePrefixCache( + model=MockModel(num_layers=2), + paged_cache_manager=PagedCacheManager( + block_size=4, max_blocks=10, model_name="t", initial_blocks=10, + ), + paged_ssd_cache_manager=None, + ) + assert cache.has_mru_partial() is False + + cache._mru_partial = _MRUPartialBlock( + parent_hash=b"x", tokens=[1], kv_data=[], + ) + assert cache.has_mru_partial() is True + + cache._mru_partial = None + assert cache.has_mru_partial() is False diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 5e87be0f..c8b3d632 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1360,6 +1360,183 @@ def test_cleanup_finished_extends_deferred_clear_for_concurrent_completions( assert scheduler._deferred_clear_at > first_target +class TestMRUDeferredClearSuppression: + """Tests for the MRU-aware deferred-clear suppression. + + When the prefix cache holds a warm MRU partial (the previous request + just left a sub-block tail in memory), the next deferred clear is + deferred by one more ``_DEFERRED_CLEAR_DELAY`` window so the still- + resident lazy KV tensors aren't dropped before a likely repeat. + + The suppression is **budgeted at one-shot per completion** so it + cannot indefinitely defer the clear under hot-prompt repeats — total + deferral is bounded at 2x ``_DEFERRED_CLEAR_DELAY`` even if the MRU + is replenished on every iteration. + """ + + def _drive_clear_gate(self, scheduler): + """Run the step()-tail gate that decides whether to clear. + + Mirrors the structure of step()'s end-of-step block — the + smallest slice of step() needed to exercise the suppression + logic without standing up a full BatchGenerator. + """ + from omlx import scheduler as sched_mod + + with patch.object(sched_mod, "_sync_and_clear_cache") as mock_clear: + scheduler._step_counter += 1 + should_clear = scheduler._should_periodic_clear_cache() + if ( + scheduler._deferred_clear_at is not None + and scheduler._step_counter >= scheduler._deferred_clear_at + ): + if ( + scheduler._mru_clear_suppression_available + and scheduler.block_aware_cache is not None + and scheduler.block_aware_cache.has_mru_partial() + ): + scheduler._deferred_clear_at = ( + scheduler._step_counter + + scheduler._DEFERRED_CLEAR_DELAY + ) + scheduler._mru_clear_suppression_available = False + else: + should_clear = True + scheduler._deferred_clear_at = None + scheduler._mru_clear_suppression_available = False + if should_clear: + sched_mod._sync_and_clear_cache() + return mock_clear.called + + def test_suppression_budget_armed_on_completion( + self, mock_model, mock_tokenizer + ): + """Each completion that arms _deferred_clear_at also arms the budget.""" + scheduler = Scheduler(model=mock_model, tokenizer=mock_tokenizer) + + request = Request( + request_id="req-arm", + prompt="hello", + sampling_params=SamplingParams(), + ) + request.prompt_token_ids = [1, 2] + request.num_prompt_tokens = 2 + request.output_token_ids = [3] + scheduler.running["req-arm"] = request + scheduler.requests["req-arm"] = request + + with patch("omlx.scheduler.mx"): + scheduler._cleanup_finished({"req-arm"}) + + assert scheduler._deferred_clear_at is not None + assert scheduler._mru_clear_suppression_available is True + + def test_clear_fires_at_deadline_when_no_mru( + self, mock_model, mock_tokenizer + ): + """Without a warm MRU, the deferred clear fires at its deadline.""" + scheduler = Scheduler(model=mock_model, tokenizer=mock_tokenizer) + scheduler.block_aware_cache = MagicMock() + scheduler.block_aware_cache.has_mru_partial.return_value = False + + scheduler._deferred_clear_at = scheduler._step_counter + 1 + scheduler._mru_clear_suppression_available = True + + cleared = self._drive_clear_gate(scheduler) + + assert cleared is True + assert scheduler._deferred_clear_at is None + assert scheduler._mru_clear_suppression_available is False + + def test_clear_deferred_once_when_mru_warm( + self, mock_model, mock_tokenizer + ): + """A warm MRU at the deadline defers the clear by one more window.""" + scheduler = Scheduler(model=mock_model, tokenizer=mock_tokenizer) + scheduler.block_aware_cache = MagicMock() + scheduler.block_aware_cache.has_mru_partial.return_value = True + + scheduler._deferred_clear_at = scheduler._step_counter + 1 + scheduler._mru_clear_suppression_available = True + + cleared = self._drive_clear_gate(scheduler) + + assert cleared is False + # Deadline pushed out by one more DELAY window + assert scheduler._deferred_clear_at == ( + scheduler._step_counter + Scheduler._DEFERRED_CLEAR_DELAY + ) + # Budget spent — next deadline cannot suppress again + assert scheduler._mru_clear_suppression_available is False + + def test_clear_fires_at_second_deadline_even_if_mru_still_warm( + self, mock_model, mock_tokenizer + ): + """Budget is one-shot: the second deadline fires regardless of MRU. + + This is the bound that protects against infinite deferral under + hot-prompt repeats (review B6). + """ + scheduler = Scheduler(model=mock_model, tokenizer=mock_tokenizer) + scheduler.block_aware_cache = MagicMock() + scheduler.block_aware_cache.has_mru_partial.return_value = True + + # First deadline → suppressed + scheduler._deferred_clear_at = scheduler._step_counter + 1 + scheduler._mru_clear_suppression_available = True + first_cleared = self._drive_clear_gate(scheduler) + assert first_cleared is False + + # Advance to the second deadline. MRU still warm, but no budget. + scheduler._step_counter = scheduler._deferred_clear_at - 1 + second_cleared = self._drive_clear_gate(scheduler) + + assert second_cleared is True + assert scheduler._deferred_clear_at is None + + def test_clear_fires_immediately_after_mru_evicted( + self, mock_model, mock_tokenizer + ): + """If the MRU evicts before the deadline, the clear fires normally.""" + scheduler = Scheduler(model=mock_model, tokenizer=mock_tokenizer) + scheduler.block_aware_cache = MagicMock() + # MRU was warm at completion but evicted by the time we reach deadline + scheduler.block_aware_cache.has_mru_partial.return_value = False + + scheduler._deferred_clear_at = scheduler._step_counter + 1 + scheduler._mru_clear_suppression_available = True + + cleared = self._drive_clear_gate(scheduler) + + assert cleared is True + # Budget left untouched as a side effect doesn't matter — it's + # reset on the next completion either way. + assert scheduler._deferred_clear_at is None + + def test_fresh_completion_resets_budget( + self, mock_model, mock_tokenizer + ): + """A new completion after suppression refreshes the budget.""" + scheduler = Scheduler(model=mock_model, tokenizer=mock_tokenizer) + scheduler._mru_clear_suppression_available = False # spent + + request = Request( + request_id="req-refresh", + prompt="hello", + sampling_params=SamplingParams(), + ) + request.prompt_token_ids = [1, 2] + request.num_prompt_tokens = 2 + request.output_token_ids = [3] + scheduler.running["req-refresh"] = request + scheduler.requests["req-refresh"] = request + + with patch("omlx.scheduler.mx"): + scheduler._cleanup_finished({"req-refresh"}) + + assert scheduler._mru_clear_suppression_available is True + + class TestPeriodicClearGating: """Tests for the conditional periodic clear (#978/#1040 mitigation).""" From dab38078e012435e55910ea2e1901407c4de0ef7 Mon Sep 17 00:00:00 2001 From: Blightbow Date: Fri, 8 May 2026 02:00:23 -0400 Subject: [PATCH 04/18] fix(cache): close MRU partial review findings (C1-C3, H2-H3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adversarial review of the prior commit (a642e04) caught five concrete holes the rewrite introduced or carried over from its predecessor. Each is fixed below with a regression test that fails on the prior commit and passes here. C1 — gate replaced registry lookup with explicit whitelist. _all_layers_sliceable consulted CacheTypeRegistry, whose get_handler_by_class_name falls through to DefaultCacheHandler for any class name without a registered handler. DefaultCacheHandler inherits from KVCacheHandler and reports supports_block_slicing=True. Several real non-sliceable types are mapped in _class_name_map but have no registered handler: - BatchRotatingKVCache (BATCH_ROTATING_KVCACHE enum, no handler) - BatchPoolingCache, PoolingCache (registered only when the deepseek_v4 patch is applied) The registry would silently classify these as sliceable, recreating exactly the silent-corruption hazard the rewrite was supposed to close, just from a different angle. The fix consults the existing KNOWN_SLICEABLE_CACHE_TYPES whitelist (the same list the rest of the scheduler trusts for snapshot-skip and partial-extraction decisions), promoted from a private alias in scheduler.py to a public constant on omlx.cache.type_registry so both modules share one source of truth. Test: test_refuse_stash_when_layer_falls_through_to_default_handler asserts the registry would have lied about BatchRotatingKVCache and the new gate refuses it. C2 — clear() wipes the MRU slot. BlockAwarePrefixCache.clear() reset _request_tables, _prefix_index, and the paged cache, but left _mru_partial alive. Scheduler.reset() and Scheduler._recover_from_cache_error() both route through clear() — meaning a stale partial would survive exactly the cache-corruption recovery path that exists *because* something was wrong. After such a recovery, a future request that happens to reproduce the same prompt prefix would get its compute_block_hash matching the partial's parent_hash and the splice would fire against a freshly-reconstructed cache. Test: test_clear_wipes_mru_partial. C3 — suppression budget arms only on transition from None. The prior commit's docstring claimed "total deferral bounded at 2x _DEFERRED_CLEAR_DELAY." False under hot-prompt repeats: every completion landing while a deferral was pending re-armed _mru_clear_suppression_available = True. Workloads whose completions arrive faster than _DEFERRED_CLEAR_DELAY (the very workload this feature targets) keep refreshing the budget after it's spent, deferring the clear forever and defeating the pool-bloat mitigation (#411). Fix: arm the budget only when starting a new deferral epoch (_deferred_clear_at transitions from None). Subsequent completions in the same epoch may still extend the deadline (the #557 invariant) but do not refresh the budget. One suppression per epoch, enforced. Test: test_completion_within_open_epoch_does_not_refresh_budget drives two sequential completions, simulates the budget being spent between them, and asserts the second completion extends the deadline but leaves the budget at False. The renamed test_new_epoch_completion_arms_budget pins the converse. H2 — global-vs-local indices are now classified, not heuristic'd. cache_uses_global_indices = (existing_tokens > 0 and cache_seq_len >= existing_tokens + 1) silently classified ambiguous lengths as "local." In multi-turn requests where cache_data was extracted at a boundary equalling the prior turn's length, the cache is global but cache_seq_len falls between local_len and global_end — the old predicate said "local" and the partial was sliced from the prefix region instead of the trailing tail. parent_hash still matched on the next request, and a future apply spliced wrong KV. Silent generation corruption — exactly the failure class the rewrite was supposed to close. Replaced with an explicit three-way classification: cache_seq_len >= partial_global_end -> global indices cache_seq_len == local_len -> local indices otherwise -> refuse to stash Refusing the ambiguous case is strictly safer than guessing. Test: test_refuse_stash_on_ambiguous_cache_layout drives the boundary directly with cache_seq_len strictly between local_len and global_end and asserts no stash. H3 — stash gated on paged_ssd_cache presence. In paged-SSD-only configurations (the only configuration this class supports for production reconstruction), reconstruct_cache returns None when paged_ssd_cache is None, which means apply_mru_partial is unreachable from the scheduler. Without the gate, the stash held a multi-MB tensor reference dead in memory until the next store_cache overwrote it — wasted memory scaling with model size. Test: test_no_stash_when_paged_ssd_cache_is_none. H4 — accounting divergence documented (no behaviour change). After a successful splice, cached_tokens is advanced by the partial length but shared_prefix_blocks is not (the partial is not a stored paged block). The relaxed invariant cached_tokens >= shared_prefix_blocks * block_size is now documented at the scheduler call site, with a guard against the most likely future misuse (indexing block_table.block_ids by shared_prefix_blocks while bounding the loop with cached_tokens). Per-test-class layout: TestMRUPartialBlockCache 18 -> 22 (+ C1, C2, H2, H3) TestHasMRUPartial 1 TestMRUDeferredClearSuppression 6 -> 7 (+ C3 invariant test; previous "fresh completion" test renamed to reflect the new contract) Suite results on this commit: 194 passed in 0.57s Co-Authored-By: Claude Opus 4.7 (1M context) --- omlx/cache/prefix_cache.py | 71 ++++++++++++---- omlx/cache/type_registry.py | 27 ++++++ omlx/scheduler.py | 50 +++++++---- tests/test_prefix_cache.py | 165 +++++++++++++++++++++++++++++++++++- tests/test_scheduler.py | 82 ++++++++++++++++-- 5 files changed, 354 insertions(+), 41 deletions(-) diff --git a/omlx/cache/prefix_cache.py b/omlx/cache/prefix_cache.py index 6e4f5e97..121a59c2 100644 --- a/omlx/cache/prefix_cache.py +++ b/omlx/cache/prefix_cache.py @@ -31,7 +31,7 @@ ) from .paged_ssd_cache import PagedSSDCacheManager from .stats import PrefixCacheStats -from .type_registry import CacheTypeRegistry +from .type_registry import KNOWN_SLICEABLE_CACHE_TYPES, CacheTypeRegistry logger = logging.getLogger(__name__) @@ -727,19 +727,32 @@ def has_mru_partial(self) -> bool: def _all_layers_sliceable( self, layer_cache_types: list[str] | None ) -> bool: - """True iff every layer's cache type supports block slicing. + """True iff every layer's cache type is in the known-sliceable + whitelist. + + Why a whitelist and not ``CacheTypeRegistry.supports_block_slicing``: + the registry uses ``DefaultCacheHandler`` (a ``KVCacheHandler`` + subclass that reports ``supports_block_slicing=True``) as the + fallback for class names with no registered handler. Several real + non-sliceable types — ``BatchRotatingKVCache``, ``BatchPoolingCache``, + ``PoolingCache`` (without the deepseek_v4 patch applied) — are + mapped in ``_class_name_map`` but have no registered handler, so + the registry would silently classify them as sliceable. The + whitelist is the same one the rest of the scheduler trusts for + snapshot-skip and partial-extraction decisions, so MRU now agrees + with the surrounding code rather than getting its own answer. Hybrid models (e.g. Gemma 3, Mistral) mix sliceable layers with - non-sliceable ones (RotatingKVCache, ArraysCache). Splicing the - partial only into the sliceable layers would create per-layer - offset skew at decode — undefined behavior at the model level. + non-sliceable ones (``RotatingKVCache``, ``ArraysCache``). Splicing + the partial only into the sliceable layers would create per-layer + offset skew at decode — undefined behaviour at the model level — + so we refuse the stash entirely whenever any layer is non-sliceable. """ if not layer_cache_types: # Default fallback path assumes all layers are KVCache; safe to stash. return True for class_name in layer_cache_types: - handler = CacheTypeRegistry.get_handler_by_class_name(class_name) - if not handler.supports_block_slicing: + if class_name not in KNOWN_SLICEABLE_CACHE_TYPES: return False return True @@ -759,14 +772,15 @@ def _update_mru_partial( """Refresh the MRU partial slot from a just-completed store_cache. Clears the slot in every "no eligible tail" branch (no trailing - tokens, non-tensor data, MLX missing, hybrid model, extraction - failure) so a stale partial cannot survive into a future - ``apply_mru_partial`` call. + tokens, non-tensor data, MLX missing, hybrid model, no SSD + configured, extraction failure, or ambiguous cache layout) so a + stale partial cannot survive into a future ``apply_mru_partial``. """ if ( trailing_partial_tokens == 0 or not is_tensor_data or not HAS_MLX + or self.paged_ssd_cache is None or not self._all_layers_sliceable(layer_cache_types) ): self._mru_partial = None @@ -776,16 +790,36 @@ def _update_mru_partial( partial_global_start = existing_tokens + partial_start partial_global_end = partial_global_start + trailing_partial_tokens + # Decide which axis-2 index range covers the trailing partial: + # - Global: cache_data spans the full sequence (prefix + new tail); + # slice at [partial_global_start:partial_global_end]. + # - Local: cache_data spans only the newly-processed suffix; + # slice at [partial_start:partial_start + trailing_partial_tokens]. + # + # We classify by exact cache length, not the previous + # ``cache_seq_len >= existing_tokens + 1`` heuristic — that one + # silently picked "local" when ``cache_seq_len == existing_tokens``, + # which can happen on multi-turn requests where the cache was + # extracted at a boundary that happens to equal the prior turn's + # length. Slicing local indices on a global cache there sampled + # tokens from the prefix, parent_hash still matched, and a future + # apply spliced wrong KV — silent generation corruption. Now: if + # the layout cannot be unambiguously classified, refuse to stash. cache_seq_len = self._get_cache_seq_len(cache_data) - cache_uses_global_indices = ( - existing_tokens > 0 and cache_seq_len >= existing_tokens + 1 - ) - if cache_uses_global_indices: + local_len = len(new_tokens) + if cache_seq_len >= partial_global_end: + # Cache is long enough to contain the global partial range. p_cache_start = partial_global_start p_cache_end = partial_global_end - else: + elif existing_tokens == 0 or cache_seq_len == local_len: + # Cache covers only the new suffix (or there is no prefix). p_cache_start = partial_start p_cache_end = partial_start + trailing_partial_tokens + else: + # Ambiguous: cache_seq_len is short of global_end but does + # not equal local_len either. Refuse rather than guess. + self._mru_partial = None + return # is_last_block=False is the correct intent for partial extraction: # we want a token slice, not the full state of any non-sliceable @@ -2683,6 +2717,13 @@ def clear(self) -> int: self._request_tables.clear() self._prefix_index.clear() self.paged_cache.clear() + # The MRU partial chains from a paged-block hash; once the paged + # cache is wiped, the partial cannot be safely applied even if a + # future request happens to reproduce the same parent_hash. + # Cache-corruption recovery (Scheduler._recover_from_cache_error) + # routes through here, so a stale partial would otherwise survive + # exactly the recovery path that exists because something was wrong. + self._mru_partial = None self.reset_stats() return cleared_count diff --git a/omlx/cache/type_registry.py b/omlx/cache/type_registry.py index 832c4a71..add800ab 100644 --- a/omlx/cache/type_registry.py +++ b/omlx/cache/type_registry.py @@ -23,6 +23,33 @@ logger = logging.getLogger(__name__) +# Authoritative whitelist of cache class names whose KV state is safe to +# slice along the sequence axis (axis=2). Used by: +# - scheduler.py: snapshot-skip gating, partial extraction +# - prefix_cache.py: MRU partial stash/apply gate +# +# Important: this is NOT derivable from CacheTypeHandler.supports_block_slicing +# alone, because DefaultCacheHandler (used as a fallback for class names with +# no registered handler) inherits from KVCacheHandler and reports +# supports_block_slicing=True. That makes Batch* and Pool* class names +# silently sliceable via the registry, which is wrong. Code that needs to +# decide "may I slice this layer's KV" MUST consult this whitelist by +# class-name string, not the registry. +KNOWN_SLICEABLE_CACHE_TYPES = frozenset( + { + "KVCache", + "BatchKVCache", + "QuantizedKVCache", + "TurboQuantKVCache", + "BatchTurboQuantKVCache", + # ChunkedKVCache is included once the batch=1 patch in scheduler.py + # installs its extract/filter/size pass-throughs (PR #1152); without + # that patch, Llama-4 requests fall back to the snapshot path. + "ChunkedKVCache", + } +) + + class CacheTypeRegistry: """Registry for cache type handlers. diff --git a/omlx/scheduler.py b/omlx/scheduler.py index 010e8496..8233c913 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -412,19 +412,15 @@ def _patched_ppb_prompt(self, tokens): # Cache class names known to be sliceable (no boundary snapshots needed). -# ChunkedKVCache is included once the batch=1 patch above installs its -# extract/filter/size pass-throughs; without it Llama-4 requests fall -# back to the snapshot path unnecessarily. -_KNOWN_SLICEABLE_CACHE_TYPES = frozenset( - { - "KVCache", - "BatchKVCache", - "QuantizedKVCache", - "TurboQuantKVCache", - "BatchTurboQuantKVCache", - "ChunkedKVCache", - } -) +# Canonical home: omlx/cache/type_registry.py. ChunkedKVCache is included +# there once the batch=1 patch above installs its extract/filter/size +# pass-throughs (PR #1152); without it Llama-4 requests fall back to the +# snapshot path unnecessarily. +from omlx.cache.type_registry import KNOWN_SLICEABLE_CACHE_TYPES + +# Module-local alias kept for backwards compatibility with existing +# call sites in this file. +_KNOWN_SLICEABLE_CACHE_TYPES = KNOWN_SLICEABLE_CACHE_TYPES def _prompt_cache_needs_snapshots(prompt_cache: list[Any]) -> bool: @@ -3322,6 +3318,20 @@ def add_request(self, request: Request) -> None: # re-prefill of the sub-block tail on exact-repeat # prompts. The splice is a no-op when no partial is # stashed or the trailing tokens differ. + # + # Accounting note: the partial advances cached_tokens + # but is NOT a stored paged block, so shared_prefix_blocks + # stays at the count of paged blocks reused. After a + # successful splice the invariant relaxes from + # cached_tokens == shared_prefix_blocks * block_size + # to + # cached_tokens >= shared_prefix_blocks * block_size + # with cached_tokens - shared_prefix_blocks * block_size + # ∈ [0, block_size) representing the partial. Current + # readers (logging at 2828, 2835) tolerate the relaxed + # form; future readers that index block_table.block_ids + # by shared_prefix_blocks must NOT use cached_tokens to + # bound the loop. if request.remaining_tokens: ( request.prompt_cache, @@ -5324,10 +5334,18 @@ def _cleanup_finished(self, finished_ids: set[str]) -> None: # finished their completeMemory() callbacks (#557). target = self._step_counter + self._DEFERRED_CLEAR_DELAY if self._deferred_clear_at is None or target > self._deferred_clear_at: + # Arm the suppression budget only when STARTING a new + # deferral epoch (transition from None). Subsequent + # completions in the same epoch may legitimately push the + # deadline out (for IOKit safety, #557) but must NOT + # refresh the budget — otherwise a hot-prompt workload + # whose completions arrive faster than _DEFERRED_CLEAR_DELAY + # could keep re-arming after the budget was spent and + # defer the clear forever, defeating the pool-bloat + # mitigation (#411). One suppression per epoch, total. + if self._deferred_clear_at is None: + self._mru_clear_suppression_available = True self._deferred_clear_at = target - # Each completion can stash a fresh MRU partial; grant a - # one-shot budget to defer the clear once for it. - self._mru_clear_suppression_available = True def _is_cache_corruption_error(self, error: Exception) -> bool: """Check if an error indicates cache corruption.""" diff --git a/tests/test_prefix_cache.py b/tests/test_prefix_cache.py index b1b2a589..a1660621 100644 --- a/tests/test_prefix_cache.py +++ b/tests/test_prefix_cache.py @@ -2460,11 +2460,29 @@ def paged_cache(self): ) @pytest.fixture - def prefix_cache(self, paged_cache): + def mock_ssd(self): + """An SSD manager mock — present, not used. + + The MRU stash gates on ``paged_ssd_cache is not None`` because in + the no-SSD configuration ``reconstruct_cache`` returns ``None`` and + ``apply_mru_partial`` is unreachable; stashing then would only + produce dead memory. Tests that exercise stash/apply directly + need an SSD instance present even though the mocked save/load + paths are not exercised. + """ + mock = MagicMock() + mock.save_block.return_value = True + mock.load_block.return_value = None + mock.load_block_with_metadata.return_value = (None, None) + mock.has_block.return_value = False + return mock + + @pytest.fixture + def prefix_cache(self, paged_cache, mock_ssd): return BlockAwarePrefixCache( model=MockModel(num_layers=4), paged_cache_manager=paged_cache, - paged_ssd_cache_manager=None, + paged_ssd_cache_manager=mock_ssd, ) def _kv_layer(self, mx, n_tokens, head_dim=4, n_kv_heads=1, fill=1.0): @@ -2579,7 +2597,9 @@ def test_stash_records_parent_hash_from_last_block(self, prefix_cache, mx): # --- threat model: hybrid refusal (B1, B2) --- - def test_refuse_stash_when_any_layer_non_sliceable_hybrid(self, paged_cache, mx): + def test_refuse_stash_when_any_layer_non_sliceable_hybrid( + self, paged_cache, mock_ssd, mx + ): """Hybrid model (KVCache + RotatingKVCache): no stash. Splicing into only the sliceable layers produces per-layer offset @@ -2591,7 +2611,7 @@ def test_refuse_stash_when_any_layer_non_sliceable_hybrid(self, paged_cache, mx) cache = BlockAwarePrefixCache( model=MockModel(num_layers=2), paged_cache_manager=paged_cache, - paged_ssd_cache_manager=None, + paged_ssd_cache_manager=mock_ssd, ) tokens = [10, 20, 30, 40, 50, 60] cache_data = [ @@ -2623,6 +2643,143 @@ def test_refuse_stash_when_all_layers_non_sliceable(self, prefix_cache, mx): assert prefix_cache._mru_partial is None + def test_refuse_stash_when_layer_falls_through_to_default_handler( + self, prefix_cache, mx + ): + """Non-sliceable types whose handler is unregistered (fall through + to ``DefaultCacheHandler``, which inherits ``KVCacheHandler``'s + ``supports_block_slicing=True``) must still be refused. + + Concrete case: ``BatchRotatingKVCache`` is mapped in + ``_class_name_map`` to ``BATCH_ROTATING_KVCACHE`` but no handler + is registered for that enum. The original rewrite's + registry-based gate would have classified it as sliceable, + recreating exactly the silent-corruption hazard the rewrite was + supposed to close, just from a different angle. The fix uses an + explicit class-name whitelist (``KNOWN_SLICEABLE_CACHE_TYPES``) + instead of the registry. + """ + from omlx.cache.hybrid_cache import ModelCacheConfig + from omlx.cache.type_registry import ( + KNOWN_SLICEABLE_CACHE_TYPES, + CacheTypeRegistry, + ) + + # Sanity: the registry would lie about this class name. + handler = CacheTypeRegistry.get_handler_by_class_name("BatchRotatingKVCache") + assert handler.supports_block_slicing is True + # And the whitelist correctly excludes it. + assert "BatchRotatingKVCache" not in KNOWN_SLICEABLE_CACHE_TYPES + + tokens = [10, 20, 30, 40, 50, 60] + # cache_data shape doesn't matter — store_cache must refuse before + # any extraction is attempted. + cache_data = [self._kv_layer(mx, 6) for _ in range(4)] + config = ModelCacheConfig.from_type_list( + ["BatchRotatingKVCache"] * 4, model_name="test" + ) + + prefix_cache.store_cache( + "req-batch-rotating", tokens, cache_data, model_cache_config=config + ) + + assert prefix_cache._mru_partial is None + assert prefix_cache.has_mru_partial() is False + + # --- threat model: stale-slot eviction at clear() (C2) --- + + def test_clear_wipes_mru_partial(self, prefix_cache, mx): + """``BlockAwarePrefixCache.clear()`` must drop the MRU slot. + + The scheduler's cache-corruption recovery routes through + ``clear()``. A surviving partial chains from a paged-block hash + whose backing block was just freed; if a future request happens + to reproduce the same prompt prefix, ``compute_block_hash`` could + coincidentally yield the same hash and the partial would splice + into a freshly-reconstructed cache. The KV would happen to be + correct (the model is deterministic) but the chain-of-trust is + gone — and the survival happens via *exactly the recovery path + that exists because something was wrong*. + """ + prefix_cache.store_cache( + "req-clear", + [10, 20, 30, 40, 50, 60], + [self._kv_layer(mx, 6) for _ in range(4)], + ) + assert prefix_cache._mru_partial is not None + + prefix_cache.clear() + + assert prefix_cache._mru_partial is None + assert prefix_cache.has_mru_partial() is False + + # --- threat model: H2 ambiguous cache layout --- + + def test_refuse_stash_on_ambiguous_cache_layout( + self, prefix_cache, mx + ): + """Cache lengths that don't unambiguously map to global or local + indexing must refuse the stash. + + Multi-turn requests can produce ``cache_seq_len == + existing_tokens`` or shapes between local and global. The + previous heuristic (``cache_seq_len >= existing_tokens + 1``) + silently picked "local" on the boundary, slicing local indices + out of a global-indexed cache and capturing tokens from the + prefix instead of the trailing tail. parent_hash still matched, + and a future apply spliced wrong KV — silent generation + corruption. + + Drive that boundary directly: cache_seq_len falls strictly + between global_end and local_len. + """ + # First turn: cache 4 tokens. + prefix_cache.store_cache( + "req-turn-1", + [1, 2, 3, 4], + [self._kv_layer(mx, 4) for _ in range(4)], + ) + + # Second turn: 8 prefix-aligned tokens (1 full block + 1 partial-block). + # Hand a cache_data whose cache_seq_len is 6 — strictly between: + # - local_len = len(new_tokens) = len(tokens) - existing_tokens = 4 + # - global_end = existing_tokens + new_count = 4 + 4 = 8 + # global_end (8) > cache_seq_len (6) > local_len (4): ambiguous. + full_tokens = [1, 2, 3, 4, 5, 6, 7, 8] + cache_data = [self._kv_layer(mx, 6) for _ in range(4)] + + prefix_cache.store_cache( + "req-turn-2-ambiguous", full_tokens, cache_data + ) + + # Refuse rather than guess. Stash must be the previous turn's + # state cleared (block-aligned turn 1 has no stash anyway), not + # a guessed-wrong turn 2. + assert prefix_cache._mru_partial is None + + # --- threat model: H3 no-SSD config --- + + def test_no_stash_when_paged_ssd_cache_is_none(self, paged_cache, mx): + """Without an SSD manager, ``reconstruct_cache`` returns ``None``, + which means ``apply_mru_partial`` is unreachable from the scheduler. + Stashing in this configuration would only produce dead memory + (a multi-MB tensor reference held until the next ``store_cache`` + with no possible consumer).""" + cache = BlockAwarePrefixCache( + model=MockModel(num_layers=4), + paged_cache_manager=paged_cache, + paged_ssd_cache_manager=None, + ) + + cache.store_cache( + "req-no-ssd", + [10, 20, 30, 40, 50, 60], + [self._kv_layer(mx, 6) for _ in range(4)], + ) + + assert cache._mru_partial is None + assert cache.has_mru_partial() is False + # --- apply: real round-trip --- def test_apply_round_trip_exact_match(self, prefix_cache, mx): diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index c8b3d632..892db445 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1513,29 +1513,99 @@ def test_clear_fires_immediately_after_mru_evicted( # reset on the next completion either way. assert scheduler._deferred_clear_at is None - def test_fresh_completion_resets_budget( + def test_new_epoch_completion_arms_budget( self, mock_model, mock_tokenizer ): - """A new completion after suppression refreshes the budget.""" + """A completion that STARTS a new deferral epoch (transition from + ``_deferred_clear_at is None``) arms the budget. Together with + the next test, this pins the contract: budget is one-shot per + epoch.""" scheduler = Scheduler(model=mock_model, tokenizer=mock_tokenizer) scheduler._mru_clear_suppression_available = False # spent + assert scheduler._deferred_clear_at is None # no epoch in flight request = Request( - request_id="req-refresh", + request_id="req-new-epoch", prompt="hello", sampling_params=SamplingParams(), ) request.prompt_token_ids = [1, 2] request.num_prompt_tokens = 2 request.output_token_ids = [3] - scheduler.running["req-refresh"] = request - scheduler.requests["req-refresh"] = request + scheduler.running["req-new-epoch"] = request + scheduler.requests["req-new-epoch"] = request with patch("omlx.scheduler.mx"): - scheduler._cleanup_finished({"req-refresh"}) + scheduler._cleanup_finished({"req-new-epoch"}) assert scheduler._mru_clear_suppression_available is True + def test_completion_within_open_epoch_does_not_refresh_budget( + self, mock_model, mock_tokenizer + ): + """**The hot-prompt invariant.** A completion that lands while a + deferral is already pending must NOT refresh the budget. + + Pre-fix behaviour: every completion re-armed + ``_mru_clear_suppression_available = True``. Under hot-prompt + repeats — the very workload this feature targets — completions + arrive faster than ``_DEFERRED_CLEAR_DELAY``, so the budget + kept refreshing after being spent and the deferred clear could + be pushed forever, defeating the pool-bloat mitigation (#411). + + Post-fix: budget is armed only on the transition from ``None``, + spent at the deadline if MRU is still warm, and stays spent for + the rest of the epoch regardless of how many further completions + land. + """ + scheduler = Scheduler(model=mock_model, tokenizer=mock_tokenizer) + + # First completion opens the epoch and arms the budget. + req1 = Request( + request_id="req-open", + prompt="a", + sampling_params=SamplingParams(), + ) + req1.prompt_token_ids = [1, 2] + req1.num_prompt_tokens = 2 + req1.output_token_ids = [3] + scheduler.running["req-open"] = req1 + scheduler.requests["req-open"] = req1 + + with patch("omlx.scheduler.mx"): + scheduler._cleanup_finished({"req-open"}) + + assert scheduler._deferred_clear_at is not None + assert scheduler._mru_clear_suppression_available is True + + # Simulate the budget already being spent (suppression fired + # at first attempted deadline). + scheduler._mru_clear_suppression_available = False + + # A second completion arrives while the same epoch is still + # open. It may legitimately push the deadline out — that's the + # #557 invariant — but it must NOT refresh the spent budget. + scheduler._step_counter += 3 + req2 = Request( + request_id="req-mid-epoch", + prompt="b", + sampling_params=SamplingParams(), + ) + req2.prompt_token_ids = [4, 5] + req2.num_prompt_tokens = 2 + req2.output_token_ids = [6] + scheduler.running["req-mid-epoch"] = req2 + scheduler.requests["req-mid-epoch"] = req2 + + with patch("omlx.scheduler.mx"): + scheduler._cleanup_finished({"req-mid-epoch"}) + + # Deadline pushed (per #557), budget unchanged (per the C3 fix). + assert scheduler._deferred_clear_at == ( + scheduler._step_counter + Scheduler._DEFERRED_CLEAR_DELAY + ) + assert scheduler._mru_clear_suppression_available is False + class TestPeriodicClearGating: """Tests for the conditional periodic clear (#978/#1040 mitigation).""" From 0e5fc97d75670d0156e91d8836bd2a5d6c6a0cd2 Mon Sep 17 00:00:00 2001 From: Blightbow Date: Sat, 9 May 2026 15:16:38 -0400 Subject: [PATCH 05/18] docs(cache): pin MRU memory accounting invariant MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a docstring explaining how the MRU partial slot's memory cost flows through the existing memory enforcement machinery, and a test that pins the invariant the implicit accounting depends on. Why this is documentation, not behaviour: Tracing the budgeting paths in this codebase shows that all KV memory enforcement reads from mx.get_active_memory(): - process_memory_enforcer.py:217,232 (process-level enforcer) - scheduler.py:1567 (prefill mid-loop limit check) - scheduler.py:3328 (prefill pre-flight peak check) - scheduler.py:3377 (generation admission guard) - scheduler.py:_periodic_clear_threshold_bytes (periodic clear) - optimizations.py:65 (telemetry) There is no separate up-front KV budget that the MRU could escape. In paged-SSD-only mode (the only mode this codebase supports), _calculate_max_blocks() returns a fixed 100k block-metadata count, not a memory budget — paged blocks live on SSD, not GPU memory. The estimator helpers (estimate_block_memory, estimate_prompt_kv_bytes) are deltas computed against the current mx.get_active_memory() baseline, which already includes the MRU. _clone_tensor (prefix_cache.py:1207-1221) uses mx.copy(tensor), producing real mx.array allocations. MLX counts these in active memory automatically. So the MRU slot's ~one-block-worth of KV (~17 MiB Kimi K2.5 / DeepSeek MLA, ~41 MiB Llama 3 70B full attention) is already enforced against the same limits as the in-flight request caches. No behaviour change required. The user's "apples to apples" intuition holds. What this commit does add: 1. _MRUPartialBlock docstring gains a "Memory accounting" section enumerating the enforcement paths and the implicit invariant (kv_data holds mx.array instances). Future maintainers reading the MRU code will see why no separate accounting hook exists and not be tempted to add one. 2. test_kv_data_holds_mlx_arrays_for_active_memory_accounting asserts the invariant directly. A "helpful" future change that stored CPU-side copies (np.ndarray to dodge a perceived GPU-memory cost) would silently escape every existing memory limit and only manifest as system OOM under load. The test fails fast if that regression is introduced. Suite results on this commit: 195 passed in 0.25s (tests/test_prefix_cache.py + tests/test_scheduler.py) Co-Authored-By: Claude Opus 4.7 (1M context) --- omlx/cache/prefix_cache.py | 26 +++++++++++++++++++++++++ tests/test_prefix_cache.py | 40 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/omlx/cache/prefix_cache.py b/omlx/cache/prefix_cache.py index 121a59c2..65134b5b 100644 --- a/omlx/cache/prefix_cache.py +++ b/omlx/cache/prefix_cache.py @@ -63,6 +63,32 @@ class _MRUPartialBlock: layer in the model is non-sliceable (RotatingKVCache, ArraysCache, etc.) the slot is left empty. Splicing into only the sliceable layers would create per-layer offset skew at decode time. + + Memory accounting + ----------------- + ``kv_data`` holds real ``mx.array`` allocations (produced by + ``_clone_tensor`` → ``mx.copy``), so the slot's memory cost is + automatically counted by ``mx.get_active_memory()``. Every memory + enforcement and telemetry path in this codebase reads from there: + + - process-level enforcer (``process_memory_enforcer.py``) + - prefill mid-loop limit check (``scheduler.py:1567``) + - prefill pre-flight peak check (``scheduler.py:3328``) + - generation admission guard (``scheduler.py:3377``) + - periodic clear threshold (``scheduler.py:_periodic_clear_threshold_bytes``) + + There is no separate up-front KV budget that the slot could escape + (``_calculate_max_blocks`` is paged-SSD-only and returns a fixed + 100k block-metadata count, not a memory budget). The slot is + bounded at one ``block_size``-worth of KV per cache instance — + typically ~17 MiB for Kimi K2.5 / DeepSeek (MLA), ~41 MiB for + full-attention 70B models — held alive between completions and + admissions. Future maintainers: do **not** add a separate + accounting hook for this slot. The invariant that ``kv_data`` + holds ``mx.array`` instances (not numpy/CPU copies) is what makes + the implicit accounting work; the test + ``test_kv_data_holds_mlx_arrays_for_active_memory_accounting`` + pins it. """ parent_hash: bytes | None diff --git a/tests/test_prefix_cache.py b/tests/test_prefix_cache.py index a1660621..3e54ec19 100644 --- a/tests/test_prefix_cache.py +++ b/tests/test_prefix_cache.py @@ -2757,6 +2757,46 @@ def test_refuse_stash_on_ambiguous_cache_layout( # a guessed-wrong turn 2. assert prefix_cache._mru_partial is None + # --- accounting invariant --- + + def test_kv_data_holds_mlx_arrays_for_active_memory_accounting( + self, prefix_cache, mx + ): + """The MRU slot's memory must flow through ``mx.get_active_memory()``. + + The codebase enforces all KV-memory limits via ``mx.get_active_memory()`` + (process_memory_enforcer, the three scheduler memory checkpoints, + the periodic-clear threshold, telemetry). The MRU slot has no + separate accounting hook — it relies on the invariant that + ``kv_data`` holds real ``mx.array`` allocations, which MLX counts + in active memory automatically. + + A "helpful" future change that stored CPU-side copies (e.g. + ``np.ndarray`` to dodge a perceived GPU-memory cost) would silently + escape every existing memory limit and only manifest as system OOM + under load. Pin the invariant so that change is caught at test + time, not in production. + """ + prefix_cache.store_cache( + "req-accounting", + [10, 20, 30, 40, 50, 60], + [self._kv_layer(mx, 6) for _ in range(4)], + ) + + partial = prefix_cache._mru_partial + assert partial is not None + assert len(partial.kv_data) == 4 + for layer_idx, (keys, values) in enumerate(partial.kv_data): + assert isinstance(keys, mx.array), ( + f"layer {layer_idx} keys is {type(keys).__name__}, not mx.array. " + f"MRU memory accounting depends on mx.array storage so the " + f"slot is visible to mx.get_active_memory()." + ) + assert isinstance(values, mx.array), ( + f"layer {layer_idx} values is {type(values).__name__}, " + f"not mx.array. See above." + ) + # --- threat model: H3 no-SSD config --- def test_no_stash_when_paged_ssd_cache_is_none(self, paged_cache, mx): From 7908f306fd011f15b7f8f3ce06373feaf543c66b Mon Sep 17 00:00:00 2001 From: Blightbow Date: Sat, 9 May 2026 15:23:18 -0400 Subject: [PATCH 06/18] docs(cache): note pre-load admission's KV headroom in MRU docs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Sanity-checking the budgeting paths revealed a third memory-budget layer worth mentioning in the MRU docstring: engine_pool's pre-load admission gate (engine_pool.py:355-373) reserves a fraction of each model's weight size as KV headroom and logs "Loading {model_id} without KV headroom" when eviction can't free enough. The MRU partial is one tenant of that headroom alongside the in-flight prompt caches, but it is not separately reserved because at one block_size of KV per cache instance (~17 MiB Kimi K2.5 / ~41 MiB Llama 3 70B), the slot is dominated by the concurrent in-flight caches the headroom was sized for. Quantification across model classes: Kimi K2.5 (MLA, ~200 GB quant): 25% headroom = ~50 GB, MRU = 17 MiB → 0.0003% Llama 3 70B (Q4, ~35 GB): 25% headroom = ~9 GB, MRU = 41 MiB → 0.5% Llama 3 8B (Q4, ~4.5 GB): 25% headroom = ~1.1 GB, MRU = 10 MiB → 1% Qwen 0.5B (~1 GB): 25% headroom = 256 MiB, MRU = 5 MiB → 2% The pre-load layer's granularity (gigabytes) makes the MRU partial invisible at every model scale. The runtime enforcer catches any overrun via mx.get_active_memory() regardless. Approach unchanged; documentation is just more complete. The 25% percentage itself is intentionally not quoted in the docstring — it could change in engine_pool without invalidating the MRU's accounting model. Suite results on this commit: 195 passed in 0.23s Co-Authored-By: Claude Opus 4.7 (1M context) --- omlx/cache/prefix_cache.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/omlx/cache/prefix_cache.py b/omlx/cache/prefix_cache.py index 65134b5b..15312571 100644 --- a/omlx/cache/prefix_cache.py +++ b/omlx/cache/prefix_cache.py @@ -68,8 +68,9 @@ class _MRUPartialBlock: ----------------- ``kv_data`` holds real ``mx.array`` allocations (produced by ``_clone_tensor`` → ``mx.copy``), so the slot's memory cost is - automatically counted by ``mx.get_active_memory()``. Every memory - enforcement and telemetry path in this codebase reads from there: + automatically counted by ``mx.get_active_memory()``. Every runtime + memory enforcement and telemetry path in this codebase reads from + there: - process-level enforcer (``process_memory_enforcer.py``) - prefill mid-loop limit check (``scheduler.py:1567``) @@ -77,16 +78,24 @@ class _MRUPartialBlock: - generation admission guard (``scheduler.py:3377``) - periodic clear threshold (``scheduler.py:_periodic_clear_threshold_bytes``) - There is no separate up-front KV budget that the slot could escape - (``_calculate_max_blocks`` is paged-SSD-only and returns a fixed - 100k block-metadata count, not a memory budget). The slot is - bounded at one ``block_size``-worth of KV per cache instance — - typically ~17 MiB for Kimi K2.5 / DeepSeek (MLA), ~41 MiB for - full-attention 70B models — held alive between completions and - admissions. Future maintainers: do **not** add a separate - accounting hook for this slot. The invariant that ``kv_data`` - holds ``mx.array`` instances (not numpy/CPU copies) is what makes - the implicit accounting work; the test + Upstream of those, the engine pool's pre-load admission gate + (``engine_pool.py``) reserves a fraction of each model's weight + size as KV headroom when deciding whether to evict other models + before loading. The MRU partial is one tenant of that headroom + alongside in-flight prompt caches; it is not separately reserved + because at one ``block_size``-worth of KV per cache instance + (~17 MiB for Kimi K2.5 / DeepSeek MLA, ~41 MiB for full-attention + 70B models) it is dominated by the concurrent in-flight caches the + headroom was sized for. + + There is no separate up-front KV memory budget that the slot could + escape (``_calculate_max_blocks`` is paged-SSD-only and returns a + fixed 100k block-metadata count, not a memory budget). + + Future maintainers: do **not** add a separate accounting hook for + this slot. The invariant that ``kv_data`` holds ``mx.array`` + instances (not numpy/CPU copies) is what makes the implicit + accounting work; the test ``test_kv_data_holds_mlx_arrays_for_active_memory_accounting`` pins it. """ From cfd4fa2fa76f5b5167467e02ef7eb6570ca2e56c Mon Sep 17 00:00:00 2001 From: Blightbow Date: Sat, 9 May 2026 15:32:15 -0400 Subject: [PATCH 07/18] chore(cache): bit-rot proof MRU docs and collapse test factories MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three cleanups surfaced by the simplify review. No behaviour change. Doc bit-rot: drop file:line references from comments. The MRU stack picked up several docstring/comment references that cite specific scheduler.py line numbers and a hardcoded "100k" block-metadata count. Both will rot — line numbers shift on any edit above them, and the 100k constant lives in _calculate_max_blocks() and could change without invalidating the MRU's accounting model. - _MRUPartialBlock docstring: enumerated scheduler.py:1567, 3328, 3377, _periodic_clear_threshold_bytes. Replaced with prose naming the same gates symbolically. Dropped the "100k" magic number; the doc now says "fixed block-metadata count" since the specific number is irrelevant to the MRU's invariant. - H4 accounting note in scheduler.py: referenced "logging at 2828, 2835" inside source. Replaced with prose ("the scheduler's prefill-completion log lines downstream") so a future search by log message still finds them after motion. Test factory collapse: tests/test_prefix_cache.py: _kv_layer and _rotating_layer were near-identical (one differed only in the cache_type/class_name string and a fill kwarg). Extracted shared _layer factory taking class_name as a kwarg; the two existing helpers now delegate to it. Keeps the call sites readable while removing the copy-paste. Same factory naturally extends to BatchRotatingKVCache and other cache types when those tests grow. Findings deferred (separate PRs warranted): - Per-block loop in store_cache (prefix_cache.py:553-555) shares the cache_seq_len >= existing_tokens + 1 heuristic the H2 fix retired in _update_mru_partial. Extracting a shared _classify_cache_indexing helper and routing the per-block loop through it would close the same hazard at its other site. Needs its own safety analysis (when does the cache_seq_len == existing_tokens boundary actually arise) and regression tests scoped to the per-block path; out of scope for the MRU branch. - paged_cache.allocated_blocks.get(...) is a leaky abstraction at 9+ pre-existing sites in prefix_cache.py. Encapsulating it behind a public PagedCacheManager.get_block_by_id() method is a wider refactor that should not piggyback on MRU work. - KNOWN_SLICEABLE_CACHE_TYPES → CacheType enum has a TurboQuant caveat (_class_name_map collapses TurboQuantKVCache and BatchTurboQuantKVCache to KVCACHE). Conversion needs a deliberate decision on whether to lose the explicit gate strings. - Per-layer mx.concatenate dispatch in apply_mru_partial's phase 1 could potentially be batched. Per the prefill-perf principle, this needs an M3 Ultra measurement before changing — both to establish that the cost is significant and to confirm batched dispatch is actually faster on the platform. Suite results on this commit: 195 passed in 0.42s Co-Authored-By: Claude Opus 4.7 (1M context) --- omlx/cache/prefix_cache.py | 31 ++++++++++++--------------- omlx/scheduler.py | 7 ++++--- tests/test_prefix_cache.py | 43 ++++++++++++++++++++++++++++---------- 3 files changed, 49 insertions(+), 32 deletions(-) diff --git a/omlx/cache/prefix_cache.py b/omlx/cache/prefix_cache.py index 15312571..cbc26ffb 100644 --- a/omlx/cache/prefix_cache.py +++ b/omlx/cache/prefix_cache.py @@ -70,27 +70,22 @@ class _MRUPartialBlock: ``_clone_tensor`` → ``mx.copy``), so the slot's memory cost is automatically counted by ``mx.get_active_memory()``. Every runtime memory enforcement and telemetry path in this codebase reads from - there: - - - process-level enforcer (``process_memory_enforcer.py``) - - prefill mid-loop limit check (``scheduler.py:1567``) - - prefill pre-flight peak check (``scheduler.py:3328``) - - generation admission guard (``scheduler.py:3377``) - - periodic clear threshold (``scheduler.py:_periodic_clear_threshold_bytes``) - - Upstream of those, the engine pool's pre-load admission gate - (``engine_pool.py``) reserves a fraction of each model's weight - size as KV headroom when deciding whether to evict other models - before loading. The MRU partial is one tenant of that headroom - alongside in-flight prompt caches; it is not separately reserved - because at one ``block_size``-worth of KV per cache instance - (~17 MiB for Kimi K2.5 / DeepSeek MLA, ~41 MiB for full-attention - 70B models) it is dominated by the concurrent in-flight caches the - headroom was sized for. + there: the process-level enforcer, the scheduler's prefill + mid-loop limit check, the prefill pre-flight peak check, the + generation admission guard, and the periodic-clear threshold. + + Upstream of those, ``EnginePool`` reserves a fraction of each + model's weight size as KV headroom when deciding whether to evict + other models before loading. The MRU partial is one tenant of + that headroom alongside in-flight prompt caches; it is not + separately reserved because at one ``block_size``-worth of KV per + cache instance (~17 MiB for Kimi K2.5 / DeepSeek MLA, ~41 MiB for + full-attention 70B models) it is dominated by the concurrent + in-flight caches the headroom was sized for. There is no separate up-front KV memory budget that the slot could escape (``_calculate_max_blocks`` is paged-SSD-only and returns a - fixed 100k block-metadata count, not a memory budget). + fixed block-metadata count, not a memory budget). Future maintainers: do **not** add a separate accounting hook for this slot. The invariant that ``kv_data`` holds ``mx.array`` diff --git a/omlx/scheduler.py b/omlx/scheduler.py index 8233c913..5fefde0b 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -3328,9 +3328,10 @@ def add_request(self, request: Request) -> None: # cached_tokens >= shared_prefix_blocks * block_size # with cached_tokens - shared_prefix_blocks * block_size # ∈ [0, block_size) representing the partial. Current - # readers (logging at 2828, 2835) tolerate the relaxed - # form; future readers that index block_table.block_ids - # by shared_prefix_blocks must NOT use cached_tokens to + # readers (the scheduler's prefill-completion log lines + # downstream) tolerate the relaxed form; future readers + # that index block_table.block_ids by + # shared_prefix_blocks must NOT use cached_tokens to # bound the loop. if request.remaining_tokens: ( diff --git a/tests/test_prefix_cache.py b/tests/test_prefix_cache.py index 3e54ec19..840242ee 100644 --- a/tests/test_prefix_cache.py +++ b/tests/test_prefix_cache.py @@ -2485,25 +2485,46 @@ def prefix_cache(self, paged_cache, mock_ssd): paged_ssd_cache_manager=mock_ssd, ) - def _kv_layer(self, mx, n_tokens, head_dim=4, n_kv_heads=1, fill=1.0): + def _layer( + self, + mx, + n_tokens, + *, + class_name="KVCache", + head_dim=4, + n_kv_heads=1, + fill=1.0, + ): + """Build a layer-state dict for store_cache. + + ``class_name`` selects the cache type (e.g. ``"KVCache"``, + ``"RotatingKVCache"``, ``"BatchRotatingKVCache"``). Both + ``cache_type`` and ``class_name`` keys are populated with the + same string — ``store_cache`` consults whichever the layer + provides. + """ return { "state": ( mx.full((1, n_kv_heads, n_tokens, head_dim), fill), mx.full((1, n_kv_heads, n_tokens, head_dim), fill), ), - "cache_type": "KVCache", - "class_name": "KVCache", + "cache_type": class_name, + "class_name": class_name, } + def _kv_layer(self, mx, n_tokens, head_dim=4, n_kv_heads=1, fill=1.0): + return self._layer( + mx, n_tokens, + class_name="KVCache", + head_dim=head_dim, n_kv_heads=n_kv_heads, fill=fill, + ) + def _rotating_layer(self, mx, n_tokens, head_dim=4, n_kv_heads=1): - return { - "state": ( - mx.ones((1, n_kv_heads, n_tokens, head_dim)), - mx.ones((1, n_kv_heads, n_tokens, head_dim)), - ), - "cache_type": "RotatingKVCache", - "class_name": "RotatingKVCache", - } + return self._layer( + mx, n_tokens, + class_name="RotatingKVCache", + head_dim=head_dim, n_kv_heads=n_kv_heads, + ) def _make_reconstructed_cache(self, mx, n_layers, n_tokens, head_dim=4): """Build a list of MockKVCache objects matching what reconstruct_cache From 9b7c95aea5c942a2e299e1acf7a9b845fa16177a Mon Sep 17 00:00:00 2001 From: Blightbow Date: Wed, 13 May 2026 11:53:02 -0400 Subject: [PATCH 08/18] refactor(cache): factor _can_reconstruct, fix gate docstrings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Follow-up to the MRU partial cache peer review. The H3 stash gate at prefix_cache.py:806 and the canonical reconstruct guard at :1710 both check ``self.paged_ssd_cache is None`` to decide whether reconstruct can possibly return non-None. Co-locating the two predicates as a single ``_can_reconstruct`` helper makes the lockstep explicit — a future fetch path that bypasses PagedSSDCacheManager (alternate backends, memory-only modes that detach from the manager) updates exactly one predicate, not two. Behaviour is unchanged. The clarification is in the docstring. The MRU stash docstring previously said the slot is cleared when "no SSD configured." That phrasing is misleading on ``hot_cache_only=True`` configurations (set via settings.json or OMLX_HOT_CACHE_ONLY env): the manager IS present in that mode — only the disk writer thread and directory init are skipped. The reconstruct path still works because PagedSSDCacheManager.load_block_with_metadata short-circuits to the hot tier without ever calling mx.load. In that mode the MRU stash IS expected to populate, and the gate correctly permits it. The gate fires only when no PagedSSDCacheManager instance exists at all — typically a test/dev scenario. The new docstring on ``_can_reconstruct`` enumerates the predicate's semantics and the hot_cache_only case explicitly so a reader looking at either site arrives at the same understanding. Test changes: - ``test_no_stash_when_paged_ssd_cache_is_none`` keeps its name and behaviour (covers the no-manager case) but its docstring now distinguishes that case from ``hot_cache_only=True``, where the MRU IS expected to populate. - New ``test_can_reconstruct_helper_reflects_manager_presence`` pins the predicate's contract directly. Suite results on this commit: 197 passed in 0.38s Co-Authored-By: Claude Opus 4.7 (1M context) --- omlx/cache/prefix_cache.py | 47 ++++++++++++++++++++++++++------------ tests/test_prefix_cache.py | 45 +++++++++++++++++++++++++++++++----- 2 files changed, 72 insertions(+), 20 deletions(-) diff --git a/omlx/cache/prefix_cache.py b/omlx/cache/prefix_cache.py index cbc26ffb..6a9cd553 100644 --- a/omlx/cache/prefix_cache.py +++ b/omlx/cache/prefix_cache.py @@ -754,6 +754,24 @@ def has_mru_partial(self) -> bool: """ return self._mru_partial is not None + def _can_reconstruct(self) -> bool: + """Whether ``reconstruct_cache`` has a path to return non-``None``. + + Used as a guard at both the MRU stash site (``_update_mru_partial``) + and the canonical reconstruct entry (``reconstruct_cache``). Co- + locating the two checks keeps them in lockstep — a future fetch + path that bypasses ``PagedSSDCacheManager`` (memory-only mode, + alternate backends) updates exactly one predicate. + + Note the predicate is ``manager is not None``, not "SSD writes are + enabled." ``hot_cache_only=True`` (omlx setting / OMLX_HOT_CACHE_ONLY + env) keeps the manager but skips the disk writer thread; reconstruct + still works because ``load_block_with_metadata`` short-circuits to + the hot tier. The gate fires only when no manager is configured + at all — typically a test/dev scenario. + """ + return HAS_MLX and self.paged_ssd_cache is not None + def _all_layers_sliceable( self, layer_cache_types: list[str] | None ) -> bool: @@ -801,16 +819,16 @@ def _update_mru_partial( ) -> None: """Refresh the MRU partial slot from a just-completed store_cache. - Clears the slot in every "no eligible tail" branch (no trailing - tokens, non-tensor data, MLX missing, hybrid model, no SSD - configured, extraction failure, or ambiguous cache layout) so a - stale partial cannot survive into a future ``apply_mru_partial``. + Clears the slot in every "no eligible tail" branch — no trailing + tokens, non-tensor data, no reconstruct path configured (see + ``_can_reconstruct``), hybrid model, extraction failure, or + ambiguous cache layout — so a stale partial cannot survive into + a future ``apply_mru_partial``. """ if ( trailing_partial_tokens == 0 or not is_tensor_data - or not HAS_MLX - or self.paged_ssd_cache is None + or not self._can_reconstruct() or not self._all_layers_sliceable(layer_cache_types) ): self._mru_partial = None @@ -1710,15 +1728,16 @@ def reconstruct_cache( if not block_table or not block_table.block_ids: return None - if not HAS_MLX: - logger.warning("Cannot reconstruct cache: MLX not available") - return None - - if self.paged_ssd_cache is None: - logger.warning( - "Cannot reconstruct cache: PagedSSDCacheManager not configured" - ) + if not self._can_reconstruct(): + if not HAS_MLX: + logger.warning("Cannot reconstruct cache: MLX not available") + else: + logger.warning( + "Cannot reconstruct cache: PagedSSDCacheManager not configured" + ) return None + # _can_reconstruct() guarantees this; narrow for the type checker. + assert self.paged_ssd_cache is not None try: # Collect cache data from valid blocks (stop at first invalid) diff --git a/tests/test_prefix_cache.py b/tests/test_prefix_cache.py index 840242ee..9f14e511 100644 --- a/tests/test_prefix_cache.py +++ b/tests/test_prefix_cache.py @@ -2818,14 +2818,21 @@ def test_kv_data_holds_mlx_arrays_for_active_memory_accounting( f"not mx.array. See above." ) - # --- threat model: H3 no-SSD config --- + # --- threat model: no-reconstruct-path config --- def test_no_stash_when_paged_ssd_cache_is_none(self, paged_cache, mx): - """Without an SSD manager, ``reconstruct_cache`` returns ``None``, - which means ``apply_mru_partial`` is unreachable from the scheduler. - Stashing in this configuration would only produce dead memory - (a multi-MB tensor reference held until the next ``store_cache`` - with no possible consumer).""" + """Without a ``PagedSSDCacheManager`` instance, ``reconstruct_cache`` + returns ``None`` (``_can_reconstruct() is False``) and + ``apply_mru_partial`` is unreachable from the scheduler. Stashing + in this configuration would only produce dead memory. + + Note: this is distinct from ``hot_cache_only=True``, where the + manager IS present (the disk writer thread is what's disabled, + not the manager itself). In that mode the MRU stash IS expected + to populate — ``load_block_with_metadata`` short-circuits to the + hot tier and reconstruct still works. The gate keys on manager + presence, not on whether SSD writes are happening. + """ cache = BlockAwarePrefixCache( model=MockModel(num_layers=4), paged_cache_manager=paged_cache, @@ -2841,6 +2848,32 @@ def test_no_stash_when_paged_ssd_cache_is_none(self, paged_cache, mx): assert cache._mru_partial is None assert cache.has_mru_partial() is False + def test_can_reconstruct_helper_reflects_manager_presence( + self, paged_cache, mock_ssd + ): + """``_can_reconstruct`` is the canonical predicate keeping the + MRU stash gate and the ``reconstruct_cache`` guard in lockstep. + + It returns False only when no manager is configured at all. + ``hot_cache_only=True`` configurations (manager present, disk + writer disabled) return True because reconstruct still works + via the hot-tier short-circuit in + ``PagedSSDCacheManager.load_block_with_metadata``. + """ + cache_with = BlockAwarePrefixCache( + model=MockModel(num_layers=2), + paged_cache_manager=paged_cache, + paged_ssd_cache_manager=mock_ssd, + ) + assert cache_with._can_reconstruct() is True + + cache_without = BlockAwarePrefixCache( + model=MockModel(num_layers=2), + paged_cache_manager=paged_cache, + paged_ssd_cache_manager=None, + ) + assert cache_without._can_reconstruct() is False + # --- apply: real round-trip --- def test_apply_round_trip_exact_match(self, prefix_cache, mx): From 1bd11d3cac55471452301d9446d84b32ff970149 Mon Sep 17 00:00:00 2001 From: Blightbow Date: Wed, 13 May 2026 17:47:38 -0400 Subject: [PATCH 09/18] feat(cache): multi-slot LRU MRU partial cache MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Upgrade the MRU partial-block cache from a single slot to a bounded LRU dict keyed by parent_hash. Multiple concurrent "warm" partials coexist up to ``mru_partial_max_entries``; LRU eviction pops the oldest when capacity is reached. Single-slot mode could not absorb interleaving (multi-user / multi-conversation workloads); every ``store_cache`` overwrote the lone slot. On M3 Ultra — where prefill is firmly compute-bound at ~26 TFLOPS effective FP16 with no MatMul accelerator (see ``feedback_apple_silicon_perf.md``) — leaving prefill compute on the table for interleaved workloads is exactly the case that should not be deferred pending metrics. Data structure -------------- ``BlockAwarePrefixCache._mru_partials: OrderedDict[bytes | None, _MRUPartialBlock]`` mirrors the LRU pattern from ``PagedSSDCacheManager._hot_cache``. No internal lock — the prefix cache relies on the scheduler's single-threaded executor model (same as today's single slot). Public API stays compatible. ``has_mru_partial()`` returns ``bool(self._mru_partials)``; the scheduler's deferred-clear suppression budget reads the same boolean predicate it did before. ``apply_mru_partial()`` and ``_update_mru_partial()`` retain their signatures. Eviction discipline ------------------- - On stash: if key exists, pop and re-insert at tail; if over capacity, ``popitem(last=False)``. - On apply success: ``move_to_end(key)`` — promote to LRU tail. - On apply miss for a found key (token-prefix, layer-count, or splice failure): ``pop(key, None)`` — evict only that key. - On ``clear()`` (cache-corruption recovery): wipe the dict. - "No eligible tail" branches in ``_update_mru_partial`` no longer wipe the dict — they bare-return. A local "nothing to stash this time" signal is unrelated to the validity of other entries. This is the behavioural change that lets distinct-prefix entries coexist under interleaving. Freed-paged-block guard (new) ----------------------------- If ``block_table.block_ids`` is non-empty but the parent paged block has been freed between stash and apply, ``apply_mru_partial`` returns no-op rather than falling through to a ``None``-keyed dict lookup. That fall-through would falsely match a short-prompt entry against a request whose parent is just gone. The race is structurally new in multi-slot mode; single-slot tolerated it because there was only ever one entry to match against. Short-prompt entries (prefix < block_size, parent_hash=None) share one slot via the ``None`` key — same multi-tenant constraint as the single-slot design, but only for the short-prompt subset. Capacity & plumbing ------------------- ``mru_partial_max_entries`` threads from ``CacheSettings`` → ``--mru-partial-max-entries`` CLI flag → ``SchedulerConfig`` → both ``BlockAwarePrefixCache(...)`` construction sites (main at ``scheduler.py:804`` and SpecPrefill draft at ``scheduler.py:3473``). Default 4 matches the dflash ``max_entries`` precedent (PR #1120). ``0`` disables stashing (silent fallback to "no MRU" behaviour, mirroring the ``hot_cache_max_size="0"`` convention). Memory worst-case at default 4: ~68 MiB MLA / ~165 MiB GQA per cache instance. With two cache instances (main + SpecPrefill draft), ~136 MiB / ~330 MiB total. All inside the engine pool's 25% KV headroom envelope. Documented in the ``_MRUPartialBlock`` docstring including the ``hot_cache_only=True`` coexistence note: the hot cache and MRU dict both live in the same envelope under that mode and should be tuned together. Test surface ------------ Existing single-slot tests adapted via a test-only ``_get_mru_partial`` helper (production class surface stays clean; tests are decoupled from the internal container shape). New ``TestMRUPartialMultiSlot`` covers the multi-slot mechanics: - ``test_distinct_prefixes_coexist_as_separate_entries`` - parameterized ``test_lru_capacity_bounds`` (evict-oldest-at-capacity + under-capacity-keeps-all) - ``test_apply_success_promotes_entry_to_lru_tail`` - ``test_max_entries_zero_disables_stashing`` - ``test_clear_mru_partials_wipes_only_partials`` - ``test_apply_noop_when_parent_block_freed`` (the new guard) - ``test_short_prompt_none_key_coexists_with_block_aligned_entry`` Existing tests adapted, with a few semantically inverted: - ``test_stash_replaced_on_subsequent_store`` → ``test_same_prefix_store_replaces_entry``: same prefix → same key → correct LRU put behaviour (replace). - ``test_stash_clears_when_subsequent_store_is_block_aligned`` → ``test_no_eligible_tail_does_not_evict_siblings``: the inverse behavioural change from single-slot. - ``test_apply_evicts_on_parent_hash_mismatch`` → ``test_apply_noop_on_parent_hash_mismatch_preserves_sibling``: no-op + sibling preservation, not eviction. - ``test_clear_wipes_mru_partial`` → ``test_clear_wipes_mru_partials``. Suite results on this commit: 211 passed: cache + scheduler + admin clear-symmetry tests 9 passed: test_settings.py::TestCacheSettings Co-Authored-By: Claude Opus 4.7 (1M context) --- omlx/cache/prefix_cache.py | 240 +++++++++++++------- omlx/cli.py | 20 ++ omlx/scheduler.py | 7 + omlx/settings.py | 7 + tests/test_prefix_cache.py | 452 +++++++++++++++++++++++++++++++------ tests/test_settings.py | 1 + 6 files changed, 577 insertions(+), 150 deletions(-) diff --git a/omlx/cache/prefix_cache.py b/omlx/cache/prefix_cache.py index 6a9cd553..44fb528f 100644 --- a/omlx/cache/prefix_cache.py +++ b/omlx/cache/prefix_cache.py @@ -9,6 +9,7 @@ import logging import math import time +from collections import OrderedDict from collections.abc import Callable from dataclasses import dataclass from typing import Any @@ -46,53 +47,57 @@ class BlockCacheEntry: @dataclass class _MRUPartialBlock: - """Single-slot most-recent stash for a trailing sub-block partial. + """One entry in the bounded LRU stash for trailing sub-block partials. The paged SSD cache only persists full ``block_size`` blocks; trailing sub-block tokens (e.g. 139 out of 256) are otherwise discarded and - re-prefilled on every repeat request. ``_MRUPartialBlock`` keeps the - last partial in memory so an immediate repeat skips re-prefilling - those tail tokens. + re-prefilled on every repeat request. The MRU stash keeps the partial + in memory so an immediate repeat skips re-prefilling those tail tokens. - Single slot, replaced on every ``store_cache`` that produces a new - trailing partial. Evicted on any mismatch — parent-hash chain, token - prefix, layer count, splice failure — so a stale or mistargeted - partial cannot accumulate. + Entries are stored in ``BlockAwarePrefixCache._mru_partials`` as an + ``OrderedDict`` keyed by ``parent_hash``. New entries land at the tail; + LRU eviction pops the oldest when the dict exceeds + ``mru_partial_max_entries``. A successful apply promotes the matched + entry to the tail (move_to_end). Apply-time miss for a found entry + pops only that key, leaving siblings intact. Same-key replacement on + stash is correct LRU put behavior. Stash and apply are gated on **uniform layer sliceability**: if any layer in the model is non-sliceable (RotatingKVCache, ArraysCache, - etc.) the slot is left empty. Splicing into only the sliceable + etc.) no entries are ever written. Splicing into only the sliceable layers would create per-layer offset skew at decode time. Memory accounting ----------------- ``kv_data`` holds real ``mx.array`` allocations (produced by - ``_clone_tensor`` → ``mx.copy``), so the slot's memory cost is - automatically counted by ``mx.get_active_memory()``. Every runtime + ``_clone_tensor`` → ``mx.copy``), so each entry's memory cost is + automatically counted by ``mx.get_active_memory()``. Every runtime memory enforcement and telemetry path in this codebase reads from - there: the process-level enforcer, the scheduler's prefill - mid-loop limit check, the prefill pre-flight peak check, the - generation admission guard, and the periodic-clear threshold. - - Upstream of those, ``EnginePool`` reserves a fraction of each - model's weight size as KV headroom when deciding whether to evict - other models before loading. The MRU partial is one tenant of - that headroom alongside in-flight prompt caches; it is not - separately reserved because at one ``block_size``-worth of KV per - cache instance (~17 MiB for Kimi K2.5 / DeepSeek MLA, ~41 MiB for - full-attention 70B models) it is dominated by the concurrent + there: the process-level enforcer, the scheduler's prefill mid-loop + limit check, the prefill pre-flight peak check, the generation + admission guard, and the periodic-clear threshold. + + Upstream of those, ``EnginePool`` reserves a fraction of each model's + weight size as KV headroom when deciding whether to evict other models + before loading. MRU partials are one tenant of that headroom alongside + in-flight prompt caches; they are not separately reserved because at + one ``block_size``-worth of KV per entry (~17 MiB for Kimi K2.5 / + DeepSeek MLA, ~41 MiB for full-attention 70B models) and the default + cap of 4 entries, the worst case (~68-165 MiB) is well below the in-flight caches the headroom was sized for. - There is no separate up-front KV memory budget that the slot could - escape (``_calculate_max_blocks`` is paged-SSD-only and returns a - fixed block-metadata count, not a memory budget). + Under ``hot_cache_only=True`` (settings or ``OMLX_HOT_CACHE_ONLY`` + env), the hot cache and the MRU dict both live in the same KV headroom + envelope. Both are bounded — the operator should be aware they share + a budget and tune ``--mru-partial-max-entries`` and + ``--hot-cache-max-size`` together rather than treating them as + independent dials. Future maintainers: do **not** add a separate accounting hook for - this slot. The invariant that ``kv_data`` holds ``mx.array`` - instances (not numpy/CPU copies) is what makes the implicit - accounting work; the test - ``test_kv_data_holds_mlx_arrays_for_active_memory_accounting`` - pins it. + these entries. The invariant that ``kv_data`` holds ``mx.array`` + instances (not numpy/CPU copies) is what makes the implicit accounting + work; the test + ``test_kv_data_holds_mlx_arrays_for_active_memory_accounting`` pins it. """ parent_hash: bytes | None @@ -138,6 +143,7 @@ def __init__( model: Any, paged_cache_manager: PagedCacheManager, paged_ssd_cache_manager: PagedSSDCacheManager | None = None, + mru_partial_max_entries: int = 4, ): """ Initialize block-aware prefix cache. @@ -146,6 +152,12 @@ def __init__( model: The MLX model (used for identification) paged_cache_manager: The PagedCacheManager instance for block management paged_ssd_cache_manager: The PagedSSDCacheManager for SSD storage (required for paged SSD-only mode) + mru_partial_max_entries: Maximum simultaneous MRU partial-block + stashes. Each entry is bounded at one ``block_size`` of KV + memory. ``0`` disables the MRU feature entirely (silent + fallback to "no MRU" behavior, mirroring + ``hot_cache_max_size="0"`` convention). Default of 4 matches + the dflash ``max_entries`` precedent (PR #1120). """ self.model = model self.model_key = id(model) @@ -167,10 +179,16 @@ def __init__( # Kept for API compatibility self._cold_restore_callback: Callable[[int, bytes], bool] | None = None - # Single-slot stash for the trailing sub-block tail (see - # _MRUPartialBlock). Populated by store_cache, consumed by - # apply_mru_partial. Always None for hybrid models. - self._mru_partial: _MRUPartialBlock | None = None + # Bounded LRU map for trailing sub-block tails (see _MRUPartialBlock). + # Keyed by parent_hash (block_hash of the last full block in the + # prefix, or None for short prompts whose prefix is < block_size). + # Populated by store_cache via _update_mru_partial, consumed by + # apply_mru_partial. Mirrors the OrderedDict + LRU pattern from + # PagedSSDCacheManager._hot_cache. Empty dict for hybrid models. + self._mru_partials: OrderedDict[ + bytes | None, _MRUPartialBlock + ] = OrderedDict() + self._mru_partial_max_entries: int = mru_partial_max_entries # Statistics self._hits = 0 @@ -746,13 +764,15 @@ def store_cache( return block_table def has_mru_partial(self) -> bool: - """Whether a trailing-tail partial is currently stashed. + """Whether any trailing-tail partial is currently stashed. Used by the scheduler to decide whether to suppress a deferred Metal cache clear (a fresh stash is a strong signal that the - same prompt may return immediately). + same prompt may return immediately). The "any entry present" + semantic is the right predicate regardless of multi-slot count + — one warm prompt is enough to merit suppression. """ - return self._mru_partial is not None + return bool(self._mru_partials) def _can_reconstruct(self) -> bool: """Whether ``reconstruct_cache`` has a path to return non-``None``. @@ -817,21 +837,29 @@ def _update_mru_partial( layer_cache_types: list[str] | None, model_cache_config: ModelCacheConfig | None, ) -> None: - """Refresh the MRU partial slot from a just-completed store_cache. - - Clears the slot in every "no eligible tail" branch — no trailing - tokens, non-tensor data, no reconstruct path configured (see - ``_can_reconstruct``), hybrid model, extraction failure, or - ambiguous cache layout — so a stale partial cannot survive into - a future ``apply_mru_partial``. + """Stash the trailing partial from a just-completed ``store_cache``. + + On success, writes one entry into the LRU map keyed by + ``parent_hash``. If the map is at capacity, the oldest entry is + evicted via ``popitem(last=False)``. + + The "no eligible tail" branches (no trailing tokens, non-tensor + data, no reconstruct path configured, hybrid model, extraction + failure, ambiguous cache layout) bare-return — they signal a + local "nothing to stash this time," NOT a global "wipe the map." + That distinction is what lets multiple distinct-prefix entries + coexist under interleaving. """ + if self._mru_partial_max_entries <= 0: + # Feature disabled. Match the hot_cache_max_size="0" convention: + # silent fallback to "no MRU" behavior. + return if ( trailing_partial_tokens == 0 or not is_tensor_data or not self._can_reconstruct() or not self._all_layers_sliceable(layer_cache_types) ): - self._mru_partial = None return partial_start = num_new_blocks * self.block_size @@ -866,7 +894,6 @@ def _update_mru_partial( else: # Ambiguous: cache_seq_len is short of global_end but does # not equal local_len either. Refuse rather than guess. - self._mru_partial = None return # is_last_block=False is the correct intent for partial extraction: @@ -880,7 +907,6 @@ def _update_mru_partial( is_last_block=False, ) if not partial_kv: - self._mru_partial = None return parent_hash: bytes | None = None @@ -890,16 +916,27 @@ def _update_mru_partial( if last_block is not None and last_block.block_hash: parent_hash = last_block.block_hash - self._mru_partial = _MRUPartialBlock( + # LRU put: pop any existing entry first so the re-insert lands at + # the tail with fresh order; then evict the oldest entry if over + # capacity. Mirrors PagedSSDCacheManager._hot_cache_put. + if parent_hash in self._mru_partials: + self._mru_partials.pop(parent_hash) + self._mru_partials[parent_hash] = _MRUPartialBlock( parent_hash=parent_hash, tokens=new_tokens[partial_start:], kv_data=partial_kv, ) + while len(self._mru_partials) > self._mru_partial_max_entries: + self._mru_partials.popitem(last=False) + logger.debug( - "Stashed MRU partial: %d tokens, parent_hash=%s, layers=%d", - len(self._mru_partial.tokens), + "Stashed MRU partial: %d tokens, parent_hash=%s, layers=%d, " + "entries=%d/%d", + trailing_partial_tokens, parent_hash[:8].hex() + "..." if parent_hash else "None", len(partial_kv), + len(self._mru_partials), + self._mru_partial_max_entries, ) def apply_mru_partial( @@ -908,21 +945,24 @@ def apply_mru_partial( block_table: BlockTable, remaining_tokens: list[int], ) -> tuple[list[Any], list[int], int]: - """Splice the MRU partial into a reconstructed cache, atomically. + """Splice an MRU partial entry into a reconstructed cache, atomically. - On a match (parent-hash chain ok, partial tokens are a prefix of - ``remaining_tokens``, layer count matches, every layer is sliceable - and accepts the concatenate), every layer's keys/values/offset are - advanced by ``len(partial.tokens)`` and the partial tokens are - consumed from the front of ``remaining_tokens``. + Looks up an entry in the multi-slot LRU map keyed by the block + hash of the last full block in ``block_table`` (``None`` for + short prompts). On a match (entry exists, partial tokens are a + prefix of ``remaining_tokens``, layer count matches, every layer + accepts the concatenate), every layer's keys/values/offset are + advanced by ``len(partial.tokens)``, the partial tokens are + consumed from the front of ``remaining_tokens``, and the entry + is moved to the LRU tail (most-recently-used). - On any miss or splice failure, the slot is evicted and the cache - is returned unchanged. + On any miss, only the matching entry is evicted; sibling entries + for other prefixes remain. The splice is **transactional**: replacement keys/values are - materialized for every layer before any layer is mutated. A + materialized for every layer before any layer is mutated. A failure on layer N rolls everything back; the caller never sees - a half-mutated cache. (This is the fix for review B3.) + a half-mutated cache. Args: cache: Reconstructed per-layer cache objects from @@ -935,39 +975,47 @@ def apply_mru_partial( ``(cache, remaining_tokens, tokens_applied)``. On miss, ``tokens_applied == 0`` and the inputs are returned unchanged. """ - partial = self._mru_partial - if partial is None or not remaining_tokens: + if not self._mru_partials or not remaining_tokens: return cache, remaining_tokens, 0 - # Parent-hash chain check + # Compute the parent_hash key. Multi-slot freed-block guard: + # if block_table is non-empty but the parent paged block has + # been freed (allocated_blocks.get returns None), do NOT fall + # through to a None-keyed dict lookup — that would falsely match + # a short-prompt entry against a request whose parent is just + # gone. Return no-op instead. This race is new in multi-slot + # mode; single-slot tolerated it because there was only ever + # one entry to match against. last_hash: bytes | None = None if block_table and block_table.block_ids: last_bid = block_table.block_ids[-1] last_block = self.paged_cache.allocated_blocks.get(last_bid) - if last_block is not None: - last_hash = last_block.block_hash - if partial.parent_hash != last_hash: - self._mru_partial = None + if last_block is None: + return cache, remaining_tokens, 0 + last_hash = last_block.block_hash + + partial = self._mru_partials.get(last_hash) + if partial is None: return cache, remaining_tokens, 0 n_partial = len(partial.tokens) if len(remaining_tokens) < n_partial: - self._mru_partial = None + self._mru_partials.pop(last_hash, None) return cache, remaining_tokens, 0 if remaining_tokens[:n_partial] != partial.tokens: - self._mru_partial = None + self._mru_partials.pop(last_hash, None) return cache, remaining_tokens, 0 if len(partial.kv_data) != len(cache): logger.debug( - "MRU partial layer count mismatch: %d vs %d, evicting", + "MRU partial layer count mismatch: %d vs %d, evicting entry", len(partial.kv_data), len(cache), ) - self._mru_partial = None + self._mru_partials.pop(last_hash, None) return cache, remaining_tokens, 0 if not HAS_MLX: - self._mru_partial = None + self._mru_partials.pop(last_hash, None) return cache, remaining_tokens, 0 # Phase 1: build per-layer replacements without touching the cache. @@ -980,8 +1028,10 @@ def apply_mru_partial( new_values = mx.concatenate([cache_obj.values, p_values], axis=2) replacements.append((layer_idx, new_keys, new_values)) except Exception as e: - logger.debug("MRU partial splice build failed: %s, evicting", e) - self._mru_partial = None + logger.debug( + "MRU partial splice build failed: %s, evicting entry", e + ) + self._mru_partials.pop(last_hash, None) return cache, remaining_tokens, 0 # Phase 2: commit. All concatenates have already succeeded; the @@ -993,10 +1043,13 @@ def apply_mru_partial( cache_obj.values = new_values cache_obj.offset += n_partial + # Promote this entry to the LRU tail (most-recently-used). + self._mru_partials.move_to_end(last_hash) + new_remaining = remaining_tokens[n_partial:] logger.debug( - "Applied MRU partial: %d tokens, %d remaining", - n_partial, len(new_remaining), + "Applied MRU partial: %d tokens, %d remaining, entries=%d", + n_partial, len(new_remaining), len(self._mru_partials), ) return cache, new_remaining, n_partial @@ -2766,16 +2819,37 @@ def clear(self) -> int: self._request_tables.clear() self._prefix_index.clear() self.paged_cache.clear() - # The MRU partial chains from a paged-block hash; once the paged - # cache is wiped, the partial cannot be safely applied even if a - # future request happens to reproduce the same parent_hash. - # Cache-corruption recovery (Scheduler._recover_from_cache_error) - # routes through here, so a stale partial would otherwise survive - # exactly the recovery path that exists because something was wrong. - self._mru_partial = None + # MRU partials chain from paged-block hashes; once the paged + # cache is wiped, no entry can be safely applied. Cache-corruption + # recovery (Scheduler._recover_from_cache_error) routes through + # here, so stale partials would otherwise survive exactly the + # recovery path that exists because something was wrong. + self._mru_partials.clear() self.reset_stats() return cleared_count + def clear_mru_partials(self) -> int: + """Wipe only the MRU partial cache, leaving paged blocks intact. + + Intended consumer: admin-triggered cache-tier clears that drop + the backing block storage (``clear_ssd_cache`` admin endpoint, + and the future ``clear_hot_cache`` endpoint once PR #1183 + lands). Without this hook, a stash whose ``parent_hash`` chains + from a paged block whose underlying KV bytes were just flushed + from the hot/SSD tier would survive in memory and waste a + reconstruct attempt on the next request before being naturally + evicted by LRU. + + Distinct from ``clear()``: this method only drops MRU entries + and does not touch the paged cache, prefix index, or stats. + + Returns: + Number of MRU entries that were wiped. + """ + n = len(self._mru_partials) + self._mru_partials.clear() + return n + def set_cold_restore_callback( self, callback: Callable[[int, bytes], bool] | None, diff --git a/omlx/cli.py b/omlx/cli.py index b53c4a29..7e3932d4 100644 --- a/omlx/cli.py +++ b/omlx/cli.py @@ -222,6 +222,15 @@ def serve_command(args): else: scheduler_config.hot_cache_max_size = 0 + # MRU partial cache: CLI arg > settings. Independent of SSD/hot + # configuration — the MRU stash works in any mode where + # reconstruct_cache has a path (which is gated by manager presence, + # not by SSD writes; see BlockAwarePrefixCache._can_reconstruct). + if args.mru_partial_max_entries is not None: + scheduler_config.mru_partial_max_entries = int(args.mru_partial_max_entries) + else: + scheduler_config.mru_partial_max_entries = settings.cache.mru_partial_max_entries + if args.no_cache: print("Mode: Multi-model serving (no oMLX cache, mlx-lm BatchGenerator only)") elif paged_ssd_cache_dir: @@ -595,6 +604,17 @@ def main(): default=None, help="Maximum in-memory hot cache size (e.g., '8GB', '4GB'). Default: 0 (disabled)", ) + serve_parser.add_argument( + "--mru-partial-max-entries", + type=int, + default=None, + help=( + "Maximum simultaneous MRU partial-block stashes. Each entry is " + "bounded at one block_size of KV memory. 0 disables the feature. " + "Default: 4. Under --hot-cache-only this shares the in-memory KV " + "headroom envelope with --hot-cache-max-size; tune both together." + ), + ) serve_parser.add_argument( "--no-cache", action="store_true", diff --git a/omlx/scheduler.py b/omlx/scheduler.py index 5fefde0b..bbf9ab67 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -549,6 +549,11 @@ class SchedulerConfig: hot_cache_only: bool = False paged_ssd_cache_max_size: int = 100 * 1024 * 1024 * 1024 # 100GB default hot_cache_max_size: int = 0 # In-memory hot cache size in bytes (0 = disabled) + # Bounded LRU stash for trailing sub-block partials of previous + # prefills, keyed by parent-block hash. Each entry holds at most + # one block_size of KV memory. ``0`` disables the feature; default + # of 4 matches the dflash max_entries precedent (PR #1120). + mru_partial_max_entries: int = 4 # Model identification (for cache isolation between different models) model_name: str = "" # OpenAI API model name (e.g., "mlx-community/Llama-3.2-3B") @@ -799,6 +804,7 @@ def __init__( self.block_aware_cache = BlockAwarePrefixCache( model=model, paged_cache_manager=self.paged_cache_manager, + mru_partial_max_entries=self.config.mru_partial_max_entries, ) # Initialize paged SSD cache @@ -3468,6 +3474,7 @@ def set_specprefill_draft_model( model=draft_model, paged_cache_manager=draft_paged, paged_ssd_cache_manager=self.paged_ssd_cache_manager, + mru_partial_max_entries=self.config.mru_partial_max_entries, ) self._draft_prefix_cache.set_cold_restore_callback( self._restore_block_from_cold diff --git a/omlx/settings.py b/omlx/settings.py index cdc97e74..c8673486 100644 --- a/omlx/settings.py +++ b/omlx/settings.py @@ -252,6 +252,11 @@ class CacheSettings: ssd_cache_max_size: str = "auto" # "auto" means 10% of SSD capacity hot_cache_max_size: str = "0" # "0" = disabled, e.g. "8GB" initial_cache_blocks: int = 256 # Starting blocks (grows dynamically) + # Bounded LRU stash for the trailing sub-block partial of a previous + # prefill, keyed by parent-block hash. Each entry holds at most one + # block_size of KV memory. ``0`` disables the feature; default of 4 + # matches the dflash max_entries precedent (PR #1120). + mru_partial_max_entries: int = 4 def get_ssd_cache_dir(self, base_path: Path) -> Path: """ @@ -295,6 +300,7 @@ def to_dict(self) -> dict[str, Any]: "ssd_cache_max_size": self.ssd_cache_max_size, "hot_cache_max_size": self.hot_cache_max_size, "initial_cache_blocks": self.initial_cache_blocks, + "mru_partial_max_entries": self.mru_partial_max_entries, } @classmethod @@ -307,6 +313,7 @@ def from_dict(cls, data: dict[str, Any]) -> CacheSettings: ssd_cache_max_size=data.get("ssd_cache_max_size", "auto"), hot_cache_max_size=data.get("hot_cache_max_size", "0"), initial_cache_blocks=data.get("initial_cache_blocks", 256), + mru_partial_max_entries=data.get("mru_partial_max_entries", 4), ) diff --git a/tests/test_prefix_cache.py b/tests/test_prefix_cache.py index 9f14e511..cb6fa3a8 100644 --- a/tests/test_prefix_cache.py +++ b/tests/test_prefix_cache.py @@ -2419,27 +2419,42 @@ def test_store_cache_last_block_with_snapshot_uses_snapshot_meta(self, mx): ) +def _get_mru_partial(cache, parent_hash): + """Test-only accessor for one MRU partial entry by parent_hash key. + + Lives in the test module rather than on ``BlockAwarePrefixCache`` + itself: production code has no consumer that looks up entries by + arbitrary key (the scheduler only needs the boolean predicate via + ``has_mru_partial()`` and the dict lookup inside ``apply_mru_partial``). + Tests use this helper to assert on individual entries without + coupling to the internal ``_mru_partials`` container shape. + """ + return cache._mru_partials.get(parent_hash) + + class TestMRUPartialBlockCache: """Tests for the MRU partial block cache. - The MRU is a single-slot in-memory cache holding the trailing sub-block - tail from the most recent ``store_cache`` call. It lets exact-repeat - requests skip re-prefilling those tail tokens. + The MRU is a bounded LRU dict of trailing sub-block tails keyed by + ``parent_hash``. It lets exact-repeat requests skip re-prefilling + those tail tokens, and tolerates interleaving (multi-user / multi- + conversation workloads) by keeping multiple distinct-prefix entries + coexistent up to ``mru_partial_max_entries``. Threat-model coverage these tests enforce: - **Hybrid refusal:** when any layer is non-sliceable (RotatingKVCache, - ArraysCache, etc.), the stash is suppressed entirely. Splicing into + ArraysCache, etc.), the stash is suppressed entirely. Splicing into only the sliceable layers would create per-layer offset skew at decode time. - **Transactional splice:** if any layer's concatenate fails, no layer - sees a mutated keys/values/offset. Half-mutated caches are silent + sees a mutated keys/values/offset. Half-mutated caches are silent generation corruption. - - **Real round-trip:** store_cache populates ``_mru_partial`` via the - production extraction path; apply_mru_partial then splices it. The - tests do not hand-build ``_MRUPartialBlock`` objects for the splice + - **Real round-trip:** ``store_cache`` populates entries via the + production extraction path; ``apply_mru_partial`` then splices. + Tests do not hand-build ``_MRUPartialBlock`` objects for splice cases — that hides the extraction-vs-apply boundary the original - branch's tests missed. + single-slot branch's tests missed. """ @pytest.fixture @@ -2547,7 +2562,7 @@ def __init__(self, k, v, offset): # --- initial state --- def test_init_state_empty(self, prefix_cache): - assert prefix_cache._mru_partial is None + assert not prefix_cache._mru_partials assert prefix_cache.has_mru_partial() is False # --- stash semantics on uniformly sliceable layers --- @@ -2557,64 +2572,87 @@ def test_stash_after_store_with_trailing_tokens(self, prefix_cache, mx): tokens = [10, 20, 30, 40, 50, 60] cache_data = [self._kv_layer(mx, 6) for _ in range(4)] - prefix_cache.store_cache("req-stash", tokens, cache_data) + block_table = prefix_cache.store_cache("req-stash", tokens, cache_data) - partial = prefix_cache._mru_partial + parent_hash = prefix_cache.paged_cache.allocated_blocks[ + block_table.block_ids[-1] + ].block_hash + partial = _get_mru_partial(prefix_cache, parent_hash) assert partial is not None assert partial.tokens == [50, 60] assert len(partial.kv_data) == 4 assert prefix_cache.has_mru_partial() is True def test_no_stash_when_block_aligned(self, prefix_cache, mx): - """Block-aligned tokens leave no trailing partial → slot cleared.""" + """Block-aligned tokens leave no trailing partial → no entry written.""" tokens = [10, 20, 30, 40] cache_data = [self._kv_layer(mx, 4) for _ in range(4)] prefix_cache.store_cache("req-aligned", tokens, cache_data) - assert prefix_cache._mru_partial is None + assert not prefix_cache._mru_partials assert prefix_cache.has_mru_partial() is False - def test_stash_replaced_on_subsequent_store(self, prefix_cache, mx): - """Each store_cache replaces the previous slot — single-slot semantics.""" + def test_same_prefix_store_replaces_entry(self, prefix_cache, mx): + """Same prefix → same parent_hash → same dict key → replace. + + Two stores with identical prefix tokens but different tails + collide on the same key (parent_hash chains from identical + prefix blocks). The newer tail replaces the older one in the + single dict entry — that is correct LRU put behavior. + """ for tail in (50, 99): tokens = [10, 20, 30, 40, tail] cache_data = [self._kv_layer(mx, 5) for _ in range(4)] prefix_cache.store_cache(f"req-{tail}", tokens, cache_data) - assert prefix_cache._mru_partial is not None - assert prefix_cache._mru_partial.tokens == [99] + # Exactly one entry; its tokens are the latest tail. + assert len(prefix_cache._mru_partials) == 1 + partial = next(iter(prefix_cache._mru_partials.values())) + assert partial.tokens == [99] - def test_stash_clears_when_subsequent_store_is_block_aligned( + def test_no_eligible_tail_does_not_evict_siblings( self, prefix_cache, mx ): - """Stash from request A is cleared when request B has no trailing partial.""" - # First: stash a partial + """Behavioral change vs single-slot: a block-aligned store (no + trailing tail) MUST NOT wipe sibling entries from other prefixes. + + Single-slot mode used to clear the lone slot in this branch. + Multi-slot mode treats "nothing eligible to stash this time" + as a local signal — sibling entries for distinct prefixes are + unrelated and stay. + """ + # First: stash a partial via prefix A. prefix_cache.store_cache( "req-a", [10, 20, 30, 40, 50], [self._kv_layer(mx, 5) for _ in range(4)] ) - assert prefix_cache._mru_partial is not None + assert len(prefix_cache._mru_partials) == 1 + before_key = next(iter(prefix_cache._mru_partials.keys())) - # Second: block-aligned, must clear the stash so an apply attempt - # against a stale partial cannot succeed. + # Second: block-aligned store on a DIFFERENT prefix — no tail to + # stash, but must not evict the existing sibling. prefix_cache.store_cache( "req-b", [11, 22, 33, 44], [self._kv_layer(mx, 4) for _ in range(4)] ) - assert prefix_cache._mru_partial is None + assert len(prefix_cache._mru_partials) == 1 + assert next(iter(prefix_cache._mru_partials.keys())) == before_key def test_stash_records_parent_hash_from_last_block(self, prefix_cache, mx): - """Stashed partial chains from the hash of the last full block.""" + """Stashed entry is keyed by the hash of the last full block.""" tokens = [10, 20, 30, 40, 50, 60] cache_data = [self._kv_layer(mx, 6) for _ in range(4)] block_table = prefix_cache.store_cache("req-hash", tokens, cache_data) - # The last (and only) block's hash should be the partial's parent. + # The last (and only) block's hash should be the dict key AND + # the partial's stored parent_hash. last_block = prefix_cache.paged_cache.allocated_blocks[ block_table.block_ids[-1] ] - assert prefix_cache._mru_partial.parent_hash == last_block.block_hash assert last_block.block_hash is not None + partial = _get_mru_partial(prefix_cache, last_block.block_hash) + assert partial is not None + assert partial.parent_hash == last_block.block_hash # --- threat model: hybrid refusal (B1, B2) --- @@ -2645,7 +2683,7 @@ def test_refuse_stash_when_any_layer_non_sliceable_hybrid( cache.store_cache("req-hybrid", tokens, cache_data, model_cache_config=config) - assert cache._mru_partial is None + assert not cache._mru_partials assert cache.has_mru_partial() is False def test_refuse_stash_when_all_layers_non_sliceable(self, prefix_cache, mx): @@ -2662,7 +2700,7 @@ def test_refuse_stash_when_all_layers_non_sliceable(self, prefix_cache, mx): "req-rotating", tokens, cache_data, model_cache_config=config ) - assert prefix_cache._mru_partial is None + assert not prefix_cache._mru_partials def test_refuse_stash_when_layer_falls_through_to_default_handler( self, prefix_cache, mx @@ -2704,34 +2742,30 @@ def test_refuse_stash_when_layer_falls_through_to_default_handler( "req-batch-rotating", tokens, cache_data, model_cache_config=config ) - assert prefix_cache._mru_partial is None + assert not prefix_cache._mru_partials assert prefix_cache.has_mru_partial() is False # --- threat model: stale-slot eviction at clear() (C2) --- - def test_clear_wipes_mru_partial(self, prefix_cache, mx): - """``BlockAwarePrefixCache.clear()`` must drop the MRU slot. + def test_clear_wipes_mru_partials(self, prefix_cache, mx): + """``BlockAwarePrefixCache.clear()`` must drop the entire MRU dict. The scheduler's cache-corruption recovery routes through - ``clear()``. A surviving partial chains from a paged-block hash - whose backing block was just freed; if a future request happens - to reproduce the same prompt prefix, ``compute_block_hash`` could - coincidentally yield the same hash and the partial would splice - into a freshly-reconstructed cache. The KV would happen to be - correct (the model is deterministic) but the chain-of-trust is - gone — and the survival happens via *exactly the recovery path - that exists because something was wrong*. + ``clear()``. Surviving partials chain from paged-block hashes + whose backing blocks were just freed; the dict is wiped so no + entry can survive into the recovery path that exists because + something was wrong. """ prefix_cache.store_cache( "req-clear", [10, 20, 30, 40, 50, 60], [self._kv_layer(mx, 6) for _ in range(4)], ) - assert prefix_cache._mru_partial is not None + assert bool(prefix_cache._mru_partials) prefix_cache.clear() - assert prefix_cache._mru_partial is None + assert not prefix_cache._mru_partials assert prefix_cache.has_mru_partial() is False # --- threat model: H2 ambiguous cache layout --- @@ -2776,7 +2810,7 @@ def test_refuse_stash_on_ambiguous_cache_layout( # Refuse rather than guess. Stash must be the previous turn's # state cleared (block-aligned turn 1 has no stash anyway), not # a guessed-wrong turn 2. - assert prefix_cache._mru_partial is None + assert not prefix_cache._mru_partials # --- accounting invariant --- @@ -2798,13 +2832,16 @@ def test_kv_data_holds_mlx_arrays_for_active_memory_accounting( under load. Pin the invariant so that change is caught at test time, not in production. """ - prefix_cache.store_cache( + block_table = prefix_cache.store_cache( "req-accounting", [10, 20, 30, 40, 50, 60], [self._kv_layer(mx, 6) for _ in range(4)], ) - partial = prefix_cache._mru_partial + parent_hash = prefix_cache.paged_cache.allocated_blocks[ + block_table.block_ids[-1] + ].block_hash + partial = _get_mru_partial(prefix_cache, parent_hash) assert partial is not None assert len(partial.kv_data) == 4 for layer_idx, (keys, values) in enumerate(partial.kv_data): @@ -2845,7 +2882,7 @@ def test_no_stash_when_paged_ssd_cache_is_none(self, paged_cache, mx): [self._kv_layer(mx, 6) for _ in range(4)], ) - assert cache._mru_partial is None + assert not cache._mru_partials assert cache.has_mru_partial() is False def test_can_reconstruct_helper_reflects_manager_presence( @@ -2918,25 +2955,42 @@ def test_apply_round_trip_prefix_match_leaves_extra_tokens( # --- apply: eviction reasons --- - def test_apply_evicts_on_parent_hash_mismatch(self, prefix_cache, mx): - """If the matched prefix's last block hash differs from the - partial's parent_hash, the partial is from a different prefix - and must be evicted.""" + def test_apply_noop_on_parent_hash_mismatch_preserves_sibling( + self, prefix_cache, mx, paged_cache + ): + """A request keyed by a parent_hash that isn't in the dict + returns no-op WITHOUT evicting unrelated sibling entries. + + Behavioral change from single-slot: the single slot used to be + evicted whenever the lookup key didn't match. In multi-slot, + the lookup simply misses and other entries are preserved. + """ + # Stash a partial under prefix A. tokens = [10, 20, 30, 40, 50, 60] - cache_data = [self._kv_layer(mx, 6) for _ in range(4)] - block_table = prefix_cache.store_cache("req-evict-h", tokens, cache_data) + block_table_a = prefix_cache.store_cache( + "req-a", tokens, [self._kv_layer(mx, 6) for _ in range(4)] + ) + before = dict(prefix_cache._mru_partials) + assert len(before) == 1 - # Tamper with the partial's parent_hash so it doesn't chain. - prefix_cache._mru_partial.parent_hash = b"different-prefix" + # Construct a synthetic block_table pointing at a block whose + # hash is NOT a key in the dict (simulate "request for a + # different prefix that has its own paged block"). + other_block = paged_cache.allocate_block() + other_block.block_hash = b"\x00" * 32 # not in the MRU dict + synthetic_bt = BlockTable(request_id="req-other") + synthetic_bt.block_ids.append(other_block.block_id) + synthetic_bt.num_tokens = 4 reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=4) _, new_remaining, applied = prefix_cache.apply_mru_partial( - reconstructed, block_table, [50, 60], + reconstructed, synthetic_bt, [50, 60], ) assert applied == 0 assert new_remaining == [50, 60] - assert prefix_cache._mru_partial is None + # Prefix A's entry must still be present — no false eviction. + assert dict(prefix_cache._mru_partials) == before def test_apply_evicts_on_token_mismatch(self, prefix_cache, mx): """Different trailing tokens → partial cannot apply, evict.""" @@ -2951,7 +3005,7 @@ def test_apply_evicts_on_token_mismatch(self, prefix_cache, mx): assert applied == 0 assert new_remaining == [99, 60] - assert prefix_cache._mru_partial is None + assert not prefix_cache._mru_partials def test_apply_evicts_on_remaining_shorter_than_partial( self, prefix_cache, mx @@ -2968,7 +3022,7 @@ def test_apply_evicts_on_remaining_shorter_than_partial( ) assert applied == 0 - assert prefix_cache._mru_partial is None + assert not prefix_cache._mru_partials def test_apply_evicts_on_layer_count_mismatch(self, prefix_cache, mx): """If the reconstructed cache layer count differs from the @@ -2984,7 +3038,7 @@ def test_apply_evicts_on_layer_count_mismatch(self, prefix_cache, mx): ) assert applied == 0 - assert prefix_cache._mru_partial is None + assert not prefix_cache._mru_partials def test_apply_noop_when_no_stash(self, prefix_cache, paged_cache, mx): """No partial → no-op, remaining unchanged.""" @@ -3013,7 +3067,7 @@ def test_apply_noop_when_remaining_empty(self, prefix_cache, mx): assert applied == 0 # Stash must NOT be evicted on empty-remaining no-op — it could # still match a *future* request that does have tail tokens. - assert prefix_cache._mru_partial is not None + assert bool(prefix_cache._mru_partials) # --- threat model: transactional splice rollback (B3) --- @@ -3061,7 +3115,7 @@ def flaky_concatenate(*args, **kwargs): # No layer's keys shape changed. assert [layer.keys.shape for layer in reconstructed] == before_key_shapes # Slot is evicted on splice failure (don't retry a failing partial). - assert prefix_cache._mru_partial is None + assert not prefix_cache._mru_partials # --- threat model: multi-turn (existing_tokens > 0) --- @@ -3107,9 +3161,14 @@ def test_stash_correct_indices_when_existing_tokens_present( "class_name": "KVCache", }) - prefix_cache.store_cache("req-turn-2", full_tokens, full_cache) + block_table = prefix_cache.store_cache( + "req-turn-2", full_tokens, full_cache + ) - partial = prefix_cache._mru_partial + parent_hash = prefix_cache.paged_cache.allocated_blocks[ + block_table.block_ids[-1] + ].block_hash + partial = _get_mru_partial(prefix_cache, parent_hash) assert partial is not None assert partial.tokens == [9, 10] # Each layer's stashed slice must be the trailing region (fill=2.0). @@ -3119,11 +3178,270 @@ def test_stash_correct_indices_when_existing_tokens_present( assert mx.allclose(values, mx.full((1, 1, 2, 4), 2.0)) +class TestMRUPartialMultiSlot: + """Multi-slot LRU semantics: coexistence, capacity, eviction discipline. + + These tests cover the mechanics that single-slot mode could not + exercise — multiple entries keyed by distinct ``parent_hash`` values, + LRU promotion on apply success, sibling preservation on apply miss, + capacity-bounded eviction, ``max_entries=0`` feature disable, and + the freed-paged-block guard introduced by the multi-slot design. + """ + + @pytest.fixture + def mx(self): + try: + import mlx.core as mx + return mx + except ImportError: + pytest.skip("MLX not available") + + @pytest.fixture + def paged_cache(self): + return PagedCacheManager( + block_size=4, + max_blocks=100, + model_name="test-model", + initial_blocks=100, + ) + + @pytest.fixture + def mock_ssd(self): + mock = MagicMock() + mock.save_block.return_value = True + mock.load_block.return_value = None + mock.load_block_with_metadata.return_value = (None, None) + mock.has_block.return_value = False + return mock + + def _cache(self, paged_cache, mock_ssd, max_entries=4): + return BlockAwarePrefixCache( + model=MockModel(num_layers=4), + paged_cache_manager=paged_cache, + paged_ssd_cache_manager=mock_ssd, + mru_partial_max_entries=max_entries, + ) + + def _kv_layer(self, mx, n_tokens, head_dim=4): + return { + "state": ( + mx.full((1, 1, n_tokens, head_dim), 1.0), + mx.full((1, 1, n_tokens, head_dim), 1.0), + ), + "cache_type": "KVCache", + "class_name": "KVCache", + } + + def _make_reconstructed_cache(self, mx, n_layers, n_tokens, head_dim=4): + class MockKVCache: + def __init__(self, k, v, offset): + self.keys = k + self.values = v + self.offset = offset + + return [ + MockKVCache( + mx.ones((1, 1, n_tokens, head_dim)), + mx.ones((1, 1, n_tokens, head_dim)), + n_tokens, + ) + for _ in range(n_layers) + ] + + def _stash_with_prefix(self, cache, mx, prefix_marker, tail_token): + """Store a partial under a distinct parent_hash. + + Builds a prompt whose first 4 tokens are unique to ``prefix_marker`` + (forcing a unique parent block hash) and whose 5th token is the + partial tail. Returns (block_table, parent_hash). + """ + tokens = [prefix_marker * 10 + i for i in range(4)] + [tail_token] + cache_data = [self._kv_layer(mx, 5) for _ in range(4)] + block_table = cache.store_cache( + f"req-{prefix_marker}", tokens, cache_data + ) + parent_hash = cache.paged_cache.allocated_blocks[ + block_table.block_ids[-1] + ].block_hash + return block_table, parent_hash + + # --- multi-entry coexistence --- + + def test_distinct_prefixes_coexist_as_separate_entries( + self, paged_cache, mock_ssd, mx + ): + """Two stashes with different parent prefixes produce two entries.""" + cache = self._cache(paged_cache, mock_ssd, max_entries=4) + _, hash_a = self._stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + _, hash_b = self._stash_with_prefix(cache, mx, prefix_marker=2, tail_token=88) + + assert hash_a != hash_b + assert len(cache._mru_partials) == 2 + assert hash_a in cache._mru_partials + assert hash_b in cache._mru_partials + + # --- LRU mechanics (parameterized) --- + + @pytest.mark.parametrize( + "scenario", + [ + # (capacity, n_stashes_in_order, expected_dict_keys_in_order) + # Capacity respected; oldest evicted on overflow. + ("evict_oldest_at_capacity", 2, [1, 2, 3], [2, 3]), + # Below capacity, all retained, insertion order preserved. + ("under_capacity_keeps_all", 4, [1, 2, 3], [1, 2, 3]), + ], + ids=lambda s: s[0] if isinstance(s, tuple) else str(s), + ) + def test_lru_capacity_bounds( + self, paged_cache, mock_ssd, mx, scenario + ): + _, capacity, order, expected = scenario + cache = self._cache(paged_cache, mock_ssd, max_entries=capacity) + hashes = {} + for marker in order: + _, h = self._stash_with_prefix( + cache, mx, prefix_marker=marker, tail_token=900 + marker + ) + hashes[marker] = h + + expected_keys = [hashes[m] for m in expected] + assert list(cache._mru_partials.keys()) == expected_keys + + def test_apply_success_promotes_entry_to_lru_tail( + self, paged_cache, mock_ssd, mx + ): + """Applying an entry moves it to the LRU tail; a subsequent + capacity-eviction drops a now-older sibling, not the just-used + entry.""" + cache = self._cache(paged_cache, mock_ssd, max_entries=2) + bt_a, hash_a = self._stash_with_prefix( + cache, mx, prefix_marker=1, tail_token=901 + ) + _, hash_b = self._stash_with_prefix( + cache, mx, prefix_marker=2, tail_token=902 + ) + assert list(cache._mru_partials.keys()) == [hash_a, hash_b] + + # Apply A → A promoted to tail. + reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + # Remaining must equal A's stashed tokens. Tail token was 901, + # placed at index 4 of A's prompt. + _, _, applied = cache.apply_mru_partial(reconstructed, bt_a, [901]) + assert applied == 1 + assert list(cache._mru_partials.keys()) == [hash_b, hash_a] + + # Stash C at capacity 2 → B evicted (oldest after promote), A kept. + _, hash_c = self._stash_with_prefix( + cache, mx, prefix_marker=3, tail_token=903 + ) + assert list(cache._mru_partials.keys()) == [hash_a, hash_c] + assert hash_b not in cache._mru_partials + + # --- max_entries=0 disables --- + + def test_max_entries_zero_disables_stashing( + self, paged_cache, mock_ssd, mx + ): + cache = self._cache(paged_cache, mock_ssd, max_entries=0) + self._stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + + assert len(cache._mru_partials) == 0 + assert cache.has_mru_partial() is False + + # --- clear_mru_partials() leaves siblings alone --- + + def test_clear_mru_partials_wipes_only_partials( + self, paged_cache, mock_ssd, mx + ): + """``clear_mru_partials()`` is the admin-clear hook. It wipes the + MRU dict but must not touch ``paged_cache``, the prefix index, + or stats — those have their own clear paths.""" + cache = self._cache(paged_cache, mock_ssd, max_entries=4) + bt, _ = self._stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + + prefix_index_before = dict(cache._prefix_index) + request_tables_before = dict(cache._request_tables) + assert len(cache._mru_partials) == 1 + assert prefix_index_before # was populated by store_cache + + n_wiped = cache.clear_mru_partials() + + assert n_wiped == 1 + assert len(cache._mru_partials) == 0 + # Paged blocks, prefix index, request tables all unchanged. + assert cache._prefix_index == prefix_index_before + assert cache._request_tables == request_tables_before + assert bt.block_ids[-1] in cache.paged_cache.allocated_blocks + + # --- freed-block guard (new in multi-slot) --- + + def test_apply_noop_when_parent_block_freed( + self, paged_cache, mock_ssd, mx + ): + """If the parent paged block is freed between stash and apply, + the apply path must not fall through to a None-keyed lookup + (which could falsely match a short-prompt entry). + + This race is new in multi-slot: single-slot tolerated it because + there was only ever one slot to match against. + """ + cache = self._cache(paged_cache, mock_ssd, max_entries=4) + # Stash a short-prompt entry (parent_hash=None) — the false-match + # bait for the freed-block scenario. + short_tokens = [99, 100, 101] # < block_size=4 + cache.store_cache( + "req-short", short_tokens, [self._kv_layer(mx, 3) for _ in range(4)] + ) + # Confirm the short-prompt entry landed under None. + assert None in cache._mru_partials + + # Construct a block_table whose last block has been freed. + freed_block = paged_cache.allocate_block() + freed_block_id = freed_block.block_id + paged_cache.free_block(freed_block_id) + bt = BlockTable(request_id="req-freed") + bt.block_ids.append(freed_block_id) + bt.num_tokens = 4 + + reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + _, new_remaining, applied = cache.apply_mru_partial( + reconstructed, bt, [99, 100, 101] + ) + + # Must NOT splice the short-prompt entry even though the + # remaining tokens happen to match. + assert applied == 0 + assert new_remaining == [99, 100, 101] + # Short-prompt entry preserved (not falsely evicted by the guard). + assert None in cache._mru_partials + + # --- short-prompt None-key coexists with hash-keyed entry --- + + def test_short_prompt_none_key_coexists_with_block_aligned_entry( + self, paged_cache, mock_ssd, mx + ): + cache = self._cache(paged_cache, mock_ssd, max_entries=4) + # Short prompt (< block_size) → parent_hash=None + cache.store_cache( + "req-short", [10, 20, 30], + [self._kv_layer(mx, 3) for _ in range(4)], + ) + # Longer prompt → distinct parent_hash + _, hash_long = self._stash_with_prefix( + cache, mx, prefix_marker=1, tail_token=99 + ) + + assert None in cache._mru_partials + assert hash_long in cache._mru_partials + assert len(cache._mru_partials) == 2 + + class TestHasMRUPartial: """The has_mru_partial() accessor is the public API the scheduler uses to decide whether to suppress the deferred Metal cache clear.""" - def test_has_mru_partial_reflects_slot(self): + def test_has_mru_partial_reflects_dict_emptiness(self): cache = BlockAwarePrefixCache( model=MockModel(num_layers=2), paged_cache_manager=PagedCacheManager( @@ -3133,10 +3451,10 @@ def test_has_mru_partial_reflects_slot(self): ) assert cache.has_mru_partial() is False - cache._mru_partial = _MRUPartialBlock( + cache._mru_partials[b"x"] = _MRUPartialBlock( parent_hash=b"x", tokens=[1], kv_data=[], ) assert cache.has_mru_partial() is True - cache._mru_partial = None + cache._mru_partials.clear() assert cache.has_mru_partial() is False diff --git a/tests/test_settings.py b/tests/test_settings.py index f3a1f54b..d3ccc4c6 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -342,6 +342,7 @@ def test_to_dict(self): "ssd_cache_max_size": "50GB", "hot_cache_max_size": "0", "initial_cache_blocks": 256, + "mru_partial_max_entries": 4, } def test_from_dict(self): From cce198ffa16e2096e50a0125adcfd0124af2c3bf Mon Sep 17 00:00:00 2001 From: Blightbow Date: Wed, 13 May 2026 17:47:50 -0400 Subject: [PATCH 10/18] fix(admin): clear_ssd_cache also wipes MRU partials The ``/api/ssd-cache/clear`` admin endpoint at ``omlx/admin/routes.py:3975`` wipes the SSD-backed paged blocks per loaded scheduler but did not touch ``BlockAwarePrefixCache._mru_partials``. Surviving MRU partials then chain from paged-block hashes whose KV bytes were just flushed, violating the operator's "drop all warm caches" intent. Under ``hot_cache_only=True`` (where the hot tier IS the only persistent store) the same hazard would also apply to PR #1183's forthcoming ``/api/hot-cache/clear`` endpoint when it lands; that wiring is a one-line follow-up at the same loop using the same method. Wire ``block_aware_cache.clear_mru_partials()`` into the per-scheduler loop alongside ``ssd_manager.clear()``, with the same defensive try/except wrapper. A failure clearing one scheduler's MRU does not prevent siblings from being cleared. The standalone ``clear_mru_partials()`` method (added in the previous commit, see ``omlx/cache/prefix_cache.py``) is the public seam. Its own unit coverage lives in ``TestMRUPartialMultiSlot::test_clear_mru_partials_wipes_only_partials``; this commit adds two endpoint-level tests in ``TestClearSSDCacheAlsoWipesMRUPartials`` that pin the wiring: - ``test_endpoint_calls_clear_mru_partials_on_each_scheduler`` confirms both ``ssd_manager.clear()`` and ``block_aware_cache.clear_mru_partials()`` fire for every loaded scheduler. - ``test_mru_clear_failure_does_not_block_other_scheduler`` pins the defensive try/except: an exception in one scheduler's clear path must not stop the loop. Suite results on this commit: 62 passed: tests/test_admin_api_key.py (1 pre-existing unrelated failure remains) Co-Authored-By: Claude Opus 4.7 (1M context) --- omlx/admin/routes.py | 16 +++++++++++ tests/test_admin_api_key.py | 55 +++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/omlx/admin/routes.py b/omlx/admin/routes.py index cb7d4989..885747ca 100644 --- a/omlx/admin/routes.py +++ b/omlx/admin/routes.py @@ -3994,6 +3994,22 @@ async def clear_ssd_cache(is_admin: bool = Depends(require_admin)): exc, ) + # MRU partials chain from paged-block hashes whose KV bytes are + # gone after the ssd_manager.clear() above. Drop them so the + # admin "clear all warm caches" intent is honoured symmetrically. + # Single-tier behaviour (no clear) is the surviving-stash hazard + # the peer review caught for this endpoint. + block_aware_cache = getattr(scheduler, "block_aware_cache", None) + if block_aware_cache is not None: + try: + block_aware_cache.clear_mru_partials() + except Exception as exc: + logger.warning( + "Failed to clear MRU partials for model '%s': %s", + model_id, + exc, + ) + # Phase 2: remove any remaining files on disk (covers unloaded models) global_settings = _get_global_settings() if global_settings is not None: diff --git a/tests/test_admin_api_key.py b/tests/test_admin_api_key.py index 0ace324e..07af7419 100644 --- a/tests/test_admin_api_key.py +++ b/tests/test_admin_api_key.py @@ -881,6 +881,61 @@ def test_runtime_cache_marks_sub_block_cached_when_indexed_blocks_zero(self): assert model_payload["last_partial_tokens_skipped"] == 577 +class TestClearSSDCacheAlsoWipesMRUPartials: + """The ``/api/ssd-cache/clear`` admin endpoint must also clear the + MRU partial cache. Without this, partials chain from paged-block + hashes whose backing KV bytes were just flushed by ssd_manager.clear() + and the operator's "drop all warm caches" intent is only half-honored. + """ + + def _scheduler_with_mocks(self): + """Build a SimpleNamespace scheduler with mock ssd manager and + block_aware_cache that we can introspect after the endpoint runs.""" + ssd_manager = MagicMock() + ssd_manager.clear.return_value = 5 # arbitrary deleted count + block_aware_cache = MagicMock() + block_aware_cache.clear_mru_partials.return_value = 3 + return SimpleNamespace( + paged_ssd_cache_manager=ssd_manager, + block_aware_cache=block_aware_cache, + ) + + def test_endpoint_calls_clear_mru_partials_on_each_scheduler(self): + scheduler_a = self._scheduler_with_mocks() + scheduler_b = self._scheduler_with_mocks() + + with patch.object( + admin_routes, + "_iter_loaded_schedulers", + return_value=iter([("model-a", scheduler_a), ("model-b", scheduler_b)]), + ), patch.object(admin_routes, "_get_global_settings", return_value=None): + asyncio.run(admin_routes.clear_ssd_cache(is_admin=True)) + + scheduler_a.paged_ssd_cache_manager.clear.assert_called_once() + scheduler_a.block_aware_cache.clear_mru_partials.assert_called_once() + scheduler_b.paged_ssd_cache_manager.clear.assert_called_once() + scheduler_b.block_aware_cache.clear_mru_partials.assert_called_once() + + def test_mru_clear_failure_does_not_block_other_scheduler(self): + """A failure in one scheduler's clear_mru_partials must be logged + but not prevent other schedulers from being cleared.""" + scheduler_a = self._scheduler_with_mocks() + scheduler_a.block_aware_cache.clear_mru_partials.side_effect = RuntimeError( + "boom" + ) + scheduler_b = self._scheduler_with_mocks() + + with patch.object( + admin_routes, + "_iter_loaded_schedulers", + return_value=iter([("model-a", scheduler_a), ("model-b", scheduler_b)]), + ), patch.object(admin_routes, "_get_global_settings", return_value=None): + asyncio.run(admin_routes.clear_ssd_cache(is_admin=True)) + + # Scheduler B must still have been cleared despite A's failure. + scheduler_b.block_aware_cache.clear_mru_partials.assert_called_once() + + class TestGlobalSettingsValidation: """Tests for stricter GlobalSettingsRequest validation.""" From 18bc7ccbe3ecc43df9541d7ba993ee661623055f Mon Sep 17 00:00:00 2001 From: Blightbow Date: Wed, 13 May 2026 18:13:55 -0400 Subject: [PATCH 11/18] feat(cache): observability counters for MRU partial cache MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Plug the multi-slot MRU partial cache into the same observability surface that #1183 established for the prefix and hot/disk tiers. Without these counters, operators have no signal for whether ``--mru-partial-max-entries`` is tuned right; with them, the dashboard answers "is the MRU paying off" with the same shape it answers for memory and disk hit rates. Counters (cumulative, on PrefixCacheStats) ------------------------------------------ - ``mru_partial_stashes`` — every successful entry write, including same-key replacements. Same-key replacement does NOT count as an eviction; operator sees stash payoff via the hits/stashes ratio. - ``mru_partial_hits`` — every successful splice via apply_mru_partial. - ``mru_partial_evictions`` — total entries removed. Includes capacity-overflow LRU evictions, apply-time mismatch pops (token, layer-count, splice failure), and ``clear_mru_partials()`` wipes. Does NOT include full ``clear()`` (cache-corruption recovery) — that path also calls ``reset_stats()``, so incrementing evictions there would be incoherent (the increment gets zeroed immediately). Operators tracking partial-only wipes use ``clear_mru_partials()``. - ``mru_partial_tokens_saved`` — sum of ``n_partial`` across hits. The direct compute-saved measure: each unit is one token of prefill forward-pass that did NOT have to run. Gauges (live state, on PrefixCacheStats) ---------------------------------------- - ``mru_partial_entries`` — current dict length. - ``mru_partial_max_entries`` — configured capacity (operator-facing). Plumbing -------- Counters thread through ``Scheduler._collect_cache_counters`` into the existing ``CacheRateTracker``. ``observability._compute_window`` and ``_compute_cumulative`` add per-counter deltas plus a derived ``mru_partial_hit_rate`` (= hits / stashes, with the same zero-stashes-no-NaN guard the other ratios use). The admin ``_build_runtime_cache_observability`` emits per-model entries/ max_entries gauges and aggregates them at the payload level the same way ``hot_cache_entries`` and ``hot_cache_max_bytes`` do. Dashboard --------- Mirrors #1183's hot-cache surface: - **Header gauge** "MRU tails N/M entries" next to the Memory and SSD gauges, visible only when ``mru_partial_max_entries > 0``. - **Rate strip** gains "MRU Tail Hit Rate" and "MRU Tokens Saved" cells. Grid expands from 4 cells (hot cache only) to 6 cells (both tiers). When only one or the other tier is enabled, the layout stays at 4 cells with the disabled tier's cells hidden via ``x-show``. - **Per-model table** gains an "MRU Tails" column showing ``entries / max_entries`` for each loaded model. Test coverage (10 new cases) ---------------------------- ``TestMRUPartialCounters`` (8 cases) — initial zeros, stash-counter bumps, same-key-replacement-is-not-eviction, capacity-overflow eviction count, apply-success bumps hits+tokens_saved, apply-miss eviction, ``clear_mru_partials()`` bulk eviction count, ``clear()`` zeros everything semantics, ``reset_stats()`` zeros cumulative counters but preserves live entries. ``TestCacheRateTrackerRates`` (3 new cases) — mru_partial_hit_rate windowed + cumulative, zero-stashes no-NaN guard, tokens_saved delta accumulation. ``TestRuntimeCacheObservability`` updated to reflect the two new per-model payload keys (``mru_partial_entries``, ``mru_partial_max_entries``). Suite results on this commit: 250 passed: cache + scheduler + observability + admin + settings Co-Authored-By: Claude Opus 4.7 (1M context) --- omlx/admin/routes.py | 20 ++ omlx/admin/static/js/dashboard.js | 11 +- omlx/admin/templates/dashboard/_status.html | 25 ++- omlx/cache/observability.py | 20 ++ omlx/cache/prefix_cache.py | 43 +++- omlx/cache/stats.py | 16 ++ omlx/scheduler.py | 6 + tests/test_admin_api_key.py | 4 + tests/test_cache_observability.py | 35 ++++ tests/test_prefix_cache.py | 218 ++++++++++++++++++++ 10 files changed, 394 insertions(+), 4 deletions(-) diff --git a/omlx/admin/routes.py b/omlx/admin/routes.py index 885747ca..f94cbcc8 100644 --- a/omlx/admin/routes.py +++ b/omlx/admin/routes.py @@ -3544,6 +3544,11 @@ def _build_runtime_cache_observability( "hot_cache_max_bytes": 0, "hot_cache_size_bytes": 0, "hot_cache_entries": 0, + # MRU partial cache (memory-only, sub-block tail of prior prefills). + # Capacity sums across models; current entries sum across models; + # max-capacity is the highest configured per-model value. + "mru_partial_entries": 0, + "mru_partial_max_entries": 0, } engine_pool = _get_engine_pool() @@ -3659,6 +3664,12 @@ def _build_runtime_cache_observability( "hot_cache_max_bytes": int(ssd_stats.get("hot_cache_max_bytes", 0) or 0), "hot_cache_size_bytes": int(ssd_stats.get("hot_cache_size_bytes", 0) or 0), "hot_cache_entries": int(ssd_stats.get("hot_cache_entries", 0) or 0), + "mru_partial_entries": int( + prefix_stats.get("mru_partial_entries", 0) or 0 + ), + "mru_partial_max_entries": int( + prefix_stats.get("mru_partial_max_entries", 0) or 0 + ), } cache_rates = runtime_stats.get("cache_rates") @@ -3684,15 +3695,24 @@ def _build_runtime_cache_observability( disk_max = payload["disk_max_bytes"] hot_cache_size_total = 0 hot_cache_entries_total = 0 + mru_entries_total = 0 + mru_max_entries_total = 0 for m in payload["models"]: hot_cache_size_total += m.get("hot_cache_size_bytes", 0) hot_cache_entries_total += m.get("hot_cache_entries", 0) hot_cache_max += m.get("hot_cache_max_bytes", 0) disk_max = max(disk_max, m.get("max_size_bytes", 0)) + # MRU partials: each model has its own dict; entries sum across + # models the same way the hot cache does. max_entries sums for + # the same reason (each model reserves its own slice of capacity). + mru_entries_total += m.get("mru_partial_entries", 0) + mru_max_entries_total += m.get("mru_partial_max_entries", 0) payload["hot_cache_max_bytes"] = hot_cache_max payload["hot_cache_size_bytes"] = hot_cache_size_total payload["hot_cache_entries"] = hot_cache_entries_total payload["disk_max_bytes"] = disk_max + payload["mru_partial_entries"] = mru_entries_total + payload["mru_partial_max_entries"] = mru_max_entries_total # Fallback: if no loaded models contributed stats, scan the cache # directory directly so the dashboard still shows real disk usage. diff --git a/omlx/admin/static/js/dashboard.js b/omlx/admin/static/js/dashboard.js index 4b4d9a54..67d927a9 100644 --- a/omlx/admin/static/js/dashboard.js +++ b/omlx/admin/static/js/dashboard.js @@ -2199,7 +2199,7 @@ return entry?.cache_rates?.cumulative || {}; } - const sumKeys = ['prefix_hits', 'prefix_misses', 'evictions', 'ssd_hot_hits', 'ssd_disk_loads', 'ssd_saves', 'hot_cache_evictions', 'hot_cache_promotions']; + const sumKeys = ['prefix_hits', 'prefix_misses', 'evictions', 'ssd_hot_hits', 'ssd_disk_loads', 'ssd_saves', 'hot_cache_evictions', 'hot_cache_promotions', 'mru_partial_stashes', 'mru_partial_hits', 'mru_partial_evictions', 'mru_partial_tokens_saved']; let agg = {}; for (const m of entries) { @@ -2214,8 +2214,11 @@ const pm = agg.prefix_misses || 0; const sh = agg.ssd_hot_hits || 0; const sd = agg.ssd_disk_loads || 0; + const ms = agg.mru_partial_stashes || 0; + const mh = agg.mru_partial_hits || 0; agg.prefix_hit_rate = (ph + pm) > 0 ? ph / (ph + pm) : 0; agg.ssd_hot_rate = (sh + sd) > 0 ? sh / (sh + sd) : 0; + agg.mru_partial_hit_rate = ms > 0 ? mh / ms : 0; return agg; }, @@ -2287,6 +2290,12 @@ return Math.min(100, (rc.hot_cache_size_bytes / rc.hot_cache_max_bytes) * 100); }, + get runtimeMruPartialPercent() { + const rc = this.stats.runtime_cache; + if (!rc || !rc.mru_partial_max_entries) return 0; + return Math.min(100, (rc.mru_partial_entries / rc.mru_partial_max_entries) * 100); + }, + get runtimeSsdCachePercent() { const rc = this.stats.runtime_cache; if (!rc || !rc.disk_max_bytes) return 0; diff --git a/omlx/admin/templates/dashboard/_status.html b/omlx/admin/templates/dashboard/_status.html index 527c81c9..d2f95031 100644 --- a/omlx/admin/templates/dashboard/_status.html +++ b/omlx/admin/templates/dashboard/_status.html @@ -313,6 +313,17 @@

{{ t('status.head | + +
+ MRU tails +
+
+
+ +
+ |
SSD @@ -369,7 +380,7 @@

{{ t('status.head

Prefix Hit Rate

@@ -381,6 +392,11 @@

{{ t('status.head

+
+

MRU Tail Hit Rate

+

+

Prefix Evictions

{{ t('status.head

+
+

MRU Tokens Saved

+

+
{{ t('status.head

+ @@ -428,6 +450,7 @@

{{ t('status.head

+ diff --git a/omlx/cache/observability.py b/omlx/cache/observability.py index 72a0e370..c53c8a2d 100644 --- a/omlx/cache/observability.py +++ b/omlx/cache/observability.py @@ -101,6 +101,10 @@ def delta(key: str) -> int: d_ssd_disk = delta("ssd_disk_loads") d_tokens_matched = delta("prefix_tokens_matched") d_tokens_requested = delta("prefix_tokens_requested") + d_mru_stashes = delta("mru_partial_stashes") + d_mru_hits = delta("mru_partial_hits") + d_mru_evictions = delta("mru_partial_evictions") + d_mru_tokens_saved = delta("mru_partial_tokens_saved") minutes = elapsed / 60.0 @@ -120,6 +124,15 @@ def delta(key: str) -> int: "ssd_hot_rate": round( _safe_ratio(d_ssd_hot, d_ssd_hot + d_ssd_disk), 4 ), + # MRU partial stash payoff: fraction of stashes that paid off as + # an apply-time splice. Workload with rate ≈ 0 means stashes + # are mostly wasted (uniform/unrepeated prompts); rate near 1 + # means almost every stash got reused. + "mru_partial_stashes": d_mru_stashes, + "mru_partial_hits": d_mru_hits, + "mru_partial_evictions": d_mru_evictions, + "mru_partial_tokens_saved": d_mru_tokens_saved, + "mru_partial_hit_rate": round(_safe_ratio(d_mru_hits, d_mru_stashes), 4), } @@ -130,6 +143,8 @@ def _compute_cumulative(counters: dict[str, int]) -> dict[str, Any]: ssd_disk = counters.get("ssd_disk_loads", 0) tokens_matched = counters.get("prefix_tokens_matched", 0) tokens_requested = counters.get("prefix_tokens_requested", 0) + mru_stashes = counters.get("mru_partial_stashes", 0) + mru_hits = counters.get("mru_partial_hits", 0) return { "prefix_hits": prefix_hits, @@ -146,4 +161,9 @@ def _compute_cumulative(counters: dict[str, int]) -> dict[str, Any]: "hot_cache_evictions": counters.get("hot_cache_evictions", 0), "hot_cache_promotions": counters.get("hot_cache_promotions", 0), "ssd_hot_rate": round(_safe_ratio(ssd_hot, ssd_hot + ssd_disk), 4), + "mru_partial_stashes": mru_stashes, + "mru_partial_hits": mru_hits, + "mru_partial_evictions": counters.get("mru_partial_evictions", 0), + "mru_partial_tokens_saved": counters.get("mru_partial_tokens_saved", 0), + "mru_partial_hit_rate": round(_safe_ratio(mru_hits, mru_stashes), 4), } diff --git a/omlx/cache/prefix_cache.py b/omlx/cache/prefix_cache.py index 44fb528f..42947f4b 100644 --- a/omlx/cache/prefix_cache.py +++ b/omlx/cache/prefix_cache.py @@ -200,6 +200,11 @@ def __init__( self._tokens_requested_total = 0 self._last_partial_tokens_skipped = 0 self._last_tokens_to_next_block = 0 + # MRU partial cache cumulative counters (see PrefixCacheStats). + self._mru_partial_stashes = 0 + self._mru_partial_hits = 0 + self._mru_partial_evictions = 0 + self._mru_partial_tokens_saved = 0 def _get_model_num_layers(self, model: Any) -> int: """ @@ -919,6 +924,9 @@ def _update_mru_partial( # LRU put: pop any existing entry first so the re-insert lands at # the tail with fresh order; then evict the oldest entry if over # capacity. Mirrors PagedSSDCacheManager._hot_cache_put. + # Note: same-key replacement does NOT count as an eviction — + # the operator sees stash payoff via the stashes/hits ratio, not + # via replacement churn. if parent_hash in self._mru_partials: self._mru_partials.pop(parent_hash) self._mru_partials[parent_hash] = _MRUPartialBlock( @@ -926,8 +934,10 @@ def _update_mru_partial( tokens=new_tokens[partial_start:], kv_data=partial_kv, ) + self._mru_partial_stashes += 1 while len(self._mru_partials) > self._mru_partial_max_entries: self._mru_partials.popitem(last=False) + self._mru_partial_evictions += 1 logger.debug( "Stashed MRU partial: %d tokens, parent_hash=%s, layers=%d, " @@ -1001,9 +1011,11 @@ def apply_mru_partial( n_partial = len(partial.tokens) if len(remaining_tokens) < n_partial: self._mru_partials.pop(last_hash, None) + self._mru_partial_evictions += 1 return cache, remaining_tokens, 0 if remaining_tokens[:n_partial] != partial.tokens: self._mru_partials.pop(last_hash, None) + self._mru_partial_evictions += 1 return cache, remaining_tokens, 0 if len(partial.kv_data) != len(cache): @@ -1012,10 +1024,12 @@ def apply_mru_partial( len(partial.kv_data), len(cache), ) self._mru_partials.pop(last_hash, None) + self._mru_partial_evictions += 1 return cache, remaining_tokens, 0 if not HAS_MLX: self._mru_partials.pop(last_hash, None) + self._mru_partial_evictions += 1 return cache, remaining_tokens, 0 # Phase 1: build per-layer replacements without touching the cache. @@ -1032,6 +1046,7 @@ def apply_mru_partial( "MRU partial splice build failed: %s, evicting entry", e ) self._mru_partials.pop(last_hash, None) + self._mru_partial_evictions += 1 return cache, remaining_tokens, 0 # Phase 2: commit. All concatenates have already succeeded; the @@ -1045,6 +1060,8 @@ def apply_mru_partial( # Promote this entry to the LRU tail (most-recently-used). self._mru_partials.move_to_end(last_hash) + self._mru_partial_hits += 1 + self._mru_partial_tokens_saved += n_partial new_remaining = remaining_tokens[n_partial:] logger.debug( @@ -2763,6 +2780,12 @@ def get_stats(self) -> PrefixCacheStats: last_tokens_to_next_block=self._last_tokens_to_next_block, tokens_matched_total=self._tokens_matched_total, tokens_requested_total=self._tokens_requested_total, + mru_partial_stashes=self._mru_partial_stashes, + mru_partial_hits=self._mru_partial_hits, + mru_partial_evictions=self._mru_partial_evictions, + mru_partial_tokens_saved=self._mru_partial_tokens_saved, + mru_partial_entries=len(self._mru_partials), + mru_partial_max_entries=self._mru_partial_max_entries, ) def get_stats_dict(self) -> dict[str, Any]: @@ -2796,7 +2819,13 @@ def get_stats_dict(self) -> dict[str, Any]: } def reset_stats(self) -> None: - """Reset statistics.""" + """Reset statistics. + + Note: ``_mru_partials`` is live state (the cache itself), not a + counter — its size is reported as a gauge by ``get_stats()`` and + is not affected by stats reset. Use ``clear_mru_partials()`` if + the entries themselves should be wiped. + """ self._hits = 0 self._misses = 0 self._tokens_saved = 0 @@ -2806,6 +2835,10 @@ def reset_stats(self) -> None: self._tokens_requested_total = 0 self._last_partial_tokens_skipped = 0 self._last_tokens_to_next_block = 0 + self._mru_partial_stashes = 0 + self._mru_partial_hits = 0 + self._mru_partial_evictions = 0 + self._mru_partial_tokens_saved = 0 self.paged_cache.reset_stats() def clear(self) -> int: @@ -2824,6 +2857,10 @@ def clear(self) -> int: # recovery (Scheduler._recover_from_cache_error) routes through # here, so stale partials would otherwise survive exactly the # recovery path that exists because something was wrong. + # No eviction-counter bump here: clear() also calls reset_stats() + # which zeros every counter (this is the "restart everything" + # path). Operators tracking partial wipes specifically should + # use clear_mru_partials() instead — that path leaves stats alone. self._mru_partials.clear() self.reset_stats() return cleared_count @@ -2841,13 +2878,15 @@ def clear_mru_partials(self) -> int: evicted by LRU. Distinct from ``clear()``: this method only drops MRU entries - and does not touch the paged cache, prefix index, or stats. + and does not touch the paged cache, prefix index, or stats + (other than incrementing ``mru_partial_evictions``). Returns: Number of MRU entries that were wiped. """ n = len(self._mru_partials) self._mru_partials.clear() + self._mru_partial_evictions += n return n def set_cold_restore_callback( diff --git a/omlx/cache/stats.py b/omlx/cache/stats.py index 01a78c53..230824d4 100644 --- a/omlx/cache/stats.py +++ b/omlx/cache/stats.py @@ -90,6 +90,16 @@ class PrefixCacheStats(BaseCacheStats): last_tokens_to_next_block: int = 0 tokens_matched_total: int = 0 tokens_requested_total: int = 0 + # MRU partial cache observability. Cumulative counters that pair + # naturally — hits/stashes gives the "stash payoff" rate, evictions + # tracks churn, tokens_saved is the direct compute-saved measure. + # Gauges (entries / max_entries) describe live capacity utilisation. + mru_partial_stashes: int = 0 + mru_partial_hits: int = 0 + mru_partial_evictions: int = 0 + mru_partial_tokens_saved: int = 0 + mru_partial_entries: int = 0 + mru_partial_max_entries: int = 0 _total_queries: int = field(default=0, repr=False) @property @@ -115,6 +125,12 @@ def reset(self) -> None: self.last_tokens_to_next_block = 0 self.tokens_matched_total = 0 self.tokens_requested_total = 0 + self.mru_partial_stashes = 0 + self.mru_partial_hits = 0 + self.mru_partial_evictions = 0 + self.mru_partial_tokens_saved = 0 + # mru_partial_entries and mru_partial_max_entries are gauges + # populated by get_stats() from live state — not reset here. self._total_queries = 0 diff --git a/omlx/scheduler.py b/omlx/scheduler.py index bbf9ab67..a679c9c0 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -6175,6 +6175,12 @@ def _collect_cache_counters(self) -> dict[str, int] | None: "prefix_tokens_requested": prefix_stats.tokens_requested_total, "prefix_tokens_saved": prefix_stats.tokens_saved, "evictions": prefix_stats.evictions, + "mru_partial_stashes": prefix_stats.mru_partial_stashes, + "mru_partial_hits": prefix_stats.mru_partial_hits, + "mru_partial_evictions": prefix_stats.mru_partial_evictions, + "mru_partial_tokens_saved": prefix_stats.mru_partial_tokens_saved, + "mru_partial_entries": prefix_stats.mru_partial_entries, + "mru_partial_max_entries": prefix_stats.mru_partial_max_entries, } if self.paged_ssd_cache_manager is not None: diff --git a/tests/test_admin_api_key.py b/tests/test_admin_api_key.py index 07af7419..cb6fc4ef 100644 --- a/tests/test_admin_api_key.py +++ b/tests/test_admin_api_key.py @@ -749,6 +749,8 @@ def test_runtime_cache_uses_model_scoped_ssd_stats(self): "hot_cache_max_bytes": 0, "hot_cache_size_bytes": 0, "hot_cache_entries": 0, + "mru_partial_entries": 0, + "mru_partial_max_entries": 0, }, { "id": "model-b", @@ -766,6 +768,8 @@ def test_runtime_cache_uses_model_scoped_ssd_stats(self): "hot_cache_max_bytes": 0, "hot_cache_size_bytes": 0, "hot_cache_entries": 0, + "mru_partial_entries": 0, + "mru_partial_max_entries": 0, }, ] manager_a.get_stats_for_model.assert_called_once_with("/models/model-a") diff --git a/tests/test_cache_observability.py b/tests/test_cache_observability.py index 29239049..af16b2f2 100644 --- a/tests/test_cache_observability.py +++ b/tests/test_cache_observability.py @@ -24,6 +24,10 @@ def _make_counters( ssd_errors=0, hot_cache_evictions=0, hot_cache_promotions=0, + mru_partial_stashes=0, + mru_partial_hits=0, + mru_partial_evictions=0, + mru_partial_tokens_saved=0, ): return { "prefix_hits": prefix_hits, @@ -38,6 +42,10 @@ def _make_counters( "ssd_errors": ssd_errors, "hot_cache_evictions": hot_cache_evictions, "hot_cache_promotions": hot_cache_promotions, + "mru_partial_stashes": mru_partial_stashes, + "mru_partial_hits": mru_partial_hits, + "mru_partial_evictions": mru_partial_evictions, + "mru_partial_tokens_saved": mru_partial_tokens_saved, } @@ -120,6 +128,33 @@ def test_ssd_hot_rate(self): result = self._tracker_with_two_snapshots(old, new, elapsed=60.0) assert result["windows"]["1m"]["ssd_hot_rate"] == 0.8 + def test_mru_partial_hit_rate(self): + """Stash payoff = hits / stashes. Workload that stashed 100 and + only got 75 hits has rate 0.75; high enough to justify the + feature, low enough that a smaller capacity might do.""" + old = _make_counters(mru_partial_stashes=0, mru_partial_hits=0) + new = _make_counters(mru_partial_stashes=100, mru_partial_hits=75) + result = self._tracker_with_two_snapshots(old, new, elapsed=60.0) + assert result["windows"]["1m"]["mru_partial_hit_rate"] == 0.75 + assert result["cumulative"]["mru_partial_hit_rate"] == 0.75 + + def test_mru_partial_zero_stashes_no_nan(self): + """If no stashes happened in the window, hit_rate must be 0.0 + not NaN. Mirrors the prefix_hit_rate empty-window guard.""" + counters = _make_counters(mru_partial_stashes=0, mru_partial_hits=0) + result = self._tracker_with_two_snapshots(counters, counters, elapsed=60.0) + assert result["windows"]["1m"]["mru_partial_hit_rate"] == 0.0 + assert result["cumulative"]["mru_partial_hit_rate"] == 0.0 + + def test_mru_partial_tokens_saved_delta(self): + """Tokens saved is the direct compute-saved measure; it + accumulates regardless of hit rate.""" + old = _make_counters(mru_partial_tokens_saved=0) + new = _make_counters(mru_partial_tokens_saved=12345) + result = self._tracker_with_two_snapshots(old, new, elapsed=60.0) + assert result["windows"]["1m"]["mru_partial_tokens_saved"] == 12345 + assert result["cumulative"]["mru_partial_tokens_saved"] == 12345 + def test_insufficient_data_returns_empty_window(self): tracker = CacheRateTracker(min_interval=0.0) diff --git a/tests/test_prefix_cache.py b/tests/test_prefix_cache.py index cb6fa3a8..0fa43215 100644 --- a/tests/test_prefix_cache.py +++ b/tests/test_prefix_cache.py @@ -3437,6 +3437,224 @@ def test_short_prompt_none_key_coexists_with_block_aligned_entry( assert len(cache._mru_partials) == 2 +class TestMRUPartialCounters: + """The observability counters mirror PR #1183's pattern so operators + can answer "is the MRU cache paying off" with the same dashboard + surface they use for prefix-hit and memory-hit rates. + + Counters: ``mru_partial_stashes``, ``mru_partial_hits``, + ``mru_partial_evictions``, ``mru_partial_tokens_saved``. + Gauges: ``mru_partial_entries``, ``mru_partial_max_entries``. + """ + + @pytest.fixture + def mx(self): + try: + import mlx.core as mx + return mx + except ImportError: + pytest.skip("MLX not available") + + @pytest.fixture + def paged_cache(self): + return PagedCacheManager( + block_size=4, + max_blocks=100, + model_name="test-model", + initial_blocks=100, + ) + + @pytest.fixture + def mock_ssd(self): + mock = MagicMock() + mock.save_block.return_value = True + mock.load_block.return_value = None + mock.load_block_with_metadata.return_value = (None, None) + mock.has_block.return_value = False + return mock + + def _cache(self, paged_cache, mock_ssd, max_entries=4): + return BlockAwarePrefixCache( + model=MockModel(num_layers=4), + paged_cache_manager=paged_cache, + paged_ssd_cache_manager=mock_ssd, + mru_partial_max_entries=max_entries, + ) + + def _kv_layer(self, mx, n_tokens, head_dim=4): + return { + "state": ( + mx.full((1, 1, n_tokens, head_dim), 1.0), + mx.full((1, 1, n_tokens, head_dim), 1.0), + ), + "cache_type": "KVCache", + "class_name": "KVCache", + } + + def _make_reconstructed_cache(self, mx, n_layers, n_tokens, head_dim=4): + class MockKVCache: + def __init__(self, k, v, offset): + self.keys = k + self.values = v + self.offset = offset + + return [ + MockKVCache( + mx.ones((1, 1, n_tokens, head_dim)), + mx.ones((1, 1, n_tokens, head_dim)), + n_tokens, + ) + for _ in range(n_layers) + ] + + def _stash_with_prefix(self, cache, mx, prefix_marker, tail_token): + tokens = [prefix_marker * 10 + i for i in range(4)] + [tail_token] + cache_data = [self._kv_layer(mx, 5) for _ in range(4)] + block_table = cache.store_cache( + f"req-{prefix_marker}", tokens, cache_data + ) + parent_hash = cache.paged_cache.allocated_blocks[ + block_table.block_ids[-1] + ].block_hash + return block_table, parent_hash + + def test_initial_counters_are_zero(self, paged_cache, mock_ssd): + cache = self._cache(paged_cache, mock_ssd, max_entries=4) + stats = cache.get_stats() + assert stats.mru_partial_stashes == 0 + assert stats.mru_partial_hits == 0 + assert stats.mru_partial_evictions == 0 + assert stats.mru_partial_tokens_saved == 0 + assert stats.mru_partial_entries == 0 + assert stats.mru_partial_max_entries == 4 + + def test_stash_increments_stash_counter(self, paged_cache, mock_ssd, mx): + cache = self._cache(paged_cache, mock_ssd, max_entries=4) + self._stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + self._stash_with_prefix(cache, mx, prefix_marker=2, tail_token=88) + + stats = cache.get_stats() + assert stats.mru_partial_stashes == 2 + assert stats.mru_partial_entries == 2 + + def test_same_key_replacement_counts_as_stash_not_eviction( + self, paged_cache, mock_ssd, mx + ): + """Replacing an existing entry under the same key counts as a + stash but NOT as an eviction. Eviction is reserved for entries + that leave the dict (capacity overflow, apply-miss, clear).""" + cache = self._cache(paged_cache, mock_ssd, max_entries=4) + self._stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + self._stash_with_prefix(cache, mx, prefix_marker=1, tail_token=77) + + stats = cache.get_stats() + assert stats.mru_partial_stashes == 2 + assert stats.mru_partial_evictions == 0 + assert stats.mru_partial_entries == 1 + + def test_capacity_overflow_increments_eviction_counter( + self, paged_cache, mock_ssd, mx + ): + cache = self._cache(paged_cache, mock_ssd, max_entries=2) + for i in (1, 2, 3): + self._stash_with_prefix(cache, mx, prefix_marker=i, tail_token=100 + i) + + stats = cache.get_stats() + assert stats.mru_partial_stashes == 3 + assert stats.mru_partial_evictions == 1 # one entry pushed out + assert stats.mru_partial_entries == 2 + + def test_apply_success_increments_hits_and_tokens_saved( + self, paged_cache, mock_ssd, mx + ): + cache = self._cache(paged_cache, mock_ssd, max_entries=4) + bt, _ = self._stash_with_prefix( + cache, mx, prefix_marker=1, tail_token=901 + ) + + reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + _, _, applied = cache.apply_mru_partial(reconstructed, bt, [901]) + assert applied == 1 + + stats = cache.get_stats() + assert stats.mru_partial_hits == 1 + assert stats.mru_partial_tokens_saved == 1 + assert stats.mru_partial_evictions == 0 # success, not eviction + + def test_apply_miss_on_found_key_increments_eviction( + self, paged_cache, mock_ssd, mx + ): + """Token-mismatch eviction pops the matched key and counts.""" + cache = self._cache(paged_cache, mock_ssd, max_entries=4) + bt, _ = self._stash_with_prefix( + cache, mx, prefix_marker=1, tail_token=99 + ) + + reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + _, _, applied = cache.apply_mru_partial(reconstructed, bt, [77]) # wrong tail + assert applied == 0 + + stats = cache.get_stats() + assert stats.mru_partial_hits == 0 + assert stats.mru_partial_evictions == 1 + + def test_clear_mru_partials_counts_all_wiped_entries( + self, paged_cache, mock_ssd, mx + ): + cache = self._cache(paged_cache, mock_ssd, max_entries=4) + self._stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + self._stash_with_prefix(cache, mx, prefix_marker=2, tail_token=88) + self._stash_with_prefix(cache, mx, prefix_marker=3, tail_token=77) + + n = cache.clear_mru_partials() + assert n == 3 + + stats = cache.get_stats() + assert stats.mru_partial_evictions == 3 + assert stats.mru_partial_entries == 0 + + def test_clear_wipes_partials_and_resets_counters( + self, paged_cache, mock_ssd, mx + ): + """clear() is the "restart everything" path (cache-corruption + recovery). It wipes the dict AND resets every counter, + including mru_partial_evictions — incrementing evictions just + to have them zeroed by the same call would be incoherent. + Operators tracking partial wipes specifically use + clear_mru_partials() instead. + """ + cache = self._cache(paged_cache, mock_ssd, max_entries=4) + self._stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + self._stash_with_prefix(cache, mx, prefix_marker=2, tail_token=88) + + cache.clear() + + stats = cache.get_stats() + assert stats.mru_partial_entries == 0 # dict wiped + assert stats.mru_partial_stashes == 0 # counters reset + assert stats.mru_partial_evictions == 0 + assert stats.mru_partial_hits == 0 + + def test_reset_stats_zeros_mru_counters_but_keeps_live_state( + self, paged_cache, mock_ssd, mx + ): + """reset_stats() is the analyst's reset — it zeros cumulative + counters but leaves the live cache state alone. Use + clear_mru_partials() if entries should be dropped too.""" + cache = self._cache(paged_cache, mock_ssd, max_entries=4) + self._stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + + cache.reset_stats() + + stats = cache.get_stats() + assert stats.mru_partial_stashes == 0 + assert stats.mru_partial_hits == 0 + assert stats.mru_partial_evictions == 0 + assert stats.mru_partial_tokens_saved == 0 + # But the live entry is still there. + assert stats.mru_partial_entries == 1 + + class TestHasMRUPartial: """The has_mru_partial() accessor is the public API the scheduler uses to decide whether to suppress the deferred Metal cache clear.""" From cf844c1dc90aeacb36c5417c5348b0033a903c84 Mon Sep 17 00:00:00 2001 From: Blightbow Date: Wed, 13 May 2026 18:31:32 -0400 Subject: [PATCH 12/18] =?UTF-8?q?refactor(cache):=20simplify=20MRU=20stack?= =?UTF-8?q?=20=E2=80=94=20dedupe=20test=20helpers,=20=5Fevict=5Fmiss,=20Al?= =?UTF-8?q?pine=20getters?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Five small simplifications surfaced by a three-agent code review on the MRU partial cache commit stack. No behaviour change. **Hoist test helpers to module level.** Three MRU test classes (``TestMRUPartialBlockCache``, ``TestMRUPartialMultiSlot``, ``TestMRUPartialCounters``) each redefined ``_layer``, ``_kv_layer``, ``_rotating_layer``, ``_make_reconstructed_cache``, ``_stash_with_prefix``, and ``_cache`` (factory). Hoist to module-level helpers in ``tests/test_prefix_cache.py`` (alongside the existing ``_get_mru_partial`` accessor); the duplicate methods come out, ~120 lines of repetition collapses, call sites switch from ``self._kv_layer(...)`` to ``_kv_layer(...)``. The factory for custom capacity is renamed ``_make_mru_cache(paged_cache, mock_ssd, max_entries, num_layers)``. Per-class fixtures (``mx``, ``paged_cache``, ``mock_ssd``) stay class-local to avoid leaking fixture names into unrelated test classes in the same module. **Extract ``_evict_miss`` helper in ``apply_mru_partial``.** Five arms of ``self._mru_partials.pop(last_hash, None); self._mru_partial_evictions += 1; return cache, remaining_tokens, 0`` collapse into a single inner function. Each call site is now one line, the eviction-counter bookkeeping lives in one place, and the rollback contract is harder to break by accident. **Dashboard Alpine getters.** Three getters added to the dashboard root: ``mruEnabled``, ``hotCacheEnabled``, ``cacheRatesGridCols``. The previous expressions repeated ``stats.runtime_cache?.mru_partial_max_entries > 0`` in 8 places and a three-arm ternary chain in the rate-strip grid-class binding. The HTML now reads ``x-show="mruEnabled && stats.runtime_cache?.models?.length > 0"`` and ``:class="cacheRatesGridCols"`` at the relevant sites. **``_make_counters`` driven by a key tuple.** Previously took 16 explicit kwargs and re-listed every key in the returned dict body. Now: a module-level ``_COUNTER_KEYS`` tuple is the single source of truth; the helper builds a zero-initialised dict and applies ``**overrides``. Unknown keys raise (catches the typos the explicit signature used to catch). Adding a new observability counter is now one tuple entry instead of three coordinated changes (signature, dict body, and any call sites that wanted the default). **Prune ``_MRUPartialBlock`` docstring rot-prone bits.** Dropped specific MiB numbers (~17/41/68-165 MiB) and the PR# reference to #1120 from the memory-accounting section. The numbers were informative when written but would have aged past the model and config landscape they assumed. Kept the invariant statement and the test reference; removed the calibration table. **Findings reviewed and skipped:** - Aggregating ``mru_partial_max_entries`` across loaded models was flagged as "wrong arithmetic" — actually correct, matches the deliberate hot-cache convention from #1183 (each model has its own budget, the dashboard gauge shows fleet fill = sum of entries / sum of capacities). - ``_get_cache_seq_len`` per-block redundancy in ``store_cache`` — pre-existing pattern, not regressed by this stack, defer. - Phase-1 ``mx.concatenate`` × N layers × 2 dispatch shape — defer pending M3 Ultra measurement (per ``feedback_apple_silicon_perf.md`` memory: never estimate prefill costs). - ``paged_cache.allocated_blocks.get(...)`` direct dict access at 9+ pre-existing sites — wider refactor, out of scope. - ``_all_layers_sliceable`` vs ``_prompt_cache_needs_snapshots`` co-location — different inputs (class-name strings vs live cache objects), unification would be scope-creep. Suite results on this commit: 250 passed: cache + scheduler + observability + admin + settings. Co-Authored-By: Claude Opus 4.7 (1M context) --- omlx/admin/static/js/dashboard.js | 15 + omlx/admin/templates/dashboard/_status.html | 14 +- omlx/cache/prefix_cache.py | 72 ++-- tests/test_cache_observability.py | 76 ++-- tests/test_prefix_cache.py | 388 ++++++++------------ 5 files changed, 251 insertions(+), 314 deletions(-) diff --git a/omlx/admin/static/js/dashboard.js b/omlx/admin/static/js/dashboard.js index 67d927a9..e240a710 100644 --- a/omlx/admin/static/js/dashboard.js +++ b/omlx/admin/static/js/dashboard.js @@ -2296,6 +2296,21 @@ return Math.min(100, (rc.mru_partial_entries / rc.mru_partial_max_entries) * 100); }, + get mruEnabled() { + return (this.stats.runtime_cache?.mru_partial_max_entries || 0) > 0; + }, + + get hotCacheEnabled() { + return (this.stats.runtime_cache?.hot_cache_max_bytes || 0) > 0; + }, + + get cacheRatesGridCols() { + const both = this.hotCacheEnabled && this.mruEnabled; + if (both) return 'grid-cols-2 sm:grid-cols-6'; + if (this.hotCacheEnabled || this.mruEnabled) return 'grid-cols-2 sm:grid-cols-4'; + return 'grid-cols-2'; + }, + get runtimeSsdCachePercent() { const rc = this.stats.runtime_cache; if (!rc || !rc.disk_max_bytes) return 0; diff --git a/omlx/admin/templates/dashboard/_status.html b/omlx/admin/templates/dashboard/_status.html index d2f95031..2365247f 100644 --- a/omlx/admin/templates/dashboard/_status.html +++ b/omlx/admin/templates/dashboard/_status.html @@ -314,7 +314,7 @@

{{ t('status.head | -
+
MRU tails
{{ t('status.head
- | + |
SSD @@ -380,7 +380,7 @@

{{ t('status.head

Prefix Hit Rate

@@ -392,7 +392,7 @@

{{ t('status.head

-
+

MRU Tail Hit Rate

@@ -407,7 +407,7 @@

{{ t('status.head

-
+

MRU Tokens Saved

@@ -427,7 +427,7 @@

{{ t('status.head

- + @@ -450,7 +450,7 @@

{{ t('status.head

- + diff --git a/omlx/cache/prefix_cache.py b/omlx/cache/prefix_cache.py index 42947f4b..fc9c9f6a 100644 --- a/omlx/cache/prefix_cache.py +++ b/omlx/cache/prefix_cache.py @@ -71,33 +71,26 @@ class _MRUPartialBlock: ----------------- ``kv_data`` holds real ``mx.array`` allocations (produced by ``_clone_tensor`` → ``mx.copy``), so each entry's memory cost is - automatically counted by ``mx.get_active_memory()``. Every runtime - memory enforcement and telemetry path in this codebase reads from - there: the process-level enforcer, the scheduler's prefill mid-loop - limit check, the prefill pre-flight peak check, the generation - admission guard, and the periodic-clear threshold. - - Upstream of those, ``EnginePool`` reserves a fraction of each model's - weight size as KV headroom when deciding whether to evict other models - before loading. MRU partials are one tenant of that headroom alongside - in-flight prompt caches; they are not separately reserved because at - one ``block_size``-worth of KV per entry (~17 MiB for Kimi K2.5 / - DeepSeek MLA, ~41 MiB for full-attention 70B models) and the default - cap of 4 entries, the worst case (~68-165 MiB) is well below the - in-flight caches the headroom was sized for. - - Under ``hot_cache_only=True`` (settings or ``OMLX_HOT_CACHE_ONLY`` - env), the hot cache and the MRU dict both live in the same KV headroom - envelope. Both are bounded — the operator should be aware they share - a budget and tune ``--mru-partial-max-entries`` and - ``--hot-cache-max-size`` together rather than treating them as - independent dials. + automatically counted by ``mx.get_active_memory()`` and is therefore + visible to every runtime memory enforcement and telemetry path in + this codebase (process enforcer, scheduler limit checks, + periodic-clear threshold, prefill pre-flight peak check). + + Upstream of those, the engine pool reserves a fraction of each + model's weight size as KV headroom before admission. MRU partials + are one tenant of that headroom alongside in-flight prompt caches; + they are not separately reserved because each entry is bounded at + one ``block_size``-worth of KV and the default cap is small. Under + ``hot_cache_only=True``, the hot cache and the MRU dict share that + envelope — tune ``--mru-partial-max-entries`` and + ``--hot-cache-max-size`` together in that mode rather than treating + them as independent dials. Future maintainers: do **not** add a separate accounting hook for - these entries. The invariant that ``kv_data`` holds ``mx.array`` - instances (not numpy/CPU copies) is what makes the implicit accounting - work; the test - ``test_kv_data_holds_mlx_arrays_for_active_memory_accounting`` pins it. + these entries. The invariant that ``kv_data`` holds ``mx.array`` + instances (not numpy/CPU copies) is what makes the implicit + accounting work; ``test_kv_data_holds_mlx_arrays_for_active_memory_accounting`` + pins it. """ parent_hash: bytes | None @@ -1008,29 +1001,32 @@ def apply_mru_partial( if partial is None: return cache, remaining_tokens, 0 - n_partial = len(partial.tokens) - if len(remaining_tokens) < n_partial: + def _evict_miss() -> tuple[list[Any], list[int], int]: + """Pop the matched entry and return the no-op tuple. + + Used by every apply-time mismatch arm. Inlines to one line + at each call site and keeps the eviction-counter bookkeeping + in one place. + """ self._mru_partials.pop(last_hash, None) self._mru_partial_evictions += 1 return cache, remaining_tokens, 0 + + n_partial = len(partial.tokens) + if len(remaining_tokens) < n_partial: + return _evict_miss() if remaining_tokens[:n_partial] != partial.tokens: - self._mru_partials.pop(last_hash, None) - self._mru_partial_evictions += 1 - return cache, remaining_tokens, 0 + return _evict_miss() if len(partial.kv_data) != len(cache): logger.debug( "MRU partial layer count mismatch: %d vs %d, evicting entry", len(partial.kv_data), len(cache), ) - self._mru_partials.pop(last_hash, None) - self._mru_partial_evictions += 1 - return cache, remaining_tokens, 0 + return _evict_miss() if not HAS_MLX: - self._mru_partials.pop(last_hash, None) - self._mru_partial_evictions += 1 - return cache, remaining_tokens, 0 + return _evict_miss() # Phase 1: build per-layer replacements without touching the cache. # Any failure here is a clean rollback — no layer has been mutated. @@ -1045,9 +1041,7 @@ def apply_mru_partial( logger.debug( "MRU partial splice build failed: %s, evicting entry", e ) - self._mru_partials.pop(last_hash, None) - self._mru_partial_evictions += 1 - return cache, remaining_tokens, 0 + return _evict_miss() # Phase 2: commit. All concatenates have already succeeded; the # only operations remaining are attribute writes and an integer diff --git a/tests/test_cache_observability.py b/tests/test_cache_observability.py index af16b2f2..d0da07dc 100644 --- a/tests/test_cache_observability.py +++ b/tests/test_cache_observability.py @@ -11,42 +11,46 @@ from omlx.cache.observability import CacheRateTracker -def _make_counters( - prefix_hits=0, - prefix_misses=0, - prefix_tokens_matched=0, - prefix_tokens_requested=0, - prefix_tokens_saved=0, - evictions=0, - ssd_hot_hits=0, - ssd_disk_loads=0, - ssd_saves=0, - ssd_errors=0, - hot_cache_evictions=0, - hot_cache_promotions=0, - mru_partial_stashes=0, - mru_partial_hits=0, - mru_partial_evictions=0, - mru_partial_tokens_saved=0, -): - return { - "prefix_hits": prefix_hits, - "prefix_misses": prefix_misses, - "prefix_tokens_matched": prefix_tokens_matched, - "prefix_tokens_requested": prefix_tokens_requested, - "prefix_tokens_saved": prefix_tokens_saved, - "evictions": evictions, - "ssd_hot_hits": ssd_hot_hits, - "ssd_disk_loads": ssd_disk_loads, - "ssd_saves": ssd_saves, - "ssd_errors": ssd_errors, - "hot_cache_evictions": hot_cache_evictions, - "hot_cache_promotions": hot_cache_promotions, - "mru_partial_stashes": mru_partial_stashes, - "mru_partial_hits": mru_partial_hits, - "mru_partial_evictions": mru_partial_evictions, - "mru_partial_tokens_saved": mru_partial_tokens_saved, - } +# Counter keys produced by Scheduler._collect_cache_counters. Adding a new +# observability counter means adding it here and in the production code +# in lockstep; the explicit-kwargs alternative duplicated the schema once +# in the function signature and once in the dict construction. +_COUNTER_KEYS = ( + "prefix_hits", + "prefix_misses", + "prefix_tokens_matched", + "prefix_tokens_requested", + "prefix_tokens_saved", + "evictions", + "ssd_hot_hits", + "ssd_disk_loads", + "ssd_saves", + "ssd_errors", + "hot_cache_evictions", + "hot_cache_promotions", + "mru_partial_stashes", + "mru_partial_hits", + "mru_partial_evictions", + "mru_partial_tokens_saved", +) + + +def _make_counters(**overrides): + """Build a counter dict for snapshot testing. + + All keys default to ``0``; override individual values via kwargs. + Unknown keys raise — the typo-catching the explicit-signature form + used to provide. + """ + unknown = set(overrides) - set(_COUNTER_KEYS) + if unknown: + raise ValueError( + f"Unknown counter keys: {sorted(unknown)}. " + f"Known: {sorted(_COUNTER_KEYS)}" + ) + counters = {k: 0 for k in _COUNTER_KEYS} + counters.update(overrides) + return counters class TestCacheRateTrackerSnapshot: diff --git a/tests/test_prefix_cache.py b/tests/test_prefix_cache.py index 0fa43215..cb17e5f6 100644 --- a/tests/test_prefix_cache.py +++ b/tests/test_prefix_cache.py @@ -2432,6 +2432,85 @@ def _get_mru_partial(cache, parent_hash): return cache._mru_partials.get(parent_hash) +def _layer(mx, n_tokens, *, class_name="KVCache", head_dim=4, n_kv_heads=1, fill=1.0): + """Build a layer-state dict for store_cache. + + ``class_name`` selects the cache type (e.g. ``"KVCache"``, + ``"RotatingKVCache"``, ``"BatchRotatingKVCache"``). Both + ``cache_type`` and ``class_name`` keys are populated with the same + string — ``store_cache`` consults whichever the layer provides. + """ + return { + "state": ( + mx.full((1, n_kv_heads, n_tokens, head_dim), fill), + mx.full((1, n_kv_heads, n_tokens, head_dim), fill), + ), + "cache_type": class_name, + "class_name": class_name, + } + + +def _kv_layer(mx, n_tokens, head_dim=4, n_kv_heads=1, fill=1.0): + return _layer( + mx, n_tokens, + class_name="KVCache", + head_dim=head_dim, n_kv_heads=n_kv_heads, fill=fill, + ) + + +def _rotating_layer(mx, n_tokens, head_dim=4, n_kv_heads=1): + return _layer( + mx, n_tokens, + class_name="RotatingKVCache", + head_dim=head_dim, n_kv_heads=n_kv_heads, + ) + + +def _make_reconstructed_cache(mx, n_layers, n_tokens, head_dim=4): + """Build a list of MockKVCache objects matching what reconstruct_cache + would produce: keys.shape[2] == offset, valid region only.""" + class MockKVCache: + def __init__(self, k, v, offset): + self.keys = k + self.values = v + self.offset = offset + + return [ + MockKVCache( + mx.ones((1, 1, n_tokens, head_dim)), + mx.ones((1, 1, n_tokens, head_dim)), + n_tokens, + ) + for _ in range(n_layers) + ] + + +def _make_mru_cache(paged_cache, mock_ssd, max_entries=4, num_layers=4): + """Construct a ``BlockAwarePrefixCache`` with a custom MRU capacity.""" + return BlockAwarePrefixCache( + model=MockModel(num_layers=num_layers), + paged_cache_manager=paged_cache, + paged_ssd_cache_manager=mock_ssd, + mru_partial_max_entries=max_entries, + ) + + +def _stash_with_prefix(cache, mx, prefix_marker, tail_token): + """Stash a partial under a distinct parent_hash for multi-slot tests. + + Builds a prompt whose first 4 tokens are unique to ``prefix_marker`` + (forcing a unique parent block hash) and whose 5th token is the + partial tail. Returns ``(block_table, parent_hash)``. + """ + tokens = [prefix_marker * 10 + i for i in range(4)] + [tail_token] + cache_data = [_kv_layer(mx, 5) for _ in range(4)] + block_table = cache.store_cache(f"req-{prefix_marker}", tokens, cache_data) + parent_hash = cache.paged_cache.allocated_blocks[ + block_table.block_ids[-1] + ].block_hash + return block_table, parent_hash + + class TestMRUPartialBlockCache: """Tests for the MRU partial block cache. @@ -2500,65 +2579,6 @@ def prefix_cache(self, paged_cache, mock_ssd): paged_ssd_cache_manager=mock_ssd, ) - def _layer( - self, - mx, - n_tokens, - *, - class_name="KVCache", - head_dim=4, - n_kv_heads=1, - fill=1.0, - ): - """Build a layer-state dict for store_cache. - - ``class_name`` selects the cache type (e.g. ``"KVCache"``, - ``"RotatingKVCache"``, ``"BatchRotatingKVCache"``). Both - ``cache_type`` and ``class_name`` keys are populated with the - same string — ``store_cache`` consults whichever the layer - provides. - """ - return { - "state": ( - mx.full((1, n_kv_heads, n_tokens, head_dim), fill), - mx.full((1, n_kv_heads, n_tokens, head_dim), fill), - ), - "cache_type": class_name, - "class_name": class_name, - } - - def _kv_layer(self, mx, n_tokens, head_dim=4, n_kv_heads=1, fill=1.0): - return self._layer( - mx, n_tokens, - class_name="KVCache", - head_dim=head_dim, n_kv_heads=n_kv_heads, fill=fill, - ) - - def _rotating_layer(self, mx, n_tokens, head_dim=4, n_kv_heads=1): - return self._layer( - mx, n_tokens, - class_name="RotatingKVCache", - head_dim=head_dim, n_kv_heads=n_kv_heads, - ) - - def _make_reconstructed_cache(self, mx, n_layers, n_tokens, head_dim=4): - """Build a list of MockKVCache objects matching what reconstruct_cache - would produce: keys.shape[2] == offset, valid region only.""" - class MockKVCache: - def __init__(self, k, v, offset): - self.keys = k - self.values = v - self.offset = offset - - return [ - MockKVCache( - mx.ones((1, 1, n_tokens, head_dim)), - mx.ones((1, 1, n_tokens, head_dim)), - n_tokens, - ) - for _ in range(n_layers) - ] - # --- initial state --- def test_init_state_empty(self, prefix_cache): @@ -2570,7 +2590,7 @@ def test_init_state_empty(self, prefix_cache): def test_stash_after_store_with_trailing_tokens(self, prefix_cache, mx): """6 tokens, block_size=4 → 1 full block + 2 trailing → stash captured.""" tokens = [10, 20, 30, 40, 50, 60] - cache_data = [self._kv_layer(mx, 6) for _ in range(4)] + cache_data = [_kv_layer(mx, 6) for _ in range(4)] block_table = prefix_cache.store_cache("req-stash", tokens, cache_data) @@ -2586,7 +2606,7 @@ def test_stash_after_store_with_trailing_tokens(self, prefix_cache, mx): def test_no_stash_when_block_aligned(self, prefix_cache, mx): """Block-aligned tokens leave no trailing partial → no entry written.""" tokens = [10, 20, 30, 40] - cache_data = [self._kv_layer(mx, 4) for _ in range(4)] + cache_data = [_kv_layer(mx, 4) for _ in range(4)] prefix_cache.store_cache("req-aligned", tokens, cache_data) @@ -2603,7 +2623,7 @@ def test_same_prefix_store_replaces_entry(self, prefix_cache, mx): """ for tail in (50, 99): tokens = [10, 20, 30, 40, tail] - cache_data = [self._kv_layer(mx, 5) for _ in range(4)] + cache_data = [_kv_layer(mx, 5) for _ in range(4)] prefix_cache.store_cache(f"req-{tail}", tokens, cache_data) # Exactly one entry; its tokens are the latest tail. @@ -2624,7 +2644,7 @@ def test_no_eligible_tail_does_not_evict_siblings( """ # First: stash a partial via prefix A. prefix_cache.store_cache( - "req-a", [10, 20, 30, 40, 50], [self._kv_layer(mx, 5) for _ in range(4)] + "req-a", [10, 20, 30, 40, 50], [_kv_layer(mx, 5) for _ in range(4)] ) assert len(prefix_cache._mru_partials) == 1 before_key = next(iter(prefix_cache._mru_partials.keys())) @@ -2632,7 +2652,7 @@ def test_no_eligible_tail_does_not_evict_siblings( # Second: block-aligned store on a DIFFERENT prefix — no tail to # stash, but must not evict the existing sibling. prefix_cache.store_cache( - "req-b", [11, 22, 33, 44], [self._kv_layer(mx, 4) for _ in range(4)] + "req-b", [11, 22, 33, 44], [_kv_layer(mx, 4) for _ in range(4)] ) assert len(prefix_cache._mru_partials) == 1 assert next(iter(prefix_cache._mru_partials.keys())) == before_key @@ -2640,7 +2660,7 @@ def test_no_eligible_tail_does_not_evict_siblings( def test_stash_records_parent_hash_from_last_block(self, prefix_cache, mx): """Stashed entry is keyed by the hash of the last full block.""" tokens = [10, 20, 30, 40, 50, 60] - cache_data = [self._kv_layer(mx, 6) for _ in range(4)] + cache_data = [_kv_layer(mx, 6) for _ in range(4)] block_table = prefix_cache.store_cache("req-hash", tokens, cache_data) @@ -2674,8 +2694,8 @@ def test_refuse_stash_when_any_layer_non_sliceable_hybrid( ) tokens = [10, 20, 30, 40, 50, 60] cache_data = [ - self._kv_layer(mx, 6), - self._rotating_layer(mx, 6), + _kv_layer(mx, 6), + _rotating_layer(mx, 6), ] config = ModelCacheConfig.from_type_list( ["KVCache", "RotatingKVCache"], model_name="test" @@ -2691,7 +2711,7 @@ def test_refuse_stash_when_all_layers_non_sliceable(self, prefix_cache, mx): from omlx.cache.hybrid_cache import ModelCacheConfig tokens = [10, 20, 30, 40, 50] - cache_data = [self._rotating_layer(mx, 5) for _ in range(4)] + cache_data = [_rotating_layer(mx, 5) for _ in range(4)] config = ModelCacheConfig.from_type_list( ["RotatingKVCache"] * 4, model_name="test" ) @@ -2733,7 +2753,7 @@ def test_refuse_stash_when_layer_falls_through_to_default_handler( tokens = [10, 20, 30, 40, 50, 60] # cache_data shape doesn't matter — store_cache must refuse before # any extraction is attempted. - cache_data = [self._kv_layer(mx, 6) for _ in range(4)] + cache_data = [_kv_layer(mx, 6) for _ in range(4)] config = ModelCacheConfig.from_type_list( ["BatchRotatingKVCache"] * 4, model_name="test" ) @@ -2759,7 +2779,7 @@ def test_clear_wipes_mru_partials(self, prefix_cache, mx): prefix_cache.store_cache( "req-clear", [10, 20, 30, 40, 50, 60], - [self._kv_layer(mx, 6) for _ in range(4)], + [_kv_layer(mx, 6) for _ in range(4)], ) assert bool(prefix_cache._mru_partials) @@ -2792,7 +2812,7 @@ def test_refuse_stash_on_ambiguous_cache_layout( prefix_cache.store_cache( "req-turn-1", [1, 2, 3, 4], - [self._kv_layer(mx, 4) for _ in range(4)], + [_kv_layer(mx, 4) for _ in range(4)], ) # Second turn: 8 prefix-aligned tokens (1 full block + 1 partial-block). @@ -2801,7 +2821,7 @@ def test_refuse_stash_on_ambiguous_cache_layout( # - global_end = existing_tokens + new_count = 4 + 4 = 8 # global_end (8) > cache_seq_len (6) > local_len (4): ambiguous. full_tokens = [1, 2, 3, 4, 5, 6, 7, 8] - cache_data = [self._kv_layer(mx, 6) for _ in range(4)] + cache_data = [_kv_layer(mx, 6) for _ in range(4)] prefix_cache.store_cache( "req-turn-2-ambiguous", full_tokens, cache_data @@ -2835,7 +2855,7 @@ def test_kv_data_holds_mlx_arrays_for_active_memory_accounting( block_table = prefix_cache.store_cache( "req-accounting", [10, 20, 30, 40, 50, 60], - [self._kv_layer(mx, 6) for _ in range(4)], + [_kv_layer(mx, 6) for _ in range(4)], ) parent_hash = prefix_cache.paged_cache.allocated_blocks[ @@ -2879,7 +2899,7 @@ def test_no_stash_when_paged_ssd_cache_is_none(self, paged_cache, mx): cache.store_cache( "req-no-ssd", [10, 20, 30, 40, 50, 60], - [self._kv_layer(mx, 6) for _ in range(4)], + [_kv_layer(mx, 6) for _ in range(4)], ) assert not cache._mru_partials @@ -2917,11 +2937,11 @@ def test_apply_round_trip_exact_match(self, prefix_cache, mx): """Real store → apply round-trip: partial produced by extraction is consumed by the splice path, no hand-built _MRUPartialBlock.""" tokens = [10, 20, 30, 40, 50, 60] - cache_data = [self._kv_layer(mx, 6) for _ in range(4)] + cache_data = [_kv_layer(mx, 6) for _ in range(4)] block_table = prefix_cache.store_cache("req-rt", tokens, cache_data) # Reconstructed cache: 4 layers × 4 tokens (the prefix only). - reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) remaining = [50, 60] result, new_remaining, applied = prefix_cache.apply_mru_partial( @@ -2940,10 +2960,10 @@ def test_apply_round_trip_prefix_match_leaves_extra_tokens( """When remaining is longer than the partial, the partial covers its prefix and the rest is left for normal prefill.""" tokens = [10, 20, 30, 40, 50, 60] - cache_data = [self._kv_layer(mx, 6) for _ in range(4)] + cache_data = [_kv_layer(mx, 6) for _ in range(4)] block_table = prefix_cache.store_cache("req-rt-prefix", tokens, cache_data) - reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) remaining = [50, 60, 70, 80] # partial is [50, 60]; [70, 80] left over _, new_remaining, applied = prefix_cache.apply_mru_partial( @@ -2968,7 +2988,7 @@ def test_apply_noop_on_parent_hash_mismatch_preserves_sibling( # Stash a partial under prefix A. tokens = [10, 20, 30, 40, 50, 60] block_table_a = prefix_cache.store_cache( - "req-a", tokens, [self._kv_layer(mx, 6) for _ in range(4)] + "req-a", tokens, [_kv_layer(mx, 6) for _ in range(4)] ) before = dict(prefix_cache._mru_partials) assert len(before) == 1 @@ -2982,7 +3002,7 @@ def test_apply_noop_on_parent_hash_mismatch_preserves_sibling( synthetic_bt.block_ids.append(other_block.block_id) synthetic_bt.num_tokens = 4 - reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) _, new_remaining, applied = prefix_cache.apply_mru_partial( reconstructed, synthetic_bt, [50, 60], ) @@ -2995,10 +3015,10 @@ def test_apply_noop_on_parent_hash_mismatch_preserves_sibling( def test_apply_evicts_on_token_mismatch(self, prefix_cache, mx): """Different trailing tokens → partial cannot apply, evict.""" tokens = [10, 20, 30, 40, 50, 60] - cache_data = [self._kv_layer(mx, 6) for _ in range(4)] + cache_data = [_kv_layer(mx, 6) for _ in range(4)] block_table = prefix_cache.store_cache("req-evict-t", tokens, cache_data) - reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) _, new_remaining, applied = prefix_cache.apply_mru_partial( reconstructed, block_table, [99, 60], # first token doesn't match ) @@ -3012,11 +3032,11 @@ def test_apply_evicts_on_remaining_shorter_than_partial( ): """If remaining_tokens is shorter than the partial it cannot match.""" tokens = [10, 20, 30, 40, 50, 60, 70] - cache_data = [self._kv_layer(mx, 7) for _ in range(4)] + cache_data = [_kv_layer(mx, 7) for _ in range(4)] block_table = prefix_cache.store_cache("req-evict-s", tokens, cache_data) # Partial is [50, 60, 70]; remaining is shorter → must evict. - reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) _, new_remaining, applied = prefix_cache.apply_mru_partial( reconstructed, block_table, [50, 60], ) @@ -3028,11 +3048,11 @@ def test_apply_evicts_on_layer_count_mismatch(self, prefix_cache, mx): """If the reconstructed cache layer count differs from the stashed partial, evict — likely a model swap or bug, not safe.""" tokens = [10, 20, 30, 40, 50, 60] - cache_data = [self._kv_layer(mx, 6) for _ in range(4)] + cache_data = [_kv_layer(mx, 6) for _ in range(4)] block_table = prefix_cache.store_cache("req-evict-lc", tokens, cache_data) # Reconstructed has only 2 layers, partial has 4 → mismatch. - reconstructed = self._make_reconstructed_cache(mx, n_layers=2, n_tokens=4) + reconstructed = _make_reconstructed_cache(mx, n_layers=2, n_tokens=4) _, _, applied = prefix_cache.apply_mru_partial( reconstructed, block_table, [50, 60], ) @@ -3043,7 +3063,7 @@ def test_apply_evicts_on_layer_count_mismatch(self, prefix_cache, mx): def test_apply_noop_when_no_stash(self, prefix_cache, paged_cache, mx): """No partial → no-op, remaining unchanged.""" block_table = paged_cache.create_block_table("req-noop") - reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=0) + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=0) result, remaining, applied = prefix_cache.apply_mru_partial( reconstructed, block_table, [10, 20], @@ -3056,10 +3076,10 @@ def test_apply_noop_when_no_stash(self, prefix_cache, paged_cache, mx): def test_apply_noop_when_remaining_empty(self, prefix_cache, mx): """Empty remaining → exact prefix hit already; no MRU work.""" tokens = [10, 20, 30, 40, 50, 60] - cache_data = [self._kv_layer(mx, 6) for _ in range(4)] + cache_data = [_kv_layer(mx, 6) for _ in range(4)] block_table = prefix_cache.store_cache("req-noop-empty", tokens, cache_data) - reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) _, _, applied = prefix_cache.apply_mru_partial( reconstructed, block_table, [], ) @@ -3083,10 +3103,10 @@ def test_splice_failure_does_not_mutate_any_layer( atomically. """ tokens = [10, 20, 30, 40, 50, 60] - cache_data = [self._kv_layer(mx, 6) for _ in range(4)] + cache_data = [_kv_layer(mx, 6) for _ in range(4)] block_table = prefix_cache.store_cache("req-rollback", tokens, cache_data) - reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) # Snapshot the pre-splice state of every layer for an after-comparison. before_offsets = [layer.offset for layer in reconstructed] before_key_shapes = [layer.keys.shape for layer in reconstructed] @@ -3130,7 +3150,7 @@ def test_stash_correct_indices_when_existing_tokens_present( """ # Pretend a previous turn already cached 4 tokens. prev_tokens = [1, 2, 3, 4] - prev_cache = [self._kv_layer(mx, 4) for _ in range(4)] + prev_cache = [_kv_layer(mx, 4) for _ in range(4)] prefix_cache.store_cache("req-turn-1", prev_tokens, prev_cache) # Second turn: 4 prev + 4 new = 8 prefix block tokens, then 2 trailing. @@ -3214,66 +3234,15 @@ def mock_ssd(self): mock.has_block.return_value = False return mock - def _cache(self, paged_cache, mock_ssd, max_entries=4): - return BlockAwarePrefixCache( - model=MockModel(num_layers=4), - paged_cache_manager=paged_cache, - paged_ssd_cache_manager=mock_ssd, - mru_partial_max_entries=max_entries, - ) - - def _kv_layer(self, mx, n_tokens, head_dim=4): - return { - "state": ( - mx.full((1, 1, n_tokens, head_dim), 1.0), - mx.full((1, 1, n_tokens, head_dim), 1.0), - ), - "cache_type": "KVCache", - "class_name": "KVCache", - } - - def _make_reconstructed_cache(self, mx, n_layers, n_tokens, head_dim=4): - class MockKVCache: - def __init__(self, k, v, offset): - self.keys = k - self.values = v - self.offset = offset - - return [ - MockKVCache( - mx.ones((1, 1, n_tokens, head_dim)), - mx.ones((1, 1, n_tokens, head_dim)), - n_tokens, - ) - for _ in range(n_layers) - ] - - def _stash_with_prefix(self, cache, mx, prefix_marker, tail_token): - """Store a partial under a distinct parent_hash. - - Builds a prompt whose first 4 tokens are unique to ``prefix_marker`` - (forcing a unique parent block hash) and whose 5th token is the - partial tail. Returns (block_table, parent_hash). - """ - tokens = [prefix_marker * 10 + i for i in range(4)] + [tail_token] - cache_data = [self._kv_layer(mx, 5) for _ in range(4)] - block_table = cache.store_cache( - f"req-{prefix_marker}", tokens, cache_data - ) - parent_hash = cache.paged_cache.allocated_blocks[ - block_table.block_ids[-1] - ].block_hash - return block_table, parent_hash - # --- multi-entry coexistence --- def test_distinct_prefixes_coexist_as_separate_entries( self, paged_cache, mock_ssd, mx ): """Two stashes with different parent prefixes produce two entries.""" - cache = self._cache(paged_cache, mock_ssd, max_entries=4) - _, hash_a = self._stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) - _, hash_b = self._stash_with_prefix(cache, mx, prefix_marker=2, tail_token=88) + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + _, hash_a = _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + _, hash_b = _stash_with_prefix(cache, mx, prefix_marker=2, tail_token=88) assert hash_a != hash_b assert len(cache._mru_partials) == 2 @@ -3297,10 +3266,10 @@ def test_lru_capacity_bounds( self, paged_cache, mock_ssd, mx, scenario ): _, capacity, order, expected = scenario - cache = self._cache(paged_cache, mock_ssd, max_entries=capacity) + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=capacity) hashes = {} for marker in order: - _, h = self._stash_with_prefix( + _, h = _stash_with_prefix( cache, mx, prefix_marker=marker, tail_token=900 + marker ) hashes[marker] = h @@ -3314,17 +3283,17 @@ def test_apply_success_promotes_entry_to_lru_tail( """Applying an entry moves it to the LRU tail; a subsequent capacity-eviction drops a now-older sibling, not the just-used entry.""" - cache = self._cache(paged_cache, mock_ssd, max_entries=2) - bt_a, hash_a = self._stash_with_prefix( + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=2) + bt_a, hash_a = _stash_with_prefix( cache, mx, prefix_marker=1, tail_token=901 ) - _, hash_b = self._stash_with_prefix( + _, hash_b = _stash_with_prefix( cache, mx, prefix_marker=2, tail_token=902 ) assert list(cache._mru_partials.keys()) == [hash_a, hash_b] # Apply A → A promoted to tail. - reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) # Remaining must equal A's stashed tokens. Tail token was 901, # placed at index 4 of A's prompt. _, _, applied = cache.apply_mru_partial(reconstructed, bt_a, [901]) @@ -3332,7 +3301,7 @@ def test_apply_success_promotes_entry_to_lru_tail( assert list(cache._mru_partials.keys()) == [hash_b, hash_a] # Stash C at capacity 2 → B evicted (oldest after promote), A kept. - _, hash_c = self._stash_with_prefix( + _, hash_c = _stash_with_prefix( cache, mx, prefix_marker=3, tail_token=903 ) assert list(cache._mru_partials.keys()) == [hash_a, hash_c] @@ -3343,8 +3312,8 @@ def test_apply_success_promotes_entry_to_lru_tail( def test_max_entries_zero_disables_stashing( self, paged_cache, mock_ssd, mx ): - cache = self._cache(paged_cache, mock_ssd, max_entries=0) - self._stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=0) + _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) assert len(cache._mru_partials) == 0 assert cache.has_mru_partial() is False @@ -3357,8 +3326,8 @@ def test_clear_mru_partials_wipes_only_partials( """``clear_mru_partials()`` is the admin-clear hook. It wipes the MRU dict but must not touch ``paged_cache``, the prefix index, or stats — those have their own clear paths.""" - cache = self._cache(paged_cache, mock_ssd, max_entries=4) - bt, _ = self._stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + bt, _ = _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) prefix_index_before = dict(cache._prefix_index) request_tables_before = dict(cache._request_tables) @@ -3386,12 +3355,12 @@ def test_apply_noop_when_parent_block_freed( This race is new in multi-slot: single-slot tolerated it because there was only ever one slot to match against. """ - cache = self._cache(paged_cache, mock_ssd, max_entries=4) + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) # Stash a short-prompt entry (parent_hash=None) — the false-match # bait for the freed-block scenario. short_tokens = [99, 100, 101] # < block_size=4 cache.store_cache( - "req-short", short_tokens, [self._kv_layer(mx, 3) for _ in range(4)] + "req-short", short_tokens, [_kv_layer(mx, 3) for _ in range(4)] ) # Confirm the short-prompt entry landed under None. assert None in cache._mru_partials @@ -3404,7 +3373,7 @@ def test_apply_noop_when_parent_block_freed( bt.block_ids.append(freed_block_id) bt.num_tokens = 4 - reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) _, new_remaining, applied = cache.apply_mru_partial( reconstructed, bt, [99, 100, 101] ) @@ -3421,14 +3390,14 @@ def test_apply_noop_when_parent_block_freed( def test_short_prompt_none_key_coexists_with_block_aligned_entry( self, paged_cache, mock_ssd, mx ): - cache = self._cache(paged_cache, mock_ssd, max_entries=4) + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) # Short prompt (< block_size) → parent_hash=None cache.store_cache( "req-short", [10, 20, 30], - [self._kv_layer(mx, 3) for _ in range(4)], + [_kv_layer(mx, 3) for _ in range(4)], ) # Longer prompt → distinct parent_hash - _, hash_long = self._stash_with_prefix( + _, hash_long = _stash_with_prefix( cache, mx, prefix_marker=1, tail_token=99 ) @@ -3473,53 +3442,8 @@ def mock_ssd(self): mock.has_block.return_value = False return mock - def _cache(self, paged_cache, mock_ssd, max_entries=4): - return BlockAwarePrefixCache( - model=MockModel(num_layers=4), - paged_cache_manager=paged_cache, - paged_ssd_cache_manager=mock_ssd, - mru_partial_max_entries=max_entries, - ) - - def _kv_layer(self, mx, n_tokens, head_dim=4): - return { - "state": ( - mx.full((1, 1, n_tokens, head_dim), 1.0), - mx.full((1, 1, n_tokens, head_dim), 1.0), - ), - "cache_type": "KVCache", - "class_name": "KVCache", - } - - def _make_reconstructed_cache(self, mx, n_layers, n_tokens, head_dim=4): - class MockKVCache: - def __init__(self, k, v, offset): - self.keys = k - self.values = v - self.offset = offset - - return [ - MockKVCache( - mx.ones((1, 1, n_tokens, head_dim)), - mx.ones((1, 1, n_tokens, head_dim)), - n_tokens, - ) - for _ in range(n_layers) - ] - - def _stash_with_prefix(self, cache, mx, prefix_marker, tail_token): - tokens = [prefix_marker * 10 + i for i in range(4)] + [tail_token] - cache_data = [self._kv_layer(mx, 5) for _ in range(4)] - block_table = cache.store_cache( - f"req-{prefix_marker}", tokens, cache_data - ) - parent_hash = cache.paged_cache.allocated_blocks[ - block_table.block_ids[-1] - ].block_hash - return block_table, parent_hash - def test_initial_counters_are_zero(self, paged_cache, mock_ssd): - cache = self._cache(paged_cache, mock_ssd, max_entries=4) + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) stats = cache.get_stats() assert stats.mru_partial_stashes == 0 assert stats.mru_partial_hits == 0 @@ -3529,9 +3453,9 @@ def test_initial_counters_are_zero(self, paged_cache, mock_ssd): assert stats.mru_partial_max_entries == 4 def test_stash_increments_stash_counter(self, paged_cache, mock_ssd, mx): - cache = self._cache(paged_cache, mock_ssd, max_entries=4) - self._stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) - self._stash_with_prefix(cache, mx, prefix_marker=2, tail_token=88) + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + _stash_with_prefix(cache, mx, prefix_marker=2, tail_token=88) stats = cache.get_stats() assert stats.mru_partial_stashes == 2 @@ -3543,9 +3467,9 @@ def test_same_key_replacement_counts_as_stash_not_eviction( """Replacing an existing entry under the same key counts as a stash but NOT as an eviction. Eviction is reserved for entries that leave the dict (capacity overflow, apply-miss, clear).""" - cache = self._cache(paged_cache, mock_ssd, max_entries=4) - self._stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) - self._stash_with_prefix(cache, mx, prefix_marker=1, tail_token=77) + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=77) stats = cache.get_stats() assert stats.mru_partial_stashes == 2 @@ -3555,9 +3479,9 @@ def test_same_key_replacement_counts_as_stash_not_eviction( def test_capacity_overflow_increments_eviction_counter( self, paged_cache, mock_ssd, mx ): - cache = self._cache(paged_cache, mock_ssd, max_entries=2) + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=2) for i in (1, 2, 3): - self._stash_with_prefix(cache, mx, prefix_marker=i, tail_token=100 + i) + _stash_with_prefix(cache, mx, prefix_marker=i, tail_token=100 + i) stats = cache.get_stats() assert stats.mru_partial_stashes == 3 @@ -3567,12 +3491,12 @@ def test_capacity_overflow_increments_eviction_counter( def test_apply_success_increments_hits_and_tokens_saved( self, paged_cache, mock_ssd, mx ): - cache = self._cache(paged_cache, mock_ssd, max_entries=4) - bt, _ = self._stash_with_prefix( + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + bt, _ = _stash_with_prefix( cache, mx, prefix_marker=1, tail_token=901 ) - reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) _, _, applied = cache.apply_mru_partial(reconstructed, bt, [901]) assert applied == 1 @@ -3585,12 +3509,12 @@ def test_apply_miss_on_found_key_increments_eviction( self, paged_cache, mock_ssd, mx ): """Token-mismatch eviction pops the matched key and counts.""" - cache = self._cache(paged_cache, mock_ssd, max_entries=4) - bt, _ = self._stash_with_prefix( + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + bt, _ = _stash_with_prefix( cache, mx, prefix_marker=1, tail_token=99 ) - reconstructed = self._make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) _, _, applied = cache.apply_mru_partial(reconstructed, bt, [77]) # wrong tail assert applied == 0 @@ -3601,10 +3525,10 @@ def test_apply_miss_on_found_key_increments_eviction( def test_clear_mru_partials_counts_all_wiped_entries( self, paged_cache, mock_ssd, mx ): - cache = self._cache(paged_cache, mock_ssd, max_entries=4) - self._stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) - self._stash_with_prefix(cache, mx, prefix_marker=2, tail_token=88) - self._stash_with_prefix(cache, mx, prefix_marker=3, tail_token=77) + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + _stash_with_prefix(cache, mx, prefix_marker=2, tail_token=88) + _stash_with_prefix(cache, mx, prefix_marker=3, tail_token=77) n = cache.clear_mru_partials() assert n == 3 @@ -3623,9 +3547,9 @@ def test_clear_wipes_partials_and_resets_counters( Operators tracking partial wipes specifically use clear_mru_partials() instead. """ - cache = self._cache(paged_cache, mock_ssd, max_entries=4) - self._stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) - self._stash_with_prefix(cache, mx, prefix_marker=2, tail_token=88) + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + _stash_with_prefix(cache, mx, prefix_marker=2, tail_token=88) cache.clear() @@ -3641,8 +3565,8 @@ def test_reset_stats_zeros_mru_counters_but_keeps_live_state( """reset_stats() is the analyst's reset — it zeros cumulative counters but leaves the live cache state alone. Use clear_mru_partials() if entries should be dropped too.""" - cache = self._cache(paged_cache, mock_ssd, max_entries=4) - self._stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) cache.reset_stats() From f0a9d1a9ceb039aa9eb9ef289d22a6bfeaec8c00 Mon Sep 17 00:00:00 2001 From: Blightbow Date: Thu, 14 May 2026 18:41:07 -0400 Subject: [PATCH 13/18] fix(cache): surface MRU fields in get_stats_dict MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The admin dashboard reads MRU partial cache state via ``Scheduler.get_ssd_cache_stats`` -> ``BlockAwarePrefixCache .get_stats_dict``, which silently dropped every MRU field added by the observability-counters commit. With ``mru_partial_max_entries`` aggregating to 0 in the admin payload, the dashboard's ``mruEnabled`` gate stayed false and hid every panel — header gauge, rate-strip cells, and the per-model "MRU Tails" column — even when operators had the feature configured. The counter delta path was unaffected: ``Scheduler._collect_cache_counters`` reads the ``PrefixCacheStats`` dataclass via ``get_stats()``, so ``cache_rates.cumulative`` already carried ``mru_partial_hit_rate`` and ``mru_partial_tokens_saved``. The dashboard simply refused to render those cells because the gauge gate was false. Adds six fields to ``get_stats_dict()``: - ``mru_partial_stashes`` (counter) - ``mru_partial_hits`` (counter) - ``mru_partial_evictions`` (counter) - ``mru_partial_tokens_saved`` (counter) - ``mru_partial_entries`` (gauge, len(_mru_partials)) - ``mru_partial_max_entries`` (gauge, configured capacity) Regression test --------------- Adds ``TestMRUPartialCounters ::test_get_stats_dict_mirrors_dataclass_after_round_trip``, following the Pattern B mandate the class docstring already established (real ``store_cache`` round-trip rather than hand-built ``_MRUPartialBlock`` state). Exercises three stashes plus one successful apply against ``max_entries=2`` so every counter and gauge moves off zero, then asserts each MRU field on ``get_stats_dict()`` matches the corresponding field on the ``PrefixCacheStats`` dataclass. Sibling MRU counter tests already covered each counter individually via ``get_stats()`` (the dataclass), which is why the missing-keys regression slipped past them — none asserted on the dict surface. This test closes that loop and was verified to fail cleanly without the fix (``mru_partial_stashes missing from get_stats_dict()``). Suite: 139 passed (cache + observability). Co-Authored-By: Claude Opus 4.7 (1M context) --- omlx/cache/prefix_cache.py | 10 ++++++ tests/test_prefix_cache.py | 64 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/omlx/cache/prefix_cache.py b/omlx/cache/prefix_cache.py index fc9c9f6a..931dc824 100644 --- a/omlx/cache/prefix_cache.py +++ b/omlx/cache/prefix_cache.py @@ -2809,6 +2809,16 @@ def get_stats_dict(self) -> dict[str, Any]: "tokens_matched_total": self._tokens_matched_total, "tokens_requested_total": self._tokens_requested_total, "active_requests": len(self._request_tables), + # MRU partial cache: counters mirror the dataclass surface so the + # admin dashboard's `mruEnabled` gate (sourced from this dict + # via Scheduler.get_ssd_cache_stats) can see the configured + # capacity. Omitting them silently hides every MRU panel. + "mru_partial_stashes": self._mru_partial_stashes, + "mru_partial_hits": self._mru_partial_hits, + "mru_partial_evictions": self._mru_partial_evictions, + "mru_partial_tokens_saved": self._mru_partial_tokens_saved, + "mru_partial_entries": len(self._mru_partials), + "mru_partial_max_entries": self._mru_partial_max_entries, **paged_stats, } diff --git a/tests/test_prefix_cache.py b/tests/test_prefix_cache.py index cb17e5f6..ee09a34f 100644 --- a/tests/test_prefix_cache.py +++ b/tests/test_prefix_cache.py @@ -3578,6 +3578,70 @@ def test_reset_stats_zeros_mru_counters_but_keeps_live_state( # But the live entry is still there. assert stats.mru_partial_entries == 1 + def test_get_stats_dict_mirrors_dataclass_after_round_trip( + self, paged_cache, mock_ssd, mx + ): + """``get_stats_dict`` must surface every MRU field that ``get_stats`` + (the dataclass) does. The admin dashboard reads MRU state via the + dict path (``Scheduler.get_ssd_cache_stats`` -> ``get_stats_dict``); + when the dict drops any of these keys, the admin payload's + ``mru_partial_max_entries`` aggregates to 0 and the dashboard's + ``mruEnabled`` gate hides every MRU panel even when the feature + is enabled. + + Uses the production round-trip (real stashes via ``store_cache``, + real apply via ``apply_mru_partial``, real capacity overflow) so + the live gauge ``mru_partial_entries`` and every counter reach the + dict via the same path the scheduler exercises. + """ + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=2) + # Stash two distinct prefixes (both fit in the cap). The first + # one (``bt_kept``) will be applied below to bump hits; the LRU + # touch from a successful apply leaves it at the MRU end, so the + # later capacity-overflow stash evicts the *other* survivor and + # not the one we just touched. + bt_kept, _ = _stash_with_prefix( + cache, mx, prefix_marker=2, tail_token=88 + ) + _stash_with_prefix(cache, mx, prefix_marker=3, tail_token=77) + + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + _, _, applied = cache.apply_mru_partial(reconstructed, bt_kept, [88]) + assert applied == 1 # guard: the apply path actually fired + + # Third stash forces capacity-overflow eviction of the un-touched + # entry. After this: 3 stashes, 1 hit, 1 token saved, 1 eviction, + # 2 live entries. + _stash_with_prefix(cache, mx, prefix_marker=4, tail_token=66) + + stats = cache.get_stats() + stats_dict = cache.get_stats_dict() + + # Every MRU field on the dataclass must surface in the dict with + # the same value — this is the contract the admin route depends on. + for field in ( + "mru_partial_stashes", + "mru_partial_hits", + "mru_partial_evictions", + "mru_partial_tokens_saved", + "mru_partial_entries", + "mru_partial_max_entries", + ): + assert field in stats_dict, f"{field} missing from get_stats_dict()" + assert stats_dict[field] == getattr(stats, field), ( + f"{field}: dict={stats_dict[field]} dataclass={getattr(stats, field)}" + ) + + # Sanity: the round-trip actually moved every counter off its + # initial zero, so a future regression that hardwires zeros into + # the dict would still fail this test. + assert stats_dict["mru_partial_stashes"] == 3 + assert stats_dict["mru_partial_hits"] == 1 + assert stats_dict["mru_partial_evictions"] == 1 # capacity overflow + assert stats_dict["mru_partial_tokens_saved"] == 1 + assert stats_dict["mru_partial_entries"] == 2 + assert stats_dict["mru_partial_max_entries"] == 2 + class TestHasMRUPartial: """The has_mru_partial() accessor is the public API the scheduler From 5848c49f19952460756f1644ff214fa57c5a8b3c Mon Sep 17 00:00:00 2001 From: Blightbow Date: Thu, 14 May 2026 22:52:23 -0400 Subject: [PATCH 14/18] feat(cache): flag model-incompatible MRU on dashboard + warn log MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The MRU partial-block stash safety gate refuses any layer set containing a non-sliceable cache type (``RotatingKVCache``, ``PoolingCache``, ``ArraysCache``, ``CacheList``, etc.) — splicing a partial into a sliceable subset only would cause per-layer offset skew at decode (silent generation corruption). For affected models the entire MRU feature is structurally unavailable, but the dashboard previously rendered a misleading "0/N entries" gauge that left operators puzzling over an apparent config bug. Concrete case: DeepSeek-V4-Flash (every layer is ``CacheList(RotatingKVCache, PoolingCache, PoolingCache)``). Detection --------- Adds a tri-state ``mru_partial_supported`` flag to ``BlockAwarePrefixCache`` and ``PrefixCacheStats``: - ``None`` → unknown (no introspection has resolved it yet) - ``True`` → every observed layer is sliceable - ``False`` → at least one non-sliceable layer observed; every future stash attempt is refused at the safety gate Two detection paths feed the flag: 1. **Eager (load time):** ``_check_mru_eligibility_at_init`` calls ``model.make_cache()`` once at construction, extracts type names via ``ModelCacheConfig.from_cache_list``, and resolves the flag immediately. Best-effort: if make_cache is absent or raises, falls back to lazy detection without crashing. Cache instances are dropped via ``del`` after inspection — no tensor buffers are allocated (those arrive on first prefill), and Python wrappers are GC-reclaimed. 2. **Lazy (first inference):** ``_update_mru_partial`` checks ``_all_layers_sliceable(layer_cache_types)`` on each call. First non-sliceable observation latches the flag and emits the warning; first sliceable observation latches True. The warning fires exactly once per cache instance via ``_mru_partial_warn_emitted``. Operator log message includes the offending types and the sliceable whitelist so it's grep-actionable: WARNING omlx.cache.prefix_cache: MRU tail cache disabled for this model: layer types ['RotatingKVCache', 'PoolingCache'] are not in the sliceable whitelist [...]. Splicing a partial into a non-sliceable subset would cause per-layer offset skew at decode (silent generation corruption), so every stash attempt will be refused. The admin dashboard's per-model 'MRU Tails' cell will display 'N/A (see log)'. Dashboard surface ----------------- Per-model "MRU Tails" cell renders ``N/A (see log)`` when ``mru_partial_supported === false``; otherwise renders ``entries / max_entries`` as before. Hover tooltip references the server log. Global rate-strip cells (MRU Tail Hit Rate, MRU Tokens Saved) stay as aggregates — if any loaded model is compatible, those cells still surface its payoff. Tests (8 new in TestMRUPartialEligibility) ------------------------------------------ - ``supported_is_none_without_make_cache_and_no_inference`` - ``supported_latches_true_on_sliceable_observation`` - ``supported_latches_false_lazy_on_non_sliceable`` (round-trip via ``store_cache`` with ``_rotating_layer`` factory) - ``warning_does_not_repeat_on_subsequent_non_sliceable`` - ``eager_check_latches_false_at_init_with_non_sliceable_make_cache`` - ``eager_check_latches_true_at_init_with_sliceable_make_cache`` - ``eager_check_skipped_when_feature_disabled`` - ``eager_check_survives_make_cache_failure`` Existing ``TestRuntimeCacheObservability::test_runtime_cache _uses_model_scoped_ssd_stats`` updated to include the new per-model payload key (``mru_partial_supported: None`` for mocks that don't populate ``prefix_cache``). Suite: 215 passed (cache + observability + admin). Co-Authored-By: Claude Opus 4.7 (1M context) --- omlx/admin/routes.py | 6 + omlx/admin/templates/dashboard/_status.html | 8 +- omlx/cache/prefix_cache.py | 103 +++++++++- omlx/cache/stats.py | 8 + tests/test_admin_api_key.py | 2 + tests/test_prefix_cache.py | 209 ++++++++++++++++++++ 6 files changed, 334 insertions(+), 2 deletions(-) diff --git a/omlx/admin/routes.py b/omlx/admin/routes.py index f94cbcc8..053012e7 100644 --- a/omlx/admin/routes.py +++ b/omlx/admin/routes.py @@ -3670,6 +3670,12 @@ def _build_runtime_cache_observability( "mru_partial_max_entries": int( prefix_stats.get("mru_partial_max_entries", 0) or 0 ), + # Tri-state: None (unknown / no inference yet), True (eligible), + # False (model uses non-sliceable cache layers — every stash + # refused at the safety gate; dashboard renders 'N/A (see log)'). + "mru_partial_supported": prefix_stats.get( + "mru_partial_supported", None + ), } cache_rates = runtime_stats.get("cache_rates") diff --git a/omlx/admin/templates/dashboard/_status.html b/omlx/admin/templates/dashboard/_status.html index 2365247f..6fb3c04a 100644 --- a/omlx/admin/templates/dashboard/_status.html +++ b/omlx/admin/templates/dashboard/_status.html @@ -450,7 +450,13 @@

{{ t('status.head

- + diff --git a/omlx/cache/prefix_cache.py b/omlx/cache/prefix_cache.py index 931dc824..fb87ba7b 100644 --- a/omlx/cache/prefix_cache.py +++ b/omlx/cache/prefix_cache.py @@ -199,6 +199,97 @@ def __init__( self._mru_partial_evictions = 0 self._mru_partial_tokens_saved = 0 + # Tri-state MRU eligibility flag (see PrefixCacheStats.mru_partial_supported). + # Latched True once a sliceable layer set is observed; latched False + # the first time a non-sliceable set is observed (eager check below + # or lazy fallback inside _update_mru_partial). + self._mru_partial_supported: bool | None = None + self._mru_partial_warn_emitted: bool = False + if self._mru_partial_max_entries > 0: + self._check_mru_eligibility_at_init(model) + + def _check_mru_eligibility_at_init(self, model: Any) -> None: + """Best-effort load-time check: does this model have sliceable layers? + + The MRU partial-block stash safety gate rejects any layer set that + contains a non-sliceable cache type (``RotatingKVCache``, + ``PoolingCache``, ``ArraysCache``, ``CacheList``, etc.) — splicing + a partial into a sliceable subset only would cause per-layer offset + skew at decode (silent generation corruption). For affected + models the entire MRU feature is structurally unavailable, and the + dashboard would otherwise show a confusing "0/N entries" gauge + forever. Warn the operator at load time so they can grep this + exact line for the offending types. + + Best-effort: if ``model.make_cache()`` is absent or raises, the + lazy fallback inside ``_update_mru_partial`` picks the same signal + up at first inference instead. + """ + if not hasattr(model, "make_cache"): + return + try: + sample_cache = model.make_cache() + except Exception as exc: + logger.debug( + "MRU eager eligibility check skipped (make_cache failed: %s)", + exc, + ) + return + try: + mcc = ModelCacheConfig.from_cache_list(sample_cache) + layer_types = mcc.get_type_names() if mcc is not None else None + except Exception as exc: + logger.debug( + "MRU eager eligibility check skipped (ModelCacheConfig failed: %s)", + exc, + ) + return + finally: + # Drop the sample cache refs ASAP — only used for type-name + # introspection. No tensor buffers were allocated yet (those + # arrive on first prefill), but keeping the wrappers around + # serves no purpose. + del sample_cache + if not layer_types: + return + if self._all_layers_sliceable(layer_types): + self._mru_partial_supported = True + else: + self._record_mru_unsupported(layer_types) + + def _record_mru_unsupported( + self, layer_cache_types: list[str] | None + ) -> None: + """Latch ``mru_partial_supported=False`` and emit a one-shot warning. + + Called from both the eager init-time check and the lazy fallback + inside ``_update_mru_partial``. The warning identifies the + offending cache types so an operator can decide whether to disable + the feature (``--mru-partial-max-entries=0``) or accept the gauge + showing 'N/A (see log)' on this model's dashboard row. + """ + already_recorded = self._mru_partial_supported is False + self._mru_partial_supported = False + if already_recorded or self._mru_partial_warn_emitted: + return + self._mru_partial_warn_emitted = True + if layer_cache_types: + offenders = sorted( + set(layer_cache_types) - KNOWN_SLICEABLE_CACHE_TYPES + ) + else: + offenders = [""] + logger.warning( + "MRU tail cache disabled for this model: layer types %s are not in " + "the sliceable whitelist %s. Splicing a partial into a " + "non-sliceable subset would cause per-layer offset skew at decode " + "(silent generation corruption), so every stash attempt will be " + "refused. The admin dashboard's per-model 'MRU Tails' cell will " + "display 'N/A (see log)'.", + offenders, + sorted(KNOWN_SLICEABLE_CACHE_TYPES), + ) + def _get_model_num_layers(self, model: Any) -> int: """ Get the expected number of *cache layers* for validation. @@ -852,11 +943,19 @@ def _update_mru_partial( # Feature disabled. Match the hot_cache_max_size="0" convention: # silent fallback to "no MRU" behavior. return + # Sliceable-layer guard is checked first and recorded so the + # eligibility flag flips at the same instant the warning fires — + # later branches in this chain are per-call ("no eligible tail + # this time") and must not pollute the structural flag. + if not self._all_layers_sliceable(layer_cache_types): + self._record_mru_unsupported(layer_cache_types) + return + if self._mru_partial_supported is None: + self._mru_partial_supported = True if ( trailing_partial_tokens == 0 or not is_tensor_data or not self._can_reconstruct() - or not self._all_layers_sliceable(layer_cache_types) ): return @@ -2780,6 +2879,7 @@ def get_stats(self) -> PrefixCacheStats: mru_partial_tokens_saved=self._mru_partial_tokens_saved, mru_partial_entries=len(self._mru_partials), mru_partial_max_entries=self._mru_partial_max_entries, + mru_partial_supported=self._mru_partial_supported, ) def get_stats_dict(self) -> dict[str, Any]: @@ -2819,6 +2919,7 @@ def get_stats_dict(self) -> dict[str, Any]: "mru_partial_tokens_saved": self._mru_partial_tokens_saved, "mru_partial_entries": len(self._mru_partials), "mru_partial_max_entries": self._mru_partial_max_entries, + "mru_partial_supported": self._mru_partial_supported, **paged_stats, } diff --git a/omlx/cache/stats.py b/omlx/cache/stats.py index 230824d4..47f2749b 100644 --- a/omlx/cache/stats.py +++ b/omlx/cache/stats.py @@ -100,6 +100,14 @@ class PrefixCacheStats(BaseCacheStats): mru_partial_tokens_saved: int = 0 mru_partial_entries: int = 0 mru_partial_max_entries: int = 0 + # Tri-state eligibility flag for the MRU partial-block feature: + # None → unknown (default; no detection has fired yet) + # True → model has only sliceable cache layers + # False → at least one layer type is not in the sliceable whitelist, + # so every stash attempt is refused at the safety gate. + # Surfaces to the admin dashboard so per-model rows can show + # "N/A (see log)" instead of a misleading "0/N entries" gauge. + mru_partial_supported: bool | None = None _total_queries: int = field(default=0, repr=False) @property diff --git a/tests/test_admin_api_key.py b/tests/test_admin_api_key.py index cb6fc4ef..eb03106b 100644 --- a/tests/test_admin_api_key.py +++ b/tests/test_admin_api_key.py @@ -751,6 +751,7 @@ def test_runtime_cache_uses_model_scoped_ssd_stats(self): "hot_cache_entries": 0, "mru_partial_entries": 0, "mru_partial_max_entries": 0, + "mru_partial_supported": None, }, { "id": "model-b", @@ -770,6 +771,7 @@ def test_runtime_cache_uses_model_scoped_ssd_stats(self): "hot_cache_entries": 0, "mru_partial_entries": 0, "mru_partial_max_entries": 0, + "mru_partial_supported": None, }, ] manager_a.get_stats_for_model.assert_called_once_with("/models/model-a") diff --git a/tests/test_prefix_cache.py b/tests/test_prefix_cache.py index ee09a34f..fc4c2a61 100644 --- a/tests/test_prefix_cache.py +++ b/tests/test_prefix_cache.py @@ -6,6 +6,7 @@ PagedCacheManager for block-based storage with SSD persistence. """ +import logging import time from pathlib import Path from typing import Any, Dict, List, Optional @@ -3643,6 +3644,214 @@ def test_get_stats_dict_mirrors_dataclass_after_round_trip( assert stats_dict["mru_partial_max_entries"] == 2 +def _model_with_make_cache(num_layers: int, layer_class_names: list[str]): + """Build a MockModel whose ``make_cache()`` returns objects whose + ``type(obj).__name__`` matches the requested cache class names. + + ``ModelCacheConfig.from_cache_list`` identifies cache types by class + name (with isinstance fallback for SizedArraysCache only), so dynamic + classes are enough to exercise the eager init-time eligibility check + without pulling in real mlx-lm cache implementations. + """ + cache_objs = [ + type(name, (object,), {"max_size": 64})() + for name in layer_class_names + ] + model = MockModel(num_layers=num_layers) + model.make_cache = lambda: cache_objs # type: ignore[attr-defined] + return model + + +class TestMRUPartialEligibility: + """The ``mru_partial_supported`` tri-state flag and its one-shot + warning. Surfaces structurally-incompatible models on the admin + dashboard so operators see ``N/A (see log)`` instead of a misleading + ``0/N entries`` gauge. Mirrors the prior-art Pattern B (real + ``store_cache`` round-trip) for the lazy fallback path, and exercises + the eager init-time path through a model with ``make_cache()``. + """ + + @pytest.fixture + def mx(self): + try: + import mlx.core as mx + return mx + except ImportError: + pytest.skip("MLX not available") + + @pytest.fixture + def paged_cache(self): + return PagedCacheManager( + block_size=4, + max_blocks=100, + model_name="test-model", + initial_blocks=100, + ) + + @pytest.fixture + def mock_ssd(self): + mock = MagicMock() + mock.save_block.return_value = True + mock.load_block.return_value = None + mock.load_block_with_metadata.return_value = (None, None) + mock.has_block.return_value = False + return mock + + def test_supported_is_none_without_make_cache_and_no_inference( + self, paged_cache, mock_ssd + ): + """``MockModel`` has no ``make_cache``; eager check bare-returns and + lazy fallback hasn't fired yet — flag stays ``None``.""" + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + stats = cache.get_stats() + assert stats.mru_partial_supported is None + assert cache.get_stats_dict()["mru_partial_supported"] is None + + def test_supported_latches_true_on_sliceable_observation( + self, paged_cache, mock_ssd, mx + ): + """A successful KVCache stash latches ``supported=True`` via the + lazy path (eager skipped because MockModel has no make_cache).""" + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + assert cache.get_stats().mru_partial_supported is True + + def test_supported_latches_false_lazy_on_non_sliceable( + self, paged_cache, mock_ssd, mx, caplog + ): + """A store_cache with RotatingKVCache layers latches + ``supported=False`` and emits exactly one warning.""" + from omlx.cache.hybrid_cache import ModelCacheConfig + + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + # Eager init skipped (MockModel has no make_cache), so flag is + # None at this point — the lazy path is what we're testing. + assert cache.get_stats().mru_partial_supported is None + + tokens = [10, 20, 30, 40, 50] + cache_data = [_rotating_layer(mx, 5) for _ in range(4)] + config = ModelCacheConfig.from_type_list( + ["RotatingKVCache"] * 4, model_name="test" + ) + with caplog.at_level(logging.WARNING, logger="omlx.cache.prefix_cache"): + cache.store_cache("req-rot", tokens, cache_data, model_cache_config=config) + + stats = cache.get_stats() + assert stats.mru_partial_supported is False + assert stats.mru_partial_stashes == 0 # gate refused, no stash + warns = [r for r in caplog.records if r.levelno == logging.WARNING] + assert len(warns) == 1 + assert "MRU tail cache disabled" in warns[0].getMessage() + assert "RotatingKVCache" in warns[0].getMessage() + + def test_warning_does_not_repeat_on_subsequent_non_sliceable( + self, paged_cache, mock_ssd, mx, caplog + ): + """Once the flag is latched False, further non-sliceable store_cache + calls must NOT re-emit the warning (operator log spam guard).""" + from omlx.cache.hybrid_cache import ModelCacheConfig + + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + config = ModelCacheConfig.from_type_list( + ["RotatingKVCache"] * 4, model_name="test" + ) + with caplog.at_level(logging.WARNING, logger="omlx.cache.prefix_cache"): + for i in range(3): + cache.store_cache( + f"req-rot-{i}", + [10 * i + j for j in range(5)], + [_rotating_layer(mx, 5) for _ in range(4)], + model_cache_config=config, + ) + + warns = [r for r in caplog.records if r.levelno == logging.WARNING] + assert len(warns) == 1 + assert cache.get_stats().mru_partial_supported is False + + def test_eager_check_latches_false_at_init_with_non_sliceable_make_cache( + self, paged_cache, mock_ssd, caplog + ): + """When ``model.make_cache()`` is available and returns non-sliceable + cache instances, the flag latches False at construction and the + warning fires BEFORE any inference — true model-load-time signal.""" + model = _model_with_make_cache( + num_layers=4, + layer_class_names=["RotatingKVCache"] * 4, + ) + with caplog.at_level(logging.WARNING, logger="omlx.cache.prefix_cache"): + cache = BlockAwarePrefixCache( + model=model, + paged_cache_manager=paged_cache, + paged_ssd_cache_manager=mock_ssd, + mru_partial_max_entries=4, + ) + + assert cache.get_stats().mru_partial_supported is False + warns = [r for r in caplog.records if r.levelno == logging.WARNING] + assert len(warns) == 1 + assert "RotatingKVCache" in warns[0].getMessage() + + def test_eager_check_latches_true_at_init_with_sliceable_make_cache( + self, paged_cache, mock_ssd, caplog + ): + """When ``model.make_cache()`` returns only sliceable cache + instances, the flag latches True at construction and no warning + is emitted.""" + model = _model_with_make_cache( + num_layers=4, + layer_class_names=["KVCache"] * 4, + ) + with caplog.at_level(logging.WARNING, logger="omlx.cache.prefix_cache"): + cache = BlockAwarePrefixCache( + model=model, + paged_cache_manager=paged_cache, + paged_ssd_cache_manager=mock_ssd, + mru_partial_max_entries=4, + ) + + assert cache.get_stats().mru_partial_supported is True + assert not [r for r in caplog.records if r.levelno == logging.WARNING] + + def test_eager_check_skipped_when_feature_disabled( + self, paged_cache, mock_ssd, caplog + ): + """``max_entries=0`` disables the feature; no eager check runs even + for an obviously incompatible model. No warning, no flag change — + the operator already opted out by setting the capacity to zero.""" + model = _model_with_make_cache( + num_layers=4, + layer_class_names=["RotatingKVCache"] * 4, + ) + with caplog.at_level(logging.WARNING, logger="omlx.cache.prefix_cache"): + cache = BlockAwarePrefixCache( + model=model, + paged_cache_manager=paged_cache, + paged_ssd_cache_manager=mock_ssd, + mru_partial_max_entries=0, + ) + + assert cache.get_stats().mru_partial_supported is None + assert not [r for r in caplog.records if r.levelno == logging.WARNING] + + def test_eager_check_survives_make_cache_failure( + self, paged_cache, mock_ssd, caplog + ): + """If ``model.make_cache()`` raises, the eager check bare-returns + and the flag stays ``None``. Lazy fallback picks up at first + inference instead — no startup crash.""" + model = MockModel(num_layers=4) + model.make_cache = lambda: (_ for _ in ()).throw( # type: ignore[attr-defined] + RuntimeError("model not fully initialized") + ) + cache = BlockAwarePrefixCache( + model=model, + paged_cache_manager=paged_cache, + paged_ssd_cache_manager=mock_ssd, + mru_partial_max_entries=4, + ) + assert cache.get_stats().mru_partial_supported is None + + class TestHasMRUPartial: """The has_mru_partial() accessor is the public API the scheduler uses to decide whether to suppress the deferred Metal cache clear.""" From b039eb271c1b8fecd5d3e49e31198c29007fbe5b Mon Sep 17 00:00:00 2001 From: Blightbow Date: Thu, 14 May 2026 23:25:44 -0400 Subject: [PATCH 15/18] refactor(cache): plain-language MRU incompatibility warning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The model-load warning added in 5848c49 read like a developer note: it dumped the internal sliceable whitelist, explained the splice mechanism and the offset-skew failure mode, raised a "silent generation corruption" alarm, and narrated which dashboard cell would change. Rewrite it in the in-tree load-phase warning voice — "condition + consequence, plain words" — matching the ``mtp_enabled`` warning in ``utils/model_loading.py`` and the L2 warning in ``engine/dflash.py``: MRU tail cache enabled but this model is incompatible (cache layers: RotatingKVCache, PoolingCache); MRU tails will be inactive for this model. - Drop the whitelist dump — operators don't tune against it. - Drop the splice-mechanism rationale — that's an engineering explanation, not an operator decision point. It still lives in the ``_all_layers_sliceable`` docstring where developers read it; ``_record_mru_unsupported``'s docstring now points there. - Drop the dashboard self-reference — the dashboard already renders 'N/A (see log)'; the log shouldn't narrate the UI. - ``", ".join(offenders)`` instead of raw list repr, and ``"unknown"`` instead of ``""`` for the fallback. Tests assert on the new wording ("MRU tails will be inactive", "incompatible"). Would have folded into 5848c49 as an amend, but the merge of main now sits between that commit and HEAD. Co-Authored-By: Claude Opus 4.7 (1M context) --- omlx/cache/prefix_cache.py | 24 +++++++++++------------- tests/test_prefix_cache.py | 3 ++- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/omlx/cache/prefix_cache.py b/omlx/cache/prefix_cache.py index fb87ba7b..a914ce1a 100644 --- a/omlx/cache/prefix_cache.py +++ b/omlx/cache/prefix_cache.py @@ -263,10 +263,13 @@ def _record_mru_unsupported( """Latch ``mru_partial_supported=False`` and emit a one-shot warning. Called from both the eager init-time check and the lazy fallback - inside ``_update_mru_partial``. The warning identifies the - offending cache types so an operator can decide whether to disable - the feature (``--mru-partial-max-entries=0``) or accept the gauge - showing 'N/A (see log)' on this model's dashboard row. + inside ``_update_mru_partial``. The warning matches the in-tree + load-phase warning style ("condition + consequence, plain words") + — see ``utils/model_loading.py`` mtp warning and ``engine/dflash`` + L2 warning for the reference voice. Mechanism details + (sliceable whitelist, offset-skew failure mode) live in the + ``_all_layers_sliceable`` docstring where developers read them, + not in the operator log. """ already_recorded = self._mru_partial_supported is False self._mru_partial_supported = False @@ -278,16 +281,11 @@ def _record_mru_unsupported( set(layer_cache_types) - KNOWN_SLICEABLE_CACHE_TYPES ) else: - offenders = [""] + offenders = ["unknown"] logger.warning( - "MRU tail cache disabled for this model: layer types %s are not in " - "the sliceable whitelist %s. Splicing a partial into a " - "non-sliceable subset would cause per-layer offset skew at decode " - "(silent generation corruption), so every stash attempt will be " - "refused. The admin dashboard's per-model 'MRU Tails' cell will " - "display 'N/A (see log)'.", - offenders, - sorted(KNOWN_SLICEABLE_CACHE_TYPES), + "MRU tail cache enabled but this model is incompatible " + "(cache layers: %s); MRU tails will be inactive for this model.", + ", ".join(offenders), ) def _get_model_num_layers(self, model: Any) -> int: diff --git a/tests/test_prefix_cache.py b/tests/test_prefix_cache.py index fc4c2a61..2a77bdfa 100644 --- a/tests/test_prefix_cache.py +++ b/tests/test_prefix_cache.py @@ -3741,7 +3741,8 @@ def test_supported_latches_false_lazy_on_non_sliceable( assert stats.mru_partial_stashes == 0 # gate refused, no stash warns = [r for r in caplog.records if r.levelno == logging.WARNING] assert len(warns) == 1 - assert "MRU tail cache disabled" in warns[0].getMessage() + assert "MRU tails will be inactive" in warns[0].getMessage() + assert "incompatible" in warns[0].getMessage() assert "RotatingKVCache" in warns[0].getMessage() def test_warning_does_not_repeat_on_subsequent_non_sliceable( From a42542f719fd17cb8472164d0abc9ad57fc14d01 Mon Sep 17 00:00:00 2001 From: Blightbow Date: Fri, 15 May 2026 00:45:34 -0400 Subject: [PATCH 16/18] fix(cache): key MRU stash to the prompt boundary MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The MRU partial-block cache stashed the trailing partial of the stored sequence (``prompt + output``) keyed by that sequence's last full block, but a repeat request resubmits the *prompt* only and ``apply_mru_partial`` looks the entry up by the prompt's last full block. Output tokens shift every block boundary past the prompt's tail, so the two keys never coincide: ``_mru_partials.get(last_hash)`` always returned None. The feature produced zero hits for ordinary chat completions — observed live as "MRU tails 3/4 entries, MRU Tail Hit Rate 0.0%" with zero evictions (a key that is never found is never evicted). It only worked for reasoning models whose prompt ends with an open ```` tag, where ``needs_think_prefix`` makes ``store_cache`` persist the prompt alone. Fix --- Thread the prompt token count from the scheduler into ``store_cache`` -> ``_update_mru_partial``. The stash now keys off the prompt's last full block (block index ``prompt_len // block_size - 1``) and slices the prompt's trailing partial, not the stored sequence's. That block's hash is identical whether the sequence is blocked as ``prompt`` or ``prompt + output`` — block hashes are content-chained and the chain is byte-identical up to the prompt's partial tail — so the key the stash writes is exactly the key a prompt-only resubmission's ``apply_mru_partial`` computes. The arithmetic runs in the existing global-coordinate frame and accounts for ``existing_tokens > 0`` (on a resubmission ``store_cache`` appends to the fetched prefix block table and works in ``new_tokens`` space): ``partial_start = prompt_partial_start - existing_tokens``, with a guard for a prompt already fully covered by cached full blocks. Edge cases degrade cleanly — block-aligned prompt: no stash; prompt shorter than one block: ``None`` key (short-prompt path); ``prompt_token_count=None`` (generic ``CacheManager.store`` path): falls back to the whole stored sequence, reproducing the pre-fix behavior for verbatim-repeat callers. ``apply_mru_partial`` is unchanged — only the stash side was wrong. Draft cache ----------- ``_draft_prefix_cache`` (SpecPrefill) is constructed with ``mru_partial_max_entries=0``. ``apply_mru_partial`` is only ever called on the main ``block_aware_cache``, so a draft-cache stash was dead work that never paid off. Tests ----- New ``TestMRUPromptBoundaryStash`` (5 cases): - ``prompt_boundary_stash_hits_on_prompt_only_resubmit`` — the round-trip the original feature shipped without: store ``prompt + output`` with the boundary, resubmit prompt-only, assert a hit. - ``whole_sequence_stash_misses_on_prompt_only_resubmit`` — pins the original bug via the ``prompt_token_count=None`` path (0 hits, 0 evictions). - ``block_aligned_prompt_does_not_stash`` - ``short_prompt_stashes_under_none_key`` - ``prompt_boundary_stash_with_existing_cached_prefix`` — the ``existing_tokens > 0`` resubmission path. Diagnosis and approach were adversarially peer-reviewed; the review caught an ``existing_tokens``-relative indexing error in the first draft of the fix, corrected here. Full unit suite: 4401 passed (5 pre-existing upstream baseline failures unrelated to cache). Co-Authored-By: Claude Opus 4.7 (1M context) --- omlx/cache/prefix_cache.py | 115 +++++++++++++++++------ omlx/scheduler.py | 13 ++- tests/test_prefix_cache.py | 186 +++++++++++++++++++++++++++++++++++++ 3 files changed, 283 insertions(+), 31 deletions(-) diff --git a/omlx/cache/prefix_cache.py b/omlx/cache/prefix_cache.py index a914ce1a..be6476bd 100644 --- a/omlx/cache/prefix_cache.py +++ b/omlx/cache/prefix_cache.py @@ -505,6 +505,7 @@ def store_cache( extra_keys: tuple[Any, ...] | None = None, extra_key_token_start: int | None = None, extra_key_ranges: list[tuple[int, tuple[Any, ...]]] | None = None, + prompt_token_count: int | None = None, ) -> BlockTable | None: """ Store computed cache for future reuse. @@ -525,6 +526,14 @@ def store_cache( boundary_snapshots: Optional mapping of token_count -> extracted cache states for intermediate block boundaries. Used to store per-block ArraysCache state instead of placeholders in hybrid models. + prompt_token_count: Number of prompt tokens at the front of + ``tokens``. ``tokens`` is typically ``prompt + output``; + the MRU partial cache must stash the prompt's trailing + tail (the part a repeat request will resubmit), not the + sequence's. ``None`` means "no prompt boundary known" — + the MRU stash then treats the whole sequence as the + prompt (pre-MRU-fix behavior; correct for verbatim-repeat + callers like the generic CacheManager.store path). Returns: BlockTable for the stored cache, or None on failure @@ -825,18 +834,17 @@ def store_cache( last_access=time.time(), ) - # Stash the trailing sub-block tail in memory so an immediate - # repeat request can splice it back in without a re-prefill. - # The slot is replaced unconditionally — if there's no eligible - # tail (block-aligned, hybrid model, extraction failure) we clear - # it so a stale partial from a previous store cannot survive. + # Stash the prompt's trailing sub-block tail in memory so an + # immediate repeat request can splice it back in without a + # re-prefill. Keyed by the prompt's last full block so the + # lookup in apply_mru_partial (which sees a prompt-only repeat) + # finds it — see _update_mru_partial. self._update_mru_partial( new_tokens=new_tokens, cache_data=cache_data, block_table=block_table, existing_tokens=existing_tokens, - num_new_blocks=num_new_blocks, - trailing_partial_tokens=trailing_partial_tokens, + prompt_token_count=prompt_token_count, is_tensor_data=bool(is_tensor_data), layer_cache_types=layer_cache_types, model_cache_config=model_cache_config, @@ -918,19 +926,28 @@ def _update_mru_partial( cache_data: list[Any], block_table: BlockTable, existing_tokens: int, - num_new_blocks: int, - trailing_partial_tokens: int, + prompt_token_count: int | None, is_tensor_data: bool, layer_cache_types: list[str] | None, model_cache_config: ModelCacheConfig | None, ) -> None: - """Stash the trailing partial from a just-completed ``store_cache``. + """Stash the prompt's trailing partial from a just-completed ``store_cache``. + + ``store_cache`` is given the full ``prompt + output`` sequence, + but a repeat request resubmits the *prompt* only — and + ``apply_mru_partial`` looks the entry up by the prompt's last + full block. So the stash must key off, and slice, the + **prompt's** trailing partial, not the stored sequence's. + ``prompt_token_count`` carries the prompt boundary; ``None`` + falls back to "whole sequence is the prompt" (pre-fix behavior, + correct for verbatim-repeat callers). On success, writes one entry into the LRU map keyed by - ``parent_hash``. If the map is at capacity, the oldest entry is - evicted via ``popitem(last=False)``. + ``parent_hash`` (the prompt's last full block, or ``None`` for a + prompt shorter than one block). If the map is at capacity, the + oldest entry is evicted via ``popitem(last=False)``. - The "no eligible tail" branches (no trailing tokens, non-tensor + The "no eligible tail" branches (block-aligned prompt, non-tensor data, no reconstruct path configured, hybrid model, extraction failure, ambiguous cache layout) bare-return — they signal a local "nothing to stash this time," NOT a global "wipe the map." @@ -950,22 +967,45 @@ def _update_mru_partial( return if self._mru_partial_supported is None: self._mru_partial_supported = True - if ( - trailing_partial_tokens == 0 - or not is_tensor_data - or not self._can_reconstruct() - ): + if not is_tensor_data or not self._can_reconstruct(): return - partial_start = num_new_blocks * self.block_size - partial_global_start = existing_tokens + partial_start - partial_global_end = partial_global_start + trailing_partial_tokens + # Resolve the prompt boundary within the stored sequence. None + # means "no boundary known" — treat the whole stored sequence as + # the prompt, which reproduces the pre-fix whole-sequence stash. + sequence_len = existing_tokens + len(new_tokens) + if prompt_token_count is None: + prompt_token_count = sequence_len + else: + prompt_token_count = min(prompt_token_count, sequence_len) + # Prompt fully covered by already-cached full blocks: its tail (if + # any) is partial and partials are never stored as blocks, so a + # prompt that ends at or before existing_tokens is block-aligned + # and has nothing for the MRU to add. + if prompt_token_count <= existing_tokens: + return + + # Prompt's trailing partial: global range + # [prompt_partial_start, prompt_token_count). Block-aligned + # prompt (partial_len == 0) → every token lands in a full paged + # block, nothing to stash. + prompt_partial_len = prompt_token_count % self.block_size + if prompt_partial_len == 0: + return + prompt_partial_start = prompt_token_count - prompt_partial_len + prompt_full_blocks = prompt_token_count // self.block_size + # new_tokens-relative offset of the prompt's trailing partial. + # existing_tokens is block-aligned and <= prompt_partial_start, + # so this is a valid index into new_tokens. + partial_start = prompt_partial_start - existing_tokens + partial_global_start = prompt_partial_start + partial_global_end = prompt_token_count # Decide which axis-2 index range covers the trailing partial: # - Global: cache_data spans the full sequence (prefix + new tail); # slice at [partial_global_start:partial_global_end]. # - Local: cache_data spans only the newly-processed suffix; - # slice at [partial_start:partial_start + trailing_partial_tokens]. + # slice at [partial_start:partial_start + prompt_partial_len]. # # We classify by exact cache length, not the previous # ``cache_seq_len >= existing_tokens + 1`` heuristic — that one @@ -985,7 +1025,7 @@ def _update_mru_partial( elif existing_tokens == 0 or cache_seq_len == local_len: # Cache covers only the new suffix (or there is no prefix). p_cache_start = partial_start - p_cache_end = partial_start + trailing_partial_tokens + p_cache_end = partial_start + prompt_partial_len else: # Ambiguous: cache_seq_len is short of global_end but does # not equal local_len either. Refuse rather than guess. @@ -1004,12 +1044,27 @@ def _update_mru_partial( if not partial_kv: return + # Key by the PROMPT's last full block (block index + # prompt_full_blocks - 1), so a prompt-only repeat — which is what + # apply_mru_partial sees — finds this entry. block_table.block_ids + # is ordered [fetched-prefix blocks..., newly-built blocks...] in + # sequence order, and the stored sequence has at least + # prompt_full_blocks full blocks (the prompt is a prefix of it), + # so the index is in range. prompt_full_blocks == 0 (prompt + # shorter than one block) keeps parent_hash = None — the + # short-prompt None-key path. parent_hash: bytes | None = None - if block_table.block_ids: - last_bid = block_table.block_ids[-1] - last_block = self.paged_cache.allocated_blocks.get(last_bid) - if last_block is not None and last_block.block_hash: - parent_hash = last_block.block_hash + if prompt_full_blocks >= 1: + if len(block_table.block_ids) < prompt_full_blocks: + # Prompt's last full block isn't in the table (block build + # rolled back on an extraction failure). Skip rather than + # mis-key against the None short-prompt slot. + return + parent_bid = block_table.block_ids[prompt_full_blocks - 1] + parent_block = self.paged_cache.allocated_blocks.get(parent_bid) + if parent_block is None or not parent_block.block_hash: + return + parent_hash = parent_block.block_hash # LRU put: pop any existing entry first so the re-insert lands at # the tail with fresh order; then evict the oldest entry if over @@ -1021,7 +1076,7 @@ def _update_mru_partial( self._mru_partials.pop(parent_hash) self._mru_partials[parent_hash] = _MRUPartialBlock( parent_hash=parent_hash, - tokens=new_tokens[partial_start:], + tokens=new_tokens[partial_start : partial_start + prompt_partial_len], kv_data=partial_kv, ) self._mru_partial_stashes += 1 @@ -1032,7 +1087,7 @@ def _update_mru_partial( logger.debug( "Stashed MRU partial: %d tokens, parent_hash=%s, layers=%d, " "entries=%d/%d", - trailing_partial_tokens, + prompt_partial_len, parent_hash[:8].hex() + "..." if parent_hash else "None", len(partial_kv), len(self._mru_partials), diff --git a/omlx/scheduler.py b/omlx/scheduler.py index cc34c76a..89152ae7 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -1039,6 +1039,7 @@ def _async_store_cache_worker( extra_keys: tuple[Any, ...] | None, extra_key_token_start: int | None, extra_key_ranges: list[tuple[int, tuple[Any, ...]]] | None, + prompt_token_count: int, ) -> None: """Run store_cache + paged_cache cleanup off the inference thread. @@ -1081,6 +1082,7 @@ def _async_store_cache_worker( extra_keys=extra_keys, extra_key_token_start=extra_key_token_start, extra_key_ranges=extra_key_ranges, + prompt_token_count=prompt_token_count, ) if block_table is None and self.paged_cache_manager is not None: block_table = self.paged_cache_manager.get_block_table(request_id) @@ -3474,7 +3476,10 @@ def set_specprefill_draft_model( model=draft_model, paged_cache_manager=draft_paged, paged_ssd_cache_manager=self.paged_ssd_cache_manager, - mru_partial_max_entries=self.config.mru_partial_max_entries, + # MRU disabled on the draft cache: apply_mru_partial is + # only ever called on the main block_aware_cache, so a + # draft-cache stash is dead work that never pays off. + mru_partial_max_entries=0, ) self._draft_prefix_cache.set_cold_restore_callback( self._restore_block_from_cold @@ -5184,6 +5189,10 @@ def _cleanup_finished(self, finished_ids: set[str]) -> None: if pre_eval_arrays: mx.async_eval(*pre_eval_arrays) + # Prompt boundary for the MRU partial stash: + # token_sequence_to_store is prompt+output, but a + # repeat request resubmits the prompt only. + prompt_token_count = len(request.prompt_token_ids) if self._store_cache_executor is not None: store_future = self._store_cache_executor.submit( self._async_store_cache_worker, @@ -5195,6 +5204,7 @@ def _cleanup_finished(self, finished_ids: set[str]) -> None: request.vlm_extra_keys_for_cache, request.vlm_extra_key_token_start_for_cache, request.vlm_extra_key_ranges_for_cache, + prompt_token_count, ) self._inflight_store_futures[request_id] = store_future else: @@ -5208,6 +5218,7 @@ def _cleanup_finished(self, finished_ids: set[str]) -> None: request.vlm_extra_keys_for_cache, request.vlm_extra_key_token_start_for_cache, request.vlm_extra_key_ranges_for_cache, + prompt_token_count, ) logger.debug( f"Submitted async store_cache for {request_id} " diff --git a/tests/test_prefix_cache.py b/tests/test_prefix_cache.py index 2a77bdfa..7abf4120 100644 --- a/tests/test_prefix_cache.py +++ b/tests/test_prefix_cache.py @@ -3853,6 +3853,192 @@ def test_eager_check_survives_make_cache_failure( assert cache.get_stats().mru_partial_supported is None +def _store_seq(cache, mx, request_id, tokens, *, prompt_token_count=None): + """``store_cache`` a full token sequence; return the BlockTable. + + ``cache_data`` spans the whole sequence so ``_update_mru_partial`` + takes the global index path (the cache covers ``prompt + output``), + matching the production resubmission layout. + """ + cache_data = [_kv_layer(mx, len(tokens)) for _ in range(4)] + return cache.store_cache( + request_id, tokens, cache_data, prompt_token_count=prompt_token_count, + ) + + +class TestMRUPromptBoundaryStash: + """The MRU stash must key off the *prompt's* trailing partial, not the + stored sequence's. + + ``store_cache`` is handed ``prompt + output``, but a repeat request + resubmits the prompt only and ``apply_mru_partial`` looks the entry + up by the prompt's last full block. Before the prompt boundary was + threaded in, the stash keyed off ``prompt + output``'s last full + block — a key a prompt-only resubmit could never compute — so the + feature never produced a hit for ordinary chat completions. + + block_size is 4 in this fixture. + """ + + @pytest.fixture + def mx(self): + try: + import mlx.core as mx + return mx + except ImportError: + pytest.skip("MLX not available") + + @pytest.fixture + def paged_cache(self): + return PagedCacheManager( + block_size=4, + max_blocks=100, + model_name="test-model", + initial_blocks=100, + ) + + @pytest.fixture + def mock_ssd(self): + mock = MagicMock() + mock.save_block.return_value = True + mock.load_block.return_value = None + mock.load_block_with_metadata.return_value = (None, None) + mock.has_block.return_value = False + return mock + + def test_prompt_boundary_stash_hits_on_prompt_only_resubmit( + self, paged_cache, mock_ssd, mx + ): + """The decisive regression test: store ``prompt + output`` with the + prompt boundary, then resubmit the prompt only — apply must HIT. + """ + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + # prompt = 1 full block [10..13] + 2-token tail [14,15]; output = 3. + prompt = [10, 11, 12, 13, 14, 15] + stored = prompt + [90, 91, 92] + store_bt = _store_seq( + cache, mx, "store", stored, prompt_token_count=len(prompt) + ) + + # Keyed by the prompt's last full block (block 0) — NOT the stored + # sequence's last full block (block 1) — and stashing the prompt + # tail [14,15], not the sequence tail [92]. + prompt_block = paged_cache.allocated_blocks[store_bt.block_ids[0]] + assert prompt_block.block_hash in cache._mru_partials + assert cache._mru_partials[prompt_block.block_hash].tokens == [14, 15] + + # Simulate fetch_cache(prompt): a block table with the prompt's + # full blocks only — apply_mru_partial keys off block_ids[-1]. + fetch_bt = BlockTable(request_id="resubmit") + fetch_bt.block_ids = [store_bt.block_ids[0]] + fetch_bt.num_tokens = 4 + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + + _, new_remaining, applied = cache.apply_mru_partial( + reconstructed, fetch_bt, [14, 15] + ) + assert applied == 2 + assert new_remaining == [] + stats = cache.get_stats() + assert stats.mru_partial_hits == 1 + assert stats.mru_partial_tokens_saved == 2 + + def test_whole_sequence_stash_misses_on_prompt_only_resubmit( + self, paged_cache, mock_ssd, mx + ): + """Pins the original bug: with no prompt boundary + (``prompt_token_count=None``) the stash keys off the stored + sequence's last full block, which a prompt-only resubmit never + reaches — 0 hits, and 0 evictions because the lookup key is never + found (no entry to evict). + """ + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + prompt = [10, 11, 12, 13, 14, 15] + stored = prompt + [90, 91, 92] + store_bt = _store_seq(cache, mx, "store", stored) # boundary unknown + + # None falls back to whole-sequence: keyed off block 1 (the stored + # sequence's last full block), unreachable by a prompt-only fetch. + seq_block = paged_cache.allocated_blocks[store_bt.block_ids[1]] + assert seq_block.block_hash in cache._mru_partials + + fetch_bt = BlockTable(request_id="resubmit") + fetch_bt.block_ids = [store_bt.block_ids[0]] + fetch_bt.num_tokens = 4 + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + + _, _, applied = cache.apply_mru_partial(reconstructed, fetch_bt, [14, 15]) + assert applied == 0 + stats = cache.get_stats() + assert stats.mru_partial_hits == 0 + assert stats.mru_partial_evictions == 0 # key miss — nothing evicted + + def test_block_aligned_prompt_does_not_stash( + self, paged_cache, mock_ssd, mx + ): + """A prompt that is an exact multiple of block_size has no partial + tail — every prompt token lands in a full paged block, so the MRU + has nothing to add.""" + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + prompt = [10, 11, 12, 13, 14, 15, 16, 17] # 8 tokens = 2 full blocks + stored = prompt + [90, 91, 92] + _store_seq(cache, mx, "store", stored, prompt_token_count=len(prompt)) + assert not cache._mru_partials + assert cache.get_stats().mru_partial_stashes == 0 + + def test_short_prompt_stashes_under_none_key( + self, paged_cache, mock_ssd, mx + ): + """A prompt shorter than one block has no last full block; the + stash is keyed by None (the short-prompt path) and holds the + whole prompt as its tail.""" + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + prompt = [10, 11, 12] # 3 tokens < block_size 4 + stored = prompt + [90, 91, 92, 93, 94] + _store_seq(cache, mx, "store", stored, prompt_token_count=len(prompt)) + assert None in cache._mru_partials + assert cache._mru_partials[None].tokens == [10, 11, 12] + + def test_prompt_boundary_stash_with_existing_cached_prefix( + self, paged_cache, mock_ssd, mx + ): + """Resubmission path: store_cache runs with ``existing_tokens > 0`` + (the prompt's leading blocks are already cached) and works in + ``new_tokens`` space. The prompt-boundary arithmetic must still + resolve the prompt's last full block — not an index shifted by + ``existing_tokens``. + """ + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + # prompt = 2 full blocks [10..17] + 2-token tail [18,19]; output = 3. + prompt = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19] + stored = prompt + [90, 91, 92] + # Round 1: store the prompt's first full block only, same + # request_id, so round 2's store sees existing_tokens > 0. + _store_seq(cache, mx, "req", prompt[:8], prompt_token_count=8) + assert not cache._mru_partials # block-aligned round-1, no stash + # Round 2: same request_id — existing_tokens == 8. + store_bt = _store_seq( + cache, mx, "req", stored, prompt_token_count=len(prompt) + ) + + # Prompt's last full block is index 1 ([14..17]); the stash must + # key off it despite existing_tokens=8 and new_tokens-space math. + prompt_block = paged_cache.allocated_blocks[store_bt.block_ids[1]] + assert prompt_block.block_hash in cache._mru_partials + assert cache._mru_partials[prompt_block.block_hash].tokens == [18, 19] + + fetch_bt = BlockTable(request_id="resubmit") + fetch_bt.block_ids = store_bt.block_ids[:2] + fetch_bt.num_tokens = 8 + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=8) + _, new_remaining, applied = cache.apply_mru_partial( + reconstructed, fetch_bt, [18, 19] + ) + assert applied == 2 + assert new_remaining == [] + assert cache.get_stats().mru_partial_hits == 1 + + class TestHasMRUPartial: """The has_mru_partial() accessor is the public API the scheduler uses to decide whether to suppress the deferred Metal cache clear.""" From 88f809b79a483181bb64a4ec115eb5def373c17d Mon Sep 17 00:00:00 2001 From: Blightbow Date: Fri, 15 May 2026 01:18:29 -0400 Subject: [PATCH 17/18] fix(cache): eval MRU partial KV at stash time MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ``a42542f`` made ``apply_mru_partial`` produce hits for the first time — and immediately exposed a latent threading bug in the splice path, which had been dead code since the feature shipped. ``_extract_block_tensor_slice`` builds the partial's tensors via ``_clone_tensor`` (``mx.copy``), which is a *lazy* op. That op is created on the ``omlx-store-cache`` worker thread (where ``store_cache`` runs) and bound to that thread's MLX stream. ``apply_mru_partial`` splices the partial into a live cache on the separate ``mlx-global`` inference thread; generation's ``mx.async_eval`` then walks the compute graph back to the worker thread's stream, which the inference thread cannot see: RuntimeError: There is no Stream(gpu, 4) in current thread. The source ``_extracted_cache`` is already materialized before the worker runs (the inference thread batches ``_collect_arrays_from_extracted_cache`` through ``mx.async_eval`` and the worker calls ``mx.synchronize()``), so the fix is just to finalize the freshly-sliced copies: ``_update_mru_partial`` now calls ``_materialize_mru_kv`` on the extracted partial before stashing it. Because the inputs are already resident this is a small memcpy of the tail KV — no recompute — and it collapses the lazy ``mx.copy`` into concrete, stream-free data safe to splice and evaluate from any thread. ``apply_mru_partial`` itself is unchanged: ``add_request`` (the splice) and ``step`` (generation) both run on the single ``mlx-global`` worker, so the splice result never crosses threads — only the stashed input did. Tests ----- New ``TestMRUPartialCrossThreadSafety``: - ``materialize_mru_kv_handles_extract_shapes`` — the helper evaluates array leaves across the plain ``(keys, values)`` and TurboQuant ``(tag, (k, v))`` shapes and tolerates the non-array tag and an empty list. - ``stashed_partial_splices_across_threads`` — extract+stash on a worker thread, splice+evaluate on the main thread. Verified to fail without the fix (``no Stream(gpu, N) in current thread`` at the splice eval) and pass with it. The test pre-materializes ``cache_data`` to mirror production, where the inference thread always hands the worker an already-evaluated extracted cache. Co-Authored-By: Claude Opus 4.7 (1M context) --- omlx/cache/prefix_cache.py | 38 +++++++++++++ tests/test_prefix_cache.py | 107 +++++++++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+) diff --git a/omlx/cache/prefix_cache.py b/omlx/cache/prefix_cache.py index be6476bd..697c9a9e 100644 --- a/omlx/cache/prefix_cache.py +++ b/omlx/cache/prefix_cache.py @@ -1043,6 +1043,11 @@ def _update_mru_partial( ) if not partial_kv: return + # Materialize now, on this (store-cache worker) thread. The + # tensors are lazy ops bound to this thread's stream; the splice + # in apply_mru_partial runs on the separate inference thread and + # must not inherit a cross-thread stream dependency. + self._materialize_mru_kv(partial_kv) # Key by the PROMPT's last full block (block index # prompt_full_blocks - 1), so a prompt-only repeat — which is what @@ -1776,6 +1781,39 @@ def _clone_tensor(self, tensor: Any) -> Any: return mx.array(tensor) + def _materialize_mru_kv(self, partial_kv: list[Any]) -> None: + """Force-evaluate a freshly-extracted MRU partial's KV tensors. + + ``_extract_block_tensor_slice`` builds the tensors as lazy + ``mx.copy`` ops on whichever thread ``store_cache`` runs on — the + ``omlx-store-cache`` worker. ``apply_mru_partial`` later splices + them into a live cache on the separate ``mlx-global`` inference + thread; while the tensors are still lazy, that thread's + ``mx.async_eval`` would walk the compute graph back to a + per-thread stream the inference thread cannot see, and MLX raises + ``RuntimeError: There is no Stream(gpu, N) in current thread``. + Evaluating here, on the worker, leaves concrete stream-free data + safe to splice and evaluate from any thread. + + Walks the nested ``(keys, values)`` / TurboQuant ``(tag, (k, v))`` + shapes ``_extract_block_tensor_slice`` returns and evaluates every + ``mx.array`` leaf in one batched call. + """ + if not HAS_MLX: + return + leaves: list[Any] = [] + + def _collect(obj: Any) -> None: + if isinstance(obj, mx.array): + leaves.append(obj) + elif isinstance(obj, (list, tuple)): + for item in obj: + _collect(item) + + _collect(partial_kv) + if leaves: + mx.eval(*leaves) + def _apply_window_padding( self, matched_blocks: int, diff --git a/tests/test_prefix_cache.py b/tests/test_prefix_cache.py index 7abf4120..42d4e5d1 100644 --- a/tests/test_prefix_cache.py +++ b/tests/test_prefix_cache.py @@ -4039,6 +4039,113 @@ def test_prompt_boundary_stash_with_existing_cached_prefix( assert cache.get_stats().mru_partial_hits == 1 +class TestMRUPartialCrossThreadSafety: + """An MRU partial is extracted on the store-cache worker thread but + spliced into a live cache on the separate inference thread. + + ``_extract_block_tensor_slice`` builds the partial as lazy + ``mx.copy`` ops; an unevaluated tensor carries a pending op bound to + the worker's per-thread MLX stream, and evaluating the splice on the + inference thread raises ``RuntimeError: There is no Stream(gpu, N) + in current thread``. ``_update_mru_partial`` must materialize the + partial at stash time so the stashed data is concrete and + stream-free. + """ + + @pytest.fixture + def mx(self): + try: + import mlx.core as mx + return mx + except ImportError: + pytest.skip("MLX not available") + + @pytest.fixture + def paged_cache(self): + return PagedCacheManager( + block_size=4, + max_blocks=100, + model_name="test-model", + initial_blocks=100, + ) + + @pytest.fixture + def mock_ssd(self): + mock = MagicMock() + mock.save_block.return_value = True + mock.load_block.return_value = None + mock.load_block_with_metadata.return_value = (None, None) + mock.has_block.return_value = False + return mock + + def test_materialize_mru_kv_handles_extract_shapes( + self, paged_cache, mock_ssd, mx + ): + """``_materialize_mru_kv`` evaluates every ``mx.array`` leaf across + the plain ``(keys, values)`` and TurboQuant ``(tag, (k, v))`` + shapes ``_extract_block_tensor_slice`` returns, and tolerates the + non-array tag string and an empty list.""" + cache = _make_mru_cache(paged_cache, mock_ssd) + plain = [(mx.ones((1, 1, 2, 4)), mx.ones((1, 1, 2, 4)))] + tagged = [ + ("__turboquant_v2__", (mx.ones((1, 1, 2, 4)), mx.ones((1, 1, 2, 4)))) + ] + # None of these should raise. + cache._materialize_mru_kv(plain) + cache._materialize_mru_kv(tagged) + cache._materialize_mru_kv([]) + + def test_stashed_partial_splices_across_threads( + self, paged_cache, mock_ssd, mx + ): + """Extract+stash on a worker thread, splice+evaluate on the main + thread. Without stash-time materialization the final ``mx.eval`` + raises a foreign-stream ``RuntimeError``; with it, the cross- + thread handoff is clean. + """ + import concurrent.futures + + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + prompt = [10, 11, 12, 13, 14, 15] + stored = prompt + [90, 91, 92] + cache_data = [_kv_layer(mx, len(stored)) for _ in range(4)] + # Mirror production: the inference thread materializes the + # extracted cache (mx.async_eval + the worker's mx.synchronize) + # before the store worker runs. Without this the cache_data + # arrays would still be lazy ops bound to THIS thread's stream — + # a different cross-thread failure than the one under test. + for layer in cache_data: + mx.eval(*layer["state"]) + + # Stash on a dedicated worker thread, mirroring the production + # _store_cache_executor (a pool distinct from the inference one). + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + store_bt = pool.submit( + cache.store_cache, + "store", + stored, + cache_data, + prompt_token_count=len(prompt), + ).result() + + # Splice on this (the "inference") thread. + fetch_bt = BlockTable(request_id="resubmit") + fetch_bt.block_ids = [store_bt.block_ids[0]] + fetch_bt.num_tokens = 4 + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + + spliced, new_remaining, applied = cache.apply_mru_partial( + reconstructed, fetch_bt, [14, 15] + ) + assert applied == 2 + assert new_remaining == [] + + # Force evaluation on this thread — the point a partial still + # carrying the worker thread's stream would fail. + for layer in spliced: + mx.eval(layer.keys, layer.values) + + class TestHasMRUPartial: """The has_mru_partial() accessor is the public API the scheduler uses to decide whether to suppress the deferred Metal cache clear.""" From b04b18a540aed308d76b8fdfd76dc639b820b8bf Mon Sep 17 00:00:00 2001 From: Blightbow Date: Fri, 15 May 2026 14:09:13 -0400 Subject: [PATCH 18/18] refactor(admin): drop global MRU-tails gauge MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The header gauge in RUNTIME CACHE OBSERVABILITY sat alongside the Memory and SSD gauges, but it did not belong there. Memory and SSD each measure one exhaustible budget shared across every loaded model, so an aggregate fill bar is meaningful. MRU tail slots are allocated per-model — ``--mru-partial-max-entries`` applies to each model's own cache — so summing entries and max-entries across models produces a number that corresponds to no real resource. Per-model occupancy is already shown in the "MRU Tails" column of the per-model table, which is the correct granularity. - Remove the header gauge block from ``_status.html``. - Remove the now-dead ``runtimeMruPartialPercent`` getter from ``dashboard.js``. - Drop the payload-level ``mru_partial_entries`` aggregate in ``_build_runtime_cache_observability``. ``mru_partial_max_ entries`` is kept as a sum solely as the ``mruEnabled`` feature-on gate (drives the rate strip and the per-model column); it is no longer surfaced as a gauge value. Per-model ``mru_partial_entries`` / ``mru_partial_max_entries`` on each ``models[]`` entry are unchanged, as are the global "MRU Tail Hit Rate" and "MRU Tokens Saved" rate-strip cells (those are rates/counters, legitimately global). Co-Authored-By: Claude Opus 4.7 (1M context) --- omlx/admin/routes.py | 19 +++++++++---------- omlx/admin/static/js/dashboard.js | 11 +++++------ omlx/admin/templates/dashboard/_status.html | 15 ++++----------- 3 files changed, 18 insertions(+), 27 deletions(-) diff --git a/omlx/admin/routes.py b/omlx/admin/routes.py index 0393c203..ad11bb03 100644 --- a/omlx/admin/routes.py +++ b/omlx/admin/routes.py @@ -3502,10 +3502,11 @@ def _build_runtime_cache_observability( "hot_cache_max_bytes": 0, "hot_cache_size_bytes": 0, "hot_cache_entries": 0, - # MRU partial cache (memory-only, sub-block tail of prior prefills). - # Capacity sums across models; current entries sum across models; - # max-capacity is the highest configured per-model value. - "mru_partial_entries": 0, + # MRU partial cache feature gate. Per-model occupancy lives on each + # models[] entry; there is no payload-level entries aggregate + # because MRU tail slots are per-model, not a shared budget — only + # this max-entries sum is kept, purely so the dashboard can tell + # whether the feature is configured for any loaded model. "mru_partial_max_entries": 0, } @@ -3659,23 +3660,21 @@ def _build_runtime_cache_observability( disk_max = payload["disk_max_bytes"] hot_cache_size_total = 0 hot_cache_entries_total = 0 - mru_entries_total = 0 mru_max_entries_total = 0 for m in payload["models"]: hot_cache_size_total += m.get("hot_cache_size_bytes", 0) hot_cache_entries_total += m.get("hot_cache_entries", 0) hot_cache_max += m.get("hot_cache_max_bytes", 0) disk_max = max(disk_max, m.get("max_size_bytes", 0)) - # MRU partials: each model has its own dict; entries sum across - # models the same way the hot cache does. max_entries sums for - # the same reason (each model reserves its own slice of capacity). - mru_entries_total += m.get("mru_partial_entries", 0) + # MRU: only the max-entries sum is kept, and only as a feature-on + # gate for the dashboard. Per-model occupancy is on each models[] + # entry; an aggregate live count would be meaningless because the + # slots are per-model, not a shared budget. mru_max_entries_total += m.get("mru_partial_max_entries", 0) payload["hot_cache_max_bytes"] = hot_cache_max payload["hot_cache_size_bytes"] = hot_cache_size_total payload["hot_cache_entries"] = hot_cache_entries_total payload["disk_max_bytes"] = disk_max - payload["mru_partial_entries"] = mru_entries_total payload["mru_partial_max_entries"] = mru_max_entries_total # Fallback: if no loaded models contributed stats, scan the cache diff --git a/omlx/admin/static/js/dashboard.js b/omlx/admin/static/js/dashboard.js index 64903ae5..cb3acbf3 100644 --- a/omlx/admin/static/js/dashboard.js +++ b/omlx/admin/static/js/dashboard.js @@ -2296,12 +2296,11 @@ return Math.min(100, (rc.hot_cache_size_bytes / rc.hot_cache_max_bytes) * 100); }, - get runtimeMruPartialPercent() { - const rc = this.stats.runtime_cache; - if (!rc || !rc.mru_partial_max_entries) return 0; - return Math.min(100, (rc.mru_partial_entries / rc.mru_partial_max_entries) * 100); - }, - + // mruEnabled is a feature-on gate (drives the rate strip and the + // per-model MRU Tails column). It reads the payload-level + // mru_partial_max_entries purely as "configured for any loaded + // model" — there is deliberately no aggregate MRU-tails gauge, + // since the slots are per-model, not a shared budget. get mruEnabled() { return (this.stats.runtime_cache?.mru_partial_max_entries || 0) > 0; }, diff --git a/omlx/admin/templates/dashboard/_status.html b/omlx/admin/templates/dashboard/_status.html index aa9b6f9c..f1284044 100644 --- a/omlx/admin/templates/dashboard/_status.html +++ b/omlx/admin/templates/dashboard/_status.html @@ -313,17 +313,10 @@

{{ t('status.head | - -
- MRU tails -
-
-
- -
- | +
SSD

Block Size Indexed Blocks Sub-block CacheCache FilesCache SizeSSD FilesSSD SizeMemory EntriesMemory Size
SSD Size Memory Entries Memory SizeMRU Tails
SSD Size Memory Entries Memory SizeMRU TailsMRU Tails
+ N/A (see log) + +