Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CacheInputs class to simplify passing of expr + optional args/kwargs/key #583

Merged
merged 5 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,9 +622,9 @@ def combine(self, *args: int) -> int:
return sum(args)

def rec(self, expr: ArrayOrNames) -> int:
key = self._cache.get_key(expr)
inputs = self._make_cache_inputs(expr)
try:
return self._cache.retrieve(expr, key=key)
return self._cache.retrieve(inputs)
except KeyError:
# Intentionally going to Mapper instead of super() to avoid
# double caching when subclasses of CachedMapper override rec,
Expand All @@ -639,7 +639,7 @@ def rec(self, expr: ArrayOrNames) -> int:
else:
result = 0 + s

self._cache.add(expr, 0, key=key)
self._cache.add(inputs, 0)
return result


Expand Down
4 changes: 2 additions & 2 deletions pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ def __init__(
self,
target: Target,
kernels_seen: dict[str, lp.LoopKernel] | None = None,
_cache: TransformMapperCache[ArrayOrNames] | None = None,
_function_cache: TransformMapperCache[FunctionDefinition] | None = None
_cache: TransformMapperCache[ArrayOrNames, []] | None = None,
_function_cache: TransformMapperCache[FunctionDefinition, []] | None = None
) -> None:
super().__init__(_cache=_cache, _function_cache=_function_cache)
self.bound_arguments: dict[str, DataInterface] = {}
Expand Down
10 changes: 5 additions & 5 deletions pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,9 @@ def __init__(self,
recvd_ary_to_name: Mapping[Array, str],
sptpo_ary_to_name: Mapping[Array, str],
name_to_output: Mapping[str, Array],
_cache: TransformMapperCache[ArrayOrNames] | None = None,
_cache: TransformMapperCache[ArrayOrNames, []] | None = None,
_function_cache:
TransformMapperCache[FunctionDefinition] | None = None,
TransformMapperCache[FunctionDefinition, []] | None = None,
) -> None:
super().__init__(_cache=_cache, _function_cache=_function_cache)

Expand All @@ -261,7 +261,7 @@ def clone_for_callee(
return type(self)(
{}, {}, {},
_function_cache=cast(
"TransformMapperCache[FunctionDefinition]", self._function_cache))
"TransformMapperCache[FunctionDefinition, []]", self._function_cache))

def map_placeholder(self, expr: Placeholder) -> Placeholder:
self.user_input_names.add(expr.name)
Expand Down Expand Up @@ -294,9 +294,9 @@ def map_distributed_send(self, expr: DistributedSend) -> DistributedSend:
return new_send

def rec(self, expr: ArrayOrNames) -> ArrayOrNames:
key = self._cache.get_key(expr)
inputs = self._make_cache_inputs(expr)
try:
return self._cache.retrieve(expr, key=key)
return self._cache.retrieve(inputs)
except KeyError:
pass

Expand Down
197 changes: 109 additions & 88 deletions pytato/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@

__doc__ = """
.. autoclass:: Mapper
.. autoclass:: CacheInputsWithKey
.. autoclass:: CachedMapperCache
.. autoclass:: CachedMapper
.. autoclass:: TransformMapperCache
Expand Down Expand Up @@ -304,77 +305,77 @@ def __call__(
CacheKeyT: TypeAlias = Hashable


class CachedMapperCache(Generic[CacheExprT, CacheResultT]):
class CacheInputsWithKey(Generic[CacheExprT, P]):
"""
Data structure for inputs to :class:`CachedMapperCache`.

.. attribute:: expr

The input expression being mapped.

.. attribute:: args

A :class:`tuple` of extra positional arguments.

.. attribute:: kwargs

A :class:`dict` of extra keyword arguments.

.. attribute:: key

The cache key corresponding to *expr* and any additional inputs that were
passed.

"""
def __init__(
self,
expr: CacheExprT,
key: CacheKeyT,
*args: P.args,
**kwargs: P.kwargs):
self.expr: CacheExprT = expr
self.args: tuple[Any, ...] = args
self.kwargs: dict[str, Any] = kwargs
self.key: CacheKeyT = key


class CachedMapperCache(Generic[CacheExprT, CacheResultT, P]):
"""
Cache for mappers.

.. automethod:: __init__
.. method:: get_key

Compute the key for an input expression.

.. automethod:: add
.. automethod:: retrieve
.. automethod:: clear
"""
def __init__(
self,
key_func: Callable[..., CacheKeyT]) -> None:
"""
Initialize the cache.

:arg key_func: Function to compute a hashable cache key from an input
expression and any extra arguments.
"""
self.get_key = key_func

self._expr_key_to_result: dict[CacheKeyT, CacheResultT] = {}
def __init__(self) -> None:
"""Initialize the cache."""
self._input_key_to_result: dict[CacheKeyT, CacheResultT] = {}

def add(
self,
key_inputs:
CacheExprT
# Currently, Python's type system doesn't have a way to annotate
# containers of args/kwargs (ParamSpec won't work here). So we have
# to fall back to using Any. More details here:
# https://github.com/python/typing/issues/1252
| tuple[CacheExprT, tuple[Any, ...], dict[str, Any]],
result: CacheResultT,
key: CacheKeyT | None = None) -> CacheResultT:
inputs: CacheInputsWithKey[CacheExprT, P],
result: CacheResultT) -> CacheResultT:
"""Cache a mapping result."""
if key is None:
if isinstance(key_inputs, tuple):
expr, key_args, key_kwargs = key_inputs
key = self.get_key(expr, *key_args, **key_kwargs)
else:
key = self.get_key(key_inputs)
key = inputs.key

assert key not in self._expr_key_to_result, \
assert key not in self._input_key_to_result, \
f"Cache entry is already present for key '{key}'."

self._expr_key_to_result[key] = result

self._input_key_to_result[key] = result
return result

def retrieve(
self,
key_inputs:
CacheExprT
| tuple[CacheExprT, tuple[Any, ...], dict[str, Any]],
key: CacheKeyT | None = None) -> CacheResultT:
def retrieve(self, inputs: CacheInputsWithKey[CacheExprT, P]) -> CacheResultT:
"""Retrieve the cached mapping result."""
if key is None:
if isinstance(key_inputs, tuple):
expr, key_args, key_kwargs = key_inputs
key = self.get_key(expr, *key_args, **key_kwargs)
else:
key = self.get_key(key_inputs)

return self._expr_key_to_result[key]
key = inputs.key
return self._input_key_to_result[key]

def clear(self) -> None:
"""Reset the cache."""
self._expr_key_to_result = {}
self._input_key_to_result = {}


class CachedMapper(Mapper[ResultT, FunctionResultT, P]):
Expand All @@ -389,58 +390,79 @@ class CachedMapper(Mapper[ResultT, FunctionResultT, P]):
def __init__(
self,
_cache:
CachedMapperCache[ArrayOrNames, ResultT] | None = None,
CachedMapperCache[ArrayOrNames, ResultT, P] | None = None,
_function_cache:
CachedMapperCache[FunctionDefinition, FunctionResultT] | None = None
CachedMapperCache[FunctionDefinition, FunctionResultT, P] | None = None
) -> None:
super().__init__()

self._cache: CachedMapperCache[ArrayOrNames, ResultT] = (
self._cache: CachedMapperCache[ArrayOrNames, ResultT, P] = (
_cache if _cache is not None
else CachedMapperCache(self.get_cache_key))
else CachedMapperCache())

self._function_cache: CachedMapperCache[
FunctionDefinition, FunctionResultT] = (
FunctionDefinition, FunctionResultT, P] = (
_function_cache if _function_cache is not None
else CachedMapperCache(self.get_function_definition_cache_key))
else CachedMapperCache())

def get_cache_key(
self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs
) -> Hashable:
return (expr, *args, tuple(sorted(kwargs.items())))
) -> CacheKeyT:
if args or kwargs:
# Depending on whether extra arguments are passed by position or by
# keyword, they can end up in either args or kwargs; hence key is not
# uniquely defined in the general case
raise NotImplementedError(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain that cache could be different based on by-kw or by-position passing.

"Derived classes must override get_cache_key if using extra inputs.")
return expr

def get_function_definition_cache_key(
self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs
) -> Hashable:
return (expr, *args, tuple(sorted(kwargs.items())))
) -> CacheKeyT:
if args or kwargs:
# Depending on whether extra arguments are passed by position or by
# keyword, they can end up in either args or kwargs; hence key is not
# uniquely defined in the general case
raise NotImplementedError(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain that cache could be different based on by-kw or by-position passing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

"Derived classes must override get_function_definition_cache_key if "
"using extra inputs.")
return expr

def _make_cache_inputs(
self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs
) -> CacheInputsWithKey[ArrayOrNames, P]:
return CacheInputsWithKey(
expr, self.get_cache_key(expr, *args, **kwargs), *args, **kwargs)

def _make_function_definition_cache_inputs(
self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs
) -> CacheInputsWithKey[FunctionDefinition, P]:
return CacheInputsWithKey(
expr, self.get_function_definition_cache_key(expr, *args, **kwargs),
*args, **kwargs)

def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT:
key = self._cache.get_key(expr, *args, **kwargs)
inputs = self._make_cache_inputs(expr, *args, **kwargs)
try:
return self._cache.retrieve((expr, args, kwargs), key=key)
return self._cache.retrieve(inputs)
except KeyError:
return self._cache.add(
(expr, args, kwargs),
# Intentionally going to Mapper instead of super() to avoid
# double caching when subclasses of CachedMapper override rec,
# see https://github.com/inducer/pytato/pull/585
Mapper.rec(self, expr, *args, **kwargs),
key=key)
# Intentionally going to Mapper instead of super() to avoid
# double caching when subclasses of CachedMapper override rec,
# see https://github.com/inducer/pytato/pull/585
return self._cache.add(inputs, Mapper.rec(self, expr, *args, **kwargs))

def rec_function_definition(
self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs
) -> FunctionResultT:
key = self._function_cache.get_key(expr, *args, **kwargs)
inputs = self._make_function_definition_cache_inputs(expr, *args, **kwargs)
try:
return self._function_cache.retrieve((expr, args, kwargs), key=key)
return self._function_cache.retrieve(inputs)
except KeyError:
return self._function_cache.add(
(expr, args, kwargs),
# Intentionally going to Mapper instead of super() to avoid
# double caching when subclasses of CachedMapper override rec,
# see https://github.com/inducer/pytato/pull/585
Mapper.rec_function_definition(self, expr, *args, **kwargs),
key=key)
inputs, Mapper.rec_function_definition(self, expr, *args, **kwargs))

def clone_for_callee(
self, function: FunctionDefinition) -> Self:
Expand All @@ -456,7 +478,7 @@ def clone_for_callee(

# {{{ TransformMapper

class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT]):
class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT, P]):
pass


Expand All @@ -470,8 +492,8 @@ class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]):
"""
def __init__(
self,
_cache: TransformMapperCache[ArrayOrNames] | None = None,
_function_cache: TransformMapperCache[FunctionDefinition] | None = None
_cache: TransformMapperCache[ArrayOrNames, []] | None = None,
_function_cache: TransformMapperCache[FunctionDefinition, []] | None = None
) -> None:
super().__init__(_cache=_cache, _function_cache=_function_cache)

Expand All @@ -492,9 +514,9 @@ class TransformMapperWithExtraArgs(
"""
def __init__(
self,
_cache: TransformMapperCache[ArrayOrNames] | None = None,
_cache: TransformMapperCache[ArrayOrNames, P] | None = None,
_function_cache:
TransformMapperCache[FunctionDefinition] | None = None
TransformMapperCache[FunctionDefinition, P] | None = None
) -> None:
super().__init__(_cache=_cache, _function_cache=_function_cache)

Expand Down Expand Up @@ -1522,8 +1544,8 @@ class CachedMapAndCopyMapper(CopyMapper):
def __init__(
self,
map_fn: Callable[[ArrayOrNames], ArrayOrNames],
_cache: TransformMapperCache[ArrayOrNames] | None = None,
_function_cache: TransformMapperCache[FunctionDefinition] | None = None
_cache: TransformMapperCache[ArrayOrNames, []] | None = None,
_function_cache: TransformMapperCache[FunctionDefinition, []] | None = None
) -> None:
super().__init__(_cache=_cache, _function_cache=_function_cache)
self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn
Expand All @@ -1533,18 +1555,17 @@ def clone_for_callee(
return type(self)(
self.map_fn,
_function_cache=cast(
"TransformMapperCache[FunctionDefinition]", self._function_cache))
"TransformMapperCache[FunctionDefinition, []]", self._function_cache))

def rec(self, expr: ArrayOrNames) -> ArrayOrNames:
key = self._cache.get_key(expr)
inputs = self._make_cache_inputs(expr)
try:
return self._cache.retrieve(expr, key=key)
return self._cache.retrieve(inputs)
except KeyError:
return self._cache.add(
# Intentionally going to Mapper instead of super() to avoid
# double caching when subclasses of CachedMapper override rec,
# see https://github.com/inducer/pytato/pull/585
expr, Mapper.rec(self, self.map_fn(expr)), key=key)
# Intentionally going to Mapper instead of super() to avoid
# double caching when subclasses of CachedMapper override rec,
# see https://github.com/inducer/pytato/pull/585
return self._cache.add(inputs, Mapper.rec(self, self.map_fn(expr)))

# }}}

Expand Down Expand Up @@ -2069,8 +2090,8 @@ class DataWrapperDeduplicator(CopyMapper):
"""
def __init__(
self,
_cache: TransformMapperCache[ArrayOrNames] | None = None,
_function_cache: TransformMapperCache[FunctionDefinition] | None = None
_cache: TransformMapperCache[ArrayOrNames, []] | None = None,
_function_cache: TransformMapperCache[FunctionDefinition, []] | None = None
) -> None:
super().__init__(_cache=_cache, _function_cache=_function_cache)
self.data_wrapper_cache: dict[CacheKeyT, DataWrapper] = {}
Expand Down
Loading
Loading