3434from llama_stack .apis .tools import RAGToolRuntime , ToolGroups , ToolRuntime
3535from llama_stack .apis .vector_dbs import VectorDBs
3636from llama_stack .apis .vector_io import VectorIO
37- from llama_stack .core .datatypes import Provider , StackRunConfig
37+ from llama_stack .core .datatypes import Provider , RegistryEntrySource , StackRunConfig
3838from llama_stack .core .distribution import get_provider_registry
3939from llama_stack .core .inspect import DistributionInspectConfig , DistributionInspectImpl
4040from llama_stack .core .providers import ProviderImpl , ProviderImplConfig
@@ -112,10 +112,25 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
112112 logger .debug (f"Skipping { rsrc .capitalize ()} registration for disabled provider." )
113113 continue
114114
115- # we want to maintain the type information in arguments to method.
116- # instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
117- # we use model_dump() to find all the attrs and then getattr to get the still typed value.
118- await method (** {k : getattr (obj , k ) for k in obj .model_dump ().keys ()})
115+ # For models, use the register_model method with config source
116+ if rsrc == "models" :
117+ logger .debug (
118+ f"Registering model from config: { obj .model_id } -> { obj .provider_model_id } via { obj .provider_id } "
119+ )
120+ await impls [api ].register_model (
121+ model_id = obj .model_id ,
122+ provider_model_id = obj .provider_model_id ,
123+ provider_id = obj .provider_id ,
124+ metadata = obj .metadata ,
125+ model_type = obj .model_type ,
126+ source = RegistryEntrySource .from_config ,
127+ )
128+ logger .debug (f"Model registration completed for: { obj .model_id } " )
129+ else :
130+ # we want to maintain the type information in arguments to method.
131+ # instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
132+ # we use model_dump() to find all the attrs and then getattr to get the still typed value.
133+ await method (** {k : getattr (obj , k ) for k in obj .model_dump ().keys ()})
119134
120135 method = getattr (impls [api ], list_method )
121136 response = await method ()
@@ -303,6 +318,14 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
303318 impls [Api .providers ] = providers_impl
304319
305320
321+ async def cleanup_provider_models_on_startup (impls : dict [Api , Any ]) -> None :
322+ """Clean up provider models from previous sessions on startup."""
323+ routing_tables = [v for v in impls .values () if isinstance (v , CommonRoutingTableImpl )]
324+ for routing_table in routing_tables :
325+ if hasattr (routing_table , "cleanup_ephemeral_models" ):
326+ await routing_table .cleanup_ephemeral_models ()
327+
328+
306329# Produces a stack of providers for the given run config. Not all APIs may be
307330# asked for in the run config.
308331async def construct_stack (
@@ -328,6 +351,8 @@ async def construct_stack(
328351
329352 await register_resources (run_config , impls )
330353
354+ # Clean up ephemeral models from previous sessions before first refresh
355+ await cleanup_provider_models_on_startup (impls )
331356 await refresh_registry_once (impls )
332357
333358 global REGISTRY_REFRESH_TASK
0 commit comments