From 27cdec9b4a8d47ae7275ccadba6505e1492c1d9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristian=20Zar=C4=99bski?= Date: Fri, 31 May 2024 08:56:55 +0100 Subject: [PATCH] Added missing 'category' option to get_artifacts_as_files --- simvue/client.py | 14 +++++++++++--- tests/refactor/test_client.py | 20 +++++++++++++++----- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/simvue/client.py b/simvue/client.py index 650a2b35..e9cbff3a 100644 --- a/simvue/client.py +++ b/simvue/client.py @@ -547,8 +547,12 @@ def list_artifacts(self, run_id: str) -> list[dict[str, typing.Any]]: ) return json_response - def _retrieve_artifact_from_server(self, run_id: str, name: str): - params: dict[str, str] = {"name": name} + def _retrieve_artifact_from_server( + self, + run_id: str, + name: str, + ) -> typing.Union[dict, list]: + params: dict[str, str | None] = {"name": name} response = requests.get( f"{self._url}/api/runs/{run_id}/artifacts", @@ -707,6 +711,7 @@ def _assemble_artifact_downloads( def get_artifacts_as_files( self, run_id: str, + category: typing.Optional[typing.Literal["input", "output", "code"]] = None, path: typing.Optional[str] = None, startswith: typing.Optional[str] = None, contains: typing.Optional[str] = None, @@ -733,9 +738,12 @@ def get_artifacts_as_files( RuntimeError if there was a failure retrieving artifacts from the server """ + params: dict[str, typing.Optional[str]] = {"category": category} response: requests.Response = requests.get( - f"{self._url}/api/runs/{run_id}/artifacts", headers=self._headers + f"{self._url}/api/runs/{run_id}/artifacts", + headers=self._headers, + params=params, ) self._get_json_from_response( diff --git a/tests/refactor/test_client.py b/tests/refactor/test_client.py index 3622a0a8..fcf943cd 100644 --- a/tests/refactor/test_client.py +++ b/tests/refactor/test_client.py @@ -107,13 +107,23 @@ def test_get_artifact_as_file( @pytest.mark.dependency @pytest.mark.client -def test_get_artifacts_as_files(create_test_run: tuple[sv_run.Run, dict]) -> None: +@pytest.mark.parametrize("category", (None, "code", "input", "output")) +def test_get_artifacts_as_files( + create_test_run: tuple[sv_run.Run, dict], + category: typing.Literal["code", "input", "output"], +) -> None: with tempfile.TemporaryDirectory() as tempd: client = svc.Client() - client.get_artifacts_as_files(create_test_run[1]["run_id"], path=tempd) + client.get_artifacts_as_files( + create_test_run[1]["run_id"], category=category, path=tempd + ) files = [os.path.basename(i) for i in glob.glob(os.path.join(tempd, "*"))] - assert create_test_run[1]["file_1"] in files - assert create_test_run[1]["file_2"] in files + if not category or category == "input": + assert create_test_run[1]["file_1"] in files + if not category or category == "output": + assert create_test_run[1]["file_2"] in files + if not category or category == "code": + assert create_test_run[1]["file_3"] in files @pytest.mark.dependency @@ -140,7 +150,7 @@ def test_get_folder(create_test_run: tuple[sv_run.Run, dict]) -> None: @pytest.mark.dependency -@pytest.mark.client +@pytest.mark.client def test_get_metrics_names(create_test_run: tuple[sv_run.Run, dict]) -> None: client = svc.Client() time.sleep(1)