Skip to content

Commit 4571892

Browse files
authored
Add HELION_ASSERT_CACHE_HIT to debug/explain cache miss (#1006)
1 parent f1d9a53 commit 4571892

File tree

5 files changed

+150
-3
lines changed

5 files changed

+150
-3
lines changed

docs/api/settings.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"Differe
257257
| ``HELION_AUTOTUNE_CONFIG_OVERRIDES`` | ``autotune_config_overrides`` | Supply JSON forcing particular autotuner config key/value pairs. |
258258
| ``HELION_CACHE_DIR`` | ``LocalAutotuneCache`` | Override the on-disk directory used for cached autotuning artifacts. |
259259
| ``HELION_SKIP_CACHE`` | ``LocalAutotuneCache`` | When set to ``1``, ignore cached autotuning entries and rerun searches. |
260+
| ``HELION_ASSERT_CACHE_HIT`` | ``AutotuneCacheBase`` | When set to ``1``, require a cache hit; raises ``CacheAssertionError`` on cache miss with detailed diagnostics. |
260261
| ``HELION_PRINT_OUTPUT_CODE`` | ``print_output_code`` | Print generated Triton code to stderr for inspection. |
261262
| ``HELION_OUTPUT_ORIGIN_LINES`` | ``output_origin_lines`` | Include ``# src[...]`` comments in generated Triton code; set to ``0`` to disable. |
262263
| ``HELION_IGNORE_WARNINGS`` | ``ignore_warnings`` | Comma-separated warning names defined in ``helion.exc`` to suppress. |

helion/autotuner/base_cache.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from __future__ import annotations
22

33
import abc
4+
from collections.abc import Sequence
45
import dataclasses
56
import functools
67
import hashlib
78
import logging
89
import os
10+
import sys
911
from typing import TYPE_CHECKING
1012
from typing import Any
1113
from typing import Callable
@@ -14,6 +16,7 @@
1416
from torch._inductor.codecache import build_code_hash
1517
from torch._inductor.codecache import torch_key
1618

19+
from .. import exc
1720
from .._utils import counters
1821
from .base_search import BaseAutotuner
1922

@@ -67,7 +70,8 @@ def torch_key_wrapper() -> str:
6770
def triton_key_wrapper() -> str:
6871
from torch._inductor.runtime.triton_compat import triton_key
6972

70-
return triton_key()
73+
full_key = triton_key()
74+
return hashlib.sha256(full_key.encode("utf-8")).hexdigest()
7175

7276

7377
class CacheKeyBase:
@@ -157,6 +161,16 @@ def _get_cache_info_message(self) -> str:
157161
"""Return a message describing where the cache is and how to clear it."""
158162
return ""
159163

164+
@abc.abstractmethod
165+
def _get_cache_key(self) -> CacheKeyBase:
166+
"""Return the cache key for this cache instance."""
167+
raise NotImplementedError
168+
169+
@abc.abstractmethod
170+
def _list_cache_entries(self) -> Sequence[tuple[str, CacheKeyBase]]:
171+
"""Return a sequence of (description, key) tuples for all cache entries."""
172+
raise NotImplementedError
173+
160174
def autotune(self, *, skip_cache: bool = False) -> Config:
161175
if skip_cache or os.environ.get("HELION_SKIP_CACHE", "") not in {
162176
"",
@@ -178,6 +192,43 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
178192
counters["autotune"]["cache_miss"] += 1
179193
log.debug("cache miss")
180194

195+
if os.environ.get("HELION_ASSERT_CACHE_HIT") == "1":
196+
current_key = self._get_cache_key()
197+
print("\n" + "=" * 80, file=sys.stderr)
198+
print("HELION_ASSERT_CACHE_HIT: Cache miss detected!", file=sys.stderr)
199+
print("=" * 80, file=sys.stderr)
200+
print(f"\nKernel: {self.kernel.kernel.name}", file=sys.stderr)
201+
print(f"\nCurrent cache key:\n{current_key}", file=sys.stderr)
202+
203+
cache_entries = self._list_cache_entries()
204+
if cache_entries:
205+
print(
206+
f"\n{len(cache_entries)} other cache entries exist (but don't match):",
207+
file=sys.stderr,
208+
)
209+
for i, (desc, cached_key) in enumerate(cache_entries, 1):
210+
print(f"\n[Entry {i}] {desc}", file=sys.stderr)
211+
print(" Key differences:", file=sys.stderr)
212+
has_diff = False
213+
for field_name in vars(current_key):
214+
current_val = str(getattr(current_key, field_name))
215+
cached_val = str(getattr(cached_key, field_name, "<missing>"))
216+
if current_val != cached_val:
217+
has_diff = True
218+
print(f" {field_name}:", file=sys.stderr)
219+
print(f" Current: {current_val}", file=sys.stderr)
220+
print(f" Cached: {cached_val}", file=sys.stderr)
221+
if not has_diff:
222+
print(
223+
" (no differences found, likely a hash collision)",
224+
file=sys.stderr,
225+
)
226+
else:
227+
print("\nNo existing cache entries found.", file=sys.stderr)
228+
229+
print("=" * 80 + "\n", file=sys.stderr)
230+
raise exc.CacheAssertionError(self.kernel.kernel.name)
231+
181232
self.autotuner.log("Starting autotuning process, this may take a while...")
182233

183234
config = self.autotuner.autotune()

helion/autotuner/local_cache.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33
import hashlib
44
import inspect
5+
import json
56
import logging
67
import os
78
from pathlib import Path
89
import textwrap
910
from typing import TYPE_CHECKING
11+
import uuid
1012

1113
import torch
1214
from torch._inductor.runtime.cache_dir_utils import (
@@ -19,6 +21,8 @@
1921
from .base_cache import StrictAutotuneCacheKey
2022

2123
if TYPE_CHECKING:
24+
from collections.abc import Sequence
25+
2226
from .base_search import BaseSearch
2327

2428
log: logging.Logger = logging.getLogger(__name__)
@@ -86,18 +90,71 @@ def _get_local_cache_path(self) -> Path:
8690
def get(self) -> Config | None:
8791
path = self._get_local_cache_path()
8892
try:
89-
return Config.load(path)
93+
data = json.loads(path.read_text())
94+
return Config.from_json(data["config"])
9095
except Exception:
9196
return None
9297

9398
def put(self, config: Config) -> None:
9499
path = self._get_local_cache_path()
95-
config.save(path)
100+
path.parent.mkdir(parents=True, exist_ok=True)
101+
102+
# Save both config and key for better debugging
103+
# Store key as dict for safer reconstruction (avoids eval)
104+
key_dict = {
105+
"type": type(self.key).__name__,
106+
"fields": {k: str(v) for k, v in vars(self.key).items()},
107+
}
108+
109+
data = {
110+
"config": config.to_json(),
111+
"key": key_dict,
112+
}
113+
114+
# Atomic write
115+
tmp = path.parent / f"tmp.{uuid.uuid4()!s}"
116+
tmp.write_text(json.dumps(data, indent=2))
117+
os.rename(str(tmp), str(path))
96118

97119
def _get_cache_info_message(self) -> str:
98120
cache_dir = self._get_local_cache_path().parent
99121
return f"Cache directory: {cache_dir}. To run autotuning again, delete the cache directory or set HELION_SKIP_CACHE=1."
100122

123+
def _get_cache_key(self) -> LooseAutotuneCacheKey:
124+
return self.key
125+
126+
def _list_cache_entries(self) -> Sequence[tuple[str, LooseAutotuneCacheKey]]:
127+
"""List all cache entries in the cache directory."""
128+
cache_dir = self._get_local_cache_path().parent
129+
if not cache_dir.exists():
130+
return []
131+
132+
current_key_hash = self.key.stable_hash()
133+
entries: list[tuple[str, LooseAutotuneCacheKey]] = []
134+
for cache_file in cache_dir.glob("*.best_config"):
135+
try:
136+
data = json.loads(cache_file.read_text())
137+
file_hash = cache_file.stem
138+
139+
if file_hash == current_key_hash:
140+
continue
141+
142+
key_data = data["key"]
143+
144+
# Create a simple namespace object that has the same attributes
145+
# for comparison purposes (we don't need the full key object)
146+
class CachedKey:
147+
def __init__(self, fields: dict[str, str]) -> None:
148+
for name, value in fields.items():
149+
setattr(self, name, value)
150+
151+
cached_key = CachedKey(key_data["fields"])
152+
entries.append((cache_file.name, cached_key)) # type: ignore[arg-type]
153+
except Exception:
154+
pass
155+
156+
return entries
157+
101158

102159
class StrictLocalAutotuneCache(LocalAutotuneCache):
103160
"""

helion/exc.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ class AutotuneError(BaseError):
5252
message = "{0}"
5353

5454

55+
class CacheAssertionError(BaseError):
56+
message = "Expected cache hit for kernel '{0}', but got cache miss. See stderr for diagnostic information."
57+
58+
5559
class ClosureMutation(BaseError):
5660
message = "Closure mutation (of {0}) is not allowed in a function arg."
5761

test/test_cache.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from __future__ import annotations
22

3+
import os
34
import unittest
5+
from unittest.mock import patch
46

57
import torch
68
from torch.testing._internal.common_utils import instantiate_parametrized_tests
79
from torch.testing._internal.common_utils import parametrize
810

911
import helion
12+
from helion import exc
1013
from helion._testing import DEVICE
1114
from helion._testing import EXAMPLES_DIR
1215
from helion._testing import RefEagerTestDisabled
@@ -147,6 +150,37 @@ def add_one(x: torch.Tensor):
147150
self.assertEqual(counters["autotune"]["cache_hit"], 1)
148151
self.assertEqual(counters["autotune"]["cache_put"], 2)
149152

153+
def test_assert_cache_hit(self):
154+
counters["autotune"].clear()
155+
self.addCleanup(counters["autotune"].clear)
156+
157+
kernel, args_a, result_a, args_b, result_b = KERNELS["add"]()
158+
kernel.reset()
159+
kernel.settings.autotuner_fn = StrictLocalAutotuneCache[BasicSearch]
160+
kernel.settings.autotune_effort = "full"
161+
162+
result = kernel(*args_a)
163+
torch.testing.assert_close(result, result_a)
164+
self.assertEqual(counters["autotune"]["cache_miss"], 1)
165+
self.assertEqual(counters["autotune"]["cache_hit"], 0)
166+
167+
kernel.reset()
168+
with patch.dict(os.environ, {"HELION_ASSERT_CACHE_HIT": "1"}):
169+
result = kernel(*args_a)
170+
torch.testing.assert_close(result, result_a)
171+
self.assertEqual(counters["autotune"]["cache_miss"], 1)
172+
self.assertEqual(counters["autotune"]["cache_hit"], 1)
173+
174+
kernel.reset()
175+
with patch.dict(os.environ, {"HELION_ASSERT_CACHE_HIT": "1"}):
176+
with self.assertRaises(exc.CacheAssertionError) as cm:
177+
kernel(*args_b)
178+
179+
self.assertIn("add", str(cm.exception))
180+
# cache_miss incremented before error, but cache_put not (autotuning prevented)
181+
self.assertEqual(counters["autotune"]["cache_miss"], 2)
182+
self.assertEqual(counters["autotune"]["cache_put"], 1)
183+
150184

151185
instantiate_parametrized_tests(TestCache)
152186

0 commit comments

Comments
 (0)