43
43
CSECachingMapperMixin ,
44
44
)
45
45
import immutables
46
+ from pymbolic .mapper .equality import (
47
+ EqualityMapper as EqualityMapperBase )
46
48
from pymbolic .mapper .evaluator import \
47
49
CachedEvaluationMapper as EvaluationMapperBase
48
50
from pymbolic .mapper .substitutor import \
@@ -501,6 +503,60 @@ def map_substitution(self, name, rule, arguments):
501
503
502
504
return self .rec (expr )
503
505
506
+
507
+ class EqualityMapper (EqualityMapperBase ):
508
+ def map_loopy_function_identifier (self , expr , other ) -> bool :
509
+ return True
510
+
511
+ def map_reduction (self , expr , other ) -> bool :
512
+ return (
513
+ expr .operation == other .operation
514
+ and expr .allow_simultaneous == other .allow_simultaneous
515
+ and self .rec (expr .expr , other .expr )
516
+ and all (iname == other_iname
517
+ for iname , other_iname in zip (expr .inames , other .inames )))
518
+
519
+ def map_group_hw_index (self , expr , other ) -> bool :
520
+ return expr .axis == other .axis
521
+
522
+ map_local_hw_index = map_group_hw_index
523
+
524
+ def map_linear_subscript (self , expr , other ) -> bool :
525
+ return (
526
+ self .rec (expr .index , other .index )
527
+ and self .rec (expr .aggregate , other .aggregate ))
528
+
529
+ def map_rule_argument (self , expr , other ) -> bool :
530
+ return expr .index == other .index
531
+
532
+ def map_resolved_function (self , expr , other ) -> bool :
533
+ return self .rec (expr .function , other .function )
534
+
535
+ def map_sub_array_ref (self , expr , other ) -> bool :
536
+ return (
537
+ len (expr .swept_inames ) == len (other .swept_inames )
538
+ and self .rec (expr .subscript , other .subscript )
539
+ and all (self .rec (iname , other_iname )
540
+ for iname , other_iname in zip (
541
+ expr .swept_inames ,
542
+ other .swept_inames ))
543
+ )
544
+
545
+ def map_tagged_variable (self , expr , other ) -> bool :
546
+ return (
547
+ expr .name == other .name
548
+ and all (tag == other_tag
549
+ for tag , other_tag in zip (expr .tags , other .tags ))
550
+ )
551
+
552
+ def map_type_cast (self , expr , other ) -> bool :
553
+ return (
554
+ expr .type == other .type
555
+ and self .rec (expr .child , other .child ))
556
+
557
+ def map_fortran_division (self , expr , other ) -> bool :
558
+ return self .map_quotient (expr , other )
559
+
504
560
# }}}
505
561
506
562
@@ -514,15 +570,18 @@ def stringifier(self):
514
570
def make_stringifier (self , originating_stringifier = None ):
515
571
return StringifyMapper ()
516
572
573
+ def make_equality_mapper (self ):
574
+ return EqualityMapper ()
575
+
517
576
518
577
class Literal (LoopyExpressionBase ):
519
578
"""A literal to be used during code generation.
520
579
521
580
.. note::
522
581
523
582
Only used in the output of
524
- :mod :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
525
- similar mappers). Not for use in Loopy source representation.
583
+ :class :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
584
+ (and similar mappers). Not for use in :mod:`loopy` source representation.
526
585
"""
527
586
528
587
def __init__ (self , s ):
@@ -542,8 +601,8 @@ class ArrayLiteral(LoopyExpressionBase):
542
601
.. note::
543
602
544
603
Only used in the output of
545
- :mod :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
546
- similar mappers). Not for use in Loopy source representation.
604
+ :class :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
605
+ (and similar mappers). Not for use in :mod:`loopy` source representation.
547
606
"""
548
607
549
608
def __init__ (self , children ):
@@ -572,8 +631,8 @@ class GroupHardwareAxisIndex(HardwareAxisIndex):
572
631
.. note::
573
632
574
633
Only used in the output of
575
- :mod :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
576
- similar mappers). Not for use in Loopy source representation.
634
+ :class :`loopy.target.c.codegen. expression.ExpressionToCExpressionMapper`
635
+ (and similar mappers). Not for use in :mod:`loopy` source representation.
577
636
"""
578
637
mapper_method = "map_group_hw_index"
579
638
@@ -583,8 +642,8 @@ class LocalHardwareAxisIndex(HardwareAxisIndex):
583
642
.. note::
584
643
585
644
Only used in the output of
586
- :mod :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
587
- similar mappers). Not for use in Loopy source representation.
645
+ :class :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
646
+ similar mappers). Not for use in :mod:`loopy` source representation.
588
647
"""
589
648
mapper_method = "map_local_hw_index"
590
649
@@ -791,12 +850,6 @@ def __getinitargs__(self):
791
850
def get_hash (self ):
792
851
return hash ((self .__class__ , self .operation , self .inames , self .expr ))
793
852
794
- def is_equal (self , other ):
795
- return (other .__class__ == self .__class__
796
- and other .operation == self .operation
797
- and other .inames == self .inames
798
- and other .expr == self .expr )
799
-
800
853
@property
801
854
def is_tuple_typed (self ):
802
855
return self .operation .arg_count > 1
@@ -994,14 +1047,6 @@ def __getinitargs__(self):
994
1047
def get_hash (self ):
995
1048
return hash ((self .__class__ , self .swept_inames , self .subscript ))
996
1049
997
- def is_equal (self , other ):
998
- """
999
- Returns *True* iff the sub-array refs have identical expressions.
1000
- """
1001
- return (other .__class__ == self .__class__
1002
- and other .subscript == self .subscript
1003
- and other .swept_inames == self .swept_inames )
1004
-
1005
1050
def make_stringifier (self , originating_stringifier = None ):
1006
1051
return StringifyMapper ()
1007
1052
0 commit comments