Skip to content

Commit 46431ee

Browse files
authored
Cache mapped functions (#531)
* avoid traversing functions multiple times * condense function cache setup * pass function cache to CombineMapper.__init__ too * use Never instead of None as the function result type for mappers that don't support functions * remove multiple inheritance for TransformMapperWithExtraArgs doesn't appear to be needed * add _verify_is_array to avoid the need for rec_ary the latter inflates recursion depth * remove map_foreign from Mapper * use Never as FunctionResultT for NumpyCodegenMapper and FancyDotWriter too * make PlaceholderSubstitutor explicit about not supporting functions * tweak types in rec/rec_function_definition * use P.args/P.kwargs in CachedWalkMapper * use a more specific type than Any for cache key in CachedWalkMapper * remove some more Anys
1 parent 377289c commit 46431ee

18 files changed

+334
-198
lines changed

pytato/analysis/__init__.py

+34-11
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,12 @@
2626
THE SOFTWARE.
2727
"""
2828

29-
from typing import TYPE_CHECKING, Any
29+
from typing import TYPE_CHECKING, Any, Never
30+
31+
from typing_extensions import Self
3032

3133
from loopy.tools import LoopyKeyBuilder
3234
from pymbolic.mapper.optimize import optimize_mapper
33-
from pytools import memoize_method
3435

3536
from pytato.array import (
3637
Array,
@@ -76,7 +77,7 @@
7677

7778
# {{{ NUserCollector
7879

79-
class NUserCollector(Mapper[None, []]):
80+
class NUserCollector(Mapper[None, None, []]):
8081
"""
8182
A :class:`pytato.transform.CachedWalkMapper` that records the number of
8283
times an array expression is a direct dependency of other nodes.
@@ -317,7 +318,7 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool:
317318

318319
# {{{ DirectPredecessorsGetter
319320

320-
class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], []]):
321+
class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]):
321322
"""
322323
Mapper to get the
323324
`direct predecessors
@@ -413,16 +414,31 @@ class NodeCountMapper(CachedWalkMapper[[]]):
413414
Dictionary mapping node types to number of nodes of that type.
414415
"""
415416

416-
def __init__(self, count_duplicates: bool = False) -> None:
417+
def __init__(
418+
self,
419+
count_duplicates: bool = False,
420+
_visited_functions: set[Any] | None = None,
421+
) -> None:
422+
super().__init__(_visited_functions=_visited_functions)
423+
417424
from collections import defaultdict
418-
super().__init__()
419425
self.expr_type_counts: dict[type[Any], int] = defaultdict(int)
420426
self.count_duplicates = count_duplicates
421427

422428
def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames:
423429
# Returns unique nodes only if count_duplicates is False
424430
return id(expr) if self.count_duplicates else expr
425431

432+
def get_function_definition_cache_key(
433+
self, expr: FunctionDefinition) -> int | FunctionDefinition:
434+
# Returns unique nodes only if count_duplicates is False
435+
return id(expr) if self.count_duplicates else expr
436+
437+
def clone_for_callee(self, function: FunctionDefinition) -> Self:
438+
return type(self)(
439+
count_duplicates=self.count_duplicates,
440+
_visited_functions=self._visited_functions)
441+
426442
def post_visit(self, expr: Any) -> None:
427443
if not isinstance(expr, DictOfNamedArrays):
428444
self.expr_type_counts[type(expr)] += 1
@@ -488,15 +504,20 @@ class NodeMultiplicityMapper(CachedWalkMapper[[]]):
488504
489505
.. autoattribute:: expr_multiplicity_counts
490506
"""
491-
def __init__(self) -> None:
507+
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
508+
super().__init__(_visited_functions=_visited_functions)
509+
492510
from collections import defaultdict
493-
super().__init__()
494511
self.expr_multiplicity_counts: dict[Array, int] = defaultdict(int)
495512

496513
def get_cache_key(self, expr: ArrayOrNames) -> int:
497514
# Returns each node, including nodes that are duplicates
498515
return id(expr)
499516

517+
def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
518+
# Returns each node, including nodes that are duplicates
519+
return id(expr)
520+
500521
def post_visit(self, expr: Any) -> None:
501522
if not isinstance(expr, DictOfNamedArrays):
502523
self.expr_multiplicity_counts[expr] += 1
@@ -530,14 +551,16 @@ class CallSiteCountMapper(CachedWalkMapper[[]]):
530551
The number of nodes.
531552
"""
532553

533-
def __init__(self) -> None:
534-
super().__init__()
554+
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
555+
super().__init__(_visited_functions=_visited_functions)
535556
self.count = 0
536557

537558
def get_cache_key(self, expr: ArrayOrNames) -> int:
538559
return id(expr)
539560

540-
@memoize_method
561+
def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
562+
return id(expr)
563+
541564
def map_function_definition(self, expr: FunctionDefinition) -> None:
542565
if not self.visit(expr):
543566
return

pytato/codegen.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@
6868

6969

7070
if TYPE_CHECKING:
71-
from collections.abc import Mapping
71+
from collections.abc import Hashable, Mapping
7272

73-
from pytato.function import NamedCallResult
73+
from pytato.function import FunctionDefinition, NamedCallResult
7474
from pytato.target import Target
7575

7676

@@ -136,10 +136,13 @@ class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper): # type: ignore[misc]
136136
====================================== =====================================
137137
"""
138138

139-
def __init__(self, target: Target,
140-
kernels_seen: dict[str, lp.LoopKernel] | None = None
141-
) -> None:
142-
super().__init__()
139+
def __init__(
140+
self,
141+
target: Target,
142+
kernels_seen: dict[str, lp.LoopKernel] | None = None,
143+
_function_cache: dict[Hashable, FunctionDefinition] | None = None
144+
) -> None:
145+
super().__init__(_function_cache=_function_cache)
143146
self.bound_arguments: dict[str, DataInterface] = {}
144147
self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator()
145148
self.target = target
@@ -266,13 +269,16 @@ def normalize_outputs(
266269

267270
@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
268271
class NamesValidityChecker(CachedWalkMapper[[]]):
269-
def __init__(self) -> None:
272+
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
270273
self.name_to_input: dict[str, InputArgumentBase] = {}
271-
super().__init__()
274+
super().__init__(_visited_functions=_visited_functions)
272275

273276
def get_cache_key(self, expr: ArrayOrNames) -> int:
274277
return id(expr)
275278

279+
def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
280+
return id(expr)
281+
276282
def post_visit(self, expr: Any) -> None:
277283
if isinstance(expr, Placeholder | SizeParam | DataWrapper):
278284
if expr.name is not None:

pytato/distributed/partition.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
TYPE_CHECKING,
7171
Any,
7272
Generic,
73+
Never,
7374
TypeVar,
7475
cast,
7576
)
@@ -89,7 +90,13 @@
8990
DistributedSendRefHolder,
9091
)
9192
from pytato.scalar_expr import SCALAR_CLASSES
92-
from pytato.transform import ArrayOrNames, CachedWalkMapper, CombineMapper, CopyMapper
93+
from pytato.transform import (
94+
ArrayOrNames,
95+
CachedWalkMapper,
96+
CombineMapper,
97+
CopyMapper,
98+
_verify_is_array,
99+
)
93100

94101

95102
if TYPE_CHECKING:
@@ -288,8 +295,9 @@ def __init__(self,
288295
recvd_ary_to_name: Mapping[Array, str],
289296
sptpo_ary_to_name: Mapping[Array, str],
290297
name_to_output: Mapping[str, Array],
298+
_function_cache: dict[Hashable, FunctionDefinition] | None = None,
291299
) -> None:
292-
super().__init__()
300+
super().__init__(_function_cache=_function_cache)
293301

294302
self.recvd_ary_to_name = recvd_ary_to_name
295303
self.sptpo_ary_to_name = sptpo_ary_to_name
@@ -303,7 +311,7 @@ def clone_for_callee(
303311
self, function: FunctionDefinition) -> _DistributedInputReplacer:
304312
# Function definitions aren't allowed to contain receives,
305313
# stored arrays promoted to part outputs, or part outputs
306-
return type(self)({}, {}, {})
314+
return type(self)({}, {}, {}, _function_cache=self._function_cache)
307315

308316
def map_placeholder(self, expr: Placeholder) -> Placeholder:
309317
self.user_input_names.add(expr.name)
@@ -394,7 +402,7 @@ def _make_distributed_partition(
394402

395403
for name, val in name_to_part_output.items():
396404
assert name not in name_to_output
397-
name_to_output[name] = comm_replacer.rec_ary(val)
405+
name_to_output[name] = _verify_is_array(comm_replacer.rec(val))
398406

399407
comm_ids = part_comm_ids[part_id]
400408

@@ -456,7 +464,7 @@ def _recv_to_comm_id(
456464

457465

458466
class _LocalSendRecvDepGatherer(
459-
CombineMapper[frozenset[CommunicationOpIdentifier]]):
467+
CombineMapper[frozenset[CommunicationOpIdentifier], Never]):
460468
def __init__(self, local_rank: int) -> None:
461469
super().__init__()
462470
self.local_comm_ids_to_needed_comm_ids: \

pytato/distributed/verify.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ class MissingRecvError(DistributedPartitionVerificationError):
144144

145145
@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
146146
class _SeenNodesWalkMapper(CachedWalkMapper[[]]):
147-
def __init__(self) -> None:
148-
super().__init__()
147+
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
148+
super().__init__(_visited_functions=_visited_functions)
149149
self.seen_nodes: set[ArrayOrNames] = set()
150150

151151
def get_cache_key(self, expr: ArrayOrNames) -> int:

pytato/equality.py

+20-17
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727

2828
from typing import TYPE_CHECKING, Any
2929

30-
from pytools import memoize_method
31-
3230
from pytato.array import (
3331
AbstractResultWithNamedArrays,
3432
AdvancedIndexInContiguousAxes,
@@ -49,13 +47,14 @@
4947
SizeParam,
5048
Stack,
5149
)
50+
from pytato.function import FunctionDefinition
5251

5352

5453
if TYPE_CHECKING:
5554
from collections.abc import Callable
5655

5756
from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder
58-
from pytato.function import Call, FunctionDefinition, NamedCallResult
57+
from pytato.function import Call, NamedCallResult
5958
from pytato.loopy import LoopyCall, LoopyCallResult
6059

6160
__doc__ = """
@@ -85,26 +84,31 @@ class EqualityComparer:
8584
more on this.
8685
"""
8786
def __init__(self) -> None:
87+
# Uses the same cache for both arrays and functions
8888
self._cache: dict[tuple[int, int], bool] = {}
8989

90-
def rec(self, expr1: ArrayOrNames, expr2: Any) -> bool:
90+
def rec(self, expr1: ArrayOrNames | FunctionDefinition, expr2: Any) -> bool:
9191
cache_key = id(expr1), id(expr2)
9292
try:
9393
return self._cache[cache_key]
9494
except KeyError:
95-
96-
method: Callable[[Array | AbstractResultWithNamedArrays, Any],
97-
bool]
98-
99-
try:
100-
method = getattr(self, expr1._mapper_method)
101-
except AttributeError:
102-
if isinstance(expr1, Array):
103-
result = self.handle_unsupported_array(expr1, expr2)
95+
if expr1 is expr2:
96+
result = True
97+
elif isinstance(expr1, ArrayOrNames):
98+
method: Callable[[ArrayOrNames, Any], bool]
99+
try:
100+
method = getattr(self, expr1._mapper_method)
101+
except AttributeError:
102+
if isinstance(expr1, Array):
103+
result = self.handle_unsupported_array(expr1, expr2)
104+
else:
105+
result = self.map_foreign(expr1, expr2)
104106
else:
105-
result = self.map_foreign(expr1, expr2)
107+
result = method(expr1, expr2)
108+
elif isinstance(expr1, FunctionDefinition):
109+
result = self.map_function_definition(expr1, expr2)
106110
else:
107-
result = (expr1 is expr2) or method(expr1, expr2)
111+
result = self.map_foreign(expr1, expr2)
108112

109113
self._cache[cache_key] = result
110114
return result
@@ -296,7 +300,6 @@ def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool:
296300
and expr1.tags == expr2.tags
297301
)
298302

299-
@memoize_method
300303
def map_function_definition(self, expr1: FunctionDefinition, expr2: Any
301304
) -> bool:
302305
return (expr1.__class__ is expr2.__class__
@@ -310,7 +313,7 @@ def map_function_definition(self, expr1: FunctionDefinition, expr2: Any
310313

311314
def map_call(self, expr1: Call, expr2: Any) -> bool:
312315
return (expr1.__class__ is expr2.__class__
313-
and self.map_function_definition(expr1.function, expr2.function)
316+
and self.rec(expr1.function, expr2.function)
314317
and frozenset(expr1.bindings) == frozenset(expr2.bindings)
315318
and all(self.rec(bnd,
316319
expr2.bindings[name])

pytato/stringifier.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@
3131
import numpy as np
3232
from immutabledict import immutabledict
3333

34-
from pytools import memoize_method
35-
3634
from pytato.array import (
3735
Array,
3836
Axis,
@@ -41,7 +39,7 @@
4139
IndexLambda,
4240
ReductionDescriptor,
4341
)
44-
from pytato.transform import Mapper
42+
from pytato.transform import ForeignObjectError, Mapper
4543

4644

4745
if TYPE_CHECKING:
@@ -58,7 +56,7 @@
5856

5957
# {{{ Reprifier
6058

61-
class Reprifier(Mapper[str, [int]]):
59+
class Reprifier(Mapper[str, str, [int]]):
6260
"""
6361
Stringifies :mod:`pytato`-types to closely resemble CPython's implementation
6462
of :func:`repr` for its builtin datatypes.
@@ -71,14 +69,27 @@ def __init__(self,
7169
self.truncation_depth = truncation_depth
7270
self.truncation_string = truncation_string
7371

72+
# Uses the same cache for both arrays and functions
7473
self._cache: dict[tuple[int, int], str] = {}
7574

7675
def rec(self, expr: Any, depth: int) -> str:
7776
cache_key = (id(expr), depth)
7877
try:
7978
return self._cache[cache_key]
8079
except KeyError:
81-
result = super().rec(expr, depth)
80+
try:
81+
result = super().rec(expr, depth)
82+
except ForeignObjectError:
83+
result = self.map_foreign(expr, depth)
84+
self._cache[cache_key] = result
85+
return result
86+
87+
def rec_function_definition(self, expr: FunctionDefinition, depth: int) -> str:
88+
cache_key = (id(expr), depth)
89+
try:
90+
return self._cache[cache_key]
91+
except KeyError:
92+
result = super().rec_function_definition(expr, depth)
8293
self._cache[cache_key] = result
8394
return result
8495

@@ -171,7 +182,6 @@ def _get_field_val(field: str) -> str:
171182
for field in dataclasses.fields(type(expr)))
172183
+ ")")
173184

174-
@memoize_method
175185
def map_function_definition(self, expr: FunctionDefinition, depth: int) -> str:
176186
if depth > self.truncation_depth:
177187
return self.truncation_string
@@ -194,7 +204,7 @@ def map_call(self, expr: Call, depth: int) -> str:
194204

195205
def _get_field_val(field: str) -> str:
196206
if field == "function":
197-
return self.map_function_definition(expr.function, depth+1)
207+
return self.rec_function_definition(expr.function, depth+1)
198208
else:
199209
return self.rec(getattr(expr, field), depth+1)
200210

0 commit comments

Comments
 (0)