Skip to content

Commit cf85fe9

Browse files
committed
fix: list models only for active providers
There has been an error rolling around where we can retrieve a model when doing something like a chat completion but then we hit issues when trying to associate that model with an active provider. This is a common thing that happens when: 1. you run the stack with say remote::ollama 2. you register a model, say llama3.2:3b 3. you do some completions, etc 4. you kill the server 5. you `unset OLLAMA_URL` 6. you re-start the stack 7. you do `llama-stack-client models list` ``` ├───────────────┼──────────────────────────────────────────────────────────────────────────────────┼──────────────────────────────────────────────────────────────────────┼───────────────────────────────────────┼──────────────────────────┤ │ embedding │ all-minilm │ all-minilm:l6-v2 │ {'embedding_dimension': 384.0, │ ollama │ │ │ │ │ 'context_length': 512.0} │ │ ├───────────────┼──────────────────────────────────────────────────────────────────────────────────┼──────────────────────────────────────────────────────────────────────┼───────────────────────────────────────┼──────────────────────────┤ │ llm │ llama3.2:3b │ llama3.2:3b │ │ ollama │ ├───────────────┼──────────────────────────────────────────────────────────────────────────────────┼──────────────────────────────────────────────────────────────────────┼───────────────────────────────────────┼──────────────────────────┤ │ embedding │ ollama/all-minilm:l6-v2 │ all-minilm:l6-v2 │ {'embedding_dimension': 384.0, │ ollama │ │ │ │ │ 'context_length': 512.0} │ │ ├───────────────┼──────────────────────────────────────────────────────────────────────────────────┼──────────────────────────────────────────────────────────────────────┼───────────────────────────────────────┼──────────────────────────┤ │ llm │ ollama/llama3.2:3b │ llama3.2:3b │ │ ollama │ ├───────────────┼──────────────────────────────────────────────────────────────────────────────────┼──────────────────────────────────────────────────────────────────────┼───────────────────────────────────────┼──────────────────────────┤ ``` This shouldn't be happening, `ollama` isn't a provider running, and the only reason the model is popping up is because its in the dist_registry (on disk). While its nice to have this static store so that if I go and `export OLLAMA_URL=..` again, it can read from the store, it shouldn't _always_ be reading and returning these models from the store now if you `llama-stack-client models list` with this change, no more llama3.2:3b appears. Signed-off-by: Charlie Doern <[email protected]>
1 parent 46ff302 commit cf85fe9

File tree

9 files changed

+25
-17
lines changed

9 files changed

+25
-17
lines changed

llama_stack/core/routers/vector_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ async def openai_list_vector_stores(
171171
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
172172
# Route to default provider for now - could aggregate from all providers in the future
173173
# call retrieve on each vector dbs to get list of vector stores
174-
vector_dbs = await self.routing_table.get_all_with_type("vector_db")
174+
vector_dbs = await self.routing_table.get_all_with_type_fitered("vector_db")
175175
all_stores = []
176176
for vector_db in vector_dbs:
177177
try:

llama_stack/core/routing_tables/benchmarks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
2121
async def list_benchmarks(self) -> ListBenchmarksResponse:
22-
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
22+
return ListBenchmarksResponse(data=await self.get_all_with_type_filtered("benchmark"))
2323

2424
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
2525
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)

llama_stack/core/routing_tables/common.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -232,15 +232,21 @@ async def assert_action_allowed(
232232

233233
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
234234
objs = await self.dist_registry.get_all()
235-
filtered_objs = [obj for obj in objs if obj.type == type]
235+
return [obj for obj in objs if obj.type == type]
236+
237+
async def get_all_with_type_filtered(self, type: str) -> list[RoutableObjectWithProvider]:
238+
all_objs = self.get_all_with_type(type=type)
236239

237240
# Apply attribute-based access control filtering
238-
if filtered_objs:
239-
filtered_objs = [
240-
obj for obj in filtered_objs if is_action_allowed(self.policy, "read", obj, get_authenticated_user())
241+
if all_objs:
242+
all_objs = [
243+
obj
244+
for obj in all_objs
245+
if is_action_allowed(self.policy, "read", obj, get_authenticated_user())
246+
and obj.provider_id in self.impls_by_provider_id
241247
]
242248

243-
return filtered_objs
249+
return all_objs
244250

245251

246252
async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model:
@@ -257,7 +263,7 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) ->
257263
)
258264
# if not found, this means model_id is an unscoped provider_model_id, we need
259265
# to iterate (given a lack of an efficient index on the KVStore)
260-
models = await routing_table.get_all_with_type("model")
266+
models = await routing_table.get_all_with_type_filtered("model")
261267
matching_models = [m for m in models if m.provider_resource_id == model_id]
262268
if len(matching_models) == 0:
263269
raise ModelNotFoundError(model_id)

llama_stack/core/routing_tables/datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
3333
async def list_datasets(self) -> ListDatasetsResponse:
34-
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
34+
return ListDatasetsResponse(data=await self.get_all_with_type_filtered(ResourceType.dataset.value))
3535

3636
async def get_dataset(self, dataset_id: str) -> Dataset:
3737
dataset = await self.get_object_by_identifier("dataset", dataset_id)

llama_stack/core/routing_tables/models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ async def refresh(self) -> None:
4343
await self.update_registered_models(provider_id, models)
4444

4545
async def list_models(self) -> ListModelsResponse:
46-
return ListModelsResponse(data=await self.get_all_with_type("model"))
46+
return ListModelsResponse(data=await self.get_all_with_type_filtered("model"))
4747

4848
async def openai_list_models(self) -> OpenAIListModelsResponse:
49-
models = await self.get_all_with_type("model")
49+
models = await self.get_all_with_type_filtered("model")
5050
openai_models = [
5151
OpenAIModel(
5252
id=model.identifier,
@@ -122,7 +122,7 @@ async def update_registered_models(
122122
provider_id: str,
123123
models: list[Model],
124124
) -> None:
125-
existing_models = await self.get_all_with_type("model")
125+
existing_models = await self.get_all_with_type_filtered("model")
126126

127127
# we may have an alias for the model registered by the user (or during initialization
128128
# from run.yaml) that we need to keep track of

llama_stack/core/routing_tables/scoring_functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424

2525
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
2626
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
27-
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
27+
return ListScoringFunctionsResponse(
28+
data=await self.get_all_with_type_filtered(ResourceType.scoring_function.value)
29+
)
2830

2931
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
3032
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)

llama_stack/core/routing_tables/shields.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
2222
async def list_shields(self) -> ListShieldsResponse:
23-
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
23+
return ListShieldsResponse(data=await self.get_all_with_type_filtered(ResourceType.shield.value))
2424

2525
async def get_shield(self, identifier: str) -> Shield:
2626
shield = await self.get_object_by_identifier("shield", identifier)

llama_stack/core/routing_tables/toolgroups.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse
4949
toolgroup_id = group_id
5050
toolgroups = [await self.get_tool_group(toolgroup_id)]
5151
else:
52-
toolgroups = await self.get_all_with_type("tool_group")
52+
toolgroups = await self.get_all_with_type_filtered("tool_group")
5353

5454
all_tools = []
5555
for toolgroup in toolgroups:
@@ -83,7 +83,7 @@ async def _index_tools(self, toolgroup: ToolGroup):
8383
self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier
8484

8585
async def list_tool_groups(self) -> ListToolGroupsResponse:
86-
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
86+
return ListToolGroupsResponse(data=await self.get_all_with_type_filtered("tool_group"))
8787

8888
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
8989
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)

llama_stack/core/routing_tables/vector_dbs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
3737
async def list_vector_dbs(self) -> ListVectorDBsResponse:
38-
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
38+
return ListVectorDBsResponse(data=await self.get_all_with_type_filtered("vector_db"))
3939

4040
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
4141
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)

0 commit comments

Comments
 (0)