|
40 | 40 | DistributeMapperBase)
|
41 | 41 | from pymbolic.mapper.stringifier import (StringifyMapper as
|
42 | 42 | StringifyMapperBase)
|
| 43 | +from pymbolic.mapper.equality import (EqualityMapper as |
| 44 | + EqualityMapperBase) |
43 | 45 | from pymbolic.mapper.collector import TermCollector as TermCollectorBase
|
44 | 46 | import pymbolic.primitives as prim
|
45 | 47 | import numpy as np
|
@@ -169,6 +171,20 @@ def map_reduce(self, expr: Any, enclosing_prec: Any, *args: Any) -> str:
|
169 | 171 | bounds_expr = "{" + bounds_expr + "}"
|
170 | 172 | return (f"{expr.op}({bounds_expr}, {self.rec(expr.inner_expr, PN)})")
|
171 | 173 |
|
| 174 | + |
| 175 | +class EqualityMapper(EqualityMapperBase): |
| 176 | + def map_reduce(self, expr: Reduce, other: Reduce) -> bool: |
| 177 | + return ( |
| 178 | + len(expr.bounds) == len(other.bounds) |
| 179 | + and all(k == other_k |
| 180 | + and self.rec(lb, other_lb) and self.rec(ub, other_ub) |
| 181 | + for (k, (lb, ub)), (other_k, (other_lb, other_ub)) in zip( |
| 182 | + sorted(expr.bounds.items()), |
| 183 | + sorted(other.bounds.items()))) |
| 184 | + and expr.op == other.op |
| 185 | + and self.rec(expr.inner_expr, other.inner_expr) |
| 186 | + ) |
| 187 | + |
172 | 188 | # }}}
|
173 | 189 |
|
174 | 190 |
|
@@ -225,6 +241,9 @@ def distribute(expr: Any, parameters: FrozenSet[Any] = frozenset(),
|
225 | 241 | # {{{ custom scalar expression nodes
|
226 | 242 |
|
227 | 243 | class ExpressionBase(prim.Expression):
|
| 244 | + def make_equality_mapper(self) -> EqualityMapper: |
| 245 | + return EqualityMapper() |
| 246 | + |
228 | 247 | def make_stringifier(self, originating_stringifier: Any = None) -> str:
|
229 | 248 | return StringifyMapper()
|
230 | 249 |
|
|
0 commit comments