Skip to content

Commit 36b5e68

Browse files
committed
implement an EqualityMapper with caching
1 parent 1d23d56 commit 36b5e68

File tree

4 files changed

+266
-28
lines changed

4 files changed

+266
-28
lines changed

pymbolic/mapper/__init__.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ def handle_unsupported_expression(self, expr, *args, **kwargs):
105105
"""
106106

107107
raise UnsupportedExpressionError(
108-
"{} cannot handle expressions of type {}".format(
109-
type(self), type(expr)))
108+
"'{}' cannot handle expressions of type '{}'".format(
109+
type(self).__name__, type(expr).__name__))
110110

111111
def __call__(self, expr, *args, **kwargs):
112112
"""Dispatch *expr* to its corresponding mapper method. Pass on
@@ -140,7 +140,7 @@ def __call__(self, expr, *args, **kwargs):
140140
rec = __call__
141141

142142
def map_algebraic_leaf(self, expr, *args, **kwargs):
143-
raise NotImplementedError
143+
raise NotImplementedError(type(expr).__name__)
144144

145145
def map_variable(self, expr, *args, **kwargs):
146146
return self.map_algebraic_leaf(expr, *args, **kwargs)
@@ -425,6 +425,7 @@ def map_subscript(self, expr, *args, **kwargs):
425425
index = self.rec(expr.index, *args, **kwargs)
426426
if aggregate is expr.aggregate and index is expr.index:
427427
return expr
428+
428429
return type(expr)(aggregate, index)
429430

430431
def map_lookup(self, expr, *args, **kwargs):

pymbolic/mapper/equality.py

+247
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
__copyright__ = "Copyright (C) 2021 Alexandru Fikl"
2+
3+
__license__ = """
4+
Permission is hereby granted, free of charge, to any person obtaining a copy
5+
of this software and associated documentation files (the "Software"), to deal
6+
in the Software without restriction, including without limitation the rights
7+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8+
copies of the Software, and to permit persons to whom the Software is
9+
furnished to do so, subject to the following conditions:
10+
11+
The above copyright notice and this permission notice shall be included in
12+
all copies or substantial portions of the Software.
13+
14+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20+
THE SOFTWARE.
21+
"""
22+
23+
from typing import Any, Dict, Tuple
24+
25+
from pymbolic.mapper import Mapper, UnsupportedExpressionError
26+
from pymbolic.primitives import Expression
27+
28+
29+
class EqualityMapper(Mapper):
30+
__slots__ = ["_ids_to_result"]
31+
32+
def __init__(self) -> None:
33+
self._ids_to_result: Dict[Tuple[int, int], bool] = {}
34+
35+
def rec(self, expr: Any, other: Any) -> bool:
36+
key = (id(expr), id(other))
37+
if key in self._ids_to_result:
38+
return self._ids_to_result[key]
39+
40+
if expr is other:
41+
result = True
42+
elif expr.__class__ != other.__class__:
43+
result = False
44+
else:
45+
try:
46+
method = getattr(self, expr.mapper_method)
47+
except AttributeError:
48+
if isinstance(expr, Expression):
49+
result = self.handle_unsupported_expression(expr, other)
50+
else:
51+
result = self.map_foreign(expr, other)
52+
else:
53+
result = method(expr, other)
54+
55+
self._ids_to_result[key] = result
56+
return result
57+
58+
def __call__(self, expr: Any, other: Any) -> bool:
59+
return self.rec(expr, other)
60+
61+
# {{{ handle_unsupported_expression
62+
63+
def handle_unsupported_expression(self, expr, other) -> bool:
64+
eq = expr.make_equality_mapper()
65+
if type(self) == type(eq):
66+
raise UnsupportedExpressionError(
67+
"'{}' cannot handle expressions of type '{}'".format(
68+
type(self).__name__, type(expr).__name__))
69+
70+
# NOTE: this may look fishy, but we want to preserve the cache as we
71+
# go through the expression tree, so that it does not do
72+
# unnecessary checks when we change the mapper for some subclass
73+
eq._ids_to_result = self._ids_to_result
74+
75+
return eq(expr, other)
76+
77+
# }}}
78+
79+
# {{{ foreign
80+
81+
def map_tuple(self, expr, other) -> bool:
82+
return (
83+
len(expr) == len(other)
84+
and all(self.rec(el, other_el)
85+
for el, other_el in zip(expr, other)))
86+
87+
def map_foreign(self, expr, other) -> bool:
88+
from pymbolic.primitives import VALID_CONSTANT_CLASSES
89+
90+
if isinstance(expr, VALID_CONSTANT_CLASSES):
91+
return expr == other
92+
elif isinstance(expr, tuple):
93+
return self.map_tuple(expr, other)
94+
else:
95+
raise ValueError(
96+
f"{type(self).__name__} encountered invalid foreign object: "
97+
f"{expr!r}")
98+
99+
# }}}
100+
101+
# {{{
102+
103+
# NOTE: `type(expr) == type(other)` is checked in `__call__`, so the
104+
# checks below can assume that the two operands always have the same type
105+
106+
# NOTE: as much as possible, these should try to put the "cheap" checks
107+
# first so that the shortcircuiting removes the need to to extra work
108+
109+
# NOTE: `all` is also shortcircuiting, so should be better to use a
110+
# generator there to avoid extra work
111+
112+
def map_nan(self, expr, other) -> bool:
113+
return True
114+
115+
def map_wildcard(self, expr, other) -> bool:
116+
return True
117+
118+
def map_function_symbol(self, expr, other) -> bool:
119+
return True
120+
121+
def map_variable(self, expr, other) -> bool:
122+
return expr.name == other.name
123+
124+
def map_subscript(self, expr, other) -> bool:
125+
return (
126+
self.rec(expr.index, other.index)
127+
and self.rec(expr.aggregate, other.aggregate))
128+
129+
def map_lookup(self, expr, other) -> bool:
130+
return (
131+
expr.name == other.name
132+
and self.rec(expr.aggregate, other.aggregate))
133+
134+
def map_call(self, expr, other) -> bool:
135+
return (
136+
len(expr.parameters) == len(other.parameters)
137+
and self.rec(expr.function, other.function)
138+
and all(self.rec(p, other_p)
139+
for p, other_p in zip(expr.parameters, other.parameters)))
140+
141+
def map_call_with_kwargs(self, expr, other) -> bool:
142+
return (
143+
len(expr.parameters) == len(other.parameters)
144+
and len(expr.kw_parameters) == len(other.kw_parameters)
145+
and self.rec(expr.function, other.function)
146+
and all(self.rec(p, other_p)
147+
for p, other_p in zip(expr.parameters, other.parameters))
148+
and all(k == other_k and self.rec(v, other_v)
149+
for (k, v), (other_k, other_v) in zip(
150+
expr.kw_parameters.items(),
151+
other.kw_parameters.items())))
152+
153+
def map_sum(self, expr, other) -> bool:
154+
return (
155+
len(expr.children) == len(other.children)
156+
and all(self.rec(child, other_child)
157+
for child, other_child in zip(expr.children, other.children))
158+
)
159+
160+
map_slice = map_sum
161+
map_product = map_sum
162+
map_min = map_sum
163+
map_max = map_sum
164+
165+
def map_bitwise_not(self, expr, other) -> bool:
166+
return self.rec(expr.child, other.child)
167+
168+
map_bitwise_and = map_sum
169+
map_bitwise_or = map_sum
170+
map_bitwise_xor = map_sum
171+
map_logical_and = map_sum
172+
map_logical_or = map_sum
173+
map_logical_not = map_bitwise_not
174+
175+
def map_quotient(self, expr, other) -> bool:
176+
return (
177+
self.rec(expr.numerator, other.numerator)
178+
and self.rec(expr.denominator, other.denominator)
179+
)
180+
181+
map_floor_div = map_quotient
182+
map_remainder = map_quotient
183+
184+
def map_power(self, expr, other) -> bool:
185+
return (
186+
self.rec(expr.base, other.base)
187+
and self.rec(expr.exponent, other.exponent)
188+
)
189+
190+
def map_left_shift(self, expr, other) -> bool:
191+
return (
192+
self.rec(expr.shift, other.shift)
193+
and self.rec(expr.shiftee, other.shiftee))
194+
195+
map_right_shift = map_left_shift
196+
197+
def map_comparison(self, expr, other) -> bool:
198+
return (
199+
expr.operator == other.operator
200+
and self.rec(expr.left, other.left)
201+
and self.rec(expr.right, other.right))
202+
203+
def map_if(self, expr, other) -> bool:
204+
return (
205+
self.rec(expr.condition, other.condition)
206+
and self.rec(expr.then, other.then)
207+
and self.rec(expr.else_, other.else_))
208+
209+
def map_if_positive(self, expr, other) -> bool:
210+
return (
211+
self.rec(expr.criterion, other.criterion)
212+
and self.rec(expr.then, other.then)
213+
and self.rec(expr.else_, other.else_))
214+
215+
def map_common_subexpression(self, expr, other) -> bool:
216+
return (
217+
expr.prefix == other.prefix
218+
and expr.scope == other.scope
219+
and self.rec(expr.child, other.child)
220+
and all(k == other_k and v == other_v
221+
for (k, v), (other_k, other_v) in zip(
222+
expr.get_extra_properties(),
223+
other.get_extra_properties())))
224+
225+
def map_substitution(self, expr, other) -> bool:
226+
return (
227+
len(expr.variables) == len(other.variables)
228+
and len(expr.values) == len(other.values)
229+
and expr.variables == other.variables
230+
and self.rec(expr.child, other.child)
231+
and all(self.rec(v, other_v)
232+
for v, other_v in zip(expr.values, other.values))
233+
)
234+
235+
def map_derivative(self, expr, other) -> bool:
236+
return (
237+
len(expr.variables) == len(other.variables)
238+
and self.rec(expr.child, other.child)
239+
and all(self.rec(v, other_v)
240+
for v, other_v in zip(expr.variables, other.variables)))
241+
242+
def map_polynomial(self, expr, other) -> bool:
243+
return (
244+
self.rec(expr.Base, other.Data)
245+
and self.rec(expr.Data, other.Data))
246+
247+
# }}}

pymbolic/polynomial.py

-7
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,6 @@ def traits(self):
9393
def __nonzero__(self):
9494
return len(self.Data) != 0
9595

96-
def __eq__(self, other):
97-
return isinstance(other, Polynomial) \
98-
and (self.Base == other.Base) \
99-
and (self.Data == other.Data)
100-
def __ne__(self, other):
101-
return not self.__eq__(other)
102-
10396
def __neg__(self):
10497
return Polynomial(self.Base,
10598
[(exp, -coeff)

pymbolic/primitives.py

+15-18
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ class Expression:
201201
.. automethod:: make_stringifier
202202
203203
.. automethod:: __eq__
204-
.. automethod:: is_equal
204+
.. automethod:: make_equality_mapper
205205
.. automethod:: __hash__
206206
.. automethod:: get_hash
207207
.. automethod:: __str__
@@ -507,18 +507,12 @@ def __repr__(self):
507507
# {{{ hash/equality interface
508508

509509
def __eq__(self, other):
510-
"""Provides equality testing with quick positive and negative paths
511-
based on :func:`id` and :meth:`__hash__`.
510+
"""Provides equality testing with quick positive and negative paths.
512511
513512
Subclasses should generally not override this method, but instead
514-
provide an implementation of :meth:`is_equal`.
513+
provide an implementation of :meth:`make_equality_mapper`.
515514
"""
516-
if self is other:
517-
return True
518-
elif hash(self) != hash(other):
519-
return False
520-
else:
521-
return self.is_equal(other)
515+
return self.make_equality_mapper()(self, other)
522516

523517
def __ne__(self, other):
524518
return not self.__eq__(other)
@@ -551,9 +545,18 @@ def __setstate__(self, state):
551545

552546
# {{{ hash/equality backend
553547

548+
def make_equality_mapper(self):
549+
from pymbolic.mapper.equality import EqualityMapper
550+
return EqualityMapper()
551+
554552
def is_equal(self, other):
555-
return (type(other) == type(self)
556-
and self.__getinitargs__() == other.__getinitargs__())
553+
from warnings import warn
554+
warn("'Expression.is_equal' is deprecated and will be removed in 2023. "
555+
"To customize the equality check, subclass 'EqualityMapper' "
556+
"and overwrite 'Expression.make_equality_mapper'",
557+
DeprecationWarning, stacklevel=2)
558+
559+
return self.make_equality_mapper()(self, other)
557560

558561
def get_hash(self):
559562
return hash((type(self).__name__,) + self.__getinitargs__())
@@ -1034,12 +1037,6 @@ class Quotient(QuotientBase):
10341037
.. attribute:: denominator
10351038
"""
10361039

1037-
def is_equal(self, other):
1038-
from pymbolic.rational import Rational
1039-
return isinstance(other, (Rational, Quotient)) \
1040-
and (self.numerator == other.numerator) \
1041-
and (self.denominator == other.denominator)
1042-
10431040
mapper_method = intern("map_quotient")
10441041

10451042

0 commit comments

Comments
 (0)