Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llama_stack/core/routers/vector_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ async def openai_list_vector_stores(
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
# Route to default provider for now - could aggregate from all providers in the future
# call retrieve on each vector dbs to get list of vector stores
vector_dbs = await self.routing_table.get_all_with_type("vector_db")
vector_dbs = await self.routing_table.get_all_with_type_filtered("vector_db")
all_stores = []
for vector_db in vector_dbs:
try:
Expand Down
2 changes: 1 addition & 1 deletion llama_stack/core/routing_tables/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
async def list_benchmarks(self) -> ListBenchmarksResponse:
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
return ListBenchmarksResponse(data=await self.get_all_with_type_filtered("benchmark"))

async def get_benchmark(self, benchmark_id: str) -> Benchmark:
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
Expand Down
18 changes: 12 additions & 6 deletions llama_stack/core/routing_tables/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,15 +232,21 @@ async def assert_action_allowed(

async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
filtered_objs = [obj for obj in objs if obj.type == type]
return [obj for obj in objs if obj.type == type]

async def get_all_with_type_filtered(self, type: str) -> list[RoutableObjectWithProvider]:
all_objs = await self.get_all_with_type(type=type)

# Apply attribute-based access control filtering
if filtered_objs:
filtered_objs = [
obj for obj in filtered_objs if is_action_allowed(self.policy, "read", obj, get_authenticated_user())
if all_objs:
all_objs = [
obj
for obj in all_objs
if is_action_allowed(self.policy, "read", obj, get_authenticated_user())
and obj.provider_id in self.impls_by_provider_id
]

return filtered_objs
return all_objs


async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model:
Expand All @@ -257,7 +263,7 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) ->
)
# if not found, this means model_id is an unscoped provider_model_id, we need
# to iterate (given a lack of an efficient index on the KVStore)
models = await routing_table.get_all_with_type("model")
models = await routing_table.get_all_with_type_filtered("model")
matching_models = [m for m in models if m.provider_resource_id == model_id]
if len(matching_models) == 0:
raise ModelNotFoundError(model_id)
Expand Down
2 changes: 1 addition & 1 deletion llama_stack/core/routing_tables/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
return ListDatasetsResponse(data=await self.get_all_with_type_filtered(ResourceType.dataset.value))

async def get_dataset(self, dataset_id: str) -> Dataset:
dataset = await self.get_object_by_identifier("dataset", dataset_id)
Expand Down
6 changes: 3 additions & 3 deletions llama_stack/core/routing_tables/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ async def refresh(self) -> None:
await self.update_registered_models(provider_id, models)

async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
return ListModelsResponse(data=await self.get_all_with_type_filtered("model"))

async def openai_list_models(self) -> OpenAIListModelsResponse:
models = await self.get_all_with_type("model")
models = await self.get_all_with_type_filtered("model")
openai_models = [
OpenAIModel(
id=model.identifier,
Expand Down Expand Up @@ -122,7 +122,7 @@ async def update_registered_models(
provider_id: str,
models: list[Model],
) -> None:
existing_models = await self.get_all_with_type("model")
existing_models = await self.get_all_with_type_filtered("model")

# we may have an alias for the model registered by the user (or during initialization
# from run.yaml) that we need to keep track of
Expand Down
4 changes: 3 additions & 1 deletion llama_stack/core/routing_tables/scoring_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@

class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
return ListScoringFunctionsResponse(
data=await self.get_all_with_type_filtered(ResourceType.scoring_function.value)
)

async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
Expand Down
2 changes: 1 addition & 1 deletion llama_stack/core/routing_tables/shields.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
return ListShieldsResponse(data=await self.get_all_with_type_filtered(ResourceType.shield.value))

async def get_shield(self, identifier: str) -> Shield:
shield = await self.get_object_by_identifier("shield", identifier)
Expand Down
4 changes: 2 additions & 2 deletions llama_stack/core/routing_tables/toolgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse
toolgroup_id = group_id
toolgroups = [await self.get_tool_group(toolgroup_id)]
else:
toolgroups = await self.get_all_with_type("tool_group")
toolgroups = await self.get_all_with_type_filtered("tool_group")

all_tools = []
for toolgroup in toolgroups:
Expand Down Expand Up @@ -83,7 +83,7 @@ async def _index_tools(self, toolgroup: ToolGroup):
self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier

async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
return ListToolGroupsResponse(data=await self.get_all_with_type_filtered("tool_group"))

async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
Expand Down
2 changes: 1 addition & 1 deletion llama_stack/core/routing_tables/vector_dbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
async def list_vector_dbs(self) -> ListVectorDBsResponse:
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
return ListVectorDBsResponse(data=await self.get_all_with_type_filtered("vector_db"))

async def get_vector_db(self, vector_db_id: str) -> VectorDB:
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
Expand Down
Loading