Skip to content

Commit f1d9a53

Browse files
authored
Beef up caching tests (#1001)
1 parent 2735d0a commit f1d9a53

File tree

1 file changed

+67
-23
lines changed

1 file changed

+67
-23
lines changed

test/test_cache.py

Lines changed: 67 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,59 +3,100 @@
33
import unittest
44

55
import torch
6+
from torch.testing._internal.common_utils import instantiate_parametrized_tests
7+
from torch.testing._internal.common_utils import parametrize
68

79
import helion
810
from helion._testing import DEVICE
11+
from helion._testing import EXAMPLES_DIR
912
from helion._testing import RefEagerTestDisabled
1013
from helion._testing import TestCase
14+
from helion._testing import import_path
1115
from helion._utils import counters
1216
from helion.autotuner import StrictLocalAutotuneCache
1317
from helion.autotuner.base_search import BaseSearch
1418
import helion.language as hl
1519

1620

1721
class BasicSearch(BaseSearch):
18-
def autotune(self):
22+
def autotune(self, *, skip_cache: bool = False):
1923
return self.config_spec.default_config()
2024

2125

26+
def get_add_kernel():
27+
kernel = import_path(EXAMPLES_DIR / "add.py").add
28+
a = torch.randn(16, device=DEVICE, dtype=torch.bfloat16)
29+
args_a = (a, a)
30+
b = torch.randn(16, device=DEVICE, dtype=torch.float16)
31+
args_b = (b, b)
32+
return kernel, args_a, a + a, args_b, b + b
33+
34+
35+
def get_matmul_kernel():
36+
kernel = import_path(EXAMPLES_DIR / "matmul.py").matmul
37+
a = torch.randn(16, 16, device=DEVICE, dtype=torch.bfloat16)
38+
args_a = (a, a, lambda acc, tile: torch.relu(acc))
39+
args_b = (a, a, lambda acc, tile: torch.sigmoid(acc))
40+
return kernel, args_a, torch.relu(a @ a), args_b, torch.sigmoid(a @ a)
41+
42+
43+
def get_welford_kernel():
44+
kernel = import_path(EXAMPLES_DIR / "welford.py").welford
45+
eager = import_path(EXAMPLES_DIR / "welford.py").eager_layer_norm
46+
47+
s, d = 2**10, 2**4
48+
weight = torch.rand((d,), device=DEVICE, dtype=torch.float32)
49+
bias = torch.rand((d,), device=DEVICE, dtype=torch.float32)
50+
x = torch.rand((s, d), device=DEVICE, dtype=torch.float32)
51+
args_a = (weight, bias, x)
52+
result_a = eager(*args_a)
53+
54+
s, d = 2**10, 2**6
55+
weight = torch.rand((d,), device=DEVICE, dtype=torch.float32)
56+
bias = torch.rand((d,), device=DEVICE, dtype=torch.float32)
57+
x = torch.rand((s, d), device=DEVICE, dtype=torch.float32)
58+
args_b = (weight, bias, x)
59+
result_b = eager(*args_b)
60+
61+
return kernel, args_a, result_a, args_b, result_b
62+
63+
64+
KERNELS = {
65+
"add": get_add_kernel,
66+
"matmul": get_matmul_kernel,
67+
"welford": get_welford_kernel,
68+
}
69+
70+
2271
class TestCache(RefEagerTestDisabled, TestCase):
23-
def test_basic(self):
24-
@helion.kernel(
25-
autotuner_fn=StrictLocalAutotuneCache[BasicSearch], autotune_effort="full"
26-
)
27-
def add(x, y):
28-
x, y = torch.broadcast_tensors(x, y)
29-
out = torch.empty_like(x)
30-
for tile in hl.tile(out.size()):
31-
out[tile] = x[tile] + y[tile]
32-
return out
72+
@parametrize("name", ("add", "matmul", "welford"))
73+
def test_kernel(self, name):
74+
kernel, args_a, result_a, args_b, result_b = KERNELS[name]()
3375

34-
a = torch.randn(16, device=DEVICE, dtype=torch.bfloat16)
35-
args_a = (a, a)
36-
b = torch.randn(16, device=DEVICE, dtype=torch.float16)
37-
args_b = (b, b)
76+
kernel.reset()
77+
kernel.settings.autotuner_fn = StrictLocalAutotuneCache[BasicSearch]
78+
kernel.settings.autotune_effort = "full"
3879

39-
result = add(*args_a)
40-
torch.testing.assert_close(result, a + a)
80+
result = kernel(*args_a)
81+
torch.testing.assert_close(result, result_a, rtol=1e-2, atol=5e-2)
4182

4283
self.assertEqual(counters["autotune"]["cache_miss"], 1)
4384
self.assertEqual(counters["autotune"]["cache_hit"], 0)
4485
self.assertEqual(counters["autotune"]["cache_put"], 1)
4586

46-
add.reset()
87+
kernel.reset()
4788

48-
result = add(*args_a)
49-
torch.testing.assert_close(result, a + a)
89+
result = kernel(*args_a)
90+
torch.testing.assert_close(result, result_a, rtol=1e-2, atol=5e-2)
5091

5192
self.assertEqual(counters["autotune"]["cache_miss"], 1)
5293
self.assertEqual(counters["autotune"]["cache_hit"], 1)
5394
self.assertEqual(counters["autotune"]["cache_put"], 1)
5495

55-
add.reset()
96+
kernel.reset()
5697

57-
result = add(*args_b)
58-
torch.testing.assert_close(result, b + b)
98+
result = kernel(*args_b)
99+
torch.testing.assert_close(result, result_b, rtol=1e-2, atol=5e-2)
59100

60101
self.assertEqual(counters["autotune"]["cache_miss"], 2)
61102
self.assertEqual(counters["autotune"]["cache_hit"], 1)
@@ -107,5 +148,8 @@ def add_one(x: torch.Tensor):
107148
self.assertEqual(counters["autotune"]["cache_put"], 2)
108149

109150

151+
instantiate_parametrized_tests(TestCache)
152+
153+
110154
if __name__ == "__main__":
111155
unittest.main()

0 commit comments

Comments
 (0)