Skip to content

Commit 15a4721

Browse files
committed
up
1 parent f60f36a commit 15a4721

34 files changed

+443
-666
lines changed

helion/_compiler/device_function.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
207207
] = {}
208208
self._expr_args: dict[sympy.Expr, SymbolArgument] = {}
209209
self._constexpr_args: dict[str, ConstExprArg] = {}
210+
self._constexpr_host_defs: set[str] = set()
210211
self._tensor_properties: dict[
211212
tuple[type[TensorPropertyArg], torch.Tensor, int], TensorPropertyArg
212213
] = {}
@@ -282,11 +283,7 @@ def block_size_var(self, block_id: int) -> str | None:
282283

283284
var_name = self.new_var(f"_BLOCK_SIZE_{block_id}")
284285
self.block_size_var_cache[key] = var_name
285-
host_expr = HostFunction.current().literal_expr(block_value)
286-
if self.constexpr_arg(var_name, host_expr):
287-
self.codegen.host_statements.append(
288-
statement_from_string(f"{var_name} = {host_expr}")
289-
)
286+
self.constexpr_arg_with_host_def(var_name, block_value)
290287

291288
return self.block_size_var_cache[key]
292289

@@ -484,14 +481,52 @@ def expr_arg(self, sym: sympy.Expr, origin: Origin) -> SymbolArgument:
484481
self._expr_args[sym] = arg
485482
return self._expr_args[sym]
486483

487-
def constexpr_arg(self, name: str, host_str: str | None = None) -> bool:
484+
def constexpr_arg(self, name: str, value: object | None = None) -> bool:
488485
"""Create a constexpr argument, returns True if created, False if already exists."""
489486
if name in self._constexpr_args:
490487
return False
491-
self._constexpr_args[name] = rv = ConstExprArg(name, host_str or name)
488+
host_str = name if value is None else self._format_constexpr_value(value)
489+
self._constexpr_args[name] = rv = ConstExprArg(name, host_str)
492490
self.arguments.append(rv)
493491
return True
494492

493+
def constexpr_arg_with_host_def(self, name: str, value: object) -> None:
494+
"""Create a constexpr argument and add its host-side definition if needed."""
495+
created = self.constexpr_arg(name, value)
496+
host_expr = self._constexpr_args[name].host_str()
497+
if created or name not in self._constexpr_host_defs:
498+
self.codegen.host_statements.append(
499+
statement_from_string(f"{name} = {host_expr}")
500+
)
501+
self._constexpr_host_defs.add(name)
502+
503+
def _format_constexpr_value(self, value: object) -> str:
504+
if isinstance(value, str):
505+
return value
506+
if isinstance(value, (int, float, bool)):
507+
return repr(value)
508+
509+
# Extract sympy expression from torch symbolic types
510+
if isinstance(value, (torch.SymInt, torch.SymFloat, torch.SymBool)):
511+
value = value._sympy_()
512+
513+
# Handle sympy expressions (sanitize by replacing triton_helpers functions)
514+
if isinstance(value, sympy.Expr):
515+
expr = value.replace(
516+
lambda node: isinstance(node, sympy.Function)
517+
and getattr(node.func, "__name__", "")
518+
== "triton_helpers.div_floor_integer",
519+
lambda node: sympy.floor(node.args[0] / node.args[1]), # pyright: ignore[reportAttributeAccessIssue]
520+
).replace(
521+
lambda node: isinstance(node, sympy.Function)
522+
and getattr(node.func, "__name__", "")
523+
== "triton_helpers.remainder_integer",
524+
lambda node: sympy.Mod(node.args[0], node.args[1]), # pyright: ignore[reportAttributeAccessIssue]
525+
)
526+
return HostFunction.current().sympy_expr(expr)
527+
528+
return HostFunction.current().literal_expr(value)
529+
495530
def _tensor_property(
496531
self,
497532
prop_cls: type[_P],
@@ -556,7 +591,12 @@ def codegen_function_def(self) -> list[ast.stmt]:
556591
]
557592

558593
def codegen_function_call(self) -> ast.AST:
559-
args = [arg.host_str() for arg in self.sorted_args()]
594+
args = []
595+
for arg in self.sorted_args():
596+
if isinstance(arg, ConstExprArg) and arg.name in self._constexpr_host_defs:
597+
args.append(arg.name)
598+
else:
599+
args.append(arg.host_str())
560600

561601
if self.has_rng_ops():
562602
# Pass the host-side seed buffer variable to the kernel

helion/_compiler/inductor_lowering.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,8 +1241,7 @@ def _create_named_result(self, node: Node, result: ast.expr) -> str:
12411241
):
12421242
# This expression is used in tl.arange, make it a constexpr
12431243
name = self.cg.device_function.new_var(node.name)
1244-
host_expr = self.cg.device_function.sympy_expr(val._sympy_())
1245-
self.cg.device_function.constexpr_arg(name, host_expr)
1244+
self.cg.device_function.constexpr_arg(name, val._sympy_())
12461245
return name
12471246

12481247
# If the lowering produced a named value that is already defined elsewhere

helion/_compiler/tile_strategy.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -244,12 +244,7 @@ def _setup_block_size_constexpr(
244244
self, state: CodegenState, block_size_var: str, block_size: SymIntLike
245245
) -> None:
246246
"""Helper to setup constexpr block size variable on host."""
247-
if state.device_function.constexpr_arg(block_size_var):
248-
state.codegen.host_statements.append(
249-
statement_from_string(
250-
f"{block_size_var} = {HostFunction.current().literal_expr(block_size)}"
251-
)
252-
)
247+
state.device_function.constexpr_arg_with_host_def(block_size_var, block_size)
253248

254249

255250
class BlockSizeTileStrategy(TileStrategy):

0 commit comments

Comments
 (0)