From 8e5e7fa440d8b66ed7cbc5ade05e624d82d7f202 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Tue, 14 Oct 2025 12:22:01 -0700 Subject: [PATCH] Make HELION_FORCE_AUTOTUNE or kernel.autotune() skip the cache stack-info: PR: https://github.com/pytorch/helion/pull/930, branch: jansel/stack/190 --- helion/autotuner/base_cache.py | 9 +++++++-- helion/autotuner/base_search.py | 4 ++-- helion/runtime/kernel.py | 10 ++++++---- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/helion/autotuner/base_cache.py b/helion/autotuner/base_cache.py index 574001938..d6f35f8ae 100644 --- a/helion/autotuner/base_cache.py +++ b/helion/autotuner/base_cache.py @@ -157,8 +157,13 @@ def _get_cache_info_message(self) -> str: """Return a message describing where the cache is and how to clear it.""" return "" - def autotune(self) -> Config: - if os.environ.get("HELION_SKIP_CACHE", "") not in {"", "0", "false", "False"}: + def autotune(self, *, skip_cache: bool = False) -> Config: + if skip_cache or os.environ.get("HELION_SKIP_CACHE", "") not in { + "", + "0", + "false", + "False", + }: return self.autotuner.autotune() if (config := self.get()) is not None: diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index bd997a9fc..b81692ea3 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -65,7 +65,7 @@ class BaseAutotuner(abc.ABC): """ @abc.abstractmethod - def autotune(self) -> Config: + def autotune(self, *, skip_cache: bool = False) -> Config: raise NotImplementedError @@ -420,7 +420,7 @@ def parallel_benchmark( results.append((config, fn, inf)) return results - def autotune(self) -> Config: + def autotune(self, *, skip_cache: bool = False) -> Config: """ Perform autotuning to find the best configuration. diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index c638ed129..2ea4224c6 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -253,7 +253,7 @@ def autotune( self, args: Sequence[object], *, - force: bool = False, + force: bool = True, **options: object, ) -> Config: """ @@ -475,7 +475,7 @@ def autotune( self, args: Sequence[object], *, - force: bool = False, + force: bool = True, **kwargs: object, ) -> Config: """ @@ -508,7 +508,9 @@ def autotune( config = FiniteSearch(self, args, self.configs).autotune() else: self.settings.check_autotuning_disabled() - config = self.settings.autotuner_fn(self, args, **kwargs).autotune() + config = self.settings.autotuner_fn(self, args, **kwargs).autotune( + skip_cache=force + ) self.set_config(config) return config @@ -623,7 +625,7 @@ def __call__(self, *args: object) -> _R: if (config := self._implicit_config()) is not None: self.set_config(config) else: - self.autotune(args) + self.autotune(args, force=False) assert self._run is not None assert self._config is not None