|
91 | 91 | CombineT = TypeVar("CombineT") # used in CombineMapper
|
92 | 92 | TransformMapperResultT = TypeVar("TransformMapperResultT", # used in TransformMapper
|
93 | 93 | Array, AbstractResultWithNamedArrays, ArrayOrNames)
|
94 |
| -CopyMapperResultT = TypeVar("CopyMapperResultT", # used in CopyMapper |
95 |
| - Array, AbstractResultWithNamedArrays, ArrayOrNames) |
96 | 94 | CachedMapperT = TypeVar("CachedMapperT") # used in CachedMapper
|
97 | 95 | IndexOrShapeExpr = TypeVar("IndexOrShapeExpr")
|
98 | 96 | R = FrozenSet[Array]
|
@@ -344,32 +342,15 @@ def clone_for_callee(
|
344 | 342 |
|
345 | 343 | # {{{ CopyMapper
|
346 | 344 |
|
347 |
| -class CopyMapper(CachedMapper[ArrayOrNames]): |
| 345 | +class CopyMapper(TransformMapper): |
348 | 346 | """Performs a deep copy of a :class:`pytato.array.Array`.
|
349 | 347 | The typical use of this mapper is to override individual ``map_`` methods
|
350 | 348 | in subclasses to permit term rewriting on an expression graph.
|
351 | 349 |
|
352 |
| - .. automethod:: clone_for_callee |
353 |
| -
|
354 | 350 | .. note::
|
355 | 351 |
|
356 | 352 | This does not copy the data of a :class:`pytato.array.DataWrapper`.
|
357 | 353 | """
|
358 |
| - if TYPE_CHECKING: |
359 |
| - def rec(self, expr: CopyMapperResultT) -> CopyMapperResultT: |
360 |
| - return cast(CopyMapperResultT, super().rec(expr)) |
361 |
| - |
362 |
| - def __call__(self, expr: CopyMapperResultT) -> CopyMapperResultT: |
363 |
| - return self.rec(expr) |
364 |
| - |
365 |
| - def clone_for_callee( |
366 |
| - self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper: |
367 |
| - """ |
368 |
| - Called to clone *self* before starting traversal of a |
369 |
| - :class:`pytato.function.FunctionDefinition`. |
370 |
| - """ |
371 |
| - return type(self)() |
372 |
| - |
373 | 354 | def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...]
|
374 | 355 | ) -> tuple[IndexOrShapeExpr, ...]:
|
375 | 356 | # type-ignore-reason: apparently mypy cannot substitute typevars
|
@@ -554,57 +535,14 @@ def map_named_call_result(self, expr: NamedCallResult) -> Array:
|
554 | 535 | return call[expr.name]
|
555 | 536 |
|
556 | 537 |
|
557 |
| -class CopyMapperWithExtraArgs(CachedMapper[ArrayOrNames]): |
| 538 | +class CopyMapperWithExtraArgs(TransformMapperWithExtraArgs): |
558 | 539 | """
|
559 | 540 | Similar to :class:`CopyMapper`, but each mapper method takes extra
|
560 | 541 | ``*args``, ``**kwargs`` that are propagated along a path by default.
|
561 | 542 |
|
562 | 543 | The logic in :class:`CopyMapper` purposely does not take the extra
|
563 | 544 | arguments to keep the cost of its each call frame low.
|
564 |
| -
|
565 |
| - .. automethod:: clone_for_callee |
566 | 545 | """
|
567 |
| - def __init__(self) -> None: |
568 |
| - super().__init__() |
569 |
| - # type-ignored as '._cache' attribute is not coherent with the base |
570 |
| - # class |
571 |
| - self._cache: dict[tuple[ArrayOrNames, |
572 |
| - tuple[Any, ...], |
573 |
| - tuple[tuple[str, Any], ...] |
574 |
| - ], |
575 |
| - ArrayOrNames] = {} # type: ignore[assignment] |
576 |
| - |
577 |
| - def get_cache_key(self, |
578 |
| - expr: ArrayOrNames, |
579 |
| - *args: Any, **kwargs: Any) -> tuple[ArrayOrNames, |
580 |
| - tuple[Any, ...], |
581 |
| - tuple[tuple[str, Any], ...] |
582 |
| - ]: |
583 |
| - return (expr, args, tuple(sorted(kwargs.items()))) |
584 |
| - |
585 |
| - def rec(self, |
586 |
| - expr: CopyMapperResultT, |
587 |
| - *args: Any, **kwargs: Any) -> CopyMapperResultT: |
588 |
| - key = self.get_cache_key(expr, *args, **kwargs) |
589 |
| - try: |
590 |
| - # type-ignore-reason: self._cache has ArrayOrNames as its values |
591 |
| - return self._cache[key] # type: ignore[return-value] |
592 |
| - except KeyError: |
593 |
| - result = Mapper.rec(self, expr, |
594 |
| - *args, |
595 |
| - **kwargs) |
596 |
| - self._cache[key] = result |
597 |
| - # type-ignore-reason: Mapper.rec is imprecise |
598 |
| - return result # type: ignore[no-any-return] |
599 |
| - |
600 |
| - def clone_for_callee( |
601 |
| - self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper: |
602 |
| - """ |
603 |
| - Called to clone *self* before starting traversal of a |
604 |
| - :class:`pytato.function.FunctionDefinition`. |
605 |
| - """ |
606 |
| - return type(self)() |
607 |
| - |
608 | 546 | def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...],
|
609 | 547 | *args: Any, **kwargs: Any
|
610 | 548 | ) -> tuple[IndexOrShapeExpr, ...]:
|
|
0 commit comments