|
3 | 3 | import unittest |
4 | 4 |
|
5 | 5 | import torch |
| 6 | +from torch.testing._internal.common_utils import instantiate_parametrized_tests |
| 7 | +from torch.testing._internal.common_utils import parametrize |
6 | 8 |
|
7 | 9 | import helion |
8 | 10 | from helion._testing import DEVICE |
| 11 | +from helion._testing import EXAMPLES_DIR |
9 | 12 | from helion._testing import RefEagerTestDisabled |
10 | 13 | from helion._testing import TestCase |
| 14 | +from helion._testing import import_path |
11 | 15 | from helion._utils import counters |
12 | 16 | from helion.autotuner import StrictLocalAutotuneCache |
13 | 17 | from helion.autotuner.base_search import BaseSearch |
14 | 18 | import helion.language as hl |
15 | 19 |
|
16 | 20 |
|
17 | 21 | class BasicSearch(BaseSearch): |
18 | | - def autotune(self): |
| 22 | + def autotune(self, *, skip_cache: bool = False): |
19 | 23 | return self.config_spec.default_config() |
20 | 24 |
|
21 | 25 |
|
| 26 | +def get_add_kernel(): |
| 27 | + kernel = import_path(EXAMPLES_DIR / "add.py").add |
| 28 | + a = torch.randn(16, device=DEVICE, dtype=torch.bfloat16) |
| 29 | + args_a = (a, a) |
| 30 | + b = torch.randn(16, device=DEVICE, dtype=torch.float16) |
| 31 | + args_b = (b, b) |
| 32 | + return kernel, args_a, a + a, args_b, b + b |
| 33 | + |
| 34 | + |
| 35 | +def get_matmul_kernel(): |
| 36 | + kernel = import_path(EXAMPLES_DIR / "matmul.py").matmul |
| 37 | + a = torch.randn(16, 16, device=DEVICE, dtype=torch.bfloat16) |
| 38 | + args_a = (a, a, lambda acc, tile: torch.relu(acc)) |
| 39 | + args_b = (a, a, lambda acc, tile: torch.sigmoid(acc)) |
| 40 | + return kernel, args_a, torch.relu(a @ a), args_b, torch.sigmoid(a @ a) |
| 41 | + |
| 42 | + |
| 43 | +def get_welford_kernel(): |
| 44 | + kernel = import_path(EXAMPLES_DIR / "welford.py").welford |
| 45 | + eager = import_path(EXAMPLES_DIR / "welford.py").eager_layer_norm |
| 46 | + |
| 47 | + s, d = 2**10, 2**4 |
| 48 | + weight = torch.rand((d,), device=DEVICE, dtype=torch.float32) |
| 49 | + bias = torch.rand((d,), device=DEVICE, dtype=torch.float32) |
| 50 | + x = torch.rand((s, d), device=DEVICE, dtype=torch.float32) |
| 51 | + args_a = (weight, bias, x) |
| 52 | + result_a = eager(*args_a) |
| 53 | + |
| 54 | + s, d = 2**10, 2**6 |
| 55 | + weight = torch.rand((d,), device=DEVICE, dtype=torch.float32) |
| 56 | + bias = torch.rand((d,), device=DEVICE, dtype=torch.float32) |
| 57 | + x = torch.rand((s, d), device=DEVICE, dtype=torch.float32) |
| 58 | + args_b = (weight, bias, x) |
| 59 | + result_b = eager(*args_b) |
| 60 | + |
| 61 | + return kernel, args_a, result_a, args_b, result_b |
| 62 | + |
| 63 | + |
| 64 | +KERNELS = { |
| 65 | + "add": get_add_kernel, |
| 66 | + "matmul": get_matmul_kernel, |
| 67 | + "welford": get_welford_kernel, |
| 68 | +} |
| 69 | + |
| 70 | + |
22 | 71 | class TestCache(RefEagerTestDisabled, TestCase): |
23 | | - def test_basic(self): |
24 | | - @helion.kernel( |
25 | | - autotuner_fn=StrictLocalAutotuneCache[BasicSearch], autotune_effort="full" |
26 | | - ) |
27 | | - def add(x, y): |
28 | | - x, y = torch.broadcast_tensors(x, y) |
29 | | - out = torch.empty_like(x) |
30 | | - for tile in hl.tile(out.size()): |
31 | | - out[tile] = x[tile] + y[tile] |
32 | | - return out |
| 72 | + @parametrize("name", ("add", "matmul", "welford")) |
| 73 | + def test_kernel(self, name): |
| 74 | + kernel, args_a, result_a, args_b, result_b = KERNELS[name]() |
33 | 75 |
|
34 | | - a = torch.randn(16, device=DEVICE, dtype=torch.bfloat16) |
35 | | - args_a = (a, a) |
36 | | - b = torch.randn(16, device=DEVICE, dtype=torch.float16) |
37 | | - args_b = (b, b) |
| 76 | + kernel.reset() |
| 77 | + kernel.settings.autotuner_fn = StrictLocalAutotuneCache[BasicSearch] |
| 78 | + kernel.settings.autotune_effort = "full" |
38 | 79 |
|
39 | | - result = add(*args_a) |
40 | | - torch.testing.assert_close(result, a + a) |
| 80 | + result = kernel(*args_a) |
| 81 | + torch.testing.assert_close(result, result_a, rtol=1e-2, atol=5e-2) |
41 | 82 |
|
42 | 83 | self.assertEqual(counters["autotune"]["cache_miss"], 1) |
43 | 84 | self.assertEqual(counters["autotune"]["cache_hit"], 0) |
44 | 85 | self.assertEqual(counters["autotune"]["cache_put"], 1) |
45 | 86 |
|
46 | | - add.reset() |
| 87 | + kernel.reset() |
47 | 88 |
|
48 | | - result = add(*args_a) |
49 | | - torch.testing.assert_close(result, a + a) |
| 89 | + result = kernel(*args_a) |
| 90 | + torch.testing.assert_close(result, result_a, rtol=1e-2, atol=5e-2) |
50 | 91 |
|
51 | 92 | self.assertEqual(counters["autotune"]["cache_miss"], 1) |
52 | 93 | self.assertEqual(counters["autotune"]["cache_hit"], 1) |
53 | 94 | self.assertEqual(counters["autotune"]["cache_put"], 1) |
54 | 95 |
|
55 | | - add.reset() |
| 96 | + kernel.reset() |
56 | 97 |
|
57 | | - result = add(*args_b) |
58 | | - torch.testing.assert_close(result, b + b) |
| 98 | + result = kernel(*args_b) |
| 99 | + torch.testing.assert_close(result, result_b, rtol=1e-2, atol=5e-2) |
59 | 100 |
|
60 | 101 | self.assertEqual(counters["autotune"]["cache_miss"], 2) |
61 | 102 | self.assertEqual(counters["autotune"]["cache_hit"], 1) |
@@ -107,5 +148,8 @@ def add_one(x: torch.Tensor): |
107 | 148 | self.assertEqual(counters["autotune"]["cache_put"], 2) |
108 | 149 |
|
109 | 150 |
|
| 151 | +instantiate_parametrized_tests(TestCache) |
| 152 | + |
| 153 | + |
110 | 154 | if __name__ == "__main__": |
111 | 155 | unittest.main() |
0 commit comments