Skip to content

Commit a9773e5

Browse files
committed
allow defining a different EqualityMapper in Expression
1 parent 26f55f1 commit a9773e5

File tree

4 files changed

+53
-27
lines changed

4 files changed

+53
-27
lines changed

pymbolic/geometric_algebra/mapper.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def map_nabla_component(self, expr):
179179
return {expr}
180180

181181
def map_derivative_source(self, expr):
182-
return {expr, self.rec(expr.operand)}
182+
return {expr} | self.rec(expr.operand)
183183

184184

185185
class NablaComponentToUnitVector(EvaluationMapper):

pymbolic/mapper/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ def handle_unsupported_expression(self, expr, *args, **kwargs):
104104
"""
105105

106106
raise UnsupportedExpressionError(
107-
"{} cannot handle expressions of type {}".format(
108-
type(self), type(expr)))
107+
"'{}' cannot handle expressions of type '{}'".format(
108+
type(self).__name__, type(expr).__name__))
109109

110110
def __call__(self, expr, *args, **kwargs):
111111
"""Dispatch *expr* to its corresponding mapper method. Pass on

pymbolic/mapper/equality.py

+38-20
Original file line numberDiff line numberDiff line change
@@ -22,38 +22,56 @@
2222

2323
from typing import Any, Dict, Tuple
2424

25-
from pymbolic.mapper import Mapper
25+
from pymbolic.mapper import Mapper, UnsupportedExpressionError
2626
from pymbolic.primitives import Expression
2727

2828

2929
class EqualityMapper(Mapper):
3030
def __init__(self) -> None:
3131
self._ids_to_result: Dict[Tuple[int, int], bool] = {}
3232

33-
def __call__(self, expr: Any, other: Any) -> bool:
33+
def rec(self, expr: Any, other: Any) -> bool:
3434
key = (id(expr), id(other))
35-
if key in self._ids_to_result:
36-
return self._ids_to_result[key]
37-
38-
if expr is other:
39-
result = True
40-
elif expr.__class__ != other.__class__:
41-
result = False
42-
else:
43-
try:
44-
method = getattr(self, expr.mapper_method)
45-
except AttributeError:
46-
if isinstance(expr, Expression):
47-
return self.handle_unsupported_expression(expr, other)
48-
else:
49-
return self.map_foreign(expr, other)
35+
36+
try:
37+
result = self._ids_to_result[key]
38+
except KeyError:
39+
if expr is other:
40+
result = True
41+
elif expr.__class__ != other.__class__:
42+
result = False
43+
elif hash(expr) != hash(other):
44+
result = False
5045
else:
51-
result = method(expr, other)
46+
try:
47+
method = getattr(self, expr.mapper_method)
48+
except AttributeError:
49+
if isinstance(expr, Expression):
50+
result = self.handle_unsupported_expression(expr, other)
51+
else:
52+
result = self.map_foreign(expr, other)
53+
else:
54+
result = method(expr, other)
55+
56+
self._ids_to_result[key] = result
5257

53-
self._ids_to_result[key] = result
5458
return result
5559

56-
rec = __call__
60+
def __call__(self, expr: Any, other: Any) -> bool:
61+
return self.rec(expr, other)
62+
63+
def handle_unsupported_expression(self, expr, other) -> bool:
64+
eq = expr.make_equality_mapper()
65+
if type(self) == type(eq):
66+
raise UnsupportedExpressionError(
67+
"'{}' cannot handle expressions of type '{}'".format(
68+
type(self).__name__, type(expr).__name__))
69+
70+
eq._ids_to_result = self._ids_to_result
71+
return eq(expr, other)
72+
73+
def map_constant(self, expr, other) -> bool:
74+
return expr == other
5775

5876
def map_variable(self, expr, other) -> bool:
5977
return expr.name == other.name

pymbolic/primitives.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -505,8 +505,7 @@ def __eq__(self, other):
505505
Subclasses should generally not override this method, but instead
506506
provide an implementation of :meth:`is_equal`.
507507
"""
508-
from pymbolic.mapper.equality import EqualityMapper
509-
return EqualityMapper()(self, other)
508+
return self.make_equality_mapper()(self, other)
510509

511510
def __ne__(self, other):
512511
return not self.__eq__(other)
@@ -539,9 +538,18 @@ def __setstate__(self, state):
539538

540539
# {{{ hash/equality backend
541540

541+
def make_equality_mapper(self):
542+
from pymbolic.mapper.equality import EqualityMapper
543+
return EqualityMapper()
544+
542545
def is_equal(self, other):
543-
return (type(other) == type(self)
544-
and self.__getinitargs__() == other.__getinitargs__())
546+
from warnings import warn
547+
warn("'Expression.is_equal' is deprecated and will be removed in 2023. "
548+
"To customize the equality check, subclass 'EqualityMapper' "
549+
"and overwrite 'Expression.make_equality_mapper'",
550+
DeprecationWarning, stacklevel=2)
551+
552+
return self.make_equality_mapper()(self, other)
545553

546554
def get_hash(self):
547555
return hash((type(self).__name__,) + self.__getinitargs__())

0 commit comments

Comments
 (0)