77import time
88from typing import TYPE_CHECKING
99from typing import Callable
10+ from typing import Collection
1011from typing import Literal
1112from typing import Protocol
1213from typing import Sequence
@@ -36,6 +37,45 @@ def __call__(
3637 ) -> BaseAutotuner : ...
3738
3839
40+ def _validate_enum_setting (
41+ value : object ,
42+ * ,
43+ name : str ,
44+ valid : Collection [str ],
45+ allow_none : bool = True ,
46+ ) -> str | None :
47+ """Normalize and validate an enum setting.
48+
49+ Args:
50+ value: The value to normalize and validate
51+ name: Name of the setting
52+ valid: Collection of valid settings
53+ allow_none: If True, None and _NONE_VALUES strings return None. If False, they raise an error.
54+ """
55+ # String values that should be treated as None
56+ _NONE_VALUES = frozenset ({"" , "0" , "false" , "none" })
57+
58+ # Normalize values
59+ normalized : str | None
60+ if isinstance (value , str ):
61+ normalized = value .strip ().lower ()
62+ else :
63+ normalized = None
64+
65+ is_none_value = normalized is None or normalized in _NONE_VALUES
66+ is_valid = normalized in valid if normalized else False
67+
68+ # Valid value (none or valid setting)
69+ if is_none_value and allow_none :
70+ return None
71+ if is_valid :
72+ return normalized
73+
74+ # Invalid value, raise error
75+ valid_list = "', '" .join (sorted (valid ))
76+ raise ValueError (f"{ name } must be one of '{ valid_list } ', got { value !r} " )
77+
78+
3979_tls : _TLS = cast ("_TLS" , threading .local ())
4080
4181
@@ -108,63 +148,6 @@ def default_autotuner_fn(
108148 return LocalAutotuneCache (autotuner_cls (bound_kernel , args , ** kwargs )) # pyright: ignore[reportArgumentType]
109149
110150
111- def _get_autotune_random_seed () -> int :
112- value = os .environ .get ("HELION_AUTOTUNE_RANDOM_SEED" )
113- if value is not None :
114- return int (value )
115- return int (time .time () * 1000 ) % 2 ** 32
116-
117-
118- def _get_autotune_max_generations () -> int | None :
119- value = os .environ .get ("HELION_AUTOTUNE_MAX_GENERATIONS" )
120- if value is not None :
121- return int (value )
122- return None
123-
124-
125- def _get_autotune_rebenchmark_threshold () -> float | None :
126- value = os .environ .get ("HELION_REBENCHMARK_THRESHOLD" )
127- if value is not None :
128- return float (value )
129- return None # Will use effort profile default
130-
131-
132- def _normalize_autotune_effort (value : object ) -> AutotuneEffort :
133- if isinstance (value , str ):
134- normalized = value .lower ()
135- if normalized in _PROFILES :
136- return cast ("AutotuneEffort" , normalized )
137- raise ValueError ("autotune_effort must be one of 'none', 'quick', or 'full'" )
138-
139-
140- def _get_autotune_effort () -> AutotuneEffort :
141- return _normalize_autotune_effort (os .environ .get ("HELION_AUTOTUNE_EFFORT" , "full" ))
142-
143-
144- def _get_autotune_precompile () -> str | None :
145- value = os .environ .get ("HELION_AUTOTUNE_PRECOMPILE" )
146- if value is None :
147- return "spawn"
148- mode = value .strip ().lower ()
149- if mode in {"" , "0" , "false" , "none" }:
150- return None
151- if mode in {"spawn" , "fork" }:
152- return mode
153- raise ValueError (
154- "HELION_AUTOTUNE_PRECOMPILE must be 'spawn', 'fork', or empty to disable precompile"
155- )
156-
157-
158- def _get_autotune_precompile_jobs () -> int | None :
159- value = os .environ .get ("HELION_AUTOTUNE_PRECOMPILE_JOBS" )
160- if value is None or value .strip () == "" :
161- return None
162- jobs = int (value )
163- if jobs <= 0 :
164- raise ValueError ("HELION_AUTOTUNE_PRECOMPILE_JOBS must be a positive integer" )
165- return jobs
166-
167-
168151@dataclasses .dataclass
169152class _Settings :
170153 # see __slots__ below for the doc strings that show up in help(Settings)
@@ -182,33 +165,45 @@ class _Settings:
182165 os .environ .get ("HELION_AUTOTUNE_COMPILE_TIMEOUT" , "60" )
183166 )
184167 autotune_precompile : str | None = dataclasses .field (
185- default_factory = _get_autotune_precompile
168+ default_factory = lambda : os . environ . get ( "HELION_AUTOTUNE_PRECOMPILE" , "spawn" )
186169 )
187170 autotune_precompile_jobs : int | None = dataclasses .field (
188- default_factory = _get_autotune_precompile_jobs
171+ default_factory = lambda : int (v )
172+ if (v := os .environ .get ("HELION_AUTOTUNE_PRECOMPILE_JOBS" ))
173+ else None
189174 )
190175 autotune_random_seed : int = dataclasses .field (
191- default_factory = _get_autotune_random_seed
176+ default_factory = lambda : (
177+ int (v )
178+ if (v := os .environ .get ("HELION_AUTOTUNE_RANDOM_SEED" ))
179+ else int (time .time () * 1000 ) % 2 ** 32
180+ )
192181 )
193182 autotune_accuracy_check : bool = (
194183 os .environ .get ("HELION_AUTOTUNE_ACCURACY_CHECK" , "1" ) == "1"
195184 )
196185 autotune_rebenchmark_threshold : float | None = dataclasses .field (
197- default_factory = _get_autotune_rebenchmark_threshold
186+ default_factory = lambda : float (v )
187+ if (v := os .environ .get ("HELION_REBENCHMARK_THRESHOLD" ))
188+ else None
198189 )
199190 autotune_progress_bar : bool = (
200191 os .environ .get ("HELION_AUTOTUNE_PROGRESS_BAR" , "1" ) == "1"
201192 )
202193 autotune_max_generations : int | None = dataclasses .field (
203- default_factory = _get_autotune_max_generations
194+ default_factory = lambda : int (v )
195+ if (v := os .environ .get ("HELION_AUTOTUNE_MAX_GENERATIONS" ))
196+ else None
204197 )
205198 print_output_code : bool = os .environ .get ("HELION_PRINT_OUTPUT_CODE" , "0" ) == "1"
206199 force_autotune : bool = os .environ .get ("HELION_FORCE_AUTOTUNE" , "0" ) == "1"
207200 autotune_config_overrides : dict [str , object ] = dataclasses .field (
208201 default_factory = dict
209202 )
210203 autotune_effort : AutotuneEffort = dataclasses .field (
211- default_factory = _get_autotune_effort
204+ default_factory = lambda : cast (
205+ "AutotuneEffort" , os .environ .get ("HELION_AUTOTUNE_EFFORT" , "full" )
206+ )
212207 )
213208 allow_warp_specialize : bool = (
214209 os .environ .get ("HELION_ALLOW_WARP_SPECIALIZE" , "1" ) == "1"
@@ -220,35 +215,43 @@ class _Settings:
220215 autotuner_fn : AutotunerFunction = default_autotuner_fn
221216
222217 def __post_init__ (self ) -> None :
223- def _is_bool (val : object ) -> bool :
224- return isinstance (val , bool )
225-
226- def _is_non_negative_int (val : object ) -> bool :
227- return isinstance (val , int ) and val >= 0
218+ # Validate all user settings
219+
220+ self .autotune_effort = cast (
221+ "AutotuneEffort" ,
222+ _validate_enum_setting (
223+ self .autotune_effort ,
224+ name = "autotune_effort" ,
225+ valid = _PROFILES .keys (),
226+ allow_none = False , # do not allow None as "none" is a non-default setting
227+ ),
228+ )
229+ self .autotune_precompile = _validate_enum_setting (
230+ self .autotune_precompile ,
231+ name = "autotune_precompile" ,
232+ valid = {"spawn" , "fork" },
233+ )
228234
229- # Validate user settings
230235 validators : dict [str , Callable [[object ], bool ]] = {
231- "autotune_log_level" : _is_non_negative_int ,
232- "autotune_compile_timeout" : _is_non_negative_int ,
233- "autotune_precompile" : lambda v : v in (None , "spawn" , "fork" ),
234- "autotune_precompile_jobs" : lambda v : v is None or _is_non_negative_int (v ),
235- "autotune_accuracy_check" : _is_bool ,
236- "autotune_progress_bar" : _is_bool ,
237- "autotune_max_generations" : lambda v : v is None or _is_non_negative_int (v ),
238- "print_output_code" : _is_bool ,
239- "force_autotune" : _is_bool ,
240- "allow_warp_specialize" : _is_bool ,
241- "debug_dtype_asserts" : _is_bool ,
236+ "autotune_log_level" : lambda v : isinstance (v , int ) and v >= 0 ,
237+ "autotune_compile_timeout" : lambda v : isinstance (v , int ) and v > 0 ,
238+ "autotune_precompile_jobs" : lambda v : v is None
239+ or (isinstance (v , int ) and v > 0 ),
240+ "autotune_accuracy_check" : lambda v : isinstance (v , bool ),
241+ "autotune_progress_bar" : lambda v : isinstance (v , bool ),
242+ "autotune_max_generations" : lambda v : v is None
243+ or (isinstance (v , int ) and v >= 0 ),
244+ "print_output_code" : lambda v : isinstance (v , bool ),
245+ "force_autotune" : lambda v : isinstance (v , bool ),
246+ "allow_warp_specialize" : lambda v : isinstance (v , bool ),
247+ "debug_dtype_asserts" : lambda v : isinstance (v , bool ),
242248 "autotune_rebenchmark_threshold" : lambda v : v is None
243249 or (isinstance (v , (int , float )) and v >= 0 ),
244250 }
245251
246- normalized_effort = _normalize_autotune_effort (self .autotune_effort )
247- object .__setattr__ (self , "autotune_effort" , normalized_effort )
248-
249- for field_name , checker in validators .items ():
252+ for field_name , validator in validators .items ():
250253 value = getattr (self , field_name )
251- if not checker (value ):
254+ if not validator (value ):
252255 raise ValueError (f"Invalid value for { field_name } : { value !r} " )
253256
254257
0 commit comments