Skip to content

Commit c95d461

Browse files
committed
Generalize tile_with_offset pass
stack-info: PR: #949, branch: jansel/stack/195
1 parent 63771cf commit c95d461

File tree

4 files changed

+136
-48
lines changed

4 files changed

+136
-48
lines changed

helion/_compiler/device_ir.py

Lines changed: 66 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import contextlib
66
import dataclasses
77
import functools
8-
import itertools
98
import operator
109
import re
1110
import textwrap
@@ -1216,55 +1215,82 @@ def add_tile_with_offset_metadata(graph_info: GraphInfo) -> None:
12161215
"""
12171216
graph = graph_info.graph
12181217
env = CompileEnvironment.current()
1219-
1220-
for node in itertools.chain(
1221-
graph.find_nodes(op="call_function", target=operator.add),
1222-
graph.find_nodes(op="call_function", target=torch.ops.aten.add.Tensor),
1223-
):
1224-
# Check if this is tile.index + offset pattern
1225-
# args[0] should be tile_index result, args[1] should be int/SymInt
1226-
if len(node.args) != 2 and not node.kwargs:
1227-
continue
1228-
left_arg, right_arg = node.args
1229-
1230-
# Check if left argument is a tile_index call
1218+
add_targets = (operator.add, torch.ops.aten.add.Tensor)
1219+
offset_types = (int, torch.SymInt)
1220+
for node in graph.nodes:
12311221
if (
1232-
not isinstance(left_arg, torch.fx.Node)
1233-
or left_arg.op != "call_function"
1234-
or left_arg.target != hl.tile_index
1222+
node.op != "call_function"
1223+
or node.target not in add_targets
1224+
or node.kwargs
1225+
or len(node.args) != 2
12351226
):
12361227
continue
12371228

1238-
# Check if right argument is an integer offset
1239-
# It could be a constant, SymInt node, or another value
1240-
# We accept int, SymInt, or nodes that represent them
1241-
offset = None
1242-
if isinstance(right_arg, (int, torch.SymInt)):
1243-
offset = right_arg
1244-
elif isinstance(right_arg, torch.fx.Node):
1245-
# Check the node's metadata for the value
1246-
val = right_arg.meta.get("val")
1247-
if isinstance(val, (int, torch.SymInt)):
1248-
offset = val
1249-
1250-
if offset is None:
1251-
continue
1229+
block_id: int | None = None
1230+
total_offset: int | torch.SymInt = 0
1231+
valid = True
12521232

1253-
# Extract the block_id from the tile_index call
1254-
tile_arg = left_arg.args[0]
1255-
block_id = None
1256-
if isinstance(tile_arg, torch.fx.Node) and isinstance(
1257-
tile_arg.meta["val"], torch.SymInt
1258-
):
1259-
block_id = env.get_block_id(tile_arg.meta["val"])
1233+
for arg in node.args:
1234+
tile_offset_value: int | torch.SymInt | None = None
1235+
arg_block_id: int | None = None
1236+
1237+
if isinstance(arg, torch.fx.Node):
1238+
meta_tile = arg.meta.get("tile_with_offset")
1239+
if meta_tile is not None:
1240+
arg_block_id = meta_tile.get("block_id")
1241+
if arg_block_id is None:
1242+
valid = False
1243+
break
1244+
tile_offset_value = meta_tile.get("offset", 0)
1245+
elif (
1246+
arg.op == "call_function"
1247+
and arg.target == hl.tile_index
1248+
and arg.args
1249+
and isinstance(arg.args[0], torch.fx.Node)
1250+
):
1251+
tile_val = arg.args[0].meta.get("val")
1252+
if isinstance(tile_val, torch.SymInt):
1253+
arg_block_id = env.get_block_id(tile_val)
1254+
if arg_block_id is None:
1255+
valid = False
1256+
break
1257+
tile_offset_value = 0
1258+
else:
1259+
val = arg.meta.get("val")
1260+
if isinstance(val, offset_types):
1261+
total_offset = total_offset + val
1262+
continue
1263+
1264+
if arg_block_id is not None:
1265+
if block_id is not None:
1266+
valid = False
1267+
break
1268+
if tile_offset_value is None:
1269+
tile_offset_value = 0
1270+
block_id = arg_block_id
1271+
total_offset = total_offset + tile_offset_value
1272+
continue
1273+
1274+
val = arg.meta.get("val")
1275+
if isinstance(val, offset_types):
1276+
total_offset = total_offset + val
1277+
continue
1278+
1279+
valid = False
1280+
break
1281+
1282+
if isinstance(arg, offset_types):
1283+
total_offset = total_offset + arg
1284+
continue
1285+
valid = False
1286+
break
12601287

1261-
if block_id is None:
1288+
if not valid or block_id is None:
12621289
continue
12631290

1264-
# Add metadata to mark this as a tile+offset node
12651291
node.meta["tile_with_offset"] = {
12661292
"block_id": block_id,
1267-
"offset": offset,
1293+
"offset": total_offset,
12681294
}
12691295

12701296

test/test_indexing.expected

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,36 @@ def pairwise_add(x: torch.Tensor, *, _launcher=_default_launcher):
347347
_launcher(_helion_pairwise_add, (triton.cdiv(499, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2)
348348
return out
349349

350+
--- assertExpectedJournal(TestIndexing.test_pairwise_add_commuted_and_multi_offset)
351+
from __future__ import annotations
352+
353+
import torch
354+
import triton
355+
import triton.language as tl
356+
from helion.runtime import default_launcher as _default_launcher
357+
358+
@triton.jit
359+
def _helion_pairwise_add_variants(x, out, _BLOCK_SIZE_0: tl.constexpr):
360+
pid_0 = tl.program_id(0)
361+
offset_0 = pid_0 * _BLOCK_SIZE_0
362+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
363+
v_0 = tl.full([], 1, tl.int32)
364+
v_1 = indices_0 + v_0
365+
left = tl.load(tl.make_block_ptr(x, [256], [1], [offset_0 + 1], [_BLOCK_SIZE_0], [0]), boundary_check=[0], padding_option='zero')
366+
v_2 = tl.full([], 1, tl.int32)
367+
v_3 = indices_0 + v_2
368+
v_4 = tl.full([], 2, tl.int32)
369+
v_5 = v_3 + v_4
370+
right = tl.load(tl.make_block_ptr(x, [256], [1], [offset_0 + 3], [_BLOCK_SIZE_0], [0]), boundary_check=[0], padding_option='zero')
371+
v_6 = left + right
372+
tl.store(tl.make_block_ptr(out, [253], [1], [offset_0], [_BLOCK_SIZE_0], [0]), v_6, boundary_check=[0])
373+
374+
def pairwise_add_variants(x: torch.Tensor, *, _launcher=_default_launcher):
375+
out = x.new_empty([x.size(0) - 3])
376+
_BLOCK_SIZE_0 = 32
377+
_launcher(_helion_pairwise_add_variants, (triton.cdiv(253, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2)
378+
return out
379+
350380
--- assertExpectedJournal(TestIndexing.test_reduction_tensor_descriptor_indexing_block_size)
351381
from __future__ import annotations
352382

test/test_indexing.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,27 @@ def pairwise_add(x: torch.Tensor) -> torch.Tensor:
168168
torch.testing.assert_close(result, x[:-1] + x[1:])
169169
self.assertExpectedJournal(code)
170170

171+
def test_pairwise_add_commuted_and_multi_offset(self):
172+
@helion.kernel()
173+
def pairwise_add_variants(x: torch.Tensor) -> torch.Tensor:
174+
out = x.new_empty([x.size(0) - 3])
175+
for tile in hl.tile(out.size(0)):
176+
left = x[1 + tile.index]
177+
right = x[tile.index + 1 + 2]
178+
out[tile] = left + right
179+
return out
180+
181+
x = torch.randn([256], device=DEVICE)
182+
code, result = code_and_output(
183+
pairwise_add_variants,
184+
(x,),
185+
block_size=32,
186+
indexing="block_ptr",
187+
)
188+
expected = x[1:-2] + x[3:]
189+
torch.testing.assert_close(result, expected)
190+
self.assertExpectedJournal(code)
191+
171192
def test_mask_store(self):
172193
@helion.kernel
173194
def masked_store(x: torch.Tensor) -> torch.Tensor:
@@ -398,6 +419,19 @@ def run_case(
398419
small_shape = (128, 128)
399420
large_shape = (51200, 51200)
400421

422+
if DEVICE.type == "cuda":
423+
free_bytes, _ = torch.cuda.mem_get_info()
424+
element_size = 2 # torch.bfloat16 element size in bytes
425+
# Worst case: inputs, kernel output, reference output, and temporary buffers.
426+
# Give ourselves margin by budgeting for 5 tensors of this shape.
427+
required_bytes = 5 * math.prod(large_shape) * element_size
428+
if free_bytes < required_bytes:
429+
required_gib = required_bytes / (1024**3)
430+
available_gib = free_bytes / (1024**3)
431+
self.skipTest(
432+
f"Large BF16 add needs ~{required_gib:.1f} GiB free, only {available_gib:.1f} GiB available"
433+
)
434+
401435
run_case(
402436
small_shape,
403437
index_dtype=torch.int32,

test/test_persistent_kernels.expected

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,32 +1598,30 @@ import triton.language as tl
15981598
from helion.runtime import default_launcher as _default_launcher
15991599

16001600
@triton.jit
1601-
def _helion_test_kernel(x, result, x_size_0, x_size_1, result_stride_0, result_stride_1, x_stride_0, x_stride_1, _NUM_SM: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
1602-
total_pids = tl.cdiv(x_size_0, _BLOCK_SIZE_0) * tl.cdiv(x_size_1, _BLOCK_SIZE_1)
1601+
def _helion_test_kernel(x, result, _NUM_SM: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
1602+
total_pids = tl.cdiv(64, _BLOCK_SIZE_0) * tl.cdiv(96, _BLOCK_SIZE_1)
16031603
block_size = tl.cdiv(total_pids, _NUM_SM)
16041604
start_pid = tl.program_id(0) * block_size
16051605
end_pid = tl.minimum(start_pid + block_size, total_pids)
16061606
for virtual_pid in tl.range(start_pid, end_pid, warp_specialize=True):
1607-
num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
1607+
num_blocks_0 = tl.cdiv(64, _BLOCK_SIZE_0)
16081608
pid_0 = virtual_pid % num_blocks_0
16091609
pid_1 = virtual_pid // num_blocks_0
16101610
offset_0 = pid_0 * _BLOCK_SIZE_0
16111611
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1612-
mask_0 = indices_0 < x_size_0
16131612
offset_1 = pid_1 * _BLOCK_SIZE_1
16141613
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1615-
mask_1 = indices_1 < x_size_1
1616-
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
1614+
load = tl.load(x + (indices_0[:, None] * 96 + indices_1[None, :] * 1), None)
16171615
v_0 = 1.0
16181616
v_1 = load + v_0
1619-
tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), v_1, mask_0[:, None] & mask_1[None, :])
1617+
tl.store(result + (indices_0[:, None] * 96 + indices_1[None, :] * 1), v_1, None)
16201618

16211619
def test_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
16221620
result = x.new_empty(x.size())
16231621
_NUM_SM = helion.runtime.get_num_sm(x.device)
16241622
_BLOCK_SIZE_0 = 32
16251623
_BLOCK_SIZE_1 = 16
1626-
_launcher(_helion_test_kernel, (_NUM_SM,), x, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2)
1624+
_launcher(_helion_test_kernel, (_NUM_SM,), x, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2)
16271625
return result
16281626

16291627
--- assertExpectedJournal(TestPersistentKernels.test_persistent_loop_variable_names)

0 commit comments

Comments
 (0)