Skip to content

Commit a25a35e

Browse files
committed
Fix lint error
stack-info: PR: #926, branch: jansel/stack/187
1 parent 3f6d43d commit a25a35e

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

helion/_compiler/device_function.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -509,23 +509,25 @@ def _format_constexpr_value(self, value: object) -> str:
509509
# Extract sympy expression from torch symbolic types
510510
if isinstance(value, (torch.SymInt, torch.SymFloat, torch.SymBool)):
511511
value = value._sympy_()
512+
if isinstance(value, tuple):
513+
# torch symbolic values may return (expr, replacements); use the expr
514+
value = value[0]
512515

513516
# Handle sympy expressions (sanitize by replacing triton_helpers functions)
514517
if isinstance(value, sympy.Expr):
515-
expr = cast(
516-
"sympy.Expr",
517-
value.replace(
518-
lambda node: isinstance(node, sympy.Function)
519-
and getattr(node.func, "__name__", "")
520-
== "triton_helpers.div_floor_integer",
521-
lambda node: sympy.floor(node.args[0] / node.args[1]), # pyright: ignore[reportAttributeAccessIssue]
522-
).replace(
523-
lambda node: isinstance(node, sympy.Function)
524-
and getattr(node.func, "__name__", "")
525-
== "triton_helpers.remainder_integer",
526-
lambda node: sympy.Mod(node.args[0], node.args[1]), # pyright: ignore[reportAttributeAccessIssue]
527-
),
518+
sympy_expr = cast("sympy.Expr", value)
519+
sanitized = sympy_expr.replace( # pyright: ignore[reportAttributeAccessIssue]
520+
lambda node: isinstance(node, sympy.Function)
521+
and getattr(node.func, "__name__", "")
522+
== "triton_helpers.div_floor_integer",
523+
lambda node: sympy.floor(node.args[0] / node.args[1]), # pyright: ignore[reportAttributeAccessIssue]
524+
).replace( # pyright: ignore[reportAttributeAccessIssue]
525+
lambda node: isinstance(node, sympy.Function)
526+
and getattr(node.func, "__name__", "")
527+
== "triton_helpers.remainder_integer",
528+
lambda node: sympy.Mod(node.args[0], node.args[1]), # pyright: ignore[reportAttributeAccessIssue]
528529
)
530+
expr = cast("sympy.Expr", sanitized)
529531
return HostFunction.current().sympy_expr(expr)
530532

531533
return HostFunction.current().literal_expr(value)

0 commit comments

Comments
 (0)