Skip to content

Commit 4e46030

Browse files
committed
Print out errors vs timeouts in autotuning status
stack-info: PR: #960, branch: jansel/stack/200
1 parent 35387fa commit 4e46030

File tree

2 files changed

+86
-30
lines changed

2 files changed

+86
-30
lines changed

helion/autotuner/base_search.py

Lines changed: 85 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing import TYPE_CHECKING
2525
from typing import Any
2626
from typing import Callable
27+
from typing import Literal
2728
from typing import NoReturn
2829
from typing import cast
2930
from unittest.mock import patch
@@ -377,7 +378,14 @@ def start_precompile_and_check_for_hangs(
377378

378379
def parallel_benchmark(
379380
self, configs: list[Config], *, desc: str = "Benchmarking"
380-
) -> list[tuple[Config, Callable[..., object], float]]:
381+
) -> list[
382+
tuple[
383+
Config,
384+
Callable[..., object],
385+
float,
386+
Literal["ok", "error", "timeout"],
387+
]
388+
]:
381389
"""
382390
Benchmark multiple configurations in parallel.
383391
@@ -389,35 +397,55 @@ def parallel_benchmark(
389397
A list of tuples containing configurations and their performance.
390398
"""
391399
fns = [self.kernel.compile_config(c, allow_print=False) for c in configs]
400+
precompile_status: list[Literal["ok", "error", "timeout"]]
392401
if self.settings.autotune_precompile:
402+
futures = [
403+
*starmap(
404+
self.start_precompile_and_check_for_hangs,
405+
zip(configs, fns, strict=True),
406+
)
407+
]
393408
is_workings = PrecompileFuture.wait_for_all(
394-
[
395-
*starmap(
396-
self.start_precompile_and_check_for_hangs,
397-
zip(configs, fns, strict=True),
398-
)
399-
],
409+
futures,
400410
desc=f"{desc} precompiling"
401411
if self.settings.autotune_progress_bar
402412
else None,
403413
)
414+
precompile_status = []
415+
for future, ok in zip(futures, is_workings, strict=True):
416+
reason = future.failure_reason
417+
if ok:
418+
precompile_status.append("ok")
419+
elif reason == "timeout":
420+
precompile_status.append("timeout")
421+
else:
422+
precompile_status.append("error")
404423
else:
405424
is_workings = [True] * len(configs)
406-
results = []
425+
precompile_status = ["ok"] * len(configs)
426+
results: list[
427+
tuple[
428+
Config, Callable[..., object], float, Literal["ok", "error", "timeout"]
429+
]
430+
] = []
407431

408432
# Render a progress bar only when the user requested it.
409433
iterator = iter_with_progress(
410-
zip(configs, fns, is_workings, strict=True),
434+
zip(configs, fns, is_workings, precompile_status, strict=True),
411435
total=len(configs),
412436
description=f"{desc} exploring neighbors",
413437
enabled=self.settings.autotune_progress_bar,
414438
)
415-
for config, fn, is_working in iterator:
439+
for config, fn, is_working, reason in iterator:
440+
status: Literal["ok", "error", "timeout"]
416441
if is_working:
417442
# benchmark one-by-one to avoid noisy results
418-
results.append((config, fn, self.benchmark_function(config, fn)))
443+
perf = self.benchmark_function(config, fn)
444+
status = "ok" if math.isfinite(perf) else "error"
445+
results.append((config, fn, perf, status))
419446
else:
420-
results.append((config, fn, inf))
447+
status = "timeout" if reason == "timeout" else "error"
448+
results.append((config, fn, inf, status))
421449
return results
422450

423451
def autotune(self, *, skip_cache: bool = False) -> Config:
@@ -486,6 +514,7 @@ class PopulationMember:
486514
perfs: list[float]
487515
flat_values: FlatConfig
488516
config: Config
517+
status: Literal["ok", "error", "timeout", "unknown"] = "unknown"
489518

490519
@property
491520
def perf(self) -> float:
@@ -581,7 +610,8 @@ def benchmark_flat(self, flat_values: FlatConfig) -> PopulationMember:
581610
"""
582611
config = self.config_gen.unflatten(flat_values)
583612
fn, perf = self.benchmark(config)
584-
return PopulationMember(fn, [perf], flat_values, config)
613+
status: Literal["ok", "error"] = "ok" if math.isfinite(perf) else "error"
614+
return PopulationMember(fn, [perf], flat_values, config, status=status)
585615

586616
def parallel_benchmark_flat(
587617
self, to_check: list[FlatConfig]
@@ -622,14 +652,15 @@ def parallel_benchmark_population(
622652
members: The list of population members to benchmark.
623653
desc: Description for the progress bar.
624654
"""
625-
for member, (config_out, fn, perf) in zip(
655+
for member, (config_out, fn, perf, status) in zip(
626656
members,
627657
self.parallel_benchmark([m.config for m in members], desc=desc),
628658
strict=True,
629659
):
630660
assert config_out is member.config
631661
member.perfs.append(perf)
632662
member.fn = fn
663+
member.status = status
633664
return members
634665

635666
def compare(self, a: PopulationMember, b: PopulationMember) -> int:
@@ -730,23 +761,39 @@ def population_statistics(population: list[PopulationMember]) -> str:
730761
A string summarizing the performance of the population.
731762
"""
732763
population = sorted(population, key=performance)
733-
if math.isinf(population[-1].perf):
734-
working = [x for x in population if not math.isinf(x.perf)]
735-
if len(working) == 0:
736-
raise exc.NoConfigFound
737-
return (
738-
f"failed={len(population) - len(working)} "
739-
f"min={working[0].perf:.4f} "
740-
f"mid={working[len(working) // 2].perf:.4f} "
741-
f"max={working[-1].perf:.4f} "
742-
f"best={population[0].config!s}"
764+
status_counts: collections.Counter[str] = collections.Counter()
765+
working: list[PopulationMember] = []
766+
for member in population:
767+
status = member.status
768+
if math.isfinite(member.perf):
769+
working.append(member)
770+
if status not in {"ok", "error", "timeout"}:
771+
status = "ok"
772+
else:
773+
if status not in {"error", "timeout"}:
774+
status = "error"
775+
if status == "timeout":
776+
status_counts["timeout"] += 1
777+
elif status == "error":
778+
status_counts["error"] += 1
779+
else:
780+
status_counts["ok"] += 1
781+
if len(working) == 0:
782+
raise exc.NoConfigFound
783+
parts: list[str] = []
784+
for label in ("error", "timeout", "ok"):
785+
count = status_counts.get(label, 0)
786+
if count:
787+
parts.append(f"{label}={count}")
788+
parts.extend(
789+
(
790+
f"min={working[0].perf:.4f}",
791+
f"mid={working[len(working) // 2].perf:.4f}",
792+
f"max={working[-1].perf:.4f}",
793+
f"best={population[0].config!s}",
743794
)
744-
return (
745-
f"min={population[0].perf:.4f} "
746-
f"mid={population[len(population) // 2].perf:.4f} "
747-
f"max={population[-1].perf:.4f} "
748-
f"best={population[0].config!s}"
749795
)
796+
return " ".join(parts)
750797

751798

752799
@dataclasses.dataclass
@@ -777,6 +824,7 @@ class PrecompileFuture:
777824
_result_received: bool = False
778825
remote_error: RemoteError | None = None
779826
_remote_error_handled: bool = False
827+
failure_reason: Literal["ok", "error", "timeout"] | None = None
780828

781829
@property
782830
def elapsed(self) -> float:
@@ -834,6 +882,7 @@ def skip(search: BaseSearch, config: Config, ok: bool) -> PrecompileFuture:
834882
_result_received=True,
835883
remote_error=None,
836884
_remote_error_handled=True,
885+
failure_reason="ok" if ok else "error",
837886
)
838887

839888
def __call__(self) -> bool:
@@ -892,6 +941,8 @@ def wait_for_all(
892941
result = []
893942
for f in futures:
894943
assert f.ok is not None
944+
if f.failure_reason is None:
945+
f.failure_reason = "ok" if f.ok else "error"
895946
result.append(f.ok)
896947
return result
897948

@@ -945,6 +996,10 @@ def _mark_complete(self) -> bool:
945996
self.ok = process.exitcode == 0
946997
self._recv_result(block=True)
947998
self._handle_remote_error(raise_on_raise=False)
999+
if self.ok:
1000+
self.failure_reason = "ok"
1001+
elif self.failure_reason is None:
1002+
self.failure_reason = "error"
9481003
return self.ok
9491004
process.terminate()
9501005
process.join(10)
@@ -960,6 +1015,7 @@ def _mark_complete(self) -> bool:
9601015
self.search.log.warning(msg)
9611016

9621017
self.ok = False
1018+
self.failure_reason = "timeout"
9631019
self._recv_result(block=False)
9641020
self._handle_remote_error(raise_on_raise=False)
9651021
return False

helion/autotuner/finite_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
def _autotune(self) -> Config:
3636
best_config = None
3737
best_time = float("inf")
38-
for config, _fn, time in self.parallel_benchmark(
38+
for config, _fn, time, _status in self.parallel_benchmark(
3939
self.configs, desc="Benchmarking"
4040
):
4141
if time < best_time:

0 commit comments

Comments
 (0)