Skip to content

Commit 20d6466

Browse files
committed
Update typing issues and fix tests for checking scenarios for abort
1 parent 0a1c4a0 commit 20d6466

File tree

5 files changed

+68
-31
lines changed

5 files changed

+68
-31
lines changed

simvue/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def set_json_header(headers: dict[str, str]) -> dict[str, str]:
4444
reraise=True,
4545
)
4646
def post(
47-
url: str, headers: dict[str, str], data: dict[str, typing.Any], is_json: bool = True
47+
url: str, headers: dict[str, str], data: typing.Any, is_json: bool = True
4848
) -> requests.Response:
4949
"""HTTP POST with retries
5050

simvue/executor.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
logger = logging.getLogger(__name__)
2828

2929

30+
class CompletionCallback(typing.Protocol):
31+
def __call__(self, *, status_code: int, std_out: str, std_err: str) -> None: ...
32+
33+
3034
def _execute_process(
3135
proc_id: str,
3236
command: typing.List[str],
@@ -67,11 +71,13 @@ def __init__(self, simvue_runner: "simvue.Run", keep_logs: bool = True) -> None:
6771
"""
6872
self._runner = simvue_runner
6973
self._keep_logs = keep_logs
70-
self._completion_callbacks = {}
71-
self._completion_triggers = {}
72-
self._exit_codes = {}
73-
self._std_err = {}
74-
self._std_out = {}
74+
self._completion_callbacks: dict[str, typing.Optional[CompletionCallback]] = {}
75+
self._completion_triggers: dict[
76+
str, typing.Optional[multiprocessing.synchronize.Event]
77+
] = {}
78+
self._exit_codes: dict[str, int] = {}
79+
self._std_err: dict[str, str] = {}
80+
self._std_out: dict[str, str] = {}
7581
self._alert_ids: dict[str, str] = {}
7682
self._command_str: dict[str, str] = {}
7783
self._processes: dict[str, subprocess.Popen] = {}
@@ -84,9 +90,7 @@ def add_process(
8490
script: typing.Optional[pathlib.Path] = None,
8591
input_file: typing.Optional[pathlib.Path] = None,
8692
env: typing.Optional[typing.Dict[str, str]] = None,
87-
completion_callback: typing.Optional[
88-
typing.Callable[[int, str, str], None]
89-
] = None,
93+
completion_callback: typing.Optional[CompletionCallback] = None,
9094
completion_trigger: typing.Optional[multiprocessing.synchronize.Event] = None,
9195
**kwargs,
9296
) -> None:
@@ -141,6 +145,9 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None:
141145
"""
142146
_pos_args = list(args)
143147

148+
if not self._runner.name:
149+
raise RuntimeError("Cannot add process, expected Run instance to have name")
150+
144151
if sys.platform == "win32" and completion_callback:
145152
logger.warning(
146153
"Completion callback for 'add_process' may fail on Windows due to "
@@ -289,20 +296,20 @@ def _save_output(self) -> None:
289296

290297
def kill_process(self, process_id: str) -> None:
291298
"""Kill a running process by ID"""
292-
if not (_process := self._processes.get(process_id)):
299+
if not (process := self._processes.get(process_id)):
293300
logger.error(
294301
f"Failed to terminate process '{process_id}', no such identifier."
295302
)
296303
return
297304

298-
_parent = psutil.Process(_process.pid)
305+
parent = psutil.Process(process.pid)
299306

300-
for child in _parent.children(recursive=True):
307+
for child in parent.children(recursive=True):
301308
logger.debug(f"Terminating child process {child.pid}: {child.name()}")
302309
child.kill()
303310

304-
logger.debug(f"Terminating child process {_process.pid}: {_process.args}")
305-
_process.terminate()
311+
logger.debug(f"Terminating child process {process.pid}: {process.args}")
312+
process.terminate()
306313

307314
self._execute_callback(process_id)
308315

@@ -325,14 +332,14 @@ def _execute_callback(self, identifier: str) -> None:
325332
with open(f"{self._runner.name}_{identifier}.out") as out:
326333
std_out = out.read()
327334

328-
if self._completion_callbacks[identifier]:
329-
self._completion_callbacks[identifier](
335+
if callback := self._completion_callbacks.get(identifier):
336+
callback(
330337
status_code=self._processes[identifier].returncode,
331338
std_out=std_out,
332339
std_err=std_err,
333340
)
334-
if self._completion_triggers[identifier]:
335-
self._completion_triggers[identifier].set()
341+
if completion_trigger := self._completion_triggers.get(identifier):
342+
completion_trigger.set()
336343

337344
def wait_for_completion(self) -> None:
338345
"""Wait for all processes to finish then perform tidy up and upload"""

simvue/run.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import datetime
1111
import json
1212
import logging
13+
import pathlib
1314
import mimetypes
1415
import multiprocessing.synchronize
1516
import threading
@@ -50,6 +51,7 @@
5051
validate_timestamp,
5152
)
5253

54+
5355
if typing.TYPE_CHECKING:
5456
from .factory.proxy import SimvueBaseClass
5557
from .factory.dispatch import DispatcherBaseClass
@@ -101,6 +103,7 @@ def __init__(
101103
self._uuid: str = f"{uuid.uuid4()}"
102104
self._mode: typing.Literal["online", "offline", "disabled"] = mode
103105
self._name: typing.Optional[str] = None
106+
self._testing: bool = False
104107
self._abort_on_alert: bool = True
105108
self._dispatch_mode: typing.Literal["direct", "queued"] = "queued"
106109
self._executor = Executor(self)
@@ -246,7 +249,7 @@ def _get_sysinfo(self) -> dict[str, typing.Any]:
246249

247250
def _create_heartbeat_callback(
248251
self,
249-
) -> typing.Callable[[str, dict, str, bool], None]:
252+
) -> typing.Callable[[threading.Event], None]:
250253
if (
251254
self._mode == "online" and (not self._url or not self._id)
252255
) or not self._heartbeat_termination_trigger:
@@ -354,7 +357,7 @@ def _online_dispatch_callback(
354357
buffer: list[typing.Any],
355358
category: str,
356359
url: str = self._url,
357-
run_id: str = self._id,
360+
run_id: typing.Optional[str] = self._id,
358361
headers: dict[str, str] = self._headers,
359362
) -> None:
360363
if not buffer:
@@ -636,7 +639,7 @@ def add_process(
636639
self,
637640
identifier: str,
638641
*cmd_args,
639-
executable: typing.Optional[typing.Union[str]] = None,
642+
executable: typing.Optional[typing.Union[str, pathlib.Path]] = None,
640643
script: typing.Optional[pydantic.FilePath] = None,
641644
input_file: typing.Optional[pydantic.FilePath] = None,
642645
completion_callback: typing.Optional[
@@ -708,12 +711,20 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None:
708711
"due to function pickling restrictions for multiprocessing"
709712
)
710713

714+
if isinstance(executable, pathlib.Path):
715+
if not executable.is_file():
716+
raise FileNotFoundError(
717+
f"Executable '{executable}' is not a valid file"
718+
)
719+
720+
executable_str = f"{executable}"
721+
711722
_cmd_list: typing.List[str] = []
712723
_pos_args = list(cmd_args)
713724

714725
# Assemble the command for saving to metadata as string
715726
if executable:
716-
_cmd_list += [executable]
727+
_cmd_list += [executable_str]
717728
else:
718729
_cmd_list += [_pos_args[0]]
719730
executable = _pos_args[0]
@@ -742,10 +753,10 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None:
742753
self._executor.add_process(
743754
identifier,
744755
*_pos_args,
745-
executable=executable,
756+
executable=executable_str,
746757
script=script,
747758
input_file=input_file,
748-
completion_callback=completion_callback,
759+
completion_callback=completion_callback, # type: ignore
749760
completion_trigger=completion_trigger,
750761
env=env,
751762
**cmd_kwargs,
@@ -1368,14 +1379,14 @@ def set_status(
13681379
if self._mode == "disabled":
13691380
return True
13701381

1371-
if not self._active:
1382+
if not self._active or not self._name:
13721383
self._error("Run is not active")
13731384
return False
13741385

13751386
data: dict[str, str] = {"name": self._name, "status": status}
13761387
self._status = status
13771388

1378-
if self._simvue.update(data):
1389+
if self._simvue and self._simvue.update(data):
13791390
return True
13801391

13811392
return False

simvue/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
except ImportError:
66
from typing_extensions import TypeAlias
77

8+
89
if typing.TYPE_CHECKING:
910
from numpy import ndarray
1011
from pandas import DataFrame

tests/refactor/test_run_class.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import pytest
2+
import pytest_mock
23
import time
34
import typing
45
import contextlib
56
import inspect
67
import tempfile
8+
import threading
79
import uuid
10+
import psutil
811
import pathlib
912
import concurrent.futures
1013
import random
@@ -461,22 +464,37 @@ def test_save_object(
461464

462465

463466
@pytest.mark.run
464-
def test_abort_on_alert_process(create_plain_run: typing.Tuple[sv_run.Run, dict]) -> None:
467+
def test_abort_on_alert_process(create_plain_run: typing.Tuple[sv_run.Run, dict], mocker: pytest_mock.MockerFixture) -> None:
468+
def testing_exit(status: int) -> None:
469+
raise SystemExit(status)
470+
mocker.patch("os._exit", testing_exit)
471+
N_PROCESSES: int = 3
465472
run, _ = create_plain_run
466473
run.config(resources_metrics_interval=1)
467474
run._heartbeat_interval = 1
468-
run.add_process(identifier="forever_long", executable="bash", c="sleep 10000")
475+
run._testing = True
476+
run.add_process(identifier="forever_long", executable="bash", c="&".join(["sleep 10"] * N_PROCESSES))
477+
process_id = list(run._executor._processes.values())[0].pid
478+
process = psutil.Process(process_id)
479+
assert len(child_processes := process.children(recursive=True)) == 3
469480
time.sleep(2)
470481
client = sv_cl.Client()
471482
client.abort_run(run._id, reason="testing abort")
472483
time.sleep(4)
484+
for child in child_processes:
485+
assert not child.is_running()
473486
if not run._status == "terminated":
474487
run.kill_all_processes()
475488
raise AssertionError("Run was not terminated")
476489

477490

478491
@pytest.mark.run
479-
def test_abort_on_alert_python(create_plain_run: typing.Tuple[sv_run.Run, dict]) -> None:
492+
def test_abort_on_alert_python(create_plain_run: typing.Tuple[sv_run.Run, dict], mocker: pytest_mock.MockerFixture) -> None:
493+
abort_set = threading.Event()
494+
def testing_exit(status: int) -> None:
495+
abort_set.set()
496+
raise SystemExit(status)
497+
mocker.patch("os._exit", testing_exit)
480498
run, _ = create_plain_run
481499
run.config(resources_metrics_interval=1)
482500
run._heartbeat_interval = 1
@@ -488,10 +506,10 @@ def test_abort_on_alert_python(create_plain_run: typing.Tuple[sv_run.Run, dict])
488506
if i == 4:
489507
client.abort_run(run._id, reason="testing abort")
490508
i += 1
491-
if i == 10:
509+
if abort_set.is_set() or i > 9:
492510
break
493511

494-
assert i == 4
512+
assert i < 7
495513
assert run._status == "terminated"
496514

497515

0 commit comments

Comments
 (0)