Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ads/aqua/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,6 +1111,9 @@ def list(
aqua_models = []
inference_containers = self.get_container_config().to_dict().get("inference")
for model in models:
freeform_tags = model.freeform_tags or {}
if Tags.AQUA_TAG not in freeform_tags:
continue
aqua_models.append(
AquaModelSummary(
**self._process_model(
Expand Down
41 changes: 41 additions & 0 deletions tests/unitary/with_extras/aqua/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1654,3 +1654,44 @@ def test_build_search_text(self, description, tags, expected_output):
self.app._build_search_text(tags=tags, description=description)
== expected_output
)

@pytest.mark.parametrize(
"remove_indices, expected_len",
[
([], 2), # All models have AQUA_TAG -> include both
([1], 1), # Second model missing AQUA_TAG -> include first only
([0, 1], 0), # Both missing AQUA_TAG -> include none
],
)
@patch.object(AquaApp, "get_container_config")
def test_list_service_models_filters_missing_aqua_tag(
self,
mock_get_container_config,
remove_indices,
expected_len,
):
"""Ensure list() excludes models that do not have AQUA_TAG in freeform_tags."""
mock_get_container_config.return_value = get_container_config()

import copy

items = copy.deepcopy(TestDataset.model_summary_objects)
for idx in remove_indices:
# remove AQUA tag entirely to validate filter behavior
items[idx]["freeform_tags"].pop("OCI_AQUA", None)

self.app.list_resource = MagicMock(
return_value=[
oci.data_science.models.ModelSummary(**item) for item in items
]
)

# Clear service models cache
self.app.clear_model_list_cache()

results = self.app.list(
compartment_id=TestDataset.SERVICE_COMPARTMENT_ID,
category=ads.config.SERVICE,
)

assert len(results) == expected_len
Loading