Skip to content

Add validation to client function calls #366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 80 additions & 35 deletions simvue/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import logging
import os
import typing
import pydantic
from concurrent.futures import ThreadPoolExecutor, as_completed
from pandas import DataFrame

import requests

Expand All @@ -21,11 +23,11 @@
)
from .serialization import deserialize_data
from .types import DeserializedContent
from .utilities import check_extra, get_auth
from .utilities import check_extra, get_auth, prettify_pydantic
from .models import FOLDER_REGEX, NAME_REGEX

if typing.TYPE_CHECKING:
from matplotlib.figure import Figure
from pandas import DataFrame
pass

CONCURRENT_DOWNLOADS = 10
DOWNLOAD_CHUNK_SIZE = 8192
Expand Down Expand Up @@ -133,7 +135,11 @@ def _get_json_from_response(

raise RuntimeError(error_str)

def get_run_id_from_name(self, name: str) -> str:
@prettify_pydantic
@pydantic.validate_call
def get_run_id_from_name(
self, name: typing.Annotated[str, pydantic.Field(pattern=NAME_REGEX)]
) -> str:
"""Get Run ID from the server matching the specified name

Assumes a unique name for this run. If multiple results are found this
Expand Down Expand Up @@ -186,6 +192,8 @@ def get_run_id_from_name(self, name: str) -> str:
raise RuntimeError("Failed to retrieve identifier for run.")
return first_id

@prettify_pydantic
@pydantic.validate_call
def get_run(self, run_id: str) -> typing.Optional[dict[str, typing.Any]]:
"""Retrieve a single run

Expand Down Expand Up @@ -225,6 +233,8 @@ def get_run(self, run_id: str) -> typing.Optional[dict[str, typing.Any]]:
)
return json_response

@prettify_pydantic
@pydantic.validate_call
def get_run_name_from_id(self, run_id: str) -> str:
"""Retrieve the name of a run from its identifier

Expand All @@ -250,6 +260,8 @@ def get_run_name_from_id(self, run_id: str) -> str:
raise RuntimeError("Expected key 'name' in server response")
return _name

@prettify_pydantic
@pydantic.validate_call
def get_runs(
self,
filters: typing.Optional[list[str]],
Expand All @@ -261,7 +273,7 @@ def get_runs(
count: int = 100,
start_index: int = 0,
) -> typing.Union[
"DataFrame", list[dict[str, typing.Union[int, str, float, None]]], None
DataFrame, list[dict[str, typing.Union[int, str, float, None]]], None
]:
"""Retrieve all runs matching filters.

Expand Down Expand Up @@ -337,6 +349,8 @@ def get_runs(
else:
raise RuntimeError("Failed to retrieve runs data")

@prettify_pydantic
@pydantic.validate_call
def delete_run(self, run_identifier: str) -> typing.Optional[dict]:
"""Delete run by identifier

Expand Down Expand Up @@ -404,13 +418,18 @@ def _get_folder_id_from_path(self, path: str) -> typing.Optional[str]:

return None

def delete_runs(self, folder_name: str) -> typing.Optional[list]:
@prettify_pydantic
@pydantic.validate_call
def delete_runs(
self, folder_path: typing.Annotated[str, pydantic.Field(pattern=FOLDER_REGEX)]
) -> typing.Optional[list]:
"""Delete runs in a named folder

Parameters
----------
folder_name : str
the name of the folder on which to perform deletion
folder_path : str
the path of the folder on which to perform deletion. All folder
paths are prefixed with `/`

Returns
-------
Expand All @@ -422,10 +441,10 @@ def delete_runs(self, folder_name: str) -> typing.Optional[list]:
RuntimeError
if deletion fails due to server request error
"""
folder_id = self._get_folder_id_from_path(folder_name)
folder_id = self._get_folder_id_from_path(folder_path)

if not folder_id:
raise ValueError(f"Could not find a folder matching '{folder_name}'")
raise ValueError(f"Could not find a folder matching '{folder_path}'")

params: dict[str, bool] = {"runs_only": True, "runs": True}

Expand All @@ -435,19 +454,21 @@ def delete_runs(self, folder_name: str) -> typing.Optional[list]:

if response.status_code == 200:
if runs := response.json().get("runs", []):
logger.debug(f"Runs from '{folder_name}' deleted successfully: {runs}")
logger.debug(f"Runs from '{folder_path}' deleted successfully: {runs}")
else:
logger.debug("Folder empty, no runs deleted.")
return runs

raise RuntimeError(
f"Deletion of runs from folder '{folder_name}' failed"
f"Deletion of runs from folder '{folder_path}' failed"
f"with code {response.status_code}: {response.text}"
)

@prettify_pydantic
@pydantic.validate_call
def delete_folder(
self,
folder_name: str,
folder_path: typing.Annotated[str, pydantic.Field(pattern=FOLDER_REGEX)],
recursive: bool = False,
remove_runs: bool = False,
allow_missing: bool = False,
Expand All @@ -456,8 +477,8 @@ def delete_folder(

Parameters
----------
folder_name : str
name of the folder to delete
folder_path : str
name of the folder to delete. All paths are prefixed with `/`
recursive : bool, optional
if folder contains additional folders remove these, else return an
error. Default False.
Expand All @@ -477,14 +498,14 @@ def delete_folder(
RuntimeError
if deletion of the folder from the server failed
"""
folder_id = self._get_folder_id_from_path(folder_name)
folder_id = self._get_folder_id_from_path(folder_path)

if not folder_id:
if allow_missing:
return None
else:
raise RuntimeError(
f"Deletion of folder '{folder_name}' failed, "
f"Deletion of folder '{folder_path}' failed, "
"folder does not exist."
)

Expand All @@ -497,7 +518,7 @@ def delete_folder(

json_response = self._get_json_from_response(
expected_status=[200, 404],
scenario=f"Deletion of folder '{folder_name}'",
scenario=f"Deletion of folder '{folder_path}'",
response=response,
)

Expand All @@ -510,6 +531,8 @@ def delete_folder(
runs: list[dict] = json_response.get("runs", [])
return runs

@prettify_pydantic
@pydantic.validate_call
def list_artifacts(self, run_id: str) -> list[dict[str, typing.Any]]:
"""Retrieve artifacts for a given run

Expand Down Expand Up @@ -574,9 +597,11 @@ def _retrieve_artifact_from_server(

return json_response

@prettify_pydantic
@pydantic.validate_call
def get_artifact(
self, run_id: str, name: str, allow_pickle: bool = False
) -> typing.Optional[DeserializedContent]:
) -> typing.Any:
"""Return the contents of a specified artifact

Parameters
Expand Down Expand Up @@ -618,6 +643,8 @@ def get_artifact(

return content or response.content

@prettify_pydantic
@pydantic.validate_call
def get_artifact_as_file(
self, run_id: str, name: str, path: typing.Optional[str] = None
) -> None:
Expand Down Expand Up @@ -708,6 +735,8 @@ def _assemble_artifact_downloads(

return downloads

@prettify_pydantic
@pydantic.validate_call
def get_artifacts_as_files(
self,
run_id: str,
Expand Down Expand Up @@ -771,13 +800,18 @@ def get_artifacts_as_files(
f"failed with exception: {e}"
)

def get_folder(self, folder_id: str) -> typing.Optional[dict[str, typing.Any]]:
@prettify_pydantic
@pydantic.validate_call
def get_folder(
self, folder_path: typing.Annotated[str, pydantic.Field(pattern=FOLDER_REGEX)]
) -> typing.Optional[dict[str, typing.Any]]:
"""Retrieve a folder by identifier

Parameters
----------
folder_id : str
unique identifier for the folder
folder_path : str
the path of the folder to retrieve on the server.
Paths are prefixed with `/`

Returns
-------
Expand All @@ -789,15 +823,16 @@ def get_folder(self, folder_id: str) -> typing.Optional[dict[str, typing.Any]]:
RuntimeError
if there was a failure when retrieving information from the server
"""
if not (_folders := self.get_folders(filters=[f"path == {folder_id}"])):
if not (_folders := self.get_folders(filters=[f"path == {folder_path}"])):
return None
return _folders[0]

@pydantic.validate_call
def get_folders(
self,
filters: typing.Optional[list[str]] = None,
count: int = 100,
start_index: int = 0,
count: pydantic.PositiveInt = 100,
start_index: pydantic.NonNegativeInt = 0,
) -> list[dict[str, typing.Any]]:
"""Retrieve folders from the server

Expand Down Expand Up @@ -847,6 +882,8 @@ def get_folders(

return data

@prettify_pydantic
@pydantic.validate_call
def get_metrics_names(self, run_id: str) -> list[str]:
"""Return information on all metrics within a run

Expand Down Expand Up @@ -918,6 +955,8 @@ def _get_run_metrics_from_server(

return json_response

@prettify_pydantic
@pydantic.validate_call
def get_metric_values(
self,
metric_names: list[str],
Expand All @@ -927,8 +966,8 @@ def get_metric_values(
run_filters: typing.Optional[list[str]] = None,
use_run_names: bool = False,
aggregate: bool = False,
max_points: int = -1,
) -> typing.Union[dict, "DataFrame", None]:
max_points: typing.Optional[pydantic.PositiveInt] = None,
) -> typing.Union[dict, DataFrame, None]:
"""Retrieve the values for a given metric across multiple runs

Uses filters to specify which runs should be retrieved.
Expand All @@ -955,7 +994,7 @@ def get_metric_values(
return results as averages (not compatible with xaxis=timestamp),
default is False
max_points : int, optional
maximum number of data points, by default -1 (all)
maximum number of data points, by default None (all)

Returns
-------
Expand Down Expand Up @@ -1010,7 +1049,7 @@ def get_metric_values(
run_ids=run_ids,
xaxis=xaxis,
aggregate=aggregate,
max_points=max_points,
max_points=max_points or -1,
)

if aggregate:
Expand All @@ -1023,13 +1062,15 @@ def get_metric_values(
)

@check_extra("plot")
@prettify_pydantic
@pydantic.validate_call
def plot_metrics(
self,
run_ids: list[str],
metric_names: list[str],
xaxis: typing.Literal["step", "time"],
max_points: int = -1,
) -> "Figure":
max_points: typing.Optional[int] = None,
) -> typing.Any:
"""Plt the time series values for multiple metrics/runs

Parameters
Expand All @@ -1041,7 +1082,7 @@ def plot_metrics(
xaxis : str, ('step' | 'time' | 'timestep')
the x axis to plot against
max_points : int, optional
maximum number of data points, by default -1 (all)
maximum number of data points, by default None (all)

Returns
-------
Expand All @@ -1059,7 +1100,7 @@ def plot_metrics(
if not isinstance(metric_names, list):
raise ValueError("Invalid names specified, must be a list of metric names.")

data: "DataFrame" = self.get_metric_values( # type: ignore
data: DataFrame = self.get_metric_values( # type: ignore
run_ids=run_ids,
metric_names=metric_names,
xaxis=xaxis,
Expand Down Expand Up @@ -1099,12 +1140,14 @@ def plot_metrics(

return plt.figure()

@prettify_pydantic
@pydantic.validate_call
def get_events(
self,
run_id: str,
message_contains: typing.Optional[str] = None,
start_index: typing.Optional[int] = None,
count_limit: typing.Optional[int] = None,
start_index: typing.Optional[pydantic.NonNegativeInt] = None,
count_limit: typing.Optional[pydantic.PositiveInt] = None,
) -> list[dict[str, str]]:
"""Return events for a specified run

Expand Down Expand Up @@ -1160,6 +1203,8 @@ def get_events(

return response.json().get("data", [])

@prettify_pydantic
@pydantic.validate_call
def get_alerts(
self, run_id: str, critical_only: bool = True, names_only: bool = True
) -> list[dict[str, typing.Any]]:
Expand Down
Loading
Loading