Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a class for CachedMapper-derived mappers instead of a dict #549

Merged
merged 5 commits into from
Jan 28, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
make CombineMapper inherit from CachedMapper
majosm committed Jan 24, 2025
commit 1e30fd131a56f7e8a448db99aba0c6c2cefdf531
33 changes: 6 additions & 27 deletions pytato/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -768,44 +768,23 @@ def map_named_call_result(self, expr: NamedCallResult,

# {{{ CombineMapper

class CombineMapper(Mapper[ResultT, FunctionResultT, []]):
class CombineMapper(CachedMapper[ResultT, FunctionResultT, []]):
"""
Abstract mapper that recursively combines the results of user nodes
of a given expression.
.. automethod:: combine
"""
def __init__(
self,
_function_cache: dict[FunctionDefinition, FunctionResultT] | None = None
) -> None:
super().__init__()
self.cache: dict[ArrayOrNames, ResultT] = {}
self.function_cache: dict[FunctionDefinition, FunctionResultT] = \
_function_cache if _function_cache is not None else {}
def get_cache_key(self, expr: ArrayOrNames) -> Hashable:
return expr

def get_function_definition_cache_key(self, expr: FunctionDefinition) -> Hashable:
return expr

def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...]
) -> tuple[ResultT, ...]:
return tuple(self.rec(s) for s in situp if isinstance(s, Array))

def rec(self, expr: ArrayOrNames) -> ResultT:
if expr in self.cache:
return self.cache[expr]
result: ResultT = super().rec(expr)
self.cache[expr] = result
return result

def rec_function_definition(
self, expr: FunctionDefinition) -> FunctionResultT:
if expr in self.function_cache:
return self.function_cache[expr]
result: FunctionResultT = super().rec_function_definition(expr)
self.function_cache[expr] = result
return result

def __call__(self, expr: ArrayOrNames) -> ResultT:
return self.rec(expr)

def combine(self, *args: ResultT) -> ResultT:
"""Combine the arguments."""
raise NotImplementedError