Skip to content

Commit 8da6fb5

Browse files
committed
FlattenMapper: guard simplifications that only hold for integers
1 parent fa79c6c commit 8da6fb5

File tree

2 files changed

+47
-3
lines changed

2 files changed

+47
-3
lines changed

pymbolic/mapper/flattener.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,19 @@ class FlattenMapper(IdentityMapper[[]]):
4848
4949
This parallels what was done implicitly in the expression node
5050
constructors.
51+
52+
.. automethod:: is_expr_integral
5153
"""
54+
55+
def is_expr_integer_valued(self, expr: ExpressionT) -> bool:
56+
"""A user-supplied method to indicate whether a given *expr* is integer-
57+
valued. This enables additional simplifications that are not valid in
58+
general. The default implementation simply returns *False*.
59+
60+
.. versionadded :: 2024.1
61+
"""
62+
return False
63+
5264
def map_sum(self, expr: p.Sum) -> ExpressionT:
5365
from pymbolic.primitives import flattened_sum
5466
return flattened_sum([
@@ -77,7 +89,9 @@ def map_floor_div(self, expr: p.FloorDiv) -> ExpressionT:
7789
if p.is_zero(r_num):
7890
return 0
7991
if p.is_zero(r_den - 1):
80-
return r_num
92+
# It's the floor function in this case.
93+
if self.is_expr_integer_valued(r_num):
94+
return r_num
8195

8296
return expr.__class__(r_num, r_den)
8397

@@ -88,8 +102,9 @@ def map_remainder(self, expr: p.Remainder) -> ExpressionT:
88102
if p.is_zero(r_num):
89103
return 0
90104
if p.is_zero(r_den - 1):
91-
# mod 1 is zero
92-
return 0
105+
# mod 1 is zero for integers, however 3.1 % 1 == .1
106+
if self.is_expr_integer_valued(r_num):
107+
return 0
93108

94109
return expr.__class__(r_num, r_den)
95110

test/test_pymbolic.py

+29
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from pymbolic.mapper.evaluator import evaluate_kw
4+
from pymbolic.mapper.flattener import FlattenMapper
35
from pymbolic.mapper.stringifier import StringifyMapper
46
from pymbolic.typing import ExpressionT
57

@@ -1053,6 +1055,33 @@ def test_derived_stringifier() -> None:
10531055
# }}}
10541056

10551057

1058+
# {{{ test_flatten
1059+
1060+
class IntegerFlattenMapper(FlattenMapper):
1061+
def is_expr_integer_valued(self, expr: ExpressionT) -> bool:
1062+
return True
1063+
1064+
1065+
def test_flatten():
1066+
expr = parse("(3 + x) % 1")
1067+
1068+
assert IntegerFlattenMapper()(expr) != expr
1069+
assert FlattenMapper()(expr) == expr
1070+
1071+
assert evaluate_kw(IntegerFlattenMapper()(expr), x=1) == 0
1072+
assert abs(evaluate_kw(FlattenMapper()(expr), x=1.1) - 0.1) < 1e-12
1073+
1074+
expr = parse("(3 + x) // 1")
1075+
1076+
assert IntegerFlattenMapper()(expr) != expr
1077+
assert FlattenMapper()(expr) == expr
1078+
1079+
assert evaluate_kw(IntegerFlattenMapper()(expr), x=1) == 4
1080+
assert abs(evaluate_kw(FlattenMapper()(expr), x=1.1) - 4) < 1e-12
1081+
1082+
# }}}
1083+
1084+
10561085
if __name__ == "__main__":
10571086
import sys
10581087
if len(sys.argv) > 1:

0 commit comments

Comments
 (0)