Skip to content

Commit dc5ca06

Browse files
authored
Merge pull request #366 from simvue-io/hotfix/add-client-validation
Add validation to client function calls
2 parents 5145d1d + f12658e commit dc5ca06

File tree

4 files changed

+139
-54
lines changed

4 files changed

+139
-54
lines changed

simvue/client.py

Lines changed: 80 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
import logging
1111
import os
1212
import typing
13+
import pydantic
1314
from concurrent.futures import ThreadPoolExecutor, as_completed
15+
from pandas import DataFrame
1416

1517
import requests
1618

@@ -21,11 +23,11 @@
2123
)
2224
from .serialization import deserialize_data
2325
from .types import DeserializedContent
24-
from .utilities import check_extra, get_auth
26+
from .utilities import check_extra, get_auth, prettify_pydantic
27+
from .models import FOLDER_REGEX, NAME_REGEX
2528

2629
if typing.TYPE_CHECKING:
27-
from matplotlib.figure import Figure
28-
from pandas import DataFrame
30+
pass
2931

3032
CONCURRENT_DOWNLOADS = 10
3133
DOWNLOAD_CHUNK_SIZE = 8192
@@ -133,7 +135,11 @@ def _get_json_from_response(
133135

134136
raise RuntimeError(error_str)
135137

136-
def get_run_id_from_name(self, name: str) -> str:
138+
@prettify_pydantic
139+
@pydantic.validate_call
140+
def get_run_id_from_name(
141+
self, name: typing.Annotated[str, pydantic.Field(pattern=NAME_REGEX)]
142+
) -> str:
137143
"""Get Run ID from the server matching the specified name
138144
139145
Assumes a unique name for this run. If multiple results are found this
@@ -186,6 +192,8 @@ def get_run_id_from_name(self, name: str) -> str:
186192
raise RuntimeError("Failed to retrieve identifier for run.")
187193
return first_id
188194

195+
@prettify_pydantic
196+
@pydantic.validate_call
189197
def get_run(self, run_id: str) -> typing.Optional[dict[str, typing.Any]]:
190198
"""Retrieve a single run
191199
@@ -225,6 +233,8 @@ def get_run(self, run_id: str) -> typing.Optional[dict[str, typing.Any]]:
225233
)
226234
return json_response
227235

236+
@prettify_pydantic
237+
@pydantic.validate_call
228238
def get_run_name_from_id(self, run_id: str) -> str:
229239
"""Retrieve the name of a run from its identifier
230240
@@ -250,6 +260,8 @@ def get_run_name_from_id(self, run_id: str) -> str:
250260
raise RuntimeError("Expected key 'name' in server response")
251261
return _name
252262

263+
@prettify_pydantic
264+
@pydantic.validate_call
253265
def get_runs(
254266
self,
255267
filters: typing.Optional[list[str]],
@@ -261,7 +273,7 @@ def get_runs(
261273
count: int = 100,
262274
start_index: int = 0,
263275
) -> typing.Union[
264-
"DataFrame", list[dict[str, typing.Union[int, str, float, None]]], None
276+
DataFrame, list[dict[str, typing.Union[int, str, float, None]]], None
265277
]:
266278
"""Retrieve all runs matching filters.
267279
@@ -337,6 +349,8 @@ def get_runs(
337349
else:
338350
raise RuntimeError("Failed to retrieve runs data")
339351

352+
@prettify_pydantic
353+
@pydantic.validate_call
340354
def delete_run(self, run_identifier: str) -> typing.Optional[dict]:
341355
"""Delete run by identifier
342356
@@ -404,13 +418,18 @@ def _get_folder_id_from_path(self, path: str) -> typing.Optional[str]:
404418

405419
return None
406420

407-
def delete_runs(self, folder_name: str) -> typing.Optional[list]:
421+
@prettify_pydantic
422+
@pydantic.validate_call
423+
def delete_runs(
424+
self, folder_path: typing.Annotated[str, pydantic.Field(pattern=FOLDER_REGEX)]
425+
) -> typing.Optional[list]:
408426
"""Delete runs in a named folder
409427
410428
Parameters
411429
----------
412-
folder_name : str
413-
the name of the folder on which to perform deletion
430+
folder_path : str
431+
the path of the folder on which to perform deletion. All folder
432+
paths are prefixed with `/`
414433
415434
Returns
416435
-------
@@ -422,10 +441,10 @@ def delete_runs(self, folder_name: str) -> typing.Optional[list]:
422441
RuntimeError
423442
if deletion fails due to server request error
424443
"""
425-
folder_id = self._get_folder_id_from_path(folder_name)
444+
folder_id = self._get_folder_id_from_path(folder_path)
426445

427446
if not folder_id:
428-
raise ValueError(f"Could not find a folder matching '{folder_name}'")
447+
raise ValueError(f"Could not find a folder matching '{folder_path}'")
429448

430449
params: dict[str, bool] = {"runs_only": True, "runs": True}
431450

@@ -435,19 +454,21 @@ def delete_runs(self, folder_name: str) -> typing.Optional[list]:
435454

436455
if response.status_code == 200:
437456
if runs := response.json().get("runs", []):
438-
logger.debug(f"Runs from '{folder_name}' deleted successfully: {runs}")
457+
logger.debug(f"Runs from '{folder_path}' deleted successfully: {runs}")
439458
else:
440459
logger.debug("Folder empty, no runs deleted.")
441460
return runs
442461

443462
raise RuntimeError(
444-
f"Deletion of runs from folder '{folder_name}' failed"
463+
f"Deletion of runs from folder '{folder_path}' failed"
445464
f"with code {response.status_code}: {response.text}"
446465
)
447466

467+
@prettify_pydantic
468+
@pydantic.validate_call
448469
def delete_folder(
449470
self,
450-
folder_name: str,
471+
folder_path: typing.Annotated[str, pydantic.Field(pattern=FOLDER_REGEX)],
451472
recursive: bool = False,
452473
remove_runs: bool = False,
453474
allow_missing: bool = False,
@@ -456,8 +477,8 @@ def delete_folder(
456477
457478
Parameters
458479
----------
459-
folder_name : str
460-
name of the folder to delete
480+
folder_path : str
481+
name of the folder to delete. All paths are prefixed with `/`
461482
recursive : bool, optional
462483
if folder contains additional folders remove these, else return an
463484
error. Default False.
@@ -477,14 +498,14 @@ def delete_folder(
477498
RuntimeError
478499
if deletion of the folder from the server failed
479500
"""
480-
folder_id = self._get_folder_id_from_path(folder_name)
501+
folder_id = self._get_folder_id_from_path(folder_path)
481502

482503
if not folder_id:
483504
if allow_missing:
484505
return None
485506
else:
486507
raise RuntimeError(
487-
f"Deletion of folder '{folder_name}' failed, "
508+
f"Deletion of folder '{folder_path}' failed, "
488509
"folder does not exist."
489510
)
490511

@@ -497,7 +518,7 @@ def delete_folder(
497518

498519
json_response = self._get_json_from_response(
499520
expected_status=[200, 404],
500-
scenario=f"Deletion of folder '{folder_name}'",
521+
scenario=f"Deletion of folder '{folder_path}'",
501522
response=response,
502523
)
503524

@@ -510,6 +531,8 @@ def delete_folder(
510531
runs: list[dict] = json_response.get("runs", [])
511532
return runs
512533

534+
@prettify_pydantic
535+
@pydantic.validate_call
513536
def list_artifacts(self, run_id: str) -> list[dict[str, typing.Any]]:
514537
"""Retrieve artifacts for a given run
515538
@@ -574,9 +597,11 @@ def _retrieve_artifact_from_server(
574597

575598
return json_response
576599

600+
@prettify_pydantic
601+
@pydantic.validate_call
577602
def get_artifact(
578603
self, run_id: str, name: str, allow_pickle: bool = False
579-
) -> typing.Optional[DeserializedContent]:
604+
) -> typing.Any:
580605
"""Return the contents of a specified artifact
581606
582607
Parameters
@@ -618,6 +643,8 @@ def get_artifact(
618643

619644
return content or response.content
620645

646+
@prettify_pydantic
647+
@pydantic.validate_call
621648
def get_artifact_as_file(
622649
self, run_id: str, name: str, path: typing.Optional[str] = None
623650
) -> None:
@@ -708,6 +735,8 @@ def _assemble_artifact_downloads(
708735

709736
return downloads
710737

738+
@prettify_pydantic
739+
@pydantic.validate_call
711740
def get_artifacts_as_files(
712741
self,
713742
run_id: str,
@@ -771,13 +800,18 @@ def get_artifacts_as_files(
771800
f"failed with exception: {e}"
772801
)
773802

774-
def get_folder(self, folder_id: str) -> typing.Optional[dict[str, typing.Any]]:
803+
@prettify_pydantic
804+
@pydantic.validate_call
805+
def get_folder(
806+
self, folder_path: typing.Annotated[str, pydantic.Field(pattern=FOLDER_REGEX)]
807+
) -> typing.Optional[dict[str, typing.Any]]:
775808
"""Retrieve a folder by identifier
776809
777810
Parameters
778811
----------
779-
folder_id : str
780-
unique identifier for the folder
812+
folder_path : str
813+
the path of the folder to retrieve on the server.
814+
Paths are prefixed with `/`
781815
782816
Returns
783817
-------
@@ -789,15 +823,16 @@ def get_folder(self, folder_id: str) -> typing.Optional[dict[str, typing.Any]]:
789823
RuntimeError
790824
if there was a failure when retrieving information from the server
791825
"""
792-
if not (_folders := self.get_folders(filters=[f"path == {folder_id}"])):
826+
if not (_folders := self.get_folders(filters=[f"path == {folder_path}"])):
793827
return None
794828
return _folders[0]
795829

830+
@pydantic.validate_call
796831
def get_folders(
797832
self,
798833
filters: typing.Optional[list[str]] = None,
799-
count: int = 100,
800-
start_index: int = 0,
834+
count: pydantic.PositiveInt = 100,
835+
start_index: pydantic.NonNegativeInt = 0,
801836
) -> list[dict[str, typing.Any]]:
802837
"""Retrieve folders from the server
803838
@@ -847,6 +882,8 @@ def get_folders(
847882

848883
return data
849884

885+
@prettify_pydantic
886+
@pydantic.validate_call
850887
def get_metrics_names(self, run_id: str) -> list[str]:
851888
"""Return information on all metrics within a run
852889
@@ -918,6 +955,8 @@ def _get_run_metrics_from_server(
918955

919956
return json_response
920957

958+
@prettify_pydantic
959+
@pydantic.validate_call
921960
def get_metric_values(
922961
self,
923962
metric_names: list[str],
@@ -927,8 +966,8 @@ def get_metric_values(
927966
run_filters: typing.Optional[list[str]] = None,
928967
use_run_names: bool = False,
929968
aggregate: bool = False,
930-
max_points: int = -1,
931-
) -> typing.Union[dict, "DataFrame", None]:
969+
max_points: typing.Optional[pydantic.PositiveInt] = None,
970+
) -> typing.Union[dict, DataFrame, None]:
932971
"""Retrieve the values for a given metric across multiple runs
933972
934973
Uses filters to specify which runs should be retrieved.
@@ -955,7 +994,7 @@ def get_metric_values(
955994
return results as averages (not compatible with xaxis=timestamp),
956995
default is False
957996
max_points : int, optional
958-
maximum number of data points, by default -1 (all)
997+
maximum number of data points, by default None (all)
959998
960999
Returns
9611000
-------
@@ -1010,7 +1049,7 @@ def get_metric_values(
10101049
run_ids=run_ids,
10111050
xaxis=xaxis,
10121051
aggregate=aggregate,
1013-
max_points=max_points,
1052+
max_points=max_points or -1,
10141053
)
10151054

10161055
if aggregate:
@@ -1023,13 +1062,15 @@ def get_metric_values(
10231062
)
10241063

10251064
@check_extra("plot")
1065+
@prettify_pydantic
1066+
@pydantic.validate_call
10261067
def plot_metrics(
10271068
self,
10281069
run_ids: list[str],
10291070
metric_names: list[str],
10301071
xaxis: typing.Literal["step", "time"],
1031-
max_points: int = -1,
1032-
) -> "Figure":
1072+
max_points: typing.Optional[int] = None,
1073+
) -> typing.Any:
10331074
"""Plt the time series values for multiple metrics/runs
10341075
10351076
Parameters
@@ -1041,7 +1082,7 @@ def plot_metrics(
10411082
xaxis : str, ('step' | 'time' | 'timestep')
10421083
the x axis to plot against
10431084
max_points : int, optional
1044-
maximum number of data points, by default -1 (all)
1085+
maximum number of data points, by default None (all)
10451086
10461087
Returns
10471088
-------
@@ -1059,7 +1100,7 @@ def plot_metrics(
10591100
if not isinstance(metric_names, list):
10601101
raise ValueError("Invalid names specified, must be a list of metric names.")
10611102

1062-
data: "DataFrame" = self.get_metric_values( # type: ignore
1103+
data: DataFrame = self.get_metric_values( # type: ignore
10631104
run_ids=run_ids,
10641105
metric_names=metric_names,
10651106
xaxis=xaxis,
@@ -1099,12 +1140,14 @@ def plot_metrics(
10991140

11001141
return plt.figure()
11011142

1143+
@prettify_pydantic
1144+
@pydantic.validate_call
11021145
def get_events(
11031146
self,
11041147
run_id: str,
11051148
message_contains: typing.Optional[str] = None,
1106-
start_index: typing.Optional[int] = None,
1107-
count_limit: typing.Optional[int] = None,
1149+
start_index: typing.Optional[pydantic.NonNegativeInt] = None,
1150+
count_limit: typing.Optional[pydantic.PositiveInt] = None,
11081151
) -> list[dict[str, str]]:
11091152
"""Return events for a specified run
11101153
@@ -1160,6 +1203,8 @@ def get_events(
11601203

11611204
return response.json().get("data", [])
11621205

1206+
@prettify_pydantic
1207+
@pydantic.validate_call
11631208
def get_alerts(
11641209
self, run_id: str, critical_only: bool = True, names_only: bool = True
11651210
) -> list[dict[str, typing.Any]]:

0 commit comments

Comments
 (0)