Skip to content

Commit 2ff186c

Browse files
authored
Better error message for calling Helion kernel from another kernel (#1008)
1 parent 4571892 commit 2ff186c

File tree

3 files changed

+41
-1
lines changed

3 files changed

+41
-1
lines changed

helion/_compiler/type_propagation.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ..language.stack_tensor import StackTensor
3131
from ..language.tile_proxy import Tile
3232
from ..language.tile_proxy import _CheckForIndexCalls
33+
from ..runtime.kernel import Kernel
3334
from .ast_extension import ExtendedAST
3435
from .ast_extension import LoopType
3536
from .ast_extension import create
@@ -105,6 +106,8 @@ def _get(self, name: str) -> TypeInfo:
105106
return TypeInfo.from_example(value, origin)
106107

107108
origin = self.function.global_scope_origin(name)
109+
if isinstance(value, Kernel):
110+
return TypeInfo.from_example(value, origin)
108111
if not isinstance(
109112
value,
110113
(types.ModuleType, types.FunctionType, types.BuiltinFunctionType),
@@ -1973,9 +1976,12 @@ def visit_Compare(self, node: ast.Compare) -> TypeInfo:
19731976

19741977
def visit_Call(self, node: ast.Call) -> TypeInfo:
19751978
# TODO(jansel): test handling if *args and **kwargs
1976-
# TODO(jansel): check for calling a Kernel here
19771979
func = self.visit(node.func)
19781980

1981+
# Check for calling a Helion kernel from within another Helion kernel
1982+
if isinstance(func, CallableType) and isinstance(func.value, Kernel):
1983+
raise exc.NestedKernelCallsNotSupported
1984+
19791985
if (
19801986
isinstance(func, CallableType)
19811987
and self.origin().is_device()

helion/exc.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,3 +452,12 @@ class NoDeviceLoopsInKernel(BaseError):
452452
"Kernel contains no device loops. Add an hl.tile(...) or hl.grid(...) loop "
453453
"around your device computations."
454454
)
455+
456+
457+
class NestedKernelCallsNotSupported(BaseError):
458+
message = (
459+
"Calling a Helion kernel from within another Helion kernel is not supported. "
460+
"Helion kernels can only be called from outside of @helion.kernel functions. "
461+
"If you need to share code between kernels, consider extracting the shared logic "
462+
"into a regular Python function that can be called from within both kernels."
463+
)

test/test_errors.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,22 @@
1616
import helion.language as hl
1717

1818

19+
@helion.kernel()
20+
def _test_inner_kernel(x: torch.Tensor) -> torch.Tensor:
21+
out = torch.empty_like(x)
22+
for tile in hl.tile(x.shape):
23+
out[tile] = x[tile] * 2
24+
return out
25+
26+
27+
@helion.kernel()
28+
def _test_outer_kernel_calling_inner(x: torch.Tensor) -> torch.Tensor:
29+
out = torch.empty_like(x)
30+
for tile in hl.tile(x.shape):
31+
out[tile] = _test_inner_kernel(x[tile])
32+
return out
33+
34+
1935
class TestErrors(RefEagerTestDisabled, TestCase):
2036
def test_autotune_no_valid_configs(self):
2137
class FakeKernel:
@@ -452,6 +468,15 @@ def fn(x: torch.Tensor) -> torch.Tensor:
452468
with self.assertRaises(helion.exc.TileOfTile):
453469
code_and_output(fn, (torch.randn(8, device=DEVICE),))
454470

471+
def test_nested_kernel_calls(self):
472+
with self.assertRaisesRegex(
473+
helion.exc.NestedKernelCallsNotSupported,
474+
r"Calling a Helion kernel from within another Helion kernel is not supported",
475+
):
476+
code_and_output(
477+
_test_outer_kernel_calling_inner, (torch.randn(8, device=DEVICE),)
478+
)
479+
455480

456481
if __name__ == "__main__":
457482
unittest.main()

0 commit comments

Comments
 (0)