diff --git a/simvue/api.py b/simvue/api.py index 626961df..022aaa95 100644 --- a/simvue/api.py +++ b/simvue/api.py @@ -44,7 +44,7 @@ def set_json_header(headers: dict[str, str]) -> dict[str, str]: reraise=True, ) def post( - url: str, headers: dict[str, str], data: dict[str, typing.Any], is_json: bool = True + url: str, headers: dict[str, str], data: typing.Any, is_json: bool = True ) -> requests.Response: """HTTP POST with retries diff --git a/simvue/client.py b/simvue/client.py index bebb7b7a..cf1f0be3 100644 --- a/simvue/client.py +++ b/simvue/client.py @@ -603,6 +603,45 @@ def _retrieve_artifact_from_server( return json_response + @prettify_pydantic + @pydantic.validate_call + def abort_run(self, run_id: str, reason: str) -> typing.Union[dict, list]: + """Abort a currently active run on the server + + Parameters + ---------- + run_id : str + the unique identifier for the run + reason : str + reason for abort + + Returns + ------- + dict | list + response from server + """ + body: dict[str, str | None] = {"id": run_id, "reason": reason} + + response = requests.put( + f"{self._url}/api/runs/abort", + headers=self._headers, + json=body, + ) + + json_response = self._get_json_from_response( + expected_status=[200, 400], + scenario=f"Abort of run '{run_id}'", + response=response, + ) + + if not isinstance(json_response, dict): + raise RuntimeError( + "Expected list from JSON response during retrieval of " + f"artifact but got '{type(json_response)}'" + ) + + return json_response + @prettify_pydantic @pydantic.validate_call def get_artifact( diff --git a/simvue/executor.py b/simvue/executor.py index 4deab048..b9682454 100644 --- a/simvue/executor.py +++ b/simvue/executor.py @@ -15,6 +15,7 @@ import sys import multiprocessing import os +import psutil import subprocess import pathlib import time @@ -26,17 +27,16 @@ logger = logging.getLogger(__name__) +class CompletionCallback(typing.Protocol): + def __call__(self, *, status_code: int, std_out: str, std_err: str) -> None: ... + + def _execute_process( proc_id: str, command: typing.List[str], runner_name: str, - exit_status_dict: typing.Dict[str, int], - std_err: typing.Dict[str, str], - std_out: typing.Dict[str, str], - run_on_exit: typing.Optional[typing.Callable[[int, int, str], None]], - trigger: typing.Optional[multiprocessing.synchronize.Event], environment: typing.Optional[typing.Dict[str, str]], -) -> None: +) -> subprocess.Popen: with open(f"{runner_name}_{proc_id}.err", "w") as err: with open(f"{runner_name}_{proc_id}.out", "w") as out: _result = subprocess.Popen( @@ -47,24 +47,7 @@ def _execute_process( env=environment, ) - _status_code = _result.wait() - with open(f"{runner_name}_{proc_id}.err") as err: - std_err[proc_id] = err.read() - - with open(f"{runner_name}_{proc_id}.out") as out: - std_out[proc_id] = out.read() - - exit_status_dict[proc_id] = _status_code - - if run_on_exit: - run_on_exit( - status_code=exit_status_dict[proc_id], - std_out=std_out[proc_id], - std_err=std_err[proc_id], - ) - - if trigger: - trigger.set() + return _result class Executor: @@ -88,13 +71,16 @@ def __init__(self, simvue_runner: "simvue.Run", keep_logs: bool = True) -> None: """ self._runner = simvue_runner self._keep_logs = keep_logs - self._manager = multiprocessing.Manager() - self._exit_codes = self._manager.dict() - self._std_err = self._manager.dict() - self._std_out = self._manager.dict() + self._completion_callbacks: dict[str, typing.Optional[CompletionCallback]] = {} + self._completion_triggers: dict[ + str, typing.Optional[multiprocessing.synchronize.Event] + ] = {} + self._exit_codes: dict[str, int] = {} + self._std_err: dict[str, str] = {} + self._std_out: dict[str, str] = {} self._alert_ids: dict[str, str] = {} - self._command_str: typing.Dict[str, str] = {} - self._processes: typing.Dict[str, multiprocessing.Process] = {} + self._command_str: dict[str, str] = {} + self._processes: dict[str, subprocess.Popen] = {} def add_process( self, @@ -104,9 +90,7 @@ def add_process( script: typing.Optional[pathlib.Path] = None, input_file: typing.Optional[pathlib.Path] = None, env: typing.Optional[typing.Dict[str, str]] = None, - completion_callback: typing.Optional[ - typing.Callable[[int, str, str], None] - ] = None, + completion_callback: typing.Optional[CompletionCallback] = None, completion_trigger: typing.Optional[multiprocessing.synchronize.Event] = None, **kwargs, ) -> None: @@ -161,6 +145,9 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None: """ _pos_args = list(args) + if not self._runner.name: + raise RuntimeError("Cannot add process, expected Run instance to have name") + if sys.platform == "win32" and completion_callback: logger.warning( "Completion callback for 'add_process' may fail on Windows due to " @@ -207,26 +194,16 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None: _command += _pos_args self._command_str[identifier] = " ".join(_command) + self._completion_callbacks[identifier] = completion_callback + self._completion_triggers[identifier] = completion_trigger - self._processes[identifier] = multiprocessing.Process( - target=_execute_process, - args=( - identifier, - _command, - self._runner.name, - self._exit_codes, - self._std_err, - self._std_out, - completion_callback, - completion_trigger, - env, - ), + self._processes[identifier] = _execute_process( + identifier, _command, self._runner.name, env ) + self._alert_ids[identifier] = self._runner.create_alert( name=f"{identifier}_exit_status", source="user" ) - logger.debug(f"Executing process: {' '.join(_command)}") - self._processes[identifier].start() @property def success(self) -> int: @@ -272,8 +249,8 @@ def _get_error_status(self, process_id: str) -> typing.Optional[str]: err_msg: typing.Optional[str] = None # Return last 10 lines of stdout if stderr empty - if not (err_msg := self._std_err[process_id]) and ( - std_out := self._std_out[process_id] + if not (err_msg := self._std_err.get(process_id)) and ( + std_out := self._std_out.get(process_id) ): err_msg = " Tail STDOUT:\n\n" start_index = -10 if len(lines := std_out.split("\n")) > 10 else 0 @@ -308,28 +285,42 @@ def _save_output(self) -> None: """Save the output to Simvue""" for proc_id in self._exit_codes.keys(): # Only save the file if the contents are not empty - if self._std_err[proc_id]: + if self._std_err.get(proc_id): self._runner.save_file( f"{self._runner.name}_{proc_id}.err", category="output" ) - if self._std_out[proc_id]: + if self._std_out.get(proc_id): self._runner.save_file( f"{self._runner.name}_{proc_id}.out", category="output" ) def kill_process(self, process_id: str) -> None: """Kill a running process by ID""" - if not (_process := self._processes.get(process_id)): + if not (process := self._processes.get(process_id)): logger.error( f"Failed to terminate process '{process_id}', no such identifier." ) return - _process.kill() + + parent = psutil.Process(process.pid) + + for child in parent.children(recursive=True): + logger.debug(f"Terminating child process {child.pid}: {child.name()}") + child.kill() + + for child in parent.children(recursive=True): + child.wait() + + logger.debug(f"Terminating child process {process.pid}: {process.args}") + process.kill() + process.wait() + + self._execute_callback(process_id) def kill_all(self) -> None: """Kill all running processes""" - for process in self._processes.values(): - process.kill() + for process in self._processes.keys(): + self.kill_process(process) def _clear_cache_files(self) -> None: """Clear local log files if required""" @@ -338,11 +329,28 @@ def _clear_cache_files(self) -> None: os.remove(f"{self._runner.name}_{proc_id}.err") os.remove(f"{self._runner.name}_{proc_id}.out") + def _execute_callback(self, identifier: str) -> None: + with open(f"{self._runner.name}_{identifier}.err") as err: + std_err = err.read() + + with open(f"{self._runner.name}_{identifier}.out") as out: + std_out = out.read() + + if callback := self._completion_callbacks.get(identifier): + callback( + status_code=self._processes[identifier].returncode, + std_out=std_out, + std_err=std_err, + ) + if completion_trigger := self._completion_triggers.get(identifier): + completion_trigger.set() + def wait_for_completion(self) -> None: """Wait for all processes to finish then perform tidy up and upload""" - for process in self._processes.values(): - if process.is_alive(): - process.join() + for identifier, process in self._processes.items(): + process.wait() + self._execute_callback(identifier) + self._update_alerts() self._save_output() diff --git a/simvue/factory/proxy/base.py b/simvue/factory/proxy/base.py index 6f95efa9..1ca9060c 100644 --- a/simvue/factory/proxy/base.py +++ b/simvue/factory/proxy/base.py @@ -89,3 +89,7 @@ def send_heartbeat(self) -> typing.Optional[dict[str, typing.Any]]: @abc.abstractmethod def check_token(self) -> bool: pass + + @abc.abstractmethod + def get_abort_status(self) -> bool: + pass diff --git a/simvue/factory/proxy/offline.py b/simvue/factory/proxy/offline.py index 4bbd736b..f716a5b0 100644 --- a/simvue/factory/proxy/offline.py +++ b/simvue/factory/proxy/offline.py @@ -176,9 +176,15 @@ def set_alert_state( @skip_if_failed("_aborted", "_suppress_errors", []) def list_tags(self) -> list[dict[str, typing.Any]]: + #TODO: Tag retrieval not implemented for offline running raise NotImplementedError( "Retrieval of current tags is not implemented for offline running" ) + + @skip_if_failed("_aborted", "_suppress_errors", True) + def get_abort_status(self) -> bool: + #TODO: Abort on failure not implemented for offline running + return True @skip_if_failed("_aborted", "_suppress_errors", []) def list_alerts(self) -> list[dict[str, typing.Any]]: diff --git a/simvue/factory/proxy/remote.py b/simvue/factory/proxy/remote.py index 888d6054..8d9d45cc 100644 --- a/simvue/factory/proxy/remote.py +++ b/simvue/factory/proxy/remote.py @@ -466,3 +466,26 @@ def check_token(self) -> bool: self._error("Token has expired") return False return True + + @skip_if_failed("_aborted", "_suppress_errors", False) + def get_abort_status(self) -> bool: + logger.debug("Retrieving alert status") + + try: + response = get( + f"{self._url}/api/runs/{self._id}/abort", self._headers_mp + ) + except Exception as err: + self._error(f"Exception retrieving abort status: {str(err)}") + return False + + logger.debug("Got status code %d when checking abort status", response.status_code) + + if response.status_code == 200: + if (status := response.json().get("status")) is None: + self._error(f"Expected key 'status' when retrieving abort status {response.json()}") + return False + return status + + self._error(f"Got status code {response.status_code} when checking abort status") + return False diff --git a/simvue/run.py b/simvue/run.py index 8b2e0a48..5d835c3b 100644 --- a/simvue/run.py +++ b/simvue/run.py @@ -10,6 +10,7 @@ import datetime import json import logging +import pathlib import mimetypes import multiprocessing.synchronize import threading @@ -50,6 +51,7 @@ validate_timestamp, ) + if typing.TYPE_CHECKING: from .factory.proxy import SimvueBaseClass from .factory.dispatch import DispatcherBaseClass @@ -101,6 +103,8 @@ def __init__( self._uuid: str = f"{uuid.uuid4()}" self._mode: typing.Literal["online", "offline", "disabled"] = mode self._name: typing.Optional[str] = None + self._testing: bool = False + self._abort_on_alert: bool = True self._dispatch_mode: typing.Literal["direct", "queued"] = "queued" self._executor = Executor(self) self._dispatcher: typing.Optional[DispatcherBaseClass] = None @@ -127,6 +131,7 @@ def __init__( self._heartbeat_termination_trigger: typing.Optional[threading.Event] = None self._storage_id: typing.Optional[str] = None self._heartbeat_thread: typing.Optional[threading.Thread] = None + self._heartbeat_interval: int = HEARTBEAT_INTERVAL def __enter__(self) -> "Run": return self @@ -139,9 +144,6 @@ def __exit__( typing.Union[typing.Type[BaseException], BaseException] ], ) -> None: - # Wait for the executor to finish with currently running processes - self._executor.wait_for_completion() - identifier = self._id logger.debug( "Automatically closing run '%s' in status %s", @@ -149,6 +151,8 @@ def __exit__( self._status, ) + self._executor.wait_for_completion() + # Stop the run heartbeat if self._heartbeat_thread and self._heartbeat_termination_trigger: self._heartbeat_termination_trigger.set() @@ -245,7 +249,7 @@ def _get_sysinfo(self) -> dict[str, typing.Any]: def _create_heartbeat_callback( self, - ) -> typing.Callable[[str, dict, str, bool], None]: + ) -> typing.Callable[[threading.Event], None]: if ( self._mode == "online" and (not self._url or not self._id) ) or not self._heartbeat_termination_trigger: @@ -274,11 +278,32 @@ def _heartbeat( ) last_res_metric_call = res_time - if time.time() - last_heartbeat < HEARTBEAT_INTERVAL: + if time.time() - last_heartbeat < self._heartbeat_interval: continue last_heartbeat = time.time() + # Check if the user has aborted the run + with self._configuration_lock: + if ( + self._simvue + and self._abort_on_alert + and self._simvue.get_abort_status() + ): + self._alert_raised_trigger.set() + self.kill_all_processes() + if self._dispatcher and self._shutdown_event: + self._shutdown_event.set() + self._dispatcher.purge() + self._dispatcher.join() + self.set_status("terminated") + click.secho( + "[simvue] Run was aborted.", + fg="red" if self._term_color else None, + bold=self._term_color, + ) + os._exit(1) + if self._simvue: self._simvue.send_heartbeat() @@ -332,7 +357,7 @@ def _online_dispatch_callback( buffer: list[typing.Any], category: str, url: str = self._url, - run_id: str = self._id, + run_id: typing.Optional[str] = self._id, headers: dict[str, str] = self._headers, ) -> None: if not buffer: @@ -394,6 +419,7 @@ def _start(self, reconnect: bool = False) -> bool: self._shutdown_event = threading.Event() self._heartbeat_termination_trigger = threading.Event() + self._alert_raised_trigger = threading.Event() try: self._dispatcher = Dispatcher( @@ -475,7 +501,6 @@ def init( folder: typing.Annotated[str, pydantic.Field(pattern=FOLDER_REGEX)] = "/", running: bool = True, retention_period: typing.Optional[str] = None, - resources_metrics_interval: typing.Optional[int] = HEARTBEAT_INTERVAL, visibility: typing.Union[ typing.Literal["public", "tenant"], list[str], None ] = None, @@ -502,8 +527,6 @@ def init( retention_period : str, optional describer for time period to retain run, the default of None removes this constraint. - resources_metrics_interval : int, optional - how often to publish resource metrics, if None these will not be published visibility : Literal['public', 'tenant'] | list[str], optional set visibility options for this run, either: * public - run viewable to all. @@ -541,8 +564,6 @@ def init( self._error("specified name is invalid") return False - self._resources_metrics_interval = resources_metrics_interval - self._name = name self._status = "running" if running else "created" @@ -618,7 +639,7 @@ def add_process( self, identifier: str, *cmd_args, - executable: typing.Optional[typing.Union[str]] = None, + executable: typing.Optional[typing.Union[str, pathlib.Path]] = None, script: typing.Optional[pydantic.FilePath] = None, input_file: typing.Optional[pydantic.FilePath] = None, completion_callback: typing.Optional[ @@ -690,12 +711,20 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None: "due to function pickling restrictions for multiprocessing" ) + if isinstance(executable, pathlib.Path): + if not executable.is_file(): + raise FileNotFoundError( + f"Executable '{executable}' is not a valid file" + ) + + executable_str = f"{executable}" + _cmd_list: typing.List[str] = [] _pos_args = list(cmd_args) # Assemble the command for saving to metadata as string if executable: - _cmd_list += [executable] + _cmd_list += [executable_str] else: _cmd_list += [_pos_args[0]] executable = _pos_args[0] @@ -724,10 +753,10 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None: self._executor.add_process( identifier, *_pos_args, - executable=executable, + executable=executable_str, script=script, input_file=input_file, - completion_callback=completion_callback, + completion_callback=completion_callback, # type: ignore completion_trigger=completion_trigger, env=env, **cmd_kwargs, @@ -816,6 +845,7 @@ def config( resources_metrics_interval: typing.Optional[int] = None, disable_resources_metrics: typing.Optional[bool] = None, storage_id: typing.Optional[str] = None, + abort_on_alert: typing.Optional[bool] = None, ) -> bool: """Optional configuration @@ -832,6 +862,8 @@ def config( disable monitoring of resource metrics storage_id : str, optional identifier of storage to use, by default None + abort_on_alert : bool, optional + whether to abort the run if an alert is triggered Returns ------- @@ -859,6 +891,9 @@ def config( if resources_metrics_interval: self._resources_metrics_interval = resources_metrics_interval + if abort_on_alert: + self._abort_on_alert = abort_on_alert + if storage_id: self._storage_id = storage_id @@ -1344,14 +1379,14 @@ def set_status( if self._mode == "disabled": return True - if not self._active: + if not self._active or not self._name: self._error("Run is not active") return False data: dict[str, str] = {"name": self._name, "status": status} self._status = status - if self._simvue.update(data): + if self._simvue and self._simvue.update(data): return True return False diff --git a/simvue/types.py b/simvue/types.py index ec194d73..95b3c46c 100644 --- a/simvue/types.py +++ b/simvue/types.py @@ -5,6 +5,7 @@ except ImportError: from typing_extensions import TypeAlias + if typing.TYPE_CHECKING: from numpy import ndarray from pandas import DataFrame diff --git a/tests/refactor/test_run_class.py b/tests/refactor/test_run_class.py index d7bc0ca9..16fad8ad 100644 --- a/tests/refactor/test_run_class.py +++ b/tests/refactor/test_run_class.py @@ -1,14 +1,16 @@ import pytest +import pytest_mock import time import typing import contextlib import inspect import tempfile +import threading import uuid +import psutil import pathlib import concurrent.futures import random -import inspect import simvue.run as sv_run import simvue.client as sv_cl @@ -459,3 +461,71 @@ def test_save_object( pytest.skip("Numpy is not installed") save_obj = array([1, 2, 3, 4]) simvue_run.save_object(save_obj, "input", f"test_object_{object_type}") + + +@pytest.mark.run +def test_abort_on_alert_process(create_plain_run: typing.Tuple[sv_run.Run, dict], mocker: pytest_mock.MockerFixture) -> None: + def testing_exit(status: int) -> None: + raise SystemExit(status) + mocker.patch("os._exit", testing_exit) + N_PROCESSES: int = 3 + run, _ = create_plain_run + run.config(resources_metrics_interval=1) + run._heartbeat_interval = 1 + run._testing = True + run.add_process(identifier="forever_long", executable="bash", c="&".join(["sleep 10"] * N_PROCESSES)) + process_id = list(run._executor._processes.values())[0].pid + process = psutil.Process(process_id) + assert len(child_processes := process.children(recursive=True)) == 3 + time.sleep(2) + client = sv_cl.Client() + client.abort_run(run._id, reason="testing abort") + time.sleep(4) + for child in child_processes: + assert not child.is_running() + if not run._status == "terminated": + run.kill_all_processes() + raise AssertionError("Run was not terminated") + + +@pytest.mark.run +def test_abort_on_alert_python(create_plain_run: typing.Tuple[sv_run.Run, dict], mocker: pytest_mock.MockerFixture) -> None: + abort_set = threading.Event() + def testing_exit(status: int) -> None: + abort_set.set() + raise SystemExit(status) + mocker.patch("os._exit", testing_exit) + run, _ = create_plain_run + run.config(resources_metrics_interval=1) + run._heartbeat_interval = 1 + client = sv_cl.Client() + i = 0 + + while True: + time.sleep(1) + if i == 4: + client.abort_run(run._id, reason="testing abort") + i += 1 + if abort_set.is_set() or i > 9: + break + + assert i < 7 + assert run._status == "terminated" + + +@pytest.mark.run +def test_kill_all_processes(create_plain_run: typing.Tuple[sv_run.Run, dict]) -> None: + run, _ = create_plain_run + run.config(resources_metrics_interval=1) + run.add_process(identifier="forever_long_1", executable="bash", c="sleep 10000") + run.add_process(identifier="forever_long_2", executable="bash", c="sleep 10000") + processes = [ + psutil.Process(process.pid) + for process in run._executor._processes.values() + ] + time.sleep(2) + run.kill_all_processes() + time.sleep(4) + for process in processes: + assert not process.is_running() + assert all(not child.is_running() for child in process.children(recursive=True))