diff --git a/tests/test_ensure_model_loaded.py b/tests/test_ensure_model_loaded.py new file mode 100644 index 0000000..5fa8db4 --- /dev/null +++ b/tests/test_ensure_model_loaded.py @@ -0,0 +1,195 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for `_is_model_loaded` and `ensure_model_loaded` in service.helpers. + +These helpers unify the previously inconsistent "is the request model name +something we can serve?" check across /v1/chat/completions, /v1/completions, +and /v1/messages routes. + +Future on-demand auto-loading (see follow-up PR) will hook into +`ensure_model_loaded`; today it strictly raises 404 for unloaded models. +""" + +from unittest.mock import MagicMock + +import pytest +from fastapi import HTTPException + +from vllm_mlx.config import get_config +from vllm_mlx.runtime.model_registry import ModelRegistry +from vllm_mlx.service.helpers import _is_model_loaded, ensure_model_loaded + + +@pytest.fixture +def reset_config(): + """Snapshot mutated config fields and restore after the test.""" + cfg = get_config() + fields = ("model_name", "model_alias", "model_path", "model_registry") + snap = {f: getattr(cfg, f) for f in fields} + yield cfg + for f, v in snap.items(): + setattr(cfg, f, v) + + +def _make_registry(*names: str) -> MagicMock: + """Spec'd registry mock — catches interface drift if ModelRegistry changes.""" + reg = MagicMock(spec=ModelRegistry) + reg.__contains__ = lambda self, x: x in names + reg.list_model_names.return_value = list(names) + return reg + + +# --------------------------------------------------------------------------- +# _is_model_loaded — single-model mode +# --------------------------------------------------------------------------- + + +class TestIsModelLoadedSingleMode: + def test_matches_model_name(self, reset_config): + reset_config.model_name = "qwen3.5-4b" + reset_config.model_alias = None + reset_config.model_path = None + reset_config.model_registry = None + assert _is_model_loaded("qwen3.5-4b") is True + + def test_matches_model_alias(self, reset_config): + reset_config.model_name = "qwen3.5-4b" + reset_config.model_alias = "qwen" + reset_config.model_path = None + reset_config.model_registry = None + assert _is_model_loaded("qwen") is True + + def test_matches_model_path(self, reset_config): + reset_config.model_name = "qwen3.5-4b" + reset_config.model_alias = None + reset_config.model_path = "mlx-community/Qwen2.5-4B-Instruct-4bit" + reset_config.model_registry = None + assert _is_model_loaded("mlx-community/Qwen2.5-4B-Instruct-4bit") is True + + def test_rejects_unknown_model(self, reset_config): + reset_config.model_name = "qwen3.5-4b" + reset_config.model_alias = None + reset_config.model_path = None + reset_config.model_registry = None + assert _is_model_loaded("kimi-48b") is False + + def test_default_accepted_in_single_mode(self, reset_config): + """P2-1 fix: 'default' is loaded in BOTH single-model and registry mode. + + Pre-fix, `_validate_model_name` accepted 'default' only when a registry + was configured — single-model servers 404'd on `model: "default"` even + though the request unambiguously targets the one model that's loaded. + """ + reset_config.model_name = "qwen3.5-4b" + reset_config.model_alias = None + reset_config.model_path = None + reset_config.model_registry = None + assert _is_model_loaded("default") is True + + def test_returns_false_when_no_model_configured(self, reset_config): + reset_config.model_name = None + reset_config.model_alias = None + reset_config.model_path = None + reset_config.model_registry = None + assert _is_model_loaded("anything") is False + + +# --------------------------------------------------------------------------- +# _is_model_loaded — registry (multi-model) mode +# --------------------------------------------------------------------------- + + +class TestIsModelLoadedRegistryMode: + def test_matches_registry_entry(self, reset_config): + reset_config.model_registry = _make_registry("qwen3.5-4b", "phi4-14b") + reset_config.model_name = "qwen3.5-4b" + assert _is_model_loaded("phi4-14b") is True + + def test_rejects_non_registry_entry(self, reset_config): + reset_config.model_registry = _make_registry("qwen3.5-4b", "phi4-14b") + reset_config.model_name = "qwen3.5-4b" + assert _is_model_loaded("kimi-48b") is False + + def test_default_accepted_in_registry_mode(self, reset_config): + reset_config.model_registry = _make_registry("qwen3.5-4b") + reset_config.model_name = "qwen3.5-4b" + assert _is_model_loaded("default") is True + + +# --------------------------------------------------------------------------- +# ensure_model_loaded — strict-404 semantics (PR #1 baseline) +# --------------------------------------------------------------------------- + + +class TestEnsureModelLoaded: + @pytest.mark.asyncio + async def test_noop_when_model_is_loaded(self, reset_config): + reset_config.model_name = "qwen3.5-4b" + reset_config.model_alias = None + reset_config.model_path = None + reset_config.model_registry = None + # Should not raise. + await ensure_model_loaded("qwen3.5-4b") + + @pytest.mark.asyncio + async def test_noop_for_empty_or_default(self, reset_config): + reset_config.model_name = "qwen3.5-4b" + reset_config.model_alias = None + reset_config.model_path = None + reset_config.model_registry = None + await ensure_model_loaded(None) + await ensure_model_loaded("") + await ensure_model_loaded("default") + + @pytest.mark.asyncio + async def test_raises_404_when_model_not_loaded(self, reset_config): + reset_config.model_name = "qwen3.5-4b" + reset_config.model_alias = None + reset_config.model_path = None + reset_config.model_registry = None + with pytest.raises(HTTPException) as exc: + await ensure_model_loaded("kimi-48b") + assert exc.value.status_code == 404 + assert "kimi-48b" in exc.value.detail + + @pytest.mark.asyncio + async def test_404_detail_lists_available_single_mode(self, reset_config): + """Locks the 404 contract: detail must include the `Available:` hint. + + Mirrors `_validate_model_name`'s message shape so the two helpers are + interchangeable on the strict-404 path — and so #319's swap logic can + replace this `raise` without changing what clients see on a miss. + """ + reset_config.model_name = "qwen3.5-4b" + reset_config.model_alias = None + reset_config.model_path = None + reset_config.model_registry = None + with pytest.raises(HTTPException) as exc: + await ensure_model_loaded("kimi-48b") + assert exc.value.status_code == 404 + assert "Available:" in exc.value.detail + assert "qwen3.5-4b" in exc.value.detail + + @pytest.mark.asyncio + async def test_404_detail_lists_available_registry_mode(self, reset_config): + reset_config.model_registry = _make_registry("qwen3.5-4b", "phi4-14b") + reset_config.model_name = "qwen3.5-4b" + reset_config.model_alias = None + reset_config.model_path = None + with pytest.raises(HTTPException) as exc: + await ensure_model_loaded("kimi-48b") + assert exc.value.status_code == 404 + assert "Available:" in exc.value.detail + assert "phi4-14b" in exc.value.detail + assert "qwen3.5-4b" in exc.value.detail + + @pytest.mark.asyncio + async def test_noop_when_server_unconfigured(self, reset_config): + """Theoretical unconfigured-server case — preserves parity with + `_validate_model_name`, which silently returns when neither + `model_name` nor `model_registry` is set.""" + reset_config.model_name = None + reset_config.model_alias = None + reset_config.model_path = None + reset_config.model_registry = None + # Should not raise even for a clearly-unknown name. + await ensure_model_loaded("anything") diff --git a/vllm_mlx/routes/chat.py b/vllm_mlx/routes/chat.py index 621c5b7..5f4d798 100644 --- a/vllm_mlx/routes/chat.py +++ b/vllm_mlx/routes/chat.py @@ -54,10 +54,10 @@ _resolve_model_name, _resolve_temperature, _resolve_top_p, - _validate_model_name, _validate_tool_call_params, _wait_with_disconnect, build_extended_sampling_kwargs, + ensure_model_loaded, get_engine, get_usage, ) @@ -203,7 +203,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re } ``` """ - _validate_model_name(request.model) + await ensure_model_loaded(request.model) engine = get_engine(request.model) # Validate messages is non-empty diff --git a/vllm_mlx/routes/completions.py b/vllm_mlx/routes/completions.py index da82c0a..8a67f8d 100644 --- a/vllm_mlx/routes/completions.py +++ b/vllm_mlx/routes/completions.py @@ -24,9 +24,9 @@ _resolve_model_name, _resolve_temperature, _resolve_top_p, - _validate_model_name, _wait_with_disconnect, build_extended_sampling_kwargs, + ensure_model_loaded, get_engine, get_usage, ) @@ -42,7 +42,7 @@ ) async def create_completion(request: CompletionRequest, raw_request: Request): """Create a text completion.""" - _validate_model_name(request.model) + await ensure_model_loaded(request.model) engine = get_engine(request.model) # Handle single prompt or list of prompts diff --git a/vllm_mlx/service/helpers.py b/vllm_mlx/service/helpers.py index 4b0cbd6..76833f4 100644 --- a/vllm_mlx/service/helpers.py +++ b/vllm_mlx/service/helpers.py @@ -367,35 +367,79 @@ def get_engine(model_name: str | None = None) -> BaseEngine: return cfg.engine +def _is_model_loaded(model_name: str | None) -> bool: + """Return True when model_name refers to a currently served model. + + Treats `None`, empty string, and the literal "default" as "loaded": the + request does not target a specific model, so the currently configured + default serves it. This unifies single-model and registry modes — the + pre-fix behavior accepted "default" only in registry mode, which surfaced + as a confusing 404 for clients sending `model: "default"` against a + single-model server. + """ + if not model_name or model_name == "default": + return True + cfg = get_config() + if cfg.model_registry is not None and model_name in cfg.model_registry: + return True + if not cfg.model_name: + return False + accepted = {cfg.model_name} + if cfg.model_alias: + accepted.add(cfg.model_alias) + if cfg.model_path: + accepted.add(cfg.model_path) + return model_name in accepted + + def _validate_model_name(request_model: str) -> None: """Validate that the request model name matches a served model.""" - if not request_model: + if _is_model_loaded(request_model): return cfg = get_config() - if cfg.model_registry and request_model in cfg.model_registry: - return - if cfg.model_registry and request_model == "default": + if not cfg.model_name and cfg.model_registry is None: return - if not cfg.model_name: + available = ( + ", ".join(cfg.model_registry.list_model_names()) + if cfg.model_registry is not None + else cfg.model_name + ) + raise HTTPException( + status_code=404, + detail=f"The model `{request_model}` does not exist. Available: {available}", + ) + + +async def ensure_model_loaded(model_name: str | None) -> None: + """Validate that `model_name` is currently loaded; otherwise raise 404. + + Canonical strict-404 check for /v1/chat/completions and /v1/completions — + mirrors `_validate_model_name`'s message shape so the two are + interchangeable on the strict-404 path. The sync counterpart is kept for + /v1/messages (anthropic adapter is model-name-agnostic by design). + + Extension point for on-demand auto-loading: a follow-up PR will let an + operator-set flag (`enable_on_demand_loading`) trigger a hot swap to the + requested model here instead of raising 404. The flag plumbing, swap + machinery, and security/concurrency model are intentionally out of scope + for this PR; today the function is strict-404. + """ + if _is_model_loaded(model_name): return - accepted = {cfg.model_name} - if cfg.model_alias: - accepted.add(cfg.model_alias) - if cfg.model_path: - accepted.add(cfg.model_path) - if request_model not in accepted: - available = ( - ", ".join(cfg.model_registry.list_model_names()) - if cfg.model_registry - else cfg.model_name - ) - raise HTTPException( - status_code=404, - detail=f"The model `{request_model}` does not exist. " - f"Available: {available}", - ) + cfg = get_config() + if not cfg.model_name and cfg.model_registry is None: + return + available = ( + ", ".join(cfg.model_registry.list_model_names()) + if cfg.model_registry is not None + else cfg.model_name + ) + raise HTTPException( + status_code=404, + detail=f"The model `{model_name}` does not exist. Available: {available}", + ) # ── Tool call parsing ──────────────────────────────────────────────