|
28 | 28 |
|
29 | 29 | from typing import TYPE_CHECKING, Any
|
30 | 30 |
|
| 31 | +from typing_extensions import Self |
| 32 | + |
31 | 33 | from loopy.tools import LoopyKeyBuilder
|
32 | 34 | from pymbolic.mapper.optimize import optimize_mapper
|
33 |
| -from pytools import memoize_method |
34 | 35 |
|
35 | 36 | from pytato.array import (
|
36 | 37 | Array,
|
|
76 | 77 |
|
77 | 78 | # {{{ NUserCollector
|
78 | 79 |
|
79 |
| -class NUserCollector(Mapper[None, []]): |
| 80 | +class NUserCollector(Mapper[None, None, []]): |
80 | 81 | """
|
81 | 82 | A :class:`pytato.transform.CachedWalkMapper` that records the number of
|
82 | 83 | times an array expression is a direct dependency of other nodes.
|
@@ -317,7 +318,7 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool:
|
317 | 318 |
|
318 | 319 | # {{{ DirectPredecessorsGetter
|
319 | 320 |
|
320 |
| -class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], []]): |
| 321 | +class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], None, []]): |
321 | 322 | """
|
322 | 323 | Mapper to get the
|
323 | 324 | `direct predecessors
|
@@ -413,16 +414,31 @@ class NodeCountMapper(CachedWalkMapper[[]]):
|
413 | 414 | Dictionary mapping node types to number of nodes of that type.
|
414 | 415 | """
|
415 | 416 |
|
416 |
| - def __init__(self, count_duplicates: bool = False) -> None: |
| 417 | + def __init__( |
| 418 | + self, |
| 419 | + count_duplicates: bool = False, |
| 420 | + _visited_functions: set[Any] | None = None, |
| 421 | + ) -> None: |
| 422 | + super().__init__(_visited_functions=_visited_functions) |
| 423 | + |
417 | 424 | from collections import defaultdict
|
418 |
| - super().__init__() |
419 | 425 | self.expr_type_counts: dict[type[Any], int] = defaultdict(int)
|
420 | 426 | self.count_duplicates = count_duplicates
|
421 | 427 |
|
422 | 428 | def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames:
|
423 | 429 | # Returns unique nodes only if count_duplicates is False
|
424 | 430 | return id(expr) if self.count_duplicates else expr
|
425 | 431 |
|
| 432 | + def get_function_definition_cache_key( |
| 433 | + self, expr: FunctionDefinition) -> int | FunctionDefinition: |
| 434 | + # Returns unique nodes only if count_duplicates is False |
| 435 | + return id(expr) if self.count_duplicates else expr |
| 436 | + |
| 437 | + def clone_for_callee(self, function: FunctionDefinition) -> Self: |
| 438 | + return type(self)( |
| 439 | + count_duplicates=self.count_duplicates, |
| 440 | + _visited_functions=self._visited_functions) |
| 441 | + |
426 | 442 | def post_visit(self, expr: Any) -> None:
|
427 | 443 | if not isinstance(expr, DictOfNamedArrays):
|
428 | 444 | self.expr_type_counts[type(expr)] += 1
|
@@ -488,15 +504,20 @@ class NodeMultiplicityMapper(CachedWalkMapper[[]]):
|
488 | 504 |
|
489 | 505 | .. autoattribute:: expr_multiplicity_counts
|
490 | 506 | """
|
491 |
| - def __init__(self) -> None: |
| 507 | + def __init__(self, _visited_functions: set[Any] | None = None) -> None: |
| 508 | + super().__init__(_visited_functions=_visited_functions) |
| 509 | + |
492 | 510 | from collections import defaultdict
|
493 |
| - super().__init__() |
494 | 511 | self.expr_multiplicity_counts: dict[Array, int] = defaultdict(int)
|
495 | 512 |
|
496 | 513 | def get_cache_key(self, expr: ArrayOrNames) -> int:
|
497 | 514 | # Returns each node, including nodes that are duplicates
|
498 | 515 | return id(expr)
|
499 | 516 |
|
| 517 | + def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: |
| 518 | + # Returns each node, including nodes that are duplicates |
| 519 | + return id(expr) |
| 520 | + |
500 | 521 | def post_visit(self, expr: Any) -> None:
|
501 | 522 | if not isinstance(expr, DictOfNamedArrays):
|
502 | 523 | self.expr_multiplicity_counts[expr] += 1
|
@@ -530,14 +551,16 @@ class CallSiteCountMapper(CachedWalkMapper[[]]):
|
530 | 551 | The number of nodes.
|
531 | 552 | """
|
532 | 553 |
|
533 |
| - def __init__(self) -> None: |
534 |
| - super().__init__() |
| 554 | + def __init__(self, _visited_functions: set[Any] | None = None) -> None: |
| 555 | + super().__init__(_visited_functions=_visited_functions) |
535 | 556 | self.count = 0
|
536 | 557 |
|
537 | 558 | def get_cache_key(self, expr: ArrayOrNames) -> int:
|
538 | 559 | return id(expr)
|
539 | 560 |
|
540 |
| - @memoize_method |
| 561 | + def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: |
| 562 | + return id(expr) |
| 563 | + |
541 | 564 | def map_function_definition(self, expr: FunctionDefinition) -> None:
|
542 | 565 | if not self.visit(expr):
|
543 | 566 | return
|
|
0 commit comments