Skip to content

Commit f6109da

Browse files
committed
wip
1 parent 9d8f78f commit f6109da

File tree

3 files changed

+36
-10
lines changed

3 files changed

+36
-10
lines changed

helion/_compiler/device_function.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,8 @@ def block_size_var(self, block_id: int) -> str | None:
282282

283283
var_name = self.new_var(f"_BLOCK_SIZE_{block_id}")
284284
self.block_size_var_cache[key] = var_name
285-
host_expr = HostFunction.current().literal_expr(block_value)
286-
if self.constexpr_arg(var_name, host_expr):
285+
if self.constexpr_arg(var_name, block_value):
286+
host_expr = self._constexpr_args[var_name].host_str()
287287
self.codegen.host_statements.append(
288288
statement_from_string(f"{var_name} = {host_expr}")
289289
)
@@ -484,14 +484,42 @@ def expr_arg(self, sym: sympy.Expr, origin: Origin) -> SymbolArgument:
484484
self._expr_args[sym] = arg
485485
return self._expr_args[sym]
486486

487-
def constexpr_arg(self, name: str, host_str: str | None = None) -> bool:
487+
def constexpr_arg(self, name: str, value: object | None = None) -> bool:
488488
"""Create a constexpr argument, returns True if created, False if already exists."""
489489
if name in self._constexpr_args:
490490
return False
491-
self._constexpr_args[name] = rv = ConstExprArg(name, host_str or name)
491+
host_str = name if value is None else self._format_constexpr_value(value)
492+
self._constexpr_args[name] = rv = ConstExprArg(name, host_str)
492493
self.arguments.append(rv)
493494
return True
494495

496+
def _format_constexpr_value(self, value: object) -> str:
497+
if isinstance(value, str):
498+
return value
499+
if isinstance(value, (int, float, bool)):
500+
return repr(value)
501+
502+
# Extract sympy expression from torch symbolic types
503+
if isinstance(value, (torch.SymInt, torch.SymFloat, torch.SymBool)):
504+
value = value._sympy_()
505+
506+
# Handle sympy expressions (sanitize by replacing triton_helpers functions)
507+
if isinstance(value, sympy.Expr):
508+
expr = value.replace(
509+
lambda node: isinstance(node, sympy.Function)
510+
and getattr(node.func, "__name__", "")
511+
== "triton_helpers.div_floor_integer",
512+
lambda node: sympy.floor(node.args[0] / node.args[1]), # pyright: ignore[reportAttributeAccessIssue]
513+
).replace(
514+
lambda node: isinstance(node, sympy.Function)
515+
and getattr(node.func, "__name__", "")
516+
== "triton_helpers.remainder_integer",
517+
lambda node: sympy.Mod(node.args[0], node.args[1]), # pyright: ignore[reportAttributeAccessIssue]
518+
)
519+
return HostFunction.current().sympy_expr(expr)
520+
521+
return HostFunction.current().literal_expr(value)
522+
495523
def _tensor_property(
496524
self,
497525
prop_cls: type[_P],

helion/_compiler/inductor_lowering.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,8 +1241,7 @@ def _create_named_result(self, node: Node, result: ast.expr) -> str:
12411241
):
12421242
# This expression is used in tl.arange, make it a constexpr
12431243
name = self.cg.device_function.new_var(node.name)
1244-
host_expr = self.cg.device_function.sympy_expr(val._sympy_())
1245-
self.cg.device_function.constexpr_arg(name, host_expr)
1244+
self.cg.device_function.constexpr_arg(name, val._sympy_())
12461245
return name
12471246

12481247
# If the lowering produced a named value that is already defined elsewhere

helion/_compiler/tile_strategy.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,10 @@ def _setup_block_size_constexpr(
244244
self, state: CodegenState, block_size_var: str, block_size: SymIntLike
245245
) -> None:
246246
"""Helper to setup constexpr block size variable on host."""
247-
if state.device_function.constexpr_arg(block_size_var):
247+
if state.device_function.constexpr_arg(block_size_var, block_size):
248+
host_expr = state.device_function._constexpr_args[block_size_var].host_str()
248249
state.codegen.host_statements.append(
249-
statement_from_string(
250-
f"{block_size_var} = {HostFunction.current().literal_expr(block_size)}"
251-
)
250+
statement_from_string(f"{block_size_var} = {host_expr}")
252251
)
253252

254253

0 commit comments

Comments
 (0)