Skip to content

Commit 750b20e

Browse files
committed
Auto-shrink autotune_precompile_jobs based on free memory
stack-info: PR: #940, branch: jansel/stack/192
1 parent b5403a4 commit 750b20e

File tree

3 files changed

+102
-13
lines changed

3 files changed

+102
-13
lines changed

docs/api/settings.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ with helion.set_default_settings(
139139
.. autoattribute:: Settings.autotune_precompile_jobs
140140
141141
Cap the number of concurrent Triton precompile subprocesses. ``None`` (default) uses the machine CPU count.
142+
Controlled by ``HELION_AUTOTUNE_PRECOMPILE_JOBS``.
143+
When using ``"spawn"`` precompile mode, Helion may automatically lower this cap if free GPU memory is limited.
142144
143145
.. autoattribute:: Settings.autotune_max_generations
144146

helion/autotuner/base_search.py

Lines changed: 87 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import torch
3333
from torch.utils._pytree import tree_flatten
3434
from torch.utils._pytree import tree_map
35+
from torch.utils._pytree import tree_map_only
3536
from triton.testing import do_bench
3637

3738
from .. import exc
@@ -81,6 +82,10 @@ class BaseSearch(BaseAutotuner):
8182
counters (collections.Counter): A counter to track various metrics during the search.
8283
"""
8384

85+
_baseline_output: object
86+
_kernel_mutates_args: bool
87+
_baseline_post_args: Sequence[object] | None
88+
8489
def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None:
8590
"""
8691
Initialize the BaseSearch object.
@@ -101,17 +106,14 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None:
101106
random.seed(seed)
102107
self.log(f"Autotune random seed: {seed}")
103108
self._original_args: Sequence[object] = self._clone_args(self.args)
104-
self._baseline_output: object | None = None
105-
self._baseline_post_args: Sequence[object] | None = None
106-
self._kernel_mutates_args: bool = False
107109
self._precompile_tmpdir: tempfile.TemporaryDirectory[str] | None = None
108110
self._precompile_args_path: str | None = None
109-
if self.settings.autotune_accuracy_check:
110-
(
111-
self._baseline_output,
112-
self._kernel_mutates_args,
113-
self._baseline_post_args,
114-
) = self._compute_baseline()
111+
(
112+
self._baseline_output,
113+
self._kernel_mutates_args,
114+
self._baseline_post_args,
115+
) = self._compute_baseline()
116+
self._jobs = self._decide_num_jobs()
115117

116118
def cleanup(self) -> None:
117119
if self._precompile_tmpdir is not None:
@@ -165,6 +167,55 @@ def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]:
165167
baseline_post_args = self._clone_args(new_args)
166168
return baseline_output, mutated, baseline_post_args
167169

170+
def _decide_num_jobs(self) -> int:
171+
if not self.settings.autotune_precompile:
172+
return 1
173+
174+
jobs = self.settings.autotune_precompile_jobs
175+
if not jobs:
176+
jobs = os.cpu_count() or 1
177+
178+
if self.settings.autotune_precompile != "spawn":
179+
return jobs
180+
181+
memory_per_job = _estimate_tree_bytes(self.args) + _estimate_tree_bytes(
182+
self._baseline_output
183+
)
184+
memory_per_job *= 2 # safety factor
185+
if memory_per_job <= 0:
186+
return jobs
187+
188+
device = self.kernel.env.device
189+
if device.type != "cuda":
190+
# TODO(jansel): support non-cuda devices
191+
return jobs
192+
193+
available_memory, _ = torch.cuda.mem_get_info(device)
194+
jobs_by_memory = available_memory // memory_per_job
195+
if jobs_by_memory < jobs:
196+
gib_per_job = memory_per_job / (1024**3)
197+
available_gib = available_memory / (1024**3)
198+
if jobs_by_memory > 0:
199+
self.log.warning(
200+
f"Reducing autotune precompile spawn jobs from {jobs} to {jobs_by_memory} "
201+
f"due to limited GPU memory (estimated {gib_per_job:.2f} GiB per job, "
202+
f"{available_gib:.2f} GiB free). "
203+
f"Set HELION_AUTOTUNE_PRECOMPILE_JOBS={jobs_by_memory} "
204+
"to make this lower cap persistent, "
205+
'set HELION_AUTOTUNE_PRECOMPILE="fork" to disable spawning, or reduce GPU memory usage.'
206+
)
207+
else:
208+
raise exc.AutotuneError(
209+
"Autotune precompile spawn mode requires at least one job, but estimated "
210+
"memory usage exceeds available GPU memory."
211+
f"Estimated {gib_per_job:.2f} GiB per job, but only "
212+
f"{available_gib:.2f} GiB free. "
213+
'Set HELION_AUTOTUNE_PRECOMPILE="fork" to disable spawning, or reduce GPU memory usage.'
214+
)
215+
jobs = jobs_by_memory
216+
217+
return jobs
218+
168219
def _validate_against_baseline(
169220
self, config: Config, output: object, args: Sequence[object]
170221
) -> bool:
@@ -179,7 +230,7 @@ def _validate_against_baseline(
179230
except AssertionError as e:
180231
self.counters["accuracy_mismatch"] += 1
181232
self.log.warning(
182-
f"Skipping config with accuracy mismatch: {config!r}{e!s}\nUse HELION_AUTOTUNE_ACCURACY_CHECK=0 to disable this check.\n"
233+
f"Skipping config with accuracy mismatch: {config!r}\n{e!s}\nUse HELION_AUTOTUNE_ACCURACY_CHECK=0 to disable this check.\n"
183234
)
184235
return False
185236
return True
@@ -454,6 +505,31 @@ def performance(member: PopulationMember) -> float:
454505
return member.perf
455506

456507

508+
def _estimate_tree_bytes(obj: object) -> int:
509+
"""Estimate the memory usage of a pytree of objects, counting shared storage only once."""
510+
total = 0
511+
seen_ptrs: set[int] = set()
512+
513+
def _accumulate(tensor: torch.Tensor) -> torch.Tensor:
514+
nonlocal total
515+
size = tensor.element_size() * tensor.numel()
516+
try:
517+
storage = tensor.untyped_storage()
518+
except RuntimeError:
519+
pass
520+
else:
521+
ptr = storage.data_ptr()
522+
if ptr in seen_ptrs:
523+
return tensor
524+
seen_ptrs.add(ptr)
525+
size = storage.nbytes()
526+
total += size
527+
return tensor
528+
529+
tree_map_only(torch.Tensor, _accumulate, obj)
530+
return total
531+
532+
457533
class PopulationBasedSearch(BaseSearch):
458534
"""
459535
Base class for search algorithms that use a population of configurations.
@@ -823,8 +899,7 @@ def _wait_for_all_step(
823899
futures: list[PrecompileFuture],
824900
) -> list[PrecompileFuture]:
825901
"""Start up to the concurrency cap, wait for progress, and return remaining futures."""
826-
# Concurrency cap from the settings of the first future's search
827-
cap = futures[0].search.settings.autotune_precompile_jobs or os.cpu_count() or 1
902+
cap = futures[0].search._jobs if futures else 1
828903
running = [f for f in futures if f.started and f.ok is None and f.is_alive()]
829904

830905
# Start queued futures up to the cap

helion/runtime/settings.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,16 @@ def _get_autotune_precompile() -> str | None:
145145
)
146146

147147

148+
def _get_autotune_precompile_jobs() -> int | None:
149+
value = os.environ.get("HELION_AUTOTUNE_PRECOMPILE_JOBS")
150+
if value is None or value.strip() == "":
151+
return None
152+
jobs = int(value)
153+
if jobs <= 0:
154+
raise ValueError("HELION_AUTOTUNE_PRECOMPILE_JOBS must be a positive integer")
155+
return jobs
156+
157+
148158
@dataclasses.dataclass
149159
class _Settings:
150160
# see __slots__ below for the doc strings that show up in help(Settings)
@@ -164,7 +174,9 @@ class _Settings:
164174
autotune_precompile: str | None = dataclasses.field(
165175
default_factory=_get_autotune_precompile
166176
)
167-
autotune_precompile_jobs: int | None = None
177+
autotune_precompile_jobs: int | None = dataclasses.field(
178+
default_factory=_get_autotune_precompile_jobs
179+
)
168180
autotune_random_seed: int = dataclasses.field(
169181
default_factory=_get_autotune_random_seed
170182
)

0 commit comments

Comments
 (0)