|
6 | 6 | import threading |
7 | 7 | import time |
8 | 8 | from typing import TYPE_CHECKING |
| 9 | +from typing import Callable |
9 | 10 | from typing import Literal |
10 | 11 | from typing import Protocol |
11 | 12 | from typing import Sequence |
|
15 | 16 | from torch._environment import is_fbcode |
16 | 17 |
|
17 | 18 | from helion import exc |
| 19 | +from helion.autotuner.effort_profile import _PROFILES |
18 | 20 | from helion.autotuner.effort_profile import AutotuneEffort |
19 | 21 | from helion.autotuner.effort_profile import get_effort_profile |
20 | 22 | from helion.runtime.ref_mode import RefMode |
@@ -127,8 +129,16 @@ def _get_autotune_rebenchmark_threshold() -> float | None: |
127 | 129 | return None # Will use effort profile default |
128 | 130 |
|
129 | 131 |
|
| 132 | +def _normalize_autotune_effort(value: object) -> AutotuneEffort: |
| 133 | + if isinstance(value, str): |
| 134 | + normalized = value.lower() |
| 135 | + if normalized in _PROFILES: |
| 136 | + return cast("AutotuneEffort", normalized) |
| 137 | + raise ValueError("autotune_effort must be one of 'none', 'quick', or 'full'") |
| 138 | + |
| 139 | + |
130 | 140 | def _get_autotune_effort() -> AutotuneEffort: |
131 | | - return cast("AutotuneEffort", os.environ.get("HELION_AUTOTUNE_EFFORT", "full")) |
| 141 | + return _normalize_autotune_effort(os.environ.get("HELION_AUTOTUNE_EFFORT", "full")) |
132 | 142 |
|
133 | 143 |
|
134 | 144 | def _get_autotune_precompile() -> str | None: |
@@ -209,6 +219,38 @@ class _Settings: |
209 | 219 | ) |
210 | 220 | autotuner_fn: AutotunerFunction = default_autotuner_fn |
211 | 221 |
|
| 222 | + def __post_init__(self) -> None: |
| 223 | + def _is_bool(val: object) -> bool: |
| 224 | + return isinstance(val, bool) |
| 225 | + |
| 226 | + def _is_non_negative_int(val: object) -> bool: |
| 227 | + return isinstance(val, int) and val >= 0 |
| 228 | + |
| 229 | + # Validate user settings |
| 230 | + validators: dict[str, Callable[[object], bool]] = { |
| 231 | + "autotune_log_level": _is_non_negative_int, |
| 232 | + "autotune_compile_timeout": _is_non_negative_int, |
| 233 | + "autotune_precompile": _is_bool, |
| 234 | + "autotune_precompile_jobs": lambda v: v is None or _is_non_negative_int(v), |
| 235 | + "autotune_accuracy_check": _is_bool, |
| 236 | + "autotune_progress_bar": _is_bool, |
| 237 | + "autotune_max_generations": lambda v: v is None or _is_non_negative_int(v), |
| 238 | + "print_output_code": _is_bool, |
| 239 | + "force_autotune": _is_bool, |
| 240 | + "allow_warp_specialize": _is_bool, |
| 241 | + "debug_dtype_asserts": _is_bool, |
| 242 | + "autotune_rebenchmark_threshold": lambda v: v is None |
| 243 | + or (isinstance(v, (int, float)) and v >= 0), |
| 244 | + } |
| 245 | + |
| 246 | + normalized_effort = _normalize_autotune_effort(self.autotune_effort) |
| 247 | + object.__setattr__(self, "autotune_effort", normalized_effort) |
| 248 | + |
| 249 | + for field_name, checker in validators.items(): |
| 250 | + value = getattr(self, field_name) |
| 251 | + if not checker(value): |
| 252 | + raise ValueError(f"Invalid value for {field_name}: {value!r}") |
| 253 | + |
212 | 254 |
|
213 | 255 | class Settings(_Settings): |
214 | 256 | """ |
|
0 commit comments