43
43
CSECachingMapperMixin ,
44
44
)
45
45
import immutables
46
+ from pymbolic .mapper .equality import (
47
+ EqualityMapper as EqualityMapperBase )
46
48
from pymbolic .mapper .evaluator import \
47
49
CachedEvaluationMapper as EvaluationMapperBase
48
50
from pymbolic .mapper .substitutor import \
60
62
61
63
from pymbolic .parser import Parser as ParserBase
62
64
from loopy .diagnostic import LoopyError
63
- from loopy .diagnostic import (ExpressionToAffineConversionError ,
64
- UnableToDetermineAccessRangeError )
65
+ from loopy .diagnostic import (
66
+ ExpressionToAffineConversionError ,
67
+ UnableToDetermineAccessRangeError )
65
68
66
69
67
70
import islpy as isl
@@ -117,8 +120,11 @@ def map_literal(self, expr, *args, **kwargs):
117
120
return expr
118
121
119
122
def map_array_literal (self , expr , * args , ** kwargs ):
120
- return type (expr )(tuple (self .rec (ch , * args , ** kwargs )
121
- for ch in expr .children ))
123
+ children = [self .rec (ch , * args , ** kwargs ) for ch in expr .children ]
124
+ if all (ch is orig for ch , orig in zip (children , expr .children )):
125
+ return expr
126
+
127
+ return type (expr )(tuple (children ))
122
128
123
129
def map_group_hw_index (self , expr , * args , ** kwargs ):
124
130
return expr
@@ -501,6 +507,60 @@ def map_substitution(self, name, rule, arguments):
501
507
502
508
return self .rec (expr )
503
509
510
+
511
+ class EqualityMapper (EqualityMapperBase ):
512
+ def map_loopy_function_identifier (self , expr , other ) -> bool :
513
+ return True
514
+
515
+ def map_reduction (self , expr , other ) -> bool :
516
+ return (
517
+ expr .operation == other .operation
518
+ and expr .allow_simultaneous == other .allow_simultaneous
519
+ and self .rec (expr .expr , other .expr )
520
+ and all (iname == other_iname
521
+ for iname , other_iname in zip (expr .inames , other .inames )))
522
+
523
+ def map_group_hw_index (self , expr , other ) -> bool :
524
+ return expr .axis == other .axis
525
+
526
+ map_local_hw_index = map_group_hw_index
527
+
528
+ def map_linear_subscript (self , expr , other ) -> bool :
529
+ return (
530
+ self .rec (expr .index , other .index )
531
+ and self .rec (expr .aggregate , other .aggregate ))
532
+
533
+ def map_rule_argument (self , expr , other ) -> bool :
534
+ return expr .index == other .index
535
+
536
+ def map_resolved_function (self , expr , other ) -> bool :
537
+ return self .rec (expr .function , other .function )
538
+
539
+ def map_sub_array_ref (self , expr , other ) -> bool :
540
+ return (
541
+ len (expr .swept_inames ) == len (other .swept_inames )
542
+ and self .rec (expr .subscript , other .subscript )
543
+ and all (self .rec (iname , other_iname )
544
+ for iname , other_iname in zip (
545
+ expr .swept_inames ,
546
+ other .swept_inames ))
547
+ )
548
+
549
+ def map_tagged_variable (self , expr , other ) -> bool :
550
+ return (
551
+ expr .name == other .name
552
+ and all (tag == other_tag
553
+ for tag , other_tag in zip (expr .tags , other .tags ))
554
+ )
555
+
556
+ def map_type_cast (self , expr , other ) -> bool :
557
+ return (
558
+ expr .type == other .type
559
+ and self .rec (expr .child , other .child ))
560
+
561
+ def map_fortran_division (self , expr , other ) -> bool :
562
+ return self .map_quotient (expr , other )
563
+
504
564
# }}}
505
565
506
566
@@ -514,15 +574,18 @@ def stringifier(self):
514
574
def make_stringifier (self , originating_stringifier = None ):
515
575
return StringifyMapper ()
516
576
577
+ def make_equality_mapper (self ):
578
+ return EqualityMapper ()
579
+
517
580
518
581
class Literal (LoopyExpressionBase ):
519
582
"""A literal to be used during code generation.
520
583
521
584
.. note::
522
585
523
586
Only used in the output of
524
- :mod :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
525
- similar mappers). Not for use in Loopy source representation.
587
+ :class :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
588
+ (and similar mappers). Not for use in :mod:`loopy` source representation.
526
589
"""
527
590
528
591
def __init__ (self , s ):
@@ -542,8 +605,8 @@ class ArrayLiteral(LoopyExpressionBase):
542
605
.. note::
543
606
544
607
Only used in the output of
545
- :mod :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
546
- similar mappers). Not for use in Loopy source representation.
608
+ :class :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
609
+ (and similar mappers). Not for use in :mod:`loopy` source representation.
547
610
"""
548
611
549
612
def __init__ (self , children ):
@@ -572,8 +635,8 @@ class GroupHardwareAxisIndex(HardwareAxisIndex):
572
635
.. note::
573
636
574
637
Only used in the output of
575
- :mod :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
576
- similar mappers). Not for use in Loopy source representation.
638
+ :class :`loopy.target.c.codegen. expression.ExpressionToCExpressionMapper`
639
+ (and similar mappers). Not for use in :mod:`loopy` source representation.
577
640
"""
578
641
mapper_method = "map_group_hw_index"
579
642
@@ -583,8 +646,8 @@ class LocalHardwareAxisIndex(HardwareAxisIndex):
583
646
.. note::
584
647
585
648
Only used in the output of
586
- :mod :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
587
- similar mappers). Not for use in Loopy source representation.
649
+ :class :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
650
+ similar mappers). Not for use in :mod:`loopy` source representation.
588
651
"""
589
652
mapper_method = "map_local_hw_index"
590
653
@@ -791,12 +854,6 @@ def __getinitargs__(self):
791
854
def get_hash (self ):
792
855
return hash ((self .__class__ , self .operation , self .inames , self .expr ))
793
856
794
- def is_equal (self , other ):
795
- return (other .__class__ == self .__class__
796
- and other .operation == self .operation
797
- and other .inames == self .inames
798
- and other .expr == self .expr )
799
-
800
857
@property
801
858
def is_tuple_typed (self ):
802
859
return self .operation .arg_count > 1
@@ -994,14 +1051,6 @@ def __getinitargs__(self):
994
1051
def get_hash (self ):
995
1052
return hash ((self .__class__ , self .swept_inames , self .subscript ))
996
1053
997
- def is_equal (self , other ):
998
- """
999
- Returns *True* iff the sub-array refs have identical expressions.
1000
- """
1001
- return (other .__class__ == self .__class__
1002
- and other .subscript == self .subscript
1003
- and other .swept_inames == self .swept_inames )
1004
-
1005
1054
def make_stringifier (self , originating_stringifier = None ):
1006
1055
return StringifyMapper ()
1007
1056
0 commit comments