@@ -282,8 +282,8 @@ def block_size_var(self, block_id: int) -> str | None:
282282
283283 var_name = self .new_var (f"_BLOCK_SIZE_{ block_id } " )
284284 self .block_size_var_cache [key ] = var_name
285- host_expr = HostFunction . current (). literal_expr ( block_value )
286- if self .constexpr_arg ( var_name , host_expr ):
285+ if self . constexpr_arg ( var_name , block_value ):
286+ host_expr = self ._constexpr_args [ var_name ]. host_str ()
287287 self .codegen .host_statements .append (
288288 statement_from_string (f"{ var_name } = { host_expr } " )
289289 )
@@ -484,14 +484,42 @@ def expr_arg(self, sym: sympy.Expr, origin: Origin) -> SymbolArgument:
484484 self ._expr_args [sym ] = arg
485485 return self ._expr_args [sym ]
486486
487- def constexpr_arg (self , name : str , host_str : str | None = None ) -> bool :
487+ def constexpr_arg (self , name : str , value : object | None = None ) -> bool :
488488 """Create a constexpr argument, returns True if created, False if already exists."""
489489 if name in self ._constexpr_args :
490490 return False
491- self ._constexpr_args [name ] = rv = ConstExprArg (name , host_str or name )
491+ host_str = name if value is None else self ._format_constexpr_value (value )
492+ self ._constexpr_args [name ] = rv = ConstExprArg (name , host_str )
492493 self .arguments .append (rv )
493494 return True
494495
496+ def _format_constexpr_value (self , value : object ) -> str :
497+ if isinstance (value , str ):
498+ return value
499+ if isinstance (value , (int , float , bool )):
500+ return repr (value )
501+
502+ # Extract sympy expression from torch symbolic types
503+ if isinstance (value , (torch .SymInt , torch .SymFloat , torch .SymBool )):
504+ value = value ._sympy_ ()
505+
506+ # Handle sympy expressions (sanitize by replacing triton_helpers functions)
507+ if isinstance (value , sympy .Expr ):
508+ expr = value .replace (
509+ lambda node : isinstance (node , sympy .Function )
510+ and getattr (node .func , "__name__" , "" )
511+ == "triton_helpers.div_floor_integer" ,
512+ lambda node : sympy .floor (node .args [0 ] / node .args [1 ]), # pyright: ignore[reportAttributeAccessIssue]
513+ ).replace (
514+ lambda node : isinstance (node , sympy .Function )
515+ and getattr (node .func , "__name__" , "" )
516+ == "triton_helpers.remainder_integer" ,
517+ lambda node : sympy .Mod (node .args [0 ], node .args [1 ]), # pyright: ignore[reportAttributeAccessIssue]
518+ )
519+ return HostFunction .current ().sympy_expr (expr )
520+
521+ return HostFunction .current ().literal_expr (value )
522+
495523 def _tensor_property (
496524 self ,
497525 prop_cls : type [_P ],
0 commit comments