Skip to content

Commit 63a69fe

Browse files
committed
add support for pymbolic.EqualityMapper
1 parent 2dd9746 commit 63a69fe

File tree

4 files changed

+111
-6
lines changed

4 files changed

+111
-6
lines changed

pytential/symbolic/mappers.py

+104-3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
)
3232
from pymbolic.mapper.dependency import (
3333
DependencyMapper as DependencyMapperBase)
34+
from pymbolic.mapper.equality import (
35+
EqualityMapper as EqualityMapperBase)
3436
from pymbolic.geometric_algebra.mapper import (
3537
CombineMapper as CombineMapperBase,
3638
IdentityMapper as IdentityMapperBase,
@@ -51,6 +53,8 @@
5153
import pytential.symbolic.primitives as prim
5254

5355

56+
# {{{ IdentityMapper
57+
5458
def rec_int_g_arguments(mapper, expr):
5559
densities = mapper.rec(expr.densities)
5660
kernel_arguments = {
@@ -138,6 +142,11 @@ def map_interpolation(self, expr):
138142
return type(expr)(expr.from_dd, expr.to_dd, operand)
139143

140144

145+
# }}}
146+
147+
148+
# {{{ CombineMapper
149+
141150
class CombineMapper(CombineMapperBase):
142151
def map_node_sum(self, expr):
143152
return self.rec(expr.operand)
@@ -168,6 +177,10 @@ def map_is_shape_class(self, expr):
168177

169178
map_error_expression = map_is_shape_class
170179

180+
# }}}
181+
182+
183+
# {{{ Collector
171184

172185
class Collector(CollectorBase, CombineMapper):
173186
def map_ones(self, expr):
@@ -186,6 +199,10 @@ def map_int_g(self, expr):
186199
class DependencyMapper(DependencyMapperBase, Collector):
187200
pass
188201

202+
# }}}
203+
204+
205+
# {{{ EvaluationMapper
189206

190207
class EvaluationMapper(EvaluationMapperBase):
191208
"""Unlike :mod:`pymbolic.mapper.evaluation.EvaluationMapper`, this class
@@ -249,8 +266,10 @@ def map_common_subexpression(self, expr):
249266
expr.prefix,
250267
expr.scope)
251268

269+
# }}}
270+
252271

253-
# {{{ dofdesc tagging
272+
# {{{ dofdesc tagging: LocationTagger, ToTargetTagger
254273

255274
class LocationTagger(CSECachingMapperMixin, IdentityMapper):
256275
"""Used internally by :class:`ToTargetTagger`."""
@@ -655,6 +674,88 @@ def map_int_g(self, expr):
655674
# }}}
656675

657676

677+
# {{{ EqualityMapper
678+
679+
class EqualityMapper(EqualityMapperBase):
680+
def map_ones(self, expr, other) -> bool:
681+
return expr.dofdesc == other.dofdesc
682+
683+
map_q_weight = map_ones
684+
685+
def map_node_coordinate_component(self, expr, other) -> bool:
686+
return (
687+
expr.ambient_axis == other.ambient_axis
688+
and expr.dofdesc == other.dofdesc)
689+
690+
def map_num_reference_derivative(self, expr, other) -> bool:
691+
return (
692+
expr.ref_axes == other.ref_axes
693+
and expr.dofdesc == other.dofdesc
694+
and self.rec(expr.operand, other.operand)
695+
)
696+
697+
def map_node_sum(self, expr, other) -> bool:
698+
return self.rec(expr.operand, other.operand)
699+
700+
map_node_max = map_node_sum
701+
map_node_min = map_node_sum
702+
703+
def map_elementwise_sum(self, expr, other) -> bool:
704+
return (
705+
expr.dofdesc == other.dofdesc
706+
and self.rec(expr.operand) == other.operand)
707+
708+
map_elementwise_max = map_elementwise_sum
709+
map_elementwise_min = map_elementwise_sum
710+
711+
def map_int_g(self, expr, other) -> bool:
712+
import numpy as np
713+
714+
def as_hashable(kernel_arg_value):
715+
# FIXME: this is here to match the fact that pickled IntGs get
716+
# restored as tuples, not ndarray, so they don't equal anymore
717+
if isinstance(kernel_arg_value, np.ndarray):
718+
return tuple(kernel_arg_value)
719+
return kernel_arg_value
720+
721+
return (
722+
expr.qbx_forced_limit == other.qbx_forced_limit
723+
and expr.source == other.source
724+
and expr.target == other.target
725+
and len(expr.kernel_arguments) == len(other.kernel_arguments)
726+
and len(expr.source_kernels) == len(other.source_kernels)
727+
and len(expr.densities) == len(other.densities)
728+
and expr.target_kernel == other.target_kernel
729+
and all(knl == other_knl for knl, other_knl in zip(
730+
expr.source_kernels, other.source_kernels)
731+
)
732+
and all(d == other_d for d, other_d in zip(
733+
expr.densities, other.densities))
734+
and all(k == other_k
735+
and self.rec(as_hashable(v), as_hashable(other_v))
736+
for (k, v), (other_k, other_v) in zip(
737+
sorted(expr.kernel_arguments.items()),
738+
sorted(other.kernel_arguments.items())))
739+
)
740+
741+
def map_interpolation(self, expr, other) -> bool:
742+
return (
743+
expr.from_dd == other.from_dd
744+
and expr.to_dd == other.to_dd
745+
and self.rec(expr.operand, other.operand))
746+
747+
def map_is_shape_class(self, expr, other) -> bool:
748+
return (
749+
expr.shape is other.shape,
750+
expr.dofdesc == other.dofdesc
751+
)
752+
753+
def map_error_expression(self, expr, other) -> bool:
754+
return expr.message == other.message
755+
756+
# }}}
757+
758+
658759
# {{{ stringifier
659760

660761
def stringify_where(where):
@@ -768,13 +869,13 @@ def map_is_shape_class(self, expr, enclosing_prec):
768869
return "IsShape[{}]({})".format(stringify_where(expr.dofdesc),
769870
expr.shape.__name__)
770871

771-
# }}}
772-
773872

774873
class PrettyStringifyMapper(
775874
CSESplittingStringifyMapperMixin, StringifyMapper):
776875
pass
777876

877+
# }}}
878+
778879

779880
# {{{ graphviz
780881

pytential/symbolic/primitives.py

+4
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,10 @@ def array_to_tuple(ary):
248248

249249

250250
class Expression(ExpressionBase):
251+
def make_equality_mapper(self):
252+
from pytential.symbolic.mappers import EqualityMapper
253+
return EqualityMapper()
254+
251255
def make_stringifier(self, originating_stringifier=None):
252256
from pytential.symbolic.mappers import StringifyMapper
253257
return StringifyMapper()

requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
numpy != 1.22.0
33

44
git+https://github.com/inducer/pytools.git#egg=pytools
5-
git+https://github.com/inducer/pymbolic.git#egg=pymbolic
5+
git+https://github.com/alexfikl/pymbolic.git@equality-mapper#egg=pymbolic
66
sympy
77
git+https://github.com/inducer/modepy.git#egg=modepy
88
git+https://github.com/inducer/pyopencl.git#egg=pyopencl
99
git+https://github.com/inducer/islpy.git#egg=islpy
10-
git+https://github.com/inducer/loopy.git#egg=loopy
10+
git+https://github.com/alexfikl/loopy.git@equality-mapper#egg=loopy
1111
git+https://github.com/inducer/boxtree.git#egg=boxtree
1212
git+https://github.com/inducer/arraycontext.git#egg=arraycontext
1313
git+https://github.com/inducer/meshmode.git#egg=meshmode

test/test_symbolic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def test_derivative_binder_expr():
380380
d1, d2 = principal_directions(ambient_dim, dim=dim)
381381
expr = (d1 @ d2 + d1 @ d1) / (d2 @ d2)
382382

383-
nruns = 4
383+
nruns = 1
384384
for i in range(nruns):
385385
from pytools import ProcessTimer
386386
with ProcessTimer() as pd:

0 commit comments

Comments
 (0)