|
17 | 17 |
|
18 | 18 | import asyncio
|
19 | 19 | from asyncio.subprocess import PIPE
|
| 20 | +from contextlib import contextmanager |
20 | 21 | import os
|
21 | 22 | import re
|
22 | 23 | import subprocess
|
23 | 24 | import sys
|
24 |
| - |
25 | 25 | import six
|
| 26 | +import signal |
26 | 27 |
|
27 | 28 | from sagemaker_training import (
|
28 | 29 | _entry_point_type,
|
|
36 | 37 | _DEFAULT_BUF_SIZE = 1024 * 64
|
37 | 38 |
|
38 | 39 |
|
| 40 | +@contextmanager |
| 41 | +def capture_signal(signalnum, callback): |
| 42 | + """ |
| 43 | + Install handler to capture signal |
| 44 | +
|
| 45 | + Args: |
| 46 | + signalnum: signal to capture |
| 47 | + callback: callback if signal occurs |
| 48 | +
|
| 49 | + """ |
| 50 | + original_handler = signal.getsignal(signalnum) |
| 51 | + signal.signal(signalnum, callback) |
| 52 | + try: |
| 53 | + yield |
| 54 | + finally: |
| 55 | + signal.signal(signalnum, original_handler) |
| 56 | + |
| 57 | + |
39 | 58 | async def watch(stream, proc_per_host):
|
40 | 59 | """Process the stdout and stderr streams on the fly.
|
41 | 60 | Decode the output lines
|
@@ -118,9 +137,10 @@ async def run_async(cmd, processes_per_host, env, cwd, stderr, **kwargs):
|
118 | 137 | cmd, env=env, cwd=cwd, stdout=PIPE, stderr=stderr, **kwargs
|
119 | 138 | )
|
120 | 139 |
|
121 |
| - output = await asyncio.gather( |
122 |
| - watch(proc.stdout, processes_per_host), watch(proc.stderr, processes_per_host) |
123 |
| - ) |
| 140 | + with capture_signal(signal.SIGTERM, lambda signalnum, frame: proc.send_signal(signalnum)): |
| 141 | + output = await asyncio.gather( |
| 142 | + watch(proc.stdout, processes_per_host), watch(proc.stderr, processes_per_host) |
| 143 | + ) |
124 | 144 | return_code = proc.returncode
|
125 | 145 | return return_code, output, proc
|
126 | 146 |
|
@@ -198,7 +218,8 @@ def check_error(cmd, error_class, processes_per_host, cwd=None, capture_error=Tr
|
198 | 218 | process = subprocess.Popen(
|
199 | 219 | cmd, env=os.environ, cwd=cwd or environment.code_dir, stderr=stderr, **kwargs
|
200 | 220 | )
|
201 |
| - return_code = process.wait() |
| 221 | + with capture_signal(signal.SIGTERM, lambda signalnum, frame: process.send_signal(signalnum)): |
| 222 | + return_code = process.wait() |
202 | 223 | if return_code:
|
203 | 224 | extra_info = None
|
204 | 225 | if return_code == 137:
|
|
0 commit comments