diff --git a/llama_stack/core/routers/vector_io.py b/llama_stack/core/routers/vector_io.py index 3d0996c491..7e2fc456c5 100644 --- a/llama_stack/core/routers/vector_io.py +++ b/llama_stack/core/routers/vector_io.py @@ -171,7 +171,7 @@ async def openai_list_vector_stores( logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}") # Route to default provider for now - could aggregate from all providers in the future # call retrieve on each vector dbs to get list of vector stores - vector_dbs = await self.routing_table.get_all_with_type("vector_db") + vector_dbs = await self.routing_table.get_all_with_type_filtered("vector_db") all_stores = [] for vector_db in vector_dbs: try: diff --git a/llama_stack/core/routing_tables/benchmarks.py b/llama_stack/core/routing_tables/benchmarks.py index 74bee80402..59e3025be6 100644 --- a/llama_stack/core/routing_tables/benchmarks.py +++ b/llama_stack/core/routing_tables/benchmarks.py @@ -19,7 +19,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): async def list_benchmarks(self) -> ListBenchmarksResponse: - return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark")) + return ListBenchmarksResponse(data=await self.get_all_with_type_filtered("benchmark")) async def get_benchmark(self, benchmark_id: str) -> Benchmark: benchmark = await self.get_object_by_identifier("benchmark", benchmark_id) diff --git a/llama_stack/core/routing_tables/common.py b/llama_stack/core/routing_tables/common.py index 339ff6da43..6131a15738 100644 --- a/llama_stack/core/routing_tables/common.py +++ b/llama_stack/core/routing_tables/common.py @@ -232,15 +232,21 @@ async def assert_action_allowed( async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]: objs = await self.dist_registry.get_all() - filtered_objs = [obj for obj in objs if obj.type == type] + return [obj for obj in objs if obj.type == type] + + async def get_all_with_type_filtered(self, type: str) -> list[RoutableObjectWithProvider]: + all_objs = await self.get_all_with_type(type=type) # Apply attribute-based access control filtering - if filtered_objs: - filtered_objs = [ - obj for obj in filtered_objs if is_action_allowed(self.policy, "read", obj, get_authenticated_user()) + if all_objs: + all_objs = [ + obj + for obj in all_objs + if is_action_allowed(self.policy, "read", obj, get_authenticated_user()) + and obj.provider_id in self.impls_by_provider_id ] - return filtered_objs + return all_objs async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model: @@ -257,7 +263,7 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> ) # if not found, this means model_id is an unscoped provider_model_id, we need # to iterate (given a lack of an efficient index on the KVStore) - models = await routing_table.get_all_with_type("model") + models = await routing_table.get_all_with_type_filtered("model") matching_models = [m for m in models if m.provider_resource_id == model_id] if len(matching_models) == 0: raise ModelNotFoundError(model_id) diff --git a/llama_stack/core/routing_tables/datasets.py b/llama_stack/core/routing_tables/datasets.py index fc6a75df41..51fa2bec6f 100644 --- a/llama_stack/core/routing_tables/datasets.py +++ b/llama_stack/core/routing_tables/datasets.py @@ -31,7 +31,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def list_datasets(self) -> ListDatasetsResponse: - return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) + return ListDatasetsResponse(data=await self.get_all_with_type_filtered(ResourceType.dataset.value)) async def get_dataset(self, dataset_id: str) -> Dataset: dataset = await self.get_object_by_identifier("dataset", dataset_id) diff --git a/llama_stack/core/routing_tables/models.py b/llama_stack/core/routing_tables/models.py index 34c431e007..8155ccc332 100644 --- a/llama_stack/core/routing_tables/models.py +++ b/llama_stack/core/routing_tables/models.py @@ -43,10 +43,10 @@ async def refresh(self) -> None: await self.update_registered_models(provider_id, models) async def list_models(self) -> ListModelsResponse: - return ListModelsResponse(data=await self.get_all_with_type("model")) + return ListModelsResponse(data=await self.get_all_with_type_filtered("model")) async def openai_list_models(self) -> OpenAIListModelsResponse: - models = await self.get_all_with_type("model") + models = await self.get_all_with_type_filtered("model") openai_models = [ OpenAIModel( id=model.identifier, @@ -122,7 +122,7 @@ async def update_registered_models( provider_id: str, models: list[Model], ) -> None: - existing_models = await self.get_all_with_type("model") + existing_models = await self.get_all_with_type_filtered("model") # we may have an alias for the model registered by the user (or during initialization # from run.yaml) that we need to keep track of diff --git a/llama_stack/core/routing_tables/scoring_functions.py b/llama_stack/core/routing_tables/scoring_functions.py index 5874ba9411..9907522c42 100644 --- a/llama_stack/core/routing_tables/scoring_functions.py +++ b/llama_stack/core/routing_tables/scoring_functions.py @@ -24,7 +24,9 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): async def list_scoring_functions(self) -> ListScoringFunctionsResponse: - return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) + return ListScoringFunctionsResponse( + data=await self.get_all_with_type_filtered(ResourceType.scoring_function.value) + ) async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id) diff --git a/llama_stack/core/routing_tables/shields.py b/llama_stack/core/routing_tables/shields.py index e08f35bfc4..0e9d78a310 100644 --- a/llama_stack/core/routing_tables/shields.py +++ b/llama_stack/core/routing_tables/shields.py @@ -20,7 +20,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> ListShieldsResponse: - return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) + return ListShieldsResponse(data=await self.get_all_with_type_filtered(ResourceType.shield.value)) async def get_shield(self, identifier: str) -> Shield: shield = await self.get_object_by_identifier("shield", identifier) diff --git a/llama_stack/core/routing_tables/toolgroups.py b/llama_stack/core/routing_tables/toolgroups.py index 6910b39060..82f7cb8197 100644 --- a/llama_stack/core/routing_tables/toolgroups.py +++ b/llama_stack/core/routing_tables/toolgroups.py @@ -49,7 +49,7 @@ async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse toolgroup_id = group_id toolgroups = [await self.get_tool_group(toolgroup_id)] else: - toolgroups = await self.get_all_with_type("tool_group") + toolgroups = await self.get_all_with_type_filtered("tool_group") all_tools = [] for toolgroup in toolgroups: @@ -83,7 +83,7 @@ async def _index_tools(self, toolgroup: ToolGroup): self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier async def list_tool_groups(self) -> ListToolGroupsResponse: - return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) + return ListToolGroupsResponse(data=await self.get_all_with_type_filtered("tool_group")) async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id) diff --git a/llama_stack/core/routing_tables/vector_dbs.py b/llama_stack/core/routing_tables/vector_dbs.py index e8dc469978..b2f252dcb6 100644 --- a/llama_stack/core/routing_tables/vector_dbs.py +++ b/llama_stack/core/routing_tables/vector_dbs.py @@ -35,7 +35,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): async def list_vector_dbs(self) -> ListVectorDBsResponse: - return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db")) + return ListVectorDBsResponse(data=await self.get_all_with_type_filtered("vector_db")) async def get_vector_db(self, vector_db_id: str) -> VectorDB: vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)