Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions tests/test_batched_engine_mllm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,85 @@ async def start(self):
assert captured["config_kwargs"]["use_memory_aware_cache"] is False
assert captured["config_kwargs"]["cache_memory_mb"] == 123
assert captured["config_kwargs"]["prefix_cache_memory_mb"] == 123
# SSD fields default to (None, 10.0) when absent from SchedulerConfig.
assert captured["config_kwargs"]["ssd_cache_dir"] is None
assert captured["config_kwargs"]["ssd_cache_max_gb"] == 10.0


def _run_start_mllm(monkeypatch, scheduler_config):
"""Run BatchedEngine._start_mllm with fakes, return captured kwargs."""
from vllm_mlx.engine.batched import BatchedEngine

captured = {}

class FakeMLXMultimodalLM:
def __init__(self, model_name, trust_remote_code=True, **kwargs):
self.model = object()
self.processor = object()

def load(self):
return None

class FakeMLLMSchedulerConfig:
def __init__(self, **kwargs):
captured["config_kwargs"] = kwargs
self.__dict__.update(kwargs)

class FakeMLLMScheduler:
def __init__(self, model, processor, config):
pass

async def start(self):
return None

import vllm_mlx.engine.batched as batched_mod

fake_mllm_scheduler = types.ModuleType("vllm_mlx.mllm_scheduler")
fake_mllm_scheduler.MLLMScheduler = FakeMLLMScheduler
fake_mllm_scheduler.MLLMSchedulerConfig = FakeMLLMSchedulerConfig
fake_mllm_model = types.ModuleType("vllm_mlx.models.mllm")
fake_mllm_model.MLXMultimodalLM = FakeMLXMultimodalLM
monkeypatch.setitem(sys.modules, "vllm_mlx.mllm_scheduler", fake_mllm_scheduler)
monkeypatch.setitem(sys.modules, "vllm_mlx.models.mllm", fake_mllm_model)
monkeypatch.setattr(
batched_mod.BatchedEngine, "_inject_mtp_mllm", lambda self: None
)

engine = BatchedEngine(
model_name="fake-qwen",
scheduler_config=scheduler_config,
force_mllm=True,
)
asyncio.run(engine._start_mllm())
return captured


def _base_scheduler_config(**overrides):
cfg = dict(
max_num_seqs=16,
prefill_batch_size=4,
completion_batch_size=8,
prefill_step_size=256,
mllm_prefill_step_size=None,
enable_prefix_cache=True,
use_memory_aware_cache=True,
cache_memory_mb=None,
enable_mtp=False,
mtp_num_draft_tokens=1,
kv_cache_quantization=False,
kv_cache_quantization_bits=8,
kv_cache_quantization_group_size=64,
chunked_prefill_tokens=0,
max_kv_size=0,
)
cfg.update(overrides)
return SimpleNamespace(**cfg)


def test_start_mllm_forwards_ssd_cache_fields(monkeypatch):
captured = _run_start_mllm(
monkeypatch,
_base_scheduler_config(ssd_cache_dir="/tmp/ssd-kv", ssd_cache_max_gb=42.0),
)
assert captured["config_kwargs"]["ssd_cache_dir"] == "/tmp/ssd-kv"
assert captured["config_kwargs"]["ssd_cache_max_gb"] == 42.0
123 changes: 123 additions & 0 deletions tests/test_mllm_ssd_spill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests that the SSD cold tier is wired onto the MLLM prefix cache.

`--ssd-cache-dir` used to be a silent no-op on the MLLM path (Qwen3.5 et al.)
because the SSD tier was only attached to the standard Scheduler's
MemoryAwarePrefixCache. These tests assert the MLLM scheduler now builds an
SSDCacheTier and calls set_ssd_tier() on the generator's prefix cache when the
flag is set, and does NOT when it is unset. Model-free: the batch generator,
sampler and SSD tier are all faked.
"""

import sys
import types
from types import SimpleNamespace

from vllm_mlx.mllm_scheduler import MLLMScheduler, MLLMSchedulerConfig


def _install_fakes(monkeypatch, prefix_cache_obj):
"""Patch the heavy collaborators _ensure_batch_generator pulls in."""
import vllm_mlx.mllm_scheduler as sched_mod

set_calls = []
tier_instances = []

class FakeGenerator:
def __init__(self, *a, **kw):
self.prefix_cache = prefix_cache_obj
self.language_model = SimpleNamespace(mtp=None)

class FakeTier:
def __init__(self, config):
self.config = config
self.started = False
self.reconciled = False
tier_instances.append(self)

def start_writer(self):
self.started = True

def reconcile(self):
self.reconciled = True

class FakeTierConfig:
def __init__(self, cache_dir=None, max_size_gb=10.0):
self.cache_dir = cache_dir
self.max_size_gb = max_size_gb

if prefix_cache_obj is not None:
monkeypatch.setattr(
prefix_cache_obj,
"set_ssd_tier",
lambda tier: set_calls.append(tier),
raising=False,
)

monkeypatch.setattr(sched_mod, "MLLMBatchGenerator", FakeGenerator)

# make_sampler and MemoryCacheConfig are imported lazily inside the method.
fake_sample_utils = types.ModuleType("mlx_lm.sample_utils")
fake_sample_utils.make_sampler = lambda **kw: (lambda x: x)
monkeypatch.setitem(sys.modules, "mlx_lm.sample_utils", fake_sample_utils)

fake_ssd = types.ModuleType("vllm_mlx.ssd_cache")
fake_ssd.SSDCacheTier = FakeTier
fake_ssd.SSDCacheConfig = FakeTierConfig
monkeypatch.setitem(sys.modules, "vllm_mlx.ssd_cache", fake_ssd)

return set_calls, tier_instances


def _bare_scheduler(config):
"""Construct an MLLMScheduler without running its heavy __init__."""
sched = MLLMScheduler.__new__(MLLMScheduler)
sched.config = config
sched.model = SimpleNamespace()
sched.processor = SimpleNamespace()
sched.mm_processor = SimpleNamespace()
sched.stop_tokens = set()
sched.batch_generator = None
return sched


def test_ssd_tier_attached_when_dir_set(monkeypatch, tmp_path):
prefix_cache = SimpleNamespace()
set_calls, tiers = _install_fakes(monkeypatch, prefix_cache)

cfg = MLLMSchedulerConfig(ssd_cache_dir=str(tmp_path), ssd_cache_max_gb=7.0)
sched = _bare_scheduler(cfg)
sched._ensure_batch_generator()

assert len(tiers) == 1
assert tiers[0].started and tiers[0].reconciled
assert tiers[0].config.cache_dir == str(tmp_path)
assert tiers[0].config.max_size_gb == 7.0
assert len(set_calls) == 1
assert set_calls[0] is tiers[0]
assert sched._ssd_tier is tiers[0]


def test_no_tier_when_dir_unset(monkeypatch):
prefix_cache = SimpleNamespace()
set_calls, tiers = _install_fakes(monkeypatch, prefix_cache)

cfg = MLLMSchedulerConfig(ssd_cache_dir=None)
sched = _bare_scheduler(cfg)
sched._ensure_batch_generator()

assert tiers == []
assert set_calls == []
assert sched._ssd_tier is None


def test_no_tier_when_prefix_cache_absent(monkeypatch, tmp_path):
# Prefix caching disabled → generator.prefix_cache is None → no SSD tier.
set_calls, tiers = _install_fakes(monkeypatch, None)

cfg = MLLMSchedulerConfig(ssd_cache_dir=str(tmp_path))
sched = _bare_scheduler(cfg)
sched._ensure_batch_generator()

assert tiers == []
assert sched._ssd_tier is None
6 changes: 6 additions & 0 deletions vllm_mlx/engine/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,10 @@ async def _start_mllm(self) -> None:
mtp_num_draft = getattr(self._scheduler_config, "mtp_num_draft_tokens", 1)
kv_quant = getattr(self._scheduler_config, "kv_cache_quantization", False)
kv_bits = getattr(self._scheduler_config, "kv_cache_quantization_bits", 8)
# SSD cold tier — same SchedulerConfig fields the standard path reads
# (cli.py populates ssd_cache_dir/ssd_cache_max_gb). None = disabled.
ssd_cache_dir = getattr(self._scheduler_config, "ssd_cache_dir", None)
ssd_cache_max_gb = getattr(self._scheduler_config, "ssd_cache_max_gb", 10.0)
kv_group_size = getattr(
self._scheduler_config, "kv_cache_quantization_group_size", 64
)
Expand Down Expand Up @@ -372,6 +376,8 @@ async def _start_mllm(self) -> None:
kv_cache_quantization_group_size=kv_group_size,
chunked_prefill_tokens=chunked_prefill_tokens,
max_kv_size=max_kv_size,
ssd_cache_dir=ssd_cache_dir,
ssd_cache_max_gb=ssd_cache_max_gb,
**mllm_extra,
)

Expand Down
30 changes: 30 additions & 0 deletions vllm_mlx/mllm_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ class MLLMSchedulerConfig:
chunked_prefill_tokens: int = 0
# Maximum KV cache size per sequence (0 = unbounded; >0 enables RotatingKVCache)
max_kv_size: int = 0
# SSD cold tier for the prefix cache (mirrors SchedulerConfig).
# None = disabled. When set, the MLLM MemoryAwarePrefixCache spills
# evicted entries to disk and promotes them back on hit.
ssd_cache_dir: Optional[str] = None
ssd_cache_max_gb: float = 10.0


@dataclass
Expand Down Expand Up @@ -321,6 +326,31 @@ def _ensure_batch_generator(self) -> None:
max_kv_size=self.config.max_kv_size,
)

# Wire the SSD cold tier onto the MLLM prefix cache, mirroring the
# standard Scheduler path (see scheduler.py ~1226). Without this
# --ssd-cache-dir is a silent no-op for MLLM models (Qwen3.5 et al.)
# because the SSD tier was only ever attached to the standard
# Scheduler's MemoryAwarePrefixCache. No-op when the flag is unset.
self._ssd_tier = None
prefix_cache = getattr(self.batch_generator, "prefix_cache", None)
if self.config.ssd_cache_dir is not None and prefix_cache is not None:
from .ssd_cache import SSDCacheConfig, SSDCacheTier

ssd_config = SSDCacheConfig(
cache_dir=self.config.ssd_cache_dir,
max_size_gb=self.config.ssd_cache_max_gb,
)
self._ssd_tier = SSDCacheTier(ssd_config)
self._ssd_tier.start_writer()
self._ssd_tier.reconcile()
prefix_cache.set_ssd_tier(self._ssd_tier)
logger.info(
"[mllm] SSD cache tier enabled on MLLM prefix cache: "
"dir=%s, max=%sGB",
self.config.ssd_cache_dir,
self.config.ssd_cache_max_gb,
)

# Install chunked prefill BEFORE MTP (MTP wraps _next,
# chunked replaces it — MTP then wraps the chunked version)
if self.config.chunked_prefill_tokens > 0:
Expand Down
Loading