Skip to content

Commit c455136

Browse files
committed
Verify compiled kernels in subprocess
This is to handle configs that hang when we run them. It should also fix the IMA issues we have been seeing. stack-info: PR: #914, branch: jansel/stack/173
1 parent cbdea70 commit c455136

File tree

11 files changed

+688
-143
lines changed

11 files changed

+688
-143
lines changed

docs/api/settings.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,12 @@ with helion.set_default_settings(
129129
130130
.. autoattribute:: Settings.autotune_precompile
131131
132-
Whether to precompile kernels before autotuning. Default is ``True`` on non-Windows systems, ``False`` on Windows.
132+
Select the autotuner precompile mode, which adds parallelism and
133+
checks for errors/timeouts. ``"spawn"`` (default) runs kernel
134+
warm-up in a fresh process including running to check for errors,
135+
``"fork"`` is faster but does not include the error check run,
136+
or None to disables precompile checks altogether. Controlled by
137+
``HELION_AUTOTUNE_PRECOMPILE``.
133138
134139
.. autoattribute:: Settings.autotune_random_seed
135140

helion/_compat.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,151 @@
11
from __future__ import annotations
22

3+
import contextlib
34
import functools
5+
from typing import Callable
6+
from typing import cast
47

58
import torch
69
from torch._inductor.runtime.hints import DeviceProperties
710
from torch._inductor.utils import triton_type
811
import triton
12+
from triton.backends.compiler import BaseBackend
913
from triton.backends.compiler import GPUTarget
1014
import triton.language as tl
15+
import triton.runtime.jit as triton_jit
16+
17+
NativeSpecializeImpl = Callable[
18+
[type[BaseBackend], object, bool, bool, bool], tuple[object, ...]
19+
]
20+
21+
22+
def _make_specialize_impl_wrapper() -> Callable[..., object]:
23+
native_impl = cast(
24+
"NativeSpecializeImpl | None",
25+
getattr(triton_jit, "native_specialize_impl", None),
26+
)
27+
if native_impl is None:
28+
raise AttributeError("native_specialize_impl unavailable")
29+
30+
def specialize_impl_wrapper(
31+
*args: object,
32+
**kwargs: object,
33+
) -> object:
34+
specialize_extra = cast(
35+
"Callable[[object], object] | None",
36+
kwargs.pop("specialize_extra", None),
37+
)
38+
kwargs.pop("specialize_zero_one", None)
39+
backend_param = kwargs.pop("backend", None)
40+
args_list: list[object] = list(args)
41+
backend_type: type[BaseBackend]
42+
if backend_param is None and args_list:
43+
first = args_list[0]
44+
if isinstance(first, type) and issubclass(first, BaseBackend):
45+
backend_type = first
46+
args_list.pop(0)
47+
elif isinstance(first, BaseBackend):
48+
backend_type = type(first)
49+
args_list.pop(0)
50+
else:
51+
backend_type = BaseBackend
52+
elif isinstance(backend_param, type) and issubclass(backend_param, BaseBackend):
53+
backend_type = backend_param
54+
elif isinstance(backend_param, BaseBackend):
55+
backend_type = type(backend_param)
56+
else:
57+
backend_type = BaseBackend
58+
59+
arg = kwargs.pop("arg", None)
60+
if arg is None:
61+
if args_list:
62+
arg = args_list.pop(0)
63+
else:
64+
raise TypeError("specialize_impl() missing positional argument 'arg'")
65+
66+
def _pop_flag(
67+
key: str,
68+
*,
69+
alt_keys: tuple[str, ...] = (),
70+
default: bool | None = None,
71+
) -> bool:
72+
value = kwargs.pop(key, None)
73+
if value is None:
74+
for alt in alt_keys:
75+
value = kwargs.pop(alt, None)
76+
if value is not None:
77+
break
78+
if value is None:
79+
if args_list:
80+
value = args_list.pop(0)
81+
elif default is not None:
82+
value = default
83+
else:
84+
raise TypeError(f"specialize_impl() missing argument '{key}'")
85+
return bool(value)
86+
87+
is_const = _pop_flag("is_const")
88+
specialize_value = _pop_flag(
89+
"specialize_value",
90+
alt_keys=("specialize",),
91+
default=True,
92+
)
93+
align = _pop_flag("align", default=True)
94+
95+
result = native_impl(
96+
backend_type,
97+
arg,
98+
is_const,
99+
specialize_value,
100+
align,
101+
)
102+
if specialize_extra is not None:
103+
with contextlib.suppress(Exception):
104+
specialize_extra(arg)
105+
return result
106+
107+
return specialize_impl_wrapper
108+
109+
110+
def _ensure_triton_specialize_impl_alias() -> None:
111+
if hasattr(triton_jit, "specialize_impl"):
112+
return
113+
if hasattr(triton_jit, "native_specialize_impl"):
114+
triton_jit.specialize_impl = _make_specialize_impl_wrapper() # type: ignore[attr-defined]
115+
return
116+
if hasattr(triton_jit, "create_specialize_impl"):
117+
triton_jit.specialize_impl = triton_jit.create_specialize_impl() # type: ignore[attr-defined]
118+
119+
120+
_ensure_triton_specialize_impl_alias()
121+
122+
123+
def _ensure_backend_specialization_alias() -> None:
124+
if hasattr(BaseBackend, "get_arg_specialization"):
125+
return
126+
if hasattr(BaseBackend, "get_tensor_specialization"):
127+
BaseBackend.get_arg_specialization = BaseBackend.get_tensor_specialization # type: ignore[attr-defined]
128+
129+
130+
_ensure_backend_specialization_alias()
131+
132+
133+
@functools.cache
134+
def get_triton_find_paths_if() -> Callable[..., object]:
135+
if hasattr(triton_jit, "find_paths_if"):
136+
return triton_jit.find_paths_if
137+
if hasattr(triton_jit, "_find_paths_if"):
138+
return triton_jit._find_paths_if # type: ignore[attr-defined]
139+
raise AttributeError("Unable to locate Triton find_paths_if helper")
140+
141+
142+
@functools.cache
143+
def get_triton_iterable_path() -> Callable[..., object]:
144+
if hasattr(triton_jit, "get_iterable_path"):
145+
return triton_jit.get_iterable_path
146+
if hasattr(triton_jit, "_get_iterable_path"):
147+
return triton_jit._get_iterable_path # type: ignore[attr-defined]
148+
raise AttributeError("Unable to locate Triton get_iterable_path helper")
11149

12150

13151
def supports_tensor_descriptor() -> bool:

0 commit comments

Comments
 (0)