Skip to content

Commit 4386b19

Browse files
attempt to resolve bpr
1 parent e060f1e commit 4386b19

File tree

3 files changed

+4
-52
lines changed

3 files changed

+4
-52
lines changed

.basedpyright/baseline.json

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3117,30 +3117,6 @@
31173117
"lineCount": 1
31183118
}
31193119
},
3120-
{
3121-
"code": "reportUnknownVariableType",
3122-
"range": {
3123-
"startColumn": 53,
3124-
"endColumn": 62,
3125-
"lineCount": 1
3126-
}
3127-
},
3128-
{
3129-
"code": "reportAny",
3130-
"range": {
3131-
"startColumn": 23,
3132-
"endColumn": 46,
3133-
"lineCount": 1
3134-
}
3135-
},
3136-
{
3137-
"code": "reportUnknownArgumentType",
3138-
"range": {
3139-
"startColumn": 33,
3140-
"endColumn": 36,
3141-
"lineCount": 1
3142-
}
3143-
},
31443120
{
31453121
"code": "reportUnknownVariableType",
31463122
"range": {
@@ -3639,14 +3615,6 @@
36393615
"lineCount": 1
36403616
}
36413617
},
3642-
{
3643-
"code": "reportAny",
3644-
"range": {
3645-
"startColumn": 12,
3646-
"endColumn": 21,
3647-
"lineCount": 1
3648-
}
3649-
},
36503618
{
36513619
"code": "reportUnknownParameterType",
36523620
"range": {
@@ -3663,14 +3631,6 @@
36633631
"lineCount": 1
36643632
}
36653633
},
3666-
{
3667-
"code": "reportAny",
3668-
"range": {
3669-
"startColumn": 19,
3670-
"endColumn": 38,
3671-
"lineCount": 1
3672-
}
3673-
},
36743634
{
36753635
"code": "reportAny",
36763636
"range": {
@@ -5337,14 +5297,6 @@
53375297
"lineCount": 1
53385298
}
53395299
},
5340-
{
5341-
"code": "reportArgumentType",
5342-
"range": {
5343-
"startColumn": 46,
5344-
"endColumn": 51,
5345-
"lineCount": 1
5346-
}
5347-
},
53485300
{
53495301
"code": "reportUnknownArgumentType",
53505302
"range": {

arraycontext/impl/jax/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from pytools.tag import ToTagSetConvertible
3737

3838
from arraycontext.container.traversal import rec_map_array_container, with_array_context
39-
from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike
39+
from arraycontext.context import ArrayContext, ArrayOrContainer, ScalarLike
4040

4141

4242
class EagerJAXArrayContext(ArrayContext):
@@ -64,7 +64,7 @@ def _get_fake_numpy_namespace(self):
6464
return EagerJAXFakeNumpyNamespace(self)
6565

6666
def _rec_map_container(
67-
self, func: Callable[[Array], Array], array: ArrayOrContainer,
67+
self, func: Callable[[object], object], array: ArrayOrContainer,
6868
allowed_types: tuple[type, ...] | None = None, *,
6969
default_scalar: ScalarLike | None = None,
7070
strict: bool = False) -> ArrayOrContainer:
@@ -101,7 +101,7 @@ def _from_numpy(ary):
101101
def to_numpy(self, array):
102102
def _to_numpy(ary):
103103
import jax
104-
return np.copy(jax.device_get(ary))
104+
return np.copy(jax.device_get(ary)) # pyright: ignore[reportAny]
105105

106106
return with_array_context(
107107
self._rec_map_container(_to_numpy, array),

arraycontext/impl/jax/fake_numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def rec_equal(x, y):
202202
def sum(self, a, axis=None, dtype=None):
203203
return rec_map_reduce_array_container(
204204
sum,
205-
partial(jnp.sum, axis=axis, dtype=dtype),
205+
partial(jnp.sum, axis=axis, dtype=dtype), # pyright: ignore[reportArgumentType]
206206
a)
207207

208208
def amin(self, a, axis=None):

0 commit comments

Comments
 (0)