Skip to content

Commit 9d9255b

Browse files
committed
add support for pymbolic.EqualityMapper
1 parent c96138c commit 9d9255b

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

pytato/scalar_expr.py

+19
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
DistributeMapperBase)
4242
from pymbolic.mapper.stringifier import (StringifyMapper as
4343
StringifyMapperBase)
44+
from pymbolic.mapper.equality import (EqualityMapper as
45+
EqualityMapperBase)
4446
from pymbolic.mapper import CombineMapper as CombineMapperBase
4547
from pymbolic.mapper.collector import TermCollector as TermCollectorBase
4648
from immutables import Map
@@ -184,6 +186,20 @@ def map_reduce(self, expr: Any, enclosing_prec: Any, *args: Any) -> str:
184186
bounds_expr = "{" + bounds_expr + "}"
185187
return (f"{expr.op}({bounds_expr}, {self.rec(expr.inner_expr, PN)})")
186188

189+
190+
class EqualityMapper(EqualityMapperBase):
191+
def map_reduce(self, expr: Reduce, other: Reduce) -> bool:
192+
return (
193+
len(expr.bounds) == len(other.bounds)
194+
and all(k == other_k
195+
and self.rec(lb, other_lb) and self.rec(ub, other_ub)
196+
for (k, (lb, ub)), (other_k, (other_lb, other_ub)) in zip(
197+
sorted(expr.bounds.items()),
198+
sorted(other.bounds.items())))
199+
and expr.op == other.op
200+
and self.rec(expr.inner_expr, other.inner_expr)
201+
)
202+
187203
# }}}
188204

189205

@@ -240,6 +256,9 @@ def distribute(expr: Any, parameters: FrozenSet[Any] = frozenset(),
240256
# {{{ custom scalar expression nodes
241257

242258
class ExpressionBase(prim.Expression):
259+
def make_equality_mapper(self) -> EqualityMapper:
260+
return EqualityMapper()
261+
243262
def make_stringifier(self, originating_stringifier: Any = None) -> str:
244263
return StringifyMapper()
245264

test/test_pytato.py

+2
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,8 @@ def test_userscollector():
355355

356356

357357
def test_asciidag():
358+
pytest.importorskip("asciidag")
359+
358360
n = pt.make_size_param("n")
359361
array = pt.make_placeholder(name="array", shape=n, dtype=np.float64)
360362
stack = pt.stack([array, 2*array, array + 6])

0 commit comments

Comments
 (0)