From f137f6baa67a6cc1e0ad921f4e19eac5d1e2bfec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristian=20Zar=C4=99bski?= Date: Thu, 13 Mar 2025 14:35:44 +0000 Subject: [PATCH 1/4] Added sorting functionality to low level API --- simvue/api/objects/alert/base.py | 17 ++++++++- simvue/api/objects/alert/fetch.py | 21 ++++++++++- simvue/api/objects/artifact/fetch.py | 30 +++++++++++++--- simvue/api/objects/base.py | 25 +++++++++++-- simvue/api/objects/events.py | 1 + simvue/api/objects/folder.py | 29 ++++++++++++++- simvue/api/objects/run.py | 54 +++++++++++++++++++++++++++- simvue/api/objects/tag.py | 35 ++++++++++++++++-- tests/unit/test_user_alert.py | 2 +- 9 files changed, 199 insertions(+), 15 deletions(-) diff --git a/simvue/api/objects/alert/base.py b/simvue/api/objects/alert/base.py index 0204f6d3..4799cd1c 100644 --- a/simvue/api/objects/alert/base.py +++ b/simvue/api/objects/alert/base.py @@ -8,11 +8,12 @@ import http import pydantic +import datetime import typing from simvue.api.objects.base import SimvueObject, staging_check, write_only from simvue.api.request import get as sv_get, get_json_from_response from simvue.api.url import URL -from simvue.models import NAME_REGEX +from simvue.models import NAME_REGEX, DATETIME_FORMAT class AlertBase(SimvueObject): @@ -125,6 +126,20 @@ def abort(self) -> bool: """Retrieve if alert can abort simulations""" return self._get_attribute("abort") + @property + @staging_check + def delay(self) -> int: + """Retrieve delay value for this alert""" + return self._get_attribute("delay") + + @property + def created(self) -> datetime.datetime | None: + """Retrieve created datetime for the alert""" + _created: str | None = self._get_attribute("created") + return ( + datetime.datetime.strptime(_created, DATETIME_FORMAT) if _created else None + ) + @abort.setter @write_only @pydantic.validate_call diff --git a/simvue/api/objects/alert/fetch.py b/simvue/api/objects/alert/fetch.py index 6b5e089e..05cc132d 100644 --- a/simvue/api/objects/alert/fetch.py +++ b/simvue/api/objects/alert/fetch.py @@ -8,10 +8,12 @@ import typing import http +import json import pydantic from simvue.api.objects.alert.user import UserAlert +from simvue.api.objects.base import Sort from simvue.api.request import get_json_from_response from simvue.api.request import get as sv_get from .events import EventsAlert @@ -21,6 +23,15 @@ AlertType = EventsAlert | UserAlert | MetricsThresholdAlert | MetricsRangeAlert +class AlertSort(Sort): + @pydantic.field_validator("column") + @classmethod + def check_column(cls, column: str) -> str: + if column and column not in ("name", "created"): + raise ValueError(f"Invalid sort column for alerts '{column}'") + return column + + class Alert: """Generic Simvue alert retrieval class""" @@ -50,11 +61,13 @@ def __new__(cls, identifier: str, **kwargs) -> AlertType: raise RuntimeError(f"Unknown source type '{_alert_pre.source}'") @classmethod + @pydantic.validate_call def get( cls, offline: bool = False, count: int | None = None, offset: int | None = None, + sorting: list[AlertSort] | None = None, **kwargs, ) -> typing.Generator[tuple[str, AlertType], None, None]: """Fetch all alerts from the server for the current user. @@ -65,6 +78,8 @@ def get( limit the number of results, default of None returns all. offset : int, optional start index for returned results, default of None starts at 0. + sorting : list[dict] | None, optional + list of sorting definitions in the form {'column': str, 'descending': bool} Yields ------ @@ -80,11 +95,15 @@ def get( _class_instance = AlertBase(_local=True, _read_only=True) _url = f"{_class_instance._base_url}" + _params: dict[str, int | str] = {"start": offset, "count": count} + + if sorting: + _params["sorting"] = json.dumps([sort.to_params() for sort in sorting]) _response = sv_get( _url, headers=_class_instance._headers, - params={"start": offset, "count": count} | kwargs, + params=_params | kwargs, ) _label: str = _class_instance.__class__.__name__.lower() diff --git a/simvue/api/objects/artifact/fetch.py b/simvue/api/objects/artifact/fetch.py index 1d571266..36334226 100644 --- a/simvue/api/objects/artifact/fetch.py +++ b/simvue/api/objects/artifact/fetch.py @@ -1,17 +1,31 @@ +import http +import typing +import pydantic +import json + from simvue.api.objects.artifact.base import ArtifactBase +from simvue.api.objects.base import Sort from .file import FileArtifact from simvue.api.objects.artifact.object import ObjectArtifact from simvue.api.request import get_json_from_response, get as sv_get from simvue.api.url import URL from simvue.exception import ObjectNotFoundError -import http -import typing -import pydantic __all__ = ["Artifact"] +class ArtifactSort(Sort): + @pydantic.field_validator("column") + @classmethod + def check_column(cls, column: str) -> str: + if column and ( + column not in ("name", "created") and not column.startswith("metadata.") + ): + raise ValueError(f"Invalid sort column for artifacts '{column}'") + return column + + class Artifact: """Generic Simvue artifact retrieval class""" @@ -119,6 +133,7 @@ def get( cls, count: int | None = None, offset: int | None = None, + sorting: list[ArtifactSort] | None = None, **kwargs, ) -> typing.Generator[tuple[str, FileArtifact | ObjectArtifact], None, None]: """Returns artifacts associated with the current user. @@ -129,6 +144,8 @@ def get( limit the number of results, default of None returns all. offset : int, optional start index for returned results, default of None starts at 0. + sorting : list[dict] | None, optional + list of sorting definitions in the form {'column': str, 'descending': bool} Yields ------ @@ -139,10 +156,15 @@ def get( _class_instance = ArtifactBase(_local=True, _read_only=True) _url = f"{_class_instance._base_url}" + _params = {"start": offset, "count": count} + + if sorting: + _params["sorting"] = json.dumps([sort.to_params() for sort in sorting]) + _response = sv_get( _url, headers=_class_instance._headers, - params={"start": offset, "count": count} | kwargs, + params=_params | kwargs, ) _label: str = _class_instance.__class__.__name__.lower() _label = _label.replace("base", "") diff --git a/simvue/api/objects/base.py b/simvue/api/objects/base.py index f9098d2d..edafe58a 100644 --- a/simvue/api/objects/base.py +++ b/simvue/api/objects/base.py @@ -128,6 +128,14 @@ def tenant(self, tenant: bool) -> None: self._update_visibility("tenant", tenant) +class Sort(pydantic.BaseModel): + column: str + descending: bool = True + + def to_params(self) -> dict[str, str]: + return {"id": self.column, "desc": self.descending} + + class SimvueObject(abc.ABC): def __init__( self, @@ -323,7 +331,13 @@ def get( **kwargs, ) -> typing.Generator[tuple[str, T | None], None, None]: _class_instance = cls(_read_only=True, _local=True) - if (_data := cls._get_all_objects(count, offset, **kwargs).get("data")) is None: + if ( + _data := cls._get_all_objects( + count=count, + offset=offset, + **kwargs, + ).get("data") + ) is None: raise RuntimeError( f"Expected key 'data' for retrieval of {_class_instance.__class__.__name__.lower()}s" ) @@ -350,14 +364,19 @@ def count(cls, **kwargs) -> int: @classmethod def _get_all_objects( - cls, count: int | None, offset: int | None, **kwargs + cls, + count: int | None, + offset: int | None, + **kwargs, ) -> dict[str, typing.Any]: _class_instance = cls(_read_only=True) _url = f"{_class_instance._base_url}" + _params: dict[str, int | str] = {"start": offset, "count": count} + _response = sv_get( _url, headers=_class_instance._headers, - params={"start": offset, "count": count} | kwargs, + params=_params | kwargs, ) _label = _class_instance.__class__.__name__.lower() diff --git a/simvue/api/objects/events.py b/simvue/api/objects/events.py index b330501b..76ba120b 100644 --- a/simvue/api/objects/events.py +++ b/simvue/api/objects/events.py @@ -44,6 +44,7 @@ def get( **kwargs, ) -> typing.Generator[EventSet, None, None]: _class_instance = cls(_read_only=True, _local=True) + if ( _data := cls._get_all_objects(count, offset, run=run_id, **kwargs).get( "data" diff --git a/simvue/api/objects/folder.py b/simvue/api/objects/folder.py index b0313e42..5ada05f0 100644 --- a/simvue/api/objects/folder.py +++ b/simvue/api/objects/folder.py @@ -16,9 +16,25 @@ from simvue.exception import ObjectNotFoundError -from .base import SimvueObject, staging_check, write_only +from .base import SimvueObject, staging_check, write_only, Sort from simvue.models import FOLDER_REGEX, DATETIME_FORMAT +# Need to use this inside of Generator typing to fix bug present in Python 3.10 - see issue #745 +T = typing.TypeVar("T", bound="Folder") + + +class FolderSort(Sort): + @pydantic.field_validator("column") + @classmethod + def check_column(cls, column: str) -> str: + if ( + column + and column not in ("created", "modified", "path") + and not column.startswith("metadata.") + ): + raise ValueError(f"Invalid sort column for folders '{column}") + return column + class Folder(SimvueObject): """ @@ -60,6 +76,17 @@ def new( """Create a new Folder on the Simvue server with the given path""" return Folder(path=path, _read_only=False, _offline=offline, **kwargs) + @classmethod + @pydantic.validate_call + def get( + cls, + count: pydantic.PositiveInt | None = None, + offset: pydantic.NonNegativeInt | None = None, + sorting: list[FolderSort] | None = None, + **kwargs, + ) -> typing.Generator[tuple[str, T | None], None, None]: + return super().get(count=count, offset=offset, sorting=sorting) + @property @staging_check def tags(self) -> list[str]: diff --git a/simvue/api/objects/run.py b/simvue/api/objects/run.py index 674e0de4..7d076f11 100644 --- a/simvue/api/objects/run.py +++ b/simvue/api/objects/run.py @@ -12,13 +12,14 @@ import pydantic import datetime import time +import json try: from typing import Self except ImportError: from typing_extensions import Self -from .base import SimvueObject, staging_check, Visibility, write_only +from .base import SimvueObject, Sort, staging_check, Visibility, write_only from simvue.api.request import ( get as sv_get, put as sv_put, @@ -31,9 +32,27 @@ "lost", "failed", "completed", "terminated", "running", "created" ] +# Need to use this inside of Generator typing to fix bug present in Python 3.10 - see issue #745 +T = typing.TypeVar("T", bound="Run") + __all__ = ["Run"] +class RunSort(Sort): + @pydantic.field_validator("column") + @classmethod + def check_column(cls, column: str) -> str: + if ( + column + and column != "name" + and not column.startswith("metrics") + and not column.startswith("metadata.") + ): + raise ValueError(f"Invalid sort column for runs '{column}") + + return column + + class Run(SimvueObject): """Class for interacting with/creating runs on the server.""" @@ -243,6 +262,39 @@ def get_alert_details(self) -> typing.Generator[dict[str, typing.Any], None, Non for alert in self._get_attribute("alerts"): yield alert["alert"] + @classmethod + @pydantic.validate_call + def get( + cls, + count: pydantic.PositiveInt | None = None, + offset: pydantic.NonNegativeInt | None = None, + sorting: list[RunSort] | None = None, + **kwargs, + ) -> typing.Generator[tuple[str, T | None], None, None]: + """Get runs from the server. + + Parameters + ---------- + count : int, optional + limit the number of objects returned, default no limit. + offset : int, optional + start index for results, default is 0. + sorting : list[dict] | None, optional + list of sorting definitions in the form {'column': str, 'descending': bool} + + Yields + ------ + tuple[str, Run] + id of run + Run object representing object on server + """ + _params: dict[str, str] = {} + + if sorting: + _params["sorting"] = json.dumps([i.to_params() for i in sorting]) + + return super().get(count=count, offset=offset, **_params) + @alerts.setter @write_only @pydantic.validate_call diff --git a/simvue/api/objects/tag.py b/simvue/api/objects/tag.py index f2ca969a..da9e1903 100644 --- a/simvue/api/objects/tag.py +++ b/simvue/api/objects/tag.py @@ -9,13 +9,24 @@ import pydantic.color import typing +import json import datetime -from .base import SimvueObject, staging_check, write_only + +from simvue.api.objects.base import SimvueObject, Sort, staging_check, write_only from simvue.models import DATETIME_FORMAT __all__ = ["Tag"] +class TagSort(Sort): + @pydantic.field_validator("column") + @classmethod + def check_column(cls, column: str) -> str: + if column and column not in ("created", "name"): + raise ValueError(f"Invalid sort column for tags '{column}") + return column + + class Tag(SimvueObject): """Class for creation/interaction with tag object on server""" @@ -87,8 +98,14 @@ def created(self) -> datetime.datetime | None: ) @classmethod + @pydantic.validate_call def get( - cls, *, count: int | None = None, offset: int | None = None, **kwargs + cls, + *, + count: int | None = None, + offset: int | None = None, + sorting: list[TagSort] | None = None, + **kwargs, ) -> typing.Generator[tuple[str, "SimvueObject"], None, None]: """Get tags from the server. @@ -98,6 +115,8 @@ def get( limit the number of objects returned, default no limit. offset : int, optional start index for results, default is 0. + sorting : list[dict] | None, optional + list of sorting definitions in the form {'column': str, 'descending': bool} Yields ------ @@ -108,4 +127,14 @@ def get( # There are currently no tag filters kwargs.pop("filters", None) - return super().get(count=count, offset=offset, **kwargs) + _params: dict[str, str] = {} + + if sorting: + _params["sorting"] = json.dumps([i.to_params() for i in sorting]) + + return super().get( + count=count, + offset=offset, + **_params, + **kwargs, + ) diff --git a/tests/unit/test_user_alert.py b/tests/unit/test_user_alert.py index a819df15..b1625e37 100644 --- a/tests/unit/test_user_alert.py +++ b/tests/unit/test_user_alert.py @@ -20,7 +20,7 @@ def test_user_alert_creation_online() -> None: assert _alert.source == "user" assert _alert.name == f"users_alert_{_uuid}" assert _alert.notification == "none" - assert dict(Alert.get()) + assert dict(Alert.get(count=10)) _alert.delete() From 4597ee817e0d5dd46dc10e78c3926c2823ef4e76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristian=20Zar=C4=99bski?= Date: Thu, 13 Mar 2025 14:49:15 +0000 Subject: [PATCH 2/4] Added fetch tests --- tests/unit/test_fetch.py | 75 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 tests/unit/test_fetch.py diff --git a/tests/unit/test_fetch.py b/tests/unit/test_fetch.py new file mode 100644 index 00000000..f8779b6d --- /dev/null +++ b/tests/unit/test_fetch.py @@ -0,0 +1,75 @@ +import pytest + +from simvue.api.objects import Alert, Artifact, Tag, Run + + +@pytest.mark.api +@pytest.mark.online +@pytest.mark.parametrize( + "sort_column,sort_descending", + [ + ("name", True), + ("created", False), + (None, None) + ], + ids=("name-desc", "created-asc", "no-sorting") +) +def test_alerts_fetch(sort_column: str | None, sort_descending: bool | None) -> None: + if sort_column: + assert dict(Alert.get(sorting=[{"column": sort_column, "descending": sort_descending}], count=10)) + else: + assert dict(Alert.get(count=10)) + + +@pytest.mark.api +@pytest.mark.online +@pytest.mark.parametrize( + "sort_column,sort_descending", + [ + ("name", True), + ("created", False), + (None, None) + ], + ids=("name-desc", "created-asc", "no-sorting") +) +def test_artifacts_fetch(sort_column: str | None, sort_descending: bool | None) -> None: + if sort_column: + assert dict(Artifact.get(sorting=[{"column": sort_column, "descending": sort_descending}], count=10)) + else: + assert dict(Artifact.get(count=10)) + + +@pytest.mark.api +@pytest.mark.online +@pytest.mark.parametrize( + "sort_column,sort_descending", + [ + ("name", True), + ("created", False), + (None, None) + ], + ids=("name-desc", "created-asc", "no-sorting") +) +def test_tags_fetch(sort_column: str | None, sort_descending: bool | None) -> None: + if sort_column: + assert dict(Tag.get(sorting=[{"column": sort_column, "descending": sort_descending}], count=10)) + else: + assert dict(Tag.get(count=10)) + + +@pytest.mark.api +@pytest.mark.online +@pytest.mark.parametrize( + "sort_column,sort_descending", + [ + ("name", True), + ("name", False), + (None, None) + ], + ids=("name-desc", "created-asc", "no-sorting") +) +def test_runs_fetch(sort_column: str | None, sort_descending: bool | None) -> None: + if sort_column: + assert dict(Run.get(sorting=[{"column": sort_column, "descending": sort_descending}], count=10)) + else: + assert dict(Run.get(count=10)) From d955cd436ecdb0a04f75ab0303d9297f26567244 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristian=20Zar=C4=99bski?= Date: Tue, 25 Mar 2025 08:52:54 +0000 Subject: [PATCH 3/4] Added runs sorting test --- simvue/api/objects/run.py | 3 ++- simvue/client.py | 8 ++++++++ tests/functional/test_client.py | 14 +++++++++++--- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/simvue/api/objects/run.py b/simvue/api/objects/run.py index 50858581..ff5ef7a9 100644 --- a/simvue/api/objects/run.py +++ b/simvue/api/objects/run.py @@ -47,8 +47,9 @@ def check_column(cls, column: str) -> str: and column != "name" and not column.startswith("metrics") and not column.startswith("metadata.") + and column not in ("created", "started", "endtime", "modified") ): - raise ValueError(f"Invalid sort column for runs '{column}") + raise ValueError(f"Invalid sort column for runs '{column}'") return column diff --git a/simvue/client.py b/simvue/client.py index 76b5c025..24ea0e1c 100644 --- a/simvue/client.py +++ b/simvue/client.py @@ -182,6 +182,7 @@ def get_runs( count_limit: pydantic.PositiveInt | None = 100, start_index: pydantic.NonNegativeInt = 0, show_shared: bool = False, + sort_by_columns: list[tuple[str, bool]] | None = None, ) -> DataFrame | typing.Generator[tuple[str, Run], None, None] | None: """Retrieve all runs matching filters. @@ -210,6 +211,10 @@ def get_runs( the index from which to count entries. Default is 0. show_shared : bool, optional whether to include runs shared with the current user. Default is False. + sort_by_columns : list[tuple[str, bool]], optional + sort by columns in the order given, + list of tuples in the form (column_name: str, sort_descending: bool), + default is None. Returns ------- @@ -236,6 +241,9 @@ def get_runs( return_alerts=alerts, return_system=system, return_metadata=metadata, + sorting=[dict(zip(("column", "descending"), a)) for a in sort_by_columns] + if sort_by_columns + else None, ) if output_format == "objects": diff --git a/tests/functional/test_client.py b/tests/functional/test_client.py index 85407b6c..21cea843 100644 --- a/tests/functional/test_client.py +++ b/tests/functional/test_client.py @@ -200,11 +200,19 @@ def test_get_artifacts_as_files( @pytest.mark.dependency @pytest.mark.client -@pytest.mark.parametrize("output_format", ("dict", "dataframe", "objects")) -def test_get_runs(create_test_run: tuple[sv_run.Run, dict], output_format: str) -> None: +@pytest.mark.parametrize( + "output_format,sorting", + [ + ("dict", None), + ("dataframe", [("created", True), ("started", True)]), + ("objects", [("metadata.test_identifier", True)]), + ], + ids=("dict-unsorted", "dataframe-datesorted", "objects-metasorted") +) +def test_get_runs(create_test_run: tuple[sv_run.Run, dict], output_format: str, sorting: list[tuple[str, bool]] | None) -> None: client = svc.Client() - _result = client.get_runs(filters=None, output_format=output_format, count_limit=10) + _result = client.get_runs(filters=None, output_format=output_format, count_limit=10, sort_by_columns=sorting) if output_format == "dataframe": assert not _result.empty From 33bf821f6009f2b83d32afdc0638b57029d1ab36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristian=20Zar=C4=99bski?= Date: Tue, 25 Mar 2025 10:09:24 +0000 Subject: [PATCH 4/4] Added additional sorting tests --- simvue/api/objects/folder.py | 7 +++- simvue/api/objects/run.py | 2 +- simvue/client.py | 74 +++++++++++++++++++++++++++++---- tests/functional/test_client.py | 41 ++++++++++++++---- 4 files changed, 106 insertions(+), 18 deletions(-) diff --git a/simvue/api/objects/folder.py b/simvue/api/objects/folder.py index 2b3f78ea..178d9c1a 100644 --- a/simvue/api/objects/folder.py +++ b/simvue/api/objects/folder.py @@ -94,7 +94,12 @@ def get( sorting: list[FolderSort] | None = None, **kwargs, ) -> typing.Generator[tuple[str, T | None], None, None]: - return super().get(count=count, offset=offset, sorting=sorting) + _params: dict[str, str] = kwargs + + if sorting: + _params["sorting"] = json.dumps([i.to_params() for i in sorting]) + + return super().get(count=count, offset=offset, **_params) @property @staging_check diff --git a/simvue/api/objects/run.py b/simvue/api/objects/run.py index ff5ef7a9..f883788c 100644 --- a/simvue/api/objects/run.py +++ b/simvue/api/objects/run.py @@ -338,7 +338,7 @@ def get( id of run Run object representing object on server """ - _params: dict[str, str] = {} + _params: dict[str, str] = kwargs if sorting: _params["sorting"] = json.dumps([i.to_params() for i in sorting]) diff --git a/simvue/client.py b/simvue/client.py index 24ea0e1c..f5845edc 100644 --- a/simvue/client.py +++ b/simvue/client.py @@ -222,6 +222,11 @@ def get_runs( either the JSON response from the runs request or the results in the form of a Pandas DataFrame + Yields + ------ + tuple[str, Run] + identifier and Run object + Raises ------ ValueError @@ -434,13 +439,19 @@ def delete_alert(self, alert_id: str) -> None: @prettify_pydantic @pydantic.validate_call - def list_artifacts(self, run_id: str) -> typing.Generator[Artifact, None, None]: + def list_artifacts( + self, run_id: str, sort_by_columns: list[tuple[str, bool]] | None = None + ) -> typing.Generator[Artifact, None, None]: """Retrieve artifacts for a given run Parameters ---------- run_id : str unique identifier for the run + sort_by_columns : list[tuple[str, bool]], optional + sort by columns in the order given, + list of tuples in the form (column_name: str, sort_descending: bool), + default is None. Yields ------ @@ -452,7 +463,12 @@ def list_artifacts(self, run_id: str) -> typing.Generator[Artifact, None, None]: RuntimeError if retrieval of artifacts failed when communicating with the server """ - return Artifact.get(runs=json.dumps([run_id])) # type: ignore + return Artifact.get( + runs=json.dumps([run_id]), + sorting=[dict(zip(("column", "descending"), a)) for a in sort_by_columns] + if sort_by_columns + else None, + ) # type: ignore def _retrieve_artifacts_from_server( self, run_id: str, name: str, count: int | None = None @@ -582,7 +598,7 @@ def get_artifacts_as_files( * input - this file is an input file. * output - this file is created by the run. * code - this file represents an executed script - output_dir : str | None, optional + output_dir : str | None, optTODOional location to download files to, the default of None will download them to the current working directory @@ -649,6 +665,7 @@ def get_folders( filters: list[str] | None = None, count: pydantic.PositiveInt = 100, start_index: pydantic.NonNegativeInt = 0, + sort_by_columns: list[tuple[str, bool]] | None = None, ) -> typing.Generator[tuple[str, Folder], None, None]: """Retrieve folders from the server @@ -660,6 +677,10 @@ def get_folders( maximum number of entries to return. Default is 100. start_index : int, optional the index from which to count entries. Default is 0. + sort_by_columns : list[tuple[str, bool]], optional + sort by columns in the order given, + list of tuples in the form (column_name: str, sort_descending: bool), + default is None. Returns ------- @@ -672,7 +693,12 @@ def get_folders( if there was a failure retrieving data from the server """ return Folder.get( - filters=json.dumps(filters or []), count=count, offset=start_index + filters=json.dumps(filters or []), + count=count, + offset=start_index, + sorting=[dict(zip(("column", "descending"), a)) for a in sort_by_columns] + if sort_by_columns + else None, ) # type: ignore @prettify_pydantic @@ -981,6 +1007,7 @@ def get_alerts( names_only: bool = True, start_index: pydantic.NonNegativeInt | None = None, count_limit: pydantic.PositiveInt | None = None, + sort_by_columns: list[tuple[str, bool]] | None = None, ) -> list[AlertBase] | list[str | None]: """Retrieve alerts for a given run @@ -996,6 +1023,10 @@ def get_alerts( slice results returning only those above this index, by default None count_limit : typing.int, optional limit number of returned results, by default None + sort_by_columns : list[tuple[str, bool]], optional + sort by columns in the order given, + list of tuples in the form (column_name: str, sort_descending: bool), + default is None. Returns ------- @@ -1012,7 +1043,22 @@ def get_alerts( raise RuntimeError( "critical_only is ambiguous when returning alerts with no run ID specified." ) - return [alert.name if names_only else alert for _, alert in Alert.get()] # type: ignore + return [ + alert.name if names_only else alert + for _, alert in Alert.get( + sorting=[ + dict(zip(("column", "descending"), a)) for a in sort_by_columns + ] + if sort_by_columns + else None, + ) + ] # type: ignore + + if sort_by_columns: + logger.warning( + "Run identifier specified for alert retrieval," + " argument 'sort_by_columns' will be ignored" + ) _alerts = [ Alert(identifier=alert.get("id"), **alert) @@ -1032,6 +1078,7 @@ def get_tags( *, start_index: pydantic.NonNegativeInt | None = None, count_limit: pydantic.PositiveInt | None = None, + sort_by_columns: list[tuple[str, bool]] | None = None, ) -> typing.Generator[Tag, None, None]: """Retrieve tags @@ -1041,18 +1088,29 @@ def get_tags( slice results returning only those above this index, by default None count_limit : typing.int, optional limit number of returned results, by default None + sort_by_columns : list[tuple[str, bool]], optional + sort by columns in the order given, + list of tuples in the form (column_name: str, sort_descending: bool), + default is None. Returns ------- - list[Tag] - a list of all tags for this run + yields + tag identifier + tag object Raises ------ RuntimeError if there was a failure retrieving data from the server """ - return Tag.get(count=count_limit, offset=start_index) + return Tag.get( + count=count_limit, + offset=start_index, + sorting=[dict(zip(("column", "descending"), a)) for a in sort_by_columns] + if sort_by_columns + else None, + ) @prettify_pydantic @pydantic.validate_call diff --git a/tests/functional/test_client.py b/tests/functional/test_client.py index 21cea843..390d90d6 100644 --- a/tests/functional/test_client.py +++ b/tests/functional/test_client.py @@ -33,7 +33,12 @@ def test_get_events(create_test_run: tuple[sv_run.Run, dict]) -> None: @pytest.mark.parametrize( "critical_only", (True, False), ids=("critical_only", "all_states") ) -def test_get_alerts(create_plain_run: tuple[sv_run.Run, dict], from_run: bool, names_only: bool, critical_only: bool) -> None: +def test_get_alerts( + create_plain_run: tuple[sv_run.Run, dict], + from_run: bool, + names_only: bool, + critical_only: bool, +) -> None: run, run_data = create_plain_run run_id = run.id unique_id = f"{uuid.uuid4()}".split("-")[0] @@ -52,13 +57,19 @@ def test_get_alerts(create_plain_run: tuple[sv_run.Run, dict], from_run: bool, n run.close() client = svc.Client() - + if critical_only and not from_run: with pytest.raises(RuntimeError) as e: - _alerts = client.get_alerts(run_id=run_id if from_run else None, critical_only=critical_only, names_only=names_only) + _alerts = client.get_alerts(critical_only=critical_only, names_only=names_only) assert "critical_only is ambiguous when returning alerts with no run ID specified." in str(e.value) else: - _alerts = client.get_alerts(run_id=run_id if from_run else None, critical_only=critical_only, names_only=names_only) + sorting = None if run_id else [("name", True), ("created", True)] + _alerts = client.get_alerts( + run_id=run_id if from_run else None, + critical_only=critical_only, + names_only=names_only, + sort_by_columns=sorting + ) if names_only: assert all(isinstance(item, str) for item in _alerts) @@ -145,9 +156,16 @@ def test_plot_metrics(create_test_run: tuple[sv_run.Run, dict]) -> None: @pytest.mark.dependency @pytest.mark.client -def test_get_artifacts_entries(create_test_run: tuple[sv_run.Run, dict]) -> None: +@pytest.mark.parametrize( + "sorting", ([("metadata.test_identifier", True)], [("name", True), ("created", True)], None), + ids=("sorted-metadata", "sorted-name-created", None) +) +def test_get_artifacts_entries(create_test_run: tuple[sv_run.Run, dict], sorting: list[tuple[str, bool]] | None) -> None: + # TODO: Reinstate this test once server bug fixed + if any("metadata" in a[0] for a in sorting or []): + pytest.skip(reason="Server bug fix required for metadata sorting.") client = svc.Client() - assert dict(client.list_artifacts(create_test_run[1]["run_id"])) + assert dict(client.list_artifacts(create_test_run[1]["run_id"], sort_by_columns=sorting)) assert client.get_artifact(create_test_run[1]["run_id"], name="test_attributes") @@ -229,9 +247,16 @@ def test_get_run(create_test_run: tuple[sv_run.Run, dict]) -> None: @pytest.mark.dependency @pytest.mark.client -def test_get_folder(create_test_run: tuple[sv_run.Run, dict]) -> None: +@pytest.mark.parametrize( + "sorting", (None, [("metadata.test_identifier", True), ("path", True)], [("modified", False)]), + ids=("no-sort", "sort-path-metadata", "sort-modified") +) +def test_get_folders(create_test_run: tuple[sv_run.Run, dict], sorting: list[tuple[str, bool]] | None) -> None: + #TODO: Once server is fixed reinstate this test + if "modified" in (a[0] for a in sorting or []): + pytest.skip(reason="Server bug when sorting by 'modified'") client = svc.Client() - assert (folders := client.get_folders()) + assert (folders := client.get_folders(sort_by_columns=sorting)) _id, _folder = next(folders) assert _folder.path assert client.get_folder(_folder.path)