Skip to content

Commit 694c4ea

Browse files
committed
Add HELION_AUTOTUNE_IGNORE_ERRORS
stack-info: PR: #961, branch: jansel/stack/201
1 parent c0ef109 commit 694c4ea

File tree

5 files changed

+114
-18
lines changed

5 files changed

+114
-18
lines changed

docs/api/settings.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ with helion.set_default_settings(
148148
149149
Lower values result in faster autotuning but may find less optimal configurations.
150150
151+
.. autoattribute:: Settings.autotune_ignore_errors
152+
153+
Continue autotuning even when candidate configurations raise recoverable runtime errors (for example, GPU out-of-memory). Default is ``False``. Controlled by ``HELION_AUTOTUNE_IGNORE_ERRORS``.
154+
151155
.. autoattribute:: Settings.autotune_accuracy_check
152156
153157
Validate each candidate configuration against a baseline output before accepting it. Default is ``True``. Controlled by ``HELION_AUTOTUNE_ACCURACY_CHECK``.
@@ -248,6 +252,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"Differe
248252
| ``HELION_AUTOTUNE_EFFORT`` | ``autotune_effort`` | Select autotuning preset (``"none"``, ``"quick"``, ``"full"``). |
249253
| ``HELION_REBENCHMARK_THRESHOLD`` | ``autotune_rebenchmark_threshold`` | Re-run configs whose performance is within a multiplier of the current best. |
250254
| ``HELION_AUTOTUNE_PROGRESS_BAR`` | ``autotune_progress_bar`` | Enable or disable the progress bar UI during autotuning. |
255+
| ``HELION_AUTOTUNE_IGNORE_ERRORS`` | ``autotune_ignore_errors`` | Continue autotuning even when recoverable runtime errors occur. |
251256
| ``HELION_CACHE_DIR`` | ``LocalAutotuneCache`` | Override the on-disk directory used for cached autotuning artifacts. |
252257
| ``HELION_SKIP_CACHE`` | ``LocalAutotuneCache`` | When set to ``1``, ignore cached autotuning entries and rerun searches. |
253258
| ``HELION_PRINT_OUTPUT_CODE`` | ``print_output_code`` | Print generated Triton code to stderr for inspection. |

helion/autotuner/base_search.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,10 @@ def _validate_against_baseline(
230230
)
231231
except AssertionError as e:
232232
self.counters["accuracy_mismatch"] += 1
233-
self.log.warning(
234-
f"Skipping config with accuracy mismatch: {config!r}\n{e!s}\nUse HELION_AUTOTUNE_ACCURACY_CHECK=0 to disable this check.\n"
235-
)
233+
if not self.settings.autotune_ignore_errors:
234+
self.log.warning(
235+
f"Skipping config with accuracy mismatch: {config!r}\n{e!s}\nUse HELION_AUTOTUNE_ACCURACY_CHECK=0 to disable this check.\n"
236+
)
236237
return False
237238
return True
238239

@@ -299,13 +300,17 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
299300
return res
300301
except Exception as e:
301302
action = classify_triton_exception(e)
302-
if action == "raise":
303+
if self.settings.autotune_ignore_errors:
304+
pass
305+
elif action == "raise":
303306
raise exc.TritonError(
304-
f"{type(e).__qualname__}: {e}",
305-
self.kernel.format_kernel_decorator(config, self.settings),
306-
self.kernel.to_triton_code(config),
307+
error=f"{type(e).__qualname__}: {e}",
308+
decorator=self.kernel.format_kernel_decorator(
309+
config, self.settings
310+
),
311+
code=self.kernel.to_triton_code(config),
307312
) from e
308-
if action == "warn":
313+
elif action == "warn":
309314
self.log.warning(format_triton_compile_failure(config, e, self.kernel))
310315
else:
311316
self.log.debug(f"Benchmarking failed: {type(e).__name__}: {e}")
@@ -1005,14 +1010,16 @@ def _mark_complete(self) -> bool:
10051010
process.join(10)
10061011
msg = f"Timeout after {self.elapsed:.0f}s compiling {self.config}"
10071012
if process.is_alive():
1008-
self.search.log.warning(
1009-
msg,
1010-
"(SIGKILL required)",
1011-
)
1013+
if not self.search.settings.autotune_ignore_errors:
1014+
self.search.log.warning(
1015+
msg,
1016+
"(SIGKILL required)",
1017+
)
10121018
process.kill()
10131019
process.join()
10141020
else:
1015-
self.search.log.warning(msg)
1021+
if not self.search.settings.autotune_ignore_errors:
1022+
self.search.log.warning(msg)
10161023

10171024
self.ok = False
10181025
self.failure_reason = "timeout"
@@ -1071,15 +1078,17 @@ def _handle_remote_error(self, *, raise_on_raise: bool) -> None:
10711078
return
10721079
exc_obj = error.to_exception()
10731080
classification = error.classification or classify_triton_exception(exc_obj)
1081+
if ignore_errors := self.search.settings.autotune_ignore_errors:
1082+
classification = "debug"
10741083
if classification == "raise":
10751084
if raise_on_raise:
10761085
self._remote_error_handled = True
10771086
raise exc.TritonError(
1078-
f"{type(exc_obj).__qualname__}: {exc_obj}",
1079-
self.search.kernel.format_kernel_decorator(
1087+
error=f"{type(exc_obj).__qualname__}: {exc_obj}",
1088+
decorator=self.search.kernel.format_kernel_decorator(
10801089
self.config, self.search.settings
10811090
),
1082-
self.search.kernel.to_triton_code(self.config),
1091+
code=self.search.kernel.to_triton_code(self.config),
10831092
) from exc_obj
10841093
return
10851094

@@ -1092,7 +1101,7 @@ def _handle_remote_error(self, *, raise_on_raise: bool) -> None:
10921101
)
10931102
if classification == "warn":
10941103
self.search.log.warning(message)
1095-
else:
1104+
elif not ignore_errors:
10961105
self.search.log.debug(message)
10971106
self._remote_error_handled = True
10981107

helion/exc.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,14 @@ class TorchOpTracingError(_WrapException):
335335

336336

337337
class TritonError(BaseError):
338-
message = "Error running generated Triton program:\n{1}\n{0}\n\nGenerated Triton code:\n{2}"
338+
message = """\
339+
Error from Triton code:
340+
{code}
341+
342+
Error running generated Triton program:
343+
{error}
344+
{decorator}
345+
Set autotune_ignore_errors=True or HELION_AUTOTUNE_IGNORE_ERRORS=1 to ignore Triton errors in autotuning."""
339346

340347

341348
class BaseWarning(_FixedMessage):

helion/runtime/settings.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,10 @@ def _get_autotune_precompile_jobs() -> int | None:
155155
return jobs
156156

157157

158+
def _get_autotune_ignore_errors() -> bool:
159+
return os.environ.get("HELION_AUTOTUNE_IGNORE_ERRORS", "0") == "1"
160+
161+
158162
@dataclasses.dataclass
159163
class _Settings:
160164
# see __slots__ below for the doc strings that show up in help(Settings)
@@ -192,6 +196,9 @@ class _Settings:
192196
autotune_max_generations: int | None = dataclasses.field(
193197
default_factory=_get_autotune_max_generations
194198
)
199+
autotune_ignore_errors: bool = dataclasses.field(
200+
default_factory=_get_autotune_ignore_errors
201+
)
195202
print_output_code: bool = os.environ.get("HELION_PRINT_OUTPUT_CODE", "0") == "1"
196203
force_autotune: bool = os.environ.get("HELION_FORCE_AUTOTUNE", "0") == "1"
197204
autotune_config_overrides: dict[str, object] = dataclasses.field(
@@ -230,6 +237,10 @@ class Settings(_Settings):
230237
"autotune_rebenchmark_threshold": "If a config is within threshold*best_perf, re-benchmark it to avoid outliers. Defaults to effort profile value. Set HELION_REBENCHMARK_THRESHOLD to override.",
231238
"autotune_progress_bar": "If True, show progress bar during autotuning. Default is True. Set HELION_AUTOTUNE_PROGRESS_BAR=0 to disable.",
232239
"autotune_max_generations": "Override the maximum number of generations for Pattern Search and Differential Evolution Search autotuning algorithms with HELION_AUTOTUNE_MAX_GENERATIONS=N or @helion.kernel(autotune_max_generations=N).",
240+
"autotune_ignore_errors": (
241+
"If True, skip logging and raising autotune errors. "
242+
"Set HELION_AUTOTUNE_IGNORE_ERRORS=1 to enable globally."
243+
),
233244
"print_output_code": "If True, print the output code of the kernel to stderr.",
234245
"force_autotune": "If True, force autotuning even if a config is provided.",
235246
"autotune_config_overrides": "Dictionary of config key/value pairs forced during autotuning.",

test/test_autotuner.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

3+
import collections
34
from contextlib import contextmanager
45
from contextlib import nullcontext
6+
import logging
57
import math
68
import os
79
from pathlib import Path
@@ -18,23 +20,27 @@
1820

1921
import helion
2022
from helion import _compat
23+
from helion import exc
2124
from helion._testing import DEVICE
2225
from helion._testing import RefEagerTestDisabled
2326
from helion._testing import TestCase
2427
from helion._testing import import_path
2528
from helion._testing import skipIfRocm
2629
from helion.autotuner import DifferentialEvolutionSearch
2730
from helion.autotuner import PatternSearch
31+
from helion.autotuner.base_search import BaseSearch
2832
from helion.autotuner.config_fragment import BooleanFragment
2933
from helion.autotuner.config_fragment import EnumFragment
3034
from helion.autotuner.config_fragment import IntegerFragment
3135
from helion.autotuner.config_fragment import PowerOfTwoFragment
3236
from helion.autotuner.config_generation import ConfigGeneration
3337
from helion.autotuner.effort_profile import get_effort_profile
3438
from helion.autotuner.finite_search import FiniteSearch
39+
from helion.autotuner.logger import LambdaLogger
3540
from helion.autotuner.random_search import RandomSearch
3641
import helion.language as hl
3742
from helion.language import loops
43+
from helion.runtime.settings import Settings
3844

3945
datadir = Path(__file__).parent / "data"
4046
basic_kernels = import_path(datadir / "basic_kernels.py")
@@ -63,6 +69,64 @@ def _autotune(self):
6369
return super()._autotune()
6470

6571

72+
class TestAutotuneIgnoreErrors(TestCase):
73+
def _make_search(self, settings: Settings) -> BaseSearch:
74+
search = BaseSearch.__new__(BaseSearch)
75+
search.settings = settings
76+
search.kernel = SimpleNamespace(
77+
format_kernel_decorator=lambda config, s: "decorator",
78+
to_triton_code=lambda config: "code",
79+
)
80+
search.args = ()
81+
search.counters = collections.Counter()
82+
search.log = LambdaLogger(logging.CRITICAL)
83+
search._kernel_mutates_args = False
84+
search.best_perf_so_far = float("inf")
85+
return search
86+
87+
def test_settings_flag_from_env(self):
88+
with patch.dict(
89+
os.environ, {"HELION_AUTOTUNE_IGNORE_ERRORS": "1"}, clear=False
90+
):
91+
settings = Settings()
92+
self.assertTrue(settings.autotune_ignore_errors)
93+
94+
def test_benchmark_raise_includes_hint(self):
95+
settings = Settings(
96+
autotune_ignore_errors=False,
97+
autotune_log_level=logging.CRITICAL,
98+
)
99+
search = self._make_search(settings)
100+
101+
def bad_fn(*_args):
102+
raise RuntimeError("boom")
103+
104+
with patch("torch.accelerator.synchronize", autospec=True) as sync:
105+
sync.return_value = None
106+
with pytest.raises(exc.TritonError) as err:
107+
search.benchmark_function("cfg", bad_fn)
108+
109+
assert "HELION_AUTOTUNE_IGNORE_ERRORS" in str(err.value)
110+
111+
def test_ignore_errors_skips_logging_and_raise(self):
112+
settings = Settings(
113+
autotune_ignore_errors=True,
114+
autotune_log_level=logging.CRITICAL,
115+
)
116+
search = self._make_search(settings)
117+
118+
def bad_fn(*_args):
119+
raise RuntimeError("boom")
120+
121+
with patch("torch.accelerator.synchronize", autospec=True) as sync:
122+
sync.return_value = None
123+
with patch.object(search.log, "warning") as warn:
124+
result = search.benchmark_function("cfg", bad_fn)
125+
126+
self.assertEqual(result, float("inf"))
127+
warn.assert_not_called()
128+
129+
66130
class TestAutotuner(RefEagerTestDisabled, TestCase):
67131
def setUp(self):
68132
super().setUp()

0 commit comments

Comments
 (0)