Skip to content

Commit a2bb673

Browse files
authored
Add progress bar for precompiling (#919)
1 parent 2d462e2 commit a2bb673

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

helion/autotuner/base_search.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,10 @@ def parallel_benchmark(
326326
self.start_precompile_and_check_for_hangs,
327327
zip(configs, fns, strict=True),
328328
)
329-
]
329+
],
330+
desc=f"{desc} precompiling"
331+
if self.settings.autotune_progress_bar
332+
else None,
330333
)
331334
else:
332335
is_workings = [True] * len(configs)
@@ -336,7 +339,7 @@ def parallel_benchmark(
336339
iterator = iter_with_progress(
337340
zip(configs, fns, is_workings, strict=True),
338341
total=len(configs),
339-
description=desc,
342+
description=f"{desc}: exploring neighbors",
340343
enabled=self.settings.autotune_progress_bar,
341344
)
342345
for config, fn, is_working in iterator:
@@ -725,6 +728,7 @@ def __call__(self) -> bool:
725728
@staticmethod
726729
def wait_for_all(
727730
futures: list[PrecompileFuture],
731+
desc: str | None = None,
728732
) -> list[bool]:
729733
"""
730734
Wait for all precompile futures to complete.
@@ -735,10 +739,21 @@ def wait_for_all(
735739
Returns:
736740
A list of boolean values indicating completion status.
737741
"""
742+
progress = iter_with_progress(
743+
range(len(futures)),
744+
total=len(futures),
745+
description=desc,
746+
enabled=desc is not None,
747+
)
748+
next(progress, None) # display the progress bar immediately
749+
progress_left = len(futures)
738750
remaining = [f for f in futures if f.ok is None]
739751
try:
740752
while remaining:
741753
remaining = PrecompileFuture._wait_for_all_step(remaining)
754+
while progress_left > len(remaining):
755+
next(progress, None)
756+
progress_left -= 1
742757
except Exception:
743758
for f in remaining:
744759
if (p := f.process) is not None:

helion/autotuner/pattern_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,11 @@ def _autotune(self) -> Config:
102102
unbenchmarked = [m for m in self.population if len(m.perfs) == 0]
103103
if unbenchmarked:
104104
self.parallel_benchmark_population(
105-
unbenchmarked, desc=f"Generation {generation}: Exploring neighbors"
105+
unbenchmarked, desc=f"Generation {generation}:"
106106
)
107107
# higher-accuracy rebenchmark
108108
self.rebenchmark_population(
109-
self.population, desc=f"Generation {generation}: Verifying top configs"
109+
self.population, desc=f"Generation {generation}: verifying top configs"
110110
)
111111
# Log final statistics for this generation
112112
self.log(f"Generation {generation} complete:", self.statistics)

0 commit comments

Comments
 (0)