@@ -250,6 +250,9 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
250250 self .rng_seed_count = 0
251251 self .device_load_index = 0 # Track which load in device code we're generating (for eviction policy tuning)
252252 # Name of the RNG seed buffer parameter in kernel signature
253+ self .device_store_index = (
254+ 0 # Track which store in device code we're generating (for subtiling)
255+ )
253256 self .rng_seed_buffer_param_name = None
254257
255258 def has_rng_ops (self ) -> bool :
@@ -421,8 +424,9 @@ def tensor_descriptor_arg(
421424 self , fake_value : torch .Tensor , block_size : list [int | torch .SymInt ]
422425 ) -> TensorDescriptorArg :
423426 host_function = HostFunction .current ()
424- block_size_expr = ", " .join (map ( self .literal_expr , block_size ) )
427+ block_size_expr = ", " .join (self .literal_expr ( dim ) for dim in block_size )
425428 key = (fake_value , block_size_expr )
429+
426430 if key not in self ._tensor_descriptor_args :
427431 origin = host_function .tensor_to_origin [fake_value ]
428432 desc_name = self .new_var (origin .suggest_var_name () + "_desc" )
@@ -515,22 +519,6 @@ def _format_constexpr_value(self, value: object) -> str:
515519 if isinstance (value , (torch .SymInt , torch .SymFloat , torch .SymBool )):
516520 value = value ._sympy_ ()
517521
518- # Handle sympy expressions (sanitize by replacing triton_helpers functions)
519- if isinstance (value , sympy .Expr ):
520- sanitized = value .replace ( # pyright: ignore[reportAttributeAccessIssue]
521- lambda node : isinstance (node , sympy .Function )
522- and getattr (node .func , "__name__" , "" )
523- == "triton_helpers.div_floor_integer" ,
524- lambda node : sympy .floor (node .args [0 ] / node .args [1 ]), # pyright: ignore[reportAttributeAccessIssue]
525- ).replace ( # pyright: ignore[reportAttributeAccessIssue]
526- lambda node : isinstance (node , sympy .Function )
527- and getattr (node .func , "__name__" , "" )
528- == "triton_helpers.remainder_integer" ,
529- lambda node : sympy .Mod (node .args [0 ], node .args [1 ]), # pyright: ignore[reportAttributeAccessIssue]
530- )
531- expr = cast ("sympy.Expr" , sanitized )
532- return HostFunction .current ().sympy_expr (expr )
533-
534522 return HostFunction .current ().literal_expr (value )
535523
536524 def _tensor_property (
@@ -708,11 +696,19 @@ def current() -> DeviceFunction:
708696
709697
710698class HelionTritonPrinter (TritonPrinter ):
711- """Custom Triton printer that avoids wrapping float literals in tl.full().
712-
713- Inductor's default TritonPrinter prints SymPy Float as a 0-D Triton value
714- via tl.full([], <val>, tl.float64). We override this to emit the raw numeric
715- literal, letting downstream type promotion and casts handle dtype.
699+ """Custom Triton printer that does the following:
700+
701+ - Avoids wrapping float literals in tl.full().
702+ Inductor's default TritonPrinter prints SymPy Float as a 0-D Triton value
703+ via tl.full([], <val>, tl.float64). We override this to emit the raw numeric
704+ literal, letting downstream type promotion and casts handle dtype.
705+
706+ - Avoids triton_helpers.div_floor_integer(...) calls when both operands are
707+ provably non-negative integers. TritonPrinter by default converts
708+ floor(u1/2) to triton_helpers.div_floor_integer(...). We override this to
709+ emit u1 // 2 only when the numerator is known to be non-negative and the
710+ denominator is a positive integer, so that we keep helper calls for cases
711+ that rely on floor semantics with mixed signs.
716712 """
717713
718714 def _print_Float (self , expr : sympy .Expr ) -> str :
@@ -721,6 +717,53 @@ def _print_Float(self, expr: sympy.Expr) -> str:
721717 def _print_ToFloat (self , expr : sympy .Expr ) -> str :
722718 return f"{ expr } + 0.0"
723719
720+ def _is_nonnegative (self , expr : sympy .Expr ) -> bool :
721+ if expr .is_nonnegative is True or expr .is_zero is True :
722+ return True
723+ if expr .is_positive is True :
724+ return True
725+ try :
726+ host_fn = HostFunction .current ()
727+ except NoCurrentFunction :
728+ host_fn = None
729+ if host_fn is not None :
730+ origin_info = host_fn .expr_to_origin .get (expr )
731+ if origin_info and isinstance (
732+ origin_info .origin , (BlockSizeOrigin , TensorSizeOrigin )
733+ ):
734+ return True
735+ if isinstance (expr , sympy .Symbol ) and expr .name .startswith ("_BLOCK_SIZE_" ):
736+ return True
737+ if isinstance (expr , sympy .Number ):
738+ return bool (expr >= 0 )
739+ return False
740+
741+ def _format_trunc_div (self , lhs : sympy .Expr , rhs : sympy .Expr ) -> str :
742+ lhs_str = self ._print (lhs )
743+ rhs_str = self ._print (rhs )
744+ if not (lhs .is_Integer or lhs .is_Symbol ):
745+ lhs_str = f"({ lhs_str } )"
746+ if not (rhs .is_Integer or rhs .is_Symbol ):
747+ rhs_str = f"({ rhs_str } )"
748+ return f"{ lhs_str } // { rhs_str } "
749+
750+ def _print_floor (self , expr : sympy .Expr ) -> str :
751+ inner = expr .args [0 ]
752+ numer , denom = inner .as_numer_denom ()
753+ if (
754+ isinstance (denom , sympy .Integer )
755+ and denom > 1
756+ and self ._is_nonnegative (numer )
757+ ):
758+ return self ._format_trunc_div (numer , denom )
759+ return super ()._print_floor (expr )
760+
761+ def _print_FloorDiv (self , expr : sympy .Expr ) -> str :
762+ lhs , rhs = expr .args
763+ if isinstance (rhs , sympy .Integer ) and rhs > 0 and self ._is_nonnegative (lhs ):
764+ return self ._format_trunc_div (lhs , rhs )
765+ return super ()._print_FloorDiv (expr )
766+
724767
725768def texpr (expr : sympy .Expr ) -> str :
726769 return HelionTritonPrinter ().doprint (expr )
0 commit comments