Skip to content

Commit 53a69ee

Browse files
committed
avoid traversing functions multiple times
1 parent 179cf09 commit 53a69ee

16 files changed

+235
-123
lines changed

pytato/analysis/__init__.py

+33-10
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@
2828

2929
from typing import TYPE_CHECKING, Any
3030

31+
from typing_extensions import Self
32+
3133
from loopy.tools import LoopyKeyBuilder
3234
from pymbolic.mapper.optimize import optimize_mapper
33-
from pytools import memoize_method
3435

3536
from pytato.array import (
3637
Array,
@@ -76,7 +77,7 @@
7677

7778
# {{{ NUserCollector
7879

79-
class NUserCollector(Mapper[None, []]):
80+
class NUserCollector(Mapper[None, None, []]):
8081
"""
8182
A :class:`pytato.transform.CachedWalkMapper` that records the number of
8283
times an array expression is a direct dependency of other nodes.
@@ -317,7 +318,7 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool:
317318

318319
# {{{ DirectPredecessorsGetter
319320

320-
class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], []]):
321+
class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], None, []]):
321322
"""
322323
Mapper to get the
323324
`direct predecessors
@@ -413,16 +414,31 @@ class NodeCountMapper(CachedWalkMapper[[]]):
413414
Dictionary mapping node types to number of nodes of that type.
414415
"""
415416

416-
def __init__(self, count_duplicates: bool = False) -> None:
417+
def __init__(
418+
self,
419+
count_duplicates: bool = False,
420+
_visited_functions: set[Any] | None = None,
421+
) -> None:
422+
super().__init__(_visited_functions=_visited_functions)
423+
417424
from collections import defaultdict
418-
super().__init__()
419425
self.expr_type_counts: dict[type[Any], int] = defaultdict(int)
420426
self.count_duplicates = count_duplicates
421427

422428
def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames:
423429
# Returns unique nodes only if count_duplicates is False
424430
return id(expr) if self.count_duplicates else expr
425431

432+
def get_function_definition_cache_key(
433+
self, expr: FunctionDefinition) -> int | FunctionDefinition:
434+
# Returns unique nodes only if count_duplicates is False
435+
return id(expr) if self.count_duplicates else expr
436+
437+
def clone_for_callee(self, function: FunctionDefinition) -> Self:
438+
return type(self)(
439+
count_duplicates=self.count_duplicates,
440+
_visited_functions=self._visited_functions)
441+
426442
def post_visit(self, expr: Any) -> None:
427443
if not isinstance(expr, DictOfNamedArrays):
428444
self.expr_type_counts[type(expr)] += 1
@@ -488,15 +504,20 @@ class NodeMultiplicityMapper(CachedWalkMapper[[]]):
488504
489505
.. autoattribute:: expr_multiplicity_counts
490506
"""
491-
def __init__(self) -> None:
507+
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
508+
super().__init__(_visited_functions=_visited_functions)
509+
492510
from collections import defaultdict
493-
super().__init__()
494511
self.expr_multiplicity_counts: dict[Array, int] = defaultdict(int)
495512

496513
def get_cache_key(self, expr: ArrayOrNames) -> int:
497514
# Returns each node, including nodes that are duplicates
498515
return id(expr)
499516

517+
def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
518+
# Returns each node, including nodes that are duplicates
519+
return id(expr)
520+
500521
def post_visit(self, expr: Any) -> None:
501522
if not isinstance(expr, DictOfNamedArrays):
502523
self.expr_multiplicity_counts[expr] += 1
@@ -530,14 +551,16 @@ class CallSiteCountMapper(CachedWalkMapper[[]]):
530551
The number of nodes.
531552
"""
532553

533-
def __init__(self) -> None:
534-
super().__init__()
554+
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
555+
super().__init__(_visited_functions=_visited_functions)
535556
self.count = 0
536557

537558
def get_cache_key(self, expr: ArrayOrNames) -> int:
538559
return id(expr)
539560

540-
@memoize_method
561+
def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
562+
return id(expr)
563+
541564
def map_function_definition(self, expr: FunctionDefinition) -> None:
542565
if not self.visit(expr):
543566
return

pytato/codegen.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@
6868

6969

7070
if TYPE_CHECKING:
71-
from collections.abc import Mapping
71+
from collections.abc import Hashable, Mapping
7272

73-
from pytato.function import NamedCallResult
73+
from pytato.function import FunctionDefinition, NamedCallResult
7474
from pytato.target import Target
7575

7676

@@ -136,10 +136,13 @@ class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper): # type: ignore[misc]
136136
====================================== =====================================
137137
"""
138138

139-
def __init__(self, target: Target,
140-
kernels_seen: dict[str, lp.LoopKernel] | None = None
141-
) -> None:
142-
super().__init__()
139+
def __init__(
140+
self,
141+
target: Target,
142+
kernels_seen: dict[str, lp.LoopKernel] | None = None,
143+
_function_cache: dict[Hashable, FunctionDefinition] | None = None
144+
) -> None:
145+
super().__init__(_function_cache=_function_cache)
143146
self.bound_arguments: dict[str, DataInterface] = {}
144147
self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator()
145148
self.target = target
@@ -266,13 +269,16 @@ def normalize_outputs(
266269

267270
@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
268271
class NamesValidityChecker(CachedWalkMapper[[]]):
269-
def __init__(self) -> None:
272+
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
270273
self.name_to_input: dict[str, InputArgumentBase] = {}
271-
super().__init__()
274+
super().__init__(_visited_functions=_visited_functions)
272275

273276
def get_cache_key(self, expr: ArrayOrNames) -> int:
274277
return id(expr)
275278

279+
def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
280+
return id(expr)
281+
276282
def post_visit(self, expr: Any) -> None:
277283
if isinstance(expr, Placeholder | SizeParam | DataWrapper):
278284
if expr.name is not None:

pytato/distributed/partition.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,9 @@ def __init__(self,
288288
recvd_ary_to_name: Mapping[Array, str],
289289
sptpo_ary_to_name: Mapping[Array, str],
290290
name_to_output: Mapping[str, Array],
291+
_function_cache: dict[Hashable, FunctionDefinition] | None = None,
291292
) -> None:
292-
super().__init__()
293+
super().__init__(_function_cache=_function_cache)
293294

294295
self.recvd_ary_to_name = recvd_ary_to_name
295296
self.sptpo_ary_to_name = sptpo_ary_to_name
@@ -303,7 +304,7 @@ def clone_for_callee(
303304
self, function: FunctionDefinition) -> _DistributedInputReplacer:
304305
# Function definitions aren't allowed to contain receives,
305306
# stored arrays promoted to part outputs, or part outputs
306-
return type(self)({}, {}, {})
307+
return type(self)({}, {}, {}, _function_cache=self._function_cache)
307308

308309
def map_placeholder(self, expr: Placeholder) -> Placeholder:
309310
self.user_input_names.add(expr.name)
@@ -456,7 +457,7 @@ def _recv_to_comm_id(
456457

457458

458459
class _LocalSendRecvDepGatherer(
459-
CombineMapper[frozenset[CommunicationOpIdentifier]]):
460+
CombineMapper[frozenset[CommunicationOpIdentifier], None]):
460461
def __init__(self, local_rank: int) -> None:
461462
super().__init__()
462463
self.local_comm_ids_to_needed_comm_ids: \

pytato/distributed/verify.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ class MissingRecvError(DistributedPartitionVerificationError):
144144

145145
@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
146146
class _SeenNodesWalkMapper(CachedWalkMapper[[]]):
147-
def __init__(self) -> None:
148-
super().__init__()
147+
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
148+
super().__init__(_visited_functions=_visited_functions)
149149
self.seen_nodes: set[ArrayOrNames] = set()
150150

151151
def get_cache_key(self, expr: ArrayOrNames) -> int:

pytato/equality.py

+20-17
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727

2828
from typing import TYPE_CHECKING, Any
2929

30-
from pytools import memoize_method
31-
3230
from pytato.array import (
3331
AbstractResultWithNamedArrays,
3432
AdvancedIndexInContiguousAxes,
@@ -49,13 +47,14 @@
4947
SizeParam,
5048
Stack,
5149
)
50+
from pytato.function import FunctionDefinition
5251

5352

5453
if TYPE_CHECKING:
5554
from collections.abc import Callable
5655

5756
from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder
58-
from pytato.function import Call, FunctionDefinition, NamedCallResult
57+
from pytato.function import Call, NamedCallResult
5958
from pytato.loopy import LoopyCall, LoopyCallResult
6059

6160
__doc__ = """
@@ -85,26 +84,31 @@ class EqualityComparer:
8584
more on this.
8685
"""
8786
def __init__(self) -> None:
87+
# Uses the same cache for both arrays and functions
8888
self._cache: dict[tuple[int, int], bool] = {}
8989

90-
def rec(self, expr1: ArrayOrNames, expr2: Any) -> bool:
90+
def rec(self, expr1: ArrayOrNames | FunctionDefinition, expr2: Any) -> bool:
9191
cache_key = id(expr1), id(expr2)
9292
try:
9393
return self._cache[cache_key]
9494
except KeyError:
95-
96-
method: Callable[[Array | AbstractResultWithNamedArrays, Any],
97-
bool]
98-
99-
try:
100-
method = getattr(self, expr1._mapper_method)
101-
except AttributeError:
102-
if isinstance(expr1, Array):
103-
result = self.handle_unsupported_array(expr1, expr2)
95+
if expr1 is expr2:
96+
result = True
97+
elif isinstance(expr1, ArrayOrNames):
98+
method: Callable[[ArrayOrNames, Any], bool]
99+
try:
100+
method = getattr(self, expr1._mapper_method)
101+
except AttributeError:
102+
if isinstance(expr1, Array):
103+
result = self.handle_unsupported_array(expr1, expr2)
104+
else:
105+
result = self.map_foreign(expr1, expr2)
104106
else:
105-
result = self.map_foreign(expr1, expr2)
107+
result = method(expr1, expr2)
108+
elif isinstance(expr1, FunctionDefinition):
109+
result = self.map_function_definition(expr1, expr2)
106110
else:
107-
result = (expr1 is expr2) or method(expr1, expr2)
111+
result = self.map_foreign(expr1, expr2)
108112

109113
self._cache[cache_key] = result
110114
return result
@@ -296,7 +300,6 @@ def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool:
296300
and expr1.tags == expr2.tags
297301
)
298302

299-
@memoize_method
300303
def map_function_definition(self, expr1: FunctionDefinition, expr2: Any
301304
) -> bool:
302305
return (expr1.__class__ is expr2.__class__
@@ -310,7 +313,7 @@ def map_function_definition(self, expr1: FunctionDefinition, expr2: Any
310313

311314
def map_call(self, expr1: Call, expr2: Any) -> bool:
312315
return (expr1.__class__ is expr2.__class__
313-
and self.map_function_definition(expr1.function, expr2.function)
316+
and self.rec(expr1.function, expr2.function)
314317
and frozenset(expr1.bindings) == frozenset(expr2.bindings)
315318
and all(self.rec(bnd,
316319
expr2.bindings[name])

pytato/stringifier.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@
3131
import numpy as np
3232
from immutabledict import immutabledict
3333

34-
from pytools import memoize_method
35-
3634
from pytato.array import (
3735
Array,
3836
Axis,
@@ -58,7 +56,7 @@
5856

5957
# {{{ Reprifier
6058

61-
class Reprifier(Mapper[str, [int]]):
59+
class Reprifier(Mapper[str, str, [int]]):
6260
"""
6361
Stringifies :mod:`pytato`-types to closely resemble CPython's implementation
6462
of :func:`repr` for its builtin datatypes.
@@ -71,6 +69,7 @@ def __init__(self,
7169
self.truncation_depth = truncation_depth
7270
self.truncation_string = truncation_string
7371

72+
# Uses the same cache for both arrays and functions
7473
self._cache: dict[tuple[int, int], str] = {}
7574

7675
def rec(self, expr: Any, depth: int) -> str:
@@ -82,6 +81,15 @@ def rec(self, expr: Any, depth: int) -> str:
8281
self._cache[cache_key] = result
8382
return result
8483

84+
def rec_function_definition(self, expr: FunctionDefinition, depth: int) -> str:
85+
cache_key = (id(expr), depth)
86+
try:
87+
return self._cache[cache_key]
88+
except KeyError:
89+
result = super().rec_function_definition(expr, depth)
90+
self._cache[cache_key] = result
91+
return result
92+
8593
def __call__(self, expr: Any, depth: int = 0) -> str:
8694
return self.rec(expr, depth)
8795

@@ -171,7 +179,6 @@ def _get_field_val(field: str) -> str:
171179
for field in dataclasses.fields(type(expr)))
172180
+ ")")
173181

174-
@memoize_method
175182
def map_function_definition(self, expr: FunctionDefinition, depth: int) -> str:
176183
if depth > self.truncation_depth:
177184
return self.truncation_string
@@ -194,7 +201,7 @@ def map_call(self, expr: Call, depth: int) -> str:
194201

195202
def _get_field_val(field: str) -> str:
196203
if field == "function":
197-
return self.map_function_definition(expr.function, depth+1)
204+
return self.rec_function_definition(expr.function, depth+1)
198205
else:
199206
return self.rec(getattr(expr, field), depth+1)
200207

pytato/target/loopy/codegen.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def update_t_unit(self, t_unit: lp.TranslationUnit) -> None:
384384

385385
# {{{ codegen mapper
386386

387-
class CodeGenMapper(Mapper[ImplementedResult, [CodeGenState]]):
387+
class CodeGenMapper(Mapper[ImplementedResult, None, [CodeGenState]]):
388388
"""A mapper for generating code for nodes in the computation graph.
389389
"""
390390
exprgen_mapper: InlinedExpressionGenMapper

pytato/target/python/numpy_like.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
SizeParam,
6363
Stack,
6464
)
65+
from pytato.function import FunctionDefinition
6566
from pytato.raising import BinaryOpType, C99CallOp
6667
from pytato.reductions import (
6768
AllReductionOperation,
@@ -171,7 +172,7 @@ def _is_slice_trivial(slice_: NormalizedSlice,
171172
}
172173

173174

174-
class NumpyCodegenMapper(CachedMapper[str, []]):
175+
class NumpyCodegenMapper(CachedMapper[str, FunctionDefinition, []]):
175176
"""
176177
.. note::
177178

0 commit comments

Comments
 (0)