Skip to content

Commit 33bf821

Browse files
committed
Added additional sorting tests
1 parent d955cd4 commit 33bf821

File tree

4 files changed

+106
-18
lines changed

4 files changed

+106
-18
lines changed

simvue/api/objects/folder.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,12 @@ def get(
9494
sorting: list[FolderSort] | None = None,
9595
**kwargs,
9696
) -> typing.Generator[tuple[str, T | None], None, None]:
97-
return super().get(count=count, offset=offset, sorting=sorting)
97+
_params: dict[str, str] = kwargs
98+
99+
if sorting:
100+
_params["sorting"] = json.dumps([i.to_params() for i in sorting])
101+
102+
return super().get(count=count, offset=offset, **_params)
98103

99104
@property
100105
@staging_check

simvue/api/objects/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def get(
338338
id of run
339339
Run object representing object on server
340340
"""
341-
_params: dict[str, str] = {}
341+
_params: dict[str, str] = kwargs
342342

343343
if sorting:
344344
_params["sorting"] = json.dumps([i.to_params() for i in sorting])

simvue/client.py

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,11 @@ def get_runs(
222222
either the JSON response from the runs request or the results in the
223223
form of a Pandas DataFrame
224224
225+
Yields
226+
------
227+
tuple[str, Run]
228+
identifier and Run object
229+
225230
Raises
226231
------
227232
ValueError
@@ -434,13 +439,19 @@ def delete_alert(self, alert_id: str) -> None:
434439

435440
@prettify_pydantic
436441
@pydantic.validate_call
437-
def list_artifacts(self, run_id: str) -> typing.Generator[Artifact, None, None]:
442+
def list_artifacts(
443+
self, run_id: str, sort_by_columns: list[tuple[str, bool]] | None = None
444+
) -> typing.Generator[Artifact, None, None]:
438445
"""Retrieve artifacts for a given run
439446
440447
Parameters
441448
----------
442449
run_id : str
443450
unique identifier for the run
451+
sort_by_columns : list[tuple[str, bool]], optional
452+
sort by columns in the order given,
453+
list of tuples in the form (column_name: str, sort_descending: bool),
454+
default is None.
444455
445456
Yields
446457
------
@@ -452,7 +463,12 @@ def list_artifacts(self, run_id: str) -> typing.Generator[Artifact, None, None]:
452463
RuntimeError
453464
if retrieval of artifacts failed when communicating with the server
454465
"""
455-
return Artifact.get(runs=json.dumps([run_id])) # type: ignore
466+
return Artifact.get(
467+
runs=json.dumps([run_id]),
468+
sorting=[dict(zip(("column", "descending"), a)) for a in sort_by_columns]
469+
if sort_by_columns
470+
else None,
471+
) # type: ignore
456472

457473
def _retrieve_artifacts_from_server(
458474
self, run_id: str, name: str, count: int | None = None
@@ -582,7 +598,7 @@ def get_artifacts_as_files(
582598
* input - this file is an input file.
583599
* output - this file is created by the run.
584600
* code - this file represents an executed script
585-
output_dir : str | None, optional
601+
output_dir : str | None, optTODOional
586602
location to download files to, the default of None will download
587603
them to the current working directory
588604
@@ -649,6 +665,7 @@ def get_folders(
649665
filters: list[str] | None = None,
650666
count: pydantic.PositiveInt = 100,
651667
start_index: pydantic.NonNegativeInt = 0,
668+
sort_by_columns: list[tuple[str, bool]] | None = None,
652669
) -> typing.Generator[tuple[str, Folder], None, None]:
653670
"""Retrieve folders from the server
654671
@@ -660,6 +677,10 @@ def get_folders(
660677
maximum number of entries to return. Default is 100.
661678
start_index : int, optional
662679
the index from which to count entries. Default is 0.
680+
sort_by_columns : list[tuple[str, bool]], optional
681+
sort by columns in the order given,
682+
list of tuples in the form (column_name: str, sort_descending: bool),
683+
default is None.
663684
664685
Returns
665686
-------
@@ -672,7 +693,12 @@ def get_folders(
672693
if there was a failure retrieving data from the server
673694
"""
674695
return Folder.get(
675-
filters=json.dumps(filters or []), count=count, offset=start_index
696+
filters=json.dumps(filters or []),
697+
count=count,
698+
offset=start_index,
699+
sorting=[dict(zip(("column", "descending"), a)) for a in sort_by_columns]
700+
if sort_by_columns
701+
else None,
676702
) # type: ignore
677703

678704
@prettify_pydantic
@@ -981,6 +1007,7 @@ def get_alerts(
9811007
names_only: bool = True,
9821008
start_index: pydantic.NonNegativeInt | None = None,
9831009
count_limit: pydantic.PositiveInt | None = None,
1010+
sort_by_columns: list[tuple[str, bool]] | None = None,
9841011
) -> list[AlertBase] | list[str | None]:
9851012
"""Retrieve alerts for a given run
9861013
@@ -996,6 +1023,10 @@ def get_alerts(
9961023
slice results returning only those above this index, by default None
9971024
count_limit : typing.int, optional
9981025
limit number of returned results, by default None
1026+
sort_by_columns : list[tuple[str, bool]], optional
1027+
sort by columns in the order given,
1028+
list of tuples in the form (column_name: str, sort_descending: bool),
1029+
default is None.
9991030
10001031
Returns
10011032
-------
@@ -1012,7 +1043,22 @@ def get_alerts(
10121043
raise RuntimeError(
10131044
"critical_only is ambiguous when returning alerts with no run ID specified."
10141045
)
1015-
return [alert.name if names_only else alert for _, alert in Alert.get()] # type: ignore
1046+
return [
1047+
alert.name if names_only else alert
1048+
for _, alert in Alert.get(
1049+
sorting=[
1050+
dict(zip(("column", "descending"), a)) for a in sort_by_columns
1051+
]
1052+
if sort_by_columns
1053+
else None,
1054+
)
1055+
] # type: ignore
1056+
1057+
if sort_by_columns:
1058+
logger.warning(
1059+
"Run identifier specified for alert retrieval,"
1060+
" argument 'sort_by_columns' will be ignored"
1061+
)
10161062

10171063
_alerts = [
10181064
Alert(identifier=alert.get("id"), **alert)
@@ -1032,6 +1078,7 @@ def get_tags(
10321078
*,
10331079
start_index: pydantic.NonNegativeInt | None = None,
10341080
count_limit: pydantic.PositiveInt | None = None,
1081+
sort_by_columns: list[tuple[str, bool]] | None = None,
10351082
) -> typing.Generator[Tag, None, None]:
10361083
"""Retrieve tags
10371084
@@ -1041,18 +1088,29 @@ def get_tags(
10411088
slice results returning only those above this index, by default None
10421089
count_limit : typing.int, optional
10431090
limit number of returned results, by default None
1091+
sort_by_columns : list[tuple[str, bool]], optional
1092+
sort by columns in the order given,
1093+
list of tuples in the form (column_name: str, sort_descending: bool),
1094+
default is None.
10441095
10451096
Returns
10461097
-------
1047-
list[Tag]
1048-
a list of all tags for this run
1098+
yields
1099+
tag identifier
1100+
tag object
10491101
10501102
Raises
10511103
------
10521104
RuntimeError
10531105
if there was a failure retrieving data from the server
10541106
"""
1055-
return Tag.get(count=count_limit, offset=start_index)
1107+
return Tag.get(
1108+
count=count_limit,
1109+
offset=start_index,
1110+
sorting=[dict(zip(("column", "descending"), a)) for a in sort_by_columns]
1111+
if sort_by_columns
1112+
else None,
1113+
)
10561114

10571115
@prettify_pydantic
10581116
@pydantic.validate_call

tests/functional/test_client.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@ def test_get_events(create_test_run: tuple[sv_run.Run, dict]) -> None:
3333
@pytest.mark.parametrize(
3434
"critical_only", (True, False), ids=("critical_only", "all_states")
3535
)
36-
def test_get_alerts(create_plain_run: tuple[sv_run.Run, dict], from_run: bool, names_only: bool, critical_only: bool) -> None:
36+
def test_get_alerts(
37+
create_plain_run: tuple[sv_run.Run, dict],
38+
from_run: bool,
39+
names_only: bool,
40+
critical_only: bool,
41+
) -> None:
3742
run, run_data = create_plain_run
3843
run_id = run.id
3944
unique_id = f"{uuid.uuid4()}".split("-")[0]
@@ -52,13 +57,19 @@ def test_get_alerts(create_plain_run: tuple[sv_run.Run, dict], from_run: bool, n
5257
run.close()
5358

5459
client = svc.Client()
55-
60+
5661
if critical_only and not from_run:
5762
with pytest.raises(RuntimeError) as e:
58-
_alerts = client.get_alerts(run_id=run_id if from_run else None, critical_only=critical_only, names_only=names_only)
63+
_alerts = client.get_alerts(critical_only=critical_only, names_only=names_only)
5964
assert "critical_only is ambiguous when returning alerts with no run ID specified." in str(e.value)
6065
else:
61-
_alerts = client.get_alerts(run_id=run_id if from_run else None, critical_only=critical_only, names_only=names_only)
66+
sorting = None if run_id else [("name", True), ("created", True)]
67+
_alerts = client.get_alerts(
68+
run_id=run_id if from_run else None,
69+
critical_only=critical_only,
70+
names_only=names_only,
71+
sort_by_columns=sorting
72+
)
6273

6374
if names_only:
6475
assert all(isinstance(item, str) for item in _alerts)
@@ -145,9 +156,16 @@ def test_plot_metrics(create_test_run: tuple[sv_run.Run, dict]) -> None:
145156

146157
@pytest.mark.dependency
147158
@pytest.mark.client
148-
def test_get_artifacts_entries(create_test_run: tuple[sv_run.Run, dict]) -> None:
159+
@pytest.mark.parametrize(
160+
"sorting", ([("metadata.test_identifier", True)], [("name", True), ("created", True)], None),
161+
ids=("sorted-metadata", "sorted-name-created", None)
162+
)
163+
def test_get_artifacts_entries(create_test_run: tuple[sv_run.Run, dict], sorting: list[tuple[str, bool]] | None) -> None:
164+
# TODO: Reinstate this test once server bug fixed
165+
if any("metadata" in a[0] for a in sorting or []):
166+
pytest.skip(reason="Server bug fix required for metadata sorting.")
149167
client = svc.Client()
150-
assert dict(client.list_artifacts(create_test_run[1]["run_id"]))
168+
assert dict(client.list_artifacts(create_test_run[1]["run_id"], sort_by_columns=sorting))
151169
assert client.get_artifact(create_test_run[1]["run_id"], name="test_attributes")
152170

153171

@@ -229,9 +247,16 @@ def test_get_run(create_test_run: tuple[sv_run.Run, dict]) -> None:
229247

230248
@pytest.mark.dependency
231249
@pytest.mark.client
232-
def test_get_folder(create_test_run: tuple[sv_run.Run, dict]) -> None:
250+
@pytest.mark.parametrize(
251+
"sorting", (None, [("metadata.test_identifier", True), ("path", True)], [("modified", False)]),
252+
ids=("no-sort", "sort-path-metadata", "sort-modified")
253+
)
254+
def test_get_folders(create_test_run: tuple[sv_run.Run, dict], sorting: list[tuple[str, bool]] | None) -> None:
255+
#TODO: Once server is fixed reinstate this test
256+
if "modified" in (a[0] for a in sorting or []):
257+
pytest.skip(reason="Server bug when sorting by 'modified'")
233258
client = svc.Client()
234-
assert (folders := client.get_folders())
259+
assert (folders := client.get_folders(sort_by_columns=sorting))
235260
_id, _folder = next(folders)
236261
assert _folder.path
237262
assert client.get_folder(_folder.path)

0 commit comments

Comments
 (0)