50
50
CSECachingMapperMixin ,
51
51
)
52
52
import immutables
53
+ from pymbolic .mapper .equality import (
54
+ EqualityMapper as EqualityMapperBase )
53
55
from pymbolic .mapper .evaluator import \
54
56
CachedEvaluationMapper as EvaluationMapperBase
55
57
from pymbolic .mapper .substitutor import \
@@ -502,6 +504,60 @@ def map_substitution(self, name, rule, arguments):
502
504
503
505
return self .rec (expr )
504
506
507
+
508
+ class EqualityMapper (EqualityMapperBase ):
509
+ def map_loopy_function_identifier (self , expr , other ) -> bool :
510
+ return True
511
+
512
+ def map_reduction (self , expr , other ) -> bool :
513
+ return (
514
+ expr .operation == other .operation
515
+ and expr .allow_simultaneous == other .allow_simultaneous
516
+ and self .rec (expr .expr , other .expr )
517
+ and all (iname == other_iname
518
+ for iname , other_iname in zip (expr .inames , other .inames )))
519
+
520
+ def map_group_hw_index (self , expr , other ) -> bool :
521
+ return expr .axis == other .axis
522
+
523
+ map_local_hw_index = map_group_hw_index
524
+
525
+ def map_linear_subscript (self , expr , other ) -> bool :
526
+ return (
527
+ self .rec (expr .index , other .index )
528
+ and self .rec (expr .aggregate , other .aggregate ))
529
+
530
+ def map_rule_argument (self , expr , other ) -> bool :
531
+ return expr .index == other .index
532
+
533
+ def map_resolved_function (self , expr , other ) -> bool :
534
+ return self .rec (expr .function , other .function )
535
+
536
+ def map_sub_array_ref (self , expr , other ) -> bool :
537
+ return (
538
+ len (expr .swept_inames ) == len (other .swept_inames )
539
+ and self .rec (expr .subscript , other .subscript )
540
+ and all (self .rec (iname , other_iname )
541
+ for iname , other_iname in zip (
542
+ expr .swept_inames ,
543
+ other .swept_inames ))
544
+ )
545
+
546
+ def map_tagged_variable (self , expr , other ) -> bool :
547
+ return (
548
+ expr .name == other .name
549
+ and all (tag == other_tag
550
+ for tag , other_tag in zip (expr .tags , other .tags ))
551
+ )
552
+
553
+ def map_type_cast (self , expr , other ) -> bool :
554
+ return (
555
+ expr .type == other .type
556
+ and self .rec (expr .child , other .child ))
557
+
558
+ def map_fortran_division (self , expr , other ) -> bool :
559
+ return self .map_quotient (expr , other )
560
+
505
561
# }}}
506
562
507
563
@@ -515,15 +571,18 @@ def stringifier(self):
515
571
def make_stringifier (self , originating_stringifier = None ):
516
572
return StringifyMapper ()
517
573
574
+ def make_equality_mapper (self ):
575
+ return EqualityMapper ()
576
+
518
577
519
578
class Literal (LoopyExpressionBase ):
520
579
"""A literal to be used during code generation.
521
580
522
581
.. note::
523
582
524
583
Only used in the output of
525
- :mod :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
526
- similar mappers). Not for use in Loopy source representation.
584
+ :class :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
585
+ (and similar mappers). Not for use in :mod:`loopy` source representation.
527
586
"""
528
587
529
588
def __init__ (self , s ):
@@ -543,8 +602,8 @@ class ArrayLiteral(LoopyExpressionBase):
543
602
.. note::
544
603
545
604
Only used in the output of
546
- :mod :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
547
- similar mappers). Not for use in Loopy source representation.
605
+ :class :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
606
+ (and similar mappers). Not for use in :mod:`loopy` source representation.
548
607
"""
549
608
550
609
def __init__ (self , children ):
@@ -573,8 +632,8 @@ class GroupHardwareAxisIndex(HardwareAxisIndex):
573
632
.. note::
574
633
575
634
Only used in the output of
576
- :mod :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
577
- similar mappers). Not for use in Loopy source representation.
635
+ :class :`loopy.target.c.codegen. expression.ExpressionToCExpressionMapper`
636
+ (and similar mappers). Not for use in :mod:`loopy` source representation.
578
637
"""
579
638
mapper_method = "map_group_hw_index"
580
639
@@ -584,8 +643,8 @@ class LocalHardwareAxisIndex(HardwareAxisIndex):
584
643
.. note::
585
644
586
645
Only used in the output of
587
- :mod :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
588
- similar mappers). Not for use in Loopy source representation.
646
+ :class :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
647
+ similar mappers). Not for use in :mod:`loopy` source representation.
589
648
"""
590
649
mapper_method = "map_local_hw_index"
591
650
@@ -792,12 +851,6 @@ def __getinitargs__(self):
792
851
def get_hash (self ):
793
852
return hash ((self .__class__ , self .operation , self .inames , self .expr ))
794
853
795
- def is_equal (self , other ):
796
- return (other .__class__ == self .__class__
797
- and other .operation == self .operation
798
- and other .inames == self .inames
799
- and other .expr == self .expr )
800
-
801
854
@property
802
855
def is_tuple_typed (self ):
803
856
return self .operation .arg_count > 1
@@ -994,14 +1047,6 @@ def __getinitargs__(self):
994
1047
def get_hash (self ):
995
1048
return hash ((self .__class__ , self .swept_inames , self .subscript ))
996
1049
997
- def is_equal (self , other ):
998
- """
999
- Returns *True* iff the sub-array refs have identical expressions.
1000
- """
1001
- return (other .__class__ == self .__class__
1002
- and other .subscript == self .subscript
1003
- and other .swept_inames == self .swept_inames )
1004
-
1005
1050
def make_stringifier (self , originating_stringifier = None ):
1006
1051
return StringifyMapper ()
1007
1052
0 commit comments