@@ -207,6 +207,7 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
207207 ] = {}
208208 self ._expr_args : dict [sympy .Expr , SymbolArgument ] = {}
209209 self ._constexpr_args : dict [str , ConstExprArg ] = {}
210+ self ._constexpr_host_defs : set [str ] = set ()
210211 self ._tensor_properties : dict [
211212 tuple [type [TensorPropertyArg ], torch .Tensor , int ], TensorPropertyArg
212213 ] = {}
@@ -282,11 +283,7 @@ def block_size_var(self, block_id: int) -> str | None:
282283
283284 var_name = self .new_var (f"_BLOCK_SIZE_{ block_id } " )
284285 self .block_size_var_cache [key ] = var_name
285- host_expr = HostFunction .current ().literal_expr (block_value )
286- if self .constexpr_arg (var_name , host_expr ):
287- self .codegen .host_statements .append (
288- statement_from_string (f"{ var_name } = { host_expr } " )
289- )
286+ self .constexpr_arg_with_host_def (var_name , block_value )
290287
291288 return self .block_size_var_cache [key ]
292289
@@ -484,14 +481,55 @@ def expr_arg(self, sym: sympy.Expr, origin: Origin) -> SymbolArgument:
484481 self ._expr_args [sym ] = arg
485482 return self ._expr_args [sym ]
486483
487- def constexpr_arg (self , name : str , host_str : str | None = None ) -> bool :
484+ def constexpr_arg (self , name : str , value : object | None = None ) -> bool :
488485 """Create a constexpr argument, returns True if created, False if already exists."""
489486 if name in self ._constexpr_args :
490487 return False
491- self ._constexpr_args [name ] = rv = ConstExprArg (name , host_str or name )
488+ host_str = name if value is None else self ._format_constexpr_value (value )
489+ self ._constexpr_args [name ] = rv = ConstExprArg (name , host_str )
492490 self .arguments .append (rv )
493491 return True
494492
493+ def constexpr_arg_with_host_def (self , name : str , value : object ) -> None :
494+ """Create a constexpr argument and add its host-side definition if needed."""
495+ created = self .constexpr_arg (name , value )
496+ host_expr = self ._constexpr_args [name ].host_str ()
497+ if created or name not in self ._constexpr_host_defs :
498+ self .codegen .host_statements .append (
499+ statement_from_string (f"{ name } = { host_expr } " )
500+ )
501+ self ._constexpr_host_defs .add (name )
502+
503+ def _format_constexpr_value (self , value : object ) -> str :
504+ if isinstance (value , str ):
505+ return value
506+ if isinstance (value , (int , float , bool )):
507+ return repr (value )
508+
509+ # Extract sympy expression from torch symbolic types
510+ if isinstance (value , (torch .SymInt , torch .SymFloat , torch .SymBool )):
511+ value = value ._sympy_ ()
512+
513+ # Handle sympy expressions (sanitize by replacing triton_helpers functions)
514+ if isinstance (value , sympy .Expr ):
515+ expr = cast (
516+ "sympy.Expr" ,
517+ value .replace (
518+ lambda node : isinstance (node , sympy .Function )
519+ and getattr (node .func , "__name__" , "" )
520+ == "triton_helpers.div_floor_integer" ,
521+ lambda node : sympy .floor (node .args [0 ] / node .args [1 ]), # pyright: ignore[reportAttributeAccessIssue]
522+ ).replace (
523+ lambda node : isinstance (node , sympy .Function )
524+ and getattr (node .func , "__name__" , "" )
525+ == "triton_helpers.remainder_integer" ,
526+ lambda node : sympy .Mod (node .args [0 ], node .args [1 ]), # pyright: ignore[reportAttributeAccessIssue]
527+ ),
528+ )
529+ return HostFunction .current ().sympy_expr (expr )
530+
531+ return HostFunction .current ().literal_expr (value )
532+
495533 def _tensor_property (
496534 self ,
497535 prop_cls : type [_P ],
@@ -556,7 +594,12 @@ def codegen_function_def(self) -> list[ast.stmt]:
556594 ]
557595
558596 def codegen_function_call (self ) -> ast .AST :
559- args = [arg .host_str () for arg in self .sorted_args ()]
597+ args = []
598+ for arg in self .sorted_args ():
599+ if isinstance (arg , ConstExprArg ) and arg .name in self ._constexpr_host_defs :
600+ args .append (arg .name )
601+ else :
602+ args .append (arg .host_str ())
560603
561604 if self .has_rng_ops ():
562605 # Pass the host-side seed buffer variable to the kernel
0 commit comments