|
31 | 31 | )
|
32 | 32 | from pymbolic.mapper.dependency import (
|
33 | 33 | DependencyMapper as DependencyMapperBase)
|
| 34 | +from pymbolic.mapper.equality import ( |
| 35 | + EqualityMapper as EqualityMapperBase) |
34 | 36 | from pymbolic.geometric_algebra.mapper import (
|
35 | 37 | CombineMapper as CombineMapperBase,
|
36 | 38 | IdentityMapper as IdentityMapperBase,
|
@@ -676,6 +678,80 @@ def map_int_g(self, expr):
|
676 | 678 | # }}}
|
677 | 679 |
|
678 | 680 |
|
| 681 | +# {{{ EqualityMapper |
| 682 | + |
| 683 | +class EqualityMapper(EqualityMapperBase): |
| 684 | + def map_ones(self, expr, other) -> bool: |
| 685 | + return expr.dofdesc == other.dofdesc |
| 686 | + |
| 687 | + map_q_weight = map_ones |
| 688 | + |
| 689 | + def map_node_coordinate_component(self, expr, other) -> bool: |
| 690 | + return ( |
| 691 | + expr.ambient_axis == other.ambient_axis |
| 692 | + and expr.dofdesc == other.dofdesc) |
| 693 | + |
| 694 | + def map_num_reference_derivative(self, expr, other) -> bool: |
| 695 | + return ( |
| 696 | + expr.ref_axes == other.ref_axes |
| 697 | + and expr.dofdesc == other.dofdesc |
| 698 | + and self.rec(expr.operand, other.operand) |
| 699 | + ) |
| 700 | + |
| 701 | + def map_node_sum(self, expr, other) -> bool: |
| 702 | + return self.rec(expr.operand, other.operand) |
| 703 | + |
| 704 | + map_node_max = map_node_sum |
| 705 | + map_node_min = map_node_sum |
| 706 | + |
| 707 | + def map_elementwise_sum(self, expr, other) -> bool: |
| 708 | + return ( |
| 709 | + expr.dofdesc == other.dofdesc |
| 710 | + and self.rec(expr.operand, other.operand)) |
| 711 | + |
| 712 | + map_elementwise_max = map_elementwise_sum |
| 713 | + map_elementwise_min = map_elementwise_sum |
| 714 | + |
| 715 | + def map_int_g(self, expr, other) -> bool: |
| 716 | + from pytential.symbolic.primitives import hashable_kernel_args |
| 717 | + return ( |
| 718 | + expr.qbx_forced_limit == other.qbx_forced_limit |
| 719 | + and expr.source == other.source |
| 720 | + and expr.target == other.target |
| 721 | + and len(expr.kernel_arguments) == len(other.kernel_arguments) |
| 722 | + and len(expr.source_kernels) == len(other.source_kernels) |
| 723 | + and len(expr.densities) == len(other.densities) |
| 724 | + and expr.target_kernel == other.target_kernel |
| 725 | + and all(knl == other_knl for knl, other_knl in zip( |
| 726 | + expr.source_kernels, other.source_kernels) |
| 727 | + ) |
| 728 | + and all(d == other_d for d, other_d in zip( |
| 729 | + expr.densities, other.densities)) |
| 730 | + and all(k == other_k |
| 731 | + and self.rec(v, other_v) |
| 732 | + for (k, v), (other_k, other_v) in zip( |
| 733 | + sorted(hashable_kernel_args(expr.kernel_arguments)), |
| 734 | + sorted(hashable_kernel_args(other.kernel_arguments)))) |
| 735 | + ) |
| 736 | + |
| 737 | + def map_interpolation(self, expr, other) -> bool: |
| 738 | + return ( |
| 739 | + expr.from_dd == other.from_dd |
| 740 | + and expr.to_dd == other.to_dd |
| 741 | + and self.rec(expr.operand, other.operand)) |
| 742 | + |
| 743 | + def map_is_shape_class(self, expr, other) -> bool: |
| 744 | + return ( |
| 745 | + expr.shape is other.shape, |
| 746 | + expr.dofdesc == other.dofdesc |
| 747 | + ) |
| 748 | + |
| 749 | + def map_error_expression(self, expr, other) -> bool: |
| 750 | + return expr.message == other.message |
| 751 | + |
| 752 | +# }}} |
| 753 | + |
| 754 | + |
679 | 755 | # {{{ StringifyMapper
|
680 | 756 |
|
681 | 757 | def stringify_where(where):
|
|
0 commit comments