Skip to content

Commit 09ccf1a

Browse files
committed
Adding validation for user provided settings.
Some invalid settings could make the compiler hang.
1 parent 1aaba3f commit 09ccf1a

File tree

2 files changed

+78
-1
lines changed

2 files changed

+78
-1
lines changed

helion/runtime/settings.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import threading
77
import time
88
from typing import TYPE_CHECKING
9+
from typing import Callable
910
from typing import Literal
1011
from typing import Protocol
1112
from typing import Sequence
@@ -15,6 +16,7 @@
1516
from torch._environment import is_fbcode
1617

1718
from helion import exc
19+
from helion.autotuner.effort_profile import _PROFILES
1820
from helion.autotuner.effort_profile import AutotuneEffort
1921
from helion.autotuner.effort_profile import get_effort_profile
2022
from helion.runtime.ref_mode import RefMode
@@ -127,8 +129,16 @@ def _get_autotune_rebenchmark_threshold() -> float | None:
127129
return None # Will use effort profile default
128130

129131

132+
def _normalize_autotune_effort(value: object) -> AutotuneEffort:
133+
if isinstance(value, str):
134+
normalized = value.lower()
135+
if normalized in _PROFILES:
136+
return cast("AutotuneEffort", normalized)
137+
raise ValueError("autotune_effort must be one of 'none', 'quick', or 'full'")
138+
139+
130140
def _get_autotune_effort() -> AutotuneEffort:
131-
return cast("AutotuneEffort", os.environ.get("HELION_AUTOTUNE_EFFORT", "full"))
141+
return _normalize_autotune_effort(os.environ.get("HELION_AUTOTUNE_EFFORT", "full"))
132142

133143

134144
def _get_autotune_precompile() -> str | None:
@@ -209,6 +219,38 @@ class _Settings:
209219
)
210220
autotuner_fn: AutotunerFunction = default_autotuner_fn
211221

222+
def __post_init__(self) -> None:
223+
def _is_bool(val: object) -> bool:
224+
return isinstance(val, bool)
225+
226+
def _is_non_negative_int(val: object) -> bool:
227+
return isinstance(val, int) and val >= 0
228+
229+
# Validate user settings
230+
validators: dict[str, Callable[[object], bool]] = {
231+
"autotune_log_level": _is_non_negative_int,
232+
"autotune_compile_timeout": _is_non_negative_int,
233+
"autotune_precompile": _is_bool,
234+
"autotune_precompile_jobs": lambda v: v is None or _is_non_negative_int(v),
235+
"autotune_accuracy_check": _is_bool,
236+
"autotune_progress_bar": _is_bool,
237+
"autotune_max_generations": lambda v: v is None or _is_non_negative_int(v),
238+
"print_output_code": _is_bool,
239+
"force_autotune": _is_bool,
240+
"allow_warp_specialize": _is_bool,
241+
"debug_dtype_asserts": _is_bool,
242+
"autotune_rebenchmark_threshold": lambda v: v is None
243+
or (isinstance(v, (int, float)) and v >= 0),
244+
}
245+
246+
normalized_effort = _normalize_autotune_effort(self.autotune_effort)
247+
object.__setattr__(self, "autotune_effort", normalized_effort)
248+
249+
for field_name, checker in validators.items():
250+
value = getattr(self, field_name)
251+
if not checker(value):
252+
raise ValueError(f"Invalid value for {field_name}: {value!r}")
253+
212254

213255
class Settings(_Settings):
214256
"""

test/test_settings.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from __future__ import annotations
2+
3+
import unittest
4+
5+
import helion
6+
7+
8+
class TestSettingsValidation(unittest.TestCase):
9+
def test_autotune_effort_none_raises(self) -> None:
10+
with self.assertRaisesRegex(
11+
ValueError, "autotune_effort must be one of 'none', 'quick', or 'full'"
12+
):
13+
helion.Settings(autotune_effort=None)
14+
15+
def test_autotune_effort_quick_normalized(self) -> None:
16+
settings = helion.Settings(autotune_effort="Quick")
17+
self.assertEqual(settings.autotune_effort, "quick")
18+
19+
def test_negative_compile_timeout_raises(self) -> None:
20+
with self.assertRaisesRegex(
21+
ValueError, r"Invalid value for autotune_compile_timeout: -1"
22+
):
23+
helion.Settings(autotune_compile_timeout=-1)
24+
25+
def test_autotune_precompile_jobs_negative_raises(self) -> None:
26+
with self.assertRaisesRegex(
27+
ValueError, r"Invalid value for autotune_precompile_jobs: -1"
28+
):
29+
helion.Settings(autotune_precompile_jobs=-1)
30+
31+
def test_autotune_max_generations_negative_raises(self) -> None:
32+
with self.assertRaisesRegex(
33+
ValueError, r"Invalid value for autotune_max_generations: -1"
34+
):
35+
helion.Settings(autotune_max_generations=-1)

0 commit comments

Comments
 (0)