From 974e0c058f3266295323d51f762be8f7c2ae62d3 Mon Sep 17 00:00:00 2001 From: Carlos Bribiescas Date: Thu, 11 Jun 2026 19:05:38 +0800 Subject: [PATCH 1/3] feat(mllm): wire SSD cold tier onto the MLLM prefix cache --ssd-cache-dir was a silent no-op on the MLLM path used by Qwen3.5 and other VLM/hybrid models: the SSD tier was only attached to the standard Scheduler's MemoryAwarePrefixCache (scheduler.py ~1226). MLLMSchedulerConfig had no ssd field, batched._start_mllm passed none through, and the MLLM generator's MemoryAwarePrefixCache never got .set_ssd_tier(). Plumbing (additive, no-op when --ssd-cache-dir is unset): - MLLMSchedulerConfig gains ssd_cache_dir / ssd_cache_max_gb. - batched._start_mllm reads them off the SchedulerConfig (same fields cli.py populates) and forwards them. - mllm_scheduler builds SSDCacheTier(SSDCacheConfig(...)), start_writer() + reconcile(), and calls set_ssd_tier() on the generator's prefix cache, mirroring the standard path. One-line startup log makes it visible. The SSD serializer already supports ArraysCache (ssd_cache.py:544), so hybrid layers spill/promote correctly once the tier is attached. Tests: model-free wiring tests for the batched bridge and the scheduler attach logic. Co-Authored-By: Claude Opus 4.8 (1M context) --- tests/test_batched_engine_mllm_config.py | 84 +++++++++++++++ tests/test_mllm_ssd_spill.py | 125 +++++++++++++++++++++++ vllm_mlx/engine/batched.py | 6 ++ vllm_mlx/mllm_scheduler.py | 30 ++++++ 4 files changed, 245 insertions(+) create mode 100644 tests/test_mllm_ssd_spill.py diff --git a/tests/test_batched_engine_mllm_config.py b/tests/test_batched_engine_mllm_config.py index 19e998dc3..e16d161ef 100644 --- a/tests/test_batched_engine_mllm_config.py +++ b/tests/test_batched_engine_mllm_config.py @@ -77,3 +77,87 @@ async def start(self): assert captured["config_kwargs"]["use_memory_aware_cache"] is False assert captured["config_kwargs"]["cache_memory_mb"] == 123 assert captured["config_kwargs"]["prefix_cache_memory_mb"] == 123 + # SSD fields default to (None, 10.0) when absent from SchedulerConfig. + assert captured["config_kwargs"]["ssd_cache_dir"] is None + assert captured["config_kwargs"]["ssd_cache_max_gb"] == 10.0 + + +def _run_start_mllm(monkeypatch, scheduler_config): + """Run BatchedEngine._start_mllm with fakes, return captured kwargs.""" + from vllm_mlx.engine.batched import BatchedEngine + + captured = {} + + class FakeMLXMultimodalLM: + def __init__(self, model_name, trust_remote_code=True, **kwargs): + self.model = object() + self.processor = object() + + def load(self): + return None + + class FakeMLLMSchedulerConfig: + def __init__(self, **kwargs): + captured["config_kwargs"] = kwargs + self.__dict__.update(kwargs) + + class FakeMLLMScheduler: + def __init__(self, model, processor, config): + pass + + async def start(self): + return None + + import vllm_mlx.engine.batched as batched_mod + + fake_mllm_scheduler = types.ModuleType("vllm_mlx.mllm_scheduler") + fake_mllm_scheduler.MLLMScheduler = FakeMLLMScheduler + fake_mllm_scheduler.MLLMSchedulerConfig = FakeMLLMSchedulerConfig + fake_mllm_model = types.ModuleType("vllm_mlx.models.mllm") + fake_mllm_model.MLXMultimodalLM = FakeMLXMultimodalLM + monkeypatch.setitem(sys.modules, "vllm_mlx.mllm_scheduler", fake_mllm_scheduler) + monkeypatch.setitem(sys.modules, "vllm_mlx.models.mllm", fake_mllm_model) + monkeypatch.setattr( + batched_mod.BatchedEngine, "_inject_mtp_mllm", lambda self: None + ) + + engine = BatchedEngine( + model_name="fake-qwen", + scheduler_config=scheduler_config, + force_mllm=True, + ) + asyncio.run(engine._start_mllm()) + return captured + + +def _base_scheduler_config(**overrides): + cfg = dict( + max_num_seqs=16, + prefill_batch_size=4, + completion_batch_size=8, + prefill_step_size=256, + mllm_prefill_step_size=None, + enable_prefix_cache=True, + use_memory_aware_cache=True, + cache_memory_mb=None, + enable_mtp=False, + mtp_num_draft_tokens=1, + kv_cache_quantization=False, + kv_cache_quantization_bits=8, + kv_cache_quantization_group_size=64, + chunked_prefill_tokens=0, + max_kv_size=0, + ) + cfg.update(overrides) + return SimpleNamespace(**cfg) + + +def test_start_mllm_forwards_ssd_cache_fields(monkeypatch): + captured = _run_start_mllm( + monkeypatch, + _base_scheduler_config( + ssd_cache_dir="/tmp/ssd-kv", ssd_cache_max_gb=42.0 + ), + ) + assert captured["config_kwargs"]["ssd_cache_dir"] == "/tmp/ssd-kv" + assert captured["config_kwargs"]["ssd_cache_max_gb"] == 42.0 diff --git a/tests/test_mllm_ssd_spill.py b/tests/test_mllm_ssd_spill.py new file mode 100644 index 000000000..ec7d45481 --- /dev/null +++ b/tests/test_mllm_ssd_spill.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests that the SSD cold tier is wired onto the MLLM prefix cache. + +`--ssd-cache-dir` used to be a silent no-op on the MLLM path (Qwen3.5 et al.) +because the SSD tier was only attached to the standard Scheduler's +MemoryAwarePrefixCache. These tests assert the MLLM scheduler now builds an +SSDCacheTier and calls set_ssd_tier() on the generator's prefix cache when the +flag is set, and does NOT when it is unset. Model-free: the batch generator, +sampler and SSD tier are all faked. +""" + +import sys +import types +from types import SimpleNamespace + +from vllm_mlx.mllm_scheduler import MLLMScheduler, MLLMSchedulerConfig + + +def _install_fakes(monkeypatch, prefix_cache_obj): + """Patch the heavy collaborators _ensure_batch_generator pulls in.""" + import vllm_mlx.mllm_scheduler as sched_mod + + set_calls = [] + tier_instances = [] + + class FakeGenerator: + def __init__(self, *a, **kw): + self.prefix_cache = prefix_cache_obj + self.language_model = SimpleNamespace(mtp=None) + + class FakeTier: + def __init__(self, config): + self.config = config + self.started = False + self.reconciled = False + tier_instances.append(self) + + def start_writer(self): + self.started = True + + def reconcile(self): + self.reconciled = True + + class FakeTierConfig: + def __init__(self, cache_dir=None, max_size_gb=10.0): + self.cache_dir = cache_dir + self.max_size_gb = max_size_gb + + if prefix_cache_obj is not None: + monkeypatch.setattr( + prefix_cache_obj, + "set_ssd_tier", + lambda tier: set_calls.append(tier), + raising=False, + ) + + monkeypatch.setattr(sched_mod, "MLLMBatchGenerator", FakeGenerator) + + # make_sampler and MemoryCacheConfig are imported lazily inside the method. + fake_sample_utils = types.ModuleType("mlx_lm.sample_utils") + fake_sample_utils.make_sampler = lambda **kw: (lambda x: x) + monkeypatch.setitem(sys.modules, "mlx_lm.sample_utils", fake_sample_utils) + + fake_ssd = types.ModuleType("vllm_mlx.ssd_cache") + fake_ssd.SSDCacheTier = FakeTier + fake_ssd.SSDCacheConfig = FakeTierConfig + monkeypatch.setitem(sys.modules, "vllm_mlx.ssd_cache", fake_ssd) + + return set_calls, tier_instances + + +def _bare_scheduler(config): + """Construct an MLLMScheduler without running its heavy __init__.""" + sched = MLLMScheduler.__new__(MLLMScheduler) + sched.config = config + sched.model = SimpleNamespace() + sched.processor = SimpleNamespace() + sched.mm_processor = SimpleNamespace() + sched.stop_tokens = set() + sched.batch_generator = None + return sched + + +def test_ssd_tier_attached_when_dir_set(monkeypatch, tmp_path): + prefix_cache = SimpleNamespace() + set_calls, tiers = _install_fakes(monkeypatch, prefix_cache) + + cfg = MLLMSchedulerConfig( + ssd_cache_dir=str(tmp_path), ssd_cache_max_gb=7.0 + ) + sched = _bare_scheduler(cfg) + sched._ensure_batch_generator() + + assert len(tiers) == 1 + assert tiers[0].started and tiers[0].reconciled + assert tiers[0].config.cache_dir == str(tmp_path) + assert tiers[0].config.max_size_gb == 7.0 + assert len(set_calls) == 1 + assert set_calls[0] is tiers[0] + assert sched._ssd_tier is tiers[0] + + +def test_no_tier_when_dir_unset(monkeypatch): + prefix_cache = SimpleNamespace() + set_calls, tiers = _install_fakes(monkeypatch, prefix_cache) + + cfg = MLLMSchedulerConfig(ssd_cache_dir=None) + sched = _bare_scheduler(cfg) + sched._ensure_batch_generator() + + assert tiers == [] + assert set_calls == [] + assert sched._ssd_tier is None + + +def test_no_tier_when_prefix_cache_absent(monkeypatch, tmp_path): + # Prefix caching disabled → generator.prefix_cache is None → no SSD tier. + set_calls, tiers = _install_fakes(monkeypatch, None) + + cfg = MLLMSchedulerConfig(ssd_cache_dir=str(tmp_path)) + sched = _bare_scheduler(cfg) + sched._ensure_batch_generator() + + assert tiers == [] + assert sched._ssd_tier is None diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index cc09136ac..8aaf25814 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -337,6 +337,10 @@ async def _start_mllm(self) -> None: mtp_num_draft = getattr(self._scheduler_config, "mtp_num_draft_tokens", 1) kv_quant = getattr(self._scheduler_config, "kv_cache_quantization", False) kv_bits = getattr(self._scheduler_config, "kv_cache_quantization_bits", 8) + # SSD cold tier — same SchedulerConfig fields the standard path reads + # (cli.py populates ssd_cache_dir/ssd_cache_max_gb). None = disabled. + ssd_cache_dir = getattr(self._scheduler_config, "ssd_cache_dir", None) + ssd_cache_max_gb = getattr(self._scheduler_config, "ssd_cache_max_gb", 10.0) kv_group_size = getattr( self._scheduler_config, "kv_cache_quantization_group_size", 64 ) @@ -372,6 +376,8 @@ async def _start_mllm(self) -> None: kv_cache_quantization_group_size=kv_group_size, chunked_prefill_tokens=chunked_prefill_tokens, max_kv_size=max_kv_size, + ssd_cache_dir=ssd_cache_dir, + ssd_cache_max_gb=ssd_cache_max_gb, **mllm_extra, ) diff --git a/vllm_mlx/mllm_scheduler.py b/vllm_mlx/mllm_scheduler.py index cb35fb391..bf2a085b3 100644 --- a/vllm_mlx/mllm_scheduler.py +++ b/vllm_mlx/mllm_scheduler.py @@ -85,6 +85,11 @@ class MLLMSchedulerConfig: chunked_prefill_tokens: int = 0 # Maximum KV cache size per sequence (0 = unbounded; >0 enables RotatingKVCache) max_kv_size: int = 0 + # SSD cold tier for the prefix cache (mirrors SchedulerConfig). + # None = disabled. When set, the MLLM MemoryAwarePrefixCache spills + # evicted entries to disk and promotes them back on hit. + ssd_cache_dir: Optional[str] = None + ssd_cache_max_gb: float = 10.0 @dataclass @@ -321,6 +326,31 @@ def _ensure_batch_generator(self) -> None: max_kv_size=self.config.max_kv_size, ) + # Wire the SSD cold tier onto the MLLM prefix cache, mirroring the + # standard Scheduler path (see scheduler.py ~1226). Without this + # --ssd-cache-dir is a silent no-op for MLLM models (Qwen3.5 et al.) + # because the SSD tier was only ever attached to the standard + # Scheduler's MemoryAwarePrefixCache. No-op when the flag is unset. + self._ssd_tier = None + prefix_cache = getattr(self.batch_generator, "prefix_cache", None) + if self.config.ssd_cache_dir is not None and prefix_cache is not None: + from .ssd_cache import SSDCacheConfig, SSDCacheTier + + ssd_config = SSDCacheConfig( + cache_dir=self.config.ssd_cache_dir, + max_size_gb=self.config.ssd_cache_max_gb, + ) + self._ssd_tier = SSDCacheTier(ssd_config) + self._ssd_tier.start_writer() + self._ssd_tier.reconcile() + prefix_cache.set_ssd_tier(self._ssd_tier) + logger.info( + "[mllm] SSD cache tier enabled on MLLM prefix cache: " + "dir=%s, max=%sGB", + self.config.ssd_cache_dir, + self.config.ssd_cache_max_gb, + ) + # Install chunked prefill BEFORE MTP (MTP wraps _next, # chunked replaces it — MTP then wraps the chunked version) if self.config.chunked_prefill_tokens > 0: From 7bbd89a2ae9636f5713c5eb2402ea335aa3fe8e6 Mon Sep 17 00:00:00 2001 From: Carlos Bribiescas Date: Sun, 14 Jun 2026 17:09:19 +0800 Subject: [PATCH 2/3] lint: black-format mllm SSD wiring + test file Resolves CI lint failure on the first run of fix/mllm-ssd-wire-cold-tier by reformatting tests/test_mllm_ssd_spill.py to match the project's black config. --- tests/test_mllm_ssd_spill.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_mllm_ssd_spill.py b/tests/test_mllm_ssd_spill.py index ec7d45481..34aa1652f 100644 --- a/tests/test_mllm_ssd_spill.py +++ b/tests/test_mllm_ssd_spill.py @@ -85,9 +85,7 @@ def test_ssd_tier_attached_when_dir_set(monkeypatch, tmp_path): prefix_cache = SimpleNamespace() set_calls, tiers = _install_fakes(monkeypatch, prefix_cache) - cfg = MLLMSchedulerConfig( - ssd_cache_dir=str(tmp_path), ssd_cache_max_gb=7.0 - ) + cfg = MLLMSchedulerConfig(ssd_cache_dir=str(tmp_path), ssd_cache_max_gb=7.0) sched = _bare_scheduler(cfg) sched._ensure_batch_generator() From 5a80859182ea6a9a1f90a849bee2be5f8908ce1f Mon Sep 17 00:00:00 2001 From: Carlos Bribiescas Date: Mon, 15 Jun 2026 19:34:34 +0800 Subject: [PATCH 3/3] lint: black-format test_batched_engine_mllm_config.py Missed the fourth file in the earlier black sweep on 7bbd89a; CI lint still failed because tests/test_batched_engine_mllm_config.py had non-black formatting. Single-line whitespace tweak. --- tests/test_batched_engine_mllm_config.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_batched_engine_mllm_config.py b/tests/test_batched_engine_mllm_config.py index e16d161ef..69a6b7e19 100644 --- a/tests/test_batched_engine_mllm_config.py +++ b/tests/test_batched_engine_mllm_config.py @@ -155,9 +155,7 @@ def _base_scheduler_config(**overrides): def test_start_mllm_forwards_ssd_cache_fields(monkeypatch): captured = _run_start_mllm( monkeypatch, - _base_scheduler_config( - ssd_cache_dir="/tmp/ssd-kv", ssd_cache_max_gb=42.0 - ), + _base_scheduler_config(ssd_cache_dir="/tmp/ssd-kv", ssd_cache_max_gb=42.0), ) assert captured["config_kwargs"]["ssd_cache_dir"] == "/tmp/ssd-kv" assert captured["config_kwargs"]["ssd_cache_max_gb"] == 42.0