Skip to content

Commit 467550f

Browse files
committed
add EqualityMapper to follow pymbolic
1 parent a24370e commit 467550f

File tree

2 files changed

+68
-23
lines changed

2 files changed

+68
-23
lines changed

loopy/symbolic.py

+67-22
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
CSECachingMapperMixin,
4444
)
4545
import immutables
46+
from pymbolic.mapper.equality import (
47+
EqualityMapper as EqualityMapperBase)
4648
from pymbolic.mapper.evaluator import \
4749
CachedEvaluationMapper as EvaluationMapperBase
4850
from pymbolic.mapper.substitutor import \
@@ -501,6 +503,60 @@ def map_substitution(self, name, rule, arguments):
501503

502504
return self.rec(expr)
503505

506+
507+
class EqualityMapper(EqualityMapperBase):
508+
def map_loopy_function_identifier(self, expr, other) -> bool:
509+
return True
510+
511+
def map_reduction(self, expr, other) -> bool:
512+
return (
513+
expr.operation == other.operation
514+
and expr.allow_simultaneous == other.allow_simultaneous
515+
and self.rec(expr.expr, other.expr)
516+
and all(iname == other_iname
517+
for iname, other_iname in zip(expr.inames, other.inames)))
518+
519+
def map_group_hw_index(self, expr, other) -> bool:
520+
return expr.axis == other.axis
521+
522+
map_local_hw_index = map_group_hw_index
523+
524+
def map_linear_subscript(self, expr, other) -> bool:
525+
return (
526+
self.rec(expr.index, other.index)
527+
and self.rec(expr.aggregate, other.aggregate))
528+
529+
def map_rule_argument(self, expr, other) -> bool:
530+
return expr.index == other.index
531+
532+
def map_resolved_function(self, expr, other) -> bool:
533+
return self.rec(expr.function, other.function)
534+
535+
def map_sub_array_ref(self, expr, other) -> bool:
536+
return (
537+
len(expr.swept_inames) == len(other.swept_inames)
538+
and self.rec(expr.subscript, other.subscript)
539+
and all(self.rec(iname, other_iname)
540+
for iname, other_iname in zip(
541+
expr.swept_inames,
542+
other.swept_inames))
543+
)
544+
545+
def map_tagged_variable(self, expr, other) -> bool:
546+
return (
547+
expr.name == other.name
548+
and all(tag == other_tag
549+
for tag, other_tag in zip(expr.tags, other.tags))
550+
)
551+
552+
def map_type_cast(self, expr, other) -> bool:
553+
return (
554+
expr.type == other.type
555+
and self.rec(expr.child, other.child))
556+
557+
def map_fortran_division(self, expr, other) -> bool:
558+
return self.map_quotient(expr, other)
559+
504560
# }}}
505561

506562

@@ -514,15 +570,18 @@ def stringifier(self):
514570
def make_stringifier(self, originating_stringifier=None):
515571
return StringifyMapper()
516572

573+
def make_equality_mapper(self):
574+
return EqualityMapper()
575+
517576

518577
class Literal(LoopyExpressionBase):
519578
"""A literal to be used during code generation.
520579
521580
.. note::
522581
523582
Only used in the output of
524-
:mod:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
525-
similar mappers). Not for use in Loopy source representation.
583+
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
584+
(and similar mappers). Not for use in :mod:`loopy` source representation.
526585
"""
527586

528587
def __init__(self, s):
@@ -542,8 +601,8 @@ class ArrayLiteral(LoopyExpressionBase):
542601
.. note::
543602
544603
Only used in the output of
545-
:mod:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
546-
similar mappers). Not for use in Loopy source representation.
604+
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
605+
(and similar mappers). Not for use in :mod:`loopy` source representation.
547606
"""
548607

549608
def __init__(self, children):
@@ -572,8 +631,8 @@ class GroupHardwareAxisIndex(HardwareAxisIndex):
572631
.. note::
573632
574633
Only used in the output of
575-
:mod:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
576-
similar mappers). Not for use in Loopy source representation.
634+
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
635+
(and similar mappers). Not for use in :mod:`loopy` source representation.
577636
"""
578637
mapper_method = "map_group_hw_index"
579638

@@ -583,8 +642,8 @@ class LocalHardwareAxisIndex(HardwareAxisIndex):
583642
.. note::
584643
585644
Only used in the output of
586-
:mod:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
587-
similar mappers). Not for use in Loopy source representation.
645+
:class:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
646+
similar mappers). Not for use in :mod:`loopy` source representation.
588647
"""
589648
mapper_method = "map_local_hw_index"
590649

@@ -791,12 +850,6 @@ def __getinitargs__(self):
791850
def get_hash(self):
792851
return hash((self.__class__, self.operation, self.inames, self.expr))
793852

794-
def is_equal(self, other):
795-
return (other.__class__ == self.__class__
796-
and other.operation == self.operation
797-
and other.inames == self.inames
798-
and other.expr == self.expr)
799-
800853
@property
801854
def is_tuple_typed(self):
802855
return self.operation.arg_count > 1
@@ -994,14 +1047,6 @@ def __getinitargs__(self):
9941047
def get_hash(self):
9951048
return hash((self.__class__, self.swept_inames, self.subscript))
9961049

997-
def is_equal(self, other):
998-
"""
999-
Returns *True* iff the sub-array refs have identical expressions.
1000-
"""
1001-
return (other.__class__ == self.__class__
1002-
and other.subscript == self.subscript
1003-
and other.swept_inames == self.swept_inames)
1004-
10051050
def make_stringifier(self, originating_stringifier=None):
10061051
return StringifyMapper()
10071052

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ git+https://github.com/inducer/pytools.git#egg=pytools >= 2021.1
22
git+https://github.com/inducer/islpy.git#egg=islpy
33
git+https://github.com/inducer/cgen.git#egg=cgen
44
git+https://github.com/inducer/pyopencl.git#egg=pyopencl
5-
git+https://github.com/inducer/pymbolic.git#egg=pymbolic
5+
git+https://github.com/alexfikl/pymbolic.git@equality-mapper#egg=pymbolic
66
git+https://github.com/inducer/genpy.git#egg=genpy
77
git+https://github.com/inducer/codepy.git#egg=codepy
88

0 commit comments

Comments
 (0)