diff --git a/omlx/admin/routes.py b/omlx/admin/routes.py index 06ceb808b..ad11bb030 100644 --- a/omlx/admin/routes.py +++ b/omlx/admin/routes.py @@ -3483,6 +3483,13 @@ def _build_runtime_cache_observability( } cache_dir = global_settings.cache.get_ssd_cache_dir(global_settings.base_path) + cache_cfg = global_settings.cache + try: + cfg_disk_max = cache_cfg.get_ssd_cache_max_size_bytes(global_settings.base_path) + except (ValueError, OSError, TypeError) as exc: + logger.warning("Could not read SSD cache max size from config: %s", exc) + cfg_disk_max = 0 + payload = { "base_path": str(global_settings.base_path), "ssd_cache_dir": str(cache_dir), @@ -3491,6 +3498,16 @@ def _build_runtime_cache_observability( "total_num_files": 0, "total_size_bytes": 0, "effective_block_sizes": [], + "disk_max_bytes": cfg_disk_max, + "hot_cache_max_bytes": 0, + "hot_cache_size_bytes": 0, + "hot_cache_entries": 0, + # MRU partial cache feature gate. Per-model occupancy lives on each + # models[] entry; there is no payload-level entries aggregate + # because MRU tail slots are per-model, not a shared budget — only + # this max-entries sum is kept, purely so the dashboard can tell + # whether the feature is configured for any loaded model. + "mru_partial_max_entries": 0, } engine_pool = _get_engine_pool() @@ -3602,11 +3619,28 @@ def _build_runtime_cache_observability( "last_tokens_to_next_block": last_tokens_to_next_block, "num_files": int(ssd_stats.get("num_files", 0) or 0), "total_size_bytes": int(ssd_stats.get("total_size_bytes", 0) or 0), + "max_size_bytes": int(ssd_stats.get("max_size_bytes", 0) or 0), "hot_cache_max_bytes": int(ssd_stats.get("hot_cache_max_bytes", 0) or 0), "hot_cache_size_bytes": int(ssd_stats.get("hot_cache_size_bytes", 0) or 0), "hot_cache_entries": int(ssd_stats.get("hot_cache_entries", 0) or 0), + "mru_partial_entries": int( + prefix_stats.get("mru_partial_entries", 0) or 0 + ), + "mru_partial_max_entries": int( + prefix_stats.get("mru_partial_max_entries", 0) or 0 + ), + # Tri-state: None (unknown / no inference yet), True (eligible), + # False (model uses non-sliceable cache layers — every stash + # refused at the safety gate; dashboard renders 'N/A (see log)'). + "mru_partial_supported": prefix_stats.get( + "mru_partial_supported", None + ), } + cache_rates = runtime_stats.get("cache_rates") + if cache_rates: + model_payload["cache_rates"] = cache_rates + payload["models"].append(model_payload) payload["total_num_files"] += model_payload["num_files"] payload["total_size_bytes"] += model_payload["total_size_bytes"] @@ -3616,6 +3650,33 @@ def _build_runtime_cache_observability( payload["effective_block_sizes"] = sorted(block_sizes) + # Aggregate hot-cache and disk-max across models. + # hot_cache_max sums across models (each model reserves its own slice of + # the same process-wide hot cache budget) so the gauge denominator matches + # the summed numerator. disk_max keeps the config fallback via max() + # because a single SSD cache directory is shared — the effective cap is + # the largest configured limit, not a per-model sum. + hot_cache_max = 0 + disk_max = payload["disk_max_bytes"] + hot_cache_size_total = 0 + hot_cache_entries_total = 0 + mru_max_entries_total = 0 + for m in payload["models"]: + hot_cache_size_total += m.get("hot_cache_size_bytes", 0) + hot_cache_entries_total += m.get("hot_cache_entries", 0) + hot_cache_max += m.get("hot_cache_max_bytes", 0) + disk_max = max(disk_max, m.get("max_size_bytes", 0)) + # MRU: only the max-entries sum is kept, and only as a feature-on + # gate for the dashboard. Per-model occupancy is on each models[] + # entry; an aggregate live count would be meaningless because the + # slots are per-model, not a shared budget. + mru_max_entries_total += m.get("mru_partial_max_entries", 0) + payload["hot_cache_max_bytes"] = hot_cache_max + payload["hot_cache_size_bytes"] = hot_cache_size_total + payload["hot_cache_entries"] = hot_cache_entries_total + payload["disk_max_bytes"] = disk_max + payload["mru_partial_max_entries"] = mru_max_entries_total + # Fallback: if no loaded models contributed stats, scan the cache # directory directly so the dashboard still shows real disk usage. if payload["total_num_files"] == 0 and cache_dir.exists(): @@ -3870,6 +3931,30 @@ async def clear_alltime_stats(is_admin: bool = Depends(require_admin)): return {"status": "ok"} +def _iter_loaded_schedulers(): + """Yield (model_id, scheduler) for each loaded model. + + Traverses the internal engine hierarchy: pool entry → async engine → + core engine → scheduler. Both ``clear_ssd_cache`` and + ``clear_hot_cache`` share this traversal. + """ + engine_pool = _get_engine_pool() + if engine_pool is None: + return + for model_info in engine_pool.get_status().get("models", []): + model_id = model_info.get("id") + if not model_id or not model_info.get("loaded"): + continue + entry = engine_pool._entries.get(model_id) + if entry is None or entry.engine is None: + continue + async_core = getattr(entry.engine, "_engine", None) + core = getattr(async_core, "engine", None) if async_core is not None else None + scheduler = getattr(core, "scheduler", None) if core is not None else None + if scheduler is not None: + yield model_id, scheduler + + @router.post("/api/ssd-cache/clear") async def clear_ssd_cache(is_admin: bool = Depends(require_admin)): """Clear all SSD cache files for all loaded models. @@ -3880,38 +3965,33 @@ async def clear_ssd_cache(is_admin: bool = Depends(require_admin)): """ total_deleted = 0 - # Phase 1: clear via loaded models' cache managers (updates in-memory index) - engine_pool = _get_engine_pool() - if engine_pool is not None: - for model_info in engine_pool.get_status().get("models", []): - model_id = model_info.get("id") - if not model_id or not model_info.get("loaded"): - continue - - entry = engine_pool._entries.get(model_id) - if entry is None or entry.engine is None: - continue - - async_core = getattr(entry.engine, "_engine", None) - core = ( - getattr(async_core, "engine", None) if async_core is not None else None - ) - scheduler = ( - getattr(core, "scheduler", None) if core is not None else None - ) + for model_id, scheduler in _iter_loaded_schedulers(): + ssd_manager = getattr(scheduler, "paged_ssd_cache_manager", None) + if ssd_manager is not None: + try: + total_deleted += ssd_manager.clear() + except Exception as exc: + logger.warning( + "Failed to clear SSD cache for model '%s': %s", + model_id, + exc, + ) - if scheduler is not None: - ssd_manager = getattr(scheduler, "paged_ssd_cache_manager", None) - if ssd_manager is not None: - try: - deleted = ssd_manager.clear() - total_deleted += deleted - except Exception as exc: - logger.warning( - "Failed to clear SSD cache for model '%s': %s", - model_id, - exc, - ) + # MRU partials chain from paged-block hashes whose KV bytes are + # gone after the ssd_manager.clear() above. Drop them so the + # admin "clear all warm caches" intent is honoured symmetrically. + # Single-tier behaviour (no clear) is the surviving-stash hazard + # the peer review caught for this endpoint. + block_aware_cache = getattr(scheduler, "block_aware_cache", None) + if block_aware_cache is not None: + try: + block_aware_cache.clear_mru_partials() + except Exception as exc: + logger.warning( + "Failed to clear MRU partials for model '%s': %s", + model_id, + exc, + ) # Phase 2: remove any remaining files on disk (covers unloaded models) global_settings = _get_global_settings() @@ -3937,6 +4017,31 @@ async def clear_ssd_cache(is_admin: bool = Depends(require_admin)): return {"status": "ok", "total_deleted": total_deleted} +@router.post("/api/hot-cache/clear") +async def clear_hot_cache(is_admin: bool = Depends(require_admin)): + """Clear the in-memory (hot) cache for all loaded models. + + No filesystem fallback needed — hot cache is in-memory only and does + not survive process restart. + """ + total_cleared = 0 + for model_id, scheduler in _iter_loaded_schedulers(): + ssd_manager = getattr(scheduler, "paged_ssd_cache_manager", None) + if ssd_manager is not None and hasattr(ssd_manager, "clear_hot_cache"): + try: + total_cleared += ssd_manager.clear_hot_cache() + except Exception as exc: + logger.warning( + "Failed to clear hot cache for model '%s': %s", + model_id, + exc, + ) + rate_tracker = getattr(scheduler, "_cache_rate_tracker", None) + if rate_tracker is not None: + rate_tracker.clear() + return {"status": "ok", "total_cleared": total_cleared} + + @router.post("/api/cache/probe") async def probe_cache( request: CacheProbeRequest, diff --git a/omlx/admin/static/css/dashboard.css b/omlx/admin/static/css/dashboard.css index a95382df8..6af38022b 100644 --- a/omlx/admin/static/css/dashboard.css +++ b/omlx/admin/static/css/dashboard.css @@ -63,6 +63,10 @@ [data-theme="dark"] .hover\:text-neutral-700:hover { color: var(--text-primary) !important; } [data-theme="dark"] .hover\:text-neutral-600:hover { color: var(--text-secondary) !important; } + /* === Gauge track (visible in both themes) === */ + .gauge-track { background-color: #e5e5e5; } + [data-theme="dark"] .gauge-track { background-color: #3f3f46 !important; } + /* === Active nav tab (bg-white with shadow inside dark nav) === */ [data-theme="dark"] .shadow-sm { box-shadow: 0 1px 2px 0 rgba(0, 0, 0, 0.3) !important; } diff --git a/omlx/admin/static/js/dashboard.js b/omlx/admin/static/js/dashboard.js index ec8ddbf90..cb3acbf3f 100644 --- a/omlx/admin/static/js/dashboard.js +++ b/omlx/admin/static/js/dashboard.js @@ -160,6 +160,10 @@ total_num_files: 0, total_size_bytes: 0, effective_block_sizes: [], + hot_cache_size_bytes: 0, + hot_cache_entries: 0, + hot_cache_max_bytes: 0, + disk_max_bytes: 0, }, }, alltimeStats: { @@ -190,6 +194,7 @@ showClearStatsConfirm: false, showClearAlltimeConfirm: false, showClearSsdCacheConfirm: false, + showClearHotCacheConfirm: false, _statsRefreshTimer: null, // Log viewer state @@ -2149,7 +2154,8 @@ async clearSsdCache() { try { - await fetch('/admin/api/ssd-cache/clear', { method: 'POST' }); + const resp = await fetch('/admin/api/ssd-cache/clear', { method: 'POST' }); + if (!resp.ok) console.error('SSD cache clear failed:', resp.status); this.showClearSsdCacheConfirm = false; await this.loadStats(); } catch (err) { @@ -2158,6 +2164,18 @@ } }, + async clearHotCache() { + try { + const resp = await fetch('/admin/api/hot-cache/clear', { method: 'POST' }); + if (!resp.ok) console.error('Hot cache clear failed:', resp.status); + this.showClearHotCacheConfirm = false; + await this.loadStats(); + } catch (err) { + console.error('Failed to clear hot cache:', err); + this.showClearHotCacheConfirm = false; + } + }, + startStatsRefresh() { this.stopStatsRefresh(); this._statsRefreshTimer = setInterval(() => { @@ -2178,6 +2196,39 @@ return num.toLocaleString(); }, + cacheObsCumulative(stats, selectedModel) { + const entries = stats.runtime_cache?.models || []; + if (entries.length === 0) return {}; + + if (selectedModel) { + const entry = entries.find(m => m.id === selectedModel); + return entry?.cache_rates?.cumulative || {}; + } + + const sumKeys = ['prefix_hits', 'prefix_misses', 'evictions', 'ssd_hot_hits', 'ssd_disk_loads', 'ssd_saves', 'hot_cache_evictions', 'hot_cache_promotions', 'mru_partial_stashes', 'mru_partial_hits', 'mru_partial_evictions', 'mru_partial_tokens_saved']; + let agg = {}; + + for (const m of entries) { + const c = m.cache_rates?.cumulative; + if (!c || Object.keys(c).length === 0) continue; + for (const k of sumKeys) { + agg[k] = (agg[k] || 0) + (c[k] || 0); + } + } + + const ph = agg.prefix_hits || 0; + const pm = agg.prefix_misses || 0; + const sh = agg.ssd_hot_hits || 0; + const sd = agg.ssd_disk_loads || 0; + const ms = agg.mru_partial_stashes || 0; + const mh = agg.mru_partial_hits || 0; + agg.prefix_hit_rate = (ph + pm) > 0 ? ph / (ph + pm) : 0; + agg.ssd_hot_rate = (sh + sd) > 0 ? sh / (sh + sd) : 0; + agg.mru_partial_hit_rate = ms > 0 ? mh / ms : 0; + + return agg; + }, + getStatFontClass(value) { if (value >= 1000000000) return 'text-2xl'; if (value >= 1000000) return 'text-3xl'; @@ -2239,6 +2290,38 @@ return 'bg-red-400'; }, + get runtimeHotCachePercent() { + const rc = this.stats.runtime_cache; + if (!rc || !rc.hot_cache_max_bytes) return 0; + return Math.min(100, (rc.hot_cache_size_bytes / rc.hot_cache_max_bytes) * 100); + }, + + // mruEnabled is a feature-on gate (drives the rate strip and the + // per-model MRU Tails column). It reads the payload-level + // mru_partial_max_entries purely as "configured for any loaded + // model" — there is deliberately no aggregate MRU-tails gauge, + // since the slots are per-model, not a shared budget. + get mruEnabled() { + return (this.stats.runtime_cache?.mru_partial_max_entries || 0) > 0; + }, + + get hotCacheEnabled() { + return (this.stats.runtime_cache?.hot_cache_max_bytes || 0) > 0; + }, + + get cacheRatesGridCols() { + const both = this.hotCacheEnabled && this.mruEnabled; + if (both) return 'grid-cols-2 sm:grid-cols-6'; + if (this.hotCacheEnabled || this.mruEnabled) return 'grid-cols-2 sm:grid-cols-4'; + return 'grid-cols-2'; + }, + + get runtimeSsdCachePercent() { + const rc = this.stats.runtime_cache; + if (!rc || !rc.disk_max_bytes) return 0; + return Math.min(100, (rc.total_size_bytes / rc.disk_max_bytes) * 100); + }, + get activeModelsMemoryPercent() { const am = this.stats.active_models; if (!am || !am.model_memory_max) return 0; diff --git a/omlx/admin/templates/dashboard/_status.html b/omlx/admin/templates/dashboard/_status.html index 1d6e04f40..f1284044c 100644 --- a/omlx/admin/templates/dashboard/_status.html +++ b/omlx/admin/templates/dashboard/_status.html @@ -282,8 +282,51 @@

{{ t('status.head Runtime Cache Observability
- + +
+ Memory +
+
+
+ +
+ +
+ Clear memory cache? + + +
+
+
+ + | + + +
+ SSD +
+
+
+ +
+ +
+
+

Prefix Hit Rate

+

+
+
+

Memory Hit Rate

+

+
+
+

MRU Tail Hit Rate

+

+
+
+

Prefix Evictions

+

+
+
+

Memory Evictions

+

+
+
+

MRU Tokens Saved

+

+
+
+
@@ -336,8 +416,11 @@

{{ t('status.head

- - + + + + + @@ -358,6 +441,15 @@

{{ t('status.head

+ + + diff --git a/omlx/cache/observability.py b/omlx/cache/observability.py new file mode 100644 index 000000000..c53c8a2d2 --- /dev/null +++ b/omlx/cache/observability.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +import threading +import time +from collections import deque +from typing import Any + + +_DEFAULT_WINDOWS = (60, 300, 900) +_MAX_SNAPSHOTS = 90 +_MIN_INTERVAL = 10.0 + + +class CacheRateTracker: + + def __init__( + self, + max_snapshots: int = _MAX_SNAPSHOTS, + min_interval: float = _MIN_INTERVAL, + ): + self._snapshots: deque[tuple[float, dict[str, int]]] = deque( + maxlen=max_snapshots + ) + self._min_interval = min_interval + self._lock = threading.Lock() + + def maybe_snapshot(self, counters: dict[str, int]) -> bool: + with self._lock: + now = time.monotonic() + if self._snapshots and (now - self._snapshots[-1][0]) < self._min_interval: + return False + self._snapshots.append((now, dict(counters))) + return True + + def get_rates( + self, windows: tuple[int, ...] = _DEFAULT_WINDOWS + ) -> dict[str, Any]: + with self._lock: + if not self._snapshots: + return {"windows": {}, "cumulative": {}} + + now = self._snapshots[-1][0] + newest = self._snapshots[-1][1] + + window_rates = {} + for w in windows: + label = _window_label(w) + baseline_ts = None + baseline_counters = None + for ts, counters in self._snapshots: + if (now - ts) <= w: + baseline_ts, baseline_counters = ts, counters + break + if baseline_ts is None: + baseline_ts, baseline_counters = self._snapshots[0] + elapsed = now - baseline_ts + if elapsed < 1.0: + window_rates[label] = {} + continue + window_rates[label] = _compute_window( + baseline_counters, newest, elapsed + ) + + cumulative = _compute_cumulative(newest) + return {"windows": window_rates, "cumulative": cumulative} + + def snapshot_and_get_rates( + self, + counters: dict[str, int], + windows: tuple[int, ...] = _DEFAULT_WINDOWS, + ) -> dict[str, Any]: + self.maybe_snapshot(counters) + return self.get_rates(windows) + + def clear(self) -> None: + with self._lock: + self._snapshots.clear() + + +def _window_label(seconds: int) -> str: + if seconds < 60: + return f"{seconds}s" + return f"{seconds // 60}m" + + +def _safe_ratio(numerator: int, denominator: int) -> float: + if denominator == 0: + return 0.0 + return numerator / denominator + + +def _compute_window( + old: dict[str, int], new: dict[str, int], elapsed: float +) -> dict[str, Any]: + def delta(key: str) -> int: + return max(0, new.get(key, 0) - old.get(key, 0)) + + d_prefix_hits = delta("prefix_hits") + d_prefix_misses = delta("prefix_misses") + d_evictions = delta("evictions") + d_ssd_hot = delta("ssd_hot_hits") + d_ssd_disk = delta("ssd_disk_loads") + d_tokens_matched = delta("prefix_tokens_matched") + d_tokens_requested = delta("prefix_tokens_requested") + d_mru_stashes = delta("mru_partial_stashes") + d_mru_hits = delta("mru_partial_hits") + d_mru_evictions = delta("mru_partial_evictions") + d_mru_tokens_saved = delta("mru_partial_tokens_saved") + + minutes = elapsed / 60.0 + + return { + "prefix_hit_rate": round( + _safe_ratio(d_prefix_hits, d_prefix_hits + d_prefix_misses), 4 + ), + "prefix_hits": d_prefix_hits, + "prefix_misses": d_prefix_misses, + "prefix_match_efficiency": round( + _safe_ratio(d_tokens_matched, d_tokens_requested), 4 + ), + "evictions": d_evictions, + "eviction_rate_per_min": round(d_evictions / minutes, 2) if minutes > 0 else 0.0, + "ssd_hot_hits": d_ssd_hot, + "ssd_disk_loads": d_ssd_disk, + "ssd_hot_rate": round( + _safe_ratio(d_ssd_hot, d_ssd_hot + d_ssd_disk), 4 + ), + # MRU partial stash payoff: fraction of stashes that paid off as + # an apply-time splice. Workload with rate ≈ 0 means stashes + # are mostly wasted (uniform/unrepeated prompts); rate near 1 + # means almost every stash got reused. + "mru_partial_stashes": d_mru_stashes, + "mru_partial_hits": d_mru_hits, + "mru_partial_evictions": d_mru_evictions, + "mru_partial_tokens_saved": d_mru_tokens_saved, + "mru_partial_hit_rate": round(_safe_ratio(d_mru_hits, d_mru_stashes), 4), + } + + +def _compute_cumulative(counters: dict[str, int]) -> dict[str, Any]: + prefix_hits = counters.get("prefix_hits", 0) + prefix_misses = counters.get("prefix_misses", 0) + ssd_hot = counters.get("ssd_hot_hits", 0) + ssd_disk = counters.get("ssd_disk_loads", 0) + tokens_matched = counters.get("prefix_tokens_matched", 0) + tokens_requested = counters.get("prefix_tokens_requested", 0) + mru_stashes = counters.get("mru_partial_stashes", 0) + mru_hits = counters.get("mru_partial_hits", 0) + + return { + "prefix_hits": prefix_hits, + "prefix_misses": prefix_misses, + "prefix_hit_rate": round(_safe_ratio(prefix_hits, prefix_hits + prefix_misses), 4), + "prefix_tokens_saved": counters.get("prefix_tokens_saved", 0), + "prefix_match_efficiency": round( + _safe_ratio(tokens_matched, tokens_requested), 4 + ), + "evictions": counters.get("evictions", 0), + "ssd_hot_hits": ssd_hot, + "ssd_disk_loads": ssd_disk, + "ssd_saves": counters.get("ssd_saves", 0), + "hot_cache_evictions": counters.get("hot_cache_evictions", 0), + "hot_cache_promotions": counters.get("hot_cache_promotions", 0), + "ssd_hot_rate": round(_safe_ratio(ssd_hot, ssd_hot + ssd_disk), 4), + "mru_partial_stashes": mru_stashes, + "mru_partial_hits": mru_hits, + "mru_partial_evictions": counters.get("mru_partial_evictions", 0), + "mru_partial_tokens_saved": counters.get("mru_partial_tokens_saved", 0), + "mru_partial_hit_rate": round(_safe_ratio(mru_hits, mru_stashes), 4), + } diff --git a/omlx/cache/paged_ssd_cache.py b/omlx/cache/paged_ssd_cache.py index 7d5c0d6c7..be52bc8e5 100644 --- a/omlx/cache/paged_ssd_cache.py +++ b/omlx/cache/paged_ssd_cache.py @@ -2035,6 +2035,20 @@ def enforce_size_limit(self) -> int: ) return freed + def clear_hot_cache(self) -> int: + """Clear all in-memory (hot) cache entries. + + Returns: + Number of entries cleared. + """ + with self._hot_cache_lock: + count = len(self._hot_cache) + self._hot_cache.clear() + self._hot_cache_total_bytes = 0 + if count: + logger.info("Cleared %d hot cache entries", count) + return count + def clear(self) -> int: """ Clear all SSD cache files. diff --git a/omlx/cache/prefix_cache.py b/omlx/cache/prefix_cache.py index 4f2bd1d32..697c9a9e9 100644 --- a/omlx/cache/prefix_cache.py +++ b/omlx/cache/prefix_cache.py @@ -9,6 +9,7 @@ import logging import math import time +from collections import OrderedDict from collections.abc import Callable from dataclasses import dataclass from typing import Any @@ -31,7 +32,7 @@ ) from .paged_ssd_cache import PagedSSDCacheManager from .stats import PrefixCacheStats -from .type_registry import CacheTypeRegistry +from .type_registry import KNOWN_SLICEABLE_CACHE_TYPES, CacheTypeRegistry logger = logging.getLogger(__name__) @@ -44,6 +45,59 @@ class BlockCacheEntry: last_access: float +@dataclass +class _MRUPartialBlock: + """One entry in the bounded LRU stash for trailing sub-block partials. + + The paged SSD cache only persists full ``block_size`` blocks; trailing + sub-block tokens (e.g. 139 out of 256) are otherwise discarded and + re-prefilled on every repeat request. The MRU stash keeps the partial + in memory so an immediate repeat skips re-prefilling those tail tokens. + + Entries are stored in ``BlockAwarePrefixCache._mru_partials`` as an + ``OrderedDict`` keyed by ``parent_hash``. New entries land at the tail; + LRU eviction pops the oldest when the dict exceeds + ``mru_partial_max_entries``. A successful apply promotes the matched + entry to the tail (move_to_end). Apply-time miss for a found entry + pops only that key, leaving siblings intact. Same-key replacement on + stash is correct LRU put behavior. + + Stash and apply are gated on **uniform layer sliceability**: if any + layer in the model is non-sliceable (RotatingKVCache, ArraysCache, + etc.) no entries are ever written. Splicing into only the sliceable + layers would create per-layer offset skew at decode time. + + Memory accounting + ----------------- + ``kv_data`` holds real ``mx.array`` allocations (produced by + ``_clone_tensor`` → ``mx.copy``), so each entry's memory cost is + automatically counted by ``mx.get_active_memory()`` and is therefore + visible to every runtime memory enforcement and telemetry path in + this codebase (process enforcer, scheduler limit checks, + periodic-clear threshold, prefill pre-flight peak check). + + Upstream of those, the engine pool reserves a fraction of each + model's weight size as KV headroom before admission. MRU partials + are one tenant of that headroom alongside in-flight prompt caches; + they are not separately reserved because each entry is bounded at + one ``block_size``-worth of KV and the default cap is small. Under + ``hot_cache_only=True``, the hot cache and the MRU dict share that + envelope — tune ``--mru-partial-max-entries`` and + ``--hot-cache-max-size`` together in that mode rather than treating + them as independent dials. + + Future maintainers: do **not** add a separate accounting hook for + these entries. The invariant that ``kv_data`` holds ``mx.array`` + instances (not numpy/CPU copies) is what makes the implicit + accounting work; ``test_kv_data_holds_mlx_arrays_for_active_memory_accounting`` + pins it. + """ + + parent_hash: bytes | None + tokens: list[int] + kv_data: list[tuple[Any, Any]] + + class BlockAwarePrefixCache(CacheManager): """ Prefix cache that uses PagedCacheManager for block-based storage. @@ -82,6 +136,7 @@ def __init__( model: Any, paged_cache_manager: PagedCacheManager, paged_ssd_cache_manager: PagedSSDCacheManager | None = None, + mru_partial_max_entries: int = 4, ): """ Initialize block-aware prefix cache. @@ -90,6 +145,12 @@ def __init__( model: The MLX model (used for identification) paged_cache_manager: The PagedCacheManager instance for block management paged_ssd_cache_manager: The PagedSSDCacheManager for SSD storage (required for paged SSD-only mode) + mru_partial_max_entries: Maximum simultaneous MRU partial-block + stashes. Each entry is bounded at one ``block_size`` of KV + memory. ``0`` disables the MRU feature entirely (silent + fallback to "no MRU" behavior, mirroring + ``hot_cache_max_size="0"`` convention). Default of 4 matches + the dflash ``max_entries`` precedent (PR #1120). """ self.model = model self.model_key = id(model) @@ -111,14 +172,121 @@ def __init__( # Kept for API compatibility self._cold_restore_callback: Callable[[int, bytes], bool] | None = None + # Bounded LRU map for trailing sub-block tails (see _MRUPartialBlock). + # Keyed by parent_hash (block_hash of the last full block in the + # prefix, or None for short prompts whose prefix is < block_size). + # Populated by store_cache via _update_mru_partial, consumed by + # apply_mru_partial. Mirrors the OrderedDict + LRU pattern from + # PagedSSDCacheManager._hot_cache. Empty dict for hybrid models. + self._mru_partials: OrderedDict[ + bytes | None, _MRUPartialBlock + ] = OrderedDict() + self._mru_partial_max_entries: int = mru_partial_max_entries + # Statistics self._hits = 0 self._misses = 0 self._tokens_saved = 0 self._partial_block_skips = 0 self._partial_tokens_skipped = 0 + self._tokens_matched_total = 0 + self._tokens_requested_total = 0 self._last_partial_tokens_skipped = 0 self._last_tokens_to_next_block = 0 + # MRU partial cache cumulative counters (see PrefixCacheStats). + self._mru_partial_stashes = 0 + self._mru_partial_hits = 0 + self._mru_partial_evictions = 0 + self._mru_partial_tokens_saved = 0 + + # Tri-state MRU eligibility flag (see PrefixCacheStats.mru_partial_supported). + # Latched True once a sliceable layer set is observed; latched False + # the first time a non-sliceable set is observed (eager check below + # or lazy fallback inside _update_mru_partial). + self._mru_partial_supported: bool | None = None + self._mru_partial_warn_emitted: bool = False + if self._mru_partial_max_entries > 0: + self._check_mru_eligibility_at_init(model) + + def _check_mru_eligibility_at_init(self, model: Any) -> None: + """Best-effort load-time check: does this model have sliceable layers? + + The MRU partial-block stash safety gate rejects any layer set that + contains a non-sliceable cache type (``RotatingKVCache``, + ``PoolingCache``, ``ArraysCache``, ``CacheList``, etc.) — splicing + a partial into a sliceable subset only would cause per-layer offset + skew at decode (silent generation corruption). For affected + models the entire MRU feature is structurally unavailable, and the + dashboard would otherwise show a confusing "0/N entries" gauge + forever. Warn the operator at load time so they can grep this + exact line for the offending types. + + Best-effort: if ``model.make_cache()`` is absent or raises, the + lazy fallback inside ``_update_mru_partial`` picks the same signal + up at first inference instead. + """ + if not hasattr(model, "make_cache"): + return + try: + sample_cache = model.make_cache() + except Exception as exc: + logger.debug( + "MRU eager eligibility check skipped (make_cache failed: %s)", + exc, + ) + return + try: + mcc = ModelCacheConfig.from_cache_list(sample_cache) + layer_types = mcc.get_type_names() if mcc is not None else None + except Exception as exc: + logger.debug( + "MRU eager eligibility check skipped (ModelCacheConfig failed: %s)", + exc, + ) + return + finally: + # Drop the sample cache refs ASAP — only used for type-name + # introspection. No tensor buffers were allocated yet (those + # arrive on first prefill), but keeping the wrappers around + # serves no purpose. + del sample_cache + if not layer_types: + return + if self._all_layers_sliceable(layer_types): + self._mru_partial_supported = True + else: + self._record_mru_unsupported(layer_types) + + def _record_mru_unsupported( + self, layer_cache_types: list[str] | None + ) -> None: + """Latch ``mru_partial_supported=False`` and emit a one-shot warning. + + Called from both the eager init-time check and the lazy fallback + inside ``_update_mru_partial``. The warning matches the in-tree + load-phase warning style ("condition + consequence, plain words") + — see ``utils/model_loading.py`` mtp warning and ``engine/dflash`` + L2 warning for the reference voice. Mechanism details + (sliceable whitelist, offset-skew failure mode) live in the + ``_all_layers_sliceable`` docstring where developers read them, + not in the operator log. + """ + already_recorded = self._mru_partial_supported is False + self._mru_partial_supported = False + if already_recorded or self._mru_partial_warn_emitted: + return + self._mru_partial_warn_emitted = True + if layer_cache_types: + offenders = sorted( + set(layer_cache_types) - KNOWN_SLICEABLE_CACHE_TYPES + ) + else: + offenders = ["unknown"] + logger.warning( + "MRU tail cache enabled but this model is incompatible " + "(cache layers: %s); MRU tails will be inactive for this model.", + ", ".join(offenders), + ) def _get_model_num_layers(self, model: Any) -> int: """ @@ -285,6 +453,8 @@ def fetch_cache( num_prefix_tokens = len(tokens) - len(remaining) self._hits += 1 self._tokens_saved += num_prefix_tokens + self._tokens_matched_total += num_prefix_tokens + self._tokens_requested_total += len(tokens) logger.debug( f"Cache hit for {request_id}: " @@ -310,6 +480,8 @@ def fetch_cache( remaining = tokens[prefix_len:] self._hits += 1 self._tokens_saved += prefix_len + self._tokens_matched_total += prefix_len + self._tokens_requested_total += len(tokens) logger.debug( f"Prefix index hit for {request_id}: " f"{prefix_len} tokens matched" @@ -319,6 +491,7 @@ def fetch_cache( # No cache hit self._misses += 1 + self._tokens_requested_total += len(tokens) logger.debug(f"Cache miss for {request_id}") return None, tokens @@ -332,6 +505,7 @@ def store_cache( extra_keys: tuple[Any, ...] | None = None, extra_key_token_start: int | None = None, extra_key_ranges: list[tuple[int, tuple[Any, ...]]] | None = None, + prompt_token_count: int | None = None, ) -> BlockTable | None: """ Store computed cache for future reuse. @@ -352,6 +526,14 @@ def store_cache( boundary_snapshots: Optional mapping of token_count -> extracted cache states for intermediate block boundaries. Used to store per-block ArraysCache state instead of placeholders in hybrid models. + prompt_token_count: Number of prompt tokens at the front of + ``tokens``. ``tokens`` is typically ``prompt + output``; + the MRU partial cache must stash the prompt's trailing + tail (the part a repeat request will resubmit), not the + sequence's. ``None`` means "no prompt boundary known" — + the MRU stash then treats the whole sequence as the + prompt (pre-MRU-fix behavior; correct for verbatim-repeat + callers like the generic CacheManager.store path). Returns: BlockTable for the stored cache, or None on failure @@ -652,6 +834,22 @@ def store_cache( last_access=time.time(), ) + # Stash the prompt's trailing sub-block tail in memory so an + # immediate repeat request can splice it back in without a + # re-prefill. Keyed by the prompt's last full block so the + # lookup in apply_mru_partial (which sees a prompt-only repeat) + # finds it — see _update_mru_partial. + self._update_mru_partial( + new_tokens=new_tokens, + cache_data=cache_data, + block_table=block_table, + existing_tokens=existing_tokens, + prompt_token_count=prompt_token_count, + is_tensor_data=bool(is_tensor_data), + layer_cache_types=layer_cache_types, + model_cache_config=model_cache_config, + ) + logger.debug( f"Stored cache for {request_id}: " f"{len(block_table.block_ids)} blocks ({blocks_saved_to_ssd} saved to tiered cache), " @@ -660,6 +858,369 @@ def store_cache( return block_table + def has_mru_partial(self) -> bool: + """Whether any trailing-tail partial is currently stashed. + + Used by the scheduler to decide whether to suppress a deferred + Metal cache clear (a fresh stash is a strong signal that the + same prompt may return immediately). The "any entry present" + semantic is the right predicate regardless of multi-slot count + — one warm prompt is enough to merit suppression. + """ + return bool(self._mru_partials) + + def _can_reconstruct(self) -> bool: + """Whether ``reconstruct_cache`` has a path to return non-``None``. + + Used as a guard at both the MRU stash site (``_update_mru_partial``) + and the canonical reconstruct entry (``reconstruct_cache``). Co- + locating the two checks keeps them in lockstep — a future fetch + path that bypasses ``PagedSSDCacheManager`` (memory-only mode, + alternate backends) updates exactly one predicate. + + Note the predicate is ``manager is not None``, not "SSD writes are + enabled." ``hot_cache_only=True`` (omlx setting / OMLX_HOT_CACHE_ONLY + env) keeps the manager but skips the disk writer thread; reconstruct + still works because ``load_block_with_metadata`` short-circuits to + the hot tier. The gate fires only when no manager is configured + at all — typically a test/dev scenario. + """ + return HAS_MLX and self.paged_ssd_cache is not None + + def _all_layers_sliceable( + self, layer_cache_types: list[str] | None + ) -> bool: + """True iff every layer's cache type is in the known-sliceable + whitelist. + + Why a whitelist and not ``CacheTypeRegistry.supports_block_slicing``: + the registry uses ``DefaultCacheHandler`` (a ``KVCacheHandler`` + subclass that reports ``supports_block_slicing=True``) as the + fallback for class names with no registered handler. Several real + non-sliceable types — ``BatchRotatingKVCache``, ``BatchPoolingCache``, + ``PoolingCache`` (without the deepseek_v4 patch applied) — are + mapped in ``_class_name_map`` but have no registered handler, so + the registry would silently classify them as sliceable. The + whitelist is the same one the rest of the scheduler trusts for + snapshot-skip and partial-extraction decisions, so MRU now agrees + with the surrounding code rather than getting its own answer. + + Hybrid models (e.g. Gemma 3, Mistral) mix sliceable layers with + non-sliceable ones (``RotatingKVCache``, ``ArraysCache``). Splicing + the partial only into the sliceable layers would create per-layer + offset skew at decode — undefined behaviour at the model level — + so we refuse the stash entirely whenever any layer is non-sliceable. + """ + if not layer_cache_types: + # Default fallback path assumes all layers are KVCache; safe to stash. + return True + for class_name in layer_cache_types: + if class_name not in KNOWN_SLICEABLE_CACHE_TYPES: + return False + return True + + def _update_mru_partial( + self, + *, + new_tokens: list[int], + cache_data: list[Any], + block_table: BlockTable, + existing_tokens: int, + prompt_token_count: int | None, + is_tensor_data: bool, + layer_cache_types: list[str] | None, + model_cache_config: ModelCacheConfig | None, + ) -> None: + """Stash the prompt's trailing partial from a just-completed ``store_cache``. + + ``store_cache`` is given the full ``prompt + output`` sequence, + but a repeat request resubmits the *prompt* only — and + ``apply_mru_partial`` looks the entry up by the prompt's last + full block. So the stash must key off, and slice, the + **prompt's** trailing partial, not the stored sequence's. + ``prompt_token_count`` carries the prompt boundary; ``None`` + falls back to "whole sequence is the prompt" (pre-fix behavior, + correct for verbatim-repeat callers). + + On success, writes one entry into the LRU map keyed by + ``parent_hash`` (the prompt's last full block, or ``None`` for a + prompt shorter than one block). If the map is at capacity, the + oldest entry is evicted via ``popitem(last=False)``. + + The "no eligible tail" branches (block-aligned prompt, non-tensor + data, no reconstruct path configured, hybrid model, extraction + failure, ambiguous cache layout) bare-return — they signal a + local "nothing to stash this time," NOT a global "wipe the map." + That distinction is what lets multiple distinct-prefix entries + coexist under interleaving. + """ + if self._mru_partial_max_entries <= 0: + # Feature disabled. Match the hot_cache_max_size="0" convention: + # silent fallback to "no MRU" behavior. + return + # Sliceable-layer guard is checked first and recorded so the + # eligibility flag flips at the same instant the warning fires — + # later branches in this chain are per-call ("no eligible tail + # this time") and must not pollute the structural flag. + if not self._all_layers_sliceable(layer_cache_types): + self._record_mru_unsupported(layer_cache_types) + return + if self._mru_partial_supported is None: + self._mru_partial_supported = True + if not is_tensor_data or not self._can_reconstruct(): + return + + # Resolve the prompt boundary within the stored sequence. None + # means "no boundary known" — treat the whole stored sequence as + # the prompt, which reproduces the pre-fix whole-sequence stash. + sequence_len = existing_tokens + len(new_tokens) + if prompt_token_count is None: + prompt_token_count = sequence_len + else: + prompt_token_count = min(prompt_token_count, sequence_len) + # Prompt fully covered by already-cached full blocks: its tail (if + # any) is partial and partials are never stored as blocks, so a + # prompt that ends at or before existing_tokens is block-aligned + # and has nothing for the MRU to add. + if prompt_token_count <= existing_tokens: + return + + # Prompt's trailing partial: global range + # [prompt_partial_start, prompt_token_count). Block-aligned + # prompt (partial_len == 0) → every token lands in a full paged + # block, nothing to stash. + prompt_partial_len = prompt_token_count % self.block_size + if prompt_partial_len == 0: + return + prompt_partial_start = prompt_token_count - prompt_partial_len + prompt_full_blocks = prompt_token_count // self.block_size + # new_tokens-relative offset of the prompt's trailing partial. + # existing_tokens is block-aligned and <= prompt_partial_start, + # so this is a valid index into new_tokens. + partial_start = prompt_partial_start - existing_tokens + partial_global_start = prompt_partial_start + partial_global_end = prompt_token_count + + # Decide which axis-2 index range covers the trailing partial: + # - Global: cache_data spans the full sequence (prefix + new tail); + # slice at [partial_global_start:partial_global_end]. + # - Local: cache_data spans only the newly-processed suffix; + # slice at [partial_start:partial_start + prompt_partial_len]. + # + # We classify by exact cache length, not the previous + # ``cache_seq_len >= existing_tokens + 1`` heuristic — that one + # silently picked "local" when ``cache_seq_len == existing_tokens``, + # which can happen on multi-turn requests where the cache was + # extracted at a boundary that happens to equal the prior turn's + # length. Slicing local indices on a global cache there sampled + # tokens from the prefix, parent_hash still matched, and a future + # apply spliced wrong KV — silent generation corruption. Now: if + # the layout cannot be unambiguously classified, refuse to stash. + cache_seq_len = self._get_cache_seq_len(cache_data) + local_len = len(new_tokens) + if cache_seq_len >= partial_global_end: + # Cache is long enough to contain the global partial range. + p_cache_start = partial_global_start + p_cache_end = partial_global_end + elif existing_tokens == 0 or cache_seq_len == local_len: + # Cache covers only the new suffix (or there is no prefix). + p_cache_start = partial_start + p_cache_end = partial_start + prompt_partial_len + else: + # Ambiguous: cache_seq_len is short of global_end but does + # not equal local_len either. Refuse rather than guess. + return + + # is_last_block=False is the correct intent for partial extraction: + # we want a token slice, not the full state of any non-sliceable + # layer. The non-sliceable case is already excluded above. + partial_kv = self._extract_block_tensor_slice( + cache_data, + p_cache_start, + p_cache_end, + model_cache_config, + is_last_block=False, + ) + if not partial_kv: + return + # Materialize now, on this (store-cache worker) thread. The + # tensors are lazy ops bound to this thread's stream; the splice + # in apply_mru_partial runs on the separate inference thread and + # must not inherit a cross-thread stream dependency. + self._materialize_mru_kv(partial_kv) + + # Key by the PROMPT's last full block (block index + # prompt_full_blocks - 1), so a prompt-only repeat — which is what + # apply_mru_partial sees — finds this entry. block_table.block_ids + # is ordered [fetched-prefix blocks..., newly-built blocks...] in + # sequence order, and the stored sequence has at least + # prompt_full_blocks full blocks (the prompt is a prefix of it), + # so the index is in range. prompt_full_blocks == 0 (prompt + # shorter than one block) keeps parent_hash = None — the + # short-prompt None-key path. + parent_hash: bytes | None = None + if prompt_full_blocks >= 1: + if len(block_table.block_ids) < prompt_full_blocks: + # Prompt's last full block isn't in the table (block build + # rolled back on an extraction failure). Skip rather than + # mis-key against the None short-prompt slot. + return + parent_bid = block_table.block_ids[prompt_full_blocks - 1] + parent_block = self.paged_cache.allocated_blocks.get(parent_bid) + if parent_block is None or not parent_block.block_hash: + return + parent_hash = parent_block.block_hash + + # LRU put: pop any existing entry first so the re-insert lands at + # the tail with fresh order; then evict the oldest entry if over + # capacity. Mirrors PagedSSDCacheManager._hot_cache_put. + # Note: same-key replacement does NOT count as an eviction — + # the operator sees stash payoff via the stashes/hits ratio, not + # via replacement churn. + if parent_hash in self._mru_partials: + self._mru_partials.pop(parent_hash) + self._mru_partials[parent_hash] = _MRUPartialBlock( + parent_hash=parent_hash, + tokens=new_tokens[partial_start : partial_start + prompt_partial_len], + kv_data=partial_kv, + ) + self._mru_partial_stashes += 1 + while len(self._mru_partials) > self._mru_partial_max_entries: + self._mru_partials.popitem(last=False) + self._mru_partial_evictions += 1 + + logger.debug( + "Stashed MRU partial: %d tokens, parent_hash=%s, layers=%d, " + "entries=%d/%d", + prompt_partial_len, + parent_hash[:8].hex() + "..." if parent_hash else "None", + len(partial_kv), + len(self._mru_partials), + self._mru_partial_max_entries, + ) + + def apply_mru_partial( + self, + cache: list[Any], + block_table: BlockTable, + remaining_tokens: list[int], + ) -> tuple[list[Any], list[int], int]: + """Splice an MRU partial entry into a reconstructed cache, atomically. + + Looks up an entry in the multi-slot LRU map keyed by the block + hash of the last full block in ``block_table`` (``None`` for + short prompts). On a match (entry exists, partial tokens are a + prefix of ``remaining_tokens``, layer count matches, every layer + accepts the concatenate), every layer's keys/values/offset are + advanced by ``len(partial.tokens)``, the partial tokens are + consumed from the front of ``remaining_tokens``, and the entry + is moved to the LRU tail (most-recently-used). + + On any miss, only the matching entry is evicted; sibling entries + for other prefixes remain. + + The splice is **transactional**: replacement keys/values are + materialized for every layer before any layer is mutated. A + failure on layer N rolls everything back; the caller never sees + a half-mutated cache. + + Args: + cache: Reconstructed per-layer cache objects from + ``reconstruct_cache``. Mutated in place on success. + block_table: Block table from the same fetch_cache call. + Used to verify the partial chains from the right prefix. + remaining_tokens: Tokens still needing prefill. + + Returns: + ``(cache, remaining_tokens, tokens_applied)``. On miss, + ``tokens_applied == 0`` and the inputs are returned unchanged. + """ + if not self._mru_partials or not remaining_tokens: + return cache, remaining_tokens, 0 + + # Compute the parent_hash key. Multi-slot freed-block guard: + # if block_table is non-empty but the parent paged block has + # been freed (allocated_blocks.get returns None), do NOT fall + # through to a None-keyed dict lookup — that would falsely match + # a short-prompt entry against a request whose parent is just + # gone. Return no-op instead. This race is new in multi-slot + # mode; single-slot tolerated it because there was only ever + # one entry to match against. + last_hash: bytes | None = None + if block_table and block_table.block_ids: + last_bid = block_table.block_ids[-1] + last_block = self.paged_cache.allocated_blocks.get(last_bid) + if last_block is None: + return cache, remaining_tokens, 0 + last_hash = last_block.block_hash + + partial = self._mru_partials.get(last_hash) + if partial is None: + return cache, remaining_tokens, 0 + + def _evict_miss() -> tuple[list[Any], list[int], int]: + """Pop the matched entry and return the no-op tuple. + + Used by every apply-time mismatch arm. Inlines to one line + at each call site and keeps the eviction-counter bookkeeping + in one place. + """ + self._mru_partials.pop(last_hash, None) + self._mru_partial_evictions += 1 + return cache, remaining_tokens, 0 + + n_partial = len(partial.tokens) + if len(remaining_tokens) < n_partial: + return _evict_miss() + if remaining_tokens[:n_partial] != partial.tokens: + return _evict_miss() + + if len(partial.kv_data) != len(cache): + logger.debug( + "MRU partial layer count mismatch: %d vs %d, evicting entry", + len(partial.kv_data), len(cache), + ) + return _evict_miss() + + if not HAS_MLX: + return _evict_miss() + + # Phase 1: build per-layer replacements without touching the cache. + # Any failure here is a clean rollback — no layer has been mutated. + try: + replacements: list[tuple[int, Any, Any]] = [] + for layer_idx, (p_keys, p_values) in enumerate(partial.kv_data): + cache_obj = cache[layer_idx] + new_keys = mx.concatenate([cache_obj.keys, p_keys], axis=2) + new_values = mx.concatenate([cache_obj.values, p_values], axis=2) + replacements.append((layer_idx, new_keys, new_values)) + except Exception as e: + logger.debug( + "MRU partial splice build failed: %s, evicting entry", e + ) + return _evict_miss() + + # Phase 2: commit. All concatenates have already succeeded; the + # only operations remaining are attribute writes and an integer + # add, which cannot raise on a well-formed cache object. + for layer_idx, new_keys, new_values in replacements: + cache_obj = cache[layer_idx] + cache_obj.keys = new_keys + cache_obj.values = new_values + cache_obj.offset += n_partial + + # Promote this entry to the LRU tail (most-recently-used). + self._mru_partials.move_to_end(last_hash) + self._mru_partial_hits += 1 + self._mru_partial_tokens_saved += n_partial + + new_remaining = remaining_tokens[n_partial:] + logger.debug( + "Applied MRU partial: %d tokens, %d remaining, entries=%d", + n_partial, len(new_remaining), len(self._mru_partials), + ) + return cache, new_remaining, n_partial + def _get_cache_seq_len(self, cache_data: list[dict[str, Any]]) -> int: """ Get the sequence length from cache data. @@ -1220,6 +1781,39 @@ def _clone_tensor(self, tensor: Any) -> Any: return mx.array(tensor) + def _materialize_mru_kv(self, partial_kv: list[Any]) -> None: + """Force-evaluate a freshly-extracted MRU partial's KV tensors. + + ``_extract_block_tensor_slice`` builds the tensors as lazy + ``mx.copy`` ops on whichever thread ``store_cache`` runs on — the + ``omlx-store-cache`` worker. ``apply_mru_partial`` later splices + them into a live cache on the separate ``mlx-global`` inference + thread; while the tensors are still lazy, that thread's + ``mx.async_eval`` would walk the compute graph back to a + per-thread stream the inference thread cannot see, and MLX raises + ``RuntimeError: There is no Stream(gpu, N) in current thread``. + Evaluating here, on the worker, leaves concrete stream-free data + safe to splice and evaluate from any thread. + + Walks the nested ``(keys, values)`` / TurboQuant ``(tag, (k, v))`` + shapes ``_extract_block_tensor_slice`` returns and evaluates every + ``mx.array`` leaf in one batched call. + """ + if not HAS_MLX: + return + leaves: list[Any] = [] + + def _collect(obj: Any) -> None: + if isinstance(obj, mx.array): + leaves.append(obj) + elif isinstance(obj, (list, tuple)): + for item in obj: + _collect(item) + + _collect(partial_kv) + if leaves: + mx.eval(*leaves) + def _apply_window_padding( self, matched_blocks: int, @@ -1388,15 +1982,16 @@ def reconstruct_cache( if not block_table or not block_table.block_ids: return None - if not HAS_MLX: - logger.warning("Cannot reconstruct cache: MLX not available") - return None - - if self.paged_ssd_cache is None: - logger.warning( - "Cannot reconstruct cache: PagedSSDCacheManager not configured" - ) + if not self._can_reconstruct(): + if not HAS_MLX: + logger.warning("Cannot reconstruct cache: MLX not available") + else: + logger.warning( + "Cannot reconstruct cache: PagedSSDCacheManager not configured" + ) return None + # _can_reconstruct() guarantees this; narrow for the type checker. + assert self.paged_ssd_cache is not None try: # Collect cache data from valid blocks (stop at first invalid) @@ -2367,6 +2962,15 @@ def get_stats(self) -> PrefixCacheStats: block_size=self.block_size, last_partial_tokens_skipped=self._last_partial_tokens_skipped, last_tokens_to_next_block=self._last_tokens_to_next_block, + tokens_matched_total=self._tokens_matched_total, + tokens_requested_total=self._tokens_requested_total, + mru_partial_stashes=self._mru_partial_stashes, + mru_partial_hits=self._mru_partial_hits, + mru_partial_evictions=self._mru_partial_evictions, + mru_partial_tokens_saved=self._mru_partial_tokens_saved, + mru_partial_entries=len(self._mru_partials), + mru_partial_max_entries=self._mru_partial_max_entries, + mru_partial_supported=self._mru_partial_supported, ) def get_stats_dict(self) -> dict[str, Any]: @@ -2393,19 +2997,44 @@ def get_stats_dict(self) -> dict[str, Any]: "block_size": self.block_size, "last_partial_tokens_skipped": self._last_partial_tokens_skipped, "last_tokens_to_next_block": self._last_tokens_to_next_block, + "tokens_matched_total": self._tokens_matched_total, + "tokens_requested_total": self._tokens_requested_total, "active_requests": len(self._request_tables), + # MRU partial cache: counters mirror the dataclass surface so the + # admin dashboard's `mruEnabled` gate (sourced from this dict + # via Scheduler.get_ssd_cache_stats) can see the configured + # capacity. Omitting them silently hides every MRU panel. + "mru_partial_stashes": self._mru_partial_stashes, + "mru_partial_hits": self._mru_partial_hits, + "mru_partial_evictions": self._mru_partial_evictions, + "mru_partial_tokens_saved": self._mru_partial_tokens_saved, + "mru_partial_entries": len(self._mru_partials), + "mru_partial_max_entries": self._mru_partial_max_entries, + "mru_partial_supported": self._mru_partial_supported, **paged_stats, } def reset_stats(self) -> None: - """Reset statistics.""" + """Reset statistics. + + Note: ``_mru_partials`` is live state (the cache itself), not a + counter — its size is reported as a gauge by ``get_stats()`` and + is not affected by stats reset. Use ``clear_mru_partials()`` if + the entries themselves should be wiped. + """ self._hits = 0 self._misses = 0 self._tokens_saved = 0 self._partial_block_skips = 0 self._partial_tokens_skipped = 0 + self._tokens_matched_total = 0 + self._tokens_requested_total = 0 self._last_partial_tokens_skipped = 0 self._last_tokens_to_next_block = 0 + self._mru_partial_stashes = 0 + self._mru_partial_hits = 0 + self._mru_partial_evictions = 0 + self._mru_partial_tokens_saved = 0 self.paged_cache.reset_stats() def clear(self) -> int: @@ -2419,9 +3048,43 @@ def clear(self) -> int: self._request_tables.clear() self._prefix_index.clear() self.paged_cache.clear() + # MRU partials chain from paged-block hashes; once the paged + # cache is wiped, no entry can be safely applied. Cache-corruption + # recovery (Scheduler._recover_from_cache_error) routes through + # here, so stale partials would otherwise survive exactly the + # recovery path that exists because something was wrong. + # No eviction-counter bump here: clear() also calls reset_stats() + # which zeros every counter (this is the "restart everything" + # path). Operators tracking partial wipes specifically should + # use clear_mru_partials() instead — that path leaves stats alone. + self._mru_partials.clear() self.reset_stats() return cleared_count + def clear_mru_partials(self) -> int: + """Wipe only the MRU partial cache, leaving paged blocks intact. + + Intended consumer: admin-triggered cache-tier clears that drop + the backing block storage (``clear_ssd_cache`` admin endpoint, + and the future ``clear_hot_cache`` endpoint once PR #1183 + lands). Without this hook, a stash whose ``parent_hash`` chains + from a paged block whose underlying KV bytes were just flushed + from the hot/SSD tier would survive in memory and waste a + reconstruct attempt on the next request before being naturally + evicted by LRU. + + Distinct from ``clear()``: this method only drops MRU entries + and does not touch the paged cache, prefix index, or stats + (other than incrementing ``mru_partial_evictions``). + + Returns: + Number of MRU entries that were wiped. + """ + n = len(self._mru_partials) + self._mru_partials.clear() + self._mru_partial_evictions += n + return n + def set_cold_restore_callback( self, callback: Callable[[int, bytes], bool] | None, diff --git a/omlx/cache/stats.py b/omlx/cache/stats.py index 412074fc7..47f2749be 100644 --- a/omlx/cache/stats.py +++ b/omlx/cache/stats.py @@ -88,6 +88,26 @@ class PrefixCacheStats(BaseCacheStats): block_size: int = 0 last_partial_tokens_skipped: int = 0 last_tokens_to_next_block: int = 0 + tokens_matched_total: int = 0 + tokens_requested_total: int = 0 + # MRU partial cache observability. Cumulative counters that pair + # naturally — hits/stashes gives the "stash payoff" rate, evictions + # tracks churn, tokens_saved is the direct compute-saved measure. + # Gauges (entries / max_entries) describe live capacity utilisation. + mru_partial_stashes: int = 0 + mru_partial_hits: int = 0 + mru_partial_evictions: int = 0 + mru_partial_tokens_saved: int = 0 + mru_partial_entries: int = 0 + mru_partial_max_entries: int = 0 + # Tri-state eligibility flag for the MRU partial-block feature: + # None → unknown (default; no detection has fired yet) + # True → model has only sliceable cache layers + # False → at least one layer type is not in the sliceable whitelist, + # so every stash attempt is refused at the safety gate. + # Surfaces to the admin dashboard so per-model rows can show + # "N/A (see log)" instead of a misleading "0/N entries" gauge. + mru_partial_supported: bool | None = None _total_queries: int = field(default=0, repr=False) @property @@ -111,6 +131,14 @@ def reset(self) -> None: self.partial_tokens_skipped = 0 self.last_partial_tokens_skipped = 0 self.last_tokens_to_next_block = 0 + self.tokens_matched_total = 0 + self.tokens_requested_total = 0 + self.mru_partial_stashes = 0 + self.mru_partial_hits = 0 + self.mru_partial_evictions = 0 + self.mru_partial_tokens_saved = 0 + # mru_partial_entries and mru_partial_max_entries are gauges + # populated by get_stats() from live state — not reset here. self._total_queries = 0 diff --git a/omlx/cache/type_registry.py b/omlx/cache/type_registry.py index 832c4a71e..add800abc 100644 --- a/omlx/cache/type_registry.py +++ b/omlx/cache/type_registry.py @@ -23,6 +23,33 @@ logger = logging.getLogger(__name__) +# Authoritative whitelist of cache class names whose KV state is safe to +# slice along the sequence axis (axis=2). Used by: +# - scheduler.py: snapshot-skip gating, partial extraction +# - prefix_cache.py: MRU partial stash/apply gate +# +# Important: this is NOT derivable from CacheTypeHandler.supports_block_slicing +# alone, because DefaultCacheHandler (used as a fallback for class names with +# no registered handler) inherits from KVCacheHandler and reports +# supports_block_slicing=True. That makes Batch* and Pool* class names +# silently sliceable via the registry, which is wrong. Code that needs to +# decide "may I slice this layer's KV" MUST consult this whitelist by +# class-name string, not the registry. +KNOWN_SLICEABLE_CACHE_TYPES = frozenset( + { + "KVCache", + "BatchKVCache", + "QuantizedKVCache", + "TurboQuantKVCache", + "BatchTurboQuantKVCache", + # ChunkedKVCache is included once the batch=1 patch in scheduler.py + # installs its extract/filter/size pass-throughs (PR #1152); without + # that patch, Llama-4 requests fall back to the snapshot path. + "ChunkedKVCache", + } +) + + class CacheTypeRegistry: """Registry for cache type handlers. diff --git a/omlx/cli.py b/omlx/cli.py index cff502809..5fdec8905 100644 --- a/omlx/cli.py +++ b/omlx/cli.py @@ -222,6 +222,15 @@ def serve_command(args): else: scheduler_config.hot_cache_max_size = 0 + # MRU partial cache: CLI arg > settings. Independent of SSD/hot + # configuration — the MRU stash works in any mode where + # reconstruct_cache has a path (which is gated by manager presence, + # not by SSD writes; see BlockAwarePrefixCache._can_reconstruct). + if args.mru_partial_max_entries is not None: + scheduler_config.mru_partial_max_entries = int(args.mru_partial_max_entries) + else: + scheduler_config.mru_partial_max_entries = settings.cache.mru_partial_max_entries + if args.no_cache: print("Mode: Multi-model serving (no oMLX cache, mlx-lm BatchGenerator only)") elif paged_ssd_cache_dir: @@ -595,6 +604,17 @@ def main(): default=None, help="Maximum in-memory hot cache size (e.g., '8GB', '4GB'). Default: 0 (disabled)", ) + serve_parser.add_argument( + "--mru-partial-max-entries", + type=int, + default=None, + help=( + "Maximum simultaneous MRU partial-block stashes. Each entry is " + "bounded at one block_size of KV memory. 0 disables the feature. " + "Default: 4. Under --hot-cache-only this shares the in-memory KV " + "headroom envelope with --hot-cache-max-size; tune both together." + ), + ) serve_parser.add_argument( "--no-cache", action="store_true", diff --git a/omlx/scheduler.py b/omlx/scheduler.py index fab20ed7b..89152ae78 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -37,6 +37,7 @@ from mlx_lm.models.cache import make_prompt_cache from mlx_lm.sample_utils import make_logits_processors +from .cache.observability import CacheRateTracker from .cache.paged_cache import PagedCacheManager from .cache.prefix_cache import BlockAwarePrefixCache from .exceptions import is_cache_corruption_error @@ -411,19 +412,15 @@ def _patched_ppb_prompt(self, tokens): # Cache class names known to be sliceable (no boundary snapshots needed). -# ChunkedKVCache is included once the batch=1 patch above installs its -# extract/filter/size pass-throughs; without it Llama-4 requests fall -# back to the snapshot path unnecessarily. -_KNOWN_SLICEABLE_CACHE_TYPES = frozenset( - { - "KVCache", - "BatchKVCache", - "QuantizedKVCache", - "TurboQuantKVCache", - "BatchTurboQuantKVCache", - "ChunkedKVCache", - } -) +# Canonical home: omlx/cache/type_registry.py. ChunkedKVCache is included +# there once the batch=1 patch above installs its extract/filter/size +# pass-throughs (PR #1152); without it Llama-4 requests fall back to the +# snapshot path unnecessarily. +from omlx.cache.type_registry import KNOWN_SLICEABLE_CACHE_TYPES + +# Module-local alias kept for backwards compatibility with existing +# call sites in this file. +_KNOWN_SLICEABLE_CACHE_TYPES = KNOWN_SLICEABLE_CACHE_TYPES def _prompt_cache_needs_snapshots(prompt_cache: list[Any]) -> bool: @@ -552,6 +549,11 @@ class SchedulerConfig: hot_cache_only: bool = False paged_ssd_cache_max_size: int = 100 * 1024 * 1024 * 1024 # 100GB default hot_cache_max_size: int = 0 # In-memory hot cache size in bytes (0 = disabled) + # Bounded LRU stash for trailing sub-block partials of previous + # prefills, keyed by parent-block hash. Each entry holds at most + # one block_size of KV memory. ``0`` disables the feature; default + # of 4 matches the dflash max_entries precedent (PR #1120). + mru_partial_max_entries: int = 4 # Model identification (for cache isolation between different models) model_name: str = "" # OpenAI API model name (e.g., "mlx-community/Llama-3.2-3B") @@ -781,6 +783,7 @@ def __init__( self.paged_cache_manager: PagedCacheManager | None = None self.block_aware_cache: BlockAwarePrefixCache | None = None self.paged_ssd_cache_manager: PagedSSDCacheManager | None = None + self._cache_rate_tracker = CacheRateTracker() self.memory_monitor: MemoryMonitor | None = None # Initialize paged SSD cache if paged_ssd_cache_dir is specified @@ -801,6 +804,7 @@ def __init__( self.block_aware_cache = BlockAwarePrefixCache( model=model, paged_cache_manager=self.paged_cache_manager, + mru_partial_max_entries=self.config.mru_partial_max_entries, ) # Initialize paged SSD cache @@ -910,6 +914,16 @@ def __init__( # None = no deferred clear pending; int = step at which to fire. self._deferred_clear_at: int | None = None + # Per-completion budget for suppressing the deferred Metal cache + # clear when the prefix cache has a warm MRU partial (a strong + # signal that the same prompt may return immediately and would + # benefit from the still-resident lazy KV tensors). + # The budget is reset to True on every _cleanup_finished() that + # arms _deferred_clear_at. Once spent (suppression has happened + # once), the deferred clear fires at the next deadline regardless + # of MRU state, bounding total deferral at 2x _DEFERRED_CLEAR_DELAY. + self._mru_clear_suppression_available: bool = False + # Cache XTC special tokens (newline + EOS) — stable per tokenizer. # Must be after _is_harmony_model / _generation_config_eos init # since _get_xtc_special_tokens() delegates to _get_stop_tokens(). @@ -1025,6 +1039,7 @@ def _async_store_cache_worker( extra_keys: tuple[Any, ...] | None, extra_key_token_start: int | None, extra_key_ranges: list[tuple[int, tuple[Any, ...]]] | None, + prompt_token_count: int, ) -> None: """Run store_cache + paged_cache cleanup off the inference thread. @@ -1067,6 +1082,7 @@ def _async_store_cache_worker( extra_keys=extra_keys, extra_key_token_start=extra_key_token_start, extra_key_ranges=extra_key_ranges, + prompt_token_count=prompt_token_count, ) if block_table is None and self.paged_cache_manager is not None: block_table = self.paged_cache_manager.get_block_table(request_id) @@ -3305,6 +3321,38 @@ def add_request(self, request: Request) -> None: request.remaining_tokens = request.prompt_token_ids[ block_table.num_tokens : ] + # Splice the in-memory MRU partial onto the reconstructed + # cache when the trailing tokens match. Saves the + # re-prefill of the sub-block tail on exact-repeat + # prompts. The splice is a no-op when no partial is + # stashed or the trailing tokens differ. + # + # Accounting note: the partial advances cached_tokens + # but is NOT a stored paged block, so shared_prefix_blocks + # stays at the count of paged blocks reused. After a + # successful splice the invariant relaxes from + # cached_tokens == shared_prefix_blocks * block_size + # to + # cached_tokens >= shared_prefix_blocks * block_size + # with cached_tokens - shared_prefix_blocks * block_size + # ∈ [0, block_size) representing the partial. Current + # readers (the scheduler's prefill-completion log lines + # downstream) tolerate the relaxed form; future readers + # that index block_table.block_ids by + # shared_prefix_blocks must NOT use cached_tokens to + # bound the loop. + if request.remaining_tokens: + ( + request.prompt_cache, + request.remaining_tokens, + partial_applied, + ) = self.block_aware_cache.apply_mru_partial( + request.prompt_cache, + block_table, + request.remaining_tokens, + ) + if partial_applied > 0: + request.cached_tokens += partial_applied # For exact prefix hits we need cache state at (N-1) and the # last prompt token as input to produce the first decode logit. # Reusing cache state at N and feeding the last token again @@ -3428,6 +3476,10 @@ def set_specprefill_draft_model( model=draft_model, paged_cache_manager=draft_paged, paged_ssd_cache_manager=self.paged_ssd_cache_manager, + # MRU disabled on the draft cache: apply_mru_partial is + # only ever called on the main block_aware_cache, so a + # draft-cache stash is dead work that never pays off. + mru_partial_max_entries=0, ) self._draft_prefix_cache.set_cold_restore_callback( self._restore_block_from_cold @@ -5137,6 +5189,10 @@ def _cleanup_finished(self, finished_ids: set[str]) -> None: if pre_eval_arrays: mx.async_eval(*pre_eval_arrays) + # Prompt boundary for the MRU partial stash: + # token_sequence_to_store is prompt+output, but a + # repeat request resubmits the prompt only. + prompt_token_count = len(request.prompt_token_ids) if self._store_cache_executor is not None: store_future = self._store_cache_executor.submit( self._async_store_cache_worker, @@ -5148,6 +5204,7 @@ def _cleanup_finished(self, finished_ids: set[str]) -> None: request.vlm_extra_keys_for_cache, request.vlm_extra_key_token_start_for_cache, request.vlm_extra_key_ranges_for_cache, + prompt_token_count, ) self._inflight_store_futures[request_id] = store_future else: @@ -5161,6 +5218,7 @@ def _cleanup_finished(self, finished_ids: set[str]) -> None: request.vlm_extra_keys_for_cache, request.vlm_extra_key_token_start_for_cache, request.vlm_extra_key_ranges_for_cache, + prompt_token_count, ) logger.debug( f"Submitted async store_cache for {request_id} " @@ -5295,6 +5353,17 @@ def _cleanup_finished(self, finished_ids: set[str]) -> None: # finished their completeMemory() callbacks (#557). target = self._step_counter + self._DEFERRED_CLEAR_DELAY if self._deferred_clear_at is None or target > self._deferred_clear_at: + # Arm the suppression budget only when STARTING a new + # deferral epoch (transition from None). Subsequent + # completions in the same epoch may legitimately push the + # deadline out (for IOKit safety, #557) but must NOT + # refresh the budget — otherwise a hot-prompt workload + # whose completions arrive faster than _DEFERRED_CLEAR_DELAY + # could keep re-arming after the budget was spent and + # defer the clear forever, defeating the pool-bloat + # mitigation (#411). One suppression per epoch, total. + if self._deferred_clear_at is None: + self._mru_clear_suppression_available = True self._deferred_clear_at = target def _is_cache_corruption_error(self, error: Exception) -> bool: @@ -5322,6 +5391,7 @@ def _recover_from_cache_error(self) -> None: # Clear caches if self.block_aware_cache is not None: self.block_aware_cache.clear() + self._cache_rate_tracker.clear() # Clear UID mappings self.request_id_to_uid.clear() @@ -5329,6 +5399,7 @@ def _recover_from_cache_error(self) -> None: # Cancel any pending deferred Metal cache clear self._deferred_clear_at = None + self._mru_clear_suppression_available = False # Clear detokenizer state to prevent contamination after recovery self._request_detokenizers.clear() @@ -5560,8 +5631,26 @@ def step(self) -> SchedulerOutput: self._deferred_clear_at is not None and self._step_counter >= self._deferred_clear_at ): - should_clear = True - self._deferred_clear_at = None + # If the prefix cache is holding a warm MRU partial and we + # haven't yet spent the per-completion suppression budget, + # defer the clear by one more _DEFERRED_CLEAR_DELAY window. + # The MRU partial is a strong predictor that the next + # request will reuse the still-resident lazy KV tensors. + # The budget is one-shot, so total deferral is bounded at + # 2x _DEFERRED_CLEAR_DELAY even under hot-prompt repeats. + if ( + self._mru_clear_suppression_available + and self.block_aware_cache is not None + and self.block_aware_cache.has_mru_partial() + ): + self._deferred_clear_at = ( + self._step_counter + self._DEFERRED_CLEAR_DELAY + ) + self._mru_clear_suppression_available = False + else: + should_clear = True + self._deferred_clear_at = None + self._mru_clear_suppression_available = False if should_clear: _sync_and_clear_cache() if ( @@ -5651,6 +5740,7 @@ def reset(self) -> None: # Clear caches if self.block_aware_cache is not None: self.block_aware_cache.clear() + self._cache_rate_tracker.clear() # Clear detokenizers self._request_detokenizers.clear() @@ -5660,6 +5750,7 @@ def reset(self) -> None: # Cancel any pending deferred Metal cache clear self._deferred_clear_at = None + self._mru_clear_suppression_available = False def deep_reset(self) -> None: """ @@ -6083,6 +6174,41 @@ def restore_cold_blocks_for_request(self, request_id: str) -> int: return verified + def _collect_cache_counters(self) -> dict[str, int] | None: + if self.block_aware_cache is None: + return None + + prefix_stats = self.block_aware_cache.get_stats() + counters = { + "prefix_hits": prefix_stats.hits, + "prefix_misses": prefix_stats.misses, + "prefix_tokens_matched": prefix_stats.tokens_matched_total, + "prefix_tokens_requested": prefix_stats.tokens_requested_total, + "prefix_tokens_saved": prefix_stats.tokens_saved, + "evictions": prefix_stats.evictions, + "mru_partial_stashes": prefix_stats.mru_partial_stashes, + "mru_partial_hits": prefix_stats.mru_partial_hits, + "mru_partial_evictions": prefix_stats.mru_partial_evictions, + "mru_partial_tokens_saved": prefix_stats.mru_partial_tokens_saved, + "mru_partial_entries": prefix_stats.mru_partial_entries, + "mru_partial_max_entries": prefix_stats.mru_partial_max_entries, + } + + if self.paged_ssd_cache_manager is not None: + ssd = self.paged_ssd_cache_manager.get_stats() + hot_hits = ssd.hot_cache_hits + total_loads = ssd.loads + counters.update({ + "ssd_hot_hits": hot_hits, + "ssd_disk_loads": max(0, total_loads - hot_hits), + "ssd_saves": ssd.saves, + "ssd_errors": ssd.errors, + "hot_cache_evictions": ssd.hot_cache_evictions, + "hot_cache_promotions": ssd.hot_cache_promotions, + }) + + return counters + def get_ssd_cache_stats(self) -> dict[str, Any] | None: """Get paged SSD + prefix cache observability statistics.""" stats = {} @@ -6091,15 +6217,18 @@ def get_ssd_cache_stats(self) -> dict[str, Any] | None: stats["ssd_cache"] = self.paged_ssd_cache_manager.get_stats() if self.paged_cache_manager is not None: - # In paged SSD-only mode, all cache data is on paged SSD stats["indexed_blocks"] = self.paged_cache_manager.cold_block_count stats["block_size"] = self.config.paged_cache_block_size if self.block_aware_cache is not None: - # Expose prefix-cache observability so UI can distinguish - # "0 indexed blocks" from "sub-block cached ( Path: """ @@ -295,6 +300,7 @@ def to_dict(self) -> dict[str, Any]: "ssd_cache_max_size": self.ssd_cache_max_size, "hot_cache_max_size": self.hot_cache_max_size, "initial_cache_blocks": self.initial_cache_blocks, + "mru_partial_max_entries": self.mru_partial_max_entries, } @classmethod @@ -307,6 +313,7 @@ def from_dict(cls, data: dict[str, Any]) -> CacheSettings: ssd_cache_max_size=data.get("ssd_cache_max_size", "auto"), hot_cache_max_size=data.get("hot_cache_max_size", "0"), initial_cache_blocks=data.get("initial_cache_blocks", 256), + mru_partial_max_entries=data.get("mru_partial_max_entries", 4), ) diff --git a/tests/test_admin_api_key.py b/tests/test_admin_api_key.py index 8766b39ce..eb03106b6 100644 --- a/tests/test_admin_api_key.py +++ b/tests/test_admin_api_key.py @@ -655,6 +655,7 @@ def test_runtime_cache_uses_model_scoped_ssd_stats(self): mock_settings = MagicMock() mock_settings.base_path = Path("/tmp/omlx-base") mock_settings.cache.get_ssd_cache_dir.return_value = cache_dir + mock_settings.cache.get_ssd_cache_max_size_bytes.return_value = 0 shared_ssd_stats = { "num_files": 999, @@ -744,9 +745,13 @@ def test_runtime_cache_uses_model_scoped_ssd_stats(self): "last_tokens_to_next_block": 0, "num_files": 3, "total_size_bytes": 4096, + "max_size_bytes": 0, "hot_cache_max_bytes": 0, "hot_cache_size_bytes": 0, "hot_cache_entries": 0, + "mru_partial_entries": 0, + "mru_partial_max_entries": 0, + "mru_partial_supported": None, }, { "id": "model-b", @@ -760,9 +765,13 @@ def test_runtime_cache_uses_model_scoped_ssd_stats(self): "last_tokens_to_next_block": 0, "num_files": 7, "total_size_bytes": 8192, + "max_size_bytes": 0, "hot_cache_max_bytes": 0, "hot_cache_size_bytes": 0, "hot_cache_entries": 0, + "mru_partial_entries": 0, + "mru_partial_max_entries": 0, + "mru_partial_supported": None, }, ] manager_a.get_stats_for_model.assert_called_once_with("/models/model-a") @@ -775,6 +784,7 @@ def test_runtime_cache_ignores_single_model_stats_failure(self): mock_settings = MagicMock() mock_settings.base_path = Path("/tmp/omlx-base") mock_settings.cache.get_ssd_cache_dir.return_value = cache_dir + mock_settings.cache.get_ssd_cache_max_size_bytes.return_value = 0 bad_scheduler = MagicMock() bad_scheduler.get_ssd_cache_stats.side_effect = RuntimeError("boom") @@ -833,6 +843,7 @@ def test_runtime_cache_marks_sub_block_cached_when_indexed_blocks_zero(self): mock_settings = MagicMock() mock_settings.base_path = Path("/tmp/omlx-base") mock_settings.cache.get_ssd_cache_dir.return_value = cache_dir + mock_settings.cache.get_ssd_cache_max_size_bytes.return_value = 0 scheduler = MagicMock() scheduler.get_ssd_cache_stats.return_value = { @@ -876,6 +887,61 @@ def test_runtime_cache_marks_sub_block_cached_when_indexed_blocks_zero(self): assert model_payload["last_partial_tokens_skipped"] == 577 +class TestClearSSDCacheAlsoWipesMRUPartials: + """The ``/api/ssd-cache/clear`` admin endpoint must also clear the + MRU partial cache. Without this, partials chain from paged-block + hashes whose backing KV bytes were just flushed by ssd_manager.clear() + and the operator's "drop all warm caches" intent is only half-honored. + """ + + def _scheduler_with_mocks(self): + """Build a SimpleNamespace scheduler with mock ssd manager and + block_aware_cache that we can introspect after the endpoint runs.""" + ssd_manager = MagicMock() + ssd_manager.clear.return_value = 5 # arbitrary deleted count + block_aware_cache = MagicMock() + block_aware_cache.clear_mru_partials.return_value = 3 + return SimpleNamespace( + paged_ssd_cache_manager=ssd_manager, + block_aware_cache=block_aware_cache, + ) + + def test_endpoint_calls_clear_mru_partials_on_each_scheduler(self): + scheduler_a = self._scheduler_with_mocks() + scheduler_b = self._scheduler_with_mocks() + + with patch.object( + admin_routes, + "_iter_loaded_schedulers", + return_value=iter([("model-a", scheduler_a), ("model-b", scheduler_b)]), + ), patch.object(admin_routes, "_get_global_settings", return_value=None): + asyncio.run(admin_routes.clear_ssd_cache(is_admin=True)) + + scheduler_a.paged_ssd_cache_manager.clear.assert_called_once() + scheduler_a.block_aware_cache.clear_mru_partials.assert_called_once() + scheduler_b.paged_ssd_cache_manager.clear.assert_called_once() + scheduler_b.block_aware_cache.clear_mru_partials.assert_called_once() + + def test_mru_clear_failure_does_not_block_other_scheduler(self): + """A failure in one scheduler's clear_mru_partials must be logged + but not prevent other schedulers from being cleared.""" + scheduler_a = self._scheduler_with_mocks() + scheduler_a.block_aware_cache.clear_mru_partials.side_effect = RuntimeError( + "boom" + ) + scheduler_b = self._scheduler_with_mocks() + + with patch.object( + admin_routes, + "_iter_loaded_schedulers", + return_value=iter([("model-a", scheduler_a), ("model-b", scheduler_b)]), + ), patch.object(admin_routes, "_get_global_settings", return_value=None): + asyncio.run(admin_routes.clear_ssd_cache(is_admin=True)) + + # Scheduler B must still have been cleared despite A's failure. + scheduler_b.block_aware_cache.clear_mru_partials.assert_called_once() + + class TestGlobalSettingsValidation: """Tests for stricter GlobalSettingsRequest validation.""" diff --git a/tests/test_cache_observability.py b/tests/test_cache_observability.py new file mode 100644 index 000000000..d0da07dca --- /dev/null +++ b/tests/test_cache_observability.py @@ -0,0 +1,241 @@ +# tests/test_cache_observability.py +# SPDX-License-Identifier: Apache-2.0 +"""Tests for cache observability module.""" + +import threading +import time +from unittest.mock import patch + +import pytest + +from omlx.cache.observability import CacheRateTracker + + +# Counter keys produced by Scheduler._collect_cache_counters. Adding a new +# observability counter means adding it here and in the production code +# in lockstep; the explicit-kwargs alternative duplicated the schema once +# in the function signature and once in the dict construction. +_COUNTER_KEYS = ( + "prefix_hits", + "prefix_misses", + "prefix_tokens_matched", + "prefix_tokens_requested", + "prefix_tokens_saved", + "evictions", + "ssd_hot_hits", + "ssd_disk_loads", + "ssd_saves", + "ssd_errors", + "hot_cache_evictions", + "hot_cache_promotions", + "mru_partial_stashes", + "mru_partial_hits", + "mru_partial_evictions", + "mru_partial_tokens_saved", +) + + +def _make_counters(**overrides): + """Build a counter dict for snapshot testing. + + All keys default to ``0``; override individual values via kwargs. + Unknown keys raise — the typo-catching the explicit-signature form + used to provide. + """ + unknown = set(overrides) - set(_COUNTER_KEYS) + if unknown: + raise ValueError( + f"Unknown counter keys: {sorted(unknown)}. " + f"Known: {sorted(_COUNTER_KEYS)}" + ) + counters = {k: 0 for k in _COUNTER_KEYS} + counters.update(overrides) + return counters + + +class TestCacheRateTrackerSnapshot: + + def test_empty_tracker_returns_empty_rates(self): + tracker = CacheRateTracker() + result = tracker.get_rates() + assert result == {"windows": {}, "cumulative": {}} + + def test_first_snapshot_always_accepted(self): + tracker = CacheRateTracker(min_interval=10.0) + assert tracker.maybe_snapshot(_make_counters()) is True + + def test_snapshot_rejected_within_min_interval(self): + tracker = CacheRateTracker(min_interval=10.0) + tracker.maybe_snapshot(_make_counters()) + assert tracker.maybe_snapshot(_make_counters()) is False + + def test_snapshot_accepted_after_min_interval(self): + tracker = CacheRateTracker(min_interval=0.0) + tracker.maybe_snapshot(_make_counters()) + assert tracker.maybe_snapshot(_make_counters()) is True + + def test_deque_overflow_evicts_oldest(self): + tracker = CacheRateTracker(max_snapshots=3, min_interval=0.0) + for i in range(5): + tracker.maybe_snapshot(_make_counters(prefix_hits=i)) + result = tracker.get_rates() + assert result["cumulative"]["prefix_hits"] == 4 + + +class TestCacheRateTrackerRates: + + def _tracker_with_two_snapshots(self, old_counters, new_counters, elapsed=60.0): + tracker = CacheRateTracker(min_interval=0.0) + fake_time = [1000.0] + + def mock_monotonic(): + return fake_time[0] + + with patch("omlx.cache.observability.time.monotonic", side_effect=mock_monotonic): + tracker.maybe_snapshot(old_counters) + + fake_time[0] = 1000.0 + elapsed + with patch("omlx.cache.observability.time.monotonic", side_effect=mock_monotonic): + tracker.maybe_snapshot(new_counters) + + with patch("omlx.cache.observability.time.monotonic", return_value=fake_time[0]): + return tracker.get_rates(windows=(60, 300, 900)) + + def test_steady_state_prefix_hit_rate(self): + old = _make_counters(prefix_hits=100, prefix_misses=50) + new = _make_counters(prefix_hits=200, prefix_misses=75) + result = self._tracker_with_two_snapshots(old, new, elapsed=60.0) + assert result["windows"]["1m"]["prefix_hit_rate"] == 0.8 + + def test_zero_activity_window_no_nan(self): + counters = _make_counters(prefix_hits=50, prefix_misses=10) + result = self._tracker_with_two_snapshots(counters, counters, elapsed=60.0) + assert result["windows"]["1m"]["prefix_hit_rate"] == 0.0 + assert result["windows"]["1m"]["prefix_match_efficiency"] == 0.0 + assert result["windows"]["1m"]["eviction_rate_per_min"] == 0.0 + + def test_eviction_rate_per_min(self): + old = _make_counters(evictions=10) + new = _make_counters(evictions=40) + result = self._tracker_with_two_snapshots(old, new, elapsed=300.0) + assert result["windows"]["5m"]["eviction_rate_per_min"] == 6.0 + + def test_prefix_match_efficiency(self): + old = _make_counters(prefix_tokens_matched=0, prefix_tokens_requested=0) + new = _make_counters(prefix_tokens_matched=600, prefix_tokens_requested=1000) + result = self._tracker_with_two_snapshots(old, new, elapsed=60.0) + assert result["windows"]["1m"]["prefix_match_efficiency"] == 0.6 + + def test_ssd_hot_rate(self): + old = _make_counters(ssd_hot_hits=0, ssd_disk_loads=0) + new = _make_counters(ssd_hot_hits=80, ssd_disk_loads=20) + result = self._tracker_with_two_snapshots(old, new, elapsed=60.0) + assert result["windows"]["1m"]["ssd_hot_rate"] == 0.8 + + def test_mru_partial_hit_rate(self): + """Stash payoff = hits / stashes. Workload that stashed 100 and + only got 75 hits has rate 0.75; high enough to justify the + feature, low enough that a smaller capacity might do.""" + old = _make_counters(mru_partial_stashes=0, mru_partial_hits=0) + new = _make_counters(mru_partial_stashes=100, mru_partial_hits=75) + result = self._tracker_with_two_snapshots(old, new, elapsed=60.0) + assert result["windows"]["1m"]["mru_partial_hit_rate"] == 0.75 + assert result["cumulative"]["mru_partial_hit_rate"] == 0.75 + + def test_mru_partial_zero_stashes_no_nan(self): + """If no stashes happened in the window, hit_rate must be 0.0 + not NaN. Mirrors the prefix_hit_rate empty-window guard.""" + counters = _make_counters(mru_partial_stashes=0, mru_partial_hits=0) + result = self._tracker_with_two_snapshots(counters, counters, elapsed=60.0) + assert result["windows"]["1m"]["mru_partial_hit_rate"] == 0.0 + assert result["cumulative"]["mru_partial_hit_rate"] == 0.0 + + def test_mru_partial_tokens_saved_delta(self): + """Tokens saved is the direct compute-saved measure; it + accumulates regardless of hit rate.""" + old = _make_counters(mru_partial_tokens_saved=0) + new = _make_counters(mru_partial_tokens_saved=12345) + result = self._tracker_with_two_snapshots(old, new, elapsed=60.0) + assert result["windows"]["1m"]["mru_partial_tokens_saved"] == 12345 + assert result["cumulative"]["mru_partial_tokens_saved"] == 12345 + + def test_insufficient_data_returns_empty_window(self): + tracker = CacheRateTracker(min_interval=0.0) + + with patch("omlx.cache.observability.time.monotonic", return_value=1000.0): + tracker.maybe_snapshot(_make_counters(prefix_hits=10)) + + with patch("omlx.cache.observability.time.monotonic", return_value=1000.5): + tracker.maybe_snapshot(_make_counters(prefix_hits=20)) + + with patch("omlx.cache.observability.time.monotonic", return_value=1000.5): + result = tracker.get_rates(windows=(60,)) + assert result["windows"]["1m"] == {} + + def test_cumulative_uses_latest_snapshot(self): + old = _make_counters(prefix_hits=10, prefix_misses=5) + new = _make_counters(prefix_hits=100, prefix_misses=20) + result = self._tracker_with_two_snapshots(old, new, elapsed=60.0) + assert result["cumulative"]["prefix_hits"] == 100 + assert result["cumulative"]["prefix_misses"] == 20 + assert abs(result["cumulative"]["prefix_hit_rate"] - 0.8333) < 0.001 + + +class TestCacheRateTrackerSnapshotAndGetRates: + + def test_combines_snapshot_and_rates(self): + tracker = CacheRateTracker(min_interval=0.0) + + with patch("omlx.cache.observability.time.monotonic", return_value=1000.0): + tracker.maybe_snapshot(_make_counters(prefix_hits=0)) + + with patch("omlx.cache.observability.time.monotonic", return_value=1060.0): + result = tracker.snapshot_and_get_rates( + _make_counters(prefix_hits=80, prefix_misses=20) + ) + + assert result["windows"]["1m"]["prefix_hit_rate"] == 0.8 + assert result["cumulative"]["prefix_hits"] == 80 + + +class TestCacheRateTrackerThreadSafety: + + def test_concurrent_snapshot_and_read(self): + tracker = CacheRateTracker(min_interval=0.0) + errors = [] + stop = threading.Event() + + def writer(): + i = 0 + while not stop.is_set(): + try: + tracker.maybe_snapshot(_make_counters(prefix_hits=i)) + i += 1 + except Exception as e: + errors.append(e) + + def reader(): + while not stop.is_set(): + try: + tracker.get_rates() + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=writer), threading.Thread(target=reader)] + for t in threads: + t.start() + time.sleep(0.2) + stop.set() + for t in threads: + t.join(timeout=2.0) + + assert errors == [], f"Thread errors: {errors}" + + +class TestCacheRateTrackerClear: + + def test_clear_resets_state(self): + tracker = CacheRateTracker(min_interval=0.0) + tracker.maybe_snapshot(_make_counters(prefix_hits=100)) + tracker.clear() + assert tracker.get_rates() == {"windows": {}, "cumulative": {}} diff --git a/tests/test_prefix_cache.py b/tests/test_prefix_cache.py index ded2395f5..42d4e5d19 100644 --- a/tests/test_prefix_cache.py +++ b/tests/test_prefix_cache.py @@ -6,6 +6,7 @@ PagedCacheManager for block-based storage with SSD persistence. """ +import logging import time from pathlib import Path from typing import Any, Dict, List, Optional @@ -20,7 +21,11 @@ PagedCacheManager, compute_block_hash, ) -from omlx.cache.prefix_cache import BlockAwarePrefixCache, BlockCacheEntry +from omlx.cache.prefix_cache import ( + BlockAwarePrefixCache, + BlockCacheEntry, + _MRUPartialBlock, +) from omlx.cache.stats import PrefixCacheStats @@ -2413,3 +2418,1752 @@ def test_store_cache_last_block_with_snapshot_uses_snapshot_meta(self, mx): f"Last block should use snapshot offset=8, not shared offset=11, " f"got {b2_meta[1]}" ) + + +def _get_mru_partial(cache, parent_hash): + """Test-only accessor for one MRU partial entry by parent_hash key. + + Lives in the test module rather than on ``BlockAwarePrefixCache`` + itself: production code has no consumer that looks up entries by + arbitrary key (the scheduler only needs the boolean predicate via + ``has_mru_partial()`` and the dict lookup inside ``apply_mru_partial``). + Tests use this helper to assert on individual entries without + coupling to the internal ``_mru_partials`` container shape. + """ + return cache._mru_partials.get(parent_hash) + + +def _layer(mx, n_tokens, *, class_name="KVCache", head_dim=4, n_kv_heads=1, fill=1.0): + """Build a layer-state dict for store_cache. + + ``class_name`` selects the cache type (e.g. ``"KVCache"``, + ``"RotatingKVCache"``, ``"BatchRotatingKVCache"``). Both + ``cache_type`` and ``class_name`` keys are populated with the same + string — ``store_cache`` consults whichever the layer provides. + """ + return { + "state": ( + mx.full((1, n_kv_heads, n_tokens, head_dim), fill), + mx.full((1, n_kv_heads, n_tokens, head_dim), fill), + ), + "cache_type": class_name, + "class_name": class_name, + } + + +def _kv_layer(mx, n_tokens, head_dim=4, n_kv_heads=1, fill=1.0): + return _layer( + mx, n_tokens, + class_name="KVCache", + head_dim=head_dim, n_kv_heads=n_kv_heads, fill=fill, + ) + + +def _rotating_layer(mx, n_tokens, head_dim=4, n_kv_heads=1): + return _layer( + mx, n_tokens, + class_name="RotatingKVCache", + head_dim=head_dim, n_kv_heads=n_kv_heads, + ) + + +def _make_reconstructed_cache(mx, n_layers, n_tokens, head_dim=4): + """Build a list of MockKVCache objects matching what reconstruct_cache + would produce: keys.shape[2] == offset, valid region only.""" + class MockKVCache: + def __init__(self, k, v, offset): + self.keys = k + self.values = v + self.offset = offset + + return [ + MockKVCache( + mx.ones((1, 1, n_tokens, head_dim)), + mx.ones((1, 1, n_tokens, head_dim)), + n_tokens, + ) + for _ in range(n_layers) + ] + + +def _make_mru_cache(paged_cache, mock_ssd, max_entries=4, num_layers=4): + """Construct a ``BlockAwarePrefixCache`` with a custom MRU capacity.""" + return BlockAwarePrefixCache( + model=MockModel(num_layers=num_layers), + paged_cache_manager=paged_cache, + paged_ssd_cache_manager=mock_ssd, + mru_partial_max_entries=max_entries, + ) + + +def _stash_with_prefix(cache, mx, prefix_marker, tail_token): + """Stash a partial under a distinct parent_hash for multi-slot tests. + + Builds a prompt whose first 4 tokens are unique to ``prefix_marker`` + (forcing a unique parent block hash) and whose 5th token is the + partial tail. Returns ``(block_table, parent_hash)``. + """ + tokens = [prefix_marker * 10 + i for i in range(4)] + [tail_token] + cache_data = [_kv_layer(mx, 5) for _ in range(4)] + block_table = cache.store_cache(f"req-{prefix_marker}", tokens, cache_data) + parent_hash = cache.paged_cache.allocated_blocks[ + block_table.block_ids[-1] + ].block_hash + return block_table, parent_hash + + +class TestMRUPartialBlockCache: + """Tests for the MRU partial block cache. + + The MRU is a bounded LRU dict of trailing sub-block tails keyed by + ``parent_hash``. It lets exact-repeat requests skip re-prefilling + those tail tokens, and tolerates interleaving (multi-user / multi- + conversation workloads) by keeping multiple distinct-prefix entries + coexistent up to ``mru_partial_max_entries``. + + Threat-model coverage these tests enforce: + + - **Hybrid refusal:** when any layer is non-sliceable (RotatingKVCache, + ArraysCache, etc.), the stash is suppressed entirely. Splicing into + only the sliceable layers would create per-layer offset skew at + decode time. + - **Transactional splice:** if any layer's concatenate fails, no layer + sees a mutated keys/values/offset. Half-mutated caches are silent + generation corruption. + - **Real round-trip:** ``store_cache`` populates entries via the + production extraction path; ``apply_mru_partial`` then splices. + Tests do not hand-build ``_MRUPartialBlock`` objects for splice + cases — that hides the extraction-vs-apply boundary the original + single-slot branch's tests missed. + """ + + @pytest.fixture + def mx(self): + try: + import mlx.core as mx + return mx + except ImportError: + pytest.skip("MLX not available") + + @pytest.fixture + def paged_cache(self): + return PagedCacheManager( + block_size=4, + max_blocks=100, + model_name="test-model", + initial_blocks=100, + ) + + @pytest.fixture + def mock_ssd(self): + """An SSD manager mock — present, not used. + + The MRU stash gates on ``paged_ssd_cache is not None`` because in + the no-SSD configuration ``reconstruct_cache`` returns ``None`` and + ``apply_mru_partial`` is unreachable; stashing then would only + produce dead memory. Tests that exercise stash/apply directly + need an SSD instance present even though the mocked save/load + paths are not exercised. + """ + mock = MagicMock() + mock.save_block.return_value = True + mock.load_block.return_value = None + mock.load_block_with_metadata.return_value = (None, None) + mock.has_block.return_value = False + return mock + + @pytest.fixture + def prefix_cache(self, paged_cache, mock_ssd): + return BlockAwarePrefixCache( + model=MockModel(num_layers=4), + paged_cache_manager=paged_cache, + paged_ssd_cache_manager=mock_ssd, + ) + + # --- initial state --- + + def test_init_state_empty(self, prefix_cache): + assert not prefix_cache._mru_partials + assert prefix_cache.has_mru_partial() is False + + # --- stash semantics on uniformly sliceable layers --- + + def test_stash_after_store_with_trailing_tokens(self, prefix_cache, mx): + """6 tokens, block_size=4 → 1 full block + 2 trailing → stash captured.""" + tokens = [10, 20, 30, 40, 50, 60] + cache_data = [_kv_layer(mx, 6) for _ in range(4)] + + block_table = prefix_cache.store_cache("req-stash", tokens, cache_data) + + parent_hash = prefix_cache.paged_cache.allocated_blocks[ + block_table.block_ids[-1] + ].block_hash + partial = _get_mru_partial(prefix_cache, parent_hash) + assert partial is not None + assert partial.tokens == [50, 60] + assert len(partial.kv_data) == 4 + assert prefix_cache.has_mru_partial() is True + + def test_no_stash_when_block_aligned(self, prefix_cache, mx): + """Block-aligned tokens leave no trailing partial → no entry written.""" + tokens = [10, 20, 30, 40] + cache_data = [_kv_layer(mx, 4) for _ in range(4)] + + prefix_cache.store_cache("req-aligned", tokens, cache_data) + + assert not prefix_cache._mru_partials + assert prefix_cache.has_mru_partial() is False + + def test_same_prefix_store_replaces_entry(self, prefix_cache, mx): + """Same prefix → same parent_hash → same dict key → replace. + + Two stores with identical prefix tokens but different tails + collide on the same key (parent_hash chains from identical + prefix blocks). The newer tail replaces the older one in the + single dict entry — that is correct LRU put behavior. + """ + for tail in (50, 99): + tokens = [10, 20, 30, 40, tail] + cache_data = [_kv_layer(mx, 5) for _ in range(4)] + prefix_cache.store_cache(f"req-{tail}", tokens, cache_data) + + # Exactly one entry; its tokens are the latest tail. + assert len(prefix_cache._mru_partials) == 1 + partial = next(iter(prefix_cache._mru_partials.values())) + assert partial.tokens == [99] + + def test_no_eligible_tail_does_not_evict_siblings( + self, prefix_cache, mx + ): + """Behavioral change vs single-slot: a block-aligned store (no + trailing tail) MUST NOT wipe sibling entries from other prefixes. + + Single-slot mode used to clear the lone slot in this branch. + Multi-slot mode treats "nothing eligible to stash this time" + as a local signal — sibling entries for distinct prefixes are + unrelated and stay. + """ + # First: stash a partial via prefix A. + prefix_cache.store_cache( + "req-a", [10, 20, 30, 40, 50], [_kv_layer(mx, 5) for _ in range(4)] + ) + assert len(prefix_cache._mru_partials) == 1 + before_key = next(iter(prefix_cache._mru_partials.keys())) + + # Second: block-aligned store on a DIFFERENT prefix — no tail to + # stash, but must not evict the existing sibling. + prefix_cache.store_cache( + "req-b", [11, 22, 33, 44], [_kv_layer(mx, 4) for _ in range(4)] + ) + assert len(prefix_cache._mru_partials) == 1 + assert next(iter(prefix_cache._mru_partials.keys())) == before_key + + def test_stash_records_parent_hash_from_last_block(self, prefix_cache, mx): + """Stashed entry is keyed by the hash of the last full block.""" + tokens = [10, 20, 30, 40, 50, 60] + cache_data = [_kv_layer(mx, 6) for _ in range(4)] + + block_table = prefix_cache.store_cache("req-hash", tokens, cache_data) + + # The last (and only) block's hash should be the dict key AND + # the partial's stored parent_hash. + last_block = prefix_cache.paged_cache.allocated_blocks[ + block_table.block_ids[-1] + ] + assert last_block.block_hash is not None + partial = _get_mru_partial(prefix_cache, last_block.block_hash) + assert partial is not None + assert partial.parent_hash == last_block.block_hash + + # --- threat model: hybrid refusal (B1, B2) --- + + def test_refuse_stash_when_any_layer_non_sliceable_hybrid( + self, paged_cache, mock_ssd, mx + ): + """Hybrid model (KVCache + RotatingKVCache): no stash. + + Splicing into only the sliceable layers produces per-layer offset + skew at decode time (review B2). The only correct behavior is to + refuse the partial entirely for hybrid models. + """ + from omlx.cache.hybrid_cache import ModelCacheConfig + + cache = BlockAwarePrefixCache( + model=MockModel(num_layers=2), + paged_cache_manager=paged_cache, + paged_ssd_cache_manager=mock_ssd, + ) + tokens = [10, 20, 30, 40, 50, 60] + cache_data = [ + _kv_layer(mx, 6), + _rotating_layer(mx, 6), + ] + config = ModelCacheConfig.from_type_list( + ["KVCache", "RotatingKVCache"], model_name="test" + ) + + cache.store_cache("req-hybrid", tokens, cache_data, model_cache_config=config) + + assert not cache._mru_partials + assert cache.has_mru_partial() is False + + def test_refuse_stash_when_all_layers_non_sliceable(self, prefix_cache, mx): + """Pure RotatingKVCache model also refuses stash.""" + from omlx.cache.hybrid_cache import ModelCacheConfig + + tokens = [10, 20, 30, 40, 50] + cache_data = [_rotating_layer(mx, 5) for _ in range(4)] + config = ModelCacheConfig.from_type_list( + ["RotatingKVCache"] * 4, model_name="test" + ) + + prefix_cache.store_cache( + "req-rotating", tokens, cache_data, model_cache_config=config + ) + + assert not prefix_cache._mru_partials + + def test_refuse_stash_when_layer_falls_through_to_default_handler( + self, prefix_cache, mx + ): + """Non-sliceable types whose handler is unregistered (fall through + to ``DefaultCacheHandler``, which inherits ``KVCacheHandler``'s + ``supports_block_slicing=True``) must still be refused. + + Concrete case: ``BatchRotatingKVCache`` is mapped in + ``_class_name_map`` to ``BATCH_ROTATING_KVCACHE`` but no handler + is registered for that enum. The original rewrite's + registry-based gate would have classified it as sliceable, + recreating exactly the silent-corruption hazard the rewrite was + supposed to close, just from a different angle. The fix uses an + explicit class-name whitelist (``KNOWN_SLICEABLE_CACHE_TYPES``) + instead of the registry. + """ + from omlx.cache.hybrid_cache import ModelCacheConfig + from omlx.cache.type_registry import ( + KNOWN_SLICEABLE_CACHE_TYPES, + CacheTypeRegistry, + ) + + # Sanity: the registry would lie about this class name. + handler = CacheTypeRegistry.get_handler_by_class_name("BatchRotatingKVCache") + assert handler.supports_block_slicing is True + # And the whitelist correctly excludes it. + assert "BatchRotatingKVCache" not in KNOWN_SLICEABLE_CACHE_TYPES + + tokens = [10, 20, 30, 40, 50, 60] + # cache_data shape doesn't matter — store_cache must refuse before + # any extraction is attempted. + cache_data = [_kv_layer(mx, 6) for _ in range(4)] + config = ModelCacheConfig.from_type_list( + ["BatchRotatingKVCache"] * 4, model_name="test" + ) + + prefix_cache.store_cache( + "req-batch-rotating", tokens, cache_data, model_cache_config=config + ) + + assert not prefix_cache._mru_partials + assert prefix_cache.has_mru_partial() is False + + # --- threat model: stale-slot eviction at clear() (C2) --- + + def test_clear_wipes_mru_partials(self, prefix_cache, mx): + """``BlockAwarePrefixCache.clear()`` must drop the entire MRU dict. + + The scheduler's cache-corruption recovery routes through + ``clear()``. Surviving partials chain from paged-block hashes + whose backing blocks were just freed; the dict is wiped so no + entry can survive into the recovery path that exists because + something was wrong. + """ + prefix_cache.store_cache( + "req-clear", + [10, 20, 30, 40, 50, 60], + [_kv_layer(mx, 6) for _ in range(4)], + ) + assert bool(prefix_cache._mru_partials) + + prefix_cache.clear() + + assert not prefix_cache._mru_partials + assert prefix_cache.has_mru_partial() is False + + # --- threat model: H2 ambiguous cache layout --- + + def test_refuse_stash_on_ambiguous_cache_layout( + self, prefix_cache, mx + ): + """Cache lengths that don't unambiguously map to global or local + indexing must refuse the stash. + + Multi-turn requests can produce ``cache_seq_len == + existing_tokens`` or shapes between local and global. The + previous heuristic (``cache_seq_len >= existing_tokens + 1``) + silently picked "local" on the boundary, slicing local indices + out of a global-indexed cache and capturing tokens from the + prefix instead of the trailing tail. parent_hash still matched, + and a future apply spliced wrong KV — silent generation + corruption. + + Drive that boundary directly: cache_seq_len falls strictly + between global_end and local_len. + """ + # First turn: cache 4 tokens. + prefix_cache.store_cache( + "req-turn-1", + [1, 2, 3, 4], + [_kv_layer(mx, 4) for _ in range(4)], + ) + + # Second turn: 8 prefix-aligned tokens (1 full block + 1 partial-block). + # Hand a cache_data whose cache_seq_len is 6 — strictly between: + # - local_len = len(new_tokens) = len(tokens) - existing_tokens = 4 + # - global_end = existing_tokens + new_count = 4 + 4 = 8 + # global_end (8) > cache_seq_len (6) > local_len (4): ambiguous. + full_tokens = [1, 2, 3, 4, 5, 6, 7, 8] + cache_data = [_kv_layer(mx, 6) for _ in range(4)] + + prefix_cache.store_cache( + "req-turn-2-ambiguous", full_tokens, cache_data + ) + + # Refuse rather than guess. Stash must be the previous turn's + # state cleared (block-aligned turn 1 has no stash anyway), not + # a guessed-wrong turn 2. + assert not prefix_cache._mru_partials + + # --- accounting invariant --- + + def test_kv_data_holds_mlx_arrays_for_active_memory_accounting( + self, prefix_cache, mx + ): + """The MRU slot's memory must flow through ``mx.get_active_memory()``. + + The codebase enforces all KV-memory limits via ``mx.get_active_memory()`` + (process_memory_enforcer, the three scheduler memory checkpoints, + the periodic-clear threshold, telemetry). The MRU slot has no + separate accounting hook — it relies on the invariant that + ``kv_data`` holds real ``mx.array`` allocations, which MLX counts + in active memory automatically. + + A "helpful" future change that stored CPU-side copies (e.g. + ``np.ndarray`` to dodge a perceived GPU-memory cost) would silently + escape every existing memory limit and only manifest as system OOM + under load. Pin the invariant so that change is caught at test + time, not in production. + """ + block_table = prefix_cache.store_cache( + "req-accounting", + [10, 20, 30, 40, 50, 60], + [_kv_layer(mx, 6) for _ in range(4)], + ) + + parent_hash = prefix_cache.paged_cache.allocated_blocks[ + block_table.block_ids[-1] + ].block_hash + partial = _get_mru_partial(prefix_cache, parent_hash) + assert partial is not None + assert len(partial.kv_data) == 4 + for layer_idx, (keys, values) in enumerate(partial.kv_data): + assert isinstance(keys, mx.array), ( + f"layer {layer_idx} keys is {type(keys).__name__}, not mx.array. " + f"MRU memory accounting depends on mx.array storage so the " + f"slot is visible to mx.get_active_memory()." + ) + assert isinstance(values, mx.array), ( + f"layer {layer_idx} values is {type(values).__name__}, " + f"not mx.array. See above." + ) + + # --- threat model: no-reconstruct-path config --- + + def test_no_stash_when_paged_ssd_cache_is_none(self, paged_cache, mx): + """Without a ``PagedSSDCacheManager`` instance, ``reconstruct_cache`` + returns ``None`` (``_can_reconstruct() is False``) and + ``apply_mru_partial`` is unreachable from the scheduler. Stashing + in this configuration would only produce dead memory. + + Note: this is distinct from ``hot_cache_only=True``, where the + manager IS present (the disk writer thread is what's disabled, + not the manager itself). In that mode the MRU stash IS expected + to populate — ``load_block_with_metadata`` short-circuits to the + hot tier and reconstruct still works. The gate keys on manager + presence, not on whether SSD writes are happening. + """ + cache = BlockAwarePrefixCache( + model=MockModel(num_layers=4), + paged_cache_manager=paged_cache, + paged_ssd_cache_manager=None, + ) + + cache.store_cache( + "req-no-ssd", + [10, 20, 30, 40, 50, 60], + [_kv_layer(mx, 6) for _ in range(4)], + ) + + assert not cache._mru_partials + assert cache.has_mru_partial() is False + + def test_can_reconstruct_helper_reflects_manager_presence( + self, paged_cache, mock_ssd + ): + """``_can_reconstruct`` is the canonical predicate keeping the + MRU stash gate and the ``reconstruct_cache`` guard in lockstep. + + It returns False only when no manager is configured at all. + ``hot_cache_only=True`` configurations (manager present, disk + writer disabled) return True because reconstruct still works + via the hot-tier short-circuit in + ``PagedSSDCacheManager.load_block_with_metadata``. + """ + cache_with = BlockAwarePrefixCache( + model=MockModel(num_layers=2), + paged_cache_manager=paged_cache, + paged_ssd_cache_manager=mock_ssd, + ) + assert cache_with._can_reconstruct() is True + + cache_without = BlockAwarePrefixCache( + model=MockModel(num_layers=2), + paged_cache_manager=paged_cache, + paged_ssd_cache_manager=None, + ) + assert cache_without._can_reconstruct() is False + + # --- apply: real round-trip --- + + def test_apply_round_trip_exact_match(self, prefix_cache, mx): + """Real store → apply round-trip: partial produced by extraction + is consumed by the splice path, no hand-built _MRUPartialBlock.""" + tokens = [10, 20, 30, 40, 50, 60] + cache_data = [_kv_layer(mx, 6) for _ in range(4)] + block_table = prefix_cache.store_cache("req-rt", tokens, cache_data) + + # Reconstructed cache: 4 layers × 4 tokens (the prefix only). + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + remaining = [50, 60] + + result, new_remaining, applied = prefix_cache.apply_mru_partial( + reconstructed, block_table, remaining, + ) + + assert applied == 2 + assert new_remaining == [] + assert all(layer.offset == 6 for layer in result) + assert all(layer.keys.shape[2] == 6 for layer in result) + assert all(layer.values.shape[2] == 6 for layer in result) + + def test_apply_round_trip_prefix_match_leaves_extra_tokens( + self, prefix_cache, mx + ): + """When remaining is longer than the partial, the partial covers + its prefix and the rest is left for normal prefill.""" + tokens = [10, 20, 30, 40, 50, 60] + cache_data = [_kv_layer(mx, 6) for _ in range(4)] + block_table = prefix_cache.store_cache("req-rt-prefix", tokens, cache_data) + + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + remaining = [50, 60, 70, 80] # partial is [50, 60]; [70, 80] left over + + _, new_remaining, applied = prefix_cache.apply_mru_partial( + reconstructed, block_table, remaining, + ) + + assert applied == 2 + assert new_remaining == [70, 80] + + # --- apply: eviction reasons --- + + def test_apply_noop_on_parent_hash_mismatch_preserves_sibling( + self, prefix_cache, mx, paged_cache + ): + """A request keyed by a parent_hash that isn't in the dict + returns no-op WITHOUT evicting unrelated sibling entries. + + Behavioral change from single-slot: the single slot used to be + evicted whenever the lookup key didn't match. In multi-slot, + the lookup simply misses and other entries are preserved. + """ + # Stash a partial under prefix A. + tokens = [10, 20, 30, 40, 50, 60] + block_table_a = prefix_cache.store_cache( + "req-a", tokens, [_kv_layer(mx, 6) for _ in range(4)] + ) + before = dict(prefix_cache._mru_partials) + assert len(before) == 1 + + # Construct a synthetic block_table pointing at a block whose + # hash is NOT a key in the dict (simulate "request for a + # different prefix that has its own paged block"). + other_block = paged_cache.allocate_block() + other_block.block_hash = b"\x00" * 32 # not in the MRU dict + synthetic_bt = BlockTable(request_id="req-other") + synthetic_bt.block_ids.append(other_block.block_id) + synthetic_bt.num_tokens = 4 + + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + _, new_remaining, applied = prefix_cache.apply_mru_partial( + reconstructed, synthetic_bt, [50, 60], + ) + + assert applied == 0 + assert new_remaining == [50, 60] + # Prefix A's entry must still be present — no false eviction. + assert dict(prefix_cache._mru_partials) == before + + def test_apply_evicts_on_token_mismatch(self, prefix_cache, mx): + """Different trailing tokens → partial cannot apply, evict.""" + tokens = [10, 20, 30, 40, 50, 60] + cache_data = [_kv_layer(mx, 6) for _ in range(4)] + block_table = prefix_cache.store_cache("req-evict-t", tokens, cache_data) + + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + _, new_remaining, applied = prefix_cache.apply_mru_partial( + reconstructed, block_table, [99, 60], # first token doesn't match + ) + + assert applied == 0 + assert new_remaining == [99, 60] + assert not prefix_cache._mru_partials + + def test_apply_evicts_on_remaining_shorter_than_partial( + self, prefix_cache, mx + ): + """If remaining_tokens is shorter than the partial it cannot match.""" + tokens = [10, 20, 30, 40, 50, 60, 70] + cache_data = [_kv_layer(mx, 7) for _ in range(4)] + block_table = prefix_cache.store_cache("req-evict-s", tokens, cache_data) + # Partial is [50, 60, 70]; remaining is shorter → must evict. + + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + _, new_remaining, applied = prefix_cache.apply_mru_partial( + reconstructed, block_table, [50, 60], + ) + + assert applied == 0 + assert not prefix_cache._mru_partials + + def test_apply_evicts_on_layer_count_mismatch(self, prefix_cache, mx): + """If the reconstructed cache layer count differs from the + stashed partial, evict — likely a model swap or bug, not safe.""" + tokens = [10, 20, 30, 40, 50, 60] + cache_data = [_kv_layer(mx, 6) for _ in range(4)] + block_table = prefix_cache.store_cache("req-evict-lc", tokens, cache_data) + + # Reconstructed has only 2 layers, partial has 4 → mismatch. + reconstructed = _make_reconstructed_cache(mx, n_layers=2, n_tokens=4) + _, _, applied = prefix_cache.apply_mru_partial( + reconstructed, block_table, [50, 60], + ) + + assert applied == 0 + assert not prefix_cache._mru_partials + + def test_apply_noop_when_no_stash(self, prefix_cache, paged_cache, mx): + """No partial → no-op, remaining unchanged.""" + block_table = paged_cache.create_block_table("req-noop") + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=0) + + result, remaining, applied = prefix_cache.apply_mru_partial( + reconstructed, block_table, [10, 20], + ) + + assert applied == 0 + assert remaining == [10, 20] + assert result is reconstructed + + def test_apply_noop_when_remaining_empty(self, prefix_cache, mx): + """Empty remaining → exact prefix hit already; no MRU work.""" + tokens = [10, 20, 30, 40, 50, 60] + cache_data = [_kv_layer(mx, 6) for _ in range(4)] + block_table = prefix_cache.store_cache("req-noop-empty", tokens, cache_data) + + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + _, _, applied = prefix_cache.apply_mru_partial( + reconstructed, block_table, [], + ) + + assert applied == 0 + # Stash must NOT be evicted on empty-remaining no-op — it could + # still match a *future* request that does have tail tokens. + assert bool(prefix_cache._mru_partials) + + # --- threat model: transactional splice rollback (B3) --- + + def test_splice_failure_does_not_mutate_any_layer( + self, prefix_cache, mx + ): + """If any layer's concatenate fails, NO layer is mutated. + + Review B3: the original implementation's try/except wrapped the + whole loop, so failure on layer N>0 left layers 0..N-1 mutated + with cache.offset += n_partial while the caller was told nothing + was applied. The rewrite must build replacements first and commit + atomically. + """ + tokens = [10, 20, 30, 40, 50, 60] + cache_data = [_kv_layer(mx, 6) for _ in range(4)] + block_table = prefix_cache.store_cache("req-rollback", tokens, cache_data) + + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + # Snapshot the pre-splice state of every layer for an after-comparison. + before_offsets = [layer.offset for layer in reconstructed] + before_key_shapes = [layer.keys.shape for layer in reconstructed] + + # Make mx.concatenate explode on the third call (layer 1's keys). + # Calls go: layer0 keys, layer0 values, layer1 keys (boom). + from omlx.cache import prefix_cache as pc_mod + real_concatenate = pc_mod.mx.concatenate + call_count = {"n": 0} + + def flaky_concatenate(*args, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 3: + raise RuntimeError("synthetic concatenate failure") + return real_concatenate(*args, **kwargs) + + with patch.object(pc_mod.mx, "concatenate", side_effect=flaky_concatenate): + _, new_remaining, applied = prefix_cache.apply_mru_partial( + reconstructed, block_table, [50, 60], + ) + + assert applied == 0 + assert new_remaining == [50, 60] + # No layer's offset advanced. + assert [layer.offset for layer in reconstructed] == before_offsets + # No layer's keys shape changed. + assert [layer.keys.shape for layer in reconstructed] == before_key_shapes + # Slot is evicted on splice failure (don't retry a failing partial). + assert not prefix_cache._mru_partials + + # --- threat model: multi-turn (existing_tokens > 0) --- + + def test_stash_correct_indices_when_existing_tokens_present( + self, prefix_cache, mx + ): + """When store_cache is called with existing_tokens > 0 (multi-turn), + the stash slices the partial from the correct cache region. + + cache_data is full-sequence (system prompt + new turn), so the + partial extraction must use global indices, not relative ones. + """ + # Pretend a previous turn already cached 4 tokens. + prev_tokens = [1, 2, 3, 4] + prev_cache = [_kv_layer(mx, 4) for _ in range(4)] + prefix_cache.store_cache("req-turn-1", prev_tokens, prev_cache) + + # Second turn: 4 prev + 4 new = 8 prefix block tokens, then 2 trailing. + # Distinct fill values let us verify the stash sliced the *right* + # tokens — the partial should contain the trailing region's data + # (fill=2.0), not the prefix region (fill=1.0). + full_tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + full_cache = [] + for _ in range(4): + # First 8 positions = old (1.0), last 2 positions = new (2.0) + keys = mx.concatenate( + [ + mx.full((1, 1, 8, 4), 1.0), + mx.full((1, 1, 2, 4), 2.0), + ], + axis=2, + ) + values = mx.concatenate( + [ + mx.full((1, 1, 8, 4), 1.0), + mx.full((1, 1, 2, 4), 2.0), + ], + axis=2, + ) + full_cache.append({ + "state": (keys, values), + "cache_type": "KVCache", + "class_name": "KVCache", + }) + + block_table = prefix_cache.store_cache( + "req-turn-2", full_tokens, full_cache + ) + + parent_hash = prefix_cache.paged_cache.allocated_blocks[ + block_table.block_ids[-1] + ].block_hash + partial = _get_mru_partial(prefix_cache, parent_hash) + assert partial is not None + assert partial.tokens == [9, 10] + # Each layer's stashed slice must be the trailing region (fill=2.0). + for keys, values in partial.kv_data: + assert keys.shape[2] == 2 + assert mx.allclose(keys, mx.full((1, 1, 2, 4), 2.0)) + assert mx.allclose(values, mx.full((1, 1, 2, 4), 2.0)) + + +class TestMRUPartialMultiSlot: + """Multi-slot LRU semantics: coexistence, capacity, eviction discipline. + + These tests cover the mechanics that single-slot mode could not + exercise — multiple entries keyed by distinct ``parent_hash`` values, + LRU promotion on apply success, sibling preservation on apply miss, + capacity-bounded eviction, ``max_entries=0`` feature disable, and + the freed-paged-block guard introduced by the multi-slot design. + """ + + @pytest.fixture + def mx(self): + try: + import mlx.core as mx + return mx + except ImportError: + pytest.skip("MLX not available") + + @pytest.fixture + def paged_cache(self): + return PagedCacheManager( + block_size=4, + max_blocks=100, + model_name="test-model", + initial_blocks=100, + ) + + @pytest.fixture + def mock_ssd(self): + mock = MagicMock() + mock.save_block.return_value = True + mock.load_block.return_value = None + mock.load_block_with_metadata.return_value = (None, None) + mock.has_block.return_value = False + return mock + + # --- multi-entry coexistence --- + + def test_distinct_prefixes_coexist_as_separate_entries( + self, paged_cache, mock_ssd, mx + ): + """Two stashes with different parent prefixes produce two entries.""" + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + _, hash_a = _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + _, hash_b = _stash_with_prefix(cache, mx, prefix_marker=2, tail_token=88) + + assert hash_a != hash_b + assert len(cache._mru_partials) == 2 + assert hash_a in cache._mru_partials + assert hash_b in cache._mru_partials + + # --- LRU mechanics (parameterized) --- + + @pytest.mark.parametrize( + "scenario", + [ + # (capacity, n_stashes_in_order, expected_dict_keys_in_order) + # Capacity respected; oldest evicted on overflow. + ("evict_oldest_at_capacity", 2, [1, 2, 3], [2, 3]), + # Below capacity, all retained, insertion order preserved. + ("under_capacity_keeps_all", 4, [1, 2, 3], [1, 2, 3]), + ], + ids=lambda s: s[0] if isinstance(s, tuple) else str(s), + ) + def test_lru_capacity_bounds( + self, paged_cache, mock_ssd, mx, scenario + ): + _, capacity, order, expected = scenario + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=capacity) + hashes = {} + for marker in order: + _, h = _stash_with_prefix( + cache, mx, prefix_marker=marker, tail_token=900 + marker + ) + hashes[marker] = h + + expected_keys = [hashes[m] for m in expected] + assert list(cache._mru_partials.keys()) == expected_keys + + def test_apply_success_promotes_entry_to_lru_tail( + self, paged_cache, mock_ssd, mx + ): + """Applying an entry moves it to the LRU tail; a subsequent + capacity-eviction drops a now-older sibling, not the just-used + entry.""" + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=2) + bt_a, hash_a = _stash_with_prefix( + cache, mx, prefix_marker=1, tail_token=901 + ) + _, hash_b = _stash_with_prefix( + cache, mx, prefix_marker=2, tail_token=902 + ) + assert list(cache._mru_partials.keys()) == [hash_a, hash_b] + + # Apply A → A promoted to tail. + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + # Remaining must equal A's stashed tokens. Tail token was 901, + # placed at index 4 of A's prompt. + _, _, applied = cache.apply_mru_partial(reconstructed, bt_a, [901]) + assert applied == 1 + assert list(cache._mru_partials.keys()) == [hash_b, hash_a] + + # Stash C at capacity 2 → B evicted (oldest after promote), A kept. + _, hash_c = _stash_with_prefix( + cache, mx, prefix_marker=3, tail_token=903 + ) + assert list(cache._mru_partials.keys()) == [hash_a, hash_c] + assert hash_b not in cache._mru_partials + + # --- max_entries=0 disables --- + + def test_max_entries_zero_disables_stashing( + self, paged_cache, mock_ssd, mx + ): + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=0) + _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + + assert len(cache._mru_partials) == 0 + assert cache.has_mru_partial() is False + + # --- clear_mru_partials() leaves siblings alone --- + + def test_clear_mru_partials_wipes_only_partials( + self, paged_cache, mock_ssd, mx + ): + """``clear_mru_partials()`` is the admin-clear hook. It wipes the + MRU dict but must not touch ``paged_cache``, the prefix index, + or stats — those have their own clear paths.""" + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + bt, _ = _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + + prefix_index_before = dict(cache._prefix_index) + request_tables_before = dict(cache._request_tables) + assert len(cache._mru_partials) == 1 + assert prefix_index_before # was populated by store_cache + + n_wiped = cache.clear_mru_partials() + + assert n_wiped == 1 + assert len(cache._mru_partials) == 0 + # Paged blocks, prefix index, request tables all unchanged. + assert cache._prefix_index == prefix_index_before + assert cache._request_tables == request_tables_before + assert bt.block_ids[-1] in cache.paged_cache.allocated_blocks + + # --- freed-block guard (new in multi-slot) --- + + def test_apply_noop_when_parent_block_freed( + self, paged_cache, mock_ssd, mx + ): + """If the parent paged block is freed between stash and apply, + the apply path must not fall through to a None-keyed lookup + (which could falsely match a short-prompt entry). + + This race is new in multi-slot: single-slot tolerated it because + there was only ever one slot to match against. + """ + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + # Stash a short-prompt entry (parent_hash=None) — the false-match + # bait for the freed-block scenario. + short_tokens = [99, 100, 101] # < block_size=4 + cache.store_cache( + "req-short", short_tokens, [_kv_layer(mx, 3) for _ in range(4)] + ) + # Confirm the short-prompt entry landed under None. + assert None in cache._mru_partials + + # Construct a block_table whose last block has been freed. + freed_block = paged_cache.allocate_block() + freed_block_id = freed_block.block_id + paged_cache.free_block(freed_block_id) + bt = BlockTable(request_id="req-freed") + bt.block_ids.append(freed_block_id) + bt.num_tokens = 4 + + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + _, new_remaining, applied = cache.apply_mru_partial( + reconstructed, bt, [99, 100, 101] + ) + + # Must NOT splice the short-prompt entry even though the + # remaining tokens happen to match. + assert applied == 0 + assert new_remaining == [99, 100, 101] + # Short-prompt entry preserved (not falsely evicted by the guard). + assert None in cache._mru_partials + + # --- short-prompt None-key coexists with hash-keyed entry --- + + def test_short_prompt_none_key_coexists_with_block_aligned_entry( + self, paged_cache, mock_ssd, mx + ): + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + # Short prompt (< block_size) → parent_hash=None + cache.store_cache( + "req-short", [10, 20, 30], + [_kv_layer(mx, 3) for _ in range(4)], + ) + # Longer prompt → distinct parent_hash + _, hash_long = _stash_with_prefix( + cache, mx, prefix_marker=1, tail_token=99 + ) + + assert None in cache._mru_partials + assert hash_long in cache._mru_partials + assert len(cache._mru_partials) == 2 + + +class TestMRUPartialCounters: + """The observability counters mirror PR #1183's pattern so operators + can answer "is the MRU cache paying off" with the same dashboard + surface they use for prefix-hit and memory-hit rates. + + Counters: ``mru_partial_stashes``, ``mru_partial_hits``, + ``mru_partial_evictions``, ``mru_partial_tokens_saved``. + Gauges: ``mru_partial_entries``, ``mru_partial_max_entries``. + """ + + @pytest.fixture + def mx(self): + try: + import mlx.core as mx + return mx + except ImportError: + pytest.skip("MLX not available") + + @pytest.fixture + def paged_cache(self): + return PagedCacheManager( + block_size=4, + max_blocks=100, + model_name="test-model", + initial_blocks=100, + ) + + @pytest.fixture + def mock_ssd(self): + mock = MagicMock() + mock.save_block.return_value = True + mock.load_block.return_value = None + mock.load_block_with_metadata.return_value = (None, None) + mock.has_block.return_value = False + return mock + + def test_initial_counters_are_zero(self, paged_cache, mock_ssd): + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + stats = cache.get_stats() + assert stats.mru_partial_stashes == 0 + assert stats.mru_partial_hits == 0 + assert stats.mru_partial_evictions == 0 + assert stats.mru_partial_tokens_saved == 0 + assert stats.mru_partial_entries == 0 + assert stats.mru_partial_max_entries == 4 + + def test_stash_increments_stash_counter(self, paged_cache, mock_ssd, mx): + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + _stash_with_prefix(cache, mx, prefix_marker=2, tail_token=88) + + stats = cache.get_stats() + assert stats.mru_partial_stashes == 2 + assert stats.mru_partial_entries == 2 + + def test_same_key_replacement_counts_as_stash_not_eviction( + self, paged_cache, mock_ssd, mx + ): + """Replacing an existing entry under the same key counts as a + stash but NOT as an eviction. Eviction is reserved for entries + that leave the dict (capacity overflow, apply-miss, clear).""" + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=77) + + stats = cache.get_stats() + assert stats.mru_partial_stashes == 2 + assert stats.mru_partial_evictions == 0 + assert stats.mru_partial_entries == 1 + + def test_capacity_overflow_increments_eviction_counter( + self, paged_cache, mock_ssd, mx + ): + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=2) + for i in (1, 2, 3): + _stash_with_prefix(cache, mx, prefix_marker=i, tail_token=100 + i) + + stats = cache.get_stats() + assert stats.mru_partial_stashes == 3 + assert stats.mru_partial_evictions == 1 # one entry pushed out + assert stats.mru_partial_entries == 2 + + def test_apply_success_increments_hits_and_tokens_saved( + self, paged_cache, mock_ssd, mx + ): + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + bt, _ = _stash_with_prefix( + cache, mx, prefix_marker=1, tail_token=901 + ) + + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + _, _, applied = cache.apply_mru_partial(reconstructed, bt, [901]) + assert applied == 1 + + stats = cache.get_stats() + assert stats.mru_partial_hits == 1 + assert stats.mru_partial_tokens_saved == 1 + assert stats.mru_partial_evictions == 0 # success, not eviction + + def test_apply_miss_on_found_key_increments_eviction( + self, paged_cache, mock_ssd, mx + ): + """Token-mismatch eviction pops the matched key and counts.""" + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + bt, _ = _stash_with_prefix( + cache, mx, prefix_marker=1, tail_token=99 + ) + + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + _, _, applied = cache.apply_mru_partial(reconstructed, bt, [77]) # wrong tail + assert applied == 0 + + stats = cache.get_stats() + assert stats.mru_partial_hits == 0 + assert stats.mru_partial_evictions == 1 + + def test_clear_mru_partials_counts_all_wiped_entries( + self, paged_cache, mock_ssd, mx + ): + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + _stash_with_prefix(cache, mx, prefix_marker=2, tail_token=88) + _stash_with_prefix(cache, mx, prefix_marker=3, tail_token=77) + + n = cache.clear_mru_partials() + assert n == 3 + + stats = cache.get_stats() + assert stats.mru_partial_evictions == 3 + assert stats.mru_partial_entries == 0 + + def test_clear_wipes_partials_and_resets_counters( + self, paged_cache, mock_ssd, mx + ): + """clear() is the "restart everything" path (cache-corruption + recovery). It wipes the dict AND resets every counter, + including mru_partial_evictions — incrementing evictions just + to have them zeroed by the same call would be incoherent. + Operators tracking partial wipes specifically use + clear_mru_partials() instead. + """ + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + _stash_with_prefix(cache, mx, prefix_marker=2, tail_token=88) + + cache.clear() + + stats = cache.get_stats() + assert stats.mru_partial_entries == 0 # dict wiped + assert stats.mru_partial_stashes == 0 # counters reset + assert stats.mru_partial_evictions == 0 + assert stats.mru_partial_hits == 0 + + def test_reset_stats_zeros_mru_counters_but_keeps_live_state( + self, paged_cache, mock_ssd, mx + ): + """reset_stats() is the analyst's reset — it zeros cumulative + counters but leaves the live cache state alone. Use + clear_mru_partials() if entries should be dropped too.""" + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + + cache.reset_stats() + + stats = cache.get_stats() + assert stats.mru_partial_stashes == 0 + assert stats.mru_partial_hits == 0 + assert stats.mru_partial_evictions == 0 + assert stats.mru_partial_tokens_saved == 0 + # But the live entry is still there. + assert stats.mru_partial_entries == 1 + + def test_get_stats_dict_mirrors_dataclass_after_round_trip( + self, paged_cache, mock_ssd, mx + ): + """``get_stats_dict`` must surface every MRU field that ``get_stats`` + (the dataclass) does. The admin dashboard reads MRU state via the + dict path (``Scheduler.get_ssd_cache_stats`` -> ``get_stats_dict``); + when the dict drops any of these keys, the admin payload's + ``mru_partial_max_entries`` aggregates to 0 and the dashboard's + ``mruEnabled`` gate hides every MRU panel even when the feature + is enabled. + + Uses the production round-trip (real stashes via ``store_cache``, + real apply via ``apply_mru_partial``, real capacity overflow) so + the live gauge ``mru_partial_entries`` and every counter reach the + dict via the same path the scheduler exercises. + """ + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=2) + # Stash two distinct prefixes (both fit in the cap). The first + # one (``bt_kept``) will be applied below to bump hits; the LRU + # touch from a successful apply leaves it at the MRU end, so the + # later capacity-overflow stash evicts the *other* survivor and + # not the one we just touched. + bt_kept, _ = _stash_with_prefix( + cache, mx, prefix_marker=2, tail_token=88 + ) + _stash_with_prefix(cache, mx, prefix_marker=3, tail_token=77) + + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + _, _, applied = cache.apply_mru_partial(reconstructed, bt_kept, [88]) + assert applied == 1 # guard: the apply path actually fired + + # Third stash forces capacity-overflow eviction of the un-touched + # entry. After this: 3 stashes, 1 hit, 1 token saved, 1 eviction, + # 2 live entries. + _stash_with_prefix(cache, mx, prefix_marker=4, tail_token=66) + + stats = cache.get_stats() + stats_dict = cache.get_stats_dict() + + # Every MRU field on the dataclass must surface in the dict with + # the same value — this is the contract the admin route depends on. + for field in ( + "mru_partial_stashes", + "mru_partial_hits", + "mru_partial_evictions", + "mru_partial_tokens_saved", + "mru_partial_entries", + "mru_partial_max_entries", + ): + assert field in stats_dict, f"{field} missing from get_stats_dict()" + assert stats_dict[field] == getattr(stats, field), ( + f"{field}: dict={stats_dict[field]} dataclass={getattr(stats, field)}" + ) + + # Sanity: the round-trip actually moved every counter off its + # initial zero, so a future regression that hardwires zeros into + # the dict would still fail this test. + assert stats_dict["mru_partial_stashes"] == 3 + assert stats_dict["mru_partial_hits"] == 1 + assert stats_dict["mru_partial_evictions"] == 1 # capacity overflow + assert stats_dict["mru_partial_tokens_saved"] == 1 + assert stats_dict["mru_partial_entries"] == 2 + assert stats_dict["mru_partial_max_entries"] == 2 + + +def _model_with_make_cache(num_layers: int, layer_class_names: list[str]): + """Build a MockModel whose ``make_cache()`` returns objects whose + ``type(obj).__name__`` matches the requested cache class names. + + ``ModelCacheConfig.from_cache_list`` identifies cache types by class + name (with isinstance fallback for SizedArraysCache only), so dynamic + classes are enough to exercise the eager init-time eligibility check + without pulling in real mlx-lm cache implementations. + """ + cache_objs = [ + type(name, (object,), {"max_size": 64})() + for name in layer_class_names + ] + model = MockModel(num_layers=num_layers) + model.make_cache = lambda: cache_objs # type: ignore[attr-defined] + return model + + +class TestMRUPartialEligibility: + """The ``mru_partial_supported`` tri-state flag and its one-shot + warning. Surfaces structurally-incompatible models on the admin + dashboard so operators see ``N/A (see log)`` instead of a misleading + ``0/N entries`` gauge. Mirrors the prior-art Pattern B (real + ``store_cache`` round-trip) for the lazy fallback path, and exercises + the eager init-time path through a model with ``make_cache()``. + """ + + @pytest.fixture + def mx(self): + try: + import mlx.core as mx + return mx + except ImportError: + pytest.skip("MLX not available") + + @pytest.fixture + def paged_cache(self): + return PagedCacheManager( + block_size=4, + max_blocks=100, + model_name="test-model", + initial_blocks=100, + ) + + @pytest.fixture + def mock_ssd(self): + mock = MagicMock() + mock.save_block.return_value = True + mock.load_block.return_value = None + mock.load_block_with_metadata.return_value = (None, None) + mock.has_block.return_value = False + return mock + + def test_supported_is_none_without_make_cache_and_no_inference( + self, paged_cache, mock_ssd + ): + """``MockModel`` has no ``make_cache``; eager check bare-returns and + lazy fallback hasn't fired yet — flag stays ``None``.""" + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + stats = cache.get_stats() + assert stats.mru_partial_supported is None + assert cache.get_stats_dict()["mru_partial_supported"] is None + + def test_supported_latches_true_on_sliceable_observation( + self, paged_cache, mock_ssd, mx + ): + """A successful KVCache stash latches ``supported=True`` via the + lazy path (eager skipped because MockModel has no make_cache).""" + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + _stash_with_prefix(cache, mx, prefix_marker=1, tail_token=99) + assert cache.get_stats().mru_partial_supported is True + + def test_supported_latches_false_lazy_on_non_sliceable( + self, paged_cache, mock_ssd, mx, caplog + ): + """A store_cache with RotatingKVCache layers latches + ``supported=False`` and emits exactly one warning.""" + from omlx.cache.hybrid_cache import ModelCacheConfig + + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + # Eager init skipped (MockModel has no make_cache), so flag is + # None at this point — the lazy path is what we're testing. + assert cache.get_stats().mru_partial_supported is None + + tokens = [10, 20, 30, 40, 50] + cache_data = [_rotating_layer(mx, 5) for _ in range(4)] + config = ModelCacheConfig.from_type_list( + ["RotatingKVCache"] * 4, model_name="test" + ) + with caplog.at_level(logging.WARNING, logger="omlx.cache.prefix_cache"): + cache.store_cache("req-rot", tokens, cache_data, model_cache_config=config) + + stats = cache.get_stats() + assert stats.mru_partial_supported is False + assert stats.mru_partial_stashes == 0 # gate refused, no stash + warns = [r for r in caplog.records if r.levelno == logging.WARNING] + assert len(warns) == 1 + assert "MRU tails will be inactive" in warns[0].getMessage() + assert "incompatible" in warns[0].getMessage() + assert "RotatingKVCache" in warns[0].getMessage() + + def test_warning_does_not_repeat_on_subsequent_non_sliceable( + self, paged_cache, mock_ssd, mx, caplog + ): + """Once the flag is latched False, further non-sliceable store_cache + calls must NOT re-emit the warning (operator log spam guard).""" + from omlx.cache.hybrid_cache import ModelCacheConfig + + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + config = ModelCacheConfig.from_type_list( + ["RotatingKVCache"] * 4, model_name="test" + ) + with caplog.at_level(logging.WARNING, logger="omlx.cache.prefix_cache"): + for i in range(3): + cache.store_cache( + f"req-rot-{i}", + [10 * i + j for j in range(5)], + [_rotating_layer(mx, 5) for _ in range(4)], + model_cache_config=config, + ) + + warns = [r for r in caplog.records if r.levelno == logging.WARNING] + assert len(warns) == 1 + assert cache.get_stats().mru_partial_supported is False + + def test_eager_check_latches_false_at_init_with_non_sliceable_make_cache( + self, paged_cache, mock_ssd, caplog + ): + """When ``model.make_cache()`` is available and returns non-sliceable + cache instances, the flag latches False at construction and the + warning fires BEFORE any inference — true model-load-time signal.""" + model = _model_with_make_cache( + num_layers=4, + layer_class_names=["RotatingKVCache"] * 4, + ) + with caplog.at_level(logging.WARNING, logger="omlx.cache.prefix_cache"): + cache = BlockAwarePrefixCache( + model=model, + paged_cache_manager=paged_cache, + paged_ssd_cache_manager=mock_ssd, + mru_partial_max_entries=4, + ) + + assert cache.get_stats().mru_partial_supported is False + warns = [r for r in caplog.records if r.levelno == logging.WARNING] + assert len(warns) == 1 + assert "RotatingKVCache" in warns[0].getMessage() + + def test_eager_check_latches_true_at_init_with_sliceable_make_cache( + self, paged_cache, mock_ssd, caplog + ): + """When ``model.make_cache()`` returns only sliceable cache + instances, the flag latches True at construction and no warning + is emitted.""" + model = _model_with_make_cache( + num_layers=4, + layer_class_names=["KVCache"] * 4, + ) + with caplog.at_level(logging.WARNING, logger="omlx.cache.prefix_cache"): + cache = BlockAwarePrefixCache( + model=model, + paged_cache_manager=paged_cache, + paged_ssd_cache_manager=mock_ssd, + mru_partial_max_entries=4, + ) + + assert cache.get_stats().mru_partial_supported is True + assert not [r for r in caplog.records if r.levelno == logging.WARNING] + + def test_eager_check_skipped_when_feature_disabled( + self, paged_cache, mock_ssd, caplog + ): + """``max_entries=0`` disables the feature; no eager check runs even + for an obviously incompatible model. No warning, no flag change — + the operator already opted out by setting the capacity to zero.""" + model = _model_with_make_cache( + num_layers=4, + layer_class_names=["RotatingKVCache"] * 4, + ) + with caplog.at_level(logging.WARNING, logger="omlx.cache.prefix_cache"): + cache = BlockAwarePrefixCache( + model=model, + paged_cache_manager=paged_cache, + paged_ssd_cache_manager=mock_ssd, + mru_partial_max_entries=0, + ) + + assert cache.get_stats().mru_partial_supported is None + assert not [r for r in caplog.records if r.levelno == logging.WARNING] + + def test_eager_check_survives_make_cache_failure( + self, paged_cache, mock_ssd, caplog + ): + """If ``model.make_cache()`` raises, the eager check bare-returns + and the flag stays ``None``. Lazy fallback picks up at first + inference instead — no startup crash.""" + model = MockModel(num_layers=4) + model.make_cache = lambda: (_ for _ in ()).throw( # type: ignore[attr-defined] + RuntimeError("model not fully initialized") + ) + cache = BlockAwarePrefixCache( + model=model, + paged_cache_manager=paged_cache, + paged_ssd_cache_manager=mock_ssd, + mru_partial_max_entries=4, + ) + assert cache.get_stats().mru_partial_supported is None + + +def _store_seq(cache, mx, request_id, tokens, *, prompt_token_count=None): + """``store_cache`` a full token sequence; return the BlockTable. + + ``cache_data`` spans the whole sequence so ``_update_mru_partial`` + takes the global index path (the cache covers ``prompt + output``), + matching the production resubmission layout. + """ + cache_data = [_kv_layer(mx, len(tokens)) for _ in range(4)] + return cache.store_cache( + request_id, tokens, cache_data, prompt_token_count=prompt_token_count, + ) + + +class TestMRUPromptBoundaryStash: + """The MRU stash must key off the *prompt's* trailing partial, not the + stored sequence's. + + ``store_cache`` is handed ``prompt + output``, but a repeat request + resubmits the prompt only and ``apply_mru_partial`` looks the entry + up by the prompt's last full block. Before the prompt boundary was + threaded in, the stash keyed off ``prompt + output``'s last full + block — a key a prompt-only resubmit could never compute — so the + feature never produced a hit for ordinary chat completions. + + block_size is 4 in this fixture. + """ + + @pytest.fixture + def mx(self): + try: + import mlx.core as mx + return mx + except ImportError: + pytest.skip("MLX not available") + + @pytest.fixture + def paged_cache(self): + return PagedCacheManager( + block_size=4, + max_blocks=100, + model_name="test-model", + initial_blocks=100, + ) + + @pytest.fixture + def mock_ssd(self): + mock = MagicMock() + mock.save_block.return_value = True + mock.load_block.return_value = None + mock.load_block_with_metadata.return_value = (None, None) + mock.has_block.return_value = False + return mock + + def test_prompt_boundary_stash_hits_on_prompt_only_resubmit( + self, paged_cache, mock_ssd, mx + ): + """The decisive regression test: store ``prompt + output`` with the + prompt boundary, then resubmit the prompt only — apply must HIT. + """ + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + # prompt = 1 full block [10..13] + 2-token tail [14,15]; output = 3. + prompt = [10, 11, 12, 13, 14, 15] + stored = prompt + [90, 91, 92] + store_bt = _store_seq( + cache, mx, "store", stored, prompt_token_count=len(prompt) + ) + + # Keyed by the prompt's last full block (block 0) — NOT the stored + # sequence's last full block (block 1) — and stashing the prompt + # tail [14,15], not the sequence tail [92]. + prompt_block = paged_cache.allocated_blocks[store_bt.block_ids[0]] + assert prompt_block.block_hash in cache._mru_partials + assert cache._mru_partials[prompt_block.block_hash].tokens == [14, 15] + + # Simulate fetch_cache(prompt): a block table with the prompt's + # full blocks only — apply_mru_partial keys off block_ids[-1]. + fetch_bt = BlockTable(request_id="resubmit") + fetch_bt.block_ids = [store_bt.block_ids[0]] + fetch_bt.num_tokens = 4 + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + + _, new_remaining, applied = cache.apply_mru_partial( + reconstructed, fetch_bt, [14, 15] + ) + assert applied == 2 + assert new_remaining == [] + stats = cache.get_stats() + assert stats.mru_partial_hits == 1 + assert stats.mru_partial_tokens_saved == 2 + + def test_whole_sequence_stash_misses_on_prompt_only_resubmit( + self, paged_cache, mock_ssd, mx + ): + """Pins the original bug: with no prompt boundary + (``prompt_token_count=None``) the stash keys off the stored + sequence's last full block, which a prompt-only resubmit never + reaches — 0 hits, and 0 evictions because the lookup key is never + found (no entry to evict). + """ + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + prompt = [10, 11, 12, 13, 14, 15] + stored = prompt + [90, 91, 92] + store_bt = _store_seq(cache, mx, "store", stored) # boundary unknown + + # None falls back to whole-sequence: keyed off block 1 (the stored + # sequence's last full block), unreachable by a prompt-only fetch. + seq_block = paged_cache.allocated_blocks[store_bt.block_ids[1]] + assert seq_block.block_hash in cache._mru_partials + + fetch_bt = BlockTable(request_id="resubmit") + fetch_bt.block_ids = [store_bt.block_ids[0]] + fetch_bt.num_tokens = 4 + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + + _, _, applied = cache.apply_mru_partial(reconstructed, fetch_bt, [14, 15]) + assert applied == 0 + stats = cache.get_stats() + assert stats.mru_partial_hits == 0 + assert stats.mru_partial_evictions == 0 # key miss — nothing evicted + + def test_block_aligned_prompt_does_not_stash( + self, paged_cache, mock_ssd, mx + ): + """A prompt that is an exact multiple of block_size has no partial + tail — every prompt token lands in a full paged block, so the MRU + has nothing to add.""" + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + prompt = [10, 11, 12, 13, 14, 15, 16, 17] # 8 tokens = 2 full blocks + stored = prompt + [90, 91, 92] + _store_seq(cache, mx, "store", stored, prompt_token_count=len(prompt)) + assert not cache._mru_partials + assert cache.get_stats().mru_partial_stashes == 0 + + def test_short_prompt_stashes_under_none_key( + self, paged_cache, mock_ssd, mx + ): + """A prompt shorter than one block has no last full block; the + stash is keyed by None (the short-prompt path) and holds the + whole prompt as its tail.""" + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + prompt = [10, 11, 12] # 3 tokens < block_size 4 + stored = prompt + [90, 91, 92, 93, 94] + _store_seq(cache, mx, "store", stored, prompt_token_count=len(prompt)) + assert None in cache._mru_partials + assert cache._mru_partials[None].tokens == [10, 11, 12] + + def test_prompt_boundary_stash_with_existing_cached_prefix( + self, paged_cache, mock_ssd, mx + ): + """Resubmission path: store_cache runs with ``existing_tokens > 0`` + (the prompt's leading blocks are already cached) and works in + ``new_tokens`` space. The prompt-boundary arithmetic must still + resolve the prompt's last full block — not an index shifted by + ``existing_tokens``. + """ + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + # prompt = 2 full blocks [10..17] + 2-token tail [18,19]; output = 3. + prompt = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19] + stored = prompt + [90, 91, 92] + # Round 1: store the prompt's first full block only, same + # request_id, so round 2's store sees existing_tokens > 0. + _store_seq(cache, mx, "req", prompt[:8], prompt_token_count=8) + assert not cache._mru_partials # block-aligned round-1, no stash + # Round 2: same request_id — existing_tokens == 8. + store_bt = _store_seq( + cache, mx, "req", stored, prompt_token_count=len(prompt) + ) + + # Prompt's last full block is index 1 ([14..17]); the stash must + # key off it despite existing_tokens=8 and new_tokens-space math. + prompt_block = paged_cache.allocated_blocks[store_bt.block_ids[1]] + assert prompt_block.block_hash in cache._mru_partials + assert cache._mru_partials[prompt_block.block_hash].tokens == [18, 19] + + fetch_bt = BlockTable(request_id="resubmit") + fetch_bt.block_ids = store_bt.block_ids[:2] + fetch_bt.num_tokens = 8 + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=8) + _, new_remaining, applied = cache.apply_mru_partial( + reconstructed, fetch_bt, [18, 19] + ) + assert applied == 2 + assert new_remaining == [] + assert cache.get_stats().mru_partial_hits == 1 + + +class TestMRUPartialCrossThreadSafety: + """An MRU partial is extracted on the store-cache worker thread but + spliced into a live cache on the separate inference thread. + + ``_extract_block_tensor_slice`` builds the partial as lazy + ``mx.copy`` ops; an unevaluated tensor carries a pending op bound to + the worker's per-thread MLX stream, and evaluating the splice on the + inference thread raises ``RuntimeError: There is no Stream(gpu, N) + in current thread``. ``_update_mru_partial`` must materialize the + partial at stash time so the stashed data is concrete and + stream-free. + """ + + @pytest.fixture + def mx(self): + try: + import mlx.core as mx + return mx + except ImportError: + pytest.skip("MLX not available") + + @pytest.fixture + def paged_cache(self): + return PagedCacheManager( + block_size=4, + max_blocks=100, + model_name="test-model", + initial_blocks=100, + ) + + @pytest.fixture + def mock_ssd(self): + mock = MagicMock() + mock.save_block.return_value = True + mock.load_block.return_value = None + mock.load_block_with_metadata.return_value = (None, None) + mock.has_block.return_value = False + return mock + + def test_materialize_mru_kv_handles_extract_shapes( + self, paged_cache, mock_ssd, mx + ): + """``_materialize_mru_kv`` evaluates every ``mx.array`` leaf across + the plain ``(keys, values)`` and TurboQuant ``(tag, (k, v))`` + shapes ``_extract_block_tensor_slice`` returns, and tolerates the + non-array tag string and an empty list.""" + cache = _make_mru_cache(paged_cache, mock_ssd) + plain = [(mx.ones((1, 1, 2, 4)), mx.ones((1, 1, 2, 4)))] + tagged = [ + ("__turboquant_v2__", (mx.ones((1, 1, 2, 4)), mx.ones((1, 1, 2, 4)))) + ] + # None of these should raise. + cache._materialize_mru_kv(plain) + cache._materialize_mru_kv(tagged) + cache._materialize_mru_kv([]) + + def test_stashed_partial_splices_across_threads( + self, paged_cache, mock_ssd, mx + ): + """Extract+stash on a worker thread, splice+evaluate on the main + thread. Without stash-time materialization the final ``mx.eval`` + raises a foreign-stream ``RuntimeError``; with it, the cross- + thread handoff is clean. + """ + import concurrent.futures + + cache = _make_mru_cache(paged_cache, mock_ssd, max_entries=4) + prompt = [10, 11, 12, 13, 14, 15] + stored = prompt + [90, 91, 92] + cache_data = [_kv_layer(mx, len(stored)) for _ in range(4)] + # Mirror production: the inference thread materializes the + # extracted cache (mx.async_eval + the worker's mx.synchronize) + # before the store worker runs. Without this the cache_data + # arrays would still be lazy ops bound to THIS thread's stream — + # a different cross-thread failure than the one under test. + for layer in cache_data: + mx.eval(*layer["state"]) + + # Stash on a dedicated worker thread, mirroring the production + # _store_cache_executor (a pool distinct from the inference one). + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + store_bt = pool.submit( + cache.store_cache, + "store", + stored, + cache_data, + prompt_token_count=len(prompt), + ).result() + + # Splice on this (the "inference") thread. + fetch_bt = BlockTable(request_id="resubmit") + fetch_bt.block_ids = [store_bt.block_ids[0]] + fetch_bt.num_tokens = 4 + reconstructed = _make_reconstructed_cache(mx, n_layers=4, n_tokens=4) + + spliced, new_remaining, applied = cache.apply_mru_partial( + reconstructed, fetch_bt, [14, 15] + ) + assert applied == 2 + assert new_remaining == [] + + # Force evaluation on this thread — the point a partial still + # carrying the worker thread's stream would fail. + for layer in spliced: + mx.eval(layer.keys, layer.values) + + +class TestHasMRUPartial: + """The has_mru_partial() accessor is the public API the scheduler + uses to decide whether to suppress the deferred Metal cache clear.""" + + def test_has_mru_partial_reflects_dict_emptiness(self): + cache = BlockAwarePrefixCache( + model=MockModel(num_layers=2), + paged_cache_manager=PagedCacheManager( + block_size=4, max_blocks=10, model_name="t", initial_blocks=10, + ), + paged_ssd_cache_manager=None, + ) + assert cache.has_mru_partial() is False + + cache._mru_partials[b"x"] = _MRUPartialBlock( + parent_hash=b"x", tokens=[1], kv_data=[], + ) + assert cache.has_mru_partial() is True + + cache._mru_partials.clear() + assert cache.has_mru_partial() is False diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 5e87be0f3..892db445c 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1360,6 +1360,253 @@ def test_cleanup_finished_extends_deferred_clear_for_concurrent_completions( assert scheduler._deferred_clear_at > first_target +class TestMRUDeferredClearSuppression: + """Tests for the MRU-aware deferred-clear suppression. + + When the prefix cache holds a warm MRU partial (the previous request + just left a sub-block tail in memory), the next deferred clear is + deferred by one more ``_DEFERRED_CLEAR_DELAY`` window so the still- + resident lazy KV tensors aren't dropped before a likely repeat. + + The suppression is **budgeted at one-shot per completion** so it + cannot indefinitely defer the clear under hot-prompt repeats — total + deferral is bounded at 2x ``_DEFERRED_CLEAR_DELAY`` even if the MRU + is replenished on every iteration. + """ + + def _drive_clear_gate(self, scheduler): + """Run the step()-tail gate that decides whether to clear. + + Mirrors the structure of step()'s end-of-step block — the + smallest slice of step() needed to exercise the suppression + logic without standing up a full BatchGenerator. + """ + from omlx import scheduler as sched_mod + + with patch.object(sched_mod, "_sync_and_clear_cache") as mock_clear: + scheduler._step_counter += 1 + should_clear = scheduler._should_periodic_clear_cache() + if ( + scheduler._deferred_clear_at is not None + and scheduler._step_counter >= scheduler._deferred_clear_at + ): + if ( + scheduler._mru_clear_suppression_available + and scheduler.block_aware_cache is not None + and scheduler.block_aware_cache.has_mru_partial() + ): + scheduler._deferred_clear_at = ( + scheduler._step_counter + + scheduler._DEFERRED_CLEAR_DELAY + ) + scheduler._mru_clear_suppression_available = False + else: + should_clear = True + scheduler._deferred_clear_at = None + scheduler._mru_clear_suppression_available = False + if should_clear: + sched_mod._sync_and_clear_cache() + return mock_clear.called + + def test_suppression_budget_armed_on_completion( + self, mock_model, mock_tokenizer + ): + """Each completion that arms _deferred_clear_at also arms the budget.""" + scheduler = Scheduler(model=mock_model, tokenizer=mock_tokenizer) + + request = Request( + request_id="req-arm", + prompt="hello", + sampling_params=SamplingParams(), + ) + request.prompt_token_ids = [1, 2] + request.num_prompt_tokens = 2 + request.output_token_ids = [3] + scheduler.running["req-arm"] = request + scheduler.requests["req-arm"] = request + + with patch("omlx.scheduler.mx"): + scheduler._cleanup_finished({"req-arm"}) + + assert scheduler._deferred_clear_at is not None + assert scheduler._mru_clear_suppression_available is True + + def test_clear_fires_at_deadline_when_no_mru( + self, mock_model, mock_tokenizer + ): + """Without a warm MRU, the deferred clear fires at its deadline.""" + scheduler = Scheduler(model=mock_model, tokenizer=mock_tokenizer) + scheduler.block_aware_cache = MagicMock() + scheduler.block_aware_cache.has_mru_partial.return_value = False + + scheduler._deferred_clear_at = scheduler._step_counter + 1 + scheduler._mru_clear_suppression_available = True + + cleared = self._drive_clear_gate(scheduler) + + assert cleared is True + assert scheduler._deferred_clear_at is None + assert scheduler._mru_clear_suppression_available is False + + def test_clear_deferred_once_when_mru_warm( + self, mock_model, mock_tokenizer + ): + """A warm MRU at the deadline defers the clear by one more window.""" + scheduler = Scheduler(model=mock_model, tokenizer=mock_tokenizer) + scheduler.block_aware_cache = MagicMock() + scheduler.block_aware_cache.has_mru_partial.return_value = True + + scheduler._deferred_clear_at = scheduler._step_counter + 1 + scheduler._mru_clear_suppression_available = True + + cleared = self._drive_clear_gate(scheduler) + + assert cleared is False + # Deadline pushed out by one more DELAY window + assert scheduler._deferred_clear_at == ( + scheduler._step_counter + Scheduler._DEFERRED_CLEAR_DELAY + ) + # Budget spent — next deadline cannot suppress again + assert scheduler._mru_clear_suppression_available is False + + def test_clear_fires_at_second_deadline_even_if_mru_still_warm( + self, mock_model, mock_tokenizer + ): + """Budget is one-shot: the second deadline fires regardless of MRU. + + This is the bound that protects against infinite deferral under + hot-prompt repeats (review B6). + """ + scheduler = Scheduler(model=mock_model, tokenizer=mock_tokenizer) + scheduler.block_aware_cache = MagicMock() + scheduler.block_aware_cache.has_mru_partial.return_value = True + + # First deadline → suppressed + scheduler._deferred_clear_at = scheduler._step_counter + 1 + scheduler._mru_clear_suppression_available = True + first_cleared = self._drive_clear_gate(scheduler) + assert first_cleared is False + + # Advance to the second deadline. MRU still warm, but no budget. + scheduler._step_counter = scheduler._deferred_clear_at - 1 + second_cleared = self._drive_clear_gate(scheduler) + + assert second_cleared is True + assert scheduler._deferred_clear_at is None + + def test_clear_fires_immediately_after_mru_evicted( + self, mock_model, mock_tokenizer + ): + """If the MRU evicts before the deadline, the clear fires normally.""" + scheduler = Scheduler(model=mock_model, tokenizer=mock_tokenizer) + scheduler.block_aware_cache = MagicMock() + # MRU was warm at completion but evicted by the time we reach deadline + scheduler.block_aware_cache.has_mru_partial.return_value = False + + scheduler._deferred_clear_at = scheduler._step_counter + 1 + scheduler._mru_clear_suppression_available = True + + cleared = self._drive_clear_gate(scheduler) + + assert cleared is True + # Budget left untouched as a side effect doesn't matter — it's + # reset on the next completion either way. + assert scheduler._deferred_clear_at is None + + def test_new_epoch_completion_arms_budget( + self, mock_model, mock_tokenizer + ): + """A completion that STARTS a new deferral epoch (transition from + ``_deferred_clear_at is None``) arms the budget. Together with + the next test, this pins the contract: budget is one-shot per + epoch.""" + scheduler = Scheduler(model=mock_model, tokenizer=mock_tokenizer) + scheduler._mru_clear_suppression_available = False # spent + assert scheduler._deferred_clear_at is None # no epoch in flight + + request = Request( + request_id="req-new-epoch", + prompt="hello", + sampling_params=SamplingParams(), + ) + request.prompt_token_ids = [1, 2] + request.num_prompt_tokens = 2 + request.output_token_ids = [3] + scheduler.running["req-new-epoch"] = request + scheduler.requests["req-new-epoch"] = request + + with patch("omlx.scheduler.mx"): + scheduler._cleanup_finished({"req-new-epoch"}) + + assert scheduler._mru_clear_suppression_available is True + + def test_completion_within_open_epoch_does_not_refresh_budget( + self, mock_model, mock_tokenizer + ): + """**The hot-prompt invariant.** A completion that lands while a + deferral is already pending must NOT refresh the budget. + + Pre-fix behaviour: every completion re-armed + ``_mru_clear_suppression_available = True``. Under hot-prompt + repeats — the very workload this feature targets — completions + arrive faster than ``_DEFERRED_CLEAR_DELAY``, so the budget + kept refreshing after being spent and the deferred clear could + be pushed forever, defeating the pool-bloat mitigation (#411). + + Post-fix: budget is armed only on the transition from ``None``, + spent at the deadline if MRU is still warm, and stays spent for + the rest of the epoch regardless of how many further completions + land. + """ + scheduler = Scheduler(model=mock_model, tokenizer=mock_tokenizer) + + # First completion opens the epoch and arms the budget. + req1 = Request( + request_id="req-open", + prompt="a", + sampling_params=SamplingParams(), + ) + req1.prompt_token_ids = [1, 2] + req1.num_prompt_tokens = 2 + req1.output_token_ids = [3] + scheduler.running["req-open"] = req1 + scheduler.requests["req-open"] = req1 + + with patch("omlx.scheduler.mx"): + scheduler._cleanup_finished({"req-open"}) + + assert scheduler._deferred_clear_at is not None + assert scheduler._mru_clear_suppression_available is True + + # Simulate the budget already being spent (suppression fired + # at first attempted deadline). + scheduler._mru_clear_suppression_available = False + + # A second completion arrives while the same epoch is still + # open. It may legitimately push the deadline out — that's the + # #557 invariant — but it must NOT refresh the spent budget. + scheduler._step_counter += 3 + req2 = Request( + request_id="req-mid-epoch", + prompt="b", + sampling_params=SamplingParams(), + ) + req2.prompt_token_ids = [4, 5] + req2.num_prompt_tokens = 2 + req2.output_token_ids = [6] + scheduler.running["req-mid-epoch"] = req2 + scheduler.requests["req-mid-epoch"] = req2 + + with patch("omlx.scheduler.mx"): + scheduler._cleanup_finished({"req-mid-epoch"}) + + # Deadline pushed (per #557), budget unchanged (per the C3 fix). + assert scheduler._deferred_clear_at == ( + scheduler._step_counter + Scheduler._DEFERRED_CLEAR_DELAY + ) + assert scheduler._mru_clear_suppression_available is False + + class TestPeriodicClearGating: """Tests for the conditional periodic clear (#978/#1040 mitigation).""" diff --git a/tests/test_settings.py b/tests/test_settings.py index f3a1f54b2..d3ccc4c68 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -342,6 +342,7 @@ def test_to_dict(self): "ssd_cache_max_size": "50GB", "hot_cache_max_size": "0", "initial_cache_blocks": 256, + "mru_partial_max_entries": 4, } def test_from_dict(self):
Block Size Indexed Blocks Sub-block CacheCache FilesCache SizeSSD FilesSSD SizeMemory EntriesMemory SizeMRU Tails
+ N/A (see log) + +