Skip to content

Commit fda3bdc

Browse files
committed
add EqualityMapper to follow pymbolic
1 parent befa5cb commit fda3bdc

File tree

2 files changed

+85
-41
lines changed

2 files changed

+85
-41
lines changed

loopy/symbolic.py

+84-40
Original file line numberDiff line numberDiff line change
@@ -40,25 +40,28 @@
4040
CallbackMapper as CallbackMapperBase,
4141
CSECachingMapperMixin,
4242
)
43-
from pymbolic.mapper.evaluator import \
44-
EvaluationMapper as EvaluationMapperBase
45-
from pymbolic.mapper.substitutor import \
46-
SubstitutionMapper as SubstitutionMapperBase
47-
from pymbolic.mapper.stringifier import \
48-
StringifyMapper as StringifyMapperBase
49-
from pymbolic.mapper.dependency import \
50-
DependencyMapper as DependencyMapperBase
51-
from pymbolic.mapper.coefficient import \
52-
CoefficientCollector as CoefficientCollectorBase
53-
from pymbolic.mapper.unifier import UnidirectionalUnifier \
54-
as UnidirectionalUnifierBase
55-
from pymbolic.mapper.constant_folder import \
56-
ConstantFoldingMapper as ConstantFoldingMapperBase
43+
from pymbolic.mapper.equality import (
44+
EqualityMapper as EqualityMapperBase)
45+
from pymbolic.mapper.evaluator import (
46+
EvaluationMapper as EvaluationMapperBase)
47+
from pymbolic.mapper.substitutor import (
48+
SubstitutionMapper as SubstitutionMapperBase)
49+
from pymbolic.mapper.stringifier import (
50+
StringifyMapper as StringifyMapperBase)
51+
from pymbolic.mapper.dependency import (
52+
DependencyMapper as DependencyMapperBase)
53+
from pymbolic.mapper.coefficient import (
54+
CoefficientCollector as CoefficientCollectorBase)
55+
from pymbolic.mapper.unifier import (
56+
UnidirectionalUnifier as UnidirectionalUnifierBase)
57+
from pymbolic.mapper.constant_folder import (
58+
ConstantFoldingMapper as ConstantFoldingMapperBase)
5759

5860
from pymbolic.parser import Parser as ParserBase
5961
from loopy.diagnostic import LoopyError
60-
from loopy.diagnostic import (ExpressionToAffineConversionError,
61-
UnableToDetermineAccessRangeError)
62+
from loopy.diagnostic import (
63+
ExpressionToAffineConversionError,
64+
UnableToDetermineAccessRangeError)
6265

6366

6467
import islpy as isl
@@ -114,8 +117,11 @@ def map_literal(self, expr, *args, **kwargs):
114117
return expr
115118

116119
def map_array_literal(self, expr, *args, **kwargs):
117-
return type(expr)(tuple(self.rec(ch, *args, **kwargs)
118-
for ch in expr.children))
120+
children = [self.rec(ch, *args, **kwargs) for ch in expr.children]
121+
if all(ch is orig for ch, orig in zip(children, expr.children)):
122+
return expr
123+
124+
return type(expr)(tuple(children))
119125

120126
def map_group_hw_index(self, expr, *args, **kwargs):
121127
return expr
@@ -484,6 +490,55 @@ def map_substitution(self, name, rule, arguments):
484490

485491
return self.rec(expr)
486492

493+
494+
class EqualityMapper(EqualityMapperBase):
495+
def map_loopy_function_identifier(self, expr, other) -> bool:
496+
return True
497+
498+
def map_reduction(self, expr, other) -> bool:
499+
return (
500+
expr.operation == other.operation
501+
and expr.allow_simultaneous == other.allow_simultaneous
502+
and self.rec(expr.expr, other.expr)
503+
and all(iname == other_iname
504+
for iname, other_iname in zip(expr.inames, other.inames)))
505+
506+
def map_group_hw_index(self, expr, other) -> bool:
507+
return expr.axis == other.axis
508+
509+
map_local_hw_index = map_group_hw_index
510+
511+
def map_rule_argument(self, expr, other) -> bool:
512+
return expr.index == other.index
513+
514+
def map_resolved_function(self, expr, other) -> bool:
515+
return self.rec(expr.function, other.function)
516+
517+
def map_sub_array_ref(self, expr, other) -> bool:
518+
return (
519+
len(expr.swept_inames) == len(other.swept_inames)
520+
and self.rec(expr.subscript, other.subscript)
521+
and all(self.rec(iname, other_iname)
522+
for iname, other_iname in zip(
523+
expr.swept_inames,
524+
other.swept_inames))
525+
)
526+
527+
def map_tagged_variable(self, expr, other) -> bool:
528+
return (
529+
expr.name == other.name
530+
and all(tag == other_tag
531+
for tag, other_tag in zip(expr.tags, other.tags))
532+
)
533+
534+
def map_type_cast(self, expr, other) -> bool:
535+
return (
536+
expr.type == other.type
537+
and self.rec(expr.child, other.child))
538+
539+
def map_fortran_division(self, expr, other) -> bool:
540+
return self.map_quotient(expr, other)
541+
487542
# }}}
488543

489544

@@ -497,15 +552,18 @@ def stringifier(self):
497552
def make_stringifier(self, originating_stringifier=None):
498553
return StringifyMapper()
499554

555+
def make_equality_mapper(self):
556+
return EqualityMapper()
557+
500558

501559
class Literal(LoopyExpressionBase):
502560
"""A literal to be used during code generation.
503561
504562
.. note::
505563
506564
Only used in the output of
507-
:mod:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
508-
similar mappers). Not for use in Loopy source representation.
565+
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
566+
(and similar mappers). Not for use in :mod:`loopy` source representation.
509567
"""
510568

511569
def __init__(self, s):
@@ -525,8 +583,8 @@ class ArrayLiteral(LoopyExpressionBase):
525583
.. note::
526584
527585
Only used in the output of
528-
:mod:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
529-
similar mappers). Not for use in Loopy source representation.
586+
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
587+
(and similar mappers). Not for use in :mod:`loopy` source representation.
530588
"""
531589

532590
def __init__(self, children):
@@ -555,8 +613,8 @@ class GroupHardwareAxisIndex(HardwareAxisIndex):
555613
.. note::
556614
557615
Only used in the output of
558-
:mod:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
559-
similar mappers). Not for use in Loopy source representation.
616+
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
617+
(and similar mappers). Not for use in :mod:`loopy` source representation.
560618
"""
561619
mapper_method = "map_group_hw_index"
562620

@@ -566,8 +624,8 @@ class LocalHardwareAxisIndex(HardwareAxisIndex):
566624
.. note::
567625
568626
Only used in the output of
569-
:mod:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
570-
similar mappers). Not for use in Loopy source representation.
627+
:class:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
628+
similar mappers). Not for use in :mod:`loopy` source representation.
571629
"""
572630
mapper_method = "map_local_hw_index"
573631

@@ -774,12 +832,6 @@ def __getinitargs__(self):
774832
def get_hash(self):
775833
return hash((self.__class__, self.operation, self.inames, self.expr))
776834

777-
def is_equal(self, other):
778-
return (other.__class__ == self.__class__
779-
and other.operation == self.operation
780-
and other.inames == self.inames
781-
and other.expr == self.expr)
782-
783835
@property
784836
def is_tuple_typed(self):
785837
return self.operation.arg_count > 1
@@ -977,14 +1029,6 @@ def __getinitargs__(self):
9771029
def get_hash(self):
9781030
return hash((self.__class__, self.swept_inames, self.subscript))
9791031

980-
def is_equal(self, other):
981-
"""
982-
Returns *True* iff the sub-array refs have identical expressions.
983-
"""
984-
return (other.__class__ == self.__class__
985-
and other.subscript == self.subscript
986-
and other.swept_inames == self.swept_inames)
987-
9881032
def make_stringifier(self, originating_stringifier=None):
9891033
return StringifyMapper()
9901034

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ git+https://github.com/inducer/pytools.git#egg=pytools >= 2021.1
22
git+https://github.com/inducer/islpy.git#egg=islpy
33
git+https://github.com/inducer/cgen.git#egg=cgen
44
git+https://github.com/inducer/pyopencl.git#egg=pyopencl
5-
git+https://github.com/inducer/pymbolic.git#egg=pymbolic
5+
git+https://github.com/alexfikl/pymbolic.git@equality-mapper#egg=pymbolic
66
git+https://github.com/inducer/genpy.git#egg=genpy
77
git+https://github.com/inducer/codepy.git#egg=codepy
88

0 commit comments

Comments
 (0)