Skip to content

Commit fff5af6

Browse files
committed
add support for pymbolic.EqualityMapper
1 parent c80701f commit fff5af6

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

pytato/scalar_expr.py

+19
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
DistributeMapperBase)
4141
from pymbolic.mapper.stringifier import (StringifyMapper as
4242
StringifyMapperBase)
43+
from pymbolic.mapper.equality import (EqualityMapper as
44+
EqualityMapperBase)
4345
from pymbolic.mapper.collector import TermCollector as TermCollectorBase
4446
import pymbolic.primitives as prim
4547
import numpy as np
@@ -169,6 +171,20 @@ def map_reduce(self, expr: Any, enclosing_prec: Any, *args: Any) -> str:
169171
bounds_expr = "{" + bounds_expr + "}"
170172
return (f"{expr.op}({bounds_expr}, {self.rec(expr.inner_expr, PN)})")
171173

174+
175+
class EqualityMapper(EqualityMapperBase):
176+
def map_reduce(self, expr: Reduce, other: Reduce) -> bool:
177+
return (
178+
len(expr.bounds) == len(other.bounds)
179+
and all(k == other_k
180+
and self.rec(lb, other_lb) and self.rec(ub, other_ub)
181+
for (k, (lb, ub)), (other_k, (other_lb, other_ub)) in zip(
182+
sorted(expr.bounds.items()),
183+
sorted(other.bounds.items())))
184+
and expr.op == other.op
185+
and self.rec(expr.inner_expr, other.inner_expr)
186+
)
187+
172188
# }}}
173189

174190

@@ -225,6 +241,9 @@ def distribute(expr: Any, parameters: FrozenSet[Any] = frozenset(),
225241
# {{{ custom scalar expression nodes
226242

227243
class ExpressionBase(prim.Expression):
244+
def make_equality_mapper(self) -> EqualityMapper:
245+
return EqualityMapper()
246+
228247
def make_stringifier(self, originating_stringifier: Any = None) -> str:
229248
return StringifyMapper()
230249

test/test_pytato.py

+2
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,8 @@ def test_userscollector():
355355

356356

357357
def test_asciidag():
358+
pytest.importorskip("asciidag")
359+
358360
n = pt.make_size_param("n")
359361
array = pt.make_placeholder(name="array", shape=n, dtype=np.float64)
360362
stack = pt.stack([array, 2*array, array + 6])

0 commit comments

Comments
 (0)