Skip to content

Commit 182b6ca

Browse files
committed
add EqualityMapper to follow pymbolic
1 parent 9c1e37d commit 182b6ca

File tree

2 files changed

+76
-27
lines changed

2 files changed

+76
-27
lines changed

loopy/symbolic.py

+75-26
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
CSECachingMapperMixin,
4444
)
4545
import immutables
46+
from pymbolic.mapper.equality import (
47+
EqualityMapper as EqualityMapperBase)
4648
from pymbolic.mapper.evaluator import \
4749
CachedEvaluationMapper as EvaluationMapperBase
4850
from pymbolic.mapper.substitutor import \
@@ -60,8 +62,9 @@
6062

6163
from pymbolic.parser import Parser as ParserBase
6264
from loopy.diagnostic import LoopyError
63-
from loopy.diagnostic import (ExpressionToAffineConversionError,
64-
UnableToDetermineAccessRangeError)
65+
from loopy.diagnostic import (
66+
ExpressionToAffineConversionError,
67+
UnableToDetermineAccessRangeError)
6568

6669

6770
import islpy as isl
@@ -117,8 +120,11 @@ def map_literal(self, expr, *args, **kwargs):
117120
return expr
118121

119122
def map_array_literal(self, expr, *args, **kwargs):
120-
return type(expr)(tuple(self.rec(ch, *args, **kwargs)
121-
for ch in expr.children))
123+
children = [self.rec(ch, *args, **kwargs) for ch in expr.children]
124+
if all(ch is orig for ch, orig in zip(children, expr.children)):
125+
return expr
126+
127+
return type(expr)(tuple(children))
122128

123129
def map_group_hw_index(self, expr, *args, **kwargs):
124130
return expr
@@ -501,6 +507,60 @@ def map_substitution(self, name, rule, arguments):
501507

502508
return self.rec(expr)
503509

510+
511+
class EqualityMapper(EqualityMapperBase):
512+
def map_loopy_function_identifier(self, expr, other) -> bool:
513+
return True
514+
515+
def map_reduction(self, expr, other) -> bool:
516+
return (
517+
expr.operation == other.operation
518+
and expr.allow_simultaneous == other.allow_simultaneous
519+
and self.rec(expr.expr, other.expr)
520+
and all(iname == other_iname
521+
for iname, other_iname in zip(expr.inames, other.inames)))
522+
523+
def map_group_hw_index(self, expr, other) -> bool:
524+
return expr.axis == other.axis
525+
526+
map_local_hw_index = map_group_hw_index
527+
528+
def map_linear_subscript(self, expr, other) -> bool:
529+
return (
530+
self.rec(expr.index, other.index)
531+
and self.rec(expr.aggregate, other.aggregate))
532+
533+
def map_rule_argument(self, expr, other) -> bool:
534+
return expr.index == other.index
535+
536+
def map_resolved_function(self, expr, other) -> bool:
537+
return self.rec(expr.function, other.function)
538+
539+
def map_sub_array_ref(self, expr, other) -> bool:
540+
return (
541+
len(expr.swept_inames) == len(other.swept_inames)
542+
and self.rec(expr.subscript, other.subscript)
543+
and all(self.rec(iname, other_iname)
544+
for iname, other_iname in zip(
545+
expr.swept_inames,
546+
other.swept_inames))
547+
)
548+
549+
def map_tagged_variable(self, expr, other) -> bool:
550+
return (
551+
expr.name == other.name
552+
and all(tag == other_tag
553+
for tag, other_tag in zip(expr.tags, other.tags))
554+
)
555+
556+
def map_type_cast(self, expr, other) -> bool:
557+
return (
558+
expr.type == other.type
559+
and self.rec(expr.child, other.child))
560+
561+
def map_fortran_division(self, expr, other) -> bool:
562+
return self.map_quotient(expr, other)
563+
504564
# }}}
505565

506566

@@ -514,15 +574,18 @@ def stringifier(self):
514574
def make_stringifier(self, originating_stringifier=None):
515575
return StringifyMapper()
516576

577+
def make_equality_mapper(self):
578+
return EqualityMapper()
579+
517580

518581
class Literal(LoopyExpressionBase):
519582
"""A literal to be used during code generation.
520583
521584
.. note::
522585
523586
Only used in the output of
524-
:mod:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
525-
similar mappers). Not for use in Loopy source representation.
587+
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
588+
(and similar mappers). Not for use in :mod:`loopy` source representation.
526589
"""
527590

528591
def __init__(self, s):
@@ -542,8 +605,8 @@ class ArrayLiteral(LoopyExpressionBase):
542605
.. note::
543606
544607
Only used in the output of
545-
:mod:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
546-
similar mappers). Not for use in Loopy source representation.
608+
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
609+
(and similar mappers). Not for use in :mod:`loopy` source representation.
547610
"""
548611

549612
def __init__(self, children):
@@ -572,8 +635,8 @@ class GroupHardwareAxisIndex(HardwareAxisIndex):
572635
.. note::
573636
574637
Only used in the output of
575-
:mod:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
576-
similar mappers). Not for use in Loopy source representation.
638+
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
639+
(and similar mappers). Not for use in :mod:`loopy` source representation.
577640
"""
578641
mapper_method = "map_group_hw_index"
579642

@@ -583,8 +646,8 @@ class LocalHardwareAxisIndex(HardwareAxisIndex):
583646
.. note::
584647
585648
Only used in the output of
586-
:mod:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
587-
similar mappers). Not for use in Loopy source representation.
649+
:class:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
650+
similar mappers). Not for use in :mod:`loopy` source representation.
588651
"""
589652
mapper_method = "map_local_hw_index"
590653

@@ -791,12 +854,6 @@ def __getinitargs__(self):
791854
def get_hash(self):
792855
return hash((self.__class__, self.operation, self.inames, self.expr))
793856

794-
def is_equal(self, other):
795-
return (other.__class__ == self.__class__
796-
and other.operation == self.operation
797-
and other.inames == self.inames
798-
and other.expr == self.expr)
799-
800857
@property
801858
def is_tuple_typed(self):
802859
return self.operation.arg_count > 1
@@ -994,14 +1051,6 @@ def __getinitargs__(self):
9941051
def get_hash(self):
9951052
return hash((self.__class__, self.swept_inames, self.subscript))
9961053

997-
def is_equal(self, other):
998-
"""
999-
Returns *True* iff the sub-array refs have identical expressions.
1000-
"""
1001-
return (other.__class__ == self.__class__
1002-
and other.subscript == self.subscript
1003-
and other.swept_inames == self.swept_inames)
1004-
10051054
def make_stringifier(self, originating_stringifier=None):
10061055
return StringifyMapper()
10071056

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ git+https://github.com/inducer/pytools.git#egg=pytools >= 2021.1
22
git+https://github.com/inducer/islpy.git#egg=islpy
33
git+https://github.com/inducer/cgen.git#egg=cgen
44
git+https://github.com/inducer/pyopencl.git#egg=pyopencl
5-
git+https://github.com/inducer/pymbolic.git#egg=pymbolic
5+
git+https://github.com/alexfikl/pymbolic.git@equality-mapper#egg=pymbolic
66
git+https://github.com/inducer/genpy.git#egg=genpy
77
git+https://github.com/inducer/codepy.git#egg=codepy
88

0 commit comments

Comments
 (0)