31
31
)
32
32
from pymbolic .mapper .dependency import (
33
33
DependencyMapper as DependencyMapperBase )
34
+ from pymbolic .mapper .equality import (
35
+ EqualityMapper as EqualityMapperBase )
34
36
from pymbolic .geometric_algebra .mapper import (
35
37
CombineMapper as CombineMapperBase ,
36
38
IdentityMapper as IdentityMapperBase ,
51
53
import pytential .symbolic .primitives as prim
52
54
53
55
56
+ # {{{ IdentityMapper
57
+
54
58
def rec_int_g_arguments (mapper , expr ):
55
59
densities = mapper .rec (expr .densities )
56
60
kernel_arguments = {
@@ -138,6 +142,11 @@ def map_interpolation(self, expr):
138
142
return type (expr )(expr .from_dd , expr .to_dd , operand )
139
143
140
144
145
+ # }}}
146
+
147
+
148
+ # {{{ CombineMapper
149
+
141
150
class CombineMapper (CombineMapperBase ):
142
151
def map_node_sum (self , expr ):
143
152
return self .rec (expr .operand )
@@ -168,6 +177,10 @@ def map_is_shape_class(self, expr):
168
177
169
178
map_error_expression = map_is_shape_class
170
179
180
+ # }}}
181
+
182
+
183
+ # {{{ Collector
171
184
172
185
class Collector (CollectorBase , CombineMapper ):
173
186
def map_ones (self , expr ):
@@ -186,6 +199,10 @@ def map_int_g(self, expr):
186
199
class DependencyMapper (DependencyMapperBase , Collector ):
187
200
pass
188
201
202
+ # }}}
203
+
204
+
205
+ # {{{ EvaluationMapper
189
206
190
207
class EvaluationMapper (EvaluationMapperBase ):
191
208
"""Unlike :mod:`pymbolic.mapper.evaluation.EvaluationMapper`, this class
@@ -249,8 +266,10 @@ def map_common_subexpression(self, expr):
249
266
expr .prefix ,
250
267
expr .scope )
251
268
269
+ # }}}
270
+
252
271
253
- # {{{ dofdesc tagging
272
+ # {{{ dofdesc tagging: LocationTagger, ToTargetTagger
254
273
255
274
class LocationTagger (CSECachingMapperMixin , IdentityMapper ):
256
275
"""Used internally by :class:`ToTargetTagger`."""
@@ -655,6 +674,88 @@ def map_int_g(self, expr):
655
674
# }}}
656
675
657
676
677
+ # {{{ EqualityMapper
678
+
679
+ class EqualityMapper (EqualityMapperBase ):
680
+ def map_ones (self , expr , other ) -> bool :
681
+ return expr .dofdesc == other .dofdesc
682
+
683
+ map_q_weight = map_ones
684
+
685
+ def map_node_coordinate_component (self , expr , other ) -> bool :
686
+ return (
687
+ expr .ambient_axis == other .ambient_axis
688
+ and expr .dofdesc == other .dofdesc )
689
+
690
+ def map_num_reference_derivative (self , expr , other ) -> bool :
691
+ return (
692
+ expr .ref_axes == other .ref_axes
693
+ and expr .dofdesc == other .dofdesc
694
+ and self .rec (expr .operand , other .operand )
695
+ )
696
+
697
+ def map_node_sum (self , expr , other ) -> bool :
698
+ return self .rec (expr .operand , other .operand )
699
+
700
+ map_node_max = map_node_sum
701
+ map_node_min = map_node_sum
702
+
703
+ def map_elementwise_sum (self , expr , other ) -> bool :
704
+ return (
705
+ expr .dofdesc == other .dofdesc
706
+ and self .rec (expr .operand ) == other .operand )
707
+
708
+ map_elementwise_max = map_elementwise_sum
709
+ map_elementwise_min = map_elementwise_sum
710
+
711
+ def map_int_g (self , expr , other ) -> bool :
712
+ import numpy as np
713
+
714
+ def as_hashable (kernel_arg_value ):
715
+ # FIXME: this is here to match the fact that pickled IntGs get
716
+ # restored as tuples, not ndarray, so they don't equal anymore
717
+ if isinstance (kernel_arg_value , np .ndarray ):
718
+ return tuple (kernel_arg_value )
719
+ return kernel_arg_value
720
+
721
+ return (
722
+ expr .qbx_forced_limit == other .qbx_forced_limit
723
+ and expr .source == other .source
724
+ and expr .target == other .target
725
+ and len (expr .kernel_arguments ) == len (other .kernel_arguments )
726
+ and len (expr .source_kernels ) == len (other .source_kernels )
727
+ and len (expr .densities ) == len (other .densities )
728
+ and expr .target_kernel == other .target_kernel
729
+ and all (knl == other_knl for knl , other_knl in zip (
730
+ expr .source_kernels , other .source_kernels )
731
+ )
732
+ and all (d == other_d for d , other_d in zip (
733
+ expr .densities , other .densities ))
734
+ and all (k == other_k
735
+ and self .rec (as_hashable (v ), as_hashable (other_v ))
736
+ for (k , v ), (other_k , other_v ) in zip (
737
+ sorted (expr .kernel_arguments .items ()),
738
+ sorted (other .kernel_arguments .items ())))
739
+ )
740
+
741
+ def map_interpolation (self , expr , other ) -> bool :
742
+ return (
743
+ expr .from_dd == other .from_dd
744
+ and expr .to_dd == other .to_dd
745
+ and self .rec (expr .operand , other .operand ))
746
+
747
+ def map_is_shape_class (self , expr , other ) -> bool :
748
+ return (
749
+ expr .shape is other .shape ,
750
+ expr .dofdesc == other .dofdesc
751
+ )
752
+
753
+ def map_error_expression (self , expr , other ) -> bool :
754
+ return expr .message == other .message
755
+
756
+ # }}}
757
+
758
+
658
759
# {{{ stringifier
659
760
660
761
def stringify_where (where ):
@@ -768,13 +869,13 @@ def map_is_shape_class(self, expr, enclosing_prec):
768
869
return "IsShape[{}]({})" .format (stringify_where (expr .dofdesc ),
769
870
expr .shape .__name__ )
770
871
771
- # }}}
772
-
773
872
774
873
class PrettyStringifyMapper (
775
874
CSESplittingStringifyMapperMixin , StringifyMapper ):
776
875
pass
777
876
877
+ # }}}
878
+
778
879
779
880
# {{{ graphviz
780
881
0 commit comments