Skip to content

Commit df0c802

Browse files
authored
Merge pull request #399 from simvue-io/feature/auto-abort
Add auto-abort feature to client
2 parents 659df9b + ef5f2fe commit df0c802

File tree

9 files changed

+264
-78
lines changed

9 files changed

+264
-78
lines changed

simvue/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def set_json_header(headers: dict[str, str]) -> dict[str, str]:
4444
reraise=True,
4545
)
4646
def post(
47-
url: str, headers: dict[str, str], data: dict[str, typing.Any], is_json: bool = True
47+
url: str, headers: dict[str, str], data: typing.Any, is_json: bool = True
4848
) -> requests.Response:
4949
"""HTTP POST with retries
5050

simvue/client.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,45 @@ def _retrieve_artifact_from_server(
603603

604604
return json_response
605605

606+
@prettify_pydantic
607+
@pydantic.validate_call
608+
def abort_run(self, run_id: str, reason: str) -> typing.Union[dict, list]:
609+
"""Abort a currently active run on the server
610+
611+
Parameters
612+
----------
613+
run_id : str
614+
the unique identifier for the run
615+
reason : str
616+
reason for abort
617+
618+
Returns
619+
-------
620+
dict | list
621+
response from server
622+
"""
623+
body: dict[str, str | None] = {"id": run_id, "reason": reason}
624+
625+
response = requests.put(
626+
f"{self._url}/api/runs/abort",
627+
headers=self._headers,
628+
json=body,
629+
)
630+
631+
json_response = self._get_json_from_response(
632+
expected_status=[200, 400],
633+
scenario=f"Abort of run '{run_id}'",
634+
response=response,
635+
)
636+
637+
if not isinstance(json_response, dict):
638+
raise RuntimeError(
639+
"Expected list from JSON response during retrieval of "
640+
f"artifact but got '{type(json_response)}'"
641+
)
642+
643+
return json_response
644+
606645
@prettify_pydantic
607646
@pydantic.validate_call
608647
def get_artifact(

simvue/executor.py

Lines changed: 67 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import sys
1616
import multiprocessing
1717
import os
18+
import psutil
1819
import subprocess
1920
import pathlib
2021
import time
@@ -26,17 +27,16 @@
2627
logger = logging.getLogger(__name__)
2728

2829

30+
class CompletionCallback(typing.Protocol):
31+
def __call__(self, *, status_code: int, std_out: str, std_err: str) -> None: ...
32+
33+
2934
def _execute_process(
3035
proc_id: str,
3136
command: typing.List[str],
3237
runner_name: str,
33-
exit_status_dict: typing.Dict[str, int],
34-
std_err: typing.Dict[str, str],
35-
std_out: typing.Dict[str, str],
36-
run_on_exit: typing.Optional[typing.Callable[[int, int, str], None]],
37-
trigger: typing.Optional[multiprocessing.synchronize.Event],
3838
environment: typing.Optional[typing.Dict[str, str]],
39-
) -> None:
39+
) -> subprocess.Popen:
4040
with open(f"{runner_name}_{proc_id}.err", "w") as err:
4141
with open(f"{runner_name}_{proc_id}.out", "w") as out:
4242
_result = subprocess.Popen(
@@ -47,24 +47,7 @@ def _execute_process(
4747
env=environment,
4848
)
4949

50-
_status_code = _result.wait()
51-
with open(f"{runner_name}_{proc_id}.err") as err:
52-
std_err[proc_id] = err.read()
53-
54-
with open(f"{runner_name}_{proc_id}.out") as out:
55-
std_out[proc_id] = out.read()
56-
57-
exit_status_dict[proc_id] = _status_code
58-
59-
if run_on_exit:
60-
run_on_exit(
61-
status_code=exit_status_dict[proc_id],
62-
std_out=std_out[proc_id],
63-
std_err=std_err[proc_id],
64-
)
65-
66-
if trigger:
67-
trigger.set()
50+
return _result
6851

6952

7053
class Executor:
@@ -88,13 +71,16 @@ def __init__(self, simvue_runner: "simvue.Run", keep_logs: bool = True) -> None:
8871
"""
8972
self._runner = simvue_runner
9073
self._keep_logs = keep_logs
91-
self._manager = multiprocessing.Manager()
92-
self._exit_codes = self._manager.dict()
93-
self._std_err = self._manager.dict()
94-
self._std_out = self._manager.dict()
74+
self._completion_callbacks: dict[str, typing.Optional[CompletionCallback]] = {}
75+
self._completion_triggers: dict[
76+
str, typing.Optional[multiprocessing.synchronize.Event]
77+
] = {}
78+
self._exit_codes: dict[str, int] = {}
79+
self._std_err: dict[str, str] = {}
80+
self._std_out: dict[str, str] = {}
9581
self._alert_ids: dict[str, str] = {}
96-
self._command_str: typing.Dict[str, str] = {}
97-
self._processes: typing.Dict[str, multiprocessing.Process] = {}
82+
self._command_str: dict[str, str] = {}
83+
self._processes: dict[str, subprocess.Popen] = {}
9884

9985
def add_process(
10086
self,
@@ -104,9 +90,7 @@ def add_process(
10490
script: typing.Optional[pathlib.Path] = None,
10591
input_file: typing.Optional[pathlib.Path] = None,
10692
env: typing.Optional[typing.Dict[str, str]] = None,
107-
completion_callback: typing.Optional[
108-
typing.Callable[[int, str, str], None]
109-
] = None,
93+
completion_callback: typing.Optional[CompletionCallback] = None,
11094
completion_trigger: typing.Optional[multiprocessing.synchronize.Event] = None,
11195
**kwargs,
11296
) -> None:
@@ -161,6 +145,9 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None:
161145
"""
162146
_pos_args = list(args)
163147

148+
if not self._runner.name:
149+
raise RuntimeError("Cannot add process, expected Run instance to have name")
150+
164151
if sys.platform == "win32" and completion_callback:
165152
logger.warning(
166153
"Completion callback for 'add_process' may fail on Windows due to "
@@ -207,26 +194,16 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None:
207194
_command += _pos_args
208195

209196
self._command_str[identifier] = " ".join(_command)
197+
self._completion_callbacks[identifier] = completion_callback
198+
self._completion_triggers[identifier] = completion_trigger
210199

211-
self._processes[identifier] = multiprocessing.Process(
212-
target=_execute_process,
213-
args=(
214-
identifier,
215-
_command,
216-
self._runner.name,
217-
self._exit_codes,
218-
self._std_err,
219-
self._std_out,
220-
completion_callback,
221-
completion_trigger,
222-
env,
223-
),
200+
self._processes[identifier] = _execute_process(
201+
identifier, _command, self._runner.name, env
224202
)
203+
225204
self._alert_ids[identifier] = self._runner.create_alert(
226205
name=f"{identifier}_exit_status", source="user"
227206
)
228-
logger.debug(f"Executing process: {' '.join(_command)}")
229-
self._processes[identifier].start()
230207

231208
@property
232209
def success(self) -> int:
@@ -272,8 +249,8 @@ def _get_error_status(self, process_id: str) -> typing.Optional[str]:
272249
err_msg: typing.Optional[str] = None
273250

274251
# Return last 10 lines of stdout if stderr empty
275-
if not (err_msg := self._std_err[process_id]) and (
276-
std_out := self._std_out[process_id]
252+
if not (err_msg := self._std_err.get(process_id)) and (
253+
std_out := self._std_out.get(process_id)
277254
):
278255
err_msg = " Tail STDOUT:\n\n"
279256
start_index = -10 if len(lines := std_out.split("\n")) > 10 else 0
@@ -308,28 +285,42 @@ def _save_output(self) -> None:
308285
"""Save the output to Simvue"""
309286
for proc_id in self._exit_codes.keys():
310287
# Only save the file if the contents are not empty
311-
if self._std_err[proc_id]:
288+
if self._std_err.get(proc_id):
312289
self._runner.save_file(
313290
f"{self._runner.name}_{proc_id}.err", category="output"
314291
)
315-
if self._std_out[proc_id]:
292+
if self._std_out.get(proc_id):
316293
self._runner.save_file(
317294
f"{self._runner.name}_{proc_id}.out", category="output"
318295
)
319296

320297
def kill_process(self, process_id: str) -> None:
321298
"""Kill a running process by ID"""
322-
if not (_process := self._processes.get(process_id)):
299+
if not (process := self._processes.get(process_id)):
323300
logger.error(
324301
f"Failed to terminate process '{process_id}', no such identifier."
325302
)
326303
return
327-
_process.kill()
304+
305+
parent = psutil.Process(process.pid)
306+
307+
for child in parent.children(recursive=True):
308+
logger.debug(f"Terminating child process {child.pid}: {child.name()}")
309+
child.kill()
310+
311+
for child in parent.children(recursive=True):
312+
child.wait()
313+
314+
logger.debug(f"Terminating child process {process.pid}: {process.args}")
315+
process.kill()
316+
process.wait()
317+
318+
self._execute_callback(process_id)
328319

329320
def kill_all(self) -> None:
330321
"""Kill all running processes"""
331-
for process in self._processes.values():
332-
process.kill()
322+
for process in self._processes.keys():
323+
self.kill_process(process)
333324

334325
def _clear_cache_files(self) -> None:
335326
"""Clear local log files if required"""
@@ -338,11 +329,28 @@ def _clear_cache_files(self) -> None:
338329
os.remove(f"{self._runner.name}_{proc_id}.err")
339330
os.remove(f"{self._runner.name}_{proc_id}.out")
340331

332+
def _execute_callback(self, identifier: str) -> None:
333+
with open(f"{self._runner.name}_{identifier}.err") as err:
334+
std_err = err.read()
335+
336+
with open(f"{self._runner.name}_{identifier}.out") as out:
337+
std_out = out.read()
338+
339+
if callback := self._completion_callbacks.get(identifier):
340+
callback(
341+
status_code=self._processes[identifier].returncode,
342+
std_out=std_out,
343+
std_err=std_err,
344+
)
345+
if completion_trigger := self._completion_triggers.get(identifier):
346+
completion_trigger.set()
347+
341348
def wait_for_completion(self) -> None:
342349
"""Wait for all processes to finish then perform tidy up and upload"""
343-
for process in self._processes.values():
344-
if process.is_alive():
345-
process.join()
350+
for identifier, process in self._processes.items():
351+
process.wait()
352+
self._execute_callback(identifier)
353+
346354
self._update_alerts()
347355
self._save_output()
348356

simvue/factory/proxy/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,7 @@ def send_heartbeat(self) -> typing.Optional[dict[str, typing.Any]]:
8989
@abc.abstractmethod
9090
def check_token(self) -> bool:
9191
pass
92+
93+
@abc.abstractmethod
94+
def get_abort_status(self) -> bool:
95+
pass

simvue/factory/proxy/offline.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,15 @@ def set_alert_state(
176176

177177
@skip_if_failed("_aborted", "_suppress_errors", [])
178178
def list_tags(self) -> list[dict[str, typing.Any]]:
179+
#TODO: Tag retrieval not implemented for offline running
179180
raise NotImplementedError(
180181
"Retrieval of current tags is not implemented for offline running"
181182
)
183+
184+
@skip_if_failed("_aborted", "_suppress_errors", True)
185+
def get_abort_status(self) -> bool:
186+
#TODO: Abort on failure not implemented for offline running
187+
return True
182188

183189
@skip_if_failed("_aborted", "_suppress_errors", [])
184190
def list_alerts(self) -> list[dict[str, typing.Any]]:

simvue/factory/proxy/remote.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,3 +466,26 @@ def check_token(self) -> bool:
466466
self._error("Token has expired")
467467
return False
468468
return True
469+
470+
@skip_if_failed("_aborted", "_suppress_errors", False)
471+
def get_abort_status(self) -> bool:
472+
logger.debug("Retrieving alert status")
473+
474+
try:
475+
response = get(
476+
f"{self._url}/api/runs/{self._id}/abort", self._headers_mp
477+
)
478+
except Exception as err:
479+
self._error(f"Exception retrieving abort status: {str(err)}")
480+
return False
481+
482+
logger.debug("Got status code %d when checking abort status", response.status_code)
483+
484+
if response.status_code == 200:
485+
if (status := response.json().get("status")) is None:
486+
self._error(f"Expected key 'status' when retrieving abort status {response.json()}")
487+
return False
488+
return status
489+
490+
self._error(f"Got status code {response.status_code} when checking abort status")
491+
return False

0 commit comments

Comments
 (0)