3232import torch
3333from torch .utils ._pytree import tree_flatten
3434from torch .utils ._pytree import tree_map
35+ from torch .utils ._pytree import tree_map_only
3536from triton .testing import do_bench
3637
3738from .. import exc
@@ -81,6 +82,10 @@ class BaseSearch(BaseAutotuner):
8182 counters (collections.Counter): A counter to track various metrics during the search.
8283 """
8384
85+ _baseline_output : object
86+ _kernel_mutates_args : bool
87+ _baseline_post_args : Sequence [object ] | None
88+
8489 def __init__ (self , kernel : BoundKernel , args : Sequence [object ]) -> None :
8590 """
8691 Initialize the BaseSearch object.
@@ -101,17 +106,14 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None:
101106 random .seed (seed )
102107 self .log (f"Autotune random seed: { seed } " )
103108 self ._original_args : Sequence [object ] = self ._clone_args (self .args )
104- self ._baseline_output : object | None = None
105- self ._baseline_post_args : Sequence [object ] | None = None
106- self ._kernel_mutates_args : bool = False
107109 self ._precompile_tmpdir : tempfile .TemporaryDirectory [str ] | None = None
108110 self ._precompile_args_path : str | None = None
109- if self . settings . autotune_accuracy_check :
110- (
111- self ._baseline_output ,
112- self ._kernel_mutates_args ,
113- self ._baseline_post_args ,
114- ) = self ._compute_baseline ()
111+ (
112+ self . _baseline_output ,
113+ self ._kernel_mutates_args ,
114+ self ._baseline_post_args ,
115+ ) = self ._compute_baseline ()
116+ self . _jobs = self ._decide_num_jobs ()
115117
116118 def cleanup (self ) -> None :
117119 if self ._precompile_tmpdir is not None :
@@ -165,6 +167,55 @@ def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]:
165167 baseline_post_args = self ._clone_args (new_args )
166168 return baseline_output , mutated , baseline_post_args
167169
170+ def _decide_num_jobs (self ) -> int :
171+ if not self .settings .autotune_precompile :
172+ return 1
173+
174+ jobs = self .settings .autotune_precompile_jobs
175+ if not jobs :
176+ jobs = os .cpu_count () or 1
177+
178+ if self .settings .autotune_precompile != "spawn" :
179+ return jobs
180+
181+ memory_per_job = _estimate_tree_bytes (self .args ) + _estimate_tree_bytes (
182+ self ._baseline_output
183+ )
184+ memory_per_job *= 2 # safety factor
185+ if memory_per_job <= 0 :
186+ return jobs
187+
188+ device = self .kernel .env .device
189+ if device .type != "cuda" :
190+ # TODO(jansel): support non-cuda devices
191+ return jobs
192+
193+ available_memory , _ = torch .cuda .mem_get_info (device )
194+ jobs_by_memory = available_memory // memory_per_job
195+ if jobs_by_memory < jobs :
196+ gib_per_job = memory_per_job / (1024 ** 3 )
197+ available_gib = available_memory / (1024 ** 3 )
198+ if jobs_by_memory > 0 :
199+ self .log .warning (
200+ f"Reducing autotune precompile spawn jobs from { jobs } to { jobs_by_memory } "
201+ f"due to limited GPU memory (estimated { gib_per_job :.2f} GiB per job, "
202+ f"{ available_gib :.2f} GiB free). "
203+ f"Set HELION_AUTOTUNE_PRECOMPILE_JOBS={ jobs_by_memory } "
204+ "to make this lower cap persistent, "
205+ 'set HELION_AUTOTUNE_PRECOMPILE="fork" to disable spawning, or reduce GPU memory usage.'
206+ )
207+ else :
208+ raise exc .AutotuneError (
209+ "Autotune precompile spawn mode requires at least one job, but estimated "
210+ "memory usage exceeds available GPU memory."
211+ f"Estimated { gib_per_job :.2f} GiB per job, but only "
212+ f"{ available_gib :.2f} GiB free. "
213+ 'Set HELION_AUTOTUNE_PRECOMPILE="fork" to disable spawning, or reduce GPU memory usage.'
214+ )
215+ jobs = jobs_by_memory
216+
217+ return jobs
218+
168219 def _validate_against_baseline (
169220 self , config : Config , output : object , args : Sequence [object ]
170221 ) -> bool :
@@ -179,7 +230,7 @@ def _validate_against_baseline(
179230 except AssertionError as e :
180231 self .counters ["accuracy_mismatch" ] += 1
181232 self .log .warning (
182- f"Skipping config with accuracy mismatch: { config !r} { e !s} \n Use HELION_AUTOTUNE_ACCURACY_CHECK=0 to disable this check.\n "
233+ f"Skipping config with accuracy mismatch: { config !r} \n { e !s} \n Use HELION_AUTOTUNE_ACCURACY_CHECK=0 to disable this check.\n "
183234 )
184235 return False
185236 return True
@@ -454,6 +505,31 @@ def performance(member: PopulationMember) -> float:
454505 return member .perf
455506
456507
508+ def _estimate_tree_bytes (obj : object ) -> int :
509+ """Estimate the memory usage of a pytree of objects, counting shared storage only once."""
510+ total = 0
511+ seen_ptrs : set [int ] = set ()
512+
513+ def _accumulate (tensor : torch .Tensor ) -> torch .Tensor :
514+ nonlocal total
515+ size = tensor .element_size () * tensor .numel ()
516+ try :
517+ storage = tensor .untyped_storage ()
518+ except RuntimeError :
519+ pass
520+ else :
521+ ptr = storage .data_ptr ()
522+ if ptr in seen_ptrs :
523+ return tensor
524+ seen_ptrs .add (ptr )
525+ size = storage .nbytes ()
526+ total += size
527+ return tensor
528+
529+ tree_map_only (torch .Tensor , _accumulate , obj )
530+ return total
531+
532+
457533class PopulationBasedSearch (BaseSearch ):
458534 """
459535 Base class for search algorithms that use a population of configurations.
@@ -823,8 +899,7 @@ def _wait_for_all_step(
823899 futures : list [PrecompileFuture ],
824900 ) -> list [PrecompileFuture ]:
825901 """Start up to the concurrency cap, wait for progress, and return remaining futures."""
826- # Concurrency cap from the settings of the first future's search
827- cap = futures [0 ].search .settings .autotune_precompile_jobs or os .cpu_count () or 1
902+ cap = futures [0 ].search ._jobs if futures else 1
828903 running = [f for f in futures if f .started and f .ok is None and f .is_alive ()]
829904
830905 # Start queued futures up to the cap
0 commit comments