Skip to content

Commit 41b5546

Browse files
committed
fix: clear model cache when run.yaml model list changes
1 parent 521865c commit 41b5546

File tree

4 files changed

+80
-5
lines changed

4 files changed

+80
-5
lines changed

llama_stack/core/datatypes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
class RegistryEntrySource(StrEnum):
4141
via_register_api = "via_register_api"
4242
listed_from_provider = "listed_from_provider"
43+
from_config = "from_config"
4344

4445

4546
class User(BaseModel):

llama_stack/core/routing_tables/common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,13 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) ->
254254
if model is not None:
255255
return model
256256

257+
# Check from_config models if this is a ModelsRoutingTable
258+
if hasattr(routing_table, "_generate_from_config_models"):
259+
from_config_models = routing_table._generate_from_config_models()
260+
for from_config_model in from_config_models:
261+
if from_config_model.identifier == model_id:
262+
return from_config_model
263+
257264
logger.warning(
258265
f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to "
259266
"searching in all providers. This is only for backwards compatibility and will stop working "

llama_stack/core/routing_tables/models.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from llama_stack.core.datatypes import (
1313
ModelWithOwner,
1414
RegistryEntrySource,
15+
StackRunConfig,
1516
)
1617
from llama_stack.log import get_logger
1718

@@ -22,6 +23,7 @@
2223

2324
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
2425
listed_providers: set[str] = set()
26+
current_run_config: "StackRunConfig | None" = None
2527

2628
async def refresh(self) -> None:
2729
for provider_id, provider in self.impls_by_provider_id.items():
@@ -43,10 +45,26 @@ async def refresh(self) -> None:
4345
await self.update_registered_models(provider_id, models)
4446

4547
async def list_models(self) -> ListModelsResponse:
46-
return ListModelsResponse(data=await self.get_all_with_type("model"))
48+
# Get persistent models from registry
49+
persistent_models = await self.get_all_with_type("model")
50+
51+
# Generate from_config models dynamically
52+
from_config_models = self._generate_from_config_models()
53+
54+
# Combine both lists
55+
all_models = persistent_models + from_config_models
56+
57+
return ListModelsResponse(data=all_models)
4758

4859
async def openai_list_models(self) -> OpenAIListModelsResponse:
49-
models = await self.get_all_with_type("model")
60+
# Get persistent models from registry
61+
persistent_models = await self.get_all_with_type("model")
62+
63+
# Generate from_config models dynamically
64+
from_config_models = self._generate_from_config_models()
65+
66+
# Combine both lists
67+
models = persistent_models + from_config_models
5068
openai_models = [
5169
OpenAIModel(
5270
id=model.identifier,
@@ -74,6 +92,7 @@ async def register_model(
7492
provider_id: str | None = None,
7593
metadata: dict[str, Any] | None = None,
7694
model_type: ModelType | None = None,
95+
source: RegistryEntrySource = RegistryEntrySource.via_register_api,
7796
) -> Model:
7897
if provider_id is None:
7998
# If provider_id not specified, use the only provider if it supports this model
@@ -106,7 +125,7 @@ async def register_model(
106125
provider_id=provider_id,
107126
metadata=metadata,
108127
model_type=model_type,
109-
source=RegistryEntrySource.via_register_api,
128+
source=source,
110129
)
111130
registered_model = await self.register_object(model)
112131
return registered_model
@@ -117,6 +136,39 @@ async def unregister_model(self, model_id: str) -> None:
117136
raise ModelNotFoundError(model_id)
118137
await self.unregister_object(existing_model)
119138

139+
def set_run_config(self, run_config: "StackRunConfig") -> None:
140+
"""Set the current run configuration for generating from_config models."""
141+
self.current_run_config = run_config
142+
143+
def _generate_from_config_models(self) -> list[ModelWithOwner]:
144+
"""Generate from_config models from the current run configuration."""
145+
if not self.current_run_config:
146+
return []
147+
148+
from_config_models = []
149+
for model_input in self.current_run_config.models:
150+
# Skip models with disabled providers
151+
if not model_input.provider_id or model_input.provider_id == "__disabled__":
152+
continue
153+
154+
# Generate identifier
155+
if model_input.model_id != (model_input.provider_model_id or model_input.model_id):
156+
identifier = model_input.model_id
157+
else:
158+
identifier = f"{model_input.provider_id}/{model_input.provider_model_id or model_input.model_id}"
159+
160+
model = ModelWithOwner(
161+
identifier=identifier,
162+
provider_resource_id=model_input.provider_model_id or model_input.model_id,
163+
provider_id=model_input.provider_id,
164+
metadata=model_input.metadata,
165+
model_type=model_input.model_type or ModelType.llm,
166+
source=RegistryEntrySource.from_config,
167+
)
168+
from_config_models.append(model)
169+
170+
return from_config_models
171+
120172
async def update_registered_models(
121173
self,
122174
provider_id: str,
@@ -133,6 +185,10 @@ async def update_registered_models(
133185
if model.source == RegistryEntrySource.via_register_api:
134186
model_ids[model.provider_resource_id] = model.identifier
135187
continue
188+
# Also preserve from_config models - they should not be unregistered during refresh
189+
if model.source == RegistryEntrySource.from_config:
190+
model_ids[model.provider_resource_id] = model.identifier
191+
continue
136192

137193
logger.debug(f"unregistering model {model.identifier}")
138194
await self.unregister_object(model)

llama_stack/core/stack.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
3636
from llama_stack.apis.vector_dbs import VectorDBs
3737
from llama_stack.apis.vector_io import VectorIO
38-
from llama_stack.core.datatypes import Provider, StackRunConfig
38+
from llama_stack.core.datatypes import Provider, RegistryEntrySource, StackRunConfig
3939
from llama_stack.core.distribution import get_provider_registry
4040
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
4141
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
@@ -101,6 +101,11 @@ class LlamaStack(
101101

102102

103103
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
104+
# Set the run config on the models routing table for generating from_config models
105+
if Api.models in impls:
106+
models_impl = impls[Api.models]
107+
models_impl.set_run_config(run_config)
108+
104109
for rsrc, api, register_method, list_method in RESOURCES:
105110
objects = getattr(run_config, rsrc)
106111
if api not in impls:
@@ -118,7 +123,13 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
118123
# we want to maintain the type information in arguments to method.
119124
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
120125
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
121-
await method(**{k: getattr(obj, k) for k in obj.model_dump().keys()})
126+
kwargs = {k: getattr(obj, k) for k in obj.model_dump().keys()}
127+
128+
# For models, add source=from_config to indicate they come from run.yaml
129+
if rsrc == "models":
130+
kwargs["source"] = RegistryEntrySource.from_config
131+
132+
await method(**kwargs)
122133

123134
method = getattr(impls[api], list_method)
124135
response = await method()

0 commit comments

Comments
 (0)