diff --git a/simvue/client.py b/simvue/client.py index e9cbff3a..019d5459 100644 --- a/simvue/client.py +++ b/simvue/client.py @@ -10,7 +10,9 @@ import logging import os import typing +import pydantic from concurrent.futures import ThreadPoolExecutor, as_completed +from pandas import DataFrame import requests @@ -21,11 +23,11 @@ ) from .serialization import deserialize_data from .types import DeserializedContent -from .utilities import check_extra, get_auth +from .utilities import check_extra, get_auth, prettify_pydantic +from .models import FOLDER_REGEX, NAME_REGEX if typing.TYPE_CHECKING: - from matplotlib.figure import Figure - from pandas import DataFrame + pass CONCURRENT_DOWNLOADS = 10 DOWNLOAD_CHUNK_SIZE = 8192 @@ -133,7 +135,11 @@ def _get_json_from_response( raise RuntimeError(error_str) - def get_run_id_from_name(self, name: str) -> str: + @prettify_pydantic + @pydantic.validate_call + def get_run_id_from_name( + self, name: typing.Annotated[str, pydantic.Field(pattern=NAME_REGEX)] + ) -> str: """Get Run ID from the server matching the specified name Assumes a unique name for this run. If multiple results are found this @@ -186,6 +192,8 @@ def get_run_id_from_name(self, name: str) -> str: raise RuntimeError("Failed to retrieve identifier for run.") return first_id + @prettify_pydantic + @pydantic.validate_call def get_run(self, run_id: str) -> typing.Optional[dict[str, typing.Any]]: """Retrieve a single run @@ -225,6 +233,8 @@ def get_run(self, run_id: str) -> typing.Optional[dict[str, typing.Any]]: ) return json_response + @prettify_pydantic + @pydantic.validate_call def get_run_name_from_id(self, run_id: str) -> str: """Retrieve the name of a run from its identifier @@ -250,6 +260,8 @@ def get_run_name_from_id(self, run_id: str) -> str: raise RuntimeError("Expected key 'name' in server response") return _name + @prettify_pydantic + @pydantic.validate_call def get_runs( self, filters: typing.Optional[list[str]], @@ -261,7 +273,7 @@ def get_runs( count: int = 100, start_index: int = 0, ) -> typing.Union[ - "DataFrame", list[dict[str, typing.Union[int, str, float, None]]], None + DataFrame, list[dict[str, typing.Union[int, str, float, None]]], None ]: """Retrieve all runs matching filters. @@ -337,6 +349,8 @@ def get_runs( else: raise RuntimeError("Failed to retrieve runs data") + @prettify_pydantic + @pydantic.validate_call def delete_run(self, run_identifier: str) -> typing.Optional[dict]: """Delete run by identifier @@ -404,13 +418,18 @@ def _get_folder_id_from_path(self, path: str) -> typing.Optional[str]: return None - def delete_runs(self, folder_name: str) -> typing.Optional[list]: + @prettify_pydantic + @pydantic.validate_call + def delete_runs( + self, folder_path: typing.Annotated[str, pydantic.Field(pattern=FOLDER_REGEX)] + ) -> typing.Optional[list]: """Delete runs in a named folder Parameters ---------- - folder_name : str - the name of the folder on which to perform deletion + folder_path : str + the path of the folder on which to perform deletion. All folder + paths are prefixed with `/` Returns ------- @@ -422,10 +441,10 @@ def delete_runs(self, folder_name: str) -> typing.Optional[list]: RuntimeError if deletion fails due to server request error """ - folder_id = self._get_folder_id_from_path(folder_name) + folder_id = self._get_folder_id_from_path(folder_path) if not folder_id: - raise ValueError(f"Could not find a folder matching '{folder_name}'") + raise ValueError(f"Could not find a folder matching '{folder_path}'") params: dict[str, bool] = {"runs_only": True, "runs": True} @@ -435,19 +454,21 @@ def delete_runs(self, folder_name: str) -> typing.Optional[list]: if response.status_code == 200: if runs := response.json().get("runs", []): - logger.debug(f"Runs from '{folder_name}' deleted successfully: {runs}") + logger.debug(f"Runs from '{folder_path}' deleted successfully: {runs}") else: logger.debug("Folder empty, no runs deleted.") return runs raise RuntimeError( - f"Deletion of runs from folder '{folder_name}' failed" + f"Deletion of runs from folder '{folder_path}' failed" f"with code {response.status_code}: {response.text}" ) + @prettify_pydantic + @pydantic.validate_call def delete_folder( self, - folder_name: str, + folder_path: typing.Annotated[str, pydantic.Field(pattern=FOLDER_REGEX)], recursive: bool = False, remove_runs: bool = False, allow_missing: bool = False, @@ -456,8 +477,8 @@ def delete_folder( Parameters ---------- - folder_name : str - name of the folder to delete + folder_path : str + name of the folder to delete. All paths are prefixed with `/` recursive : bool, optional if folder contains additional folders remove these, else return an error. Default False. @@ -477,14 +498,14 @@ def delete_folder( RuntimeError if deletion of the folder from the server failed """ - folder_id = self._get_folder_id_from_path(folder_name) + folder_id = self._get_folder_id_from_path(folder_path) if not folder_id: if allow_missing: return None else: raise RuntimeError( - f"Deletion of folder '{folder_name}' failed, " + f"Deletion of folder '{folder_path}' failed, " "folder does not exist." ) @@ -497,7 +518,7 @@ def delete_folder( json_response = self._get_json_from_response( expected_status=[200, 404], - scenario=f"Deletion of folder '{folder_name}'", + scenario=f"Deletion of folder '{folder_path}'", response=response, ) @@ -510,6 +531,8 @@ def delete_folder( runs: list[dict] = json_response.get("runs", []) return runs + @prettify_pydantic + @pydantic.validate_call def list_artifacts(self, run_id: str) -> list[dict[str, typing.Any]]: """Retrieve artifacts for a given run @@ -574,9 +597,11 @@ def _retrieve_artifact_from_server( return json_response + @prettify_pydantic + @pydantic.validate_call def get_artifact( self, run_id: str, name: str, allow_pickle: bool = False - ) -> typing.Optional[DeserializedContent]: + ) -> typing.Any: """Return the contents of a specified artifact Parameters @@ -618,6 +643,8 @@ def get_artifact( return content or response.content + @prettify_pydantic + @pydantic.validate_call def get_artifact_as_file( self, run_id: str, name: str, path: typing.Optional[str] = None ) -> None: @@ -708,6 +735,8 @@ def _assemble_artifact_downloads( return downloads + @prettify_pydantic + @pydantic.validate_call def get_artifacts_as_files( self, run_id: str, @@ -771,13 +800,18 @@ def get_artifacts_as_files( f"failed with exception: {e}" ) - def get_folder(self, folder_id: str) -> typing.Optional[dict[str, typing.Any]]: + @prettify_pydantic + @pydantic.validate_call + def get_folder( + self, folder_path: typing.Annotated[str, pydantic.Field(pattern=FOLDER_REGEX)] + ) -> typing.Optional[dict[str, typing.Any]]: """Retrieve a folder by identifier Parameters ---------- - folder_id : str - unique identifier for the folder + folder_path : str + the path of the folder to retrieve on the server. + Paths are prefixed with `/` Returns ------- @@ -789,15 +823,16 @@ def get_folder(self, folder_id: str) -> typing.Optional[dict[str, typing.Any]]: RuntimeError if there was a failure when retrieving information from the server """ - if not (_folders := self.get_folders(filters=[f"path == {folder_id}"])): + if not (_folders := self.get_folders(filters=[f"path == {folder_path}"])): return None return _folders[0] + @pydantic.validate_call def get_folders( self, filters: typing.Optional[list[str]] = None, - count: int = 100, - start_index: int = 0, + count: pydantic.PositiveInt = 100, + start_index: pydantic.NonNegativeInt = 0, ) -> list[dict[str, typing.Any]]: """Retrieve folders from the server @@ -847,6 +882,8 @@ def get_folders( return data + @prettify_pydantic + @pydantic.validate_call def get_metrics_names(self, run_id: str) -> list[str]: """Return information on all metrics within a run @@ -918,6 +955,8 @@ def _get_run_metrics_from_server( return json_response + @prettify_pydantic + @pydantic.validate_call def get_metric_values( self, metric_names: list[str], @@ -927,8 +966,8 @@ def get_metric_values( run_filters: typing.Optional[list[str]] = None, use_run_names: bool = False, aggregate: bool = False, - max_points: int = -1, - ) -> typing.Union[dict, "DataFrame", None]: + max_points: typing.Optional[pydantic.PositiveInt] = None, + ) -> typing.Union[dict, DataFrame, None]: """Retrieve the values for a given metric across multiple runs Uses filters to specify which runs should be retrieved. @@ -955,7 +994,7 @@ def get_metric_values( return results as averages (not compatible with xaxis=timestamp), default is False max_points : int, optional - maximum number of data points, by default -1 (all) + maximum number of data points, by default None (all) Returns ------- @@ -1010,7 +1049,7 @@ def get_metric_values( run_ids=run_ids, xaxis=xaxis, aggregate=aggregate, - max_points=max_points, + max_points=max_points or -1, ) if aggregate: @@ -1023,13 +1062,15 @@ def get_metric_values( ) @check_extra("plot") + @prettify_pydantic + @pydantic.validate_call def plot_metrics( self, run_ids: list[str], metric_names: list[str], xaxis: typing.Literal["step", "time"], - max_points: int = -1, - ) -> "Figure": + max_points: typing.Optional[int] = None, + ) -> typing.Any: """Plt the time series values for multiple metrics/runs Parameters @@ -1041,7 +1082,7 @@ def plot_metrics( xaxis : str, ('step' | 'time' | 'timestep') the x axis to plot against max_points : int, optional - maximum number of data points, by default -1 (all) + maximum number of data points, by default None (all) Returns ------- @@ -1059,7 +1100,7 @@ def plot_metrics( if not isinstance(metric_names, list): raise ValueError("Invalid names specified, must be a list of metric names.") - data: "DataFrame" = self.get_metric_values( # type: ignore + data: DataFrame = self.get_metric_values( # type: ignore run_ids=run_ids, metric_names=metric_names, xaxis=xaxis, @@ -1099,12 +1140,14 @@ def plot_metrics( return plt.figure() + @prettify_pydantic + @pydantic.validate_call def get_events( self, run_id: str, message_contains: typing.Optional[str] = None, - start_index: typing.Optional[int] = None, - count_limit: typing.Optional[int] = None, + start_index: typing.Optional[pydantic.NonNegativeInt] = None, + count_limit: typing.Optional[pydantic.PositiveInt] = None, ) -> list[dict[str, str]]: """Return events for a specified run @@ -1160,6 +1203,8 @@ def get_events( return response.json().get("data", []) + @prettify_pydantic + @pydantic.validate_call def get_alerts( self, run_id: str, critical_only: bool = True, names_only: bool = True ) -> list[dict[str, typing.Any]]: diff --git a/simvue/run.py b/simvue/run.py index 52ac4d81..ef6722cc 100644 --- a/simvue/run.py +++ b/simvue/run.py @@ -37,7 +37,7 @@ from .executor import Executor from .factory.proxy import Simvue from .metrics import get_gpu_metrics, get_process_cpu, get_process_memory -from .models import RunInput +from .models import RunInput, FOLDER_REGEX, NAME_REGEX from .serialization import serialize_object from .system import get_system from .metadata import git_info @@ -449,11 +449,13 @@ def _error(self, message: str, join_threads: bool = True) -> None: @pydantic.validate_call def init( self, - name: typing.Optional[str] = None, + name: typing.Optional[ + typing.Annotated[str, pydantic.Field(pattern=NAME_REGEX)] + ] = None, metadata: typing.Optional[dict[str, typing.Any]] = None, tags: typing.Optional[list[str]] = None, description: typing.Optional[str] = None, - folder: str = "/", + folder: typing.Annotated[str, pydantic.Field(pattern=FOLDER_REGEX)] = "/", running: bool = True, retention_period: typing.Optional[str] = None, resources_metrics_interval: typing.Optional[int] = HEARTBEAT_INTERVAL, @@ -1027,7 +1029,9 @@ def save_object( self, obj: typing.Any, category: typing.Literal["input", "output", "code"], - name: typing.Optional[str] = None, + name: typing.Optional[ + typing.Annotated[str, pydantic.Field(pattern=NAME_REGEX)] + ] = None, allow_pickle: bool = False, ) -> bool: """Save an object to the Simvue server @@ -1084,7 +1088,9 @@ def save_file( category: typing.Literal["input", "output", "code"], filetype: typing.Optional[str] = None, preserve_path: bool = False, - name: typing.Optional[str] = None, + name: typing.Optional[ + typing.Annotated[str, pydantic.Field(pattern=NAME_REGEX)] + ] = None, ) -> bool: """Upload file to the server @@ -1345,7 +1351,7 @@ def close(self) -> bool: @pydantic.validate_call def set_folder_details( self, - path: str, + path: typing.Annotated[str, pydantic.Field(pattern=FOLDER_REGEX)], metadata: typing.Optional[dict[str, typing.Union[int, str, float]]] = None, tags: typing.Optional[list[str]] = None, description: typing.Optional[str] = None, @@ -1448,7 +1454,7 @@ def add_alerts( @pydantic.validate_call def create_alert( self, - name: str, + name: typing.Annotated[str, pydantic.Field(pattern=NAME_REGEX)], source: typing.Literal["events", "metrics", "user"] = "metrics", description: typing.Optional[str] = None, frequency: typing.Optional[pydantic.PositiveInt] = None, diff --git a/simvue/utilities.py b/simvue/utilities.py index b3c3fe57..c0b60b96 100644 --- a/simvue/utilities.py +++ b/simvue/utilities.py @@ -102,6 +102,18 @@ def wrapper(self, *args, **kwargs) -> typing.Any: return decorator +def parse_pydantic_error(class_name: str, error: pydantic.ValidationError) -> str: + out_table: list[str] = [] + for data in json.loads(error.json()): + out_table.append([data["loc"], data["type"], data["msg"]]) + err_table = tabulate.tabulate( + out_table, + headers=["Location", "Type", "Message"], + tablefmt="fancy_grid", + ) + return f"`{class_name}` Validation:\n{err_table}" + + def skip_if_failed( failure_attr: str, ignore_exc_attr: str, @@ -146,20 +158,12 @@ def wrapper(self, *args, **kwargs) -> typing.Any: try: return class_func(self, *args, **kwargs) except pydantic.ValidationError as e: - out_table: list[str] = [] - for data in json.loads(e.json()): - out_table.append([data["loc"], data["type"], data["msg"]]) - err_table = tabulate.tabulate( - out_table, - headers=["Location", "Type", "Message"], - tablefmt="fancy_grid", - ) - err_str = f"`{class_func.__name__}` Validation:\n{err_table}" + error_str = parse_pydantic_error(class_func.__name__, e) if getattr(self, ignore_exc_attr, True): setattr(self, failure_attr, True) - logger.error(err_str) + logger.error(error_str) return on_failure_return - raise RuntimeError(err_str) + raise RuntimeError(error_str) setattr(wrapper, "__fail_safe", True) return wrapper @@ -167,6 +171,36 @@ def wrapper(self, *args, **kwargs) -> typing.Any: return decorator +def prettify_pydantic(class_func: typing.Callable) -> typing.Callable: + """Converts pydantic validation errors to a table + + Parameters + ---------- + class_func : typing.Callable + function to wrap + + Returns + ------- + typing.Callable + wrapped function + + Raises + ------ + RuntimeError + the formatted validation error + """ + + @functools.wraps(class_func) + def wrapper(self, *args, **kwargs) -> typing.Any: + try: + return class_func(self, *args, **kwargs) + except pydantic.ValidationError as e: + error_str = parse_pydantic_error(class_func.__name__, e) + raise RuntimeError(error_str) + + return wrapper + + def get_auth(): """ Get the URL and access token diff --git a/tests/refactor/test_client.py b/tests/refactor/test_client.py index fcf943cd..76e7a1ac 100644 --- a/tests/refactor/test_client.py +++ b/tests/refactor/test_client.py @@ -145,7 +145,7 @@ def test_get_run(create_test_run: tuple[sv_run.Run, dict]) -> None: def test_get_folder(create_test_run: tuple[sv_run.Run, dict]) -> None: client = svc.Client() assert (folders := client.get_folders()) - assert (folder_id := folders[0].get("id")) + assert (folder_id := folders[1].get("path")) assert client.get_folder(folder_id)