|
41 | 41 | DistributeMapperBase)
|
42 | 42 | from pymbolic.mapper.stringifier import (StringifyMapper as
|
43 | 43 | StringifyMapperBase)
|
| 44 | +from pymbolic.mapper.equality import (EqualityMapper as |
| 45 | + EqualityMapperBase) |
44 | 46 | from pymbolic.mapper import CombineMapper as CombineMapperBase
|
45 | 47 | from pymbolic.mapper.collector import TermCollector as TermCollectorBase
|
46 | 48 | from immutables import Map
|
@@ -184,6 +186,20 @@ def map_reduce(self, expr: Any, enclosing_prec: Any, *args: Any) -> str:
|
184 | 186 | bounds_expr = "{" + bounds_expr + "}"
|
185 | 187 | return (f"{expr.op}({bounds_expr}, {self.rec(expr.inner_expr, PN)})")
|
186 | 188 |
|
| 189 | + |
| 190 | +class EqualityMapper(EqualityMapperBase): |
| 191 | + def map_reduce(self, expr: Reduce, other: Reduce) -> bool: |
| 192 | + return ( |
| 193 | + len(expr.bounds) == len(other.bounds) |
| 194 | + and all(k == other_k |
| 195 | + and self.rec(lb, other_lb) and self.rec(ub, other_ub) |
| 196 | + for (k, (lb, ub)), (other_k, (other_lb, other_ub)) in zip( |
| 197 | + sorted(expr.bounds.items()), |
| 198 | + sorted(other.bounds.items()))) |
| 199 | + and expr.op == other.op |
| 200 | + and self.rec(expr.inner_expr, other.inner_expr) |
| 201 | + ) |
| 202 | + |
187 | 203 | # }}}
|
188 | 204 |
|
189 | 205 |
|
@@ -240,6 +256,9 @@ def distribute(expr: Any, parameters: FrozenSet[Any] = frozenset(),
|
240 | 256 | # {{{ custom scalar expression nodes
|
241 | 257 |
|
242 | 258 | class ExpressionBase(prim.Expression):
|
| 259 | + def make_equality_mapper(self) -> EqualityMapper: |
| 260 | + return EqualityMapper() |
| 261 | + |
243 | 262 | def make_stringifier(self, originating_stringifier: Any = None) -> str:
|
244 | 263 | return StringifyMapper()
|
245 | 264 |
|
|
0 commit comments