Skip to content

Commit 5b7f668

Browse files
authored
add cute_inductor example (#157)
Signed-off-by: Mayank Mishra <[email protected]>
1 parent 71af063 commit 5b7f668

File tree

4 files changed

+75
-27
lines changed

4 files changed

+75
-27
lines changed

cute_kernels/cute_inductor/compiler.py

+21-24
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from torch._dynamo import lookup_backend
55

6-
from ..utils import get_boolean_env_variable, set_cute_tracing
6+
from ..utils import enable_cute_tracing, get_boolean_env_variable
77
from .rmsnorm import replace_rmsnorm
88
from .swiglu_unchunked import replace_swiglu_unchunked
99

@@ -21,26 +21,23 @@ def __init__(
2121
self.replace_functions = replace_functions
2222

2323
def compiler(self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]) -> Callable:
24-
set_cute_tracing(True)
25-
26-
if _DEBUG_CUTEINDUCTOR:
27-
print("graph before cute inductor")
28-
gm.print_readable()
29-
30-
for replace_function in self.replace_functions:
31-
for node in gm.graph.nodes:
32-
replace_function(gm, node)
33-
34-
if _DEBUG_CUTEINDUCTOR:
35-
print("graph after cute inductor")
36-
gm.print_readable()
37-
38-
if self.use_torch_inductor_after_cute_inductor:
39-
inductor = lookup_backend("inductor")
40-
compiled = inductor(gm, example_inputs)
41-
else:
42-
compiled = gm.forward
43-
44-
set_cute_tracing(False)
45-
46-
return compiled
24+
with enable_cute_tracing():
25+
if _DEBUG_CUTEINDUCTOR:
26+
print("graph before cute inductor")
27+
gm.print_readable()
28+
29+
for replace_function in self.replace_functions:
30+
for node in gm.graph.nodes:
31+
replace_function(gm, node)
32+
33+
if _DEBUG_CUTEINDUCTOR:
34+
print("graph after cute inductor")
35+
gm.print_readable()
36+
37+
if self.use_torch_inductor_after_cute_inductor:
38+
inductor = lookup_backend("inductor")
39+
compiled = inductor(gm, example_inputs)
40+
else:
41+
compiled = gm.forward
42+
43+
return compiled

cute_kernels/utils/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .contiguous import ensure_contiguous, ensure_same_strides
2-
from .custom_op import cute_op, set_cute_tracing
2+
from .custom_op import cute_op, enable_cute_tracing
33
from .device import device_synchronize, get_sm_count, is_hip
44
from .env import get_boolean_env_variable
55
from .ptx import get_ptx_from_triton_kernel

cute_kernels/utils/custom_op.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
from contextlib import contextmanager
23
from typing import Callable, Iterable, Sequence
34

45
import torch
@@ -7,9 +8,14 @@
78
_IS_CUTE_TRACING = False
89

910

10-
def set_cute_tracing(enable: bool) -> None:
11+
@contextmanager
12+
def enable_cute_tracing():
1113
global _IS_CUTE_TRACING
12-
_IS_CUTE_TRACING = enable
14+
_IS_CUTE_TRACING = True
15+
16+
yield
17+
18+
_IS_CUTE_TRACING = False
1319

1420

1521
def _dispatch(func: Callable, compileable_fn: Callable, *args, **kwargs):

examples/cute_inductor.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from cute_kernels import CuteInductor
5+
from cute_kernels.cute_inductor.rmsnorm import replace_rmsnorm
6+
from cute_kernels.cute_inductor.swiglu_unchunked import replace_swiglu_unchunked
7+
8+
9+
# NOTE swiglu unchunked computes:
10+
# ------------------------------------------------------------------------------
11+
# def swiglu_unchunked_torch(x: torch.Tensor) -> torch.Tensor:
12+
# x = x.chunk(2, dim=-1)
13+
# return x[0] * F.silu(x[1])
14+
# ------------------------------------------------------------------------------
15+
16+
17+
class Model(nn.Module):
18+
def __init__(self) -> None:
19+
super().__init__()
20+
21+
self.norm1 = nn.RMSNorm(4)
22+
self.linear = nn.Linear(4, 4)
23+
self.norm2 = nn.RMSNorm(4)
24+
25+
def forward(self, x: torch.Tensor) -> torch.Tensor:
26+
x = self.norm1(x)
27+
x = self.linear(x)
28+
x = self.norm2(x)
29+
return x
30+
31+
32+
model = Model().to(torch.cuda.current_device())
33+
34+
use_torch_inductor_after_cute_inductor = True # to use torch's compiler optimizations as well
35+
replace_functions = [replace_rmsnorm] # add other replacing functions
36+
37+
cute_inductor = CuteInductor(
38+
use_torch_inductor_after_cute_inductor=use_torch_inductor_after_cute_inductor, replace_functions=replace_functions
39+
)
40+
41+
compiled_model = torch.compile(model, backend=cute_inductor.compiler)
42+
43+
# trigger JIT compilation
44+
x = torch.randn(4, 4, device=torch.cuda.current_device())
45+
y = compiled_model(x)

0 commit comments

Comments
 (0)