Skip to content

Commit 63771cf

Browse files
committed
Use interleaved_bench for run_example
stack-info: PR: #945, branch: jansel/stack/194
1 parent 843d962 commit 63771cf

File tree

2 files changed

+45
-6
lines changed

2 files changed

+45
-6
lines changed

helion/_testing.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@
1919
import torch
2020
from torch.utils._pytree import tree_map
2121
import triton
22-
from triton.testing import do_bench
2322

2423
from ._utils import counters
2524
from .runtime.config import Config
2625
import helion
2726
from helion._compat import get_tensor_descriptor_fn_name
27+
from helion.autotuner.benchmarking import compute_repeat
28+
from helion.autotuner.benchmarking import interleaved_bench
2829
from helion.runtime.ref_mode import is_ref_mode_enabled
2930

3031
if TYPE_CHECKING:
@@ -560,11 +561,11 @@ def run_example(
560561
t.grad = None
561562

562563
# Benchmark all functions
563-
all_times = {
564-
name: do_bench(lambda fn=fn: fn(*args))
565-
for name, fn in {**kernels, **baselines}.items()
566-
}
567-
564+
all_benchmarks = {**kernels, **baselines}
565+
bench_fns = [functools.partial(fn, *args) for fn in all_benchmarks.values()]
566+
repeat = compute_repeat(bench_fns[0])
567+
timings = interleaved_bench(bench_fns, repeat=repeat, desc="Benchmarking")
568+
all_times = dict(zip(all_benchmarks.keys(), timings, strict=True))
568569
best_baseline_time = min(all_times[name] for name in baselines) # pyright: ignore[reportArgumentType]
569570

570571
# Print results

helion/autotuner/benchmarking.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import functools
4+
import math
45
import statistics
56
from typing import Callable
67

@@ -9,6 +10,43 @@
910
from .progress_bar import iter_with_progress
1011

1112

13+
def compute_repeat(
14+
fn: Callable[[], object],
15+
*,
16+
target_ms: float = 100.0,
17+
min_repeat: int = 10,
18+
max_repeat: int = 1000,
19+
estimate_runs: int = 5,
20+
) -> int:
21+
"""
22+
Estimate how many repetitions are needed to collect a stable benchmark for a
23+
single function call, mirroring Triton's ``do_bench`` heuristic while
24+
clamping the result between ``min_repeat`` and ``max_repeat``.
25+
"""
26+
di = runtime.driver.active.get_device_interface() # type: ignore[attr-defined]
27+
cache = runtime.driver.active.get_empty_cache_for_benchmark() # type: ignore[attr-defined]
28+
29+
# Warm the pipeline once before collecting timing samples.
30+
fn()
31+
di.synchronize()
32+
33+
start_event = di.Event(enable_timing=True)
34+
end_event = di.Event(enable_timing=True)
35+
start_event.record()
36+
for _ in range(estimate_runs):
37+
runtime.driver.active.clear_cache(cache) # type: ignore[attr-defined]
38+
fn()
39+
end_event.record()
40+
di.synchronize()
41+
42+
estimate_ms = start_event.elapsed_time(end_event) / max(estimate_runs, 1)
43+
if not math.isfinite(estimate_ms) or estimate_ms <= 0:
44+
return max_repeat
45+
46+
repeat = int(target_ms / estimate_ms)
47+
return max(min_repeat, min(max_repeat, max(1, repeat)))
48+
49+
1250
def interleaved_bench(
1351
fns: list[Callable[[], object]], *, repeat: int, desc: str | None = None
1452
) -> list[float]:

0 commit comments

Comments
 (0)