40
40
CallbackMapper as CallbackMapperBase ,
41
41
CSECachingMapperMixin ,
42
42
)
43
- from pymbolic .mapper .evaluator import \
44
- EvaluationMapper as EvaluationMapperBase
45
- from pymbolic .mapper .substitutor import \
46
- SubstitutionMapper as SubstitutionMapperBase
47
- from pymbolic .mapper .stringifier import \
48
- StringifyMapper as StringifyMapperBase
49
- from pymbolic .mapper .dependency import \
50
- DependencyMapper as DependencyMapperBase
51
- from pymbolic .mapper .coefficient import \
52
- CoefficientCollector as CoefficientCollectorBase
53
- from pymbolic .mapper .unifier import UnidirectionalUnifier \
54
- as UnidirectionalUnifierBase
55
- from pymbolic .mapper .constant_folder import \
56
- ConstantFoldingMapper as ConstantFoldingMapperBase
43
+ from pymbolic .mapper .equality import (
44
+ EqualityMapper as EqualityMapperBase )
45
+ from pymbolic .mapper .evaluator import (
46
+ EvaluationMapper as EvaluationMapperBase )
47
+ from pymbolic .mapper .substitutor import (
48
+ SubstitutionMapper as SubstitutionMapperBase )
49
+ from pymbolic .mapper .stringifier import (
50
+ StringifyMapper as StringifyMapperBase )
51
+ from pymbolic .mapper .dependency import (
52
+ DependencyMapper as DependencyMapperBase )
53
+ from pymbolic .mapper .coefficient import (
54
+ CoefficientCollector as CoefficientCollectorBase )
55
+ from pymbolic .mapper .unifier import (
56
+ UnidirectionalUnifier as UnidirectionalUnifierBase )
57
+ from pymbolic .mapper .constant_folder import (
58
+ ConstantFoldingMapper as ConstantFoldingMapperBase )
57
59
58
60
from pymbolic .parser import Parser as ParserBase
59
61
from loopy .diagnostic import LoopyError
60
- from loopy .diagnostic import (ExpressionToAffineConversionError ,
61
- UnableToDetermineAccessRangeError )
62
+ from loopy .diagnostic import (
63
+ ExpressionToAffineConversionError ,
64
+ UnableToDetermineAccessRangeError )
62
65
63
66
64
67
import islpy as isl
@@ -114,8 +117,11 @@ def map_literal(self, expr, *args, **kwargs):
114
117
return expr
115
118
116
119
def map_array_literal (self , expr , * args , ** kwargs ):
117
- return type (expr )(tuple (self .rec (ch , * args , ** kwargs )
118
- for ch in expr .children ))
120
+ children = [self .rec (ch , * args , ** kwargs ) for ch in expr .children ]
121
+ if all (ch is orig for ch , orig in zip (children , expr .children )):
122
+ return expr
123
+
124
+ return type (expr )(tuple (children ))
119
125
120
126
def map_group_hw_index (self , expr , * args , ** kwargs ):
121
127
return expr
@@ -484,6 +490,55 @@ def map_substitution(self, name, rule, arguments):
484
490
485
491
return self .rec (expr )
486
492
493
+
494
+ class EqualityMapper (EqualityMapperBase ):
495
+ def map_loopy_function_identifier (self , expr , other ) -> bool :
496
+ return True
497
+
498
+ def map_reduction (self , expr , other ) -> bool :
499
+ return (
500
+ expr .operation == other .operation
501
+ and expr .allow_simultaneous == other .allow_simultaneous
502
+ and self .rec (expr .expr , other .expr )
503
+ and all (iname == other_iname
504
+ for iname , other_iname in zip (expr .inames , other .inames )))
505
+
506
+ def map_group_hw_index (self , expr , other ) -> bool :
507
+ return expr .axis == other .axis
508
+
509
+ map_local_hw_index = map_group_hw_index
510
+
511
+ def map_rule_argument (self , expr , other ) -> bool :
512
+ return expr .index == other .index
513
+
514
+ def map_resolved_function (self , expr , other ) -> bool :
515
+ return self .rec (expr .function , other .function )
516
+
517
+ def map_sub_array_ref (self , expr , other ) -> bool :
518
+ return (
519
+ len (expr .swept_inames ) == len (other .swept_inames )
520
+ and self .rec (expr .subscript , other .subscript )
521
+ and all (self .rec (iname , other_iname )
522
+ for iname , other_iname in zip (
523
+ expr .swept_inames ,
524
+ other .swept_inames ))
525
+ )
526
+
527
+ def map_tagged_variable (self , expr , other ) -> bool :
528
+ return (
529
+ expr .name == other .name
530
+ and all (tag == other_tag
531
+ for tag , other_tag in zip (expr .tags , other .tags ))
532
+ )
533
+
534
+ def map_type_cast (self , expr , other ) -> bool :
535
+ return (
536
+ expr .type == other .type
537
+ and self .rec (expr .child , other .child ))
538
+
539
+ def map_fortran_division (self , expr , other ) -> bool :
540
+ return self .map_quotient (expr , other )
541
+
487
542
# }}}
488
543
489
544
@@ -497,15 +552,18 @@ def stringifier(self):
497
552
def make_stringifier (self , originating_stringifier = None ):
498
553
return StringifyMapper ()
499
554
555
+ def make_equality_mapper (self ):
556
+ return EqualityMapper ()
557
+
500
558
501
559
class Literal (LoopyExpressionBase ):
502
560
"""A literal to be used during code generation.
503
561
504
562
.. note::
505
563
506
564
Only used in the output of
507
- :mod :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
508
- similar mappers). Not for use in Loopy source representation.
565
+ :class :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
566
+ (and similar mappers). Not for use in :mod:`loopy` source representation.
509
567
"""
510
568
511
569
def __init__ (self , s ):
@@ -525,8 +583,8 @@ class ArrayLiteral(LoopyExpressionBase):
525
583
.. note::
526
584
527
585
Only used in the output of
528
- :mod :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
529
- similar mappers). Not for use in Loopy source representation.
586
+ :class :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
587
+ (and similar mappers). Not for use in :mod:`loopy` source representation.
530
588
"""
531
589
532
590
def __init__ (self , children ):
@@ -555,8 +613,8 @@ class GroupHardwareAxisIndex(HardwareAxisIndex):
555
613
.. note::
556
614
557
615
Only used in the output of
558
- :mod :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
559
- similar mappers). Not for use in Loopy source representation.
616
+ :class :`loopy.target.c.codegen. expression.ExpressionToCExpressionMapper`
617
+ (and similar mappers). Not for use in :mod:`loopy` source representation.
560
618
"""
561
619
mapper_method = "map_group_hw_index"
562
620
@@ -566,8 +624,8 @@ class LocalHardwareAxisIndex(HardwareAxisIndex):
566
624
.. note::
567
625
568
626
Only used in the output of
569
- :mod :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
570
- similar mappers). Not for use in Loopy source representation.
627
+ :class :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
628
+ similar mappers). Not for use in :mod:`loopy` source representation.
571
629
"""
572
630
mapper_method = "map_local_hw_index"
573
631
@@ -774,12 +832,6 @@ def __getinitargs__(self):
774
832
def get_hash (self ):
775
833
return hash ((self .__class__ , self .operation , self .inames , self .expr ))
776
834
777
- def is_equal (self , other ):
778
- return (other .__class__ == self .__class__
779
- and other .operation == self .operation
780
- and other .inames == self .inames
781
- and other .expr == self .expr )
782
-
783
835
@property
784
836
def is_tuple_typed (self ):
785
837
return self .operation .arg_count > 1
@@ -977,14 +1029,6 @@ def __getinitargs__(self):
977
1029
def get_hash (self ):
978
1030
return hash ((self .__class__ , self .swept_inames , self .subscript ))
979
1031
980
- def is_equal (self , other ):
981
- """
982
- Returns *True* iff the sub-array refs have identical expressions.
983
- """
984
- return (other .__class__ == self .__class__
985
- and other .subscript == self .subscript
986
- and other .swept_inames == self .swept_inames )
987
-
988
1032
def make_stringifier (self , originating_stringifier = None ):
989
1033
return StringifyMapper ()
990
1034
0 commit comments