2424from typing import TYPE_CHECKING
2525from typing import Any
2626from typing import Callable
27+ from typing import Literal
2728from typing import NoReturn
2829from typing import cast
2930from unittest .mock import patch
@@ -377,7 +378,14 @@ def start_precompile_and_check_for_hangs(
377378
378379 def parallel_benchmark (
379380 self , configs : list [Config ], * , desc : str = "Benchmarking"
380- ) -> list [tuple [Config , Callable [..., object ], float ]]:
381+ ) -> list [
382+ tuple [
383+ Config ,
384+ Callable [..., object ],
385+ float ,
386+ Literal ["ok" , "error" , "timeout" ],
387+ ]
388+ ]:
381389 """
382390 Benchmark multiple configurations in parallel.
383391
@@ -389,35 +397,55 @@ def parallel_benchmark(
389397 A list of tuples containing configurations and their performance.
390398 """
391399 fns = [self .kernel .compile_config (c , allow_print = False ) for c in configs ]
400+ precompile_status : list [Literal ["ok" , "error" , "timeout" ]]
392401 if self .settings .autotune_precompile :
402+ futures = [
403+ * starmap (
404+ self .start_precompile_and_check_for_hangs ,
405+ zip (configs , fns , strict = True ),
406+ )
407+ ]
393408 is_workings = PrecompileFuture .wait_for_all (
394- [
395- * starmap (
396- self .start_precompile_and_check_for_hangs ,
397- zip (configs , fns , strict = True ),
398- )
399- ],
409+ futures ,
400410 desc = f"{ desc } precompiling"
401411 if self .settings .autotune_progress_bar
402412 else None ,
403413 )
414+ precompile_status = []
415+ for future , ok in zip (futures , is_workings , strict = True ):
416+ reason = future .failure_reason
417+ if ok :
418+ precompile_status .append ("ok" )
419+ elif reason == "timeout" :
420+ precompile_status .append ("timeout" )
421+ else :
422+ precompile_status .append ("error" )
404423 else :
405424 is_workings = [True ] * len (configs )
406- results = []
425+ precompile_status = ["ok" ] * len (configs )
426+ results : list [
427+ tuple [
428+ Config , Callable [..., object ], float , Literal ["ok" , "error" , "timeout" ]
429+ ]
430+ ] = []
407431
408432 # Render a progress bar only when the user requested it.
409433 iterator = iter_with_progress (
410- zip (configs , fns , is_workings , strict = True ),
434+ zip (configs , fns , is_workings , precompile_status , strict = True ),
411435 total = len (configs ),
412436 description = f"{ desc } exploring neighbors" ,
413437 enabled = self .settings .autotune_progress_bar ,
414438 )
415- for config , fn , is_working in iterator :
439+ for config , fn , is_working , reason in iterator :
440+ status : Literal ["ok" , "error" , "timeout" ]
416441 if is_working :
417442 # benchmark one-by-one to avoid noisy results
418- results .append ((config , fn , self .benchmark_function (config , fn )))
443+ perf = self .benchmark_function (config , fn )
444+ status = "ok" if math .isfinite (perf ) else "error"
445+ results .append ((config , fn , perf , status ))
419446 else :
420- results .append ((config , fn , inf ))
447+ status = "timeout" if reason == "timeout" else "error"
448+ results .append ((config , fn , inf , status ))
421449 return results
422450
423451 def autotune (self , * , skip_cache : bool = False ) -> Config :
@@ -486,6 +514,7 @@ class PopulationMember:
486514 perfs : list [float ]
487515 flat_values : FlatConfig
488516 config : Config
517+ status : Literal ["ok" , "error" , "timeout" , "unknown" ] = "unknown"
489518
490519 @property
491520 def perf (self ) -> float :
@@ -581,7 +610,8 @@ def benchmark_flat(self, flat_values: FlatConfig) -> PopulationMember:
581610 """
582611 config = self .config_gen .unflatten (flat_values )
583612 fn , perf = self .benchmark (config )
584- return PopulationMember (fn , [perf ], flat_values , config )
613+ status : Literal ["ok" , "error" ] = "ok" if math .isfinite (perf ) else "error"
614+ return PopulationMember (fn , [perf ], flat_values , config , status = status )
585615
586616 def parallel_benchmark_flat (
587617 self , to_check : list [FlatConfig ]
@@ -622,14 +652,15 @@ def parallel_benchmark_population(
622652 members: The list of population members to benchmark.
623653 desc: Description for the progress bar.
624654 """
625- for member , (config_out , fn , perf ) in zip (
655+ for member , (config_out , fn , perf , status ) in zip (
626656 members ,
627657 self .parallel_benchmark ([m .config for m in members ], desc = desc ),
628658 strict = True ,
629659 ):
630660 assert config_out is member .config
631661 member .perfs .append (perf )
632662 member .fn = fn
663+ member .status = status
633664 return members
634665
635666 def compare (self , a : PopulationMember , b : PopulationMember ) -> int :
@@ -730,23 +761,39 @@ def population_statistics(population: list[PopulationMember]) -> str:
730761 A string summarizing the performance of the population.
731762 """
732763 population = sorted (population , key = performance )
733- if math .isinf (population [- 1 ].perf ):
734- working = [x for x in population if not math .isinf (x .perf )]
735- if len (working ) == 0 :
736- raise exc .NoConfigFound
737- return (
738- f"failed={ len (population ) - len (working )} "
739- f"min={ working [0 ].perf :.4f} "
740- f"mid={ working [len (working ) // 2 ].perf :.4f} "
741- f"max={ working [- 1 ].perf :.4f} "
742- f"best={ population [0 ].config !s} "
764+ status_counts : collections .Counter [str ] = collections .Counter ()
765+ working : list [PopulationMember ] = []
766+ for member in population :
767+ status = member .status
768+ if math .isfinite (member .perf ):
769+ working .append (member )
770+ if status not in {"ok" , "error" , "timeout" }:
771+ status = "ok"
772+ else :
773+ if status not in {"error" , "timeout" }:
774+ status = "error"
775+ if status == "timeout" :
776+ status_counts ["timeout" ] += 1
777+ elif status == "error" :
778+ status_counts ["error" ] += 1
779+ else :
780+ status_counts ["ok" ] += 1
781+ if len (working ) == 0 :
782+ raise exc .NoConfigFound
783+ parts : list [str ] = []
784+ for label in ("error" , "timeout" , "ok" ):
785+ count = status_counts .get (label , 0 )
786+ if count :
787+ parts .append (f"{ label } ={ count } " )
788+ parts .extend (
789+ (
790+ f"min={ working [0 ].perf :.4f} " ,
791+ f"mid={ working [len (working ) // 2 ].perf :.4f} " ,
792+ f"max={ working [- 1 ].perf :.4f} " ,
793+ f"best={ population [0 ].config !s} " ,
743794 )
744- return (
745- f"min={ population [0 ].perf :.4f} "
746- f"mid={ population [len (population ) // 2 ].perf :.4f} "
747- f"max={ population [- 1 ].perf :.4f} "
748- f"best={ population [0 ].config !s} "
749795 )
796+ return " " .join (parts )
750797
751798
752799@dataclasses .dataclass
@@ -777,6 +824,7 @@ class PrecompileFuture:
777824 _result_received : bool = False
778825 remote_error : RemoteError | None = None
779826 _remote_error_handled : bool = False
827+ failure_reason : Literal ["ok" , "error" , "timeout" ] | None = None
780828
781829 @property
782830 def elapsed (self ) -> float :
@@ -834,6 +882,7 @@ def skip(search: BaseSearch, config: Config, ok: bool) -> PrecompileFuture:
834882 _result_received = True ,
835883 remote_error = None ,
836884 _remote_error_handled = True ,
885+ failure_reason = "ok" if ok else "error" ,
837886 )
838887
839888 def __call__ (self ) -> bool :
@@ -892,6 +941,8 @@ def wait_for_all(
892941 result = []
893942 for f in futures :
894943 assert f .ok is not None
944+ if f .failure_reason is None :
945+ f .failure_reason = "ok" if f .ok else "error"
895946 result .append (f .ok )
896947 return result
897948
@@ -945,6 +996,10 @@ def _mark_complete(self) -> bool:
945996 self .ok = process .exitcode == 0
946997 self ._recv_result (block = True )
947998 self ._handle_remote_error (raise_on_raise = False )
999+ if self .ok :
1000+ self .failure_reason = "ok"
1001+ elif self .failure_reason is None :
1002+ self .failure_reason = "error"
9481003 return self .ok
9491004 process .terminate ()
9501005 process .join (10 )
@@ -960,6 +1015,7 @@ def _mark_complete(self) -> bool:
9601015 self .search .log .warning (msg )
9611016
9621017 self .ok = False
1018+ self .failure_reason = "timeout"
9631019 self ._recv_result (block = False )
9641020 self ._handle_remote_error (raise_on_raise = False )
9651021 return False
0 commit comments