Skip to content

Add duplication checks #550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,11 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool:

# {{{ DirectPredecessorsGetter

class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]):
class DirectPredecessorsGetter(
Mapper[
FrozenOrderedSet[ArrayOrNames | FunctionDefinition],
FrozenOrderedSet[ArrayOrNames],
[]]):
"""
Mapper to get the
`direct predecessors
Expand All @@ -334,9 +338,17 @@ class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]):

We only consider the predecessors of a nodes in a data-flow sense.
"""
def __init__(self, *, include_functions: bool = False) -> None:
super().__init__()
self.include_functions = include_functions

def _get_preds_from_shape(self, shape: ShapeType) -> FrozenOrderedSet[ArrayOrNames]:
return FrozenOrderedSet(dim for dim in shape if isinstance(dim, Array))

def map_dict_of_named_arrays(
self, expr: DictOfNamedArrays) -> FrozenOrderedSet[ArrayOrNames]:
return FrozenOrderedSet(expr._data.values())

def map_index_lambda(self, expr: IndexLambda) -> FrozenOrderedSet[ArrayOrNames]:
return (FrozenOrderedSet(expr.bindings.values())
| self._get_preds_from_shape(expr.shape))
Expand Down Expand Up @@ -397,8 +409,17 @@ def map_distributed_send_ref_holder(self,
) -> FrozenOrderedSet[ArrayOrNames]:
return FrozenOrderedSet([expr.passthrough_data])

def map_call(self, expr: Call) -> FrozenOrderedSet[ArrayOrNames]:
return FrozenOrderedSet(expr.bindings.values())
def map_call(
self, expr: Call) -> FrozenOrderedSet[ArrayOrNames | FunctionDefinition]:
result: FrozenOrderedSet[ArrayOrNames | FunctionDefinition] = \
FrozenOrderedSet(expr.bindings.values())
if self.include_functions:
result = result | FrozenOrderedSet([expr.function])
return result

def map_function_definition(
self, expr: FunctionDefinition) -> FrozenOrderedSet[ArrayOrNames]:
return FrozenOrderedSet(expr.returns.values())

def map_named_call_result(
self, expr: NamedCallResult) -> FrozenOrderedSet[ArrayOrNames]:
Expand Down Expand Up @@ -624,7 +645,7 @@ def combine(self, *args: int) -> int:
def rec(self, expr: ArrayOrNames) -> int:
inputs = self._make_cache_inputs(expr)
try:
return self._cache.retrieve(inputs)
return self._cache_retrieve(inputs)
except KeyError:
# Intentionally going to Mapper instead of super() to avoid
# double caching when subclasses of CachedMapper override rec,
Expand All @@ -639,7 +660,7 @@ def rec(self, expr: ArrayOrNames) -> int:
else:
result = 0 + s

self._cache.add(inputs, 0)
self._cache_add(inputs, 0)
return result


Expand Down
2 changes: 1 addition & 1 deletion pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def map_distributed_send(self, expr: DistributedSend) -> DistributedSend:
def rec(self, expr: ArrayOrNames) -> ArrayOrNames:
inputs = self._make_cache_inputs(expr)
try:
return self._cache.retrieve(inputs)
return self._cache_retrieve(inputs)
except KeyError:
pass

Expand Down
6 changes: 4 additions & 2 deletions pytato/loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

import loopy as lp
import pymbolic.primitives as prim
from pymbolic.typing import ArithmeticExpression, Expression, Integer, not_none
from loopy.typing import assert_tuple
from pytools import memoize_method

from pytato.array import (
Expand All @@ -61,6 +61,8 @@
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping, Sequence

from pymbolic.typing import ArithmeticExpression, Expression, Integer


__doc__ = r"""
.. currentmodule:: pytato.loopy
Expand Down Expand Up @@ -423,7 +425,7 @@ def extend_bindings_with_shape_inference(knl: lp.LoopKernel,
get_size_param_deps = SizeParamGatherer()

lp_size_params: frozenset[str] = reduce(frozenset.union,
(lpy_get_deps(not_none(arg.shape))
(lpy_get_deps(assert_tuple(arg.shape))
for arg in knl.args
if isinstance(arg, ArrayBase)
and is_expression(arg.shape)
Expand Down
Loading
Loading