From a2f9ee103387ada74749af049648f81de19f34b0 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 2 Jul 2025 09:13:08 -0500 Subject: [PATCH 01/10] More precisely type actx.compile --- arraycontext/context.py | 4 ++- arraycontext/impl/pytato/__init__.py | 39 +++++++++++++++++++++++++--- arraycontext/impl/pytato/compile.py | 5 +++- 3 files changed, 42 insertions(+), 6 deletions(-) diff --git a/arraycontext/context.py b/arraycontext/context.py index b1e44d0d..439c0656 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -638,7 +638,9 @@ def clone(self) -> Self: "setup-only" array context "leaks" into the application. """ - def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: + def compile(self, + f: Callable[P, ArrayOrArithContainerOrScalarT] + ) -> Callable[P, ArrayOrArithContainerOrScalarT]: """Compiles *f* for repeated use on this array context. *f* is expected to be a `pure function `__ performing an array computation. diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 35bbdcc0..3654c215 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -69,6 +69,7 @@ from arraycontext.context import ( Array, ArrayContext, + ArrayOrArithContainerOrScalarT, ArrayOrContainerOrScalarT, ArrayOrContainerT, ArrayOrScalar, @@ -781,9 +782,24 @@ def call_loopy(self, program, **kwargs): return call_loopy(program, processed_kwargs, entrypoint) - def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: + @override + def compile(self, + f: Callable[P, ArrayOrArithContainerOrScalarT] + ) -> Callable[P, ArrayOrArithContainerOrScalarT]: + # FIXME Ideally, the ParamSpec P should be bounded by ArrayOrContainerOrScalar, + # but this is not currently possible: + # https://github.com/python/typing/issues/1027 + + # FIXME An aspect of this that's a bit of a lie is that the types + # coming out of the outlined function are not guaranteed to be the same + # as the ones that the un-outlined function would return. That said, + # if f is written only in terms of the array context types (Array, ScalarLike, + # containers), this is close enough to being true that I'm willing + # to take responsibility. -AK, 2025-06-30 + from .compile import LazilyPyOpenCLCompilingFunctionCaller - return LazilyPyOpenCLCompilingFunctionCaller(self, f) + return cast("Callable[P, ArrayOrArithContainerOrScalarT]", + LazilyPyOpenCLCompilingFunctionCaller(self, f)) def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays ) -> pytato.AbstractResultWithNamedArrays: @@ -983,9 +999,24 @@ def _thaw(ary: jnp.ndarray) -> pt.Array: self._rec_map_container(_thaw, array, self._frozen_array_types), actx=self) - def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: + @override + def compile(self, + f: Callable[P, ArrayOrArithContainerOrScalarT] + ) -> Callable[P, ArrayOrArithContainerOrScalarT]: + # FIXME Ideally, the ParamSpec P should be bounded by ArrayOrContainerOrScalar, + # but this is not currently possible: + # https://github.com/python/typing/issues/1027 + + # FIXME An aspect of this that's a bit of a lie is that the types + # coming out of the outlined function are not guaranteed to be the same + # as the ones that the un-outlined function would return. That said, + # if f is written only in terms of the array context types (Array, ScalarLike, + # containers), this is close enough to being true that I'm willing + # to take responsibility. -AK, 2025-06-30 + from .compile import LazilyJAXCompilingFunctionCaller - return LazilyJAXCompilingFunctionCaller(self, f) + return cast("Callable[P, ArrayOrArithContainerOrScalarT]", + LazilyJAXCompilingFunctionCaller(self, f)) @override def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 0b2cd715..ee3b04bf 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -315,7 +315,10 @@ def _dag_to_compiled_func(self, ary_or_dict_of_named_arrays, else: raise NotImplementedError(type(ary_or_dict_of_named_arrays)) - def __call__(self, *args: Any, **kwargs: Any) -> Any: + def __call__(self, + *args: ArrayOrContainerOrScalar, + **kwargs: ArrayOrContainerOrScalar + ) -> ArrayOrContainerOrScalar: """ Returns the result of :attr:`~BaseLazilyCompilingFunctionCaller.f`'s function application on *args*. From a8a48ed49d9a55f3cfe91ed55688868fe4cd6bf1 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 2 Jul 2025 13:58:02 -0500 Subject: [PATCH 02/10] Make use of pytools ObjectArray types where appropriate --- arraycontext/container/__init__.py | 29 +++++++++++++-------- arraycontext/container/arithmetic.py | 23 ++++++++--------- arraycontext/container/dataclass.py | 11 ++++++++ arraycontext/container/traversal.py | 38 ++++++++++++++++++---------- arraycontext/context.py | 4 +-- arraycontext/impl/numpy/__init__.py | 16 ++++++------ examples/how_to_outline.py | 4 +-- pyproject.toml | 13 +++++++--- test/test_arraycontext.py | 24 ++++++++++++------ test/testlib.py | 13 +++++----- 10 files changed, 109 insertions(+), 66 deletions(-) diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 43ae0560..844b599f 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -95,12 +95,6 @@ from __future__ import annotations -from types import GenericAlias, UnionType - -from numpy.typing import NDArray - -from arraycontext.context import ArrayOrArithContainer, ArrayOrContainerOrScalar - __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees @@ -128,9 +122,9 @@ from collections.abc import Hashable, Sequence from functools import singledispatch +from types import GenericAlias, UnionType from typing import ( TYPE_CHECKING, - Any, ClassVar, Protocol, TypeAlias, @@ -144,6 +138,14 @@ import numpy as np from typing_extensions import Self, TypeIs +from pytools.obj_array import ObjectArrayND + +from arraycontext.context import ( + ArrayOrArithContainer, + ArrayOrArithContainerOrScalar, + ArrayOrContainerOrScalar, +) + if TYPE_CHECKING: from pymbolic.geometric_algebra import CoeffT, MultiVector @@ -163,7 +165,10 @@ class _UserDefinedArrayContainer(Protocol): __array_ufunc__: ClassVar[None] -ArrayContainer: TypeAlias = NDArray[Any] | _UserDefinedArrayContainer +ArrayContainer: TypeAlias = ( + ObjectArrayND[ArrayOrContainerOrScalar] + | _UserDefinedArrayContainer + ) class _UserDefinedArithArrayContainer(_UserDefinedArrayContainer, Protocol): @@ -187,7 +192,9 @@ def __pow__(self, other: ArrayOrScalar | Self) -> Self: ... def __rpow__(self, other: ArrayOrScalar | Self) -> Self: ... -ArithArrayContainer: TypeAlias = NDArray[Any] | _UserDefinedArithArrayContainer +ArithArrayContainer: TypeAlias = ( + ObjectArrayND[ArrayOrArithContainerOrScalar] + | _UserDefinedArithArrayContainer) ArrayContainerT = TypeVar("ArrayContainerT", bound=ArrayContainer) @@ -307,9 +314,9 @@ def get_container_context_opt(ary: ArrayContainer) -> ArrayContext | None: # {{{ object arrays as array containers +# Sadly, ObjectArray is not usable here. @serialize_container.register(np.ndarray) -def _serialize_ndarray_container( - ary: numpy.ndarray) -> SerializedContainer: +def _serialize_ndarray_container(ary: numpy.ndarray) -> SerializedContainer: if ary.dtype.char != "O": raise NotAnArrayContainerError( f"cannot serialize '{type(ary).__name__}' with dtype '{ary.dtype}'") diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 8d230813..3becbfc3 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -50,6 +50,13 @@ from warnings import warn import numpy as np +from typing_extensions import override + +from pytools.obj_array import ( + ObjectArray, + # for backward compatibility + ObjectArray as NumpyObjectArray, # noqa: F401 # pyright: ignore[reportUnusedImport] +) from arraycontext.container import ( NotAnArrayContainerError, @@ -152,17 +159,9 @@ def _format_binary_op_str(op_str: str, return op_str.format(arg1, arg2) -class NumpyObjectArrayMetaclass(type): - def __instancecheck__(cls, instance: Any) -> bool: - return isinstance(instance, np.ndarray) and instance.dtype == object - - -class NumpyObjectArray(metaclass=NumpyObjectArrayMetaclass): - pass - - class ComplainingNumpyNonObjectArrayMetaclass(type): - def __instancecheck__(cls, instance: Any) -> bool: + @override + def __instancecheck__(cls, instance: object) -> bool: if isinstance(instance, np.ndarray) and instance.dtype != object: # Example usage site: # https://github.com/illinois-ceesd/mirgecom/blob/f5d0d97c41e8c8a05546b1d1a6a2979ec8ea3554/mirgecom/inviscid.py#L148-L149 @@ -272,7 +271,7 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args): # - Anything that special-cases np.ndarray by type is broken by design because: # - np.ndarray is an array context array. # - numpy object arrays can be array containers. - # Using NumpyObjectArray and NumpyNonObjectArray *may* be better? + # Using ObjectArray and NumpyNonObjectArray *may* be better? # They're new, so there is no operational experience with them. # # - Broadcast rules are hard to change once established, particularly @@ -374,7 +373,7 @@ def numpy_pred(name: str) -> str: raise ValueError("If numpy.ndarray is part of bcast_container_types, " "bcast_obj_array must be False.") - numpy_check_types: list[type] = [NumpyObjectArray, ComplainingNumpyNonObjectArray] + numpy_check_types: list[type] = [ObjectArray, ComplainingNumpyNonObjectArray] container_types_bcast_across = tuple( new_ct for old_ct in container_types_bcast_across diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index f36905e6..05210ba7 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -12,6 +12,10 @@ """ from __future__ import annotations +from pytools.obj_array import ObjectArray + +from arraycontext.context import ArrayOrContainer, ArrayOrContainerOrScalar + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees @@ -133,9 +137,16 @@ def is_array_field(f: _Field) -> bool: # pyright has no idea what we're up to. :) if field_type is ArrayContainer: # pyright: ignore[reportUnnecessaryComparison] return True + if field_type is ArrayOrContainer: # pyright: ignore[reportUnnecessaryComparison] + return True + if field_type is ArrayOrContainerOrScalar: # pyright: ignore[reportUnnecessaryComparison] + return True origin = get_origin(field_type) + if origin is ObjectArray: + return True + # NOTE: `UnionType` is returned when using `Type1 | Type2` if origin in (Union, UnionType): # pyright: ignore[reportDeprecated] for arg in get_args(field_type): # pyright: ignore[reportAny] diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 7e419d10..f9646005 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -71,12 +71,18 @@ """ from functools import partial, singledispatch, update_wrapper -from typing import TYPE_CHECKING, Any, cast, overload +from typing import TYPE_CHECKING, Any, ClassVar, Protocol, cast, overload from warnings import warn import numpy as np from typing_extensions import deprecated +from pytools.obj_array import ( + ObjectArray, + from_numpy as obj_array_from_numpy, + to_numpy as obj_array_to_numpy, +) + from arraycontext.container import ( ArrayContainer, ArrayContainerT, @@ -87,12 +93,11 @@ is_array_container, serialize_container, ) -from arraycontext.container.arithmetic import NumpyObjectArray from arraycontext.context import is_scalar_like, shape_is_int_only if TYPE_CHECKING: - from collections.abc import Callable, Iterable + from collections.abc import Callable, Collection, Iterable from arraycontext.context import ( Array, @@ -102,7 +107,6 @@ ArrayOrContainerOrScalarT, ArrayOrContainerT, ArrayOrScalar, - ScalarLike, ) @@ -1089,6 +1093,8 @@ def _flat_size(subary: ArrayOrContainerOrScalar) -> Array | int | np.integer: # }}} +class _HasOuterBcastTypes(Protocol): + _outer_bcast_types: ClassVar[Collection[type]] # {{{ numpy conversion @@ -1124,7 +1130,10 @@ def to_numpy(ary: ArrayOrContainer, actx: ArrayContext) -> ArrayOrContainer: # {{{ algebraic operations -def outer(a: Any, b: Any) -> Any: +def outer( + a: ArrayOrContainerOrScalar, + b: ArrayOrContainerOrScalar + ) -> ArrayOrContainerOrScalar: """ Compute the outer product of *a* and *b* while allowing either of them to be an :class:`ArrayContainer`. @@ -1139,7 +1148,7 @@ def outer(a: Any, b: Any) -> Any: have the same type. """ - def treat_as_scalar(x: Any) -> bool: + def treat_as_scalar(x: ArrayOrContainerOrScalar) -> bool: try: serialize_container(x) except NotAnArrayContainerError: @@ -1148,20 +1157,23 @@ def treat_as_scalar(x: Any) -> bool: return ( not isinstance(x, np.ndarray) # This condition is whether "ndarrays should broadcast inside x". - and NumpyObjectArray not in x.__class__._outer_bcast_types) + and ObjectArray not in cast( + "type[_HasOuterBcastTypes]", x.__class__)._outer_bcast_types) - a_is_ndarray = isinstance(a, np.ndarray) - b_is_ndarray = isinstance(b, np.ndarray) + a_is_ndarray = isinstance(a, ObjectArray) + b_is_ndarray = isinstance(b, ObjectArray) - if a_is_ndarray and a.dtype != object: + if isinstance(a, np.ndarray) and a.dtype != object: raise TypeError("passing a non-object numpy array is not allowed") - if b_is_ndarray and b.dtype != object: + if isinstance(b, np.ndarray) and b.dtype != object: raise TypeError("passing a non-object numpy array is not allowed") if treat_as_scalar(a) or treat_as_scalar(b): - return a*b + return a*b # pyright: ignore[reportOperatorIssue,reportReturnType] elif a_is_ndarray and b_is_ndarray: - return np.outer(a, b) + return obj_array_from_numpy(np.outer( + obj_array_to_numpy(a), + obj_array_to_numpy(b))) elif a_is_ndarray or b_is_ndarray: return map_array_container(lambda x: outer(x, b), a) else: diff --git a/arraycontext/context.py b/arraycontext/context.py index 439c0656..c57d08a4 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -427,9 +427,7 @@ def zeros(self, return self.np.zeros(shape, dtype) @overload - # FIXME: object arrays are containers, so pyright has a point. - # Maybe introduce a separate (type-check-only) NumpyObjectArray type? - def from_numpy(self, array: np.ndarray) -> Array: # pyright: ignore[reportOverlappingOverload] + def from_numpy(self, array: np.ndarray) -> Array: ... @overload diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index e1736a41..fd05a75f 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -8,8 +8,6 @@ from __future__ import annotations -from typing_extensions import override - __copyright__ = """ Copyright (C) 2021 University of Illinois Board of Trustees @@ -38,6 +36,7 @@ from typing import TYPE_CHECKING, Any, cast, overload import numpy as np +from typing_extensions import override import loopy as lp @@ -54,6 +53,7 @@ ContainerOrScalarT, NumpyOrContainerOrScalar, UntransformedCodeWarning, + is_scalar_like, ) @@ -61,6 +61,8 @@ from pymbolic import Scalar from pytools.tag import ToTagSetConvertible + from arraycontext.container import ArrayContainerT + class NumpyNonObjectArrayMetaclass(type): def __instancecheck__(cls, instance: Any) -> bool: @@ -97,9 +99,7 @@ def clone(self): return type(self)() @overload - # FIXME: object arrays are containers, so pyright has a point. - # Maybe introduce a separate (type-check-only) NumpyObjectArray type? - def from_numpy(self, array: np.ndarray) -> Array: # pyright: ignore[reportOverlappingOverload] + def from_numpy(self, array: np.ndarray) -> Array: ... @overload @@ -107,15 +107,15 @@ def from_numpy(self, array: Scalar) -> Array: ... @overload - def from_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT: + def from_numpy(self, array: ArrayContainerT) -> ArrayContainerT: ... @override def from_numpy(self, array: NumpyOrContainerOrScalar ) -> ArrayOrContainerOrScalar: - if np.isscalar(array): - return np.array(array) + if isinstance(array, np.ndarray) or is_scalar_like(array): + return cast("Array", cast("object", np.array(array))) return array @overload diff --git a/examples/how_to_outline.py b/examples/how_to_outline.py index 3564fb44..5592b741 100644 --- a/examples/how_to_outline.py +++ b/examples/how_to_outline.py @@ -7,7 +7,7 @@ from typing_extensions import override import pytato as pt -from pytools.obj_array import make_obj_array +from pytools.obj_array import ObjectArray1D, make_obj_array from arraycontext import ( Array, @@ -67,7 +67,7 @@ def transform_dag(self, @dc.dataclass(frozen=True) class State: mass: Array | np.ndarray - vel: np.ndarray # np array of Arrays or numpy arrays + vel: ObjectArray1D[Array] @actx.outline diff --git a/pyproject.toml b/pyproject.toml index bbb8f44e..1994af6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,9 +29,9 @@ classifiers = [ dependencies = [ "immutabledict>=4.1", "numpy", - "pytools>=2024.1.3", - # for Self - "typing_extensions>=4", + "pytools>=2025.2", + # for TypeIs + "typing_extensions>=4.10", ] [project.optional-dependencies] @@ -127,6 +127,9 @@ extend-ignore-re = [ "(?Rm)^.*(#|//)\\s*spellchecker:\\s*disable-line$" ] +[tool.typos.default.extend-words] +"nd" = "nd" + [tool.basedpyright] reportImplicitStringConcatenation = "none" reportUnnecessaryIsInstance = "none" @@ -159,6 +162,10 @@ reportIndexIssue = "hint" reportOperatorIssue = "hint" reportAttributeAccessIssue = "hint" +# so much numpy-or-not abuse in the tests *facepalm* +reportCallIssue = "hint" +reportArgumentType = "hint" + [[tool.basedpyright.executionEnvironments]] root = "examples" reportUnknownArgumentType = "hint" diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index acde4212..b1ea3572 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -31,7 +31,7 @@ import numpy as np import pytest -from pytools.obj_array import make_obj_array +from pytools.obj_array import ObjectArray1D, make_obj_array from pytools.tag import Tag from arraycontext import ( @@ -1194,7 +1194,10 @@ def outlined_scale_and_orthogonalize(alpha: float, vel: Velocity2D) -> Velocity2 return scale_and_orthogonalize(alpha, vel) def multi_scale_and_orthogonalize( - alpha: float, vel1: Velocity2D, vel2: Velocity2D) -> np.ndarray: + alpha: float, + vel1: Velocity2D, + vel2: Velocity2D + ) -> ObjectArray1D[Velocity2D]: return make_obj_array([ outlined_scale_and_orthogonalize(alpha, vel1), outlined_scale_and_orthogonalize(alpha, vel2)]) @@ -1206,8 +1209,13 @@ def multi_scale_and_orthogonalize( v2_x = rng.uniform(size=10) v2_y = rng.uniform(size=10) - vel1 = actx.from_numpy(Velocity2D(v1_x, v1_y, actx)) - vel2 = actx.from_numpy(Velocity2D(v2_x, v2_y, actx)) + v1_x_actx = actx.from_numpy(v1_x) + v1_y_actx = actx.from_numpy(v1_y) + v2_x_actx = actx.from_numpy(v2_x) + v2_y_actx = actx.from_numpy(v2_y) + + vel1 = Velocity2D(v1_x_actx, v1_y_actx, actx) + vel2 = Velocity2D(v2_x_actx, v2_y_actx, actx) scaled_speed1, scaled_speed2 = compiled_rhs(np.float64(3.14), vel1, vel2) @@ -1233,9 +1241,9 @@ def test_container_equality(actx_factory: ArrayContextFactory): # MyContainer sets eq_comparison to False, so equality comparison should # not succeed. - # type-ignore because pyright is right and I'm sorry. - dc = MyContainer(name="yoink", mass=ary_dof, momentum=None, enthalpy=None) # pyright: ignore[reportArgumentType] - dc2 = MyContainer(name="yoink", mass=ary_dof, momentum=None, enthalpy=None) # pyright: ignore[reportArgumentType] + # (formerly) type-ignored because pyright is right and I'm sorry. + dc = MyContainer(name="yoink", mass=ary_dof, momentum=None, enthalpy=None) + dc2 = MyContainer(name="yoink", mass=ary_dof, momentum=None, enthalpy=None) assert dc != dc2 assert isinstance(actx.np.equal(bcast_dc_of_dofs, bcast_dc_of_dofs_2), @@ -1436,7 +1444,7 @@ def test_array_container_with_numpy(actx_factory: ArrayContextFactory): # FIXME: Possibly, rec_map_container's types could be taught that numpy # arrays can happen, but life's too short. - rec_map_container(lambda x: x, mystate) # pyright: ignore[reportCallIssue, reportArgumentType] + rec_map_container(lambda x: x, mystate) # }}} diff --git a/test/testlib.py b/test/testlib.py index 808fc1b5..22bf35be 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -27,8 +27,9 @@ import numpy as np +from pytools.obj_array import ObjectArray1D # noqa: TC001 + from arraycontext import ( - ArrayContainer, ArrayContext, dataclass_array_container, deserialize_container, @@ -36,7 +37,7 @@ with_array_context, with_container_arithmetic, ) -from arraycontext.context import ScalarLike # noqa: TC001 +from arraycontext.context import ArrayOrContainer, ScalarLike # noqa: TC001 # Containers live here, because in order for get_annotations to work, they must @@ -147,7 +148,7 @@ def _with_actx_dofarray(ary: DOFArray, actx: ArrayContext) -> DOFArray: # type: class MyContainer: name: str mass: DOFArray | np.ndarray | ScalarLike - momentum: np.ndarray + momentum: ObjectArray1D[DOFArray] enthalpy: DOFArray | np.ndarray | ScalarLike __array_ufunc__: ClassVar[None] = None @@ -172,7 +173,7 @@ def array_context(self): class MyContainerDOFBcast: name: str mass: DOFArray | np.ndarray - momentum: np.ndarray + momentum: ObjectArray1D[DOFArray] enthalpy: DOFArray | np.ndarray __array_ufunc__: ClassVar[None] = None @@ -209,8 +210,8 @@ def array_context(self): @dataclass_array_container @dataclass(frozen=True) class Velocity2D: - u: ArrayContainer - v: ArrayContainer + u: ArrayOrContainer + v: ArrayOrContainer array_context: ArrayContext __array_ufunc__: ClassVar[None] = None From 9441691b0538c80171fb44e4e2862c57ad600f97 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 2 Jul 2025 11:29:07 -0500 Subject: [PATCH 03/10] Use pytools.ndindex in test --- test/test_arraycontext.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index b1ea3572..2dc47cc5 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -31,6 +31,7 @@ import numpy as np import pytest +from pytools import ndindex from pytools.obj_array import ObjectArray1D, make_obj_array from pytools.tag import Tag @@ -148,7 +149,7 @@ def _get_test_containers(actx, ambient_dim=2, shapes=50_000): ary_dof = x ary_of_dofs = make_obj_array([x] * ambient_dim) mat_of_dofs = np.empty((ambient_dim, ambient_dim), dtype=object) - for i in np.ndindex(mat_of_dofs.shape): + for i in ndindex(mat_of_dofs.shape): mat_of_dofs[i] = x return (ary_dof, ary_of_dofs, mat_of_dofs, dataclass_of_dofs, From 79e2caf0deee1e3a6ceb459667aabbb78179df4f Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 2 Jul 2025 13:54:20 -0500 Subject: [PATCH 04/10] Typing annotation -> type annotation --- arraycontext/container/dataclass.py | 2 +- test/test_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index 05210ba7..780ff6d4 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -180,7 +180,7 @@ def is_array_field(f: _Field) -> bool: if isinstance(field_type, GenericAlias | _BaseGenericAlias | _SpecialForm): # NOTE: anything except a Union is not allowed raise TypeError( - f"Typing annotation not supported on field '{f.name}': " + f"Type annotation not supported on field '{f.name}': " f"'{field_type!r}'") if not isinstance(field_type, type): diff --git a/test/test_utils.py b/test/test_utils.py index 3ecc32e2..b39c6300 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -88,7 +88,7 @@ class ArrayContainerWithTuple: # Deliberately left as Tuple to test compatibility. y: Tuple[Array, Array] # noqa: UP006 - with pytest.raises(TypeError, match="Typing annotation not supported on field 'y'"): + with pytest.raises(TypeError, match="Type annotation not supported on field 'y'"): dataclass_array_container(ArrayContainerWithTuple) @dataclass @@ -96,7 +96,7 @@ class ArrayContainerWithTupleAlt: x: Array y: tuple[Array, Array] - with pytest.raises(TypeError, match="Typing annotation not supported on field 'y'"): + with pytest.raises(TypeError, match="Type annotation not supported on field 'y'"): dataclass_array_container(ArrayContainerWithTupleAlt) # }}} From da79de8d90e4efe0ef3336f0d16121f529f33aa3 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 2 Jul 2025 13:55:11 -0500 Subject: [PATCH 05/10] Some typing improvements in pyopencl actx --- arraycontext/impl/pyopencl/__init__.py | 47 ++++++++++++++++++++---- arraycontext/impl/pyopencl/fake_numpy.py | 9 +++-- 2 files changed, 45 insertions(+), 11 deletions(-) diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index 0e1c2894..1275d30b 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -31,7 +31,7 @@ THE SOFTWARE. """ -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, cast, overload from warnings import warn import numpy as np @@ -49,6 +49,8 @@ ArrayOrContainerOrScalarT, ArrayOrContainerT as ArrayOrContainerT, ArrayOrScalar, + ContainerOrScalarT, + NumpyOrContainerOrScalar, ScalarLike, UntransformedCodeWarning, is_scalar_like, @@ -58,12 +60,15 @@ if TYPE_CHECKING: from collections.abc import Callable, Mapping + from numpy.typing import NDArray + import loopy as lp import pyopencl as cl import pyopencl.array as cl_array from loopy import TranslationUnit from pytools.tag import ToTagSetConvertible + from arraycontext.container import ArrayContainerT from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray @@ -223,17 +228,43 @@ def _wrapper(ary: ArrayOrContainerOrScalar) -> ArrayOrContainerOrScalar: # {{{ ArrayContext interface - def from_numpy(self, array): + @overload + def from_numpy(self, array: NDArray[Any]) -> Array: + ... + + @overload + def from_numpy(self, array: ScalarLike) -> Array: + ... + + @overload + def from_numpy(self, array: ArrayContainerT) -> ArrayContainerT: + ... + + @override + def from_numpy(self, + array: NumpyOrContainerOrScalar + ) -> ArrayOrContainerOrScalar: import arraycontext.impl.pyopencl.taggable_cl_array as tga - def _from_numpy(ary): + def _from_numpy(ary: NDArray[Any]): return tga.to_device(self.queue, ary, allocator=self.allocator) return with_array_context( - self._rec_map_container(_from_numpy, array, (np.ndarray,), strict=True), + self._rec_map_container(_from_numpy, array, (np.ndarray,), strict=True), # pyright: ignore[reportArgumentType] actx=self) - def to_numpy(self, array): + @overload + def to_numpy(self, array: Array) -> np.ndarray: + ... + + @overload + def to_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT: + ... + + @override + def to_numpy(self, + array: ArrayOrContainerOrScalar + ) -> NumpyOrContainerOrScalar: def _to_numpy(ary): return ary.get(queue=self.queue) @@ -241,14 +272,16 @@ def _to_numpy(ary): self._rec_map_container(_to_numpy, array), actx=None) - def freeze(self, array): + @override + def freeze(self, array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT: def _freeze(ary): ary.finish() return ary.with_queue(None) return with_array_context(self._rec_map_container(_freeze, array), actx=None) - def thaw(self, array): + @override + def thaw(self, array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT: def _thaw(ary): return ary.with_queue(self.queue) diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index 9a32c627..7ccbae3b 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -31,7 +31,7 @@ import operator from functools import partial, reduce -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from warnings import warn import numpy as np @@ -223,6 +223,7 @@ def _any(ary): _any, a) + @override def array_equal(self, a: ArrayOrContainerOrScalar, b: ArrayOrContainerOrScalar @@ -237,7 +238,7 @@ def array_equal(self, def rec_equal( x: ArrayOrContainerOrScalar, y: ArrayOrContainerOrScalar, - ) -> cl_array.Array: + ) -> Array: if type(x) is not type(y): return false_ary @@ -256,13 +257,13 @@ def rec_equal( if len(serialized_x) != len(serialized_y): return false_ary - return reduce( + return cast("Array", reduce( partial(cl_array.minimum, queue=queue), [(true_ary if kx_i == ky_i else false_ary) and rec_equal(x_i, y_i) for (kx_i, x_i), (ky_i, y_i) in zip(serialized_x, serialized_y, strict=True)], - true_ary) + true_ary)) return rec_equal(a, b) From 77770fb5103b3e640342c31a441ae2120f018c03 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 2 Jul 2025 13:55:55 -0500 Subject: [PATCH 06/10] Deprecate np.ndarray annotations in dataclass_array_container --- arraycontext/container/dataclass.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index 780ff6d4..84343a90 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -12,10 +12,6 @@ """ from __future__ import annotations -from pytools.obj_array import ObjectArray - -from arraycontext.context import ArrayOrContainer, ArrayOrContainerOrScalar - __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees @@ -54,10 +50,14 @@ get_args, get_origin, ) +from warnings import warn import numpy as np +from pytools.obj_array import ObjectArray + from arraycontext.container import ArrayContainer, is_array_container_type +from arraycontext.context import ArrayOrContainer, ArrayOrContainerOrScalar if TYPE_CHECKING: @@ -77,7 +77,15 @@ class _Field(NamedTuple): type: type -def is_array_type(tp: type, /) -> bool: +def _is_array_or_container_type(tp: type, /) -> bool: + if tp is np.ndarray: + warn("Encountered 'numpy.ndarray' in a dataclass_array_container. " + "This is deprecated and will stop working in 2026. " + "If you meant an object array, use pytools.obj_array.ObjectArray. " + "For other uses, file an issue to discuss.", + DeprecationWarning, stacklevel=3) + return True + from arraycontext import Array return tp is Array or is_array_container_type(tp) @@ -151,7 +159,7 @@ def is_array_field(f: _Field) -> bool: if origin in (Union, UnionType): # pyright: ignore[reportDeprecated] for arg in get_args(field_type): # pyright: ignore[reportAny] if not ( - is_array_type(cast("type", arg)) + _is_array_or_container_type(cast("type", arg)) or is_scalar_type(cast("type", arg))): raise TypeError( f"Field '{f.name}' union contains non-array container " @@ -188,7 +196,7 @@ def is_array_field(f: _Field) -> bool: f"Field '{f.name}' not an instance of 'type': " f"'{field_type!r}'") - return is_array_type(field_type) + return _is_array_or_container_type(field_type) from pytools import partition From 28ef3583d285314e3d7efe55e391e321e10db129 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 2 Jul 2025 15:03:41 -0500 Subject: [PATCH 07/10] Better type instance check in NumpyNonObjectArrayMetaclass --- arraycontext/impl/numpy/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index fd05a75f..aee1df0e 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -65,7 +65,8 @@ class NumpyNonObjectArrayMetaclass(type): - def __instancecheck__(cls, instance: Any) -> bool: + @override + def __instancecheck__(cls, instance: object) -> bool: return isinstance(instance, np.ndarray) and instance.dtype != object From ab517556ba7f60d7015449fc54399bf3052ae834 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 2 Jul 2025 15:04:54 -0500 Subject: [PATCH 08/10] Centralize typing stuff in arraycontext.typing --- arraycontext/__init__.py | 36 ++- arraycontext/container/__init__.py | 67 +---- arraycontext/container/arithmetic.py | 4 +- arraycontext/container/dataclass.py | 8 +- arraycontext/container/traversal.py | 35 ++- arraycontext/context.py | 227 ++------------ arraycontext/fake_numpy.py | 6 +- arraycontext/impl/jax/__init__.py | 4 +- arraycontext/impl/jax/fake_numpy.py | 7 +- arraycontext/impl/numpy/__init__.py | 8 +- arraycontext/impl/numpy/fake_numpy.py | 4 +- arraycontext/impl/pyopencl/__init__.py | 8 +- arraycontext/impl/pyopencl/fake_numpy.py | 6 +- .../impl/pyopencl/taggable_cl_array.py | 2 +- arraycontext/impl/pytato/__init__.py | 10 +- arraycontext/impl/pytato/compile.py | 5 +- arraycontext/impl/pytato/fake_numpy.py | 4 +- arraycontext/impl/pytato/outline.py | 8 +- arraycontext/typing.py | 278 ++++++++++++++++++ doc/conf.py | 7 +- doc/other.rst | 2 + examples/how_to_outline.py | 2 +- test/testlib.py | 2 +- 23 files changed, 404 insertions(+), 336 deletions(-) create mode 100644 arraycontext/typing.py diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index fb0b948c..39a2bffc 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -30,9 +30,6 @@ """ from .container import ( - ArithArrayContainer, - ArrayContainer, - ArrayContainerT, NotAnArrayContainerError, SerializationKey, SerializedContainer, @@ -74,9 +71,26 @@ with_array_context, ) from .context import ( - Array, ArrayContext, ArrayContextFactory, + tag_axes, +) +from .impl.jax import EagerJAXArrayContext +from .impl.numpy import NumpyArrayContext +from .impl.pyopencl import PyOpenCLArrayContext +from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext +from .loopy import make_loopy_program +from .pytest import ( + PytestArrayContextFactory, + PytestPyOpenCLArrayContextFactory, + pytest_generate_tests_for_array_contexts, +) +from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag +from .typing import ( + ArithArrayContainer, + Array, + ArrayContainer, + ArrayContainerT, ArrayOrArithContainer, ArrayOrArithContainerOrScalar, ArrayOrArithContainerOrScalarT, @@ -88,21 +102,10 @@ ArrayOrScalar, ArrayOrScalarT, ArrayT, + ContainerOrScalarT, Scalar, ScalarLike, - tag_axes, -) -from .impl.jax import EagerJAXArrayContext -from .impl.numpy import NumpyArrayContext -from .impl.pyopencl import PyOpenCLArrayContext -from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext -from .loopy import make_loopy_program -from .pytest import ( - PytestArrayContextFactory, - PytestPyOpenCLArrayContextFactory, - pytest_generate_tests_for_array_contexts, ) -from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag __all__ = ( @@ -125,6 +128,7 @@ "ArrayT", "BcastUntilActxArray", "CommonSubexpressionTag", + "ContainerOrScalarT", "EagerJAXArrayContext", "ElementwiseMapKernelTag", "NotAnArrayContainerError", diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 844b599f..620978ff 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -49,9 +49,7 @@ This should be considered experimental for now, and it may well change. .. autoclass:: ArithArrayContainer -.. class:: ArrayContainerT - - A type variable with a lower bound of :class:`ArrayContainer`. +.. autoclass:: ArrayContainerT .. autoexception:: NotAnArrayContainerError @@ -125,10 +123,7 @@ from types import GenericAlias, UnionType from typing import ( TYPE_CHECKING, - ClassVar, - Protocol, TypeAlias, - TypeVar, get_origin, ) @@ -136,13 +131,15 @@ # what 'np' is. import numpy import numpy as np -from typing_extensions import Self, TypeIs +from typing_extensions import TypeIs -from pytools.obj_array import ObjectArrayND +from pytools.obj_array import ObjectArrayND as ObjectArrayND -from arraycontext.context import ( +from arraycontext.typing import ( + ArrayContainer, + ArrayContainerT, ArrayOrArithContainer, - ArrayOrArithContainerOrScalar, + ArrayOrArithContainerOrScalar as ArrayOrArithContainerOrScalar, ArrayOrContainerOrScalar, ) @@ -150,55 +147,11 @@ if TYPE_CHECKING: from pymbolic.geometric_algebra import CoeffT, MultiVector - from arraycontext.context import ArrayContext, ArrayOrScalar - - -# {{{ ArrayContainer - -class _UserDefinedArrayContainer(Protocol): - # This is used as a type annotation in dataclasses that are processed - # by dataclass_array_container, where it's used to recognize attributes - # that are container-typed. - - # This method prevents ArrayContainer from matching any object, while - # matching numpy object arrays and many array containers. - __array_ufunc__: ClassVar[None] - - -ArrayContainer: TypeAlias = ( - ObjectArrayND[ArrayOrContainerOrScalar] - | _UserDefinedArrayContainer - ) - - -class _UserDefinedArithArrayContainer(_UserDefinedArrayContainer, Protocol): - # This is loose and permissive, assuming that any array can be added - # to any container. The alternative would be to plaster type-ignores - # on all those uses. Achieving typing precision on what broadcasting is - # allowable seems like a huge endeavor and is likely not feasible without - # a mypy plugin. Maybe some day? -AK, November 2024 - - def __neg__(self) -> Self: ... - def __abs__(self) -> Self: ... - def __add__(self, other: ArrayOrScalar | Self) -> Self: ... - def __radd__(self, other: ArrayOrScalar | Self) -> Self: ... - def __sub__(self, other: ArrayOrScalar | Self) -> Self: ... - def __rsub__(self, other: ArrayOrScalar | Self) -> Self: ... - def __mul__(self, other: ArrayOrScalar | Self) -> Self: ... - def __rmul__(self, other: ArrayOrScalar | Self) -> Self: ... - def __truediv__(self, other: ArrayOrScalar | Self) -> Self: ... - def __rtruediv__(self, other: ArrayOrScalar | Self) -> Self: ... - def __pow__(self, other: ArrayOrScalar | Self) -> Self: ... - def __rpow__(self, other: ArrayOrScalar | Self) -> Self: ... - - -ArithArrayContainer: TypeAlias = ( - ObjectArrayND[ArrayOrArithContainerOrScalar] - | _UserDefinedArithArrayContainer) - + from arraycontext.context import ArrayContext + from arraycontext.typing import ArrayOrScalar as ArrayOrScalar -ArrayContainerT = TypeVar("ArrayContainerT", bound=ArrayContainer) +# {{{ ArrayContainer traversals class NotAnArrayContainerError(TypeError): """:class:`TypeError` subclass raised when an array container is expected.""" diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 3becbfc3..f245fed7 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -68,8 +68,8 @@ if TYPE_CHECKING: from collections.abc import Callable - from arraycontext.context import ( - ArrayContext, + from arraycontext.context import ArrayContext + from arraycontext.typing import ( ArrayOrContainer, ArrayOrContainerOrScalar, ) diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index 84343a90..d2487567 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -56,8 +56,12 @@ from pytools.obj_array import ObjectArray -from arraycontext.container import ArrayContainer, is_array_container_type -from arraycontext.context import ArrayOrContainer, ArrayOrContainerOrScalar +from arraycontext.container import is_array_container_type +from arraycontext.typing import ( + ArrayContainer, + ArrayOrContainer, + ArrayOrContainerOrScalar, +) if TYPE_CHECKING: diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index f9646005..40fd694b 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -33,14 +33,26 @@ .. autofunction:: unflatten .. autofunction:: flat_size_and_dtype -Numpy conversion -~~~~~~~~~~~~~~~~ -.. autofunction:: from_numpy -.. autofunction:: to_numpy - Algebraic operations ~~~~~~~~~~~~~~~~~~~~ .. autofunction:: outer + +.. currentmodule:: arraycontext.traversal + +References +---------- + +.. class:: ArrayOrScalar + + See :class:`arraycontext.ArrayOrScalar`. + +.. class:: ArrayOrContainer + + See :class:`arraycontext.ArrayOrContainer`. + +.. class:: ArrayContainerT + + See :class:`arraycontext.ArrayContainerT`. """ from __future__ import annotations @@ -84,8 +96,6 @@ ) from arraycontext.container import ( - ArrayContainer, - ArrayContainerT, NotAnArrayContainerError, SerializationKey, deserialize_container, @@ -93,15 +103,20 @@ is_array_container, serialize_container, ) -from arraycontext.context import is_scalar_like, shape_is_int_only +from arraycontext.typing import ( + ArrayContainer, + ArrayContainerT, + is_scalar_like, + shape_is_int_only, +) if TYPE_CHECKING: from collections.abc import Callable, Collection, Iterable - from arraycontext.context import ( + from arraycontext.context import ArrayContext + from arraycontext.typing import ( Array, - ArrayContext, ArrayOrContainer, ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT, diff --git a/arraycontext/context.py b/arraycontext/context.py index c57d08a4..0625a752 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -85,69 +85,9 @@ A :class:`typing.ParamSpec` representing the arguments of a function being :meth:`ArrayContext.outline`\ d. -Types and Type Variables for Arrays and Containers --------------------------------------------------- +References +---------- -.. autodata:: ScalarLike - :noindex: - - A type alias of :data:`pymbolic.Scalar`. - -.. autoclass:: Array - -.. autodata:: ArrayT - - A type variable with a lower bound of :class:`Array`. - -.. autodata:: ScalarLike - - A type annotation for scalar types commonly usable with arrays. - -See also :class:`ArrayContainer` and :class:`ArrayOrContainerT`. - -.. autoclass:: ArrayOrScalar -.. autodata:: ArrayOrScalarT - -.. autodata:: ArrayOrContainer - -.. autodata:: ArrayOrContainerT - - A type variable with a bound of :class:`ArrayOrContainer`. - -.. autodata:: ArrayOrArithContainer - -.. autodata:: ArrayOrArithContainerT - - A type variable with a bound of :class:`ArrayOrArithContainer`. - -.. autodata:: ArrayOrArithContainerOrScalar - -.. autodata:: ArrayOrArithContainerOrScalarT - - A type variable with a bound of :class:`ArrayOrContainerOrScalar`. - -.. autodata:: ArrayOrContainerOrScalar - -.. autodata:: ArrayOrContainerOrScalarT - - A type variable with a bound of :class:`ArrayOrContainerOrScalar`. - -.. currentmodule:: arraycontext.context - -Canonical locations for type annotations ----------------------------------------- - -.. class:: ArrayT - - :canonical: arraycontext.ArrayT - -.. class:: ArrayOrContainerT - - :canonical: arraycontext.ArrayOrContainerT - -.. class:: ArrayOrContainerOrScalarT - - :canonical: arraycontext.ArrayOrContainerOrScalarT """ from __future__ import annotations @@ -183,176 +123,43 @@ from typing import ( TYPE_CHECKING, Any, - Literal, ParamSpec, - Protocol, - SupportsInt, TypeAlias, - TypeVar, - cast, overload, ) from warnings import warn -import numpy as np -from typing_extensions import Self, TypeIs +from typing_extensions import Self -from pymbolic.typing import Scalar as _Scalar from pytools import memoize_method +# FIXME: remove sometime, this import was used in grudge in July 2025. +from .typing import ArrayOrArithContainerTc as ArrayOrArithContainerTc + if TYPE_CHECKING: - from numpy.typing import DTypeLike + import numpy as np import loopy - from pymbolic.typing import Integer from pytools.tag import ToTagSetConvertible - from arraycontext.container import ( - ArithArrayContainer, - ArrayContainer, + from .fake_numpy import BaseFakeNumpyNamespace + from .typing import ( + Array, ArrayContainerT, + ArrayOrArithContainerOrScalarT, + ArrayOrContainerOrScalar, + ArrayOrContainerOrScalarT, + ArrayOrContainerT, + ContainerOrScalarT, + NumpyOrContainerOrScalar, + ScalarLike, ) - from arraycontext.fake_numpy import BaseFakeNumpyNamespace -# {{{ typing - P = ParamSpec("P") -# We won't support 'A' and 'K', since they depend on in-memory order; that is -# not intended to be a meaningful concept for actx arrays. -OrderCF: TypeAlias = Literal["C"] | Literal["F"] - - -class Array(Protocol): - """A :class:`~typing.Protocol` for the array type supported by - :class:`ArrayContext`. - - This is meant to aid in typing annotations. For a explicit list of - supported types see :attr:`ArrayContext.array_types`. - - .. attribute:: shape - .. attribute:: size - .. attribute:: dtype - .. attribute:: __getitem__ - - In addition, arrays are expected to support basic arithmetic. - """ - - @property - def shape(self) -> tuple[Array | Integer, ...]: - ... - - @property - def size(self) -> Array | Integer: - ... - - def __len__(self) -> int: ... - - @property - def dtype(self) -> np.dtype[Any]: - ... - - # Covering all the possible index variations is hard and (kind of) futile. - # If you'd like to see how, try changing the Any to - # AxisIndex = slice | int | "Array" - # Index = AxisIndex |tuple[AxisIndex] - def __getitem__(self, index: Any) -> Array: - ... - - # Some basic arithmetic that's supposed to work - # Need to return Array instead of Self because for some array types, arithmetic - # operations on one subtype may result in a different subtype. - # For example, pytato arrays: + 1 -> - def __neg__(self) -> Array: ... - def __abs__(self) -> Array: ... - def __add__(self, other: Self | ScalarLike) -> Array: ... - def __radd__(self, other: Self | ScalarLike) -> Array: ... - def __sub__(self, other: Self | ScalarLike) -> Array: ... - def __rsub__(self, other: Self | ScalarLike) -> Array: ... - def __mul__(self, other: Self | ScalarLike) -> Array: ... - def __rmul__(self, other: Self | ScalarLike) -> Array: ... - def __pow__(self, other: Self | ScalarLike) -> Array: ... - def __rpow__(self, other: Self | ScalarLike) -> Array: ... - def __truediv__(self, other: Self | ScalarLike) -> Array: ... - def __rtruediv__(self, other: Self | ScalarLike) -> Array: ... - - def copy(self) -> Self: ... - - @property - def real(self) -> Array: ... - @property - def imag(self) -> Array: ... - def conj(self) -> Array: ... - - def astype(self, dtype: DTypeLike) -> Array: ... - - # Annoyingly, numpy 2.3.1 (and likely earlier) treats these differently when - # reshaping to the empty shape (), so we need to expose both. - @overload - def reshape(self, *shape: int, order: OrderCF = "C") -> Array: ... - - @overload - def reshape(self, shape: tuple[int, ...], /, *, order: OrderCF = "C") -> Array: ... - - @property - def T(self) -> Array: ... # noqa: N802 - - def transpose(self, axes: tuple[int, ...]) -> Array: ... - - -# deprecated, use ScalarLike instead -Scalar: TypeAlias = _Scalar -ScalarLike = Scalar -ScalarLikeT = TypeVar("ScalarLikeT", bound=ScalarLike) - -ArrayT = TypeVar("ArrayT", bound=Array) -ArrayOrScalar: TypeAlias = Array | _Scalar -ArrayOrScalarT = TypeVar("ArrayOrScalarT", bound=ArrayOrScalar) -ArrayOrContainer: TypeAlias = "Array | ArrayContainer" -ArrayOrArithContainer: TypeAlias = "Array | ArithArrayContainer" -ArrayOrArithContainerTc = TypeVar("ArrayOrArithContainerTc", - Array, "ArithArrayContainer") -ArrayOrContainerT = TypeVar("ArrayOrContainerT", bound=ArrayOrContainer) -ArrayOrArithContainerT = TypeVar("ArrayOrArithContainerT", bound=ArrayOrArithContainer) -ArrayOrContainerOrScalar: TypeAlias = "Array | ArrayContainer | ScalarLike" -ArrayOrArithContainerOrScalar: TypeAlias = "Array | ArithArrayContainer | ScalarLike" -ArrayOrContainerOrScalarT = TypeVar( - "ArrayOrContainerOrScalarT", - bound=ArrayOrContainerOrScalar) -ArrayOrArithContainerOrScalarT = TypeVar( - "ArrayOrArithContainerOrScalarT", - bound=ArrayOrArithContainerOrScalar) - - -ContainerOrScalarT = TypeVar("ContainerOrScalarT", bound="ArrayContainer | ScalarLike") - - -NumpyOrContainerOrScalar: TypeAlias = "np.ndarray | ArrayContainer | ScalarLike" - - -def is_scalar_like(x: object, /) -> TypeIs[Scalar]: - return np.isscalar(x) - - -def shape_is_int_only(shape: tuple[Array | Integer, ...], /) -> tuple[int, ...]: - res: list[int] = [] - for i, s in enumerate(shape): - try: - res.append(int(cast("SupportsInt", s))) - except TypeError: - raise TypeError( - "only non-parametric shapes are allowed in this context, " - f"axis {i+1} is {type(s)}" - ) from None - - return tuple(res) - -# }}} - - # {{{ ArrayContext class ArrayContext(ABC): diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index 7bd7a960..c3395e0e 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -40,7 +40,7 @@ serialize_container, ) from arraycontext.container.traversal import rec_map_container -from arraycontext.context import ArrayOrContainer, ArrayOrContainerT, is_scalar_like +from arraycontext.typing import ArrayOrContainer, ArrayOrContainerT, is_scalar_like if TYPE_CHECKING: @@ -50,9 +50,9 @@ from pymbolic import Scalar - from arraycontext.context import ( + from arraycontext.context import ArrayContext + from arraycontext.typing import ( Array, - ArrayContext, ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT, ArrayOrScalar, diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py index e22311de..73f29cf5 100644 --- a/arraycontext/impl/jax/__init__.py +++ b/arraycontext/impl/jax/__init__.py @@ -38,9 +38,9 @@ rec_map_container, with_array_context, ) -from arraycontext.context import ( +from arraycontext.context import ArrayContext +from arraycontext.typing import ( Array, - ArrayContext, ArrayOrContainerOrScalar, ArrayOrScalar, ScalarLike, diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 843b9eab..48d8a4af 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -1,7 +1,5 @@ from __future__ import annotations -from arraycontext.context import is_scalar_like - __copyright__ = """ Copyright (C) 2021 University of Illinois Board of Trustees @@ -44,6 +42,7 @@ rec_multimap_array_container, ) from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace +from arraycontext.typing import is_scalar_like if TYPE_CHECKING: @@ -51,12 +50,12 @@ from pymbolic import Scalar - from arraycontext.context import ( + from arraycontext.impl.jax import EagerJAXArrayContext + from arraycontext.typing import ( Array, ArrayOrContainerOrScalar, ArrayOrScalar, ) - from arraycontext.impl.jax import EagerJAXArrayContext class EagerJAXFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index aee1df0e..3aac1a9e 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -46,13 +46,15 @@ with_array_context, ) from arraycontext.context import ( - Array, ArrayContext, + UntransformedCodeWarning, +) +from arraycontext.typing import ( + Array, ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT, ContainerOrScalarT, NumpyOrContainerOrScalar, - UntransformedCodeWarning, is_scalar_like, ) @@ -61,7 +63,7 @@ from pymbolic import Scalar from pytools.tag import ToTagSetConvertible - from arraycontext.container import ArrayContainerT + from arraycontext.typing import ArrayContainerT class NumpyNonObjectArrayMetaclass(type): diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py index 163e3037..ccb96de7 100644 --- a/arraycontext/impl/numpy/fake_numpy.py +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -37,11 +37,11 @@ rec_multimap_array_container, rec_multimap_reduce_array_container, ) -from arraycontext.context import OrderCF, is_scalar_like from arraycontext.fake_numpy import ( BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace, ) +from arraycontext.typing import OrderCF, is_scalar_like if TYPE_CHECKING: @@ -51,7 +51,7 @@ from pymbolic import Scalar - from arraycontext.context import ( + from arraycontext.typing import ( Array, ArrayOrContainerOrScalar, ArrayOrScalar, diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index 1275d30b..c3db39da 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -43,8 +43,11 @@ with_array_context, ) from arraycontext.context import ( - Array, ArrayContext, + UntransformedCodeWarning, +) +from arraycontext.typing import ( + Array, ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT, ArrayOrContainerT as ArrayOrContainerT, @@ -52,7 +55,6 @@ ContainerOrScalarT, NumpyOrContainerOrScalar, ScalarLike, - UntransformedCodeWarning, is_scalar_like, ) @@ -68,8 +70,8 @@ from loopy import TranslationUnit from pytools.tag import ToTagSetConvertible - from arraycontext.container import ArrayContainerT from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray + from arraycontext.typing import ArrayContainerT # {{{ PyOpenCLArrayContext diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index 7ccbae3b..8accb584 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -46,10 +46,10 @@ rec_multimap_array_container, rec_multimap_reduce_array_container, ) -from arraycontext.context import OrderCF, is_scalar_like from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray from arraycontext.loopy import LoopyBasedFakeNumpyNamespace +from arraycontext.typing import OrderCF, is_scalar_like if TYPE_CHECKING: @@ -58,12 +58,12 @@ from pymbolic import Scalar from pytools.tag import Tag - from arraycontext.context import ( + from arraycontext.impl.pyopencl import PyOpenCLArrayContext + from arraycontext.typing import ( Array, ArrayOrContainerOrScalar, ArrayOrScalar, ) - from arraycontext.impl.pyopencl import PyOpenCLArrayContext # {{{ fake numpy diff --git a/arraycontext/impl/pyopencl/taggable_cl_array.py b/arraycontext/impl/pyopencl/taggable_cl_array.py index 3a2a8493..6402ba66 100644 --- a/arraycontext/impl/pyopencl/taggable_cl_array.py +++ b/arraycontext/impl/pyopencl/taggable_cl_array.py @@ -21,7 +21,7 @@ if TYPE_CHECKING: from numpy.typing import DTypeLike - from arraycontext.context import Array + from arraycontext.typing import Array _EMPTY_TAG_SET: frozenset[Tag] = frozenset() diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 3654c215..d770c7f7 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -67,18 +67,20 @@ with_array_context, ) from arraycontext.context import ( - Array, ArrayContext, + P, + UntransformedCodeWarning, +) +from arraycontext.metadata import NameHint +from arraycontext.typing import ( + Array, ArrayOrArithContainerOrScalarT, ArrayOrContainerOrScalarT, ArrayOrContainerT, ArrayOrScalar, - P, ScalarLike, - UntransformedCodeWarning, is_scalar_like, ) -from arraycontext.metadata import NameHint if TYPE_CHECKING: diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index ee3b04bf..323eb791 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -46,12 +46,10 @@ from pytools.tag import Tag from arraycontext.container import ( - ArrayContainer, SerializationKey, is_array_container_type, ) from arraycontext.container.traversal import rec_keyed_map_array_container -from arraycontext.context import ArrayOrContainerOrScalar, ArrayOrScalar, is_scalar_like from arraycontext.impl.pyopencl.taggable_cl_array import ( TaggableCLArray, ) @@ -60,6 +58,7 @@ PytatoPyOpenCLArrayContext, _BasePytatoArrayContext, ) +from arraycontext.typing import ArrayOrContainerOrScalar, ArrayOrScalar, is_scalar_like if TYPE_CHECKING: @@ -68,6 +67,8 @@ import pyopencl.array as cla from pytato.array import AxesT + from arraycontext.typing import ArrayContainer + AllowedArray: TypeAlias = "pt.Array | TaggableCLArray | cla.Array" AllowedArrayTc = TypeVar("AllowedArrayTc", pt.Array, TaggableCLArray, "cla.Array") diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index 0ee2b97d..75274580 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -39,9 +39,9 @@ rec_map_reduce_array_container, rec_multimap_array_container, ) -from arraycontext.context import ArrayOrScalar, OrderCF, is_scalar_like from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace from arraycontext.loopy import LoopyBasedFakeNumpyNamespace +from arraycontext.typing import ArrayOrScalar, OrderCF, is_scalar_like if TYPE_CHECKING: @@ -51,8 +51,8 @@ from pymbolic import Scalar - from arraycontext.context import Array, ArrayOrContainerOrScalar from arraycontext.impl.pytato import _BasePytatoArrayContext + from arraycontext.typing import Array, ArrayOrContainerOrScalar class PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): diff --git a/arraycontext/impl/pytato/outline.py b/arraycontext/impl/pytato/outline.py index 9d116081..3ea7f68d 100644 --- a/arraycontext/impl/pytato/outline.py +++ b/arraycontext/impl/pytato/outline.py @@ -39,9 +39,9 @@ from arraycontext.container import SerializationKey, is_array_container_type from arraycontext.container.traversal import rec_keyed_map_array_container -from arraycontext.context import ( +from arraycontext.context import P +from arraycontext.typing import ( ArrayOrContainerOrScalar, - P, is_scalar_like, ) @@ -52,11 +52,11 @@ from pymbolic import Scalar from pytools.tag import Tag - from arraycontext.context import ( + from arraycontext.impl.pytato import _BasePytatoArrayContext + from arraycontext.typing import ( Array, ArrayOrScalar, ) - from arraycontext.impl.pytato import _BasePytatoArrayContext def _get_arg_id_to_arg( diff --git a/arraycontext/typing.py b/arraycontext/typing.py new file mode 100644 index 00000000..ac86e424 --- /dev/null +++ b/arraycontext/typing.py @@ -0,0 +1,278 @@ +""" +.. currentmodule:: arraycontext + +Types and Type Variables for Arrays and Containers +-------------------------------------------------- + +.. autoclass:: ScalarLike + + A type alias of :data:`pymbolic.Scalar`. + +.. autoclass:: Array + +.. autoclass:: ArrayT + + A type variable with a lower bound of :class:`Array`. + +See also :class:`ArrayContainer` and :class:`ArrayOrContainerT`. + +.. autoclass:: ArrayOrScalar +.. autoclass:: ArrayOrScalarT +.. autoclass:: ArrayOrContainer +.. autoclass:: ArrayOrContainerT + + A type variable with a bound of :class:`ArrayOrContainer`. + +.. autoclass:: ArrayOrArithContainer +.. autoclass:: ArrayOrArithContainerT +.. autoclass:: ContainerOrScalarT +.. autoclass:: ArrayOrArithContainerOrScalar +.. autoclass:: ArrayOrArithContainerOrScalarT + + A type variable with a bound of :class:`ArrayOrContainerOrScalar`. + +.. autoclass:: ArrayOrContainerOrScalar + +.. autoclass:: ArrayOrContainerOrScalarT + + A type variable with a bound of :class:`ArrayOrContainerOrScalar`. + +Other locations +--------------- +.. currentmodule:: arraycontext.typing + +.. class:: ArrayContainerT + + :canonical: :class:`arraycontext.ArrayContainerT`. +""" +from __future__ import annotations + + +__copyright__ = """ +Copyright (C) 2025 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Literal, + Protocol, + SupportsInt, + TypeAlias, + TypeVar, + cast, + overload, +) + +import numpy as np +from typing_extensions import Self, TypeIs + +from pymbolic.typing import Integer, Scalar as _Scalar +from pytools.obj_array import ObjectArrayND + + +if TYPE_CHECKING: + from numpy.typing import DTypeLike + + +# deprecated, use ScalarLike instead +Scalar: TypeAlias = _Scalar +ScalarLike = Scalar +ScalarLikeT = TypeVar("ScalarLikeT", bound=ScalarLike) + +# {{{ array + +# We won't support 'A' and 'K', since they depend on in-memory order; that is +# not intended to be a meaningful concept for actx arrays. +OrderCF: TypeAlias = Literal["C"] | Literal["F"] + + +class Array(Protocol): + """A :class:`~typing.Protocol` for the array type supported by + :class:`ArrayContext`. + + This is meant to aid in typing annotations. For a explicit list of + supported types see :attr:`ArrayContext.array_types`. + + .. attribute:: shape + .. attribute:: size + .. attribute:: dtype + .. attribute:: __getitem__ + + In addition, arrays are expected to support basic arithmetic. + """ + + @property + def shape(self) -> tuple[Array | Integer, ...]: + ... + + @property + def size(self) -> Array | Integer: + ... + + def __len__(self) -> int: ... + + @property + def dtype(self) -> np.dtype[Any]: + ... + + # Covering all the possible index variations is hard and (kind of) futile. + # If you'd like to see how, try changing the Any to + # AxisIndex = slice | int | "Array" + # Index = AxisIndex |tuple[AxisIndex] + def __getitem__(self, index: Any) -> Array: # pyright: ignore[reportAny] + ... + + # Some basic arithmetic that's supposed to work + # Need to return Array instead of Self because for some array types, arithmetic + # operations on one subtype may result in a different subtype. + # For example, pytato arrays: + 1 -> + def __neg__(self) -> Array: ... + def __abs__(self) -> Array: ... + def __add__(self, other: Self | ScalarLike) -> Array: ... + def __radd__(self, other: Self | ScalarLike) -> Array: ... + def __sub__(self, other: Self | ScalarLike) -> Array: ... + def __rsub__(self, other: Self | ScalarLike) -> Array: ... + def __mul__(self, other: Self | ScalarLike) -> Array: ... + def __rmul__(self, other: Self | ScalarLike) -> Array: ... + def __pow__(self, other: Self | ScalarLike) -> Array: ... + def __rpow__(self, other: Self | ScalarLike) -> Array: ... + def __truediv__(self, other: Self | ScalarLike) -> Array: ... + def __rtruediv__(self, other: Self | ScalarLike) -> Array: ... + + def copy(self) -> Self: ... + + @property + def real(self) -> Array: ... + @property + def imag(self) -> Array: ... + def conj(self) -> Array: ... + + def astype(self, dtype: DTypeLike) -> Array: ... + + # Annoyingly, numpy 2.3.1 (and likely earlier) treats these differently when + # reshaping to the empty shape (), so we need to expose both. + @overload + def reshape(self, *shape: int, order: OrderCF = "C") -> Array: ... + + @overload + def reshape(self, shape: tuple[int, ...], /, *, order: OrderCF = "C") -> Array: ... + + @property + def T(self) -> Array: ... # noqa: N802 + + def transpose(self, axes: tuple[int, ...]) -> Array: ... + +# }}} + + +# {{{ array container + +class _UserDefinedArrayContainer(Protocol): + # This is used as a type annotation in dataclasses that are processed + # by dataclass_array_container, where it's used to recognize attributes + # that are container-typed. + + # This method prevents ArrayContainer from matching any object, while + # matching numpy object arrays and many array containers. + __array_ufunc__: ClassVar[None] + + +ArrayContainer: TypeAlias = ( + ObjectArrayND["ArrayOrContainerOrScalar"] + | _UserDefinedArrayContainer + ) + + +class _UserDefinedArithArrayContainer(_UserDefinedArrayContainer, Protocol): + # This is loose and permissive, assuming that any array can be added + # to any container. The alternative would be to plaster type-ignores + # on all those uses. Achieving typing precision on what broadcasting is + # allowable seems like a huge endeavor and is likely not feasible without + # a mypy plugin. Maybe some day? -AK, November 2024 + + def __neg__(self) -> Self: ... + def __abs__(self) -> Self: ... + def __add__(self, other: ArrayOrScalar | Self) -> Self: ... + def __radd__(self, other: ArrayOrScalar | Self) -> Self: ... + def __sub__(self, other: ArrayOrScalar | Self) -> Self: ... + def __rsub__(self, other: ArrayOrScalar | Self) -> Self: ... + def __mul__(self, other: ArrayOrScalar | Self) -> Self: ... + def __rmul__(self, other: ArrayOrScalar | Self) -> Self: ... + def __truediv__(self, other: ArrayOrScalar | Self) -> Self: ... + def __rtruediv__(self, other: ArrayOrScalar | Self) -> Self: ... + def __pow__(self, other: ArrayOrScalar | Self) -> Self: ... + def __rpow__(self, other: ArrayOrScalar | Self) -> Self: ... + + +ArithArrayContainer: TypeAlias = ( + ObjectArrayND["ArrayOrArithContainerOrScalar"] + | _UserDefinedArithArrayContainer) + + +ArrayContainerT = TypeVar("ArrayContainerT", bound=ArrayContainer) + +# }}} + + +ArrayT = TypeVar("ArrayT", bound=Array) +ArrayOrScalar: TypeAlias = Array | _Scalar +ArrayOrScalarT = TypeVar("ArrayOrScalarT", bound=ArrayOrScalar) +ArrayOrContainer: TypeAlias = Array | ArrayContainer +ArrayOrArithContainer: TypeAlias = Array | ArithArrayContainer +ArrayOrArithContainerTc = TypeVar("ArrayOrArithContainerTc", + Array, "ArithArrayContainer") +ArrayOrContainerT = TypeVar("ArrayOrContainerT", bound=ArrayOrContainer) +ArrayOrArithContainerT = TypeVar("ArrayOrArithContainerT", bound=ArrayOrArithContainer) +ArrayOrContainerOrScalar: TypeAlias = Array | ArrayContainer | ScalarLike +ArrayOrArithContainerOrScalar: TypeAlias = Array | ArithArrayContainer | ScalarLike +ArrayOrContainerOrScalarT = TypeVar( + "ArrayOrContainerOrScalarT", + bound=ArrayOrContainerOrScalar) +ArrayOrArithContainerOrScalarT = TypeVar( + "ArrayOrArithContainerOrScalarT", + bound=ArrayOrArithContainerOrScalar) + + +ContainerOrScalarT = TypeVar("ContainerOrScalarT", bound="ArrayContainer | ScalarLike") + + +NumpyOrContainerOrScalar: TypeAlias = "np.ndarray | ArrayContainer | ScalarLike" + + +def is_scalar_like(x: object, /) -> TypeIs[Scalar]: + return np.isscalar(x) + + +def shape_is_int_only(shape: tuple[Array | Integer, ...], /) -> tuple[int, ...]: + res: list[int] = [] + for i, s in enumerate(shape): + try: + res.append(int(cast("SupportsInt", s))) + except TypeError: + raise TypeError( + "only non-parametric shapes are allowed in this context, " + f"axis {i+1} is {type(s)}" + ) from None + + return tuple(res) diff --git a/doc/conf.py b/doc/conf.py index 6e64d9f4..f2acda4f 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -41,9 +41,8 @@ nitpick_ignore_regex = [ - ["py:class", r"arraycontext\.context\.ContainerOrScalarT"], - ["py:class", r"ArrayOrContainer"], - ["py:class", r"ArrayOrScalar"], - ["py:class", r"arraycontext.container._UserDefinedArithArrayContainer"], + ["py:class", r"arraycontext.typing._UserDefinedArrayContainer"], + ["py:class", r"arraycontext.typing._UserDefinedArithArrayContainer"], ["py:class", r"np.integer"], + ["py:class", r".*\|.*"], ] diff --git a/doc/other.rst b/doc/other.rst index fc13932f..4b0cca75 100644 --- a/doc/other.rst +++ b/doc/other.rst @@ -1,6 +1,8 @@ Other functionality =================== +.. automodule:: arraycontext.typing + .. _metadata: Metadata ("tags") for Arrays and Array Axes diff --git a/examples/how_to_outline.py b/examples/how_to_outline.py index 5592b741..6dddd600 100644 --- a/examples/how_to_outline.py +++ b/examples/how_to_outline.py @@ -18,7 +18,7 @@ if TYPE_CHECKING: - from arraycontext.context import ( + from arraycontext import ( ArrayOrArithContainer, ) diff --git a/test/testlib.py b/test/testlib.py index 22bf35be..81e7b1b4 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -37,7 +37,7 @@ with_array_context, with_container_arithmetic, ) -from arraycontext.context import ArrayOrContainer, ScalarLike # noqa: TC001 +from arraycontext.typing import ArrayOrContainer, ScalarLike # noqa: TC001 # Containers live here, because in order for get_annotations to work, they must From fa4b94284dabe3bbebeb3cf02a00dd65ab7dc408 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 2 Jul 2025 09:16:49 -0500 Subject: [PATCH 09/10] Update baseline --- .basedpyright/baseline.json | 498 ++---------------------------------- 1 file changed, 21 insertions(+), 477 deletions(-) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 3804760f..a7c4183e 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -133,38 +133,6 @@ } ], "./arraycontext/container/arithmetic.py": [ - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 25, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 31, - "endColumn": 39, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 25, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 31, - "endColumn": 39, - "lineCount": 1 - } - }, { "code": "reportUnusedParameter", "range": { @@ -847,14 +815,6 @@ } ], "./arraycontext/container/dataclass.py": [ - { - "code": "reportDeprecated", - "range": { - "startColumn": 22, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportAttributeAccessIssue", "range": { @@ -1473,83 +1433,35 @@ "lineCount": 3 } }, - { - "code": "reportAny", - "range": { - "startColumn": 4, - "endColumn": 9, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 10, - "endColumn": 11, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 18, - "endColumn": 19, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 24, - "endColumn": 25, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 44, - "endColumn": 55, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 44, - "endColumn": 74, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { - "startColumn": 15, - "endColumn": 23, + "startColumn": 37, + "endColumn": 44, "lineCount": 1 } }, { - "code": "reportAny", + "code": "reportUnknownMemberType", "range": { - "startColumn": 45, - "endColumn": 56, + "startColumn": 37, + "endColumn": 44, "lineCount": 1 } }, { - "code": "reportAny", + "code": "reportUnknownVariableType", "range": { - "startColumn": 16, - "endColumn": 17, + "startColumn": 15, + "endColumn": 18, "lineCount": 1 } }, { - "code": "reportAny", + "code": "reportUnknownMemberType", "range": { - "startColumn": 31, - "endColumn": 32, + "startColumn": 36, + "endColumn": 44, "lineCount": 1 } }, @@ -1586,23 +1498,23 @@ } }, { - "code": "reportAny", + "code": "reportUnknownArgumentType", "range": { - "startColumn": 53, - "endColumn": 64, + "startColumn": 59, + "endColumn": 60, "lineCount": 1 } - } - ], - "./arraycontext/context.py": [ + }, { - "code": "reportAny", + "code": "reportUnknownArgumentType", "range": { - "startColumn": 26, - "endColumn": 31, + "startColumn": 62, + "endColumn": 63, "lineCount": 1 } - }, + } + ], + "./arraycontext/context.py": [ { "code": "reportImplicitOverride", "range": { @@ -2723,22 +2635,6 @@ } ], "./arraycontext/impl/numpy/__init__.py": [ - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 25, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 31, - "endColumn": 39, - "lineCount": 1 - } - }, { "code": "reportUnannotatedClassAttribute", "range": { @@ -3639,62 +3535,6 @@ "lineCount": 1 } }, - { - "code": "reportAny", - "range": { - "startColumn": 8, - "endColumn": 18, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 18, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 25, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 25, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 24, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 24, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 45, - "endColumn": 48, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -3703,54 +3543,6 @@ "lineCount": 3 } }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 36, - "endColumn": 47, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 49, - "endColumn": 54, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 8, - "endColumn": 16, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 16, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 23, - "endColumn": 28, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 23, - "endColumn": 28, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -3807,46 +3599,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 47, - "endColumn": 52, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 8, - "endColumn": 14, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 14, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 21, - "endColumn": 26, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 21, - "endColumn": 26, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -3911,46 +3663,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 67, - "endColumn": 72, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 8, - "endColumn": 12, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 12, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 19, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 19, - "endColumn": 24, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -4007,14 +3719,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 65, - "endColumn": 70, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -4793,70 +4497,6 @@ "lineCount": 1 } }, - { - "code": "reportAny", - "range": { - "startColumn": 8, - "endColumn": 16, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 19, - "endColumn": 34, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 8, - "endColumn": 17, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 20, - "endColumn": 35, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 23, - "endColumn": 32, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 27, - "endColumn": 36, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 27, - "endColumn": 36, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 23, - "endColumn": 33, - "lineCount": 7 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -4873,14 +4513,6 @@ "lineCount": 1 } }, - { - "code": "reportAny", - "range": { - "startColumn": 24, - "endColumn": 32, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -7421,14 +7053,6 @@ "lineCount": 1 } }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 15, - "lineCount": 1 - } - }, { "code": "reportImplicitOverride", "range": { @@ -7869,14 +7493,6 @@ "lineCount": 1 } }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 15, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -8335,30 +7951,6 @@ "lineCount": 1 } }, - { - "code": "reportAny", - "range": { - "startColumn": 8, - "endColumn": 16, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 24, - "endColumn": 28, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 37, - "endColumn": 43, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -8375,38 +7967,6 @@ "lineCount": 1 } }, - { - "code": "reportAny", - "range": { - "startColumn": 42, - "endColumn": 45, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 30, - "endColumn": 33, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 47, - "endColumn": 50, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 28, - "endColumn": 31, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -12751,14 +12311,6 @@ "lineCount": 1 } }, - { - "code": "reportOperatorIssue", - "range": { - "startColumn": 8, - "endColumn": 53, - "lineCount": 1 - } - }, { "code": "reportOperatorIssue", "range": { @@ -12839,14 +12391,6 @@ "lineCount": 1 } }, - { - "code": "reportOperatorIssue", - "range": { - "startColumn": 36, - "endColumn": 55, - "lineCount": 1 - } - }, { "code": "reportOperatorIssue", "range": { From 006dd3ecb695753dd1f2d1fb6637e286a979b858 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 2 Jul 2025 09:15:09 -0500 Subject: [PATCH 10/10] Remove deprecated traversal.{to,from}_numpy --- arraycontext/__init__.py | 4 ---- arraycontext/container/traversal.py | 30 ----------------------------- 2 files changed, 34 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 39a2bffc..e728684b 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -51,7 +51,6 @@ flat_size_and_dtype, flatten, freeze, - from_numpy, map_array_container, map_reduce_array_container, mapped_over_array_containers, @@ -66,7 +65,6 @@ rec_multimap_reduce_array_container, stringify_array_container_tree, thaw, - to_numpy, unflatten, with_array_context, ) @@ -147,7 +145,6 @@ "flat_size_and_dtype", "flatten", "freeze", - "from_numpy", "get_container_context_opt", "get_container_context_recursively", "get_container_context_recursively_opt", @@ -172,7 +169,6 @@ "stringify_array_container_tree", "tag_axes", "thaw", - "to_numpy", "unflatten", "with_array_context", "with_container_arithmetic", diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 40fd694b..d8ebbd4e 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -1106,40 +1106,10 @@ def _flat_size(subary: ArrayOrContainerOrScalar) -> Array | int | np.integer: size = _flat_size(ary) return size, common_dtype -# }}} class _HasOuterBcastTypes(Protocol): _outer_bcast_types: ClassVar[Collection[type]] -# {{{ numpy conversion - -def from_numpy( - ary: np.ndarray | ScalarLike, - actx: ArrayContext) -> ArrayOrContainerOrScalar: - """Convert all :mod:`numpy` arrays in the :class:`~arraycontext.ArrayContainer` - to the base array type of :class:`~arraycontext.ArrayContext`. - - The conversion is done using :meth:`arraycontext.ArrayContext.from_numpy`. - """ - warn("Calling from_numpy(ary, actx) is deprecated, call actx.from_numpy(ary)" - " instead. This will stop working in 2023.", - DeprecationWarning, stacklevel=2) - - return actx.from_numpy(ary) - - -def to_numpy(ary: ArrayOrContainer, actx: ArrayContext) -> ArrayOrContainer: - """Convert all arrays in the :class:`~arraycontext.ArrayContainer` to - :mod:`numpy` using the provided :class:`~arraycontext.ArrayContext` *actx*. - - The conversion is done using :meth:`arraycontext.ArrayContext.to_numpy`. - """ - warn("Calling to_numpy(ary, actx) is deprecated, call actx.to_numpy(ary)" - " instead. This will stop working in 2023.", - DeprecationWarning, stacklevel=2) - - return actx.to_numpy(ary) - # }}}