2424from typing import TYPE_CHECKING
2525from typing import Any
2626from typing import Callable
27+ from typing import Iterable
2728from typing import Literal
2829from typing import NoReturn
2930from typing import cast
@@ -362,6 +363,7 @@ def start_precompile_and_check_for_hangs(
362363 args = (fn_spec , self ._precompile_args_path , child_conn , decorator ),
363364 ),
364365 )
366+ process .daemon = True
365367 else :
366368 ctx = mp .get_context ("fork" )
367369 parent_conn , child_conn = ctx .Pipe ()
@@ -372,6 +374,7 @@ def start_precompile_and_check_for_hangs(
372374 args = (fn , device_args , config , self .kernel , child_conn , decorator ),
373375 ),
374376 )
377+ process .daemon = True
375378 return PrecompileFuture (
376379 search = self ,
377380 config = config ,
@@ -937,11 +940,8 @@ def wait_for_all(
937940 while progress_left > len (remaining ):
938941 next (progress , None )
939942 progress_left -= 1
940- except Exception :
941- for f in remaining :
942- if (p := f .process ) is not None :
943- with contextlib .suppress (Exception ):
944- p .terminate ()
943+ except BaseException :
944+ PrecompileFuture ._cancel_all (futures )
945945 raise
946946 result = []
947947 for f in futures :
@@ -983,6 +983,47 @@ def _wait_for_all_step(
983983 remaining .append (f )
984984 return remaining
985985
986+ @staticmethod
987+ def _cancel_all (futures : Iterable [PrecompileFuture ]) -> None :
988+ """Cancel any futures that have not completed."""
989+ active = [future for future in futures if future .ok is None ]
990+ for future in active :
991+ with contextlib .suppress (Exception ):
992+ future ._kill_without_wait ()
993+ for future in active :
994+ with contextlib .suppress (Exception ):
995+ future .cancel ()
996+
997+ def _kill_without_wait (self ) -> None :
998+ """Issue a hard kill to the underlying process without waiting for exit."""
999+ process = self .process
1000+ if process is None or not self .started :
1001+ return
1002+ if process .is_alive ():
1003+ with contextlib .suppress (Exception ):
1004+ process .kill ()
1005+
1006+ def cancel (self ) -> None :
1007+ """Terminate the underlying process (if any) without waiting for success."""
1008+ self .end_time = time .time ()
1009+ process = self .process
1010+ if process is not None :
1011+ if self .started :
1012+ with contextlib .suppress (Exception ):
1013+ if process .is_alive ():
1014+ process .kill ()
1015+ process .join ()
1016+ if self .child_conn is not None :
1017+ with contextlib .suppress (Exception ):
1018+ self .child_conn .close ()
1019+ self .child_conn = None
1020+ if self .ok is None :
1021+ self .ok = False
1022+ if self .failure_reason is None :
1023+ self .failure_reason = "error"
1024+ self ._recv_result (block = False )
1025+ self ._handle_remote_error (raise_on_raise = False )
1026+
9861027 def _mark_complete (self ) -> bool :
9871028 """
9881029 Mark the precompile future as complete and kill the process if needed.
0 commit comments