Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,17 @@ def resolve_model_provider(model_id: str) -> tuple:
entry_model = (entry.get("model") or "").strip()
entry_name = (entry.get("name") or "").strip()
entry_base_url = (entry.get("base_url") or "").strip()
if entry_model and entry_name and model_id == entry_model:
entry_model_ids = set()
if entry_model:
entry_model_ids.add(entry_model)
entry_models = entry.get("models")
if isinstance(entry_models, dict):
entry_model_ids.update(
key.strip()
for key in entry_models.keys()
if isinstance(key, str) and key.strip()
)
if entry_name and model_id in entry_model_ids:
provider_hint = "custom:" + entry_name.lower().replace(" ", "-")
return model_id, provider_hint, entry_base_url or None

Expand Down Expand Up @@ -1265,16 +1275,30 @@ def _is_valid_models_cache(cache: object) -> bool:
)


def _load_models_cache_from_disk() -> dict | None:
"""Load /api/models cache from disk if it exists and has current metadata."""
def _load_models_cache_from_disk(config_mtime: float | None = None) -> dict | None:
"""Load /api/models cache from disk if it matches the current config."""
try:
import json as _j

if not _models_cache_path.exists():
return None
with open(_models_cache_path, encoding="utf-8") as f:
cache = _j.load(f)
return cache if _is_valid_models_cache(cache) else None
if not _is_valid_models_cache(cache):
return None
expected_mtime = _cfg_mtime if config_mtime is None else config_mtime
try:
cached_mtime = float(cache.get("_config_mtime"))
except (TypeError, ValueError):
return None
if cached_mtime != float(expected_mtime):
return None
return {
"active_provider": cache["active_provider"],
"default_model": cache["default_model"],
"configured_model_badges": cache["configured_model_badges"],
"groups": cache["groups"],
}
except Exception:
return None

Expand All @@ -1288,6 +1312,7 @@ def _save_models_cache_to_disk(cache: dict) -> None:
with open(tmp, "w", encoding="utf-8") as f:
json.dump(
{
"_config_mtime": _cfg_mtime,
"active_provider": cache["active_provider"],
"default_model": cache["default_model"],
"configured_model_badges": cache["configured_model_badges"],
Expand Down Expand Up @@ -1428,6 +1453,7 @@ def get_available_models() -> dict:
_current_mtime = 0.0
if _current_mtime != _cfg_mtime:
reload_config()
invalidate_models_cache()
# ── COLD PATH helper ─────────────────────────────────────────────────────
# Extracted so it runs inside _available_models_cache_lock (RLock) to
# prevent thundering-herd: only one thread rebuilds while others wait.
Expand Down Expand Up @@ -2044,7 +2070,7 @@ def _build_configured_model_badges() -> dict[str, dict[str, str]]:
# so only one thread rebuilds while others wait.
disk_groups = None
if _available_models_cache is None:
disk_groups = _load_models_cache_from_disk()
disk_groups = _load_models_cache_from_disk(_current_mtime)

with _available_models_cache_lock:
# If another thread is already building, wait for its result instead
Expand Down
18 changes: 16 additions & 2 deletions api/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -1701,8 +1701,22 @@ def on_tool(*cb_args, **cb_kwargs):
_toolsets = _resolve_cli_toolsets(_cfg)

# Fallback model from profile config (e.g. for rate-limit recovery)
_fallback = _cfg.get('fallback_model') or None
if _fallback:
_fallback = _cfg.get('fallback_providers') or _cfg.get('fallback_model') or None
if isinstance(_fallback, list):
_fallback_resolved = [
{
'model': str(entry.get('model') or '').strip(),
'provider': str(entry.get('provider') or '').strip(),
'base_url': entry.get('base_url'),
}
for entry in _fallback
if isinstance(entry, dict)
and str(entry.get('model') or '').strip()
and str(entry.get('provider') or '').strip()
]
if not _fallback_resolved:
_fallback_resolved = None
elif isinstance(_fallback, dict):
# Resolve the fallback through our provider logic too
fb_model = _fallback.get('model', '')
fb_provider = _fallback.get('provider', '')
Expand Down
20 changes: 20 additions & 0 deletions tests/test_model_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,26 @@ def test_custom_provider_model_with_slash_routes_to_named_custom_provider():
assert base_url == 'http://lmstudio.local:1234/v1'


def test_custom_provider_models_dict_routes_to_named_custom_provider():
"""Models listed only under custom_providers[].models still route to that endpoint."""
model, provider, base_url = _resolve_with_config(
'sensenova-6.7-flash-lite',
provider='xiaomi',
custom_providers=[{
'name': 'LiteLLM Proxy',
'base_url': 'http://127.0.0.1:8080/v1',
'model': 'deepseek-v4-flash',
'models': {
'deepseek-v4-flash': {},
'sensenova-6.7-flash-lite': {},
},
}],
)
assert model == 'sensenova-6.7-flash-lite'
assert provider == 'custom:litellm-proxy'
assert base_url == 'http://127.0.0.1:8080/v1'


# ── get_available_models() @provider: hint behaviour ──────────────────────


Expand Down
114 changes: 88 additions & 26 deletions tests/test_ttl_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- copy.deepcopy() isolation (mutating returned dict doesn't pollute cache)
- invalidate_models_cache() direct invalidation
"""
import json
import time
from unittest.mock import patch

Expand Down Expand Up @@ -108,45 +109,106 @@ def test_ttl_expiry():

# ── 3. test_mtime_invalidation ───────────────────────────────────────────

def test_mtime_invalidation():
def test_mtime_invalidation(tmp_path, monkeypatch):
"""Populate the cache, then change _cfg_mtime to simulate a config file
change on disk. The next call should invalidate the cache and re-scan.
"""
_reset_cache()
config_path = tmp_path / "config.yaml"
cache_path = tmp_path / "models_cache.json"

# Ensure _cfg_mtime matches file so first call doesn't re-scan due to mtime
try:
real_mtime = config.Path(config._get_config_path()).stat().st_mtime
except OSError:
real_mtime = 0.0
config._cfg_mtime = real_mtime
with monkeypatch.context() as m:
m.setattr(config, "_get_config_path", lambda: config_path)
m.setattr(config, "_models_cache_path", cache_path)

# First call populates cache
result1 = config.get_available_models()
assert config._available_models_cache is not None
config_path.write_text(
"model:\n provider: openai\n default: old-test-model\n",
encoding="utf-8",
)
config.reload_config()
config.invalidate_models_cache()

# Simulate config.yaml changed on disk by setting _cfg_mtime to 0
# (which won't match the actual file mtime)
config._cfg_mtime = 0.0
result1 = config.get_available_models()
assert result1["default_model"] == "old-test-model"
assert config._available_models_cache is not None

# The next call should detect mtime mismatch, reload, and invalidate cache
old_cache = config._available_models_cache
old_ts = config._available_models_cache_ts
old_cache = config._available_models_cache

result2 = config.get_available_models()
# Simulate an external edit to config.yaml. This is the path that the
# WebUI hits when a user edits the file outside /api/default-model.
config_path.write_text(
"model:\n provider: openai\n default: new-test-model\n",
encoding="utf-8",
)
new_mtime = config_path.stat().st_mtime + 2.0
config_path.touch()
config.os.utime(config_path, (new_mtime, new_mtime))

# Cache must have been refreshed — timestamp advanced since we reset it
# to 0.0 on invalidation.
assert config._available_models_cache_ts > 0.0, (
"Cache timestamp should be updated after invalidation + rebuild"
)
result2 = config.get_available_models()

assert result2["default_model"] == "new-test-model"
assert config._available_models_cache is not old_cache, (
"Cache object should be replaced after config mtime invalidation"
)
assert config._available_models_cache_ts > 0.0, (
"Cache timestamp should be updated after invalidation + rebuild"
)

config.reload_config()
_reset_cache()


# ── 4. test_stale_disk_cache_after_restart_ignored ─────────────────────────

def test_stale_disk_cache_after_restart_ignored(tmp_path, monkeypatch):
"""A stale disk cache from before a config change must not survive restart."""
_reset_cache()
config_path = tmp_path / "config.yaml"
cache_path = tmp_path / "models_cache.json"

with monkeypatch.context() as m:
m.setattr(config, "_get_config_path", lambda: config_path)
m.setattr(config, "_models_cache_path", cache_path)

config_path.write_text(
"model:\n provider: xiaomi\n default: old-test-model\n",
encoding="utf-8",
)
config.reload_config()
old_mtime = config._cfg_mtime

stale_cache = {
"_config_mtime": old_mtime,
"active_provider": "xiaomi",
"default_model": "old-test-model",
"configured_model_badges": {},
"groups": [],
}
cache_path.write_text(json.dumps(stale_cache), encoding="utf-8")

config_path.write_text(
"model:\n provider: custom:litellm-proxy\n default: new-test-model\n",
encoding="utf-8",
)
new_mtime = config_path.stat().st_mtime + 2.0
config_path.touch()
config.os.utime(config_path, (new_mtime, new_mtime))

# Simulate a fresh server process: config was reloaded before the first
# /api/models request, so _cfg_changed is false on entry.
config.reload_config()
_reset_cache()

result = config.get_available_models()

assert result["active_provider"] == "custom:litellm-proxy"
assert result["default_model"] == "new-test-model"

# Restore
config._cfg_mtime = real_mtime
config.reload_config()
_reset_cache()


# ── 4. test_deepcopy_isolation ────────────────────────────────────────────
# ── 5. test_deepcopy_isolation ────────────────────────────────────────────

def test_deepcopy_isolation():
"""Mutating the returned dict from get_available_models() must not
Expand Down Expand Up @@ -189,7 +251,7 @@ def test_deepcopy_isolation():
_reset_cache()


# ── 5. test_invalidate_models_cache_direct ───────────────────────────────
# ── 6. test_invalidate_models_cache_direct ───────────────────────────────

def test_invalidate_models_cache_direct():
"""Call invalidate_models_cache() after populating the cache.
Expand Down