Skip to content

Commit 312e8ff

Browse files
committed
avoid traversing functions multiple times
1 parent 5f61931 commit 312e8ff

File tree

12 files changed

+252
-103
lines changed

12 files changed

+252
-103
lines changed

pytato/analysis/__init__.py

+33-9
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from typing import TYPE_CHECKING, Any, Mapping
3030

3131
from pymbolic.mapper.optimize import optimize_mapper
32-
from pytools import memoize_method
3332

3433
from pytato.array import (
3534
Array,
@@ -46,7 +45,7 @@
4645
)
4746
from pytato.function import Call, FunctionDefinition, NamedCallResult
4847
from pytato.loopy import LoopyCall
49-
from pytato.transform import ArrayOrNames, CachedWalkMapper, Mapper
48+
from pytato.transform import ArrayOrNames, CachedWalkMapper, Mapper, _SelfMapper
5049

5150

5251
if TYPE_CHECKING:
@@ -410,16 +409,34 @@ class NodeCountMapper(CachedWalkMapper):
410409
Dictionary mapping node types to number of nodes of that type.
411410
"""
412411

413-
def __init__(self, count_duplicates: bool = False) -> None:
412+
def __init__(
413+
self,
414+
count_duplicates: bool = False,
415+
_visited_functions: set[Any] | None = None,
416+
) -> None:
417+
super().__init__(_visited_functions=_visited_functions)
418+
414419
from collections import defaultdict
415-
super().__init__()
416420
self.expr_type_counts: dict[type[Any], int] = defaultdict(int)
417421
self.count_duplicates = count_duplicates
418422

419423
def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames:
420424
# Returns unique nodes only if count_duplicates is False
421425
return id(expr) if self.count_duplicates else expr
422426

427+
def get_function_definition_cache_key(
428+
self, expr: FunctionDefinition) -> int | FunctionDefinition:
429+
# Returns unique nodes only if count_duplicates is False
430+
return id(expr) if self.count_duplicates else expr
431+
432+
def clone_for_callee(
433+
self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper:
434+
# type-ignore-reason: self.__init__ has a different function signature
435+
# than Mapper.__init__
436+
return type(self)(
437+
count_duplicates=self.count_duplicates, # type: ignore[attr-defined]
438+
_visited_functions=self._visited_functions) # type: ignore[call-arg,attr-defined]
439+
423440
def post_visit(self, expr: Any) -> None:
424441
if not isinstance(expr, DictOfNamedArrays):
425442
self.expr_type_counts[type(expr)] += 1
@@ -485,15 +502,20 @@ class NodeMultiplicityMapper(CachedWalkMapper):
485502
486503
.. autoattribute:: expr_multiplicity_counts
487504
"""
488-
def __init__(self) -> None:
505+
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
506+
super().__init__(_visited_functions=_visited_functions)
507+
489508
from collections import defaultdict
490-
super().__init__()
491509
self.expr_multiplicity_counts: dict[Array, int] = defaultdict(int)
492510

493511
def get_cache_key(self, expr: ArrayOrNames) -> int:
494512
# Returns each node, including nodes that are duplicates
495513
return id(expr)
496514

515+
def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
516+
# Returns each node, including nodes that are duplicates
517+
return id(expr)
518+
497519
def post_visit(self, expr: Any) -> None:
498520
if not isinstance(expr, DictOfNamedArrays):
499521
self.expr_multiplicity_counts[expr] += 1
@@ -527,14 +549,16 @@ class CallSiteCountMapper(CachedWalkMapper):
527549
The number of nodes.
528550
"""
529551

530-
def __init__(self) -> None:
531-
super().__init__()
552+
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
553+
super().__init__(_visited_functions=_visited_functions)
532554
self.count = 0
533555

534556
def get_cache_key(self, expr: ArrayOrNames) -> int:
535557
return id(expr)
536558

537-
@memoize_method
559+
def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
560+
return id(expr)
561+
538562
def map_function_definition(self, expr: FunctionDefinition) -> None:
539563
if not self.visit(expr):
540564
return

pytato/codegen.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
"""
2525

2626
import dataclasses
27-
from typing import Any, Mapping, Tuple
27+
from typing import Any, Hashable, Mapping, Tuple
2828

2929
from immutabledict import immutabledict
3030

@@ -42,7 +42,7 @@
4242
SizeParam,
4343
make_dict_of_named_arrays,
4444
)
45-
from pytato.function import NamedCallResult
45+
from pytato.function import FunctionDefinition, NamedCallResult
4646
from pytato.loopy import LoopyCall
4747
from pytato.scalar_expr import IntegralScalarExpression
4848
from pytato.target import Target
@@ -118,10 +118,13 @@ class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper): # type: ignore[misc]
118118
====================================== =====================================
119119
"""
120120

121-
def __init__(self, target: Target,
122-
kernels_seen: dict[str, lp.LoopKernel] | None = None
123-
) -> None:
124-
super().__init__()
121+
def __init__(
122+
self,
123+
target: Target,
124+
kernels_seen: dict[str, lp.LoopKernel] | None = None,
125+
_function_cache: dict[Hashable, FunctionDefinition] | None = None
126+
) -> None:
127+
super().__init__(_function_cache=_function_cache)
125128
self.bound_arguments: dict[str, DataInterface] = {}
126129
self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator()
127130
self.target = target
@@ -247,13 +250,16 @@ def normalize_outputs(
247250

248251
@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
249252
class NamesValidityChecker(CachedWalkMapper):
250-
def __init__(self) -> None:
253+
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
251254
self.name_to_input: dict[str, InputArgumentBase] = {}
252-
super().__init__()
255+
super().__init__(_visited_functions=_visited_functions)
253256

254257
def get_cache_key(self, expr: ArrayOrNames) -> int:
255258
return id(expr)
256259

260+
def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
261+
return id(expr)
262+
257263
def post_visit(self, expr: Any) -> None:
258264
if isinstance(expr, (Placeholder, SizeParam, DataWrapper)):
259265
if expr.name is not None:

pytato/distributed/partition.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,9 @@ def __init__(self,
292292
recvd_ary_to_name: Mapping[Array, str],
293293
sptpo_ary_to_name: Mapping[Array, str],
294294
name_to_output: Mapping[str, Array],
295+
_function_cache: dict[Hashable, FunctionDefinition] | None = None,
295296
) -> None:
296-
super().__init__()
297+
super().__init__(_function_cache=_function_cache)
297298

298299
self.recvd_ary_to_name = recvd_ary_to_name
299300
self.sptpo_ary_to_name = sptpo_ary_to_name
@@ -307,7 +308,7 @@ def clone_for_callee(
307308
self, function: FunctionDefinition) -> _DistributedInputReplacer:
308309
# Function definitions aren't allowed to contain receives,
309310
# stored arrays promoted to part outputs, or part outputs
310-
return type(self)({}, {}, {})
311+
return type(self)({}, {}, {}, _function_cache=self._function_cache)
311312

312313
def map_placeholder(self, expr: Placeholder) -> Placeholder:
313314
self.user_input_names.add(expr.name)

pytato/distributed/verify.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ class MissingRecvError(DistributedPartitionVerificationError):
142142

143143
@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
144144
class _SeenNodesWalkMapper(CachedWalkMapper):
145-
def __init__(self) -> None:
146-
super().__init__()
145+
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
146+
super().__init__(_visited_functions=_visited_functions)
147147
self.seen_nodes: set[ArrayOrNames] = set()
148148

149149
def get_cache_key(self, expr: ArrayOrNames) -> int:

pytato/equality.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727

2828
from typing import TYPE_CHECKING, Any, Callable, Union
2929

30-
from pytools import memoize_method
31-
3230
from pytato.array import (
3331
AbstractResultWithNamedArrays,
3432
AdvancedIndexInContiguousAxes,
@@ -83,19 +81,23 @@ class EqualityComparer:
8381
more on this.
8482
"""
8583
def __init__(self) -> None:
84+
# Uses the same cache for both arrays and functions
8685
self._cache: dict[tuple[int, int], bool] = {}
8786

88-
def rec(self, expr1: ArrayOrNames, expr2: Any) -> bool:
87+
def rec(self, expr1: ArrayOrNames | FunctionDefinition, expr2: Any) -> bool:
8988
cache_key = id(expr1), id(expr2)
9089
try:
9190
return self._cache[cache_key]
9291
except KeyError:
93-
94-
method: Callable[[Array | AbstractResultWithNamedArrays, Any],
95-
bool]
92+
method: Callable[
93+
[Array | AbstractResultWithNamedArrays | FunctionDefinition, Any],
94+
bool]
9695

9796
try:
98-
method = getattr(self, expr1._mapper_method)
97+
method = (
98+
getattr(self, expr1._mapper_method)
99+
if isinstance(expr1, (Array, AbstractResultWithNamedArrays))
100+
else self.map_function_definition)
99101
except AttributeError:
100102
if isinstance(expr1, Array):
101103
result = self.handle_unsupported_array(expr1, expr2)
@@ -293,7 +295,6 @@ def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool:
293295
and expr1.tags == expr2.tags
294296
)
295297

296-
@memoize_method
297298
def map_function_definition(self, expr1: FunctionDefinition, expr2: Any
298299
) -> bool:
299300
return (expr1.__class__ is expr2.__class__
@@ -307,7 +308,7 @@ def map_function_definition(self, expr1: FunctionDefinition, expr2: Any
307308

308309
def map_call(self, expr1: Call, expr2: Any) -> bool:
309310
return (expr1.__class__ is expr2.__class__
310-
and self.map_function_definition(expr1.function, expr2.function)
311+
and self.rec(expr1.function, expr2.function)
311312
and frozenset(expr1.bindings) == frozenset(expr2.bindings)
312313
and all(self.rec(bnd,
313314
expr2.bindings[name])

pytato/stringifier.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@
3131
import numpy as np
3232
from immutabledict import immutabledict
3333

34-
from pytools import memoize_method
35-
3634
from pytato.array import (
3735
Array,
3836
Axis,
@@ -68,6 +66,7 @@ def __init__(self,
6866
self.truncation_depth = truncation_depth
6967
self.truncation_string = truncation_string
7068

69+
# Uses the same cache for both arrays and functions
7170
self._cache: dict[tuple[int, int], str] = {}
7271

7372
def rec(self, expr: Any, depth: int) -> str:
@@ -79,6 +78,15 @@ def rec(self, expr: Any, depth: int) -> str:
7978
self._cache[cache_key] = result
8079
return result # type: ignore[no-any-return]
8180

81+
def rec_function_definition(self, expr: FunctionDefinition, depth: int) -> str:
82+
cache_key = (id(expr), depth)
83+
try:
84+
return self._cache[cache_key]
85+
except KeyError:
86+
result = super().rec_function_definition(expr, depth)
87+
self._cache[cache_key] = result
88+
return result # type: ignore[no-any-return]
89+
8290
def __call__(self, expr: Any, depth: int = 0) -> str:
8391
return self.rec(expr, depth)
8492

@@ -168,7 +176,6 @@ def _get_field_val(field: str) -> str:
168176
for field in attrs.fields(type(expr)))
169177
+ ")")
170178

171-
@memoize_method
172179
def map_function_definition(self, expr: FunctionDefinition, depth: int) -> str:
173180
if depth > self.truncation_depth:
174181
return self.truncation_string
@@ -191,7 +198,7 @@ def map_call(self, expr: Call, depth: int) -> str:
191198

192199
def _get_field_val(field: str) -> str:
193200
if field == "function":
194-
return self.map_function_definition(expr.function, depth+1)
201+
return self.rec_function_definition(expr.function, depth+1)
195202
else:
196203
return self.rec(getattr(expr, field), depth+1)
197204

pytato/target/python/numpy_like.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
SizeParam,
6565
Stack,
6666
)
67+
from pytato.function import FunctionDefinition
6768
from pytato.raising import BinaryOpType, C99CallOp
6869
from pytato.reductions import (
6970
AllReductionOperation,
@@ -169,7 +170,7 @@ def _is_slice_trivial(slice_: NormalizedSlice,
169170
}
170171

171172

172-
class NumpyCodegenMapper(CachedMapper[str]):
173+
class NumpyCodegenMapper(CachedMapper[str, FunctionDefinition]):
173174
"""
174175
.. note::
175176

0 commit comments

Comments
 (0)