Skip to content

Commit cd4391e

Browse files
committed
Fix triton/torch.compile compability issue
An upstream Triton change seems to be breaking things. stack-info: PR: #927, branch: jansel/stack/188
1 parent 34352e7 commit cd4391e

File tree

2 files changed

+181
-0
lines changed

2 files changed

+181
-0
lines changed

helion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from triton import cdiv
44
from triton import next_power_of_2
55

6+
from . import _compat as _compat_module # noqa: F401 # side-effect import
67
from . import _logging
78
from . import exc
89
from . import language

helion/_compat.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,193 @@
11
from __future__ import annotations
22

3+
import contextlib
34
import functools
5+
from typing import Any
6+
from typing import Callable
7+
from typing import cast
48

59
import torch
610
from torch._inductor.runtime.hints import DeviceProperties
711
from torch._inductor.utils import triton_type
812
import triton
13+
from triton.backends.compiler import BaseBackend
914
from triton.backends.compiler import GPUTarget
1015
import triton.language as tl
16+
import triton.runtime.jit as triton_jit
17+
18+
NativeSpecializeImpl = Callable[
19+
[type[BaseBackend], object, bool, bool, bool], tuple[object, ...]
20+
]
21+
CreateSpecializeImpl = Callable[
22+
[Callable[..., object]], Callable[..., tuple[object, ...]]
23+
]
24+
25+
26+
def _make_specialize_impl_wrapper(
27+
*,
28+
native_impl: NativeSpecializeImpl | None = None,
29+
create_factory: CreateSpecializeImpl | None = None,
30+
) -> Callable[..., object]:
31+
if native_impl is None:
32+
native_impl = cast(
33+
"NativeSpecializeImpl | None",
34+
getattr(triton_jit, "native_specialize_impl", None),
35+
)
36+
if native_impl is None and create_factory is None:
37+
raise AttributeError("native_specialize_impl unavailable")
38+
39+
def specialize_impl_wrapper(
40+
*args: object,
41+
**kwargs: object,
42+
) -> object:
43+
specialize_extra = cast(
44+
"Callable[..., object] | None",
45+
kwargs.pop("specialize_extra", None),
46+
)
47+
kwargs.pop("specialize_zero_one", None)
48+
backend_param = kwargs.pop("backend", None)
49+
args_list: list[object] = list(args)
50+
backend_type: type[BaseBackend]
51+
if backend_param is None and args_list:
52+
first = args_list[0]
53+
if isinstance(first, type) and issubclass(first, BaseBackend):
54+
backend_type = first
55+
args_list.pop(0)
56+
elif isinstance(first, BaseBackend):
57+
backend_type = type(first)
58+
args_list.pop(0)
59+
else:
60+
backend_type = BaseBackend
61+
elif isinstance(backend_param, type) and issubclass(backend_param, BaseBackend):
62+
backend_type = backend_param
63+
elif isinstance(backend_param, BaseBackend):
64+
backend_type = type(backend_param)
65+
else:
66+
backend_type = BaseBackend
67+
68+
arg = kwargs.pop("arg", None)
69+
if arg is None:
70+
if args_list:
71+
arg = args_list.pop(0)
72+
else:
73+
raise TypeError("specialize_impl() missing positional argument 'arg'")
74+
75+
def _pop_flag(
76+
key: str,
77+
*,
78+
alt_keys: tuple[str, ...] = (),
79+
default: bool | None = None,
80+
) -> bool:
81+
value = kwargs.pop(key, None)
82+
if value is None:
83+
for alt in alt_keys:
84+
value = kwargs.pop(alt, None)
85+
if value is not None:
86+
break
87+
if value is None:
88+
if args_list:
89+
value = args_list.pop(0)
90+
elif default is not None:
91+
value = default
92+
else:
93+
raise TypeError(f"specialize_impl() missing argument '{key}'")
94+
return bool(value)
95+
96+
is_const = _pop_flag("is_const")
97+
specialize_value = _pop_flag(
98+
"specialize_value",
99+
alt_keys=("specialize",),
100+
default=True,
101+
)
102+
align = _pop_flag("align", default=True)
103+
104+
if native_impl is not None:
105+
result = native_impl(
106+
backend_type,
107+
arg,
108+
is_const,
109+
specialize_value,
110+
align,
111+
)
112+
if specialize_extra is not None:
113+
with contextlib.suppress(Exception):
114+
specialize_extra(arg)
115+
else:
116+
assert create_factory is not None
117+
118+
def _call_specialize_extra(
119+
extra_arg: object,
120+
kind: object,
121+
*,
122+
align: bool = True,
123+
) -> object:
124+
if specialize_extra is None:
125+
return None
126+
try:
127+
return specialize_extra(extra_arg)
128+
except TypeError:
129+
try:
130+
return specialize_extra(extra_arg, kind, align=align)
131+
except Exception:
132+
return None
133+
except Exception:
134+
return None
135+
136+
impl = create_factory(_call_specialize_extra)
137+
result = impl(
138+
arg,
139+
is_const=is_const,
140+
specialize_value=specialize_value,
141+
align=align,
142+
)
143+
return result
144+
145+
return specialize_impl_wrapper
146+
147+
148+
def _ensure_triton_specialize_impl_alias() -> None:
149+
if hasattr(triton_jit, "specialize_impl"):
150+
return
151+
if hasattr(triton_jit, "native_specialize_impl"):
152+
module: Any = triton_jit
153+
module.specialize_impl = _make_specialize_impl_wrapper() # type: ignore[assignment]
154+
return
155+
if hasattr(triton_jit, "create_specialize_impl"):
156+
module: Any = triton_jit
157+
module.specialize_impl = _make_specialize_impl_wrapper(
158+
create_factory=triton_jit.create_specialize_impl,
159+
) # type: ignore[assignment]
160+
161+
162+
_ensure_triton_specialize_impl_alias()
163+
164+
165+
def _ensure_backend_specialization_alias() -> None:
166+
if hasattr(BaseBackend, "get_arg_specialization"):
167+
return
168+
if hasattr(BaseBackend, "get_tensor_specialization"):
169+
BaseBackend.get_arg_specialization = BaseBackend.get_tensor_specialization # type: ignore[attr-defined]
170+
171+
172+
_ensure_backend_specialization_alias()
173+
174+
175+
@functools.cache
176+
def get_triton_find_paths_if() -> Callable[..., object]:
177+
if hasattr(triton_jit, "find_paths_if"):
178+
return triton_jit.find_paths_if
179+
if hasattr(triton_jit, "_find_paths_if"):
180+
return triton_jit._find_paths_if # type: ignore[attr-defined]
181+
raise AttributeError("Unable to locate Triton find_paths_if helper")
182+
183+
184+
@functools.cache
185+
def get_triton_iterable_path() -> Callable[..., object]:
186+
if hasattr(triton_jit, "get_iterable_path"):
187+
return triton_jit.get_iterable_path
188+
if hasattr(triton_jit, "_get_iterable_path"):
189+
return triton_jit._get_iterable_path # type: ignore[attr-defined]
190+
raise AttributeError("Unable to locate Triton get_iterable_path helper")
11191

12192

13193
def supports_tensor_descriptor() -> bool:

0 commit comments

Comments
 (0)