Skip to content

Commit 6e7b147

Browse files
authored
Exit autotuning faster on KeyboardInterrupt (#963)
1 parent 4486f7b commit 6e7b147

File tree

1 file changed

+46
-5
lines changed

1 file changed

+46
-5
lines changed

helion/autotuner/base_search.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing import TYPE_CHECKING
2525
from typing import Any
2626
from typing import Callable
27+
from typing import Iterable
2728
from typing import Literal
2829
from typing import NoReturn
2930
from typing import cast
@@ -362,6 +363,7 @@ def start_precompile_and_check_for_hangs(
362363
args=(fn_spec, self._precompile_args_path, child_conn, decorator),
363364
),
364365
)
366+
process.daemon = True
365367
else:
366368
ctx = mp.get_context("fork")
367369
parent_conn, child_conn = ctx.Pipe()
@@ -372,6 +374,7 @@ def start_precompile_and_check_for_hangs(
372374
args=(fn, device_args, config, self.kernel, child_conn, decorator),
373375
),
374376
)
377+
process.daemon = True
375378
return PrecompileFuture(
376379
search=self,
377380
config=config,
@@ -937,11 +940,8 @@ def wait_for_all(
937940
while progress_left > len(remaining):
938941
next(progress, None)
939942
progress_left -= 1
940-
except Exception:
941-
for f in remaining:
942-
if (p := f.process) is not None:
943-
with contextlib.suppress(Exception):
944-
p.terminate()
943+
except BaseException:
944+
PrecompileFuture._cancel_all(futures)
945945
raise
946946
result = []
947947
for f in futures:
@@ -983,6 +983,47 @@ def _wait_for_all_step(
983983
remaining.append(f)
984984
return remaining
985985

986+
@staticmethod
987+
def _cancel_all(futures: Iterable[PrecompileFuture]) -> None:
988+
"""Cancel any futures that have not completed."""
989+
active = [future for future in futures if future.ok is None]
990+
for future in active:
991+
with contextlib.suppress(Exception):
992+
future._kill_without_wait()
993+
for future in active:
994+
with contextlib.suppress(Exception):
995+
future.cancel()
996+
997+
def _kill_without_wait(self) -> None:
998+
"""Issue a hard kill to the underlying process without waiting for exit."""
999+
process = self.process
1000+
if process is None or not self.started:
1001+
return
1002+
if process.is_alive():
1003+
with contextlib.suppress(Exception):
1004+
process.kill()
1005+
1006+
def cancel(self) -> None:
1007+
"""Terminate the underlying process (if any) without waiting for success."""
1008+
self.end_time = time.time()
1009+
process = self.process
1010+
if process is not None:
1011+
if self.started:
1012+
with contextlib.suppress(Exception):
1013+
if process.is_alive():
1014+
process.kill()
1015+
process.join()
1016+
if self.child_conn is not None:
1017+
with contextlib.suppress(Exception):
1018+
self.child_conn.close()
1019+
self.child_conn = None
1020+
if self.ok is None:
1021+
self.ok = False
1022+
if self.failure_reason is None:
1023+
self.failure_reason = "error"
1024+
self._recv_result(block=False)
1025+
self._handle_remote_error(raise_on_raise=False)
1026+
9861027
def _mark_complete(self) -> bool:
9871028
"""
9881029
Mark the precompile future as complete and kill the process if needed.

0 commit comments

Comments
 (0)