|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import collections |
3 | 4 | from contextlib import contextmanager |
4 | 5 | from contextlib import nullcontext |
| 6 | +import logging |
5 | 7 | import math |
6 | 8 | import os |
7 | 9 | from pathlib import Path |
|
18 | 20 |
|
19 | 21 | import helion |
20 | 22 | from helion import _compat |
| 23 | +from helion import exc |
21 | 24 | from helion._testing import DEVICE |
22 | 25 | from helion._testing import RefEagerTestDisabled |
23 | 26 | from helion._testing import TestCase |
24 | 27 | from helion._testing import import_path |
25 | 28 | from helion._testing import skipIfRocm |
26 | 29 | from helion.autotuner import DifferentialEvolutionSearch |
27 | 30 | from helion.autotuner import PatternSearch |
| 31 | +from helion.autotuner.base_search import BaseSearch |
28 | 32 | from helion.autotuner.config_fragment import BooleanFragment |
29 | 33 | from helion.autotuner.config_fragment import EnumFragment |
30 | 34 | from helion.autotuner.config_fragment import IntegerFragment |
31 | 35 | from helion.autotuner.config_fragment import PowerOfTwoFragment |
32 | 36 | from helion.autotuner.config_generation import ConfigGeneration |
33 | 37 | from helion.autotuner.effort_profile import get_effort_profile |
34 | 38 | from helion.autotuner.finite_search import FiniteSearch |
| 39 | +from helion.autotuner.logger import LambdaLogger |
35 | 40 | from helion.autotuner.random_search import RandomSearch |
36 | 41 | import helion.language as hl |
37 | 42 | from helion.language import loops |
| 43 | +from helion.runtime.settings import Settings |
38 | 44 |
|
39 | 45 | datadir = Path(__file__).parent / "data" |
40 | 46 | basic_kernels = import_path(datadir / "basic_kernels.py") |
@@ -63,6 +69,64 @@ def _autotune(self): |
63 | 69 | return super()._autotune() |
64 | 70 |
|
65 | 71 |
|
| 72 | +class TestAutotuneIgnoreErrors(TestCase): |
| 73 | + def _make_search(self, settings: Settings) -> BaseSearch: |
| 74 | + search = BaseSearch.__new__(BaseSearch) |
| 75 | + search.settings = settings |
| 76 | + search.kernel = SimpleNamespace( |
| 77 | + format_kernel_decorator=lambda config, s: "decorator", |
| 78 | + to_triton_code=lambda config: "code", |
| 79 | + ) |
| 80 | + search.args = () |
| 81 | + search.counters = collections.Counter() |
| 82 | + search.log = LambdaLogger(logging.CRITICAL) |
| 83 | + search._kernel_mutates_args = False |
| 84 | + search.best_perf_so_far = float("inf") |
| 85 | + return search |
| 86 | + |
| 87 | + def test_settings_flag_from_env(self): |
| 88 | + with patch.dict( |
| 89 | + os.environ, {"HELION_AUTOTUNE_IGNORE_ERRORS": "1"}, clear=False |
| 90 | + ): |
| 91 | + settings = Settings() |
| 92 | + self.assertTrue(settings.autotune_ignore_errors) |
| 93 | + |
| 94 | + def test_benchmark_raise_includes_hint(self): |
| 95 | + settings = Settings( |
| 96 | + autotune_ignore_errors=False, |
| 97 | + autotune_log_level=logging.CRITICAL, |
| 98 | + ) |
| 99 | + search = self._make_search(settings) |
| 100 | + |
| 101 | + def bad_fn(*_args): |
| 102 | + raise RuntimeError("boom") |
| 103 | + |
| 104 | + with patch("torch.accelerator.synchronize", autospec=True) as sync: |
| 105 | + sync.return_value = None |
| 106 | + with pytest.raises(exc.TritonError) as err: |
| 107 | + search.benchmark_function("cfg", bad_fn) |
| 108 | + |
| 109 | + assert "HELION_AUTOTUNE_IGNORE_ERRORS" in str(err.value) |
| 110 | + |
| 111 | + def test_ignore_errors_skips_logging_and_raise(self): |
| 112 | + settings = Settings( |
| 113 | + autotune_ignore_errors=True, |
| 114 | + autotune_log_level=logging.CRITICAL, |
| 115 | + ) |
| 116 | + search = self._make_search(settings) |
| 117 | + |
| 118 | + def bad_fn(*_args): |
| 119 | + raise RuntimeError("boom") |
| 120 | + |
| 121 | + with patch("torch.accelerator.synchronize", autospec=True) as sync: |
| 122 | + sync.return_value = None |
| 123 | + with patch.object(search.log, "warning") as warn: |
| 124 | + result = search.benchmark_function("cfg", bad_fn) |
| 125 | + |
| 126 | + self.assertEqual(result, float("inf")) |
| 127 | + warn.assert_not_called() |
| 128 | + |
| 129 | + |
66 | 130 | class TestAutotuner(RefEagerTestDisabled, TestCase): |
67 | 131 | def setUp(self): |
68 | 132 | super().setUp() |
|
0 commit comments