Skip to content

Commit 4c319e4

Browse files
committed
implement an EqualityMapper with caching
1 parent 324ada6 commit 4c319e4

File tree

4 files changed

+185
-14
lines changed

4 files changed

+185
-14
lines changed

pymbolic/mapper/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def __call__(self, expr, *args, **kwargs):
139139
rec = __call__
140140

141141
def map_algebraic_leaf(self, expr, *args, **kwargs):
142-
raise NotImplementedError
142+
raise NotImplementedError(type(expr).__name__)
143143

144144
def map_variable(self, expr, *args, **kwargs):
145145
return self.map_algebraic_leaf(expr, *args, **kwargs)
@@ -413,6 +413,7 @@ def map_subscript(self, expr, *args, **kwargs):
413413
index = self.rec(expr.index, *args, **kwargs)
414414
if aggregate is expr.aggregate and index is expr.index:
415415
return expr
416+
416417
return type(expr)(aggregate, index)
417418

418419
def map_lookup(self, expr, *args, **kwargs):

pymbolic/mapper/equality.py

+181
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
__copyright__ = "Copyright (C) 2021 Alexandru Fikl"
2+
3+
__license__ = """
4+
Permission is hereby granted, free of charge, to any person obtaining a copy
5+
of this software and associated documentation files (the "Software"), to deal
6+
in the Software without restriction, including without limitation the rights
7+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8+
copies of the Software, and to permit persons to whom the Software is
9+
furnished to do so, subject to the following conditions:
10+
11+
The above copyright notice and this permission notice shall be included in
12+
all copies or substantial portions of the Software.
13+
14+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20+
THE SOFTWARE.
21+
"""
22+
23+
from typing import Any, Dict, Tuple
24+
25+
from pymbolic.mapper import Mapper
26+
from pymbolic.primitives import Expression
27+
28+
29+
class EqualityMapper(Mapper):
30+
def __init__(self) -> None:
31+
self._ids_to_result: Dict[Tuple[int, int], bool] = {}
32+
33+
def __call__(self, expr: Any, other: Any) -> bool:
34+
key = (id(expr), id(other))
35+
if key in self._ids_to_result:
36+
return self._ids_to_result[key]
37+
38+
if expr is other:
39+
result = True
40+
elif expr.__class__ != other.__class__:
41+
result = False
42+
else:
43+
try:
44+
method = getattr(self, expr.mapper_method)
45+
except AttributeError:
46+
if isinstance(expr, Expression):
47+
return self.handle_unsupported_expression(expr, other)
48+
else:
49+
return self.map_foreign(expr, other)
50+
else:
51+
result = method(expr, other)
52+
53+
self._ids_to_result[key] = result
54+
return result
55+
56+
rec = __call__
57+
58+
def map_variable(self, expr, other) -> bool:
59+
return expr.name == other.name
60+
61+
def map_subscript(self, expr, other) -> bool:
62+
return (
63+
self.rec(expr.index, other.index)
64+
and self.rec(expr.aggregate, other.aggregate))
65+
66+
def map_lookup(self, expr, other) -> bool:
67+
return (
68+
expr.name == other.name
69+
and self.rec(expr.aggregate, other.aggregate))
70+
71+
def map_call(self, expr, other) -> bool:
72+
return (
73+
len(expr.parameters) == len(other.parameters)
74+
and self.rec(expr.function, other.function)
75+
and all(self.rec(p, other_p)
76+
for p, other_p in zip(expr.parameters, other.parameters)))
77+
78+
def map_call_with_kwargs(self, expr, other) -> bool:
79+
return (
80+
len(expr.parameters) == len(other.parameters)
81+
and len(expr.kw_parameters) == len(other.kw_parameters)
82+
and self.rec(expr.function, other.function)
83+
and all(self.rec(p, other_p)
84+
for p, other_p in zip(expr.parameters, other.parameters))
85+
and all(k == other_k and self.rec(v, other_v)
86+
for (k, v), (other_k, other_v) in zip(
87+
expr.kw_parameters.items(),
88+
other.kw_parameters.items())))
89+
90+
def map_sum(self, expr, other) -> bool:
91+
return (
92+
len(expr.children) == len(other.children)
93+
and all(self.rec(child, other_child)
94+
for child, other_child in zip(expr.children, other.children))
95+
)
96+
97+
map_product = map_sum
98+
map_min = map_sum
99+
map_max = map_sum
100+
101+
def map_bitwise_not(self, expr, other) -> bool:
102+
return self.rec(expr.child, other.child)
103+
104+
map_bitwise_and = map_sum
105+
map_bitwise_or = map_sum
106+
map_bitwise_xor = map_sum
107+
map_logical_and = map_sum
108+
map_logical_or = map_sum
109+
map_logical_not = map_bitwise_not
110+
111+
def map_quotient(self, expr, other) -> bool:
112+
return (
113+
self.rec(expr.numerator, other.numerator)
114+
and self.rec(expr.denominator, other.denominator)
115+
)
116+
117+
map_floor_div = map_quotient
118+
map_remainder = map_quotient
119+
120+
def map_power(self, expr, other) -> bool:
121+
return (
122+
self.rec(expr.base, other.base)
123+
and self.rec(expr.exponent, other.exponent)
124+
)
125+
126+
def map_left_shift(self, expr, other) -> bool:
127+
return (
128+
self.rec(expr.shiftee, other.shiftee)
129+
and self.rec(expr.shift, other.shift))
130+
131+
map_right_shift = map_left_shift
132+
133+
def map_comparison(self, expr, other) -> bool:
134+
return (
135+
expr.operator == other.operator
136+
and self.rec(expr.left, other.left)
137+
and self.rec(expr.right, other.right))
138+
139+
def map_if(self, expr, other) -> bool:
140+
return (
141+
self.rec(expr.condition, other.condition)
142+
and self.rec(expr.then, other.then)
143+
and self.rec(expr.else_, other.else_))
144+
145+
def map_common_subexpression(self, expr, other) -> bool:
146+
return (
147+
expr.prefix == other.prefix
148+
and expr.scope == other.scope
149+
and self.rec(expr.child, other.child))
150+
151+
def map_substitution(self, expr, other) -> bool:
152+
return (
153+
len(expr.variables) == len(other.variables)
154+
and len(expr.values) == len(other.values)
155+
and expr.variables == other.variables
156+
and self.rec(expr.child, other.child)
157+
and all(self.rec(v, other_v)
158+
for v, other_v in zip(expr.values, other.values))
159+
)
160+
161+
def map_derivative(self, expr, other) -> bool:
162+
return (
163+
len(expr.variables) == len(other.variables)
164+
and self.rec(expr.child, other.child)
165+
and all(self.rec(v, other_v)
166+
for v, other_v in zip(expr.variables, other.variables)))
167+
168+
def map_polynomial(self, expr, other) -> bool:
169+
return (
170+
self.rec(expr.Base, other.Data)
171+
and self.rec(expr.Data, other.Data))
172+
173+
# {{{ foreign
174+
175+
def map_tuple(self, expr, other) -> bool:
176+
return (
177+
len(expr) == len(other)
178+
and all(self.rec(el, other_el)
179+
for el, other_el in zip(expr, other)))
180+
181+
# }}}

pymbolic/polynomial.py

-7
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,6 @@ def traits(self):
9393
def __nonzero__(self):
9494
return len(self.Data) != 0
9595

96-
def __eq__(self, other):
97-
return isinstance(other, Polynomial) \
98-
and (self.Base == other.Base) \
99-
and (self.Data == other.Data)
100-
def __ne__(self, other):
101-
return not self.__eq__(other)
102-
10396
def __neg__(self):
10497
return Polynomial(self.Base,
10598
[(exp, -coeff)

pymbolic/primitives.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -505,12 +505,8 @@ def __eq__(self, other):
505505
Subclasses should generally not override this method, but instead
506506
provide an implementation of :meth:`is_equal`.
507507
"""
508-
if self is other:
509-
return True
510-
elif hash(self) != hash(other):
511-
return False
512-
else:
513-
return self.is_equal(other)
508+
from pymbolic.mapper.equality import EqualityMapper
509+
return EqualityMapper()(self, other)
514510

515511
def __ne__(self, other):
516512
return not self.__eq__(other)

0 commit comments

Comments
 (0)