File tree 2 files changed +47
-3
lines changed
2 files changed +47
-3
lines changed Original file line number Diff line number Diff line change @@ -48,7 +48,19 @@ class FlattenMapper(IdentityMapper[[]]):
48
48
49
49
This parallels what was done implicitly in the expression node
50
50
constructors.
51
+
52
+ .. automethod:: is_expr_integral
51
53
"""
54
+
55
+ def is_expr_integer_valued (self , expr : ExpressionT ) -> bool :
56
+ """A user-supplied method to indicate whether a given *expr* is integer-
57
+ valued. This enables additional simplifications that are not valid in
58
+ general. The default implementation simply returns *False*.
59
+
60
+ .. versionadded :: 2024.1
61
+ """
62
+ return False
63
+
52
64
def map_sum (self , expr : p .Sum ) -> ExpressionT :
53
65
from pymbolic .primitives import flattened_sum
54
66
return flattened_sum ([
@@ -77,7 +89,9 @@ def map_floor_div(self, expr: p.FloorDiv) -> ExpressionT:
77
89
if p .is_zero (r_num ):
78
90
return 0
79
91
if p .is_zero (r_den - 1 ):
80
- return r_num
92
+ # It's the floor function in this case.
93
+ if self .is_expr_integer_valued (r_num ):
94
+ return r_num
81
95
82
96
return expr .__class__ (r_num , r_den )
83
97
@@ -88,8 +102,9 @@ def map_remainder(self, expr: p.Remainder) -> ExpressionT:
88
102
if p .is_zero (r_num ):
89
103
return 0
90
104
if p .is_zero (r_den - 1 ):
91
- # mod 1 is zero
92
- return 0
105
+ # mod 1 is zero for integers, however 3.1 % 1 == .1
106
+ if self .is_expr_integer_valued (r_num ):
107
+ return 0
93
108
94
109
return expr .__class__ (r_num , r_den )
95
110
Original file line number Diff line number Diff line change 1
1
from __future__ import annotations
2
2
3
+ from pymbolic .mapper .evaluator import evaluate_kw
4
+ from pymbolic .mapper .flattener import FlattenMapper
3
5
from pymbolic .mapper .stringifier import StringifyMapper
4
6
from pymbolic .typing import ExpressionT
5
7
@@ -1053,6 +1055,33 @@ def test_derived_stringifier() -> None:
1053
1055
# }}}
1054
1056
1055
1057
1058
+ # {{{ test_flatten
1059
+
1060
+ class IntegerFlattenMapper (FlattenMapper ):
1061
+ def is_expr_integer_valued (self , expr : ExpressionT ) -> bool :
1062
+ return True
1063
+
1064
+
1065
+ def test_flatten ():
1066
+ expr = parse ("(3 + x) % 1" )
1067
+
1068
+ assert IntegerFlattenMapper ()(expr ) != expr
1069
+ assert FlattenMapper ()(expr ) == expr
1070
+
1071
+ assert evaluate_kw (IntegerFlattenMapper ()(expr ), x = 1 ) == 0
1072
+ assert abs (evaluate_kw (FlattenMapper ()(expr ), x = 1.1 ) - 0.1 ) < 1e-12
1073
+
1074
+ expr = parse ("(3 + x) // 1" )
1075
+
1076
+ assert IntegerFlattenMapper ()(expr ) != expr
1077
+ assert FlattenMapper ()(expr ) == expr
1078
+
1079
+ assert evaluate_kw (IntegerFlattenMapper ()(expr ), x = 1 ) == 4
1080
+ assert abs (evaluate_kw (FlattenMapper ()(expr ), x = 1.1 ) - 4 ) < 1e-12
1081
+
1082
+ # }}}
1083
+
1084
+
1056
1085
if __name__ == "__main__" :
1057
1086
import sys
1058
1087
if len (sys .argv ) > 1 :
You can’t perform that action at this time.
0 commit comments