Skip to content

Commit 0826b24

Browse files
authored
Merge branch 'main' into PaulZhang12/stack/14
2 parents ad6ca82 + 2ff186c commit 0826b24

File tree

8 files changed

+258
-483
lines changed

8 files changed

+258
-483
lines changed

docs/api/settings.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"Differe
257257
| ``HELION_AUTOTUNE_CONFIG_OVERRIDES`` | ``autotune_config_overrides`` | Supply JSON forcing particular autotuner config key/value pairs. |
258258
| ``HELION_CACHE_DIR`` | ``LocalAutotuneCache`` | Override the on-disk directory used for cached autotuning artifacts. |
259259
| ``HELION_SKIP_CACHE`` | ``LocalAutotuneCache`` | When set to ``1``, ignore cached autotuning entries and rerun searches. |
260+
| ``HELION_ASSERT_CACHE_HIT`` | ``AutotuneCacheBase`` | When set to ``1``, require a cache hit; raises ``CacheAssertionError`` on cache miss with detailed diagnostics. |
260261
| ``HELION_PRINT_OUTPUT_CODE`` | ``print_output_code`` | Print generated Triton code to stderr for inspection. |
261262
| ``HELION_OUTPUT_ORIGIN_LINES`` | ``output_origin_lines`` | Include ``# src[...]`` comments in generated Triton code; set to ``0`` to disable. |
262263
| ``HELION_IGNORE_WARNINGS`` | ``ignore_warnings`` | Comma-separated warning names defined in ``helion.exc`` to suppress. |

helion/_compiler/type_propagation.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ..language.stack_tensor import StackTensor
3131
from ..language.tile_proxy import Tile
3232
from ..language.tile_proxy import _CheckForIndexCalls
33+
from ..runtime.kernel import Kernel
3334
from .ast_extension import ExtendedAST
3435
from .ast_extension import LoopType
3536
from .ast_extension import create
@@ -105,6 +106,8 @@ def _get(self, name: str) -> TypeInfo:
105106
return TypeInfo.from_example(value, origin)
106107

107108
origin = self.function.global_scope_origin(name)
109+
if isinstance(value, Kernel):
110+
return TypeInfo.from_example(value, origin)
108111
if not isinstance(
109112
value,
110113
(types.ModuleType, types.FunctionType, types.BuiltinFunctionType),
@@ -1973,9 +1976,12 @@ def visit_Compare(self, node: ast.Compare) -> TypeInfo:
19731976

19741977
def visit_Call(self, node: ast.Call) -> TypeInfo:
19751978
# TODO(jansel): test handling if *args and **kwargs
1976-
# TODO(jansel): check for calling a Kernel here
19771979
func = self.visit(node.func)
19781980

1981+
# Check for calling a Helion kernel from within another Helion kernel
1982+
if isinstance(func, CallableType) and isinstance(func.value, Kernel):
1983+
raise exc.NestedKernelCallsNotSupported
1984+
19791985
if (
19801986
isinstance(func, CallableType)
19811987
and self.origin().is_device()

helion/autotuner/base_cache.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from __future__ import annotations
22

33
import abc
4+
from collections.abc import Sequence
45
import dataclasses
56
import functools
67
import hashlib
78
import logging
89
import os
10+
import sys
911
from typing import TYPE_CHECKING
1012
from typing import Any
1113
from typing import Callable
@@ -14,6 +16,7 @@
1416
from torch._inductor.codecache import build_code_hash
1517
from torch._inductor.codecache import torch_key
1618

19+
from .. import exc
1720
from .._utils import counters
1821
from .base_search import BaseAutotuner
1922

@@ -67,7 +70,8 @@ def torch_key_wrapper() -> str:
6770
def triton_key_wrapper() -> str:
6871
from torch._inductor.runtime.triton_compat import triton_key
6972

70-
return triton_key()
73+
full_key = triton_key()
74+
return hashlib.sha256(full_key.encode("utf-8")).hexdigest()
7175

7276

7377
class CacheKeyBase:
@@ -157,6 +161,16 @@ def _get_cache_info_message(self) -> str:
157161
"""Return a message describing where the cache is and how to clear it."""
158162
return ""
159163

164+
@abc.abstractmethod
165+
def _get_cache_key(self) -> CacheKeyBase:
166+
"""Return the cache key for this cache instance."""
167+
raise NotImplementedError
168+
169+
@abc.abstractmethod
170+
def _list_cache_entries(self) -> Sequence[tuple[str, CacheKeyBase]]:
171+
"""Return a sequence of (description, key) tuples for all cache entries."""
172+
raise NotImplementedError
173+
160174
def autotune(self, *, skip_cache: bool = False) -> Config:
161175
if skip_cache or os.environ.get("HELION_SKIP_CACHE", "") not in {
162176
"",
@@ -178,6 +192,43 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
178192
counters["autotune"]["cache_miss"] += 1
179193
log.debug("cache miss")
180194

195+
if os.environ.get("HELION_ASSERT_CACHE_HIT") == "1":
196+
current_key = self._get_cache_key()
197+
print("\n" + "=" * 80, file=sys.stderr)
198+
print("HELION_ASSERT_CACHE_HIT: Cache miss detected!", file=sys.stderr)
199+
print("=" * 80, file=sys.stderr)
200+
print(f"\nKernel: {self.kernel.kernel.name}", file=sys.stderr)
201+
print(f"\nCurrent cache key:\n{current_key}", file=sys.stderr)
202+
203+
cache_entries = self._list_cache_entries()
204+
if cache_entries:
205+
print(
206+
f"\n{len(cache_entries)} other cache entries exist (but don't match):",
207+
file=sys.stderr,
208+
)
209+
for i, (desc, cached_key) in enumerate(cache_entries, 1):
210+
print(f"\n[Entry {i}] {desc}", file=sys.stderr)
211+
print(" Key differences:", file=sys.stderr)
212+
has_diff = False
213+
for field_name in vars(current_key):
214+
current_val = str(getattr(current_key, field_name))
215+
cached_val = str(getattr(cached_key, field_name, "<missing>"))
216+
if current_val != cached_val:
217+
has_diff = True
218+
print(f" {field_name}:", file=sys.stderr)
219+
print(f" Current: {current_val}", file=sys.stderr)
220+
print(f" Cached: {cached_val}", file=sys.stderr)
221+
if not has_diff:
222+
print(
223+
" (no differences found, likely a hash collision)",
224+
file=sys.stderr,
225+
)
226+
else:
227+
print("\nNo existing cache entries found.", file=sys.stderr)
228+
229+
print("=" * 80 + "\n", file=sys.stderr)
230+
raise exc.CacheAssertionError(self.kernel.kernel.name)
231+
181232
self.autotuner.log("Starting autotuning process, this may take a while...")
182233

183234
config = self.autotuner.autotune()

helion/autotuner/local_cache.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33
import hashlib
44
import inspect
5+
import json
56
import logging
67
import os
78
from pathlib import Path
89
import textwrap
910
from typing import TYPE_CHECKING
11+
import uuid
1012

1113
import torch
1214
from torch._inductor.runtime.cache_dir_utils import (
@@ -19,6 +21,8 @@
1921
from .base_cache import StrictAutotuneCacheKey
2022

2123
if TYPE_CHECKING:
24+
from collections.abc import Sequence
25+
2226
from .base_search import BaseSearch
2327

2428
log: logging.Logger = logging.getLogger(__name__)
@@ -86,18 +90,71 @@ def _get_local_cache_path(self) -> Path:
8690
def get(self) -> Config | None:
8791
path = self._get_local_cache_path()
8892
try:
89-
return Config.load(path)
93+
data = json.loads(path.read_text())
94+
return Config.from_json(data["config"])
9095
except Exception:
9196
return None
9297

9398
def put(self, config: Config) -> None:
9499
path = self._get_local_cache_path()
95-
config.save(path)
100+
path.parent.mkdir(parents=True, exist_ok=True)
101+
102+
# Save both config and key for better debugging
103+
# Store key as dict for safer reconstruction (avoids eval)
104+
key_dict = {
105+
"type": type(self.key).__name__,
106+
"fields": {k: str(v) for k, v in vars(self.key).items()},
107+
}
108+
109+
data = {
110+
"config": config.to_json(),
111+
"key": key_dict,
112+
}
113+
114+
# Atomic write
115+
tmp = path.parent / f"tmp.{uuid.uuid4()!s}"
116+
tmp.write_text(json.dumps(data, indent=2))
117+
os.rename(str(tmp), str(path))
96118

97119
def _get_cache_info_message(self) -> str:
98120
cache_dir = self._get_local_cache_path().parent
99121
return f"Cache directory: {cache_dir}. To run autotuning again, delete the cache directory or set HELION_SKIP_CACHE=1."
100122

123+
def _get_cache_key(self) -> LooseAutotuneCacheKey:
124+
return self.key
125+
126+
def _list_cache_entries(self) -> Sequence[tuple[str, LooseAutotuneCacheKey]]:
127+
"""List all cache entries in the cache directory."""
128+
cache_dir = self._get_local_cache_path().parent
129+
if not cache_dir.exists():
130+
return []
131+
132+
current_key_hash = self.key.stable_hash()
133+
entries: list[tuple[str, LooseAutotuneCacheKey]] = []
134+
for cache_file in cache_dir.glob("*.best_config"):
135+
try:
136+
data = json.loads(cache_file.read_text())
137+
file_hash = cache_file.stem
138+
139+
if file_hash == current_key_hash:
140+
continue
141+
142+
key_data = data["key"]
143+
144+
# Create a simple namespace object that has the same attributes
145+
# for comparison purposes (we don't need the full key object)
146+
class CachedKey:
147+
def __init__(self, fields: dict[str, str]) -> None:
148+
for name, value in fields.items():
149+
setattr(self, name, value)
150+
151+
cached_key = CachedKey(key_data["fields"])
152+
entries.append((cache_file.name, cached_key)) # type: ignore[arg-type]
153+
except Exception:
154+
pass
155+
156+
return entries
157+
101158

102159
class StrictLocalAutotuneCache(LocalAutotuneCache):
103160
"""

helion/exc.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ class AutotuneError(BaseError):
5252
message = "{0}"
5353

5454

55+
class CacheAssertionError(BaseError):
56+
message = "Expected cache hit for kernel '{0}', but got cache miss. See stderr for diagnostic information."
57+
58+
5559
class ClosureMutation(BaseError):
5660
message = "Closure mutation (of {0}) is not allowed in a function arg."
5761

@@ -448,3 +452,12 @@ class NoDeviceLoopsInKernel(BaseError):
448452
"Kernel contains no device loops. Add an hl.tile(...) or hl.grid(...) loop "
449453
"around your device computations."
450454
)
455+
456+
457+
class NestedKernelCallsNotSupported(BaseError):
458+
message = (
459+
"Calling a Helion kernel from within another Helion kernel is not supported. "
460+
"Helion kernels can only be called from outside of @helion.kernel functions. "
461+
"If you need to share code between kernels, consider extracting the shared logic "
462+
"into a regular Python function that can be called from within both kernels."
463+
)

0 commit comments

Comments
 (0)