From af7d57661d4daf2aff3fbe77613a69b904d64210 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristian=20Zar=C4=99bski?= Date: Tue, 21 May 2024 09:31:06 +0100 Subject: [PATCH 1/9] Refactor serialization and fix tuple without comma bug --- simvue/client.py | 4 +- simvue/run.py | 184 +++++++++--------- simvue/serialization.py | 118 ++++++----- tests/functional/common.py | 7 +- tests/functional/test_artifacts_code.py | 5 +- .../functional/test_artifacts_code_created.py | 2 +- tests/functional/test_artifacts_input.py | 2 +- .../test_artifacts_input_created.py | 2 +- tests/functional/test_artifacts_output.py | 2 +- .../test_artifacts_output_created.py | 2 +- .../functional/test_offline_artifacts_code.py | 2 +- .../test_offline_artifacts_code_created.py | 2 +- .../test_offline_artifacts_input.py | 2 +- .../test_offline_artifacts_input_created.py | 2 +- .../test_offline_artifacts_output.py | 2 +- tests/refactor/conftest.py | 8 +- tests/refactor/test_run_class.py | 62 +++++- .../unit/test_matplotlib_figure_mime_type.py | 4 +- tests/unit/test_numpy_array_mime_type.py | 4 +- tests/unit/test_numpy_array_serialization.py | 6 +- tests/unit/test_pandas_dataframe_mimetype.py | 4 +- .../test_pandas_dataframe_serialization.py | 6 +- tests/unit/test_pickle_serialization.py | 7 +- tests/unit/test_plotly_figure_mime_type.py | 4 +- tests/unit/test_pytorch_tensor_mime_type.py | 4 +- .../unit/test_pytorch_tensor_serialization.py | 6 +- tests/unit/test_run_init_folder.py | 1 - tests/unit/test_run_init_metadata.py | 1 - tests/unit/test_run_init_tags.py | 1 - 29 files changed, 261 insertions(+), 195 deletions(-) diff --git a/simvue/client.py b/simvue/client.py index b4581814..d9ca220d 100644 --- a/simvue/client.py +++ b/simvue/client.py @@ -19,7 +19,7 @@ to_dataframe, parse_run_set_metrics, ) -from .serialization import Deserializer +from .serialization import deserialize_data from .types import DeserializedContent from .utilities import check_extra, get_auth @@ -608,7 +608,7 @@ def get_artifact( response = requests.get(url, timeout=DOWNLOAD_TIMEOUT) response.raise_for_status() - content: typing.Optional[DeserializedContent] = Deserializer().deserialize( + content: typing.Optional[DeserializedContent] = deserialize_data( response.content, mimetype, allow_pickle ) diff --git a/simvue/run.py b/simvue/run.py index b9ad2ed4..59142d9f 100644 --- a/simvue/run.py +++ b/simvue/run.py @@ -37,7 +37,7 @@ from .factory.proxy import Simvue from .metrics import get_gpu_metrics, get_process_cpu, get_process_memory from .models import RunInput -from .serialization import Serializer +from .serialization import serialize_object from .system import get_system from .utilities import ( calculate_sha256, @@ -51,6 +51,7 @@ if typing.TYPE_CHECKING: from .factory.proxy import SimvueBaseClass from .factory.dispatch import DispatcherBaseClass + from .types import DeserializedContent UPLOAD_TIMEOUT: int = 30 HEARTBEAT_INTERVAL: int = 60 @@ -141,7 +142,7 @@ def __exit__( else: if self._active: self.log_event(f"{exc_type.__name__}: {value}") - if exc_type.__name__ in ("KeyboardInterrupt") and self._active: + if exc_type.__name__ in ("KeyboardInterrupt",) and self._active: self.set_status("terminated") else: if traceback and self._active: @@ -930,17 +931,70 @@ def log_metrics( @skip_if_failed("_aborted", "_suppress_errors", False) @pydantic.validate_call - def save( + def save_object( self, - filename: str, + obj: typing.Any, + category: typing.Literal["input", "output", "code"], + name: typing.Optional[str] = None, + allow_pickle: bool = False, + ) -> bool: + obj: DeserializedContent + serialized = serialize_object(obj, allow_pickle) + + if not serialized or not (pickled := serialized[0]): + self._error(f"Failed to serialize '{obj}'") + return False + + data_type = serialized[1] + + if not data_type and not allow_pickle: + self._error("Unable to save Python object, set allow_pickle to True") + return False + + data: dict[str, typing.Any] = { + "pickled": pickled, + "type": data_type, + "checksum": calculate_sha256(pickled, False), + "originalPath": "", + "size": sys.getsizeof(pickled), + "name": name, + "run": self._name, + "category": category, + "storage": self._storage_id, + } + + # Register file + return self._simvue is not None and self._simvue.save_file(data) is not None + + @skip_if_failed("_aborted", "_suppress_errors", False) + @pydantic.validate_call + def save_file( + self, + filename: pydantic.FilePath, category: typing.Literal["input", "output", "code"], filetype: typing.Optional[str] = None, preserve_path: bool = False, name: typing.Optional[str] = None, - allow_pickle: bool = False, ) -> bool: - """ - Upload file or object + """Upload file to the server + + Parameters + ---------- + filename : pydantic.FilePath + path to the file to upload + category : Literal['input', 'output', 'code'] + category of file with respect to this run + filetype : str, optional + the MIME file type else this is deduced, by default None + preserve_path : bool, optional + whether to preserve the path during storage, by default False + name : str, optional + name to associate with this file, by default None + + Returns + ------- + bool + whether the upload was successful """ if self._mode == "disabled": return True @@ -953,96 +1007,48 @@ def save( self._error("Cannot upload output files for runs in the created state") return False - is_file: bool = False - - if isinstance(filename, str): - if not os.path.isfile(filename): - self._error(f"File {filename} does not exist") - return False - else: - is_file = True - - if filetype: - mimetypes_valid = ["application/vnd.plotly.v1+json"] - mimetypes.init() - for _, value in mimetypes.types_map.items(): - mimetypes_valid.append(value) - - if filetype not in mimetypes_valid: - self._error("Invalid MIME type specified") - return False - - data: dict[str, typing.Any] = {} - - if preserve_path: - data["name"] = filename - if data["name"].startswith("./"): - data["name"] = data["name"][2:] - elif is_file: - data["name"] = os.path.basename(filename) + mimetypes.init() + mimetypes_valid = ["application/vnd.plotly.v1+json"] + mimetypes_valid += list(mimetypes.types_map.values()) - if name: - data["name"] = name - - data["run"] = self._name - data["category"] = category + if filetype and filetype not in mimetypes_valid: + self._error(f"Invalid MIME type '{filetype}' specified") + return False - if is_file: - data["size"] = os.path.getsize(filename) - data["originalPath"] = os.path.abspath( - os.path.expanduser(os.path.expandvars(filename)) - ) - data["checksum"] = calculate_sha256(filename, is_file) + stored_file_name: str = f"{filename}" - if data["size"] == 0: - click.secho( - "WARNING: saving zero-sized files not currently supported", - bold=True, - fg="yellow", - ) - return True + if preserve_path and stored_file_name.startswith("./"): + stored_file_name = stored_file_name[2:] + elif not preserve_path: + stored_file_name = os.path.basename(filename) # Determine mimetype - mimetype = None - if not filetype and is_file: - mimetypes.init() - mimetype = mimetypes.guess_type(filename)[0] - if not mimetype: - mimetype = "application/octet-stream" - elif is_file: - mimetype = filetype - - if mimetype: - data["type"] = mimetype + if not (mimetype := filetype): + mimetype = mimetypes.guess_type(filename)[0] or "application/octet-stream" - if not is_file: - serialized = Serializer().serialize(filename, allow_pickle) - - if not serialized or not (pickled := serialized[0]): - self._error(f"Failed to serialize '{filename}'") - return False - - data_type = serialized[1] - - data["pickled"] = pickled - data["type"] = data_type - - if not data["type"] and not allow_pickle: - self._error("Unable to save Python object, set allow_pickle to True") - return False - - data["checksum"] = calculate_sha256(pickled, False) - data["originalPath"] = "" - data["size"] = sys.getsizeof(pickled) + data: dict[str, typing.Any] = { + "name": name or stored_file_name, + "run": self._name, + "type": mimetype, + "storage": self._storage_id, + "category": category, + "size": (file_size := os.path.getsize(filename)), + "originalPath": os.path.abspath( + os.path.expanduser(os.path.expandvars(filename)) + ), + "checksum": calculate_sha256(f"{filename}", True), + } - if self._storage_id: - data["storage"] = self._storage_id + if not file_size: + click.secho( + "WARNING: saving zero-sized files not currently supported", + bold=True, + fg="yellow", + ) + return True # Register file - if not self._simvue.save_file(data): - return False - - return True + return self._simvue.save_file(data) is not None @skip_if_failed("_aborted", "_suppress_errors", False) @pydantic.validate_call @@ -1076,7 +1082,7 @@ def save_directory( for dirpath, _, filenames in directory.walk(): for filename in filenames: if (full_path := dirpath.joinpath(filename)).is_file(): - self.save(f"{full_path}", category, filetype, preserve_path) + self.save_file(full_path, category, filetype, preserve_path) return True diff --git a/simvue/serialization.py b/simvue/serialization.py index 74c4621e..80e14fa3 100644 --- a/simvue/serialization.py +++ b/simvue/serialization.py @@ -1,20 +1,27 @@ +""" +Object Serialization +==================== + +Contains serializers for storage of objects on the Simvue server +""" + +import typing import pickle from io import BytesIO -from .utilities import check_extra - +if typing.TYPE_CHECKING: + from pandas import DataFrame + from plotly.graph_objects import Figure + from torch import Tensor + from typing_extensions import Buffer + from .types import DeserializedContent -class Serializer: - def serialize(self, data, allow_pickle=False): - serializer = get_serializer(data, allow_pickle) - if serializer: - return serializer(data) - return None, None +from .utilities import check_extra -def _is_torch_tensor(data): +def _is_torch_tensor(data: typing.Any) -> bool: """ - Check if a dictionary is a PyTorch tensor or state dict + Check if value is a PyTorch tensor or state dict """ module_name = data.__class__.__module__ class_name = data.__class__.__name__ @@ -34,50 +41,63 @@ def _is_torch_tensor(data): return False -def get_serializer(data, allow_pickle): - """ - Determine which serializer to use +def serialize_object( + data: typing.Any, allow_pickle: bool +) -> typing.Optional[tuple[str, str]]: + """Determine which serializer to use for the given object + + Parameters + ---------- + data : typing.Any + object to serialize + allow_pickle : bool + whether pickling is allowed + + Returns + ------- + Callable[[typing.Any], tuple[str, str]] + the serializer to user """ module_name = data.__class__.__module__ class_name = data.__class__.__name__ if module_name == "plotly.graph_objs._figure" and class_name == "Figure": - return _serialize_plotly_figure + return _serialize_plotly_figure(data) elif module_name == "matplotlib.figure" and class_name == "Figure": - return _serialize_matplotlib_figure + return _serialize_matplotlib_figure(data) elif module_name == "numpy" and class_name == "ndarray": - return _serialize_numpy_array + return _serialize_numpy_array(data) elif module_name == "pandas.core.frame" and class_name == "DataFrame": - return _serialize_dataframe + return _serialize_dataframe(data) elif _is_torch_tensor(data): - return _serialize_torch_tensor + return _serialize_torch_tensor(data) elif module_name == "builtins" and class_name == "module" and not allow_pickle: try: - import matplotlib + import matplotlib.pyplot if data == matplotlib.pyplot: - return _serialize_matplotlib + return _serialize_matplotlib(data) except ImportError: pass if allow_pickle: - return _serialize_pickle + return _serialize_pickle(data) return None @check_extra("plot") -def _serialize_plotly_figure(data): +def _serialize_plotly_figure(data: typing.Any) -> typing.Optional[tuple[str, str]]: try: import plotly except ImportError: - return + return None mimetype = "application/vnd.plotly.v1+json" data = plotly.io.to_json(data, "json") return data, mimetype @check_extra("plot") -def _serialize_matplotlib(data): +def _serialize_matplotlib(data: typing.Any) -> typing.Optional[tuple[str, str]]: try: import plotly except ImportError: @@ -88,7 +108,7 @@ def _serialize_matplotlib(data): @check_extra("plot") -def _serialize_matplotlib_figure(data): +def _serialize_matplotlib_figure(data: typing.Any) -> typing.Optional[tuple[str, str]]: try: import plotly except ImportError: @@ -99,7 +119,7 @@ def _serialize_matplotlib_figure(data): @check_extra("dataset") -def _serialize_numpy_array(data): +def _serialize_numpy_array(data: typing.Any) -> typing.Optional[tuple[str, str]]: try: import numpy as np except ImportError: @@ -115,7 +135,7 @@ def _serialize_numpy_array(data): @check_extra("dataset") -def _serialize_dataframe(data): +def _serialize_dataframe(data: typing.Any) -> typing.Optional[tuple[str, str]]: mimetype = "application/vnd.simvue.df.v1" mfile = BytesIO() data.to_csv(mfile) @@ -125,7 +145,7 @@ def _serialize_dataframe(data): @check_extra("torch") -def _serialize_torch_tensor(data): +def _serialize_torch_tensor(data: typing.Any) -> typing.Optional[tuple[str, str]]: try: import torch except ImportError: @@ -140,41 +160,35 @@ def _serialize_torch_tensor(data): return data, mimetype -def _serialize_pickle(data): +def _serialize_pickle(data: typing.Any) -> typing.Optional[tuple[str, str]]: mimetype = "application/octet-stream" data = pickle.dumps(data) return data, mimetype -class Deserializer: - def deserialize(self, data, mimetype, allow_pickle=False): - deserializer = get_deserializer(mimetype, allow_pickle) - if deserializer: - return deserializer(data) - return None - - -def get_deserializer(mimetype, allow_pickle): +def deserialize_data( + data: "Buffer", mimetype: str, allow_pickle: bool +) -> typing.Optional["DeserializedContent"]: """ Determine which deserializer to use """ if mimetype == "application/vnd.plotly.v1+json": - return _deserialize_plotly_figure + return _deserialize_plotly_figure(data) elif mimetype == "application/vnd.plotly.v1+json": - return _deserialize_matplotlib_figure + return _deserialize_matplotlib_figure(data) elif mimetype == "application/vnd.simvue.numpy.v1": - return _deserialize_numpy_array + return _deserialize_numpy_array(data) elif mimetype == "application/vnd.simvue.df.v1": - return _deserialize_dataframe + return _deserialize_dataframe(data) elif mimetype == "application/vnd.simvue.torch.v1": - return _deserialize_torch_tensor + return _deserialize_torch_tensor(data) elif mimetype == "application/octet-stream" and allow_pickle: - return _deserialize_pickle + return _deserialize_pickle(data) return None @check_extra("plot") -def _deserialize_plotly_figure(data): +def _deserialize_plotly_figure(data: "Buffer") -> typing.Optional["Figure"]: try: import plotly except ImportError: @@ -184,7 +198,7 @@ def _deserialize_plotly_figure(data): @check_extra("plot") -def _deserialize_matplotlib_figure(data): +def _deserialize_matplotlib_figure(data: "Buffer") -> typing.Optional["Figure"]: try: import plotly except ImportError: @@ -194,7 +208,7 @@ def _deserialize_matplotlib_figure(data): @check_extra("dataset") -def _deserialize_numpy_array(data): +def _deserialize_numpy_array(data: "Buffer") -> typing.Optional[typing.Any]: try: import numpy as np except ImportError: @@ -208,7 +222,7 @@ def _deserialize_numpy_array(data): @check_extra("dataset") -def _deserialize_dataframe(data): +def _deserialize_dataframe(data: "Buffer") -> typing.Optional["DataFrame"]: try: import pandas as pd except ImportError: @@ -217,12 +231,11 @@ def _deserialize_dataframe(data): mfile = BytesIO(data) mfile.seek(0) - data = pd.read_csv(mfile, index_col=0) - return data + return pd.read_csv(mfile, index_col=0) @check_extra("torch") -def _deserialize_torch_tensor(data): +def _deserialize_torch_tensor(data: "Buffer") -> typing.Optional["Tensor"]: try: import torch except ImportError: @@ -231,8 +244,7 @@ def _deserialize_torch_tensor(data): mfile = BytesIO(data) mfile.seek(0) - data = torch.load(mfile) - return data + return torch.load(mfile) def _deserialize_pickle(data): diff --git a/tests/functional/common.py b/tests/functional/common.py index 5420189d..4fb50ae0 100644 --- a/tests/functional/common.py +++ b/tests/functional/common.py @@ -1,4 +1,5 @@ import configparser +import pathlib import os import uuid @@ -20,9 +21,9 @@ def update_config(): config.write(configfile) FOLDER = '/test-%s' % str(uuid.uuid4()) -FILENAME1 = str(uuid.uuid4()) -FILENAME2 = str(uuid.uuid4()) -FILENAME3 = str(uuid.uuid4()) +FILENAME1 = pathlib.Path(str(uuid.uuid4())) +FILENAME2 = pathlib.Path(str(uuid.uuid4())) +FILENAME3 = pathlib.Path(str(uuid.uuid4())) RUNNAME1 = 'test-%s' % str(uuid.uuid4()) RUNNAME2 = 'test-%s' % str(uuid.uuid4()) RUNNAME3 = 'test-%s' % str(uuid.uuid4()) diff --git a/tests/functional/test_artifacts_code.py b/tests/functional/test_artifacts_code.py index 04f6a3e0..6ea0f220 100644 --- a/tests/functional/test_artifacts_code.py +++ b/tests/functional/test_artifacts_code.py @@ -1,12 +1,9 @@ -import configparser import filecmp import os import shutil -import time import unittest import uuid from simvue import Run, Client -from simvue.sender import sender import common @@ -22,7 +19,7 @@ def test_artifact_code(self): content = str(uuid.uuid4()) with open(common.FILENAME1, 'w') as fh: fh.write(content) - run.save(common.FILENAME1, 'code') + run.save_file(common.FILENAME1, 'code') run.close() diff --git a/tests/functional/test_artifacts_code_created.py b/tests/functional/test_artifacts_code_created.py index 0a03e3cb..667e2820 100644 --- a/tests/functional/test_artifacts_code_created.py +++ b/tests/functional/test_artifacts_code_created.py @@ -23,7 +23,7 @@ def test_artifact_code_created(self): content = str(uuid.uuid4()) with open(common.FILENAME1, 'w') as fh: fh.write(content) - run.save(common.FILENAME1, 'code') + run.save_file(common.FILENAME1, 'code') shutil.rmtree('./test', ignore_errors=True) os.mkdir('./test') diff --git a/tests/functional/test_artifacts_input.py b/tests/functional/test_artifacts_input.py index 841c0321..bde1a9e9 100644 --- a/tests/functional/test_artifacts_input.py +++ b/tests/functional/test_artifacts_input.py @@ -22,7 +22,7 @@ def test_artifact_input(self): content = str(uuid.uuid4()) with open(common.FILENAME2, 'w') as fh: fh.write(content) - run.save(common.FILENAME2, 'input') + run.save_file(common.FILENAME2, 'input') run.close() diff --git a/tests/functional/test_artifacts_input_created.py b/tests/functional/test_artifacts_input_created.py index e990ca0c..09dc0510 100644 --- a/tests/functional/test_artifacts_input_created.py +++ b/tests/functional/test_artifacts_input_created.py @@ -23,7 +23,7 @@ def test_artifact_input_created(self): content = str(uuid.uuid4()) with open(common.FILENAME2, 'w') as fh: fh.write(content) - run.save(common.FILENAME2, 'input') + run.save_file(common.FILENAME2, 'input') shutil.rmtree('./test', ignore_errors=True) os.mkdir('./test') diff --git a/tests/functional/test_artifacts_output.py b/tests/functional/test_artifacts_output.py index 7899d474..f41a8c2c 100644 --- a/tests/functional/test_artifacts_output.py +++ b/tests/functional/test_artifacts_output.py @@ -22,7 +22,7 @@ def test_artifact_output(self): content = str(uuid.uuid4()) with open(common.FILENAME3, 'w') as fh: fh.write(content) - run.save(common.FILENAME3, 'output') + run.save_file(common.FILENAME3, 'output') run.close() diff --git a/tests/functional/test_artifacts_output_created.py b/tests/functional/test_artifacts_output_created.py index 38517698..fc5349a8 100644 --- a/tests/functional/test_artifacts_output_created.py +++ b/tests/functional/test_artifacts_output_created.py @@ -24,7 +24,7 @@ def test_artifact_output_created(self): fh.write(content) with self.assertRaises(Exception) as context: - run.save(common.FILENAME3, 'output') + run.save_file(common.FILENAME3, 'output') self.assertTrue('Cannot upload output files for runs in the created state' in str(context.exception)) diff --git a/tests/functional/test_offline_artifacts_code.py b/tests/functional/test_offline_artifacts_code.py index 06be6daf..6cbc247b 100644 --- a/tests/functional/test_offline_artifacts_code.py +++ b/tests/functional/test_offline_artifacts_code.py @@ -29,7 +29,7 @@ def test_artifact_code_offline(self): content = str(uuid.uuid4()) with open(common.FILENAME1, 'w') as fh: fh.write(content) - run.save(common.FILENAME1, 'code') + run.save_file(common.FILENAME1, 'code') run.close() diff --git a/tests/functional/test_offline_artifacts_code_created.py b/tests/functional/test_offline_artifacts_code_created.py index cc2b5ef1..51a7e700 100644 --- a/tests/functional/test_offline_artifacts_code_created.py +++ b/tests/functional/test_offline_artifacts_code_created.py @@ -30,7 +30,7 @@ def test_artifact_code_offline(self): content = str(uuid.uuid4()) with open(common.FILENAME1, "w") as fh: fh.write(content) - run.save(common.FILENAME1, "code") + run.save_file(common.FILENAME1, "code") sender() diff --git a/tests/functional/test_offline_artifacts_input.py b/tests/functional/test_offline_artifacts_input.py index faa14450..0c19dc8c 100644 --- a/tests/functional/test_offline_artifacts_input.py +++ b/tests/functional/test_offline_artifacts_input.py @@ -30,7 +30,7 @@ def test_artifact_input_offline(self): content = str(uuid.uuid4()) with open(common.FILENAME2, "w") as fh: fh.write(content) - run.save(common.FILENAME2, "input") + run.save_file(common.FILENAME2, "input") run.close() diff --git a/tests/functional/test_offline_artifacts_input_created.py b/tests/functional/test_offline_artifacts_input_created.py index d5a12f55..1f29ad2f 100644 --- a/tests/functional/test_offline_artifacts_input_created.py +++ b/tests/functional/test_offline_artifacts_input_created.py @@ -30,7 +30,7 @@ def test_artifact_input_offline(self): content = str(uuid.uuid4()) with open(common.FILENAME2, "w") as fh: fh.write(content) - run.save(common.FILENAME2, "input") + run.save_file(common.FILENAME2, "input") sender() diff --git a/tests/functional/test_offline_artifacts_output.py b/tests/functional/test_offline_artifacts_output.py index c5808494..4cfe747c 100644 --- a/tests/functional/test_offline_artifacts_output.py +++ b/tests/functional/test_offline_artifacts_output.py @@ -30,7 +30,7 @@ def test_artifact_output_offline(self): content = str(uuid.uuid4()) with open(common.FILENAME3, "w") as fh: fh.write(content) - run.save(common.FILENAME3, "output") + run.save_file(common.FILENAME3, "output") run.close() diff --git a/tests/refactor/conftest.py b/tests/refactor/conftest.py index 81741d05..63c46202 100644 --- a/tests/refactor/conftest.py +++ b/tests/refactor/conftest.py @@ -114,19 +114,19 @@ def setup_test_run(run: sv_run.Run, create_objects: bool): with tempfile.TemporaryDirectory() as tempd: with open((test_file := os.path.join(tempd, "test_file.txt")), "w") as out_f: out_f.write("This is a test file") - run.save(test_file, category="input", name="test_file") + run.save_file(test_file, category="input", name="test_file") TEST_DATA["file_1"] = "test_file" with open((test_json := os.path.join(tempd, f"test_attrs_{fix_use_id}.json")), "w") as out_f: json.dump(TEST_DATA, out_f, indent=2) - run.save(test_json, category="output", name="test_attributes") + run.save_file(test_json, category="output", name="test_attributes") TEST_DATA["file_2"] = "test_attributes" - with open((test_script := os.path.join(tempd, f"test_script.py")), "w") as out_f: + with open((test_script := os.path.join(tempd, "test_script.py")), "w") as out_f: out_f.write( "print('Hello World!')" ) - run.save(test_script, category="code", name="test_empty_file") + run.save_file(test_script, category="code", name="test_empty_file") TEST_DATA["file_3"] = "test_empty_file" time.sleep(1.) diff --git a/tests/refactor/test_run_class.py b/tests/refactor/test_run_class.py index 58df8d97..b058ee65 100644 --- a/tests/refactor/test_run_class.py +++ b/tests/refactor/test_run_class.py @@ -3,7 +3,9 @@ import typing import contextlib import inspect +import tempfile import uuid +import pathlib import concurrent.futures import random @@ -273,7 +275,7 @@ def test_suppressed_errors( name="test_suppressed_errors", folder="/simvue_unit_testing", tags=["simvue_client_unit_tests"], - retention_period="1 hour" + retention_period="1 hour", ) run.config(suppress_errors=True) @@ -286,17 +288,69 @@ def test_suppressed_errors( assert setup_logging.counts[0] == len(decorated_funcs) + 1 else: assert setup_logging.counts[0] == len(decorated_funcs) - + @pytest.mark.run def test_set_folder_details() -> None: with sv_run.Run() as run: - folder_name: str ="/simvue_unit_test_folder" + folder_name: str = "/simvue_unit_test_folder" description: str = "test description" tags: list[str] = ["simvue_client_unit_tests", "test_set_folder_details"] run.init(folder=folder_name) run.set_folder_details(path=folder_name, tags=tags, description=description) - + client = sv_cl.Client() assert (folder := client.get_folders([f"path == {folder_name}"])[0])["tags"] == tags assert folder["description"] == description + + +@pytest.mark.run +@pytest.mark.parametrize("valid_mimetype", (True, False), ids=("valid_mime", "invalid_mime")) +@pytest.mark.parametrize("preserve_path", (True, False), ids=("preserve_path", "modified_path")) +@pytest.mark.parametrize("name", ("test_file", None), ids=("named", "nameless")) +@pytest.mark.parametrize("allow_pickle", (True, False), ids=("pickled", "unpickled")) +@pytest.mark.parametrize("empty_file", (True, False), ids=("empty", "content")) +def test_save_file( + create_plain_run: typing.Tuple[sv_run.Run, dict], + valid_mimetype: bool, + preserve_path: bool, + name: typing.Optional[str], + allow_pickle: bool, + empty_file: bool, + capfd +) -> None: + simvue_run, _ = create_plain_run + file_type: str = 'text/plain' if valid_mimetype else 'text/text' + with tempfile.TemporaryDirectory() as tempd: + with open( + ( + out_name := pathlib.Path(tempd).joinpath("test_file.txt") + ), + "w", + ) as out_f: + out_f.write("test data entry" if not empty_file else "") + + if valid_mimetype: + simvue_run.save_file( + out_name, + category="input", + filetype=file_type, + preserve_path=preserve_path, + name=name, + ) + else: + with pytest.raises(RuntimeError): + simvue_run.save_file( + out_name, + category="input", + filetype=file_type, + preserve_path=preserve_path + ) + return + + variable = capfd.readouterr() + with capfd.disabled(): + if empty_file: + assert variable.out == "WARNING: saving zero-sized files not currently supported\n" + + diff --git a/tests/unit/test_matplotlib_figure_mime_type.py b/tests/unit/test_matplotlib_figure_mime_type.py index fba52d64..a3ae9bcc 100644 --- a/tests/unit/test_matplotlib_figure_mime_type.py +++ b/tests/unit/test_matplotlib_figure_mime_type.py @@ -1,4 +1,4 @@ -from simvue.serialization import Serializer, Deserializer +from simvue.serialization import serialize_object import matplotlib.pyplot as plt def test_matplotlib_figure_mime_type(): @@ -8,6 +8,6 @@ def test_matplotlib_figure_mime_type(): plt.plot([1, 2, 3, 4]) figure = plt.gcf() - _, mime_type = Serializer().serialize(figure) + _, mime_type = serialize_object(figure, False) assert (mime_type == 'application/vnd.plotly.v1+json') diff --git a/tests/unit/test_numpy_array_mime_type.py b/tests/unit/test_numpy_array_mime_type.py index ca16806a..7523d30b 100644 --- a/tests/unit/test_numpy_array_mime_type.py +++ b/tests/unit/test_numpy_array_mime_type.py @@ -1,4 +1,4 @@ -from simvue.serialization import Serializer, Deserializer +from simvue.serialization import serialize_object import numpy as np def test_numpy_array_mime_type(): @@ -6,6 +6,6 @@ def test_numpy_array_mime_type(): Check that the mimetype for numpy arrays is correct """ array = np.array([1, 2, 3, 4, 5]) - _, mime_type = Serializer().serialize(array) + _, mime_type = serialize_object(array, False) assert (mime_type == 'application/vnd.simvue.numpy.v1') diff --git a/tests/unit/test_numpy_array_serialization.py b/tests/unit/test_numpy_array_serialization.py index d7c952a1..0f713cdd 100644 --- a/tests/unit/test_numpy_array_serialization.py +++ b/tests/unit/test_numpy_array_serialization.py @@ -1,4 +1,4 @@ -from simvue.serialization import Serializer, Deserializer +from simvue.serialization import serialize_object, deserialize_data import numpy as np def test_numpy_array_serialization(): @@ -7,7 +7,7 @@ def test_numpy_array_serialization(): """ array = np.array([1, 2, 3, 4, 5]) - serialized, mime_type = Serializer().serialize(array) - array_out = Deserializer().deserialize(serialized, mime_type) + serialized, mime_type = serialize_object(array, False) + array_out = deserialize_data(serialized, mime_type, False) assert (array == array_out).all() diff --git a/tests/unit/test_pandas_dataframe_mimetype.py b/tests/unit/test_pandas_dataframe_mimetype.py index e72a5491..2a1923b9 100644 --- a/tests/unit/test_pandas_dataframe_mimetype.py +++ b/tests/unit/test_pandas_dataframe_mimetype.py @@ -1,5 +1,5 @@ import pandas as pd -from simvue.serialization import Serializer, Deserializer +from simvue.serialization import serialize_object def test_pandas_dataframe_mimetype(): """ @@ -8,6 +8,6 @@ def test_pandas_dataframe_mimetype(): data = {'col1': [1, 2], 'col2': [3, 4]} df = pd.DataFrame(data=data) - _, mime_type = Serializer().serialize(df) + _, mime_type = serialize_object(df, False) assert (mime_type == 'application/vnd.simvue.df.v1') diff --git a/tests/unit/test_pandas_dataframe_serialization.py b/tests/unit/test_pandas_dataframe_serialization.py index e1676ae1..52d60285 100644 --- a/tests/unit/test_pandas_dataframe_serialization.py +++ b/tests/unit/test_pandas_dataframe_serialization.py @@ -1,4 +1,4 @@ -from simvue.serialization import Serializer, Deserializer +from simvue.serialization import serialize_object, deserialize_data import pandas as pd def test_pandas_dataframe_serialization(): @@ -8,7 +8,7 @@ def test_pandas_dataframe_serialization(): data = {'col1': [1, 2], 'col2': [3, 4]} df = pd.DataFrame(data=data) - serialized, mime_type = Serializer().serialize(df) - df_out = Deserializer().deserialize(serialized, mime_type) + serialized, mime_type = serialize_object(df, False) + df_out = deserialize_data(serialized, mime_type, False) assert (df.equals(df_out)) diff --git a/tests/unit/test_pickle_serialization.py b/tests/unit/test_pickle_serialization.py index aac4ed55..60833665 100644 --- a/tests/unit/test_pickle_serialization.py +++ b/tests/unit/test_pickle_serialization.py @@ -1,5 +1,4 @@ -import pandas as pd -from simvue.serialization import Serializer, Deserializer +from simvue.serialization import deserialize_data, serialize_object def test_pickle_serialization(): """ @@ -7,7 +6,7 @@ def test_pickle_serialization(): """ data = {'a': 1.0, 'b': 'test'} - serialized, mime_type = Serializer().serialize(data, allow_pickle=True) - data_out = Deserializer().deserialize(serialized, mime_type, allow_pickle=True) + serialized, mime_type = serialize_object(data, allow_pickle=True) + data_out = deserialize_data(serialized, mime_type, allow_pickle=True) assert (data == data_out) diff --git a/tests/unit/test_plotly_figure_mime_type.py b/tests/unit/test_plotly_figure_mime_type.py index 3acb14e6..b7f6e62c 100644 --- a/tests/unit/test_plotly_figure_mime_type.py +++ b/tests/unit/test_plotly_figure_mime_type.py @@ -1,4 +1,4 @@ -from simvue.serialization import Serializer, Deserializer +from simvue.serialization import serialize_object import matplotlib.pyplot as plt import plotly @@ -10,6 +10,6 @@ def test_plotly_figure_mime_type(): figure = plt.gcf() plotly_figure = plotly.tools.mpl_to_plotly(figure) - _, mime_type = Serializer().serialize(plotly_figure) + _, mime_type = serialize_object(plotly_figure, False) assert (mime_type == 'application/vnd.plotly.v1+json') diff --git a/tests/unit/test_pytorch_tensor_mime_type.py b/tests/unit/test_pytorch_tensor_mime_type.py index d7a53ff0..013a326d 100644 --- a/tests/unit/test_pytorch_tensor_mime_type.py +++ b/tests/unit/test_pytorch_tensor_mime_type.py @@ -1,4 +1,4 @@ -from simvue.serialization import Serializer, Deserializer +from simvue.serialization import serialize_object import torch def test_pytorch_tensor_mime_type(): @@ -7,6 +7,6 @@ def test_pytorch_tensor_mime_type(): """ torch.manual_seed(1724) array = torch.rand(2, 3) - _, mime_type = Serializer().serialize(array) + _, mime_type = serialize_object(array, False) assert (mime_type == 'application/vnd.simvue.torch.v1') diff --git a/tests/unit/test_pytorch_tensor_serialization.py b/tests/unit/test_pytorch_tensor_serialization.py index f6011ec8..9fd3365f 100644 --- a/tests/unit/test_pytorch_tensor_serialization.py +++ b/tests/unit/test_pytorch_tensor_serialization.py @@ -1,5 +1,5 @@ import torch -from simvue.serialization import Serializer, Deserializer +from simvue.serialization import serialize_object, deserialize_data def test_pytorch_tensor_serialization(): """ @@ -8,7 +8,7 @@ def test_pytorch_tensor_serialization(): torch.manual_seed(1724) array = torch.rand(2, 3) - serialized, mime_type = Serializer().serialize(array) - array_out = Deserializer().deserialize(serialized, mime_type) + serialized, mime_type = serialize_object(array, False) + array_out = deserialize_data(serialized, mime_type, False) assert (array == array_out).all() diff --git a/tests/unit/test_run_init_folder.py b/tests/unit/test_run_init_folder.py index 6561c13b..33da8773 100644 --- a/tests/unit/test_run_init_folder.py +++ b/tests/unit/test_run_init_folder.py @@ -1,4 +1,3 @@ -import os from simvue import Run import pytest diff --git a/tests/unit/test_run_init_metadata.py b/tests/unit/test_run_init_metadata.py index fcf2e3dc..517dc645 100644 --- a/tests/unit/test_run_init_metadata.py +++ b/tests/unit/test_run_init_metadata.py @@ -1,4 +1,3 @@ -import os from simvue import Run import pytest diff --git a/tests/unit/test_run_init_tags.py b/tests/unit/test_run_init_tags.py index b11247b3..a352b5b5 100644 --- a/tests/unit/test_run_init_tags.py +++ b/tests/unit/test_run_init_tags.py @@ -1,4 +1,3 @@ -import os from simvue import Run import pytest From 8c78d594737ca05e9e9c22fb14efeeb7f1f633c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristian=20Zar=C4=99bski?= Date: Fri, 24 May 2024 09:34:31 +0100 Subject: [PATCH 2/9] remove MR remnant --- tests/unit/test_pytorch_tensor_mime_type.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/unit/test_pytorch_tensor_mime_type.py b/tests/unit/test_pytorch_tensor_mime_type.py index ed48008c..c240cd5b 100644 --- a/tests/unit/test_pytorch_tensor_mime_type.py +++ b/tests/unit/test_pytorch_tensor_mime_type.py @@ -1,10 +1,7 @@ -<<<<<<< HEAD from simvue.serialization import serialize_object import torch -======= import pytest -from simvue.serialization import Serializer, Deserializer ->>>>>>> dev + try: import torch From 8a199769bc03829aad17bc9f090e65a79cb856f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristian=20Zar=C4=99bski?= Date: Fri, 24 May 2024 09:38:32 +0100 Subject: [PATCH 3/9] Added save object test --- tests/refactor/test_run_class.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/refactor/test_run_class.py b/tests/refactor/test_run_class.py index c9dfb74a..d2b2cb68 100644 --- a/tests/refactor/test_run_class.py +++ b/tests/refactor/test_run_class.py @@ -377,3 +377,23 @@ def test_save_file( assert variable.out == "WARNING: saving zero-sized files not currently supported\n" +@pytest.mark.run +@pytest.mark.parametrize("object_type", ("DataFrame", "ndarray")) +def test_save_object( + create_plain_run: typing.Tuple[sv_run.Run, dict], object_type: str +) -> None: + simvue_run, _ = create_plain_run + + if object_type == "DataFrame": + try: + from pandas import DataFrame + except ImportError: + pytest.skip("Pandas is not installed") + save_obj = DataFrame({"x": [1, 2, 3, 4], "y": [2, 4, 6, 8]}) + elif object_type == "ndarray": + try: + from numpy import array + except ImportError: + pytest.skip("Numpy is not installed") + save_obj = array([1, 2, 3, 4]) + simvue_run.save_object(save_obj, "input", f"test_object_{object_type}") From 604395e0f70847454c76bf664e6e53d9b3f9a423 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristian=20Zar=C4=99bski?= Date: Fri, 24 May 2024 10:24:58 +0100 Subject: [PATCH 4/9] Fix bad import in test --- tests/unit/test_plotly_figure_mime_type.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unit/test_plotly_figure_mime_type.py b/tests/unit/test_plotly_figure_mime_type.py index e806ccc8..8cf8a479 100644 --- a/tests/unit/test_plotly_figure_mime_type.py +++ b/tests/unit/test_plotly_figure_mime_type.py @@ -3,8 +3,6 @@ import plotly import pytest -from simvue.serialization import Serializer, Deserializer - try: import matplotlib.pyplot as plt except ImportError: From 0f19954470e6a80adb8d70834cb16d8b9a5e927d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristian=20Zar=C4=99bski?= Date: Fri, 24 May 2024 11:23:54 +0100 Subject: [PATCH 5/9] Attempt JSON serialization before pickling and fix bug with save_all --- README.md | 8 ++++---- .../bluemira_simvue_geometry_optimisation.py | 2 +- examples/PyTorch/main.py | 2 +- examples/SU2/SU2.py | 4 ++-- examples/Tensorflow/dynamic_rnn.py | 2 +- simvue/executor.py | 8 ++++---- simvue/run.py | 2 +- simvue/serialization.py | 12 ++++++++++++ 8 files changed, 26 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 5addbbd9..7b9877b2 100644 --- a/README.md +++ b/README.md @@ -69,13 +69,13 @@ if __name__ == "__main__": description='This is part 1 of a test') # Description # Upload the code - run.save('training.py', 'code') + run.save_file('training.py', 'code') # Upload an input file - run.save('params.in', 'input') + run.save_file('params.in', 'input') # Add an alert (the alert definition will be created if necessary) - run.add_alert(name='loss-too-high', # Name + run.create_alert(name='loss-too-high', # Name source='metrics', # Source rule='is above', # Rule metric='loss', # Metric @@ -96,7 +96,7 @@ if __name__ == "__main__": ... # Upload an output file - run.save('output.cdf', 'output') + run.save_file('output.cdf', 'output') # If we weren't using a context manager we'd need to end the run # run.close() diff --git a/examples/GeometryOptimisation/bluemira_simvue_geometry_optimisation.py b/examples/GeometryOptimisation/bluemira_simvue_geometry_optimisation.py index 4868cf46..264ed03a 100644 --- a/examples/GeometryOptimisation/bluemira_simvue_geometry_optimisation.py +++ b/examples/GeometryOptimisation/bluemira_simvue_geometry_optimisation.py @@ -171,5 +171,5 @@ def my_minimise_length(vector, grad, parameterisation, ad_args=None): # Here we're minimising the length, within the bounds of our PrincetonD parameterisation, # so we'd expect that x1 goes to its upper bound, and x2 goes to its lower bound. -run.save("bluemira_simvue_geometry_optimisation.py", "code") +run.save_file("bluemira_simvue_geometry_optimisation.py", "code") run.close() diff --git a/examples/PyTorch/main.py b/examples/PyTorch/main.py index ed324d86..2fd55bf4 100644 --- a/examples/PyTorch/main.py +++ b/examples/PyTorch/main.py @@ -205,7 +205,7 @@ def main(): scheduler.step() if args.save_model: - run.save(model.state_dict(), "output", name="mnist_cnn.pt") + run.save_file(model.state_dict(), "output", name="mnist_cnn.pt") run.close() diff --git a/examples/SU2/SU2.py b/examples/SU2/SU2.py index 9e189095..51740867 100644 --- a/examples/SU2/SU2.py +++ b/examples/SU2/SU2.py @@ -56,7 +56,7 @@ filetype = None if input_file.endswith(".cfg"): filetype = "text/plain" - run.save(input_file, "input", filetype) + run.save_file(input_file, "input", filetype) running = True latest = [] @@ -106,6 +106,6 @@ # Save output files for output_file in OUTPUT_FILES: - run.save(output_file, "output") + run.save_file(output_file, "output") run.close() diff --git a/examples/Tensorflow/dynamic_rnn.py b/examples/Tensorflow/dynamic_rnn.py index 1eefb709..0a5339c7 100644 --- a/examples/Tensorflow/dynamic_rnn.py +++ b/examples/Tensorflow/dynamic_rnn.py @@ -45,7 +45,7 @@ "computation over sequences with variable length. This example is using a toy dataset to " "classify linear sequences. The generated sequences have variable length.", ) - run.save("dynamic_rnn.py", "code") + run.save_file("dynamic_rnn.py", "code") # ==================== # TOY DATA GENERATOR diff --git a/simvue/executor.py b/simvue/executor.py index d4013179..f8c3e4f6 100644 --- a/simvue/executor.py +++ b/simvue/executor.py @@ -167,10 +167,10 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None: ) if script: - self._runner.save(filename=script, category="code") + self._runner.save_file(filename=script, category="code") if input_file: - self._runner.save(filename=input_file, category="input") + self._runner.save_file(filename=input_file, category="input") _command: typing.List[str] = [] @@ -284,11 +284,11 @@ def _save_output(self) -> None: for proc_id in self._exit_codes.keys(): # Only save the file if the contents are not empty if self._std_err[proc_id]: - self._runner.save( + self._runner.save_file( f"{self._runner.name}_{proc_id}.err", category="output" ) if self._std_out[proc_id]: - self._runner.save( + self._runner.save_file( f"{self._runner.name}_{proc_id}.out", category="output" ) diff --git a/simvue/run.py b/simvue/run.py index 6cea06cf..1b48e5b8 100644 --- a/simvue/run.py +++ b/simvue/run.py @@ -1142,7 +1142,7 @@ def save_all( for item in items: if item.is_file(): - save_file = self.save(f"{item}", category, filetype, preserve_path) + save_file = self.save_file(item, category, filetype, preserve_path) elif item.is_dir(): save_file = self.save_directory(item, category, filetype, preserve_path) else: diff --git a/simvue/serialization.py b/simvue/serialization.py index f9e4aa0b..e5c518c8 100644 --- a/simvue/serialization.py +++ b/simvue/serialization.py @@ -8,6 +8,7 @@ import typing import pickle import pandas +import json import numpy from io import BytesIO @@ -82,6 +83,8 @@ def serialize_object( return _serialize_matplotlib(data) except ImportError: pass + elif serialized := _serialize_json(data): + return serialized if allow_pickle: return _serialize_pickle(data) @@ -155,6 +158,15 @@ def _serialize_torch_tensor(data: typing.Any) -> typing.Optional[tuple[str, str] return data, mimetype +def _serialize_json(data: typing.Any) -> typing.Optional[tuple[str, str]]: + mimetype = "application/json" + try: + data = json.dumps(data) + except TypeError: + return None + return data, mimetype + + def _serialize_pickle(data: typing.Any) -> typing.Optional[tuple[str, str]]: mimetype = "application/octet-stream" data = pickle.dumps(data) From 75b794fed12844f7db8e156dd9fdc7a15777731f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristian=20Zar=C4=99bski?= Date: Fri, 24 May 2024 11:28:48 +0100 Subject: [PATCH 6/9] Added missing JSON deserializer --- simvue/serialization.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/simvue/serialization.py b/simvue/serialization.py index e5c518c8..96f7d067 100644 --- a/simvue/serialization.py +++ b/simvue/serialization.py @@ -189,6 +189,8 @@ def deserialize_data( return _deserialize_dataframe(data) elif mimetype == "application/vnd.simvue.torch.v1": return _deserialize_torch_tensor(data) + elif mimetype == "application/json": + return _deserialize_json(data) elif mimetype == "application/octet-stream" and allow_pickle: return _deserialize_pickle(data) return None @@ -240,6 +242,11 @@ def _deserialize_torch_tensor(data: "Buffer") -> typing.Optional["Tensor"]: return torch.load(mfile) -def _deserialize_pickle(data): +def _deserialize_pickle(data) -> typing.Optional[typing.Any]: data = pickle.loads(data) return data + + +def _deserialize_json(data) -> typing.Optional[typing.Any]: + data = json.loads(data) + return data From 221c1588b110ea1bbaa3b3f82c362c6568fb24f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristian=20Zar=C4=99bski?= Date: Fri, 24 May 2024 11:30:51 +0100 Subject: [PATCH 7/9] Fix bug with 'engine' parameter in JSON decoding for plotly --- simvue/serialization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/simvue/serialization.py b/simvue/serialization.py index 96f7d067..b2d48a97 100644 --- a/simvue/serialization.py +++ b/simvue/serialization.py @@ -98,7 +98,7 @@ def _serialize_plotly_figure(data: typing.Any) -> typing.Optional[tuple[str, str except ImportError: return None mimetype = "application/vnd.plotly.v1+json" - data = plotly.io.to_json(data, "json") + data = plotly.io.to_json(data, engine="json") return data, mimetype @@ -109,7 +109,7 @@ def _serialize_matplotlib(data: typing.Any) -> typing.Optional[tuple[str, str]]: except ImportError: return None mimetype = "application/vnd.plotly.v1+json" - data = plotly.io.to_json(plotly.tools.mpl_to_plotly(data.gcf()), "json") + data = plotly.io.to_json(plotly.tools.mpl_to_plotly(data.gcf()), engine="json") return data, mimetype @@ -120,7 +120,7 @@ def _serialize_matplotlib_figure(data: typing.Any) -> typing.Optional[tuple[str, except ImportError: return None mimetype = "application/vnd.plotly.v1+json" - data = plotly.io.to_json(plotly.tools.mpl_to_plotly(data), "json") + data = plotly.io.to_json(plotly.tools.mpl_to_plotly(data), engine="json") return data, mimetype From 66c648d92773e81052b8e2066ba5a7f3e10aaccd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristian=20Zar=C4=99bski?= Date: Fri, 24 May 2024 15:26:16 +0100 Subject: [PATCH 8/9] Fix tag in test and address review comments --- simvue/run.py | 36 ++++++++++++++++++++++++++---------- simvue/serialization.py | 2 +- tests/refactor/conftest.py | 2 +- 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/simvue/run.py b/simvue/run.py index ced0f519..77e6dcd9 100644 --- a/simvue/run.py +++ b/simvue/run.py @@ -52,7 +52,6 @@ if typing.TYPE_CHECKING: from .factory.proxy import SimvueBaseClass from .factory.dispatch import DispatcherBaseClass - from .types import DeserializedContent UPLOAD_TIMEOUT: int = 30 HEARTBEAT_INTERVAL: int = 60 @@ -990,7 +989,24 @@ def save_object( name: typing.Optional[str] = None, allow_pickle: bool = False, ) -> bool: - obj: DeserializedContent + """Save an object to the Simvue server + + Parameters + ---------- + obj : typing.Any + object to serialize and send to the server + category : Literal['input', 'output', 'code'] + category of file with respect to this run + name : str, optional + name to associate with this object, by default None + allow_pickle : bool, optional + whether to allow pickling if all other serialization types fail, by default False + + Returns + ------- + bool + whether object upload was successful + """ serialized = serialize_object(obj, allow_pickle) if not serialized or not (pickled := serialized[0]): @@ -1022,7 +1038,7 @@ def save_object( @pydantic.validate_call def save_file( self, - filename: pydantic.FilePath, + file_path: pydantic.FilePath, category: typing.Literal["input", "output", "code"], filetype: typing.Optional[str] = None, preserve_path: bool = False, @@ -1032,7 +1048,7 @@ def save_file( Parameters ---------- - filename : pydantic.FilePath + file_path : pydantic.FilePath path to the file to upload category : Literal['input', 'output', 'code'] category of file with respect to this run @@ -1067,16 +1083,16 @@ def save_file( self._error(f"Invalid MIME type '{filetype}' specified") return False - stored_file_name: str = f"{filename}" + stored_file_name: str = f"{file_path}" if preserve_path and stored_file_name.startswith("./"): stored_file_name = stored_file_name[2:] elif not preserve_path: - stored_file_name = os.path.basename(filename) + stored_file_name = os.path.basename(file_path) # Determine mimetype if not (mimetype := filetype): - mimetype = mimetypes.guess_type(filename)[0] or "application/octet-stream" + mimetype = mimetypes.guess_type(file_path)[0] or "application/octet-stream" data: dict[str, typing.Any] = { "name": name or stored_file_name, @@ -1084,11 +1100,11 @@ def save_file( "type": mimetype, "storage": self._storage_id, "category": category, - "size": (file_size := os.path.getsize(filename)), + "size": (file_size := os.path.getsize(file_path)), "originalPath": os.path.abspath( - os.path.expanduser(os.path.expandvars(filename)) + os.path.expanduser(os.path.expandvars(file_path)) ), - "checksum": calculate_sha256(f"{filename}", True), + "checksum": calculate_sha256(f"{file_path}", True), } if not file_size: diff --git a/simvue/serialization.py b/simvue/serialization.py index b2d48a97..c51847e6 100644 --- a/simvue/serialization.py +++ b/simvue/serialization.py @@ -60,7 +60,7 @@ def serialize_object( Returns ------- Callable[[typing.Any], tuple[str, str]] - the serializer to user + the serializer to use """ module_name = data.__class__.__module__ class_name = data.__class__.__name__ diff --git a/tests/refactor/conftest.py b/tests/refactor/conftest.py index ddb8c121..bc2b31ae 100644 --- a/tests/refactor/conftest.py +++ b/tests/refactor/conftest.py @@ -79,7 +79,7 @@ def setup_test_run(run: sv_run.Run, create_objects: bool, request: pytest.Fixtur "test_identifier": fix_use_id }, "folder": f"/simvue_unit_testing/{fix_use_id}", - "tags": ["simvue_client_unit_tests", request.node.name] + "tags": ["simvue_client_unit_tests", request.node.name.replace("[", "_").replace("]", "_")] } if os.environ.get("CI"): From e60f19cb07ab30b321794b7fda18d34ddbfa496b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristian=20Zar=C4=99bski?= Date: Fri, 24 May 2024 15:28:28 +0100 Subject: [PATCH 9/9] Fix bad argument name --- simvue/executor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/simvue/executor.py b/simvue/executor.py index f8c3e4f6..d156c2a1 100644 --- a/simvue/executor.py +++ b/simvue/executor.py @@ -167,10 +167,10 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None: ) if script: - self._runner.save_file(filename=script, category="code") + self._runner.save_file(file_path=script, category="code") if input_file: - self._runner.save_file(filename=input_file, category="input") + self._runner.save_file(file_path=input_file, category="input") _command: typing.List[str] = []