|
5 | 5 | import contextlib |
6 | 6 | import dataclasses |
7 | 7 | import functools |
| 8 | +import itertools |
8 | 9 | import operator |
9 | 10 | import re |
10 | 11 | import textwrap |
@@ -450,7 +451,18 @@ def _body(self, body: list[ast.stmt]) -> None: |
450 | 451 | self.visit(stmt) |
451 | 452 |
|
452 | 453 | def visit_BinOp(self, node: ast.BinOp) -> object: |
453 | | - return _eval_binary(node.op, self.visit(node.left), self.visit(node.right)) |
| 454 | + left = self.visit(node.left) |
| 455 | + right = self.visit(node.right) |
| 456 | + # Special handling for Tile + offset: expand to tile.index + offset |
| 457 | + # and mark with metadata for indexing strategies to recognize |
| 458 | + if ( |
| 459 | + isinstance(node.op, ast.Add) |
| 460 | + and isinstance(left, Tile) |
| 461 | + and isinstance(right, (int, torch.SymInt)) |
| 462 | + ): |
| 463 | + # Implicitly expand to tile.index + offset |
| 464 | + left = hl.tile_index(left) |
| 465 | + return _eval_binary(node.op, left, right) |
454 | 466 |
|
455 | 467 | def visit_UnaryOp(self, node: ast.UnaryOp) -> object: |
456 | 468 | return _eval_unary(node.op, self.visit(node.operand)) |
@@ -1128,6 +1140,7 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR: |
1128 | 1140 | prepare_graph_lowerings(graph.graph) |
1129 | 1141 | for graph in device_ir.graphs: |
1130 | 1142 | validate_host_tensor_usage(graph.graph) |
| 1143 | + add_tile_with_offset_metadata(graph) |
1131 | 1144 | remove_unnecessary_tile_index(graph.graph) |
1132 | 1145 | remove_unnecessary_masking(graph.graph) |
1133 | 1146 | device_ir.build_rolled_reductions() |
@@ -1193,6 +1206,68 @@ def validate_host_tensor_usage(graph: torch.fx.Graph) -> None: |
1193 | 1206 | raise exc.HostTensorDirectUsage(scalar_tensor_name, op_name) |
1194 | 1207 |
|
1195 | 1208 |
|
| 1209 | +def add_tile_with_offset_metadata(graph_info: GraphInfo) -> None: |
| 1210 | + """ |
| 1211 | + Recognize tile.index + offset patterns and add metadata to enable tensor descriptor indexing. |
| 1212 | +
|
| 1213 | + This pass identifies FX nodes that represent `tile.index + offset` (where offset is an |
| 1214 | + integer or SymInt), and adds the `tile_with_offset` metadata to those nodes so that |
| 1215 | + indexing strategies can generate efficient code (e.g., tensor descriptors) for them. |
| 1216 | + """ |
| 1217 | + graph = graph_info.graph |
| 1218 | + env = CompileEnvironment.current() |
| 1219 | + |
| 1220 | + for node in itertools.chain( |
| 1221 | + graph.find_nodes(op="call_function", target=operator.add), |
| 1222 | + graph.find_nodes(op="call_function", target=torch.ops.aten.add.Tensor), |
| 1223 | + ): |
| 1224 | + # Check if this is tile.index + offset pattern |
| 1225 | + # args[0] should be tile_index result, args[1] should be int/SymInt |
| 1226 | + if len(node.args) != 2 and not node.kwargs: |
| 1227 | + continue |
| 1228 | + left_arg, right_arg = node.args |
| 1229 | + |
| 1230 | + # Check if left argument is a tile_index call |
| 1231 | + if ( |
| 1232 | + not isinstance(left_arg, torch.fx.Node) |
| 1233 | + or left_arg.op != "call_function" |
| 1234 | + or left_arg.target != hl.tile_index |
| 1235 | + ): |
| 1236 | + continue |
| 1237 | + |
| 1238 | + # Check if right argument is an integer offset |
| 1239 | + # It could be a constant, SymInt node, or another value |
| 1240 | + # We accept int, SymInt, or nodes that represent them |
| 1241 | + offset = None |
| 1242 | + if isinstance(right_arg, (int, torch.SymInt)): |
| 1243 | + offset = right_arg |
| 1244 | + elif isinstance(right_arg, torch.fx.Node): |
| 1245 | + # Check the node's metadata for the value |
| 1246 | + val = right_arg.meta.get("val") |
| 1247 | + if isinstance(val, (int, torch.SymInt)): |
| 1248 | + offset = val |
| 1249 | + |
| 1250 | + if offset is None: |
| 1251 | + continue |
| 1252 | + |
| 1253 | + # Extract the block_id from the tile_index call |
| 1254 | + tile_arg = left_arg.args[0] |
| 1255 | + block_id = None |
| 1256 | + if isinstance(tile_arg, torch.fx.Node) and isinstance( |
| 1257 | + tile_arg.meta["val"], torch.SymInt |
| 1258 | + ): |
| 1259 | + block_id = env.get_block_id(tile_arg.meta["val"]) |
| 1260 | + |
| 1261 | + if block_id is None: |
| 1262 | + continue |
| 1263 | + |
| 1264 | + # Add metadata to mark this as a tile+offset node |
| 1265 | + node.meta["tile_with_offset"] = { |
| 1266 | + "block_id": block_id, |
| 1267 | + "offset": offset, |
| 1268 | + } |
| 1269 | + |
| 1270 | + |
1196 | 1271 | def remove_unnecessary_tile_index(graph: torch.fx.Graph) -> None: |
1197 | 1272 | """ |
1198 | 1273 | Remove unnecessary tile_index nodes from the graph. |
|
0 commit comments