Skip to content

Add auto-abort feature to client #399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion simvue/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def set_json_header(headers: dict[str, str]) -> dict[str, str]:
reraise=True,
)
def post(
url: str, headers: dict[str, str], data: dict[str, typing.Any], is_json: bool = True
url: str, headers: dict[str, str], data: typing.Any, is_json: bool = True
) -> requests.Response:
"""HTTP POST with retries

Expand Down
39 changes: 39 additions & 0 deletions simvue/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,45 @@ def _retrieve_artifact_from_server(

return json_response

@prettify_pydantic
@pydantic.validate_call
def abort_run(self, run_id: str, reason: str) -> typing.Union[dict, list]:
"""Abort a currently active run on the server

Parameters
----------
run_id : str
the unique identifier for the run
reason : str
reason for abort

Returns
-------
dict | list
response from server
"""
body: dict[str, str | None] = {"id": run_id, "reason": reason}

response = requests.put(
f"{self._url}/api/runs/abort",
headers=self._headers,
json=body,
)

json_response = self._get_json_from_response(
expected_status=[200, 400],
scenario=f"Abort of run '{run_id}'",
response=response,
)

if not isinstance(json_response, dict):
raise RuntimeError(
"Expected list from JSON response during retrieval of "
f"artifact but got '{type(json_response)}'"
)

return json_response

@prettify_pydantic
@pydantic.validate_call
def get_artifact(
Expand Down
126 changes: 67 additions & 59 deletions simvue/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import sys
import multiprocessing
import os
import psutil
import subprocess
import pathlib
import time
Expand All @@ -26,17 +27,16 @@
logger = logging.getLogger(__name__)


class CompletionCallback(typing.Protocol):
def __call__(self, *, status_code: int, std_out: str, std_err: str) -> None: ...


def _execute_process(
proc_id: str,
command: typing.List[str],
runner_name: str,
exit_status_dict: typing.Dict[str, int],
std_err: typing.Dict[str, str],
std_out: typing.Dict[str, str],
run_on_exit: typing.Optional[typing.Callable[[int, int, str], None]],
trigger: typing.Optional[multiprocessing.synchronize.Event],
environment: typing.Optional[typing.Dict[str, str]],
) -> None:
) -> subprocess.Popen:
with open(f"{runner_name}_{proc_id}.err", "w") as err:
with open(f"{runner_name}_{proc_id}.out", "w") as out:
_result = subprocess.Popen(
Expand All @@ -47,24 +47,7 @@ def _execute_process(
env=environment,
)

_status_code = _result.wait()
with open(f"{runner_name}_{proc_id}.err") as err:
std_err[proc_id] = err.read()

with open(f"{runner_name}_{proc_id}.out") as out:
std_out[proc_id] = out.read()

exit_status_dict[proc_id] = _status_code

if run_on_exit:
run_on_exit(
status_code=exit_status_dict[proc_id],
std_out=std_out[proc_id],
std_err=std_err[proc_id],
)

if trigger:
trigger.set()
return _result


class Executor:
Expand All @@ -88,13 +71,16 @@ def __init__(self, simvue_runner: "simvue.Run", keep_logs: bool = True) -> None:
"""
self._runner = simvue_runner
self._keep_logs = keep_logs
self._manager = multiprocessing.Manager()
self._exit_codes = self._manager.dict()
self._std_err = self._manager.dict()
self._std_out = self._manager.dict()
self._completion_callbacks: dict[str, typing.Optional[CompletionCallback]] = {}
self._completion_triggers: dict[
str, typing.Optional[multiprocessing.synchronize.Event]
] = {}
self._exit_codes: dict[str, int] = {}
self._std_err: dict[str, str] = {}
self._std_out: dict[str, str] = {}
self._alert_ids: dict[str, str] = {}
self._command_str: typing.Dict[str, str] = {}
self._processes: typing.Dict[str, multiprocessing.Process] = {}
self._command_str: dict[str, str] = {}
self._processes: dict[str, subprocess.Popen] = {}

def add_process(
self,
Expand All @@ -104,9 +90,7 @@ def add_process(
script: typing.Optional[pathlib.Path] = None,
input_file: typing.Optional[pathlib.Path] = None,
env: typing.Optional[typing.Dict[str, str]] = None,
completion_callback: typing.Optional[
typing.Callable[[int, str, str], None]
] = None,
completion_callback: typing.Optional[CompletionCallback] = None,
completion_trigger: typing.Optional[multiprocessing.synchronize.Event] = None,
**kwargs,
) -> None:
Expand Down Expand Up @@ -161,6 +145,9 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None:
"""
_pos_args = list(args)

if not self._runner.name:
raise RuntimeError("Cannot add process, expected Run instance to have name")

if sys.platform == "win32" and completion_callback:
logger.warning(
"Completion callback for 'add_process' may fail on Windows due to "
Expand Down Expand Up @@ -207,26 +194,16 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None:
_command += _pos_args

self._command_str[identifier] = " ".join(_command)
self._completion_callbacks[identifier] = completion_callback
self._completion_triggers[identifier] = completion_trigger

self._processes[identifier] = multiprocessing.Process(
target=_execute_process,
args=(
identifier,
_command,
self._runner.name,
self._exit_codes,
self._std_err,
self._std_out,
completion_callback,
completion_trigger,
env,
),
self._processes[identifier] = _execute_process(
identifier, _command, self._runner.name, env
)

self._alert_ids[identifier] = self._runner.create_alert(
name=f"{identifier}_exit_status", source="user"
)
logger.debug(f"Executing process: {' '.join(_command)}")
self._processes[identifier].start()

@property
def success(self) -> int:
Expand Down Expand Up @@ -272,8 +249,8 @@ def _get_error_status(self, process_id: str) -> typing.Optional[str]:
err_msg: typing.Optional[str] = None

# Return last 10 lines of stdout if stderr empty
if not (err_msg := self._std_err[process_id]) and (
std_out := self._std_out[process_id]
if not (err_msg := self._std_err.get(process_id)) and (
std_out := self._std_out.get(process_id)
):
err_msg = " Tail STDOUT:\n\n"
start_index = -10 if len(lines := std_out.split("\n")) > 10 else 0
Expand Down Expand Up @@ -308,28 +285,42 @@ def _save_output(self) -> None:
"""Save the output to Simvue"""
for proc_id in self._exit_codes.keys():
# Only save the file if the contents are not empty
if self._std_err[proc_id]:
if self._std_err.get(proc_id):
self._runner.save_file(
f"{self._runner.name}_{proc_id}.err", category="output"
)
if self._std_out[proc_id]:
if self._std_out.get(proc_id):
self._runner.save_file(
f"{self._runner.name}_{proc_id}.out", category="output"
)

def kill_process(self, process_id: str) -> None:
"""Kill a running process by ID"""
if not (_process := self._processes.get(process_id)):
if not (process := self._processes.get(process_id)):
logger.error(
f"Failed to terminate process '{process_id}', no such identifier."
)
return
_process.kill()

parent = psutil.Process(process.pid)

for child in parent.children(recursive=True):
logger.debug(f"Terminating child process {child.pid}: {child.name()}")
child.kill()

for child in parent.children(recursive=True):
child.wait()

logger.debug(f"Terminating child process {process.pid}: {process.args}")
process.kill()
process.wait()

self._execute_callback(process_id)

def kill_all(self) -> None:
"""Kill all running processes"""
for process in self._processes.values():
process.kill()
for process in self._processes.keys():
self.kill_process(process)

def _clear_cache_files(self) -> None:
"""Clear local log files if required"""
Expand All @@ -338,11 +329,28 @@ def _clear_cache_files(self) -> None:
os.remove(f"{self._runner.name}_{proc_id}.err")
os.remove(f"{self._runner.name}_{proc_id}.out")

def _execute_callback(self, identifier: str) -> None:
with open(f"{self._runner.name}_{identifier}.err") as err:
std_err = err.read()

with open(f"{self._runner.name}_{identifier}.out") as out:
std_out = out.read()

if callback := self._completion_callbacks.get(identifier):
callback(
status_code=self._processes[identifier].returncode,
std_out=std_out,
std_err=std_err,
)
if completion_trigger := self._completion_triggers.get(identifier):
completion_trigger.set()

def wait_for_completion(self) -> None:
"""Wait for all processes to finish then perform tidy up and upload"""
for process in self._processes.values():
if process.is_alive():
process.join()
for identifier, process in self._processes.items():
process.wait()
self._execute_callback(identifier)

self._update_alerts()
self._save_output()

Expand Down
4 changes: 4 additions & 0 deletions simvue/factory/proxy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,7 @@ def send_heartbeat(self) -> typing.Optional[dict[str, typing.Any]]:
@abc.abstractmethod
def check_token(self) -> bool:
pass

@abc.abstractmethod
def get_abort_status(self) -> bool:
pass
6 changes: 6 additions & 0 deletions simvue/factory/proxy/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,15 @@ def set_alert_state(

@skip_if_failed("_aborted", "_suppress_errors", [])
def list_tags(self) -> list[dict[str, typing.Any]]:
#TODO: Tag retrieval not implemented for offline running
raise NotImplementedError(
"Retrieval of current tags is not implemented for offline running"
)

@skip_if_failed("_aborted", "_suppress_errors", True)
def get_abort_status(self) -> bool:
#TODO: Abort on failure not implemented for offline running
return True

@skip_if_failed("_aborted", "_suppress_errors", [])
def list_alerts(self) -> list[dict[str, typing.Any]]:
Expand Down
23 changes: 23 additions & 0 deletions simvue/factory/proxy/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,26 @@ def check_token(self) -> bool:
self._error("Token has expired")
return False
return True

@skip_if_failed("_aborted", "_suppress_errors", False)
def get_abort_status(self) -> bool:
logger.debug("Retrieving alert status")

try:
response = get(
f"{self._url}/api/runs/{self._id}/abort", self._headers_mp
)
except Exception as err:
self._error(f"Exception retrieving abort status: {str(err)}")
return False

logger.debug("Got status code %d when checking abort status", response.status_code)

if response.status_code == 200:
if (status := response.json().get("status")) is None:
self._error(f"Expected key 'status' when retrieving abort status {response.json()}")
return False
return status

self._error(f"Got status code {response.status_code} when checking abort status")
return False
Loading
Loading