@@ -462,8 +462,9 @@ def tensor_descriptor_arg(
462462 self , fake_value : torch .Tensor , block_size : list [int | torch .SymInt ]
463463 ) -> TensorDescriptorArg :
464464 host_function = HostFunction .current ()
465- block_size_expr = ", " .join (map ( self .literal_expr , block_size ) )
465+ block_size_expr = ", " .join (self .literal_expr ( dim ) for dim in block_size )
466466 key = (fake_value , block_size_expr )
467+
467468 if key not in self ._tensor_descriptor_args :
468469 origin = host_function .tensor_to_origin [fake_value ]
469470 desc_name = self .new_var (origin .suggest_var_name () + "_desc" )
@@ -556,22 +557,6 @@ def _format_constexpr_value(self, value: object) -> str:
556557 if isinstance (value , (torch .SymInt , torch .SymFloat , torch .SymBool )):
557558 value = value ._sympy_ ()
558559
559- # Handle sympy expressions (sanitize by replacing triton_helpers functions)
560- if isinstance (value , sympy .Expr ):
561- sanitized = value .replace ( # pyright: ignore[reportAttributeAccessIssue]
562- lambda node : isinstance (node , sympy .Function )
563- and getattr (node .func , "__name__" , "" )
564- == "triton_helpers.div_floor_integer" ,
565- lambda node : sympy .floor (node .args [0 ] / node .args [1 ]), # pyright: ignore[reportAttributeAccessIssue]
566- ).replace ( # pyright: ignore[reportAttributeAccessIssue]
567- lambda node : isinstance (node , sympy .Function )
568- and getattr (node .func , "__name__" , "" )
569- == "triton_helpers.remainder_integer" ,
570- lambda node : sympy .Mod (node .args [0 ], node .args [1 ]), # pyright: ignore[reportAttributeAccessIssue]
571- )
572- expr = cast ("sympy.Expr" , sanitized )
573- return HostFunction .current ().sympy_expr (expr )
574-
575560 return HostFunction .current ().literal_expr (value )
576561
577562 def _tensor_property (
@@ -749,11 +734,19 @@ def current() -> DeviceFunction:
749734
750735
751736class HelionTritonPrinter (TritonPrinter ):
752- """Custom Triton printer that avoids wrapping float literals in tl.full().
753-
754- Inductor's default TritonPrinter prints SymPy Float as a 0-D Triton value
755- via tl.full([], <val>, tl.float64). We override this to emit the raw numeric
756- literal, letting downstream type promotion and casts handle dtype.
737+ """Custom Triton printer that does the following:
738+
739+ - Avoids wrapping float literals in tl.full().
740+ Inductor's default TritonPrinter prints SymPy Float as a 0-D Triton value
741+ via tl.full([], <val>, tl.float64). We override this to emit the raw numeric
742+ literal, letting downstream type promotion and casts handle dtype.
743+
744+ - Avoids triton_helpers.div_floor_integer(...) calls when both operands are
745+ provably non-negative integers. TritonPrinter by default converts
746+ floor(u1/2) to triton_helpers.div_floor_integer(...). We override this to
747+ emit u1 // 2 only when the numerator is known to be non-negative and the
748+ denominator is a positive integer, so that we keep helper calls for cases
749+ that rely on floor semantics with mixed signs.
757750 """
758751
759752 def _print_Float (self , expr : sympy .Expr ) -> str :
@@ -762,6 +755,53 @@ def _print_Float(self, expr: sympy.Expr) -> str:
762755 def _print_ToFloat (self , expr : sympy .Expr ) -> str :
763756 return f"{ expr } + 0.0"
764757
758+ def _is_nonnegative (self , expr : sympy .Expr ) -> bool :
759+ if expr .is_nonnegative is True or expr .is_zero is True :
760+ return True
761+ if expr .is_positive is True :
762+ return True
763+ try :
764+ host_fn = HostFunction .current ()
765+ except NoCurrentFunction :
766+ host_fn = None
767+ if host_fn is not None :
768+ origin_info = host_fn .expr_to_origin .get (expr )
769+ if origin_info and isinstance (
770+ origin_info .origin , (BlockSizeOrigin , TensorSizeOrigin )
771+ ):
772+ return True
773+ if isinstance (expr , sympy .Symbol ) and expr .name .startswith ("_BLOCK_SIZE_" ):
774+ return True
775+ if isinstance (expr , sympy .Number ):
776+ return bool (expr >= 0 )
777+ return False
778+
779+ def _format_trunc_div (self , lhs : sympy .Expr , rhs : sympy .Expr ) -> str :
780+ lhs_str = self ._print (lhs )
781+ rhs_str = self ._print (rhs )
782+ if not (lhs .is_Integer or lhs .is_Symbol ):
783+ lhs_str = f"({ lhs_str } )"
784+ if not (rhs .is_Integer or rhs .is_Symbol ):
785+ rhs_str = f"({ rhs_str } )"
786+ return f"{ lhs_str } // { rhs_str } "
787+
788+ def _print_floor (self , expr : sympy .Expr ) -> str :
789+ inner = expr .args [0 ]
790+ numer , denom = inner .as_numer_denom ()
791+ if (
792+ isinstance (denom , sympy .Integer )
793+ and denom > 1
794+ and self ._is_nonnegative (numer )
795+ ):
796+ return self ._format_trunc_div (numer , denom )
797+ return super ()._print_floor (expr )
798+
799+ def _print_FloorDiv (self , expr : sympy .Expr ) -> str :
800+ lhs , rhs = expr .args
801+ if isinstance (rhs , sympy .Integer ) and rhs > 0 and self ._is_nonnegative (lhs ):
802+ return self ._format_trunc_div (lhs , rhs )
803+ return super ()._print_FloorDiv (expr )
804+
765805
766806def texpr (expr : sympy .Expr ) -> str :
767807 return HelionTritonPrinter ().doprint (expr )
0 commit comments