Skip to content

Commit 6aab612

Browse files
committed
Fix triton/torch.compile compability issue
An upstream Triton change seems to be breaking things.
1 parent 05dcc55 commit 6aab612

File tree

2 files changed

+139
-0
lines changed

2 files changed

+139
-0
lines changed

helion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from triton import cdiv
44
from triton import next_power_of_2
55

6+
from . import _compat as _compat_module # noqa: F401 # side-effect import
67
from . import _logging
78
from . import exc
89
from . import language

helion/_compat.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,151 @@
11
from __future__ import annotations
22

3+
import contextlib
34
import functools
5+
from typing import Callable
6+
from typing import cast
47

58
import torch
69
from torch._inductor.runtime.hints import DeviceProperties
710
from torch._inductor.utils import triton_type
811
import triton
12+
from triton.backends.compiler import BaseBackend
913
from triton.backends.compiler import GPUTarget
1014
import triton.language as tl
15+
import triton.runtime.jit as triton_jit
16+
17+
NativeSpecializeImpl = Callable[
18+
[type[BaseBackend], object, bool, bool, bool], tuple[object, ...]
19+
]
20+
21+
22+
def _make_specialize_impl_wrapper() -> Callable[..., object]:
23+
native_impl = cast(
24+
"NativeSpecializeImpl | None",
25+
getattr(triton_jit, "native_specialize_impl", None),
26+
)
27+
if native_impl is None:
28+
raise AttributeError("native_specialize_impl unavailable")
29+
30+
def specialize_impl_wrapper(
31+
*args: object,
32+
**kwargs: object,
33+
) -> object:
34+
specialize_extra = cast(
35+
"Callable[[object], object] | None",
36+
kwargs.pop("specialize_extra", None),
37+
)
38+
kwargs.pop("specialize_zero_one", None)
39+
backend_param = kwargs.pop("backend", None)
40+
args_list: list[object] = list(args)
41+
backend_type: type[BaseBackend]
42+
if backend_param is None and args_list:
43+
first = args_list[0]
44+
if isinstance(first, type) and issubclass(first, BaseBackend):
45+
backend_type = first
46+
args_list.pop(0)
47+
elif isinstance(first, BaseBackend):
48+
backend_type = type(first)
49+
args_list.pop(0)
50+
else:
51+
backend_type = BaseBackend
52+
elif isinstance(backend_param, type) and issubclass(backend_param, BaseBackend):
53+
backend_type = backend_param
54+
elif isinstance(backend_param, BaseBackend):
55+
backend_type = type(backend_param)
56+
else:
57+
backend_type = BaseBackend
58+
59+
arg = kwargs.pop("arg", None)
60+
if arg is None:
61+
if args_list:
62+
arg = args_list.pop(0)
63+
else:
64+
raise TypeError("specialize_impl() missing positional argument 'arg'")
65+
66+
def _pop_flag(
67+
key: str,
68+
*,
69+
alt_keys: tuple[str, ...] = (),
70+
default: bool | None = None,
71+
) -> bool:
72+
value = kwargs.pop(key, None)
73+
if value is None:
74+
for alt in alt_keys:
75+
value = kwargs.pop(alt, None)
76+
if value is not None:
77+
break
78+
if value is None:
79+
if args_list:
80+
value = args_list.pop(0)
81+
elif default is not None:
82+
value = default
83+
else:
84+
raise TypeError(f"specialize_impl() missing argument '{key}'")
85+
return bool(value)
86+
87+
is_const = _pop_flag("is_const")
88+
specialize_value = _pop_flag(
89+
"specialize_value",
90+
alt_keys=("specialize",),
91+
default=True,
92+
)
93+
align = _pop_flag("align", default=True)
94+
95+
result = native_impl(
96+
backend_type,
97+
arg,
98+
is_const,
99+
specialize_value,
100+
align,
101+
)
102+
if specialize_extra is not None:
103+
with contextlib.suppress(Exception):
104+
specialize_extra(arg)
105+
return result
106+
107+
return specialize_impl_wrapper
108+
109+
110+
def _ensure_triton_specialize_impl_alias() -> None:
111+
if hasattr(triton_jit, "specialize_impl"):
112+
return
113+
if hasattr(triton_jit, "native_specialize_impl"):
114+
triton_jit.specialize_impl = _make_specialize_impl_wrapper() # type: ignore[attr-defined]
115+
return
116+
if hasattr(triton_jit, "create_specialize_impl"):
117+
triton_jit.specialize_impl = triton_jit.create_specialize_impl() # type: ignore[attr-defined]
118+
119+
120+
_ensure_triton_specialize_impl_alias()
121+
122+
123+
def _ensure_backend_specialization_alias() -> None:
124+
if hasattr(BaseBackend, "get_arg_specialization"):
125+
return
126+
if hasattr(BaseBackend, "get_tensor_specialization"):
127+
BaseBackend.get_arg_specialization = BaseBackend.get_tensor_specialization # type: ignore[attr-defined]
128+
129+
130+
_ensure_backend_specialization_alias()
131+
132+
133+
@functools.cache
134+
def get_triton_find_paths_if() -> Callable[..., object]:
135+
if hasattr(triton_jit, "find_paths_if"):
136+
return triton_jit.find_paths_if
137+
if hasattr(triton_jit, "_find_paths_if"):
138+
return triton_jit._find_paths_if # type: ignore[attr-defined]
139+
raise AttributeError("Unable to locate Triton find_paths_if helper")
140+
141+
142+
@functools.cache
143+
def get_triton_iterable_path() -> Callable[..., object]:
144+
if hasattr(triton_jit, "get_iterable_path"):
145+
return triton_jit.get_iterable_path
146+
if hasattr(triton_jit, "_get_iterable_path"):
147+
return triton_jit._get_iterable_path # type: ignore[attr-defined]
148+
raise AttributeError("Unable to locate Triton get_iterable_path helper")
11149

12150

13151
def supports_tensor_descriptor() -> bool:

0 commit comments

Comments
 (0)