diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 7dd8614a..69f8a4ba 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -3117,30 +3117,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 53, - "endColumn": 62, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 23, - "endColumn": 46, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 33, - "endColumn": 36, - "lineCount": 1 - } - }, { "code": "reportUnknownVariableType", "range": { @@ -3639,14 +3615,6 @@ "lineCount": 1 } }, - { - "code": "reportAny", - "range": { - "startColumn": 12, - "endColumn": 21, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -3663,14 +3631,6 @@ "lineCount": 1 } }, - { - "code": "reportAny", - "range": { - "startColumn": 19, - "endColumn": 38, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -5337,14 +5297,6 @@ "lineCount": 1 } }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 46, - "endColumn": 51, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py index d767a083..a44be608 100644 --- a/arraycontext/impl/jax/__init__.py +++ b/arraycontext/impl/jax/__init__.py @@ -36,7 +36,7 @@ from pytools.tag import ToTagSetConvertible from arraycontext.container.traversal import rec_map_array_container, with_array_context -from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike +from arraycontext.context import ArrayContext, ArrayOrContainer, ScalarLike class EagerJAXArrayContext(ArrayContext): @@ -64,7 +64,7 @@ def _get_fake_numpy_namespace(self): return EagerJAXFakeNumpyNamespace(self) def _rec_map_container( - self, func: Callable[[Array], Array], array: ArrayOrContainer, + self, func: Callable[[object], object], array: ArrayOrContainer, allowed_types: tuple[type, ...] | None = None, *, default_scalar: ScalarLike | None = None, strict: bool = False) -> ArrayOrContainer: @@ -101,7 +101,7 @@ def _from_numpy(ary): def to_numpy(self, array): def _to_numpy(ary): import jax - return jax.device_get(ary) + return np.copy(jax.device_get(ary)) # pyright: ignore[reportAny] return with_array_context( self._rec_map_container(_to_numpy, array), diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 7acf4fab..a47fe661 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -202,7 +202,7 @@ def rec_equal(x, y): def sum(self, a, axis=None, dtype=None): return rec_map_reduce_array_container( sum, - partial(jnp.sum, axis=axis, dtype=dtype), + partial(jnp.sum, axis=axis, dtype=dtype), # pyright: ignore[reportArgumentType] a) def amin(self, a, axis=None):