Skip to content

Commit a8c83b6

Browse files
committed
Add epilogue subtiling
stack-info: PR: #948, branch: PaulZhang12/stack/14
1 parent dee9f57 commit a8c83b6

17 files changed

+1453
-129
lines changed

examples/matmul.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,13 @@
2828
@helion.kernel(
2929
# static_shapes=True gives a performance boost for matmuls
3030
static_shapes=True,
31-
# Disable autotung over unrolling/range_num_stages
31+
# Disable autotuning over range_num_stages
3232
# tl.dot is pipelined with num_stages
3333
autotune_config_overrides={
3434
"range_unroll_factors": [0, 0],
3535
"range_num_stages": [0, 0],
3636
},
37+
allow_epilogue_subtiling=True,
3738
)
3839
def matmul(
3940
x: Tensor,

helion/_compiler/device_function.py

Lines changed: 62 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -462,8 +462,9 @@ def tensor_descriptor_arg(
462462
self, fake_value: torch.Tensor, block_size: list[int | torch.SymInt]
463463
) -> TensorDescriptorArg:
464464
host_function = HostFunction.current()
465-
block_size_expr = ", ".join(map(self.literal_expr, block_size))
465+
block_size_expr = ", ".join(self.literal_expr(dim) for dim in block_size)
466466
key = (fake_value, block_size_expr)
467+
467468
if key not in self._tensor_descriptor_args:
468469
origin = host_function.tensor_to_origin[fake_value]
469470
desc_name = self.new_var(origin.suggest_var_name() + "_desc")
@@ -556,22 +557,6 @@ def _format_constexpr_value(self, value: object) -> str:
556557
if isinstance(value, (torch.SymInt, torch.SymFloat, torch.SymBool)):
557558
value = value._sympy_()
558559

559-
# Handle sympy expressions (sanitize by replacing triton_helpers functions)
560-
if isinstance(value, sympy.Expr):
561-
sanitized = value.replace( # pyright: ignore[reportAttributeAccessIssue]
562-
lambda node: isinstance(node, sympy.Function)
563-
and getattr(node.func, "__name__", "")
564-
== "triton_helpers.div_floor_integer",
565-
lambda node: sympy.floor(node.args[0] / node.args[1]), # pyright: ignore[reportAttributeAccessIssue]
566-
).replace( # pyright: ignore[reportAttributeAccessIssue]
567-
lambda node: isinstance(node, sympy.Function)
568-
and getattr(node.func, "__name__", "")
569-
== "triton_helpers.remainder_integer",
570-
lambda node: sympy.Mod(node.args[0], node.args[1]), # pyright: ignore[reportAttributeAccessIssue]
571-
)
572-
expr = cast("sympy.Expr", sanitized)
573-
return HostFunction.current().sympy_expr(expr)
574-
575560
return HostFunction.current().literal_expr(value)
576561

577562
def _tensor_property(
@@ -749,11 +734,19 @@ def current() -> DeviceFunction:
749734

750735

751736
class HelionTritonPrinter(TritonPrinter):
752-
"""Custom Triton printer that avoids wrapping float literals in tl.full().
753-
754-
Inductor's default TritonPrinter prints SymPy Float as a 0-D Triton value
755-
via tl.full([], <val>, tl.float64). We override this to emit the raw numeric
756-
literal, letting downstream type promotion and casts handle dtype.
737+
"""Custom Triton printer that does the following:
738+
739+
- Avoids wrapping float literals in tl.full().
740+
Inductor's default TritonPrinter prints SymPy Float as a 0-D Triton value
741+
via tl.full([], <val>, tl.float64). We override this to emit the raw numeric
742+
literal, letting downstream type promotion and casts handle dtype.
743+
744+
- Avoids triton_helpers.div_floor_integer(...) calls when both operands are
745+
provably non-negative integers. TritonPrinter by default converts
746+
floor(u1/2) to triton_helpers.div_floor_integer(...). We override this to
747+
emit u1 // 2 only when the numerator is known to be non-negative and the
748+
denominator is a positive integer, so that we keep helper calls for cases
749+
that rely on floor semantics with mixed signs.
757750
"""
758751

759752
def _print_Float(self, expr: sympy.Expr) -> str:
@@ -762,6 +755,53 @@ def _print_Float(self, expr: sympy.Expr) -> str:
762755
def _print_ToFloat(self, expr: sympy.Expr) -> str:
763756
return f"{expr} + 0.0"
764757

758+
def _is_nonnegative(self, expr: sympy.Expr) -> bool:
759+
if expr.is_nonnegative is True or expr.is_zero is True:
760+
return True
761+
if expr.is_positive is True:
762+
return True
763+
try:
764+
host_fn = HostFunction.current()
765+
except NoCurrentFunction:
766+
host_fn = None
767+
if host_fn is not None:
768+
origin_info = host_fn.expr_to_origin.get(expr)
769+
if origin_info and isinstance(
770+
origin_info.origin, (BlockSizeOrigin, TensorSizeOrigin)
771+
):
772+
return True
773+
if isinstance(expr, sympy.Symbol) and expr.name.startswith("_BLOCK_SIZE_"):
774+
return True
775+
if isinstance(expr, sympy.Number):
776+
return bool(expr >= 0)
777+
return False
778+
779+
def _format_trunc_div(self, lhs: sympy.Expr, rhs: sympy.Expr) -> str:
780+
lhs_str = self._print(lhs)
781+
rhs_str = self._print(rhs)
782+
if not (lhs.is_Integer or lhs.is_Symbol):
783+
lhs_str = f"({lhs_str})"
784+
if not (rhs.is_Integer or rhs.is_Symbol):
785+
rhs_str = f"({rhs_str})"
786+
return f"{lhs_str} // {rhs_str}"
787+
788+
def _print_floor(self, expr: sympy.Expr) -> str:
789+
inner = expr.args[0]
790+
numer, denom = inner.as_numer_denom()
791+
if (
792+
isinstance(denom, sympy.Integer)
793+
and denom > 1
794+
and self._is_nonnegative(numer)
795+
):
796+
return self._format_trunc_div(numer, denom)
797+
return super()._print_floor(expr)
798+
799+
def _print_FloorDiv(self, expr: sympy.Expr) -> str:
800+
lhs, rhs = expr.args
801+
if isinstance(rhs, sympy.Integer) and rhs > 0 and self._is_nonnegative(lhs):
802+
return self._format_trunc_div(lhs, rhs)
803+
return super()._print_FloorDiv(expr)
804+
765805

766806
def texpr(expr: sympy.Expr) -> str:
767807
return HelionTritonPrinter().doprint(expr)

helion/_compiler/device_ir.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from .type_propagation import _eval_binary
6464
from .type_propagation import _eval_compare
6565
from .type_propagation import _eval_unary
66+
from .utils import _use_epilogue_subtile
6667

6768
if TYPE_CHECKING:
6869
from collections.abc import Callable
@@ -1191,6 +1192,11 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
11911192
total_load_count, loads_without_eviction_policy, store_count
11921193
)
11931194

1195+
if _use_epilogue_subtile():
1196+
for graph in device_ir.graphs:
1197+
# Epilogue subtiling only for Blackwell
1198+
epilogue_subtiling_pass(graph.graph, store_count)
1199+
11941200
return device_ir
11951201

11961202

@@ -1348,3 +1354,69 @@ def remove_unnecessary_tile_index(graph: torch.fx.Graph) -> None:
13481354
user.args = tuple(new_args)
13491355
if len(node.users) == 0:
13501356
graph.erase_node(node)
1357+
1358+
1359+
def epilogue_subtiling_pass(graph: torch.fx.Graph, store_count: int) -> None:
1360+
"""
1361+
Replace epilogue subtile with a tunable value.
1362+
"""
1363+
if store_count == 0:
1364+
return
1365+
1366+
from ..autotuner.config_fragment import EnumFragment
1367+
from ..autotuner.config_fragment import ListOf
1368+
from ..autotuner.config_spec import VALID_EPILOGUE_SUBTILE_SIZES
1369+
from .inductor_lowering import PointwiseLowering
1370+
1371+
env = CompileEnvironment.current()
1372+
# Register a tunable for epilogue subtile for all device stores
1373+
fragment = ListOf(
1374+
EnumFragment(choices=VALID_EPILOGUE_SUBTILE_SIZES), length=store_count
1375+
)
1376+
env.config_spec.epilogue_subtiling = fragment
1377+
1378+
def collect_pointwise_epilogue_nodes(
1379+
store_node: torch.fx.Node,
1380+
) -> dict[torch.fx.Node, None]:
1381+
"""Recursively collect all pointwise nodes that can be subtiled in the epilogue.
1382+
1383+
Starting from a store node, traverse backwards through all input nodes,
1384+
collecting pointwise operations until we hit non-pointwise nodes.
1385+
Only include pointwise nodes that have a single user to ensure they can be fused.
1386+
"""
1387+
# dict to preserve order
1388+
pointwise_nodes = {}
1389+
visited = set()
1390+
stack = [store_node.args[2]] # Start with the value being stored
1391+
1392+
while stack:
1393+
current = stack.pop()
1394+
if current in visited or not isinstance(current, torch.fx.Node):
1395+
continue
1396+
1397+
visited.add(current)
1398+
1399+
lowering = current.meta.get("lowering")
1400+
# Check if this is a pointwise operation with only one user
1401+
if isinstance(lowering, PointwiseLowering) and len(current.users) == 1:
1402+
if current not in pointwise_nodes:
1403+
pointwise_nodes[current] = None
1404+
stack.extend(current.all_input_nodes)
1405+
1406+
return pointwise_nodes
1407+
1408+
from ..language import store as store_api
1409+
1410+
stores = set()
1411+
1412+
for node in graph.nodes:
1413+
if node.op == "call_function" and node.target == store_api:
1414+
stores.add(node)
1415+
# Collect all pointwise nodes that can be subtiled in the epilogue
1416+
pointwise_nodes = collect_pointwise_epilogue_nodes(node)
1417+
if pointwise_nodes:
1418+
# Mark all collected pointwise nodes for epilogue subtiling
1419+
for pw_node in pointwise_nodes:
1420+
pw_node.meta["epilogue_subtile"] = True
1421+
# Store the set of pointwise nodes in the store node's metadata
1422+
node.meta["pointwise_epilogue_nodes"] = pointwise_nodes

0 commit comments

Comments
 (0)