@@ -207,6 +207,7 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
207207 ] = {}
208208 self ._expr_args : dict [sympy .Expr , SymbolArgument ] = {}
209209 self ._constexpr_args : dict [str , ConstExprArg ] = {}
210+ self ._constexpr_host_defs : set [str ] = set ()
210211 self ._tensor_properties : dict [
211212 tuple [type [TensorPropertyArg ], torch .Tensor , int ], TensorPropertyArg
212213 ] = {}
@@ -282,11 +283,7 @@ def block_size_var(self, block_id: int) -> str | None:
282283
283284 var_name = self .new_var (f"_BLOCK_SIZE_{ block_id } " )
284285 self .block_size_var_cache [key ] = var_name
285- host_expr = HostFunction .current ().literal_expr (block_value )
286- if self .constexpr_arg (var_name , host_expr ):
287- self .codegen .host_statements .append (
288- statement_from_string (f"{ var_name } = { host_expr } " )
289- )
286+ self .constexpr_arg_with_host_def (var_name , block_value )
290287
291288 return self .block_size_var_cache [key ]
292289
@@ -484,14 +481,52 @@ def expr_arg(self, sym: sympy.Expr, origin: Origin) -> SymbolArgument:
484481 self ._expr_args [sym ] = arg
485482 return self ._expr_args [sym ]
486483
487- def constexpr_arg (self , name : str , host_str : str | None = None ) -> bool :
484+ def constexpr_arg (self , name : str , value : object | None = None ) -> bool :
488485 """Create a constexpr argument, returns True if created, False if already exists."""
489486 if name in self ._constexpr_args :
490487 return False
491- self ._constexpr_args [name ] = rv = ConstExprArg (name , host_str or name )
488+ host_str = name if value is None else self ._format_constexpr_value (value )
489+ self ._constexpr_args [name ] = rv = ConstExprArg (name , host_str )
492490 self .arguments .append (rv )
493491 return True
494492
493+ def constexpr_arg_with_host_def (self , name : str , value : object ) -> None :
494+ """Create a constexpr argument and add its host-side definition if needed."""
495+ created = self .constexpr_arg (name , value )
496+ host_expr = self ._constexpr_args [name ].host_str ()
497+ if created or name not in self ._constexpr_host_defs :
498+ self .codegen .host_statements .append (
499+ statement_from_string (f"{ name } = { host_expr } " )
500+ )
501+ self ._constexpr_host_defs .add (name )
502+
503+ def _format_constexpr_value (self , value : object ) -> str :
504+ if isinstance (value , str ):
505+ return value
506+ if isinstance (value , (int , float , bool )):
507+ return repr (value )
508+
509+ # Extract sympy expression from torch symbolic types
510+ if isinstance (value , (torch .SymInt , torch .SymFloat , torch .SymBool )):
511+ value = value ._sympy_ ()
512+
513+ # Handle sympy expressions (sanitize by replacing triton_helpers functions)
514+ if isinstance (value , sympy .Expr ):
515+ expr = value .replace (
516+ lambda node : isinstance (node , sympy .Function )
517+ and getattr (node .func , "__name__" , "" )
518+ == "triton_helpers.div_floor_integer" ,
519+ lambda node : sympy .floor (node .args [0 ] / node .args [1 ]), # pyright: ignore[reportAttributeAccessIssue]
520+ ).replace (
521+ lambda node : isinstance (node , sympy .Function )
522+ and getattr (node .func , "__name__" , "" )
523+ == "triton_helpers.remainder_integer" ,
524+ lambda node : sympy .Mod (node .args [0 ], node .args [1 ]), # pyright: ignore[reportAttributeAccessIssue]
525+ )
526+ return HostFunction .current ().sympy_expr (expr )
527+
528+ return HostFunction .current ().literal_expr (value )
529+
495530 def _tensor_property (
496531 self ,
497532 prop_cls : type [_P ],
@@ -556,7 +591,12 @@ def codegen_function_def(self) -> list[ast.stmt]:
556591 ]
557592
558593 def codegen_function_call (self ) -> ast .AST :
559- args = [arg .host_str () for arg in self .sorted_args ()]
594+ args = []
595+ for arg in self .sorted_args ():
596+ if isinstance (arg , ConstExprArg ) and arg .name in self ._constexpr_host_defs :
597+ args .append (arg .name )
598+ else :
599+ args .append (arg .host_str ())
560600
561601 if self .has_rng_ops ():
562602 # Pass the host-side seed buffer variable to the kernel
0 commit comments