Skip to content

Commit 84c279c

Browse files
committed
make CopyMapper/CopyMapperWithExtraArgs inherit from TransformMapper/TransformMapperWithExtraArgs
1 parent de083a5 commit 84c279c

File tree

1 file changed

+2
-64
lines changed

1 file changed

+2
-64
lines changed

pytato/transform/__init__.py

+2-64
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,6 @@
9191
CombineT = TypeVar("CombineT") # used in CombineMapper
9292
TransformMapperResultT = TypeVar("TransformMapperResultT", # used in TransformMapper
9393
Array, AbstractResultWithNamedArrays, ArrayOrNames)
94-
CopyMapperResultT = TypeVar("CopyMapperResultT", # used in CopyMapper
95-
Array, AbstractResultWithNamedArrays, ArrayOrNames)
9694
CachedMapperT = TypeVar("CachedMapperT") # used in CachedMapper
9795
IndexOrShapeExpr = TypeVar("IndexOrShapeExpr")
9896
R = FrozenSet[Array]
@@ -344,32 +342,15 @@ def clone_for_callee(
344342

345343
# {{{ CopyMapper
346344

347-
class CopyMapper(CachedMapper[ArrayOrNames]):
345+
class CopyMapper(TransformMapper):
348346
"""Performs a deep copy of a :class:`pytato.array.Array`.
349347
The typical use of this mapper is to override individual ``map_`` methods
350348
in subclasses to permit term rewriting on an expression graph.
351349
352-
.. automethod:: clone_for_callee
353-
354350
.. note::
355351
356352
This does not copy the data of a :class:`pytato.array.DataWrapper`.
357353
"""
358-
if TYPE_CHECKING:
359-
def rec(self, expr: CopyMapperResultT) -> CopyMapperResultT:
360-
return cast(CopyMapperResultT, super().rec(expr))
361-
362-
def __call__(self, expr: CopyMapperResultT) -> CopyMapperResultT:
363-
return self.rec(expr)
364-
365-
def clone_for_callee(
366-
self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper:
367-
"""
368-
Called to clone *self* before starting traversal of a
369-
:class:`pytato.function.FunctionDefinition`.
370-
"""
371-
return type(self)()
372-
373354
def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...]
374355
) -> tuple[IndexOrShapeExpr, ...]:
375356
# type-ignore-reason: apparently mypy cannot substitute typevars
@@ -554,57 +535,14 @@ def map_named_call_result(self, expr: NamedCallResult) -> Array:
554535
return call[expr.name]
555536

556537

557-
class CopyMapperWithExtraArgs(CachedMapper[ArrayOrNames]):
538+
class CopyMapperWithExtraArgs(TransformMapperWithExtraArgs):
558539
"""
559540
Similar to :class:`CopyMapper`, but each mapper method takes extra
560541
``*args``, ``**kwargs`` that are propagated along a path by default.
561542
562543
The logic in :class:`CopyMapper` purposely does not take the extra
563544
arguments to keep the cost of its each call frame low.
564-
565-
.. automethod:: clone_for_callee
566545
"""
567-
def __init__(self) -> None:
568-
super().__init__()
569-
# type-ignored as '._cache' attribute is not coherent with the base
570-
# class
571-
self._cache: dict[tuple[ArrayOrNames,
572-
tuple[Any, ...],
573-
tuple[tuple[str, Any], ...]
574-
],
575-
ArrayOrNames] = {} # type: ignore[assignment]
576-
577-
def get_cache_key(self,
578-
expr: ArrayOrNames,
579-
*args: Any, **kwargs: Any) -> tuple[ArrayOrNames,
580-
tuple[Any, ...],
581-
tuple[tuple[str, Any], ...]
582-
]:
583-
return (expr, args, tuple(sorted(kwargs.items())))
584-
585-
def rec(self,
586-
expr: CopyMapperResultT,
587-
*args: Any, **kwargs: Any) -> CopyMapperResultT:
588-
key = self.get_cache_key(expr, *args, **kwargs)
589-
try:
590-
# type-ignore-reason: self._cache has ArrayOrNames as its values
591-
return self._cache[key] # type: ignore[return-value]
592-
except KeyError:
593-
result = Mapper.rec(self, expr,
594-
*args,
595-
**kwargs)
596-
self._cache[key] = result
597-
# type-ignore-reason: Mapper.rec is imprecise
598-
return result # type: ignore[no-any-return]
599-
600-
def clone_for_callee(
601-
self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper:
602-
"""
603-
Called to clone *self* before starting traversal of a
604-
:class:`pytato.function.FunctionDefinition`.
605-
"""
606-
return type(self)()
607-
608546
def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...],
609547
*args: Any, **kwargs: Any
610548
) -> tuple[IndexOrShapeExpr, ...]:

0 commit comments

Comments
 (0)