Skip to content

Commit 5d0ebcb

Browse files
committed
Pass SIGTERM to training subprocess
feature: Pass SIGTERM to training subprocess fix: aws#125
1 parent 22a170a commit 5d0ebcb

File tree

1 file changed

+26
-5
lines changed

1 file changed

+26
-5
lines changed

src/sagemaker_training/process.py

+26-5
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717

1818
import asyncio
1919
from asyncio.subprocess import PIPE
20+
from contextlib import contextmanager
2021
import os
2122
import re
2223
import subprocess
2324
import sys
24-
2525
import six
26+
import signal
2627

2728
from sagemaker_training import (
2829
_entry_point_type,
@@ -36,6 +37,24 @@
3637
_DEFAULT_BUF_SIZE = 1024 * 64
3738

3839

40+
@contextmanager
41+
def capture_signal(signalnum, callback):
42+
"""
43+
Install handler to capture signal
44+
45+
Args:
46+
signalnum: signal to capture
47+
callback: callback if signal occurs
48+
49+
"""
50+
original_handler = signal.getsignal(signalnum)
51+
signal.signal(signalnum, callback)
52+
try:
53+
yield
54+
finally:
55+
signal.signal(signalnum, original_handler)
56+
57+
3958
async def watch(stream, proc_per_host):
4059
"""Process the stdout and stderr streams on the fly.
4160
Decode the output lines
@@ -118,9 +137,10 @@ async def run_async(cmd, processes_per_host, env, cwd, stderr, **kwargs):
118137
cmd, env=env, cwd=cwd, stdout=PIPE, stderr=stderr, **kwargs
119138
)
120139

121-
output = await asyncio.gather(
122-
watch(proc.stdout, processes_per_host), watch(proc.stderr, processes_per_host)
123-
)
140+
with capture_signal(signal.SIGTERM, lambda signalnum, frame: proc.send_signal(signalnum)):
141+
output = await asyncio.gather(
142+
watch(proc.stdout, processes_per_host), watch(proc.stderr, processes_per_host)
143+
)
124144
return_code = proc.returncode
125145
return return_code, output, proc
126146

@@ -198,7 +218,8 @@ def check_error(cmd, error_class, processes_per_host, cwd=None, capture_error=Tr
198218
process = subprocess.Popen(
199219
cmd, env=os.environ, cwd=cwd or environment.code_dir, stderr=stderr, **kwargs
200220
)
201-
return_code = process.wait()
221+
with capture_signal(signal.SIGTERM, lambda signalnum, frame: process.send_signal(signalnum)):
222+
return_code = process.wait()
202223
if return_code:
203224
extra_info = None
204225
if return_code == 137:

0 commit comments

Comments
 (0)