Skip to content

Commit 0c73b20

Browse files
committed
avoid traversing functions multiple times
1 parent 84c279c commit 0c73b20

File tree

12 files changed

+219
-112
lines changed

12 files changed

+219
-112
lines changed

pytato/analysis/__init__.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -404,14 +404,14 @@ class NodeCountMapper(CachedWalkMapper):
404404
The number of nodes.
405405
"""
406406

407-
def __init__(self) -> None:
408-
super().__init__()
407+
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
408+
super().__init__(_visited_functions=_visited_functions)
409409
self.count = 0
410410

411411
def get_cache_key(self, expr: ArrayOrNames) -> int:
412412
return id(expr)
413413

414-
def get_func_def_cache_key(self, expr: FunctionDefinition) -> int:
414+
def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
415415
return id(expr)
416416

417417
def post_visit(self, expr: Any) -> None:
@@ -444,28 +444,25 @@ class CallSiteCountMapper(CachedWalkMapper):
444444
The number of nodes.
445445
"""
446446

447-
def __init__(self) -> None:
448-
super().__init__()
447+
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
448+
super().__init__(_visited_functions=_visited_functions)
449449
self.count = 0
450450

451451
def get_cache_key(self, expr: ArrayOrNames) -> int:
452452
return id(expr)
453453

454-
def get_func_def_cache_key(self, expr: FunctionDefinition) -> int:
454+
def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
455455
return id(expr)
456456

457457
def map_function_definition(self, expr: FunctionDefinition) -> None:
458-
cache_key = self.get_func_def_cache_key(expr)
459-
if not self.visit(expr) or cache_key in self._visited_functions:
458+
if not self.visit(expr):
460459
return
461460

462461
new_mapper = self.clone_for_callee(expr)
463462
for subexpr in expr.returns.values():
464463
new_mapper(subexpr)
465464
self.count += new_mapper.count
466465

467-
self._visited_functions.add(cache_key)
468-
469466
self.post_visit(expr)
470467

471468
def post_visit(self, expr: Any) -> None:

pytato/codegen.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
"""
2525

2626
import dataclasses
27-
from typing import Any, Mapping, Tuple
27+
from typing import Any, Hashable, Mapping, Tuple
2828

2929
from immutabledict import immutabledict
3030

@@ -118,10 +118,13 @@ class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper): # type: ignore[misc]
118118
====================================== =====================================
119119
"""
120120

121-
def __init__(self, target: Target,
122-
kernels_seen: dict[str, lp.LoopKernel] | None = None
123-
) -> None:
124-
super().__init__()
121+
def __init__(
122+
self,
123+
target: Target,
124+
kernels_seen: dict[str, lp.LoopKernel] | None = None,
125+
_function_cache: dict[Hashable, FunctionDefinition] | None = None
126+
) -> None:
127+
super().__init__(_function_cache=_function_cache)
125128
self.bound_arguments: dict[str, DataInterface] = {}
126129
self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator()
127130
self.target = target
@@ -247,14 +250,14 @@ def normalize_outputs(
247250

248251
@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
249252
class NamesValidityChecker(CachedWalkMapper):
250-
def __init__(self) -> None:
253+
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
251254
self.name_to_input: dict[str, InputArgumentBase] = {}
252-
super().__init__()
255+
super().__init__(_visited_functions=_visited_functions)
253256

254257
def get_cache_key(self, expr: ArrayOrNames) -> int:
255258
return id(expr)
256259

257-
def get_func_def_cache_key(self, expr: FunctionDefinition) -> int:
260+
def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
258261
return id(expr)
259262

260263
def post_visit(self, expr: Any) -> None:

pytato/distributed/partition.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,9 @@ def __init__(self,
292292
recvd_ary_to_name: Mapping[Array, str],
293293
sptpo_ary_to_name: Mapping[Array, str],
294294
name_to_output: Mapping[str, Array],
295+
_function_cache: dict[Hashable, FunctionDefinition] | None = None,
295296
) -> None:
296-
super().__init__()
297+
super().__init__(_function_cache=_function_cache)
297298

298299
self.recvd_ary_to_name = recvd_ary_to_name
299300
self.sptpo_ary_to_name = sptpo_ary_to_name
@@ -307,7 +308,7 @@ def clone_for_callee(
307308
self, function: FunctionDefinition) -> _DistributedInputReplacer:
308309
# Function definitions aren't allowed to contain receives,
309310
# stored arrays promoted to part outputs, or part outputs
310-
return type(self)({}, {}, {})
311+
return type(self)({}, {}, {}, _function_cache=self._function_cache)
311312

312313
def map_placeholder(self, expr: Placeholder) -> Placeholder:
313314
self.user_input_names.add(expr.name)

pytato/distributed/verify.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ class MissingRecvError(DistributedPartitionVerificationError):
142142

143143
@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
144144
class _SeenNodesWalkMapper(CachedWalkMapper):
145-
def __init__(self) -> None:
146-
super().__init__()
145+
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
146+
super().__init__(_visited_functions=_visited_functions)
147147
self.seen_nodes: set[ArrayOrNames] = set()
148148

149149
def get_cache_key(self, expr: ArrayOrNames) -> int:

pytato/equality.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727

2828
from typing import TYPE_CHECKING, Any, Callable, Union
2929

30-
from pytools import memoize_method
31-
3230
from pytato.array import (
3331
AbstractResultWithNamedArrays,
3432
AdvancedIndexInContiguousAxes,
@@ -83,19 +81,23 @@ class EqualityComparer:
8381
more on this.
8482
"""
8583
def __init__(self) -> None:
84+
# Uses the same cache for both arrays and functions
8685
self._cache: dict[tuple[int, int], bool] = {}
8786

88-
def rec(self, expr1: ArrayOrNames, expr2: Any) -> bool:
87+
def rec(self, expr1: ArrayOrNames | FunctionDefinition, expr2: Any) -> bool:
8988
cache_key = id(expr1), id(expr2)
9089
try:
9190
return self._cache[cache_key]
9291
except KeyError:
93-
94-
method: Callable[[Array | AbstractResultWithNamedArrays, Any],
95-
bool]
92+
method: Callable[
93+
[Array | AbstractResultWithNamedArrays | FunctionDefinition, Any],
94+
bool]
9695

9796
try:
98-
method = getattr(self, expr1._mapper_method)
97+
method = (
98+
getattr(self, expr1._mapper_method)
99+
if isinstance(expr1, (Array, AbstractResultWithNamedArrays))
100+
else self.map_function_definition)
99101
except AttributeError:
100102
if isinstance(expr1, Array):
101103
result = self.handle_unsupported_array(expr1, expr2)
@@ -293,7 +295,6 @@ def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool:
293295
and expr1.tags == expr2.tags
294296
)
295297

296-
@memoize_method
297298
def map_function_definition(self, expr1: FunctionDefinition, expr2: Any
298299
) -> bool:
299300
return (expr1.__class__ is expr2.__class__
@@ -307,7 +308,7 @@ def map_function_definition(self, expr1: FunctionDefinition, expr2: Any
307308

308309
def map_call(self, expr1: Call, expr2: Any) -> bool:
309310
return (expr1.__class__ is expr2.__class__
310-
and self.map_function_definition(expr1.function, expr2.function)
311+
and self.rec(expr1.function, expr2.function)
311312
and frozenset(expr1.bindings) == frozenset(expr2.bindings)
312313
and all(self.rec(bnd,
313314
expr2.bindings[name])

pytato/stringifier.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@
3131
import numpy as np
3232
from immutabledict import immutabledict
3333

34-
from pytools import memoize_method
35-
3634
from pytato.array import (
3735
Array,
3836
Axis,
@@ -68,6 +66,7 @@ def __init__(self,
6866
self.truncation_depth = truncation_depth
6967
self.truncation_string = truncation_string
7068

69+
# Uses the same cache for both arrays and functions
7170
self._cache: dict[tuple[int, int], str] = {}
7271

7372
def rec(self, expr: Any, depth: int) -> str:
@@ -79,6 +78,15 @@ def rec(self, expr: Any, depth: int) -> str:
7978
self._cache[cache_key] = result
8079
return result # type: ignore[no-any-return]
8180

81+
def rec_function_definition(self, expr: FunctionDefinition, depth: int) -> str:
82+
cache_key = (id(expr), depth)
83+
try:
84+
return self._cache[cache_key]
85+
except KeyError:
86+
result = super().rec_function_definition(expr, depth)
87+
self._cache[cache_key] = result
88+
return result # type: ignore[no-any-return]
89+
8290
def __call__(self, expr: Any, depth: int = 0) -> str:
8391
return self.rec(expr, depth)
8492

@@ -168,7 +176,6 @@ def _get_field_val(field: str) -> str:
168176
for field in attrs.fields(type(expr)))
169177
+ ")")
170178

171-
@memoize_method
172179
def map_function_definition(self, expr: FunctionDefinition, depth: int) -> str:
173180
if depth > self.truncation_depth:
174181
return self.truncation_string
@@ -191,7 +198,7 @@ def map_call(self, expr: Call, depth: int) -> str:
191198

192199
def _get_field_val(field: str) -> str:
193200
if field == "function":
194-
return self.map_function_definition(expr.function, depth+1)
201+
return self.rec_function_definition(expr.function, depth+1)
195202
else:
196203
return self.rec(getattr(expr, field), depth+1)
197204

pytato/target/python/numpy_like.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
SizeParam,
6565
Stack,
6666
)
67+
from pytato.function import FunctionDefinition
6768
from pytato.raising import BinaryOpType, C99CallOp
6869
from pytato.reductions import (
6970
AllReductionOperation,
@@ -193,7 +194,7 @@ def _is_slice_trivial(slice_: NormalizedSlice,
193194
}
194195

195196

196-
class NumpyCodegenMapper(CachedMapper[str]):
197+
class NumpyCodegenMapper(CachedMapper[str, FunctionDefinition]):
197198
"""
198199
.. note::
199200

0 commit comments

Comments
 (0)