Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 195 additions & 0 deletions tests/test_ensure_model_loaded.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for `_is_model_loaded` and `ensure_model_loaded` in service.helpers.

These helpers unify the previously inconsistent "is the request model name
something we can serve?" check across /v1/chat/completions, /v1/completions,
and /v1/messages routes.

Future on-demand auto-loading (see follow-up PR) will hook into
`ensure_model_loaded`; today it strictly raises 404 for unloaded models.
"""

from unittest.mock import MagicMock

import pytest
from fastapi import HTTPException

from vllm_mlx.config import get_config
from vllm_mlx.runtime.model_registry import ModelRegistry
from vllm_mlx.service.helpers import _is_model_loaded, ensure_model_loaded


@pytest.fixture
def reset_config():
"""Snapshot mutated config fields and restore after the test."""
cfg = get_config()
fields = ("model_name", "model_alias", "model_path", "model_registry")
snap = {f: getattr(cfg, f) for f in fields}
yield cfg
for f, v in snap.items():
setattr(cfg, f, v)


def _make_registry(*names: str) -> MagicMock:
"""Spec'd registry mock — catches interface drift if ModelRegistry changes."""
reg = MagicMock(spec=ModelRegistry)
reg.__contains__ = lambda self, x: x in names
reg.list_model_names.return_value = list(names)
return reg


# ---------------------------------------------------------------------------
# _is_model_loaded — single-model mode
# ---------------------------------------------------------------------------


class TestIsModelLoadedSingleMode:
def test_matches_model_name(self, reset_config):
reset_config.model_name = "qwen3.5-4b"
reset_config.model_alias = None
reset_config.model_path = None
reset_config.model_registry = None
assert _is_model_loaded("qwen3.5-4b") is True

def test_matches_model_alias(self, reset_config):
reset_config.model_name = "qwen3.5-4b"
reset_config.model_alias = "qwen"
reset_config.model_path = None
reset_config.model_registry = None
assert _is_model_loaded("qwen") is True

def test_matches_model_path(self, reset_config):
reset_config.model_name = "qwen3.5-4b"
reset_config.model_alias = None
reset_config.model_path = "mlx-community/Qwen2.5-4B-Instruct-4bit"
reset_config.model_registry = None
assert _is_model_loaded("mlx-community/Qwen2.5-4B-Instruct-4bit") is True

def test_rejects_unknown_model(self, reset_config):
reset_config.model_name = "qwen3.5-4b"
reset_config.model_alias = None
reset_config.model_path = None
reset_config.model_registry = None
assert _is_model_loaded("kimi-48b") is False

def test_default_accepted_in_single_mode(self, reset_config):
"""P2-1 fix: 'default' is loaded in BOTH single-model and registry mode.

Pre-fix, `_validate_model_name` accepted 'default' only when a registry
was configured — single-model servers 404'd on `model: "default"` even
though the request unambiguously targets the one model that's loaded.
"""
reset_config.model_name = "qwen3.5-4b"
reset_config.model_alias = None
reset_config.model_path = None
reset_config.model_registry = None
assert _is_model_loaded("default") is True

def test_returns_false_when_no_model_configured(self, reset_config):
reset_config.model_name = None
reset_config.model_alias = None
reset_config.model_path = None
reset_config.model_registry = None
assert _is_model_loaded("anything") is False


# ---------------------------------------------------------------------------
# _is_model_loaded — registry (multi-model) mode
# ---------------------------------------------------------------------------


class TestIsModelLoadedRegistryMode:
def test_matches_registry_entry(self, reset_config):
reset_config.model_registry = _make_registry("qwen3.5-4b", "phi4-14b")
reset_config.model_name = "qwen3.5-4b"
assert _is_model_loaded("phi4-14b") is True

def test_rejects_non_registry_entry(self, reset_config):
reset_config.model_registry = _make_registry("qwen3.5-4b", "phi4-14b")
reset_config.model_name = "qwen3.5-4b"
assert _is_model_loaded("kimi-48b") is False

def test_default_accepted_in_registry_mode(self, reset_config):
reset_config.model_registry = _make_registry("qwen3.5-4b")
reset_config.model_name = "qwen3.5-4b"
assert _is_model_loaded("default") is True


# ---------------------------------------------------------------------------
# ensure_model_loaded — strict-404 semantics (PR #1 baseline)
# ---------------------------------------------------------------------------


class TestEnsureModelLoaded:
@pytest.mark.asyncio
async def test_noop_when_model_is_loaded(self, reset_config):
reset_config.model_name = "qwen3.5-4b"
reset_config.model_alias = None
reset_config.model_path = None
reset_config.model_registry = None
# Should not raise.
await ensure_model_loaded("qwen3.5-4b")

@pytest.mark.asyncio
async def test_noop_for_empty_or_default(self, reset_config):
reset_config.model_name = "qwen3.5-4b"
reset_config.model_alias = None
reset_config.model_path = None
reset_config.model_registry = None
await ensure_model_loaded(None)
await ensure_model_loaded("")
await ensure_model_loaded("default")

@pytest.mark.asyncio
async def test_raises_404_when_model_not_loaded(self, reset_config):
reset_config.model_name = "qwen3.5-4b"
reset_config.model_alias = None
reset_config.model_path = None
reset_config.model_registry = None
with pytest.raises(HTTPException) as exc:
await ensure_model_loaded("kimi-48b")
assert exc.value.status_code == 404
assert "kimi-48b" in exc.value.detail

@pytest.mark.asyncio
async def test_404_detail_lists_available_single_mode(self, reset_config):
"""Locks the 404 contract: detail must include the `Available:` hint.

Mirrors `_validate_model_name`'s message shape so the two helpers are
interchangeable on the strict-404 path — and so #319's swap logic can
replace this `raise` without changing what clients see on a miss.
"""
reset_config.model_name = "qwen3.5-4b"
reset_config.model_alias = None
reset_config.model_path = None
reset_config.model_registry = None
with pytest.raises(HTTPException) as exc:
await ensure_model_loaded("kimi-48b")
assert exc.value.status_code == 404
assert "Available:" in exc.value.detail
assert "qwen3.5-4b" in exc.value.detail

@pytest.mark.asyncio
async def test_404_detail_lists_available_registry_mode(self, reset_config):
reset_config.model_registry = _make_registry("qwen3.5-4b", "phi4-14b")
reset_config.model_name = "qwen3.5-4b"
reset_config.model_alias = None
reset_config.model_path = None
with pytest.raises(HTTPException) as exc:
await ensure_model_loaded("kimi-48b")
assert exc.value.status_code == 404
assert "Available:" in exc.value.detail
assert "phi4-14b" in exc.value.detail
assert "qwen3.5-4b" in exc.value.detail

@pytest.mark.asyncio
async def test_noop_when_server_unconfigured(self, reset_config):
"""Theoretical unconfigured-server case — preserves parity with
`_validate_model_name`, which silently returns when neither
`model_name` nor `model_registry` is set."""
reset_config.model_name = None
reset_config.model_alias = None
reset_config.model_path = None
reset_config.model_registry = None
# Should not raise even for a clearly-unknown name.
await ensure_model_loaded("anything")
4 changes: 2 additions & 2 deletions vllm_mlx/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@
_resolve_model_name,
_resolve_temperature,
_resolve_top_p,
_validate_model_name,
_validate_tool_call_params,
_wait_with_disconnect,
build_extended_sampling_kwargs,
ensure_model_loaded,
get_engine,
get_usage,
)
Expand Down Expand Up @@ -203,7 +203,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
}
```
"""
_validate_model_name(request.model)
await ensure_model_loaded(request.model)
engine = get_engine(request.model)

# Validate messages is non-empty
Expand Down
4 changes: 2 additions & 2 deletions vllm_mlx/routes/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
_resolve_model_name,
_resolve_temperature,
_resolve_top_p,
_validate_model_name,
_wait_with_disconnect,
build_extended_sampling_kwargs,
ensure_model_loaded,
get_engine,
get_usage,
)
Expand All @@ -42,7 +42,7 @@
)
async def create_completion(request: CompletionRequest, raw_request: Request):
"""Create a text completion."""
_validate_model_name(request.model)
await ensure_model_loaded(request.model)
engine = get_engine(request.model)

# Handle single prompt or list of prompts
Expand Down
86 changes: 65 additions & 21 deletions vllm_mlx/service/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,35 +367,79 @@ def get_engine(model_name: str | None = None) -> BaseEngine:
return cfg.engine


def _is_model_loaded(model_name: str | None) -> bool:
"""Return True when model_name refers to a currently served model.

Treats `None`, empty string, and the literal "default" as "loaded": the
request does not target a specific model, so the currently configured
default serves it. This unifies single-model and registry modes — the
pre-fix behavior accepted "default" only in registry mode, which surfaced
as a confusing 404 for clients sending `model: "default"` against a
single-model server.
"""
if not model_name or model_name == "default":
return True
cfg = get_config()
if cfg.model_registry is not None and model_name in cfg.model_registry:
return True
if not cfg.model_name:
return False
accepted = {cfg.model_name}
if cfg.model_alias:
accepted.add(cfg.model_alias)
if cfg.model_path:
accepted.add(cfg.model_path)
return model_name in accepted


def _validate_model_name(request_model: str) -> None:
"""Validate that the request model name matches a served model."""
if not request_model:
if _is_model_loaded(request_model):
return

cfg = get_config()
if cfg.model_registry and request_model in cfg.model_registry:
return
if cfg.model_registry and request_model == "default":
if not cfg.model_name and cfg.model_registry is None:
return

if not cfg.model_name:
available = (
", ".join(cfg.model_registry.list_model_names())
if cfg.model_registry is not None
else cfg.model_name
)
raise HTTPException(
status_code=404,
detail=f"The model `{request_model}` does not exist. Available: {available}",
)


async def ensure_model_loaded(model_name: str | None) -> None:
"""Validate that `model_name` is currently loaded; otherwise raise 404.

Canonical strict-404 check for /v1/chat/completions and /v1/completions —
mirrors `_validate_model_name`'s message shape so the two are
interchangeable on the strict-404 path. The sync counterpart is kept for
/v1/messages (anthropic adapter is model-name-agnostic by design).

Extension point for on-demand auto-loading: a follow-up PR will let an
operator-set flag (`enable_on_demand_loading`) trigger a hot swap to the
requested model here instead of raising 404. The flag plumbing, swap
machinery, and security/concurrency model are intentionally out of scope
for this PR; today the function is strict-404.
"""
if _is_model_loaded(model_name):
return
accepted = {cfg.model_name}
if cfg.model_alias:
accepted.add(cfg.model_alias)
if cfg.model_path:
accepted.add(cfg.model_path)
if request_model not in accepted:
available = (
", ".join(cfg.model_registry.list_model_names())
if cfg.model_registry
else cfg.model_name
)
raise HTTPException(
status_code=404,
detail=f"The model `{request_model}` does not exist. "
f"Available: {available}",
)
cfg = get_config()
if not cfg.model_name and cfg.model_registry is None:
return
available = (
", ".join(cfg.model_registry.list_model_names())
if cfg.model_registry is not None
else cfg.model_name
)
raise HTTPException(
status_code=404,
detail=f"The model `{model_name}` does not exist. Available: {available}",
)


# ── Tool call parsing ──────────────────────────────────────────────
Expand Down