|
29 | 29 | from typing import TYPE_CHECKING, Any, Mapping
|
30 | 30 |
|
31 | 31 | from pymbolic.mapper.optimize import optimize_mapper
|
32 |
| -from pytools import memoize_method |
33 | 32 |
|
34 | 33 | from pytato.array import (
|
35 | 34 | Array,
|
|
46 | 45 | )
|
47 | 46 | from pytato.function import Call, FunctionDefinition, NamedCallResult
|
48 | 47 | from pytato.loopy import LoopyCall
|
49 |
| -from pytato.transform import ArrayOrNames, CachedWalkMapper, Mapper |
| 48 | +from pytato.transform import ArrayOrNames, CachedWalkMapper, Mapper, _SelfMapper |
50 | 49 |
|
51 | 50 |
|
52 | 51 | if TYPE_CHECKING:
|
@@ -410,16 +409,34 @@ class NodeCountMapper(CachedWalkMapper):
|
410 | 409 | Dictionary mapping node types to number of nodes of that type.
|
411 | 410 | """
|
412 | 411 |
|
413 |
| - def __init__(self, count_duplicates: bool = False) -> None: |
| 412 | + def __init__( |
| 413 | + self, |
| 414 | + count_duplicates: bool = False, |
| 415 | + _visited_functions: set[Any] | None = None, |
| 416 | + ) -> None: |
| 417 | + super().__init__(_visited_functions=_visited_functions) |
| 418 | + |
414 | 419 | from collections import defaultdict
|
415 |
| - super().__init__() |
416 | 420 | self.expr_type_counts: dict[type[Any], int] = defaultdict(int)
|
417 | 421 | self.count_duplicates = count_duplicates
|
418 | 422 |
|
419 | 423 | def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames:
|
420 | 424 | # Returns unique nodes only if count_duplicates is False
|
421 | 425 | return id(expr) if self.count_duplicates else expr
|
422 | 426 |
|
| 427 | + def get_function_definition_cache_key( |
| 428 | + self, expr: FunctionDefinition) -> int | FunctionDefinition: |
| 429 | + # Returns unique nodes only if count_duplicates is False |
| 430 | + return id(expr) if self.count_duplicates else expr |
| 431 | + |
| 432 | + def clone_for_callee( |
| 433 | + self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper: |
| 434 | + # type-ignore-reason: self.__init__ has a different function signature |
| 435 | + # than Mapper.__init__ |
| 436 | + return type(self)( |
| 437 | + count_duplicates=self.count_duplicates, # type: ignore[attr-defined] |
| 438 | + _visited_functions=self._visited_functions) # type: ignore[call-arg,attr-defined] |
| 439 | + |
423 | 440 | def post_visit(self, expr: Any) -> None:
|
424 | 441 | if not isinstance(expr, DictOfNamedArrays):
|
425 | 442 | self.expr_type_counts[type(expr)] += 1
|
@@ -485,15 +502,20 @@ class NodeMultiplicityMapper(CachedWalkMapper):
|
485 | 502 |
|
486 | 503 | .. autoattribute:: expr_multiplicity_counts
|
487 | 504 | """
|
488 |
| - def __init__(self) -> None: |
| 505 | + def __init__(self, _visited_functions: set[Any] | None = None) -> None: |
| 506 | + super().__init__(_visited_functions=_visited_functions) |
| 507 | + |
489 | 508 | from collections import defaultdict
|
490 |
| - super().__init__() |
491 | 509 | self.expr_multiplicity_counts: dict[Array, int] = defaultdict(int)
|
492 | 510 |
|
493 | 511 | def get_cache_key(self, expr: ArrayOrNames) -> int:
|
494 | 512 | # Returns each node, including nodes that are duplicates
|
495 | 513 | return id(expr)
|
496 | 514 |
|
| 515 | + def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: |
| 516 | + # Returns each node, including nodes that are duplicates |
| 517 | + return id(expr) |
| 518 | + |
497 | 519 | def post_visit(self, expr: Any) -> None:
|
498 | 520 | if not isinstance(expr, DictOfNamedArrays):
|
499 | 521 | self.expr_multiplicity_counts[expr] += 1
|
@@ -527,14 +549,16 @@ class CallSiteCountMapper(CachedWalkMapper):
|
527 | 549 | The number of nodes.
|
528 | 550 | """
|
529 | 551 |
|
530 |
| - def __init__(self) -> None: |
531 |
| - super().__init__() |
| 552 | + def __init__(self, _visited_functions: set[Any] | None = None) -> None: |
| 553 | + super().__init__(_visited_functions=_visited_functions) |
532 | 554 | self.count = 0
|
533 | 555 |
|
534 | 556 | def get_cache_key(self, expr: ArrayOrNames) -> int:
|
535 | 557 | return id(expr)
|
536 | 558 |
|
537 |
| - @memoize_method |
| 559 | + def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: |
| 560 | + return id(expr) |
| 561 | + |
538 | 562 | def map_function_definition(self, expr: FunctionDefinition) -> None:
|
539 | 563 | if not self.visit(expr):
|
540 | 564 | return
|
|
0 commit comments