diff --git a/api/config.py b/api/config.py index 406261e2d7..6e193827e8 100644 --- a/api/config.py +++ b/api/config.py @@ -990,7 +990,17 @@ def resolve_model_provider(model_id: str) -> tuple: entry_model = (entry.get("model") or "").strip() entry_name = (entry.get("name") or "").strip() entry_base_url = (entry.get("base_url") or "").strip() - if entry_model and entry_name and model_id == entry_model: + entry_model_ids = set() + if entry_model: + entry_model_ids.add(entry_model) + entry_models = entry.get("models") + if isinstance(entry_models, dict): + entry_model_ids.update( + key.strip() + for key in entry_models.keys() + if isinstance(key, str) and key.strip() + ) + if entry_name and model_id in entry_model_ids: provider_hint = "custom:" + entry_name.lower().replace(" ", "-") return model_id, provider_hint, entry_base_url or None @@ -1265,8 +1275,8 @@ def _is_valid_models_cache(cache: object) -> bool: ) -def _load_models_cache_from_disk() -> dict | None: - """Load /api/models cache from disk if it exists and has current metadata.""" +def _load_models_cache_from_disk(config_mtime: float | None = None) -> dict | None: + """Load /api/models cache from disk if it matches the current config.""" try: import json as _j @@ -1274,7 +1284,21 @@ def _load_models_cache_from_disk() -> dict | None: return None with open(_models_cache_path, encoding="utf-8") as f: cache = _j.load(f) - return cache if _is_valid_models_cache(cache) else None + if not _is_valid_models_cache(cache): + return None + expected_mtime = _cfg_mtime if config_mtime is None else config_mtime + try: + cached_mtime = float(cache.get("_config_mtime")) + except (TypeError, ValueError): + return None + if cached_mtime != float(expected_mtime): + return None + return { + "active_provider": cache["active_provider"], + "default_model": cache["default_model"], + "configured_model_badges": cache["configured_model_badges"], + "groups": cache["groups"], + } except Exception: return None @@ -1288,6 +1312,7 @@ def _save_models_cache_to_disk(cache: dict) -> None: with open(tmp, "w", encoding="utf-8") as f: json.dump( { + "_config_mtime": _cfg_mtime, "active_provider": cache["active_provider"], "default_model": cache["default_model"], "configured_model_badges": cache["configured_model_badges"], @@ -1428,6 +1453,7 @@ def get_available_models() -> dict: _current_mtime = 0.0 if _current_mtime != _cfg_mtime: reload_config() + invalidate_models_cache() # ── COLD PATH helper ───────────────────────────────────────────────────── # Extracted so it runs inside _available_models_cache_lock (RLock) to # prevent thundering-herd: only one thread rebuilds while others wait. @@ -2044,7 +2070,7 @@ def _build_configured_model_badges() -> dict[str, dict[str, str]]: # so only one thread rebuilds while others wait. disk_groups = None if _available_models_cache is None: - disk_groups = _load_models_cache_from_disk() + disk_groups = _load_models_cache_from_disk(_current_mtime) with _available_models_cache_lock: # If another thread is already building, wait for its result instead diff --git a/api/streaming.py b/api/streaming.py index c7cdcc5217..4491764e1a 100644 --- a/api/streaming.py +++ b/api/streaming.py @@ -1701,8 +1701,22 @@ def on_tool(*cb_args, **cb_kwargs): _toolsets = _resolve_cli_toolsets(_cfg) # Fallback model from profile config (e.g. for rate-limit recovery) - _fallback = _cfg.get('fallback_model') or None - if _fallback: + _fallback = _cfg.get('fallback_providers') or _cfg.get('fallback_model') or None + if isinstance(_fallback, list): + _fallback_resolved = [ + { + 'model': str(entry.get('model') or '').strip(), + 'provider': str(entry.get('provider') or '').strip(), + 'base_url': entry.get('base_url'), + } + for entry in _fallback + if isinstance(entry, dict) + and str(entry.get('model') or '').strip() + and str(entry.get('provider') or '').strip() + ] + if not _fallback_resolved: + _fallback_resolved = None + elif isinstance(_fallback, dict): # Resolve the fallback through our provider logic too fb_model = _fallback.get('model', '') fb_provider = _fallback.get('provider', '') diff --git a/tests/test_model_resolver.py b/tests/test_model_resolver.py index 4ceb1138bf..a30dac7307 100644 --- a/tests/test_model_resolver.py +++ b/tests/test_model_resolver.py @@ -159,6 +159,26 @@ def test_custom_provider_model_with_slash_routes_to_named_custom_provider(): assert base_url == 'http://lmstudio.local:1234/v1' +def test_custom_provider_models_dict_routes_to_named_custom_provider(): + """Models listed only under custom_providers[].models still route to that endpoint.""" + model, provider, base_url = _resolve_with_config( + 'sensenova-6.7-flash-lite', + provider='xiaomi', + custom_providers=[{ + 'name': 'LiteLLM Proxy', + 'base_url': 'http://127.0.0.1:8080/v1', + 'model': 'deepseek-v4-flash', + 'models': { + 'deepseek-v4-flash': {}, + 'sensenova-6.7-flash-lite': {}, + }, + }], + ) + assert model == 'sensenova-6.7-flash-lite' + assert provider == 'custom:litellm-proxy' + assert base_url == 'http://127.0.0.1:8080/v1' + + # ── get_available_models() @provider: hint behaviour ────────────────────── diff --git a/tests/test_ttl_cache.py b/tests/test_ttl_cache.py index 42a7d99c12..fb5cb65ad3 100644 --- a/tests/test_ttl_cache.py +++ b/tests/test_ttl_cache.py @@ -8,6 +8,7 @@ - copy.deepcopy() isolation (mutating returned dict doesn't pollute cache) - invalidate_models_cache() direct invalidation """ +import json import time from unittest.mock import patch @@ -108,45 +109,106 @@ def test_ttl_expiry(): # ── 3. test_mtime_invalidation ─────────────────────────────────────────── -def test_mtime_invalidation(): +def test_mtime_invalidation(tmp_path, monkeypatch): """Populate the cache, then change _cfg_mtime to simulate a config file change on disk. The next call should invalidate the cache and re-scan. """ _reset_cache() + config_path = tmp_path / "config.yaml" + cache_path = tmp_path / "models_cache.json" - # Ensure _cfg_mtime matches file so first call doesn't re-scan due to mtime - try: - real_mtime = config.Path(config._get_config_path()).stat().st_mtime - except OSError: - real_mtime = 0.0 - config._cfg_mtime = real_mtime + with monkeypatch.context() as m: + m.setattr(config, "_get_config_path", lambda: config_path) + m.setattr(config, "_models_cache_path", cache_path) - # First call populates cache - result1 = config.get_available_models() - assert config._available_models_cache is not None + config_path.write_text( + "model:\n provider: openai\n default: old-test-model\n", + encoding="utf-8", + ) + config.reload_config() + config.invalidate_models_cache() - # Simulate config.yaml changed on disk by setting _cfg_mtime to 0 - # (which won't match the actual file mtime) - config._cfg_mtime = 0.0 + result1 = config.get_available_models() + assert result1["default_model"] == "old-test-model" + assert config._available_models_cache is not None - # The next call should detect mtime mismatch, reload, and invalidate cache - old_cache = config._available_models_cache - old_ts = config._available_models_cache_ts + old_cache = config._available_models_cache - result2 = config.get_available_models() + # Simulate an external edit to config.yaml. This is the path that the + # WebUI hits when a user edits the file outside /api/default-model. + config_path.write_text( + "model:\n provider: openai\n default: new-test-model\n", + encoding="utf-8", + ) + new_mtime = config_path.stat().st_mtime + 2.0 + config_path.touch() + config.os.utime(config_path, (new_mtime, new_mtime)) - # Cache must have been refreshed — timestamp advanced since we reset it - # to 0.0 on invalidation. - assert config._available_models_cache_ts > 0.0, ( - "Cache timestamp should be updated after invalidation + rebuild" - ) + result2 = config.get_available_models() + + assert result2["default_model"] == "new-test-model" + assert config._available_models_cache is not old_cache, ( + "Cache object should be replaced after config mtime invalidation" + ) + assert config._available_models_cache_ts > 0.0, ( + "Cache timestamp should be updated after invalidation + rebuild" + ) + + config.reload_config() + _reset_cache() + + +# ── 4. test_stale_disk_cache_after_restart_ignored ───────────────────────── + +def test_stale_disk_cache_after_restart_ignored(tmp_path, monkeypatch): + """A stale disk cache from before a config change must not survive restart.""" + _reset_cache() + config_path = tmp_path / "config.yaml" + cache_path = tmp_path / "models_cache.json" + + with monkeypatch.context() as m: + m.setattr(config, "_get_config_path", lambda: config_path) + m.setattr(config, "_models_cache_path", cache_path) + + config_path.write_text( + "model:\n provider: xiaomi\n default: old-test-model\n", + encoding="utf-8", + ) + config.reload_config() + old_mtime = config._cfg_mtime + + stale_cache = { + "_config_mtime": old_mtime, + "active_provider": "xiaomi", + "default_model": "old-test-model", + "configured_model_badges": {}, + "groups": [], + } + cache_path.write_text(json.dumps(stale_cache), encoding="utf-8") + + config_path.write_text( + "model:\n provider: custom:litellm-proxy\n default: new-test-model\n", + encoding="utf-8", + ) + new_mtime = config_path.stat().st_mtime + 2.0 + config_path.touch() + config.os.utime(config_path, (new_mtime, new_mtime)) + + # Simulate a fresh server process: config was reloaded before the first + # /api/models request, so _cfg_changed is false on entry. + config.reload_config() + _reset_cache() + + result = config.get_available_models() + + assert result["active_provider"] == "custom:litellm-proxy" + assert result["default_model"] == "new-test-model" - # Restore - config._cfg_mtime = real_mtime + config.reload_config() _reset_cache() -# ── 4. test_deepcopy_isolation ──────────────────────────────────────────── +# ── 5. test_deepcopy_isolation ──────────────────────────────────────────── def test_deepcopy_isolation(): """Mutating the returned dict from get_available_models() must not @@ -189,7 +251,7 @@ def test_deepcopy_isolation(): _reset_cache() -# ── 5. test_invalidate_models_cache_direct ─────────────────────────────── +# ── 6. test_invalidate_models_cache_direct ─────────────────────────────── def test_invalidate_models_cache_direct(): """Call invalidate_models_cache() after populating the cache.