27
27
28
28
from typing import TYPE_CHECKING , Any , Callable , Union
29
29
30
- from pytools import memoize_method
31
-
32
30
from pytato .array import (
33
31
AbstractResultWithNamedArrays ,
34
32
AdvancedIndexInContiguousAxes ,
@@ -83,19 +81,23 @@ class EqualityComparer:
83
81
more on this.
84
82
"""
85
83
def __init__ (self ) -> None :
84
+ # Uses the same cache for both arrays and functions
86
85
self ._cache : dict [tuple [int , int ], bool ] = {}
87
86
88
- def rec (self , expr1 : ArrayOrNames , expr2 : Any ) -> bool :
87
+ def rec (self , expr1 : ArrayOrNames | FunctionDefinition , expr2 : Any ) -> bool :
89
88
cache_key = id (expr1 ), id (expr2 )
90
89
try :
91
90
return self ._cache [cache_key ]
92
91
except KeyError :
93
-
94
- method : Callable [[ Array | AbstractResultWithNamedArrays , Any ],
95
- bool ]
92
+ method : Callable [
93
+ [ Array | AbstractResultWithNamedArrays | FunctionDefinition , Any ],
94
+ bool ]
96
95
97
96
try :
98
- method = getattr (self , expr1 ._mapper_method )
97
+ method = (
98
+ getattr (self , expr1 ._mapper_method )
99
+ if isinstance (expr1 , (Array , AbstractResultWithNamedArrays ))
100
+ else self .map_function_definition )
99
101
except AttributeError :
100
102
if isinstance (expr1 , Array ):
101
103
result = self .handle_unsupported_array (expr1 , expr2 )
@@ -293,7 +295,6 @@ def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool:
293
295
and expr1 .tags == expr2 .tags
294
296
)
295
297
296
- @memoize_method
297
298
def map_function_definition (self , expr1 : FunctionDefinition , expr2 : Any
298
299
) -> bool :
299
300
return (expr1 .__class__ is expr2 .__class__
@@ -307,7 +308,7 @@ def map_function_definition(self, expr1: FunctionDefinition, expr2: Any
307
308
308
309
def map_call (self , expr1 : Call , expr2 : Any ) -> bool :
309
310
return (expr1 .__class__ is expr2 .__class__
310
- and self .map_function_definition (expr1 .function , expr2 .function )
311
+ and self .rec (expr1 .function , expr2 .function )
311
312
and frozenset (expr1 .bindings ) == frozenset (expr2 .bindings )
312
313
and all (self .rec (bnd ,
313
314
expr2 .bindings [name ])
0 commit comments