Skip to content

Commit cb47477

Browse files
committed
Add a pass to keep cond predicate on CPU memory
1 parent 355a7a6 commit cb47477

File tree

5 files changed

+143
-5
lines changed

5 files changed

+143
-5
lines changed

backends/aoti/aoti_backend.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,10 @@ def preprocess(
166166
# Apply custom backend-specific passes
167167
custom_passes = cls.get_custom_passes()
168168
for custom_pass in custom_passes:
169-
custom_pass(device_edge_program.graph_module)
169+
if getattr(custom_pass, "requires_exported_program", False):
170+
custom_pass(device_edge_program)
171+
else:
172+
custom_pass(device_edge_program.graph_module)
170173

171174
# Run decompositions if any
172175
if decomposition_table:
@@ -187,9 +190,10 @@ def preprocess(
187190
missing_fallback_kernels: Set[str] = set()
188191

189192
# Compile with fallback kernel collection
190-
with cls.collect_unsupported_fallback_kernels(
191-
missing_fallback_kernels
192-
), torch.no_grad():
193+
with (
194+
cls.collect_unsupported_fallback_kernels(missing_fallback_kernels),
195+
torch.no_grad(),
196+
):
193197
paths = torch._inductor.aot_compile(
194198
edge_program_module, tuple(user_input_placeholders), options=options
195199
)

backends/cuda/cuda_backend.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
import torch
1212
from executorch.backends.aoti.aoti_backend import AotiBackend
13+
from executorch.backends.cuda.passes.keep_cond_predicate_on_cpu import (
14+
KeepCondPredicateOnCpuPass,
15+
)
1316
from executorch.backends.cuda.triton.replacement_pass import (
1417
ReplaceEdgeOpWithTritonOpPass,
1518
)
@@ -49,7 +52,7 @@ def get_decomposition_table(cls) -> Dict[Any, Any]:
4952
@classmethod
5053
def get_custom_passes(cls) -> List[typing.Any]:
5154
"""Return CUDA-specific passes: ReplaceEdgeOpWithTritonOpPass"""
52-
return [ReplaceEdgeOpWithTritonOpPass()]
55+
return [KeepCondPredicateOnCpuPass(), ReplaceEdgeOpWithTritonOpPass()]
5356

5457
@classmethod
5558
def get_aoti_compile_options(

backends/cuda/passes/__init__.py

Whitespace-only changes.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import torch
2+
from torch.export import ExportedProgram
3+
4+
5+
class KeepCondPredicateOnCpuPass:
6+
"""
7+
A pass that locates torch.cond in the graph and makes sure the predicate stays on CPU
8+
if the predicate is a buffer (placeholder).
9+
"""
10+
11+
requires_exported_program = True
12+
13+
def __call__(self, exported_program: ExportedProgram):
14+
graph_module = exported_program.graph_module
15+
state_dict = exported_program.state_dict
16+
17+
# Map input names to buffer names
18+
inputs_to_buffers = exported_program.graph_signature.inputs_to_buffers
19+
20+
for node in graph_module.graph.nodes:
21+
if (
22+
node.op == "call_function"
23+
and node.target == torch.ops.higher_order.cond
24+
):
25+
pred_node = node.args[0]
26+
if pred_node.op == "placeholder":
27+
# Found a placeholder used as predicate
28+
# Check if it corresponds to a buffer
29+
if pred_node.name in inputs_to_buffers:
30+
buffer_name = inputs_to_buffers[pred_node.name]
31+
32+
# Move the buffer in state_dict to CPU
33+
if buffer_name in state_dict:
34+
# We modify the tensor in place or replace it?
35+
# Replacing it is safer.
36+
tensor = exported_program.state_dict[buffer_name]
37+
if tensor.device.type != "cpu":
38+
if isinstance(tensor, torch.nn.Parameter):
39+
exported_program._state_dict[buffer_name] = (
40+
torch.nn.Parameter(
41+
tensor.to("cpu"),
42+
tensor.requires_grad,
43+
)
44+
)
45+
else:
46+
exported_program._state_dict[buffer_name] = (
47+
tensor.to("cpu")
48+
)
49+
50+
if buffer_name in exported_program.constants:
51+
tensor = exported_program._constants[buffer_name]
52+
if tensor.device.type != "cpu":
53+
exported_program._constants[buffer_name] = tensor.to(
54+
"cpu"
55+
)
56+
57+
# Also update the placeholder metadata
58+
if "val" in pred_node.meta:
59+
fake_tensor = pred_node.meta["val"]
60+
if isinstance(fake_tensor, torch.Tensor):
61+
pred_node.meta["val"] = fake_tensor.to("cpu")
62+
exported_program.validate()
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import unittest
2+
3+
import torch
4+
from executorch.backends.cuda.passes.keep_cond_predicate_on_cpu import (
5+
KeepCondPredicateOnCpuPass,
6+
)
7+
from torch.export import export
8+
9+
10+
class TestKeepCondPredicateOnCpuPass(unittest.TestCase):
11+
def test_keep_cond_predicate_on_cpu(self):
12+
# Define a simple model using torch.cond
13+
class Model(torch.nn.Module):
14+
def forward(self, pred, x, y):
15+
def true_fn(x, y):
16+
return x + y
17+
18+
def false_fn(x, y):
19+
return x - y
20+
21+
return torch.cond(pred, true_fn, false_fn, [x, y])
22+
23+
model = Model()
24+
pred = torch.tensor(True)
25+
x = torch.randn(2, 2)
26+
y = torch.randn(2, 2)
27+
28+
# Export the model
29+
ep = export(model, (pred, x, y))
30+
gm = ep.graph_module
31+
32+
# Simulate move_to_device_pass by setting all placeholders to cuda using FakeTensorMode
33+
# We need to be careful not to trigger CUDA init
34+
from unittest.mock import MagicMock
35+
36+
for node in gm.graph.nodes:
37+
if node.op == "placeholder":
38+
if "val" in node.meta:
39+
# Use MagicMock to simulate a tensor on cuda
40+
val = MagicMock(spec=torch.Tensor)
41+
val.device = torch.device("cuda")
42+
43+
def to_side_effect(device):
44+
new_val = MagicMock(spec=torch.Tensor)
45+
new_val.device = torch.device(device)
46+
return new_val
47+
48+
val.to.side_effect = to_side_effect
49+
node.meta["val"] = val
50+
51+
# Verify that pred is on cuda
52+
pred_node = list(gm.graph.nodes)[0]
53+
self.assertEqual(pred_node.meta["val"].device.type, "cuda")
54+
55+
# Run the pass
56+
pass_instance = KeepCondPredicateOnCpuPass()
57+
pass_instance(gm)
58+
59+
# Verify that pred is back on cpu
60+
self.assertEqual(pred_node.meta["val"].device.type, "cpu")
61+
62+
# Verify other nodes are still on cuda (if they were)
63+
# The second node is x
64+
x_node = list(gm.graph.nodes)[1]
65+
self.assertEqual(x_node.meta["val"].device.type, "cuda")
66+
67+
68+
if __name__ == "__main__":
69+
unittest.main()

0 commit comments

Comments
 (0)