Skip to content

Commit cdbedf6

Browse files
committed
Add epilogue subtiling
stack-info: PR: #948, branch: PaulZhang12/stack/14
1 parent b77301f commit cdbedf6

File tree

3 files changed

+215
-98
lines changed

3 files changed

+215
-98
lines changed

examples/matmul.py

Lines changed: 105 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,19 @@
2828
@helion.kernel(
2929
# static_shapes=True gives a performance boost for matmuls
3030
static_shapes=True,
31+
config=helion.Config(
32+
block_sizes=[64, 64, 64],
33+
loop_orders=[[0, 1]],
34+
l2_groupings=[4],
35+
range_unroll_factors=[0, 1],
36+
range_num_stages=[0, 3],
37+
range_multi_buffers=[None, False],
38+
range_flattens=[None, None],
39+
num_warps=8,
40+
num_stages=6,
41+
indexing='tensor_descriptor',
42+
pid_type='flat'
43+
)
3144
)
3245
def matmul(
3346
x: Tensor,
@@ -44,6 +57,7 @@ def matmul(
4457
Returns:
4558
Tensor: Resulting matrix of shape [m, n].
4659
"""
60+
4761
m, k = x.size()
4862
k2, n = y.size()
4963
assert k == k2, f"size mismatch {k} != {k2}"
@@ -298,97 +312,97 @@ def check(m: int, k: int, n: int) -> None:
298312
# Test without bias
299313
run_example(matmul, torch.matmul, (x, y))
300314

301-
# Test for addmm with scalar bias
302-
def addmm(bias: Tensor, mat1: Tensor, mat2: Tensor) -> Tensor:
303-
m, k = mat1.size()
304-
k2, n = mat2.size()
305-
bias = torch.broadcast_to(bias, [m, n])
306-
return matmul(mat1, mat2, lambda acc, tile: acc + bias[tile[0], tile[1]])
307-
308-
run_example(addmm, torch.addmm, (bias_scalar, x, y))
309-
310-
# Test with bias
311-
def helion_linear(x: Tensor, y: Tensor, bias: Tensor) -> Tensor:
312-
return matmul(x, y, lambda acc, tile: acc + bias[tile[1]])
313-
314-
def baseline_linear(x: Tensor, y: Tensor, bias: Tensor) -> Tensor:
315-
return torch.nn.functional.linear(x, y.T, bias)
316-
317-
run_example(helion_linear, baseline_linear, (x, y, bias))
318-
319-
# Test more complex epilogue
320-
def epilogue(acc: Tensor, tile: tuple[Tensor, ...]) -> Tensor:
321-
# The epilogue can use the captured bias tensor that is implicitly lifted to a kernel arg
322-
return torch.relu(acc + bias[tile[1]])
323-
324-
def kernel_wrapper(x: Tensor, y: Tensor) -> Tensor:
325-
return matmul(x, y, epilogue)
326-
327-
def baseline_wrapper(x: Tensor, y: Tensor) -> Tensor:
328-
return torch.relu(x @ y + bias)
329-
330-
run_example(
331-
kernel_wrapper,
332-
baseline_wrapper,
333-
(x, y),
334-
)
335-
336-
# Test matmul forward + backward pass
337-
print("\n\n=== MatMul Forward + Backward Pass Test ===")
338-
x_grad = torch.randn([m, k], device=DEVICE, dtype=torch.float16, requires_grad=True)
339-
y_grad = torch.randn([k, n], device=DEVICE, dtype=torch.float16, requires_grad=True)
340-
341-
run_example(
342-
matmul_autograd,
343-
torch.matmul,
344-
(x_grad, y_grad),
345-
kernel_name="helion_matmul_autograd",
346-
baseline_name="torch",
347-
rtol=1e-2,
348-
atol=1e-2,
349-
bwd=True,
350-
)
351-
352-
# Test addmm forward + backward pass
353-
print("\n\n=== AddMM Forward + Backward Pass Test ===")
354-
input_grad = torch.randn(
355-
[m, n], device=DEVICE, dtype=torch.float16, requires_grad=True
356-
)
357-
mat1_grad = torch.randn(
358-
[m, k], device=DEVICE, dtype=torch.float16, requires_grad=True
359-
)
360-
mat2_grad = torch.randn(
361-
[k, n], device=DEVICE, dtype=torch.float16, requires_grad=True
362-
)
363-
364-
# Use lambda to handle the keyword argument format for torch.addmm
365-
run_example(
366-
addmm_autograd,
367-
lambda bias, mat1, mat2, alpha, beta: torch.addmm(
368-
bias, mat1, mat2, alpha=alpha, beta=beta
369-
),
370-
(input_grad, mat1_grad, mat2_grad, 1.0, 1.0),
371-
kernel_name="helion_addmm_autograd",
372-
baseline_name="torch",
373-
rtol=1e-2,
374-
atol=1e-2,
375-
bwd=True,
376-
)
377-
378-
# Test addmm forward + backward with different alpha/beta values
379-
print("\n\n=== AddMM Forward + Backward Test (Alpha=2.0, Beta=0.5) ===")
380-
run_example(
381-
addmm_autograd,
382-
lambda bias, mat1, mat2, alpha, beta: torch.addmm(
383-
bias, mat1, mat2, alpha=alpha, beta=beta
384-
),
385-
(input_grad, mat1_grad, mat2_grad, 2.0, 0.5),
386-
kernel_name="helion_addmm_autograd_scaled",
387-
baseline_name="torch",
388-
rtol=1e-2,
389-
atol=1e-2,
390-
bwd=True,
391-
)
315+
# # Test for addmm with scalar bias
316+
# def addmm(bias: Tensor, mat1: Tensor, mat2: Tensor) -> Tensor:
317+
# m, k = mat1.size()
318+
# k2, n = mat2.size()
319+
# bias = torch.broadcast_to(bias, [m, n])
320+
# return matmul(mat1, mat2, lambda acc, tile: acc + bias[tile[0], tile[1]])
321+
322+
# run_example(addmm, torch.addmm, (bias_scalar, x, y))
323+
324+
# # Test with bias
325+
# def helion_linear(x: Tensor, y: Tensor, bias: Tensor) -> Tensor:
326+
# return matmul(x, y, lambda acc, tile: acc + bias[tile[1]])
327+
328+
# def baseline_linear(x: Tensor, y: Tensor, bias: Tensor) -> Tensor:
329+
# return torch.nn.functional.linear(x, y.T, bias)
330+
331+
# run_example(helion_linear, baseline_linear, (x, y, bias))
332+
333+
# # Test more complex epilogue
334+
# def epilogue(acc: Tensor, tile: tuple[Tensor, ...]) -> Tensor:
335+
# # The epilogue can use the captured bias tensor that is implicitly lifted to a kernel arg
336+
# return torch.relu(acc + bias[tile[1]])
337+
338+
# def kernel_wrapper(x: Tensor, y: Tensor) -> Tensor:
339+
# return matmul(x, y, epilogue)
340+
341+
# def baseline_wrapper(x: Tensor, y: Tensor) -> Tensor:
342+
# return torch.relu(x @ y + bias)
343+
344+
# run_example(
345+
# kernel_wrapper,
346+
# baseline_wrapper,
347+
# (x, y),
348+
# )
349+
350+
# # Test matmul forward + backward pass
351+
# print("\n\n=== MatMul Forward + Backward Pass Test ===")
352+
# x_grad = torch.randn([m, k], device=DEVICE, dtype=torch.float16, requires_grad=True)
353+
# y_grad = torch.randn([k, n], device=DEVICE, dtype=torch.float16, requires_grad=True)
354+
355+
# run_example(
356+
# matmul_autograd,
357+
# torch.matmul,
358+
# (x_grad, y_grad),
359+
# kernel_name="helion_matmul_autograd",
360+
# baseline_name="torch",
361+
# rtol=1e-2,
362+
# atol=1e-2,
363+
# bwd=True,
364+
# )
365+
366+
# # Test addmm forward + backward pass
367+
# print("\n\n=== AddMM Forward + Backward Pass Test ===")
368+
# input_grad = torch.randn(
369+
# [m, n], device=DEVICE, dtype=torch.float16, requires_grad=True
370+
# )
371+
# mat1_grad = torch.randn(
372+
# [m, k], device=DEVICE, dtype=torch.float16, requires_grad=True
373+
# )
374+
# mat2_grad = torch.randn(
375+
# [k, n], device=DEVICE, dtype=torch.float16, requires_grad=True
376+
# )
377+
378+
# # Use lambda to handle the keyword argument format for torch.addmm
379+
# run_example(
380+
# addmm_autograd,
381+
# lambda bias, mat1, mat2, alpha, beta: torch.addmm(
382+
# bias, mat1, mat2, alpha=alpha, beta=beta
383+
# ),
384+
# (input_grad, mat1_grad, mat2_grad, 1.0, 1.0),
385+
# kernel_name="helion_addmm_autograd",
386+
# baseline_name="torch",
387+
# rtol=1e-2,
388+
# atol=1e-2,
389+
# bwd=True,
390+
# )
391+
392+
# # Test addmm forward + backward with different alpha/beta values
393+
# print("\n\n=== AddMM Forward + Backward Test (Alpha=2.0, Beta=0.5) ===")
394+
# run_example(
395+
# addmm_autograd,
396+
# lambda bias, mat1, mat2, alpha, beta: torch.addmm(
397+
# bias, mat1, mat2, alpha=alpha, beta=beta
398+
# ),
399+
# (input_grad, mat1_grad, mat2_grad, 2.0, 0.5),
400+
# kernel_name="helion_addmm_autograd_scaled",
401+
# baseline_name="torch",
402+
# rtol=1e-2,
403+
# atol=1e-2,
404+
# bwd=True,
405+
# )
392406

393407

394408
# %%

helion/_compiler/device_function.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,9 +415,14 @@ def tensor_arg(
415415
def tensor_descriptor_arg(
416416
self, fake_value: torch.Tensor, block_size: list[int | torch.SymInt]
417417
) -> TensorDescriptorArg:
418+
import re
418419
host_function = HostFunction.current()
419420
block_size_expr = ", ".join(map(self.literal_expr, block_size))
421+
pattern = r'triton_helpers\.div_floor_integer\(([^,]+),\s*(\d+)\)'
422+
replacement = r'\1 // \2'
423+
block_size_expr = re.sub(pattern, replacement, block_size_expr)
420424
key = (fake_value, block_size_expr)
425+
421426
if key not in self._tensor_descriptor_args:
422427
origin = host_function.tensor_to_origin[fake_value]
423428
desc_name = self.new_var(origin.suggest_var_name() + "_desc")

helion/_compiler/indexing_strategy.py

Lines changed: 105 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .. import exc
1616
from .._compat import get_tensor_descriptor_fn_name
1717
from .ast_extension import expr_from_string
18+
from .ast_extension import statement_from_string
1819
from .compile_environment import CompileEnvironment
1920
from .device_function import DeviceFunction
2021
from .host_function import HostFunction
@@ -385,21 +386,118 @@ def codegen_store(
385386
indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript)
386387

387388
# Apply permutation to the value being stored if needed
388-
desc_arg = indexing.tensor_descriptor_arg(state)
389+
# desc_arg = indexing.tensor_descriptor_arg(state, subtile=True)
389390
store_value = indexing.reshape_store(state, value)
390391

391-
if desc_arg.permutation is not None:
392-
# Apply permutation to the value
393-
store_value = expr_from_string(
394-
f"tl.permute({{store_val}}, {desc_arg.permutation!r})",
395-
store_val=store_value,
392+
# if desc_arg.permutation is not None:
393+
# # Apply permutation to the value
394+
# store_value = expr_from_string(
395+
# f"tl.permute({{store_val}}, {desc_arg.permutation!r})",
396+
# store_val=store_value,
397+
# )
398+
399+
if (
400+
subtile_store := self._codegen_epilogue_subtile_store(
401+
state, fake_tensor, indexing, store_value
396402
)
397-
403+
) is not None:
404+
return subtile_store
405+
398406
return expr_from_string(
399407
f"{indexing.tensor_descriptor(state)}.store({indexing.offsets_str_permuted(state)}, {{value}})",
400408
value=store_value,
401409
)
402410

411+
def _codegen_epilogue_subtile_store(
412+
self,
413+
state: CodegenState,
414+
fake_tensor: torch.Tensor,
415+
indexing: BlockedSubscriptIndexing,
416+
store_value: ast.AST,
417+
) -> ast.AST | None:
418+
# Currently support 2D tiles without permutations
419+
if len(indexing.block_shape) != 2 or len(indexing.offsets) != 2:
420+
return None
421+
422+
env = CompileEnvironment.current()
423+
block_m, block_n = indexing.block_shape
424+
try:
425+
block_n_hint = env.size_hint(block_n)
426+
except Exception:
427+
return None
428+
429+
if block_n_hint % 2 != 0:
430+
return None
431+
432+
device_fn = state.device_function
433+
codegen = state.codegen
434+
435+
block_m_str = device_fn.literal_expr(block_m)
436+
block_n_str = device_fn.literal_expr(block_n)
437+
indexing.block_shape[1] //= 2
438+
desc_arg = indexing.tensor_descriptor_arg(state)
439+
440+
if desc_arg.permutation is not None:
441+
return None
442+
443+
444+
block_n_half_str = f"({block_n_str} // 2)"
445+
446+
# Lift the store value into a temporary variable for reuse
447+
acc_var = codegen.lift(store_value, prefix="acc")
448+
449+
reshape_expr = expr_from_string(
450+
"tl.reshape({acc}, [{dim_m}, 2, {dim_half}])",
451+
acc=acc_var,
452+
dim_m=expr_from_string(block_m_str),
453+
dim_half=expr_from_string(block_n_half_str),
454+
)
455+
reshape_var = codegen.lift(reshape_expr, prefix="acc")
456+
457+
permute_expr = expr_from_string(
458+
"tl.permute({acc}, [0, 2, 1])",
459+
acc=reshape_var,
460+
)
461+
permute_var = codegen.lift(permute_expr, prefix="acc")
462+
463+
acc0_name = codegen.tmpvar(prefix="acc")
464+
acc1_name = codegen.tmpvar(prefix="acc")
465+
codegen.add_statement(
466+
statement_from_string(
467+
f"{acc0_name}, {acc1_name} = tl.split({{acc}})",
468+
acc=permute_var,
469+
)
470+
)
471+
acc0 = expr_from_string(acc0_name)
472+
acc1 = expr_from_string(acc1_name)
473+
474+
desc_name = indexing.tensor_descriptor(state)
475+
offset0 = expr_from_string(indexing.offsets[0])
476+
offset1 = expr_from_string(indexing.offsets[1])
477+
478+
# First subtile store
479+
codegen.add_statement(
480+
statement_from_string(
481+
f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})",
482+
off0=offset0,
483+
off1=offset1,
484+
value=acc0,
485+
)
486+
)
487+
488+
offset1_shifted = expr_from_string(
489+
"({offset} + {half})",
490+
offset=expr_from_string(indexing.offsets[1]),
491+
half=expr_from_string(block_n_half_str),
492+
)
493+
494+
# Emit second subtile store as the expression returned to the caller
495+
return expr_from_string(
496+
f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})",
497+
off0=offset0,
498+
off1=offset1_shifted,
499+
value=acc1,
500+
)
403501

404502
class StackIndexingStrategy:
405503
"""

0 commit comments

Comments
 (0)