|
22 | 22 |
|
23 | 23 | from typing import Any, Dict, Tuple
|
24 | 24 |
|
25 |
| -from pymbolic.mapper import Mapper |
| 25 | +from pymbolic.mapper import Mapper, UnsupportedExpressionError |
26 | 26 | from pymbolic.primitives import Expression
|
27 | 27 |
|
28 | 28 |
|
29 | 29 | class EqualityMapper(Mapper):
|
30 | 30 | def __init__(self) -> None:
|
31 | 31 | self._ids_to_result: Dict[Tuple[int, int], bool] = {}
|
32 | 32 |
|
33 |
| - def __call__(self, expr: Any, other: Any) -> bool: |
| 33 | + def rec(self, expr: Any, other: Any) -> bool: |
34 | 34 | key = (id(expr), id(other))
|
35 |
| - if key in self._ids_to_result: |
36 |
| - return self._ids_to_result[key] |
37 |
| - |
38 |
| - if expr is other: |
39 |
| - result = True |
40 |
| - elif expr.__class__ != other.__class__: |
41 |
| - result = False |
42 |
| - else: |
43 |
| - try: |
44 |
| - method = getattr(self, expr.mapper_method) |
45 |
| - except AttributeError: |
46 |
| - if isinstance(expr, Expression): |
47 |
| - return self.handle_unsupported_expression(expr, other) |
48 |
| - else: |
49 |
| - return self.map_foreign(expr, other) |
| 35 | + |
| 36 | + try: |
| 37 | + result = self._ids_to_result[key] |
| 38 | + except KeyError: |
| 39 | + if expr is other: |
| 40 | + result = True |
| 41 | + elif expr.__class__ != other.__class__: |
| 42 | + result = False |
| 43 | + elif hash(expr) != hash(other): |
| 44 | + result = False |
50 | 45 | else:
|
51 |
| - result = method(expr, other) |
| 46 | + try: |
| 47 | + method = getattr(self, expr.mapper_method) |
| 48 | + except AttributeError: |
| 49 | + if isinstance(expr, Expression): |
| 50 | + result = self.handle_unsupported_expression(expr, other) |
| 51 | + else: |
| 52 | + result = self.map_foreign(expr, other) |
| 53 | + else: |
| 54 | + result = method(expr, other) |
| 55 | + |
| 56 | + self._ids_to_result[key] = result |
52 | 57 |
|
53 |
| - self._ids_to_result[key] = result |
54 | 58 | return result
|
55 | 59 |
|
56 |
| - rec = __call__ |
| 60 | + def __call__(self, expr: Any, other: Any) -> bool: |
| 61 | + return self.rec(expr, other) |
| 62 | + |
| 63 | + def handle_unsupported_expression(self, expr, other) -> bool: |
| 64 | + eq = expr.make_equality_mapper() |
| 65 | + if type(self) == type(eq): |
| 66 | + raise UnsupportedExpressionError( |
| 67 | + "'{}' cannot handle expressions of type '{}'".format( |
| 68 | + type(self).__name__, type(expr).__name__)) |
| 69 | + |
| 70 | + eq._ids_to_result = self._ids_to_result |
| 71 | + return eq(expr, other) |
| 72 | + |
| 73 | + def map_constant(self, expr, other) -> bool: |
| 74 | + return expr == other |
57 | 75 |
|
58 | 76 | def map_variable(self, expr, other) -> bool:
|
59 | 77 | return expr.name == other.name
|
|
0 commit comments