Skip to content

Commit 62d1d59

Browse files
committed
fix: clear model cache when run.yaml model list changes
1 parent eb07a0f commit 62d1d59

File tree

3 files changed

+78
-7
lines changed

3 files changed

+78
-7
lines changed

llama_stack/core/datatypes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
class RegistryEntrySource(StrEnum):
4040
via_register_api = "via_register_api"
4141
listed_from_provider = "listed_from_provider"
42+
from_config = "from_config"
4243

4344

4445
class User(BaseModel):

llama_stack/core/routing_tables/models.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ async def register_model(
7474
provider_id: str | None = None,
7575
metadata: dict[str, Any] | None = None,
7676
model_type: ModelType | None = None,
77+
source: RegistryEntrySource | None = None,
7778
) -> Model:
7879
if provider_id is None:
7980
# If provider_id not specified, use the only provider if it supports this model
@@ -100,13 +101,15 @@ async def register_model(
100101
else:
101102
identifier = f"{provider_id}/{provider_model_id}"
102103

104+
source = source or RegistryEntrySource.via_register_api
105+
103106
model = ModelWithOwner(
104107
identifier=identifier,
105108
provider_resource_id=provider_model_id,
106109
provider_id=provider_id,
107110
metadata=metadata,
108111
model_type=model_type,
109-
source=RegistryEntrySource.via_register_api,
112+
source=source,
110113
)
111114
registered_model = await self.register_object(model)
112115
return registered_model
@@ -130,7 +133,7 @@ async def update_registered_models(
130133
for model in existing_models:
131134
if model.provider_id != provider_id:
132135
continue
133-
if model.source == RegistryEntrySource.via_register_api:
136+
if model.source in [RegistryEntrySource.via_register_api, RegistryEntrySource.from_config]:
134137
model_ids[model.provider_resource_id] = model.identifier
135138
continue
136139

@@ -156,3 +159,45 @@ async def update_registered_models(
156159
source=RegistryEntrySource.listed_from_provider,
157160
)
158161
)
162+
163+
async def cleanup_ephemeral_models(self) -> None:
164+
"""Clean up models that should not persist across sessions."""
165+
existing_models = await self.get_all_with_type("model")
166+
167+
for model in existing_models:
168+
if model.source == RegistryEntrySource.listed_from_provider:
169+
logger.debug(f"Cleaning up ephemeral provider model: {model.identifier}")
170+
await self.unregister_object(model)
171+
continue
172+
173+
async def cleanup_config_models(self) -> None:
174+
"""Clean up models that came from configuration (run.yaml)."""
175+
existing_models = await self.get_all_with_type("model")
176+
177+
for model in existing_models:
178+
if model.source == RegistryEntrySource.from_config:
179+
logger.debug(f"Cleaning up config model: {model.identifier}")
180+
await self.unregister_object(model)
181+
continue
182+
183+
async def initialize(self) -> None:
184+
"""Initialize the models routing table with cleanup."""
185+
# Clean up provider models from previous sessions
186+
await self.cleanup_ephemeral_models()
187+
188+
# Also clean up config models from previous sessions
189+
# This ensures we start with a clean slate for config models
190+
await self.cleanup_config_models()
191+
192+
await super().initialize()
193+
194+
async def shutdown(self) -> None:
195+
"""Shutdown with cleanup of ephemeral models."""
196+
# Clean up provider models before shutdown
197+
existing_models = await self.get_all_with_type("model")
198+
199+
for model in existing_models:
200+
if model.source == RegistryEntrySource.listed_from_provider:
201+
await self.unregister_object(model)
202+
203+
await super().shutdown()

llama_stack/core/stack.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
3535
from llama_stack.apis.vector_dbs import VectorDBs
3636
from llama_stack.apis.vector_io import VectorIO
37-
from llama_stack.core.datatypes import Provider, StackRunConfig
37+
from llama_stack.core.datatypes import Provider, RegistryEntrySource, StackRunConfig
3838
from llama_stack.core.distribution import get_provider_registry
3939
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
4040
from llama_stack.core.providers import ProviderImpl, ProviderImplConfig
@@ -112,10 +112,25 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
112112
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
113113
continue
114114

115-
# we want to maintain the type information in arguments to method.
116-
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
117-
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
118-
await method(**{k: getattr(obj, k) for k in obj.model_dump().keys()})
115+
# For models, use the register_model method with config source
116+
if rsrc == "models":
117+
logger.debug(
118+
f"Registering model from config: {obj.model_id} -> {obj.provider_model_id} via {obj.provider_id}"
119+
)
120+
await impls[api].register_model(
121+
model_id=obj.model_id,
122+
provider_model_id=obj.provider_model_id,
123+
provider_id=obj.provider_id,
124+
metadata=obj.metadata,
125+
model_type=obj.model_type,
126+
source=RegistryEntrySource.from_config,
127+
)
128+
logger.debug(f"Model registration completed for: {obj.model_id}")
129+
else:
130+
# we want to maintain the type information in arguments to method.
131+
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
132+
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
133+
await method(**{k: getattr(obj, k) for k in obj.model_dump().keys()})
119134

120135
method = getattr(impls[api], list_method)
121136
response = await method()
@@ -303,6 +318,14 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
303318
impls[Api.providers] = providers_impl
304319

305320

321+
async def cleanup_provider_models_on_startup(impls: dict[Api, Any]) -> None:
322+
"""Clean up provider models from previous sessions on startup."""
323+
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
324+
for routing_table in routing_tables:
325+
if hasattr(routing_table, "cleanup_ephemeral_models"):
326+
await routing_table.cleanup_ephemeral_models()
327+
328+
306329
# Produces a stack of providers for the given run config. Not all APIs may be
307330
# asked for in the run config.
308331
async def construct_stack(
@@ -328,6 +351,8 @@ async def construct_stack(
328351

329352
await register_resources(run_config, impls)
330353

354+
# Clean up ephemeral models from previous sessions before first refresh
355+
await cleanup_provider_models_on_startup(impls)
331356
await refresh_registry_once(impls)
332357

333358
global REGISTRY_REFRESH_TASK

0 commit comments

Comments
 (0)