Skip to content

Commit 276e473

Browse files
committed
Support tile+offset and tensor descriptors
stack-info: PR: #928, branch: jansel/stack/189
1 parent dbf666e commit 276e473

File tree

7 files changed

+479
-18
lines changed

7 files changed

+479
-18
lines changed

helion/_compiler/device_ir.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import contextlib
66
import dataclasses
77
import functools
8+
import itertools
89
import operator
910
import re
1011
import textwrap
@@ -450,7 +451,18 @@ def _body(self, body: list[ast.stmt]) -> None:
450451
self.visit(stmt)
451452

452453
def visit_BinOp(self, node: ast.BinOp) -> object:
453-
return _eval_binary(node.op, self.visit(node.left), self.visit(node.right))
454+
left = self.visit(node.left)
455+
right = self.visit(node.right)
456+
# Special handling for Tile + offset: expand to tile.index + offset
457+
# and mark with metadata for indexing strategies to recognize
458+
if (
459+
isinstance(node.op, ast.Add)
460+
and isinstance(left, Tile)
461+
and isinstance(right, (int, torch.SymInt))
462+
):
463+
# Implicitly expand to tile.index + offset
464+
left = hl.tile_index(left)
465+
return _eval_binary(node.op, left, right)
454466

455467
def visit_UnaryOp(self, node: ast.UnaryOp) -> object:
456468
return _eval_unary(node.op, self.visit(node.operand))
@@ -1128,6 +1140,7 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
11281140
prepare_graph_lowerings(graph.graph)
11291141
for graph in device_ir.graphs:
11301142
validate_host_tensor_usage(graph.graph)
1143+
add_tile_with_offset_metadata(graph)
11311144
remove_unnecessary_tile_index(graph.graph)
11321145
remove_unnecessary_masking(graph.graph)
11331146
device_ir.build_rolled_reductions()
@@ -1193,6 +1206,68 @@ def validate_host_tensor_usage(graph: torch.fx.Graph) -> None:
11931206
raise exc.HostTensorDirectUsage(scalar_tensor_name, op_name)
11941207

11951208

1209+
def add_tile_with_offset_metadata(graph_info: GraphInfo) -> None:
1210+
"""
1211+
Recognize tile.index + offset patterns and add metadata to enable tensor descriptor indexing.
1212+
1213+
This pass identifies FX nodes that represent `tile.index + offset` (where offset is an
1214+
integer or SymInt), and adds the `tile_with_offset` metadata to those nodes so that
1215+
indexing strategies can generate efficient code (e.g., tensor descriptors) for them.
1216+
"""
1217+
graph = graph_info.graph
1218+
env = CompileEnvironment.current()
1219+
1220+
for node in itertools.chain(
1221+
graph.find_nodes(op="call_function", target=operator.add),
1222+
graph.find_nodes(op="call_function", target=torch.ops.aten.add.Tensor),
1223+
):
1224+
# Check if this is tile.index + offset pattern
1225+
# args[0] should be tile_index result, args[1] should be int/SymInt
1226+
if len(node.args) != 2 and not node.kwargs:
1227+
continue
1228+
left_arg, right_arg = node.args
1229+
1230+
# Check if left argument is a tile_index call
1231+
if (
1232+
not isinstance(left_arg, torch.fx.Node)
1233+
or left_arg.op != "call_function"
1234+
or left_arg.target != hl.tile_index
1235+
):
1236+
continue
1237+
1238+
# Check if right argument is an integer offset
1239+
# It could be a constant, SymInt node, or another value
1240+
# We accept int, SymInt, or nodes that represent them
1241+
offset = None
1242+
if isinstance(right_arg, (int, torch.SymInt)):
1243+
offset = right_arg
1244+
elif isinstance(right_arg, torch.fx.Node):
1245+
# Check the node's metadata for the value
1246+
val = right_arg.meta.get("val")
1247+
if isinstance(val, (int, torch.SymInt)):
1248+
offset = val
1249+
1250+
if offset is None:
1251+
continue
1252+
1253+
# Extract the block_id from the tile_index call
1254+
tile_arg = left_arg.args[0]
1255+
block_id = None
1256+
if isinstance(tile_arg, torch.fx.Node) and isinstance(
1257+
tile_arg.meta["val"], torch.SymInt
1258+
):
1259+
block_id = env.get_block_id(tile_arg.meta["val"])
1260+
1261+
if block_id is None:
1262+
continue
1263+
1264+
# Add metadata to mark this as a tile+offset node
1265+
node.meta["tile_with_offset"] = {
1266+
"block_id": block_id,
1267+
"offset": offset,
1268+
}
1269+
1270+
11961271
def remove_unnecessary_tile_index(graph: torch.fx.Graph) -> None:
11971272
"""
11981273
Remove unnecessary tile_index nodes from the graph.

0 commit comments

Comments
 (0)