@@ -512,20 +512,18 @@ def _format_constexpr_value(self, value: object) -> str:
512512
513513 # Handle sympy expressions (sanitize by replacing triton_helpers functions)
514514 if isinstance (value , sympy .Expr ):
515- expr = cast (
516- "sympy.Expr" ,
517- value .replace (
518- lambda node : isinstance (node , sympy .Function )
519- and getattr (node .func , "__name__" , "" )
520- == "triton_helpers.div_floor_integer" ,
521- lambda node : sympy .floor (node .args [0 ] / node .args [1 ]), # pyright: ignore[reportAttributeAccessIssue]
522- ).replace (
523- lambda node : isinstance (node , sympy .Function )
524- and getattr (node .func , "__name__" , "" )
525- == "triton_helpers.remainder_integer" ,
526- lambda node : sympy .Mod (node .args [0 ], node .args [1 ]), # pyright: ignore[reportAttributeAccessIssue]
527- ),
515+ sanitized = value .replace ( # pyright: ignore[reportAttributeAccessIssue]
516+ lambda node : isinstance (node , sympy .Function )
517+ and getattr (node .func , "__name__" , "" )
518+ == "triton_helpers.div_floor_integer" ,
519+ lambda node : sympy .floor (node .args [0 ] / node .args [1 ]), # pyright: ignore[reportAttributeAccessIssue]
520+ ).replace ( # pyright: ignore[reportAttributeAccessIssue]
521+ lambda node : isinstance (node , sympy .Function )
522+ and getattr (node .func , "__name__" , "" )
523+ == "triton_helpers.remainder_integer" ,
524+ lambda node : sympy .Mod (node .args [0 ], node .args [1 ]), # pyright: ignore[reportAttributeAccessIssue]
528525 )
526+ expr = cast ("sympy.Expr" , sanitized )
529527 return HostFunction .current ().sympy_expr (expr )
530528
531529 return HostFunction .current ().literal_expr (value )
0 commit comments