Skip to content

Commit 8ea4f38

Browse files
committed
add EqualityMapper to follow pymbolic
1 parent ccc8cbf commit 8ea4f38

File tree

2 files changed

+68
-23
lines changed

2 files changed

+68
-23
lines changed

loopy/symbolic.py

+67-22
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
CSECachingMapperMixin,
5151
)
5252
import immutables
53+
from pymbolic.mapper.equality import (
54+
EqualityMapper as EqualityMapperBase)
5355
from pymbolic.mapper.evaluator import \
5456
CachedEvaluationMapper as EvaluationMapperBase
5557
from pymbolic.mapper.substitutor import \
@@ -502,6 +504,60 @@ def map_substitution(self, name, rule, arguments):
502504

503505
return self.rec(expr)
504506

507+
508+
class EqualityMapper(EqualityMapperBase):
509+
def map_loopy_function_identifier(self, expr, other) -> bool:
510+
return True
511+
512+
def map_reduction(self, expr, other) -> bool:
513+
return (
514+
expr.operation == other.operation
515+
and expr.allow_simultaneous == other.allow_simultaneous
516+
and self.rec(expr.expr, other.expr)
517+
and all(iname == other_iname
518+
for iname, other_iname in zip(expr.inames, other.inames)))
519+
520+
def map_group_hw_index(self, expr, other) -> bool:
521+
return expr.axis == other.axis
522+
523+
map_local_hw_index = map_group_hw_index
524+
525+
def map_linear_subscript(self, expr, other) -> bool:
526+
return (
527+
self.rec(expr.index, other.index)
528+
and self.rec(expr.aggregate, other.aggregate))
529+
530+
def map_rule_argument(self, expr, other) -> bool:
531+
return expr.index == other.index
532+
533+
def map_resolved_function(self, expr, other) -> bool:
534+
return self.rec(expr.function, other.function)
535+
536+
def map_sub_array_ref(self, expr, other) -> bool:
537+
return (
538+
len(expr.swept_inames) == len(other.swept_inames)
539+
and self.rec(expr.subscript, other.subscript)
540+
and all(self.rec(iname, other_iname)
541+
for iname, other_iname in zip(
542+
expr.swept_inames,
543+
other.swept_inames))
544+
)
545+
546+
def map_tagged_variable(self, expr, other) -> bool:
547+
return (
548+
expr.name == other.name
549+
and all(tag == other_tag
550+
for tag, other_tag in zip(expr.tags, other.tags))
551+
)
552+
553+
def map_type_cast(self, expr, other) -> bool:
554+
return (
555+
expr.type == other.type
556+
and self.rec(expr.child, other.child))
557+
558+
def map_fortran_division(self, expr, other) -> bool:
559+
return self.map_quotient(expr, other)
560+
505561
# }}}
506562

507563

@@ -515,15 +571,18 @@ def stringifier(self):
515571
def make_stringifier(self, originating_stringifier=None):
516572
return StringifyMapper()
517573

574+
def make_equality_mapper(self):
575+
return EqualityMapper()
576+
518577

519578
class Literal(LoopyExpressionBase):
520579
"""A literal to be used during code generation.
521580
522581
.. note::
523582
524583
Only used in the output of
525-
:mod:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
526-
similar mappers). Not for use in Loopy source representation.
584+
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
585+
(and similar mappers). Not for use in :mod:`loopy` source representation.
527586
"""
528587

529588
def __init__(self, s):
@@ -543,8 +602,8 @@ class ArrayLiteral(LoopyExpressionBase):
543602
.. note::
544603
545604
Only used in the output of
546-
:mod:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
547-
similar mappers). Not for use in Loopy source representation.
605+
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
606+
(and similar mappers). Not for use in :mod:`loopy` source representation.
548607
"""
549608

550609
def __init__(self, children):
@@ -573,8 +632,8 @@ class GroupHardwareAxisIndex(HardwareAxisIndex):
573632
.. note::
574633
575634
Only used in the output of
576-
:mod:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
577-
similar mappers). Not for use in Loopy source representation.
635+
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
636+
(and similar mappers). Not for use in :mod:`loopy` source representation.
578637
"""
579638
mapper_method = "map_group_hw_index"
580639

@@ -584,8 +643,8 @@ class LocalHardwareAxisIndex(HardwareAxisIndex):
584643
.. note::
585644
586645
Only used in the output of
587-
:mod:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
588-
similar mappers). Not for use in Loopy source representation.
646+
:class:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
647+
similar mappers). Not for use in :mod:`loopy` source representation.
589648
"""
590649
mapper_method = "map_local_hw_index"
591650

@@ -792,12 +851,6 @@ def __getinitargs__(self):
792851
def get_hash(self):
793852
return hash((self.__class__, self.operation, self.inames, self.expr))
794853

795-
def is_equal(self, other):
796-
return (other.__class__ == self.__class__
797-
and other.operation == self.operation
798-
and other.inames == self.inames
799-
and other.expr == self.expr)
800-
801854
@property
802855
def is_tuple_typed(self):
803856
return self.operation.arg_count > 1
@@ -994,14 +1047,6 @@ def __getinitargs__(self):
9941047
def get_hash(self):
9951048
return hash((self.__class__, self.swept_inames, self.subscript))
9961049

997-
def is_equal(self, other):
998-
"""
999-
Returns *True* iff the sub-array refs have identical expressions.
1000-
"""
1001-
return (other.__class__ == self.__class__
1002-
and other.subscript == self.subscript
1003-
and other.swept_inames == self.swept_inames)
1004-
10051050
def make_stringifier(self, originating_stringifier=None):
10061051
return StringifyMapper()
10071052

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ git+https://github.com/inducer/pytools.git#egg=pytools >= 2021.1
22
git+https://github.com/inducer/islpy.git#egg=islpy
33
git+https://github.com/inducer/cgen.git#egg=cgen
44
git+https://github.com/inducer/pyopencl.git#egg=pyopencl
5-
git+https://github.com/inducer/pymbolic.git#egg=pymbolic
5+
git+https://github.com/alexfikl/pymbolic.git@equality-mapper#egg=pymbolic
66
git+https://github.com/inducer/genpy.git#egg=genpy
77
git+https://github.com/inducer/codepy.git#egg=codepy
88

0 commit comments

Comments
 (0)