Skip to content

Commit e1cf378

Browse files
committed
add support for pymbolic.EqualityMapper
1 parent b29a492 commit e1cf378

File tree

4 files changed

+88
-3
lines changed

4 files changed

+88
-3
lines changed

pytential/symbolic/mappers.py

+76
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
)
3232
from pymbolic.mapper.dependency import (
3333
DependencyMapper as DependencyMapperBase)
34+
from pymbolic.mapper.equality import (
35+
EqualityMapper as EqualityMapperBase)
3436
from pymbolic.geometric_algebra.mapper import (
3537
CombineMapper as CombineMapperBase,
3638
IdentityMapper as IdentityMapperBase,
@@ -676,6 +678,80 @@ def map_int_g(self, expr):
676678
# }}}
677679

678680

681+
# {{{ EqualityMapper
682+
683+
class EqualityMapper(EqualityMapperBase):
684+
def map_ones(self, expr, other) -> bool:
685+
return expr.dofdesc == other.dofdesc
686+
687+
map_q_weight = map_ones
688+
689+
def map_node_coordinate_component(self, expr, other) -> bool:
690+
return (
691+
expr.ambient_axis == other.ambient_axis
692+
and expr.dofdesc == other.dofdesc)
693+
694+
def map_num_reference_derivative(self, expr, other) -> bool:
695+
return (
696+
expr.ref_axes == other.ref_axes
697+
and expr.dofdesc == other.dofdesc
698+
and self.rec(expr.operand, other.operand)
699+
)
700+
701+
def map_node_sum(self, expr, other) -> bool:
702+
return self.rec(expr.operand, other.operand)
703+
704+
map_node_max = map_node_sum
705+
map_node_min = map_node_sum
706+
707+
def map_elementwise_sum(self, expr, other) -> bool:
708+
return (
709+
expr.dofdesc == other.dofdesc
710+
and self.rec(expr.operand, other.operand))
711+
712+
map_elementwise_max = map_elementwise_sum
713+
map_elementwise_min = map_elementwise_sum
714+
715+
def map_int_g(self, expr, other) -> bool:
716+
from pytential.symbolic.primitives import hashable_kernel_args
717+
return (
718+
expr.qbx_forced_limit == other.qbx_forced_limit
719+
and expr.source == other.source
720+
and expr.target == other.target
721+
and len(expr.kernel_arguments) == len(other.kernel_arguments)
722+
and len(expr.source_kernels) == len(other.source_kernels)
723+
and len(expr.densities) == len(other.densities)
724+
and expr.target_kernel == other.target_kernel
725+
and all(knl == other_knl for knl, other_knl in zip(
726+
expr.source_kernels, other.source_kernels)
727+
)
728+
and all(d == other_d for d, other_d in zip(
729+
expr.densities, other.densities))
730+
and all(k == other_k
731+
and self.rec(v, other_v)
732+
for (k, v), (other_k, other_v) in zip(
733+
sorted(hashable_kernel_args(expr.kernel_arguments)),
734+
sorted(hashable_kernel_args(other.kernel_arguments))))
735+
)
736+
737+
def map_interpolation(self, expr, other) -> bool:
738+
return (
739+
expr.from_dd == other.from_dd
740+
and expr.to_dd == other.to_dd
741+
and self.rec(expr.operand, other.operand))
742+
743+
def map_is_shape_class(self, expr, other) -> bool:
744+
return (
745+
expr.shape is other.shape,
746+
expr.dofdesc == other.dofdesc
747+
)
748+
749+
def map_error_expression(self, expr, other) -> bool:
750+
return expr.message == other.message
751+
752+
# }}}
753+
754+
679755
# {{{ StringifyMapper
680756

681757
def stringify_where(where):

pytential/symbolic/primitives.py

+4
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,10 @@ def array_to_tuple(ary):
248248

249249

250250
class Expression(ExpressionBase):
251+
def make_equality_mapper(self):
252+
from pytential.symbolic.mappers import EqualityMapper
253+
return EqualityMapper()
254+
251255
def make_stringifier(self, originating_stringifier=None):
252256
from pytential.symbolic.mappers import StringifyMapper
253257
return StringifyMapper()

requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
numpy != 1.22.0
33

44
git+https://github.com/inducer/pytools.git#egg=pytools
5-
git+https://github.com/inducer/pymbolic.git#egg=pymbolic
5+
git+https://github.com/alexfikl/pymbolic.git@equality-mapper#egg=pymbolic
66
sympy
77
git+https://github.com/inducer/modepy.git#egg=modepy
88
git+https://github.com/inducer/pyopencl.git#egg=pyopencl
99
git+https://github.com/inducer/islpy.git#egg=islpy
10-
git+https://github.com/inducer/loopy.git#egg=loopy
10+
git+https://github.com/alexfikl/loopy.git@equality-mapper#egg=loopy
1111
git+https://github.com/inducer/boxtree.git#egg=boxtree
1212
git+https://github.com/inducer/arraycontext.git#egg=arraycontext
1313
git+https://github.com/inducer/meshmode.git#egg=meshmode

test/test_symbolic.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def test_derivative_binder_expr():
380380
d1, d2 = principal_directions(ambient_dim, dim=dim)
381381
expr = (d1 @ d2 + d1 @ d1) / (d2 @ d2)
382382

383-
nruns = 4
383+
nruns = 1
384384
for i in range(nruns):
385385
from pytools import ProcessTimer
386386
with ProcessTimer() as pd:
@@ -446,6 +446,8 @@ def is_base_kernel(knl):
446446

447447
@pytest.mark.parametrize("op_name", ["dirichlet", "neumann"])
448448
def test_mapper_int_g_term_collector(op_name, k=0):
449+
logging.basicConfig(level=logging.INFO)
450+
449451
ambient_dim = 3
450452
op = _make_operator(ambient_dim, op_name, k)
451453
expr = op.operator(op.get_density_var("sigma"))
@@ -463,6 +465,9 @@ def test_mapper_int_g_term_collector(op_name, k=0):
463465
else:
464466
raise ValueError(f"unknown operator name: {op_name}")
465467

468+
print(sym.pretty(expr_only_intgs))
469+
print(sym.pretty(expected_expr))
470+
466471
assert expr_only_intgs == expected_expr
467472

468473
# }}}

0 commit comments

Comments
 (0)