Skip to content

Commit 97f6967

Browse files
committed
wip
1 parent 6db8b11 commit 97f6967

File tree

1 file changed

+24
-13
lines changed

1 file changed

+24
-13
lines changed

array_api_compat/common/_helpers.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import math
1313
import sys
1414
import warnings
15-
from collections.abc import Collection
15+
from collections.abc import Collection, Hashable
1616
from functools import lru_cache
1717
from typing import (
1818
TYPE_CHECKING,
@@ -83,7 +83,8 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
8383
dtype = x.dtype # type: ignore[attr-defined]
8484
except AttributeError:
8585
return False
86-
if not _issubclass_fast(type(dtype), "numpy.dtypes", "VoidDType"):
86+
cls = cast(Hashable, type(dtype))
87+
if not _issubclass_fast(cls, "numpy.dtypes", "VoidDType"):
8788
return False
8889

8990
if "jax" not in sys.modules:
@@ -116,7 +117,7 @@ def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]:
116117
is_pydata_sparse_array
117118
"""
118119
# TODO: Should we reject ndarray subclasses?
119-
cls = type(x)
120+
cls = cast(Hashable, type(x))
120121
return (
121122
_issubclass_fast(cls, "numpy", "ndarray")
122123
or _issubclass_fast(cls, "numpy", "generic")
@@ -144,7 +145,8 @@ def is_cupy_array(x: object) -> bool:
144145
is_jax_array
145146
is_pydata_sparse_array
146147
"""
147-
return _issubclass_fast(type(x), "cupy", "ndarray")
148+
cls = cast(Hashable, type(x))
149+
return _issubclass_fast(cls, "cupy", "ndarray")
148150

149151

150152
def is_torch_array(x: object) -> TypeIs[torch.Tensor]:
@@ -165,7 +167,8 @@ def is_torch_array(x: object) -> TypeIs[torch.Tensor]:
165167
is_jax_array
166168
is_pydata_sparse_array
167169
"""
168-
return _issubclass_fast(type(x), "torch", "Tensor")
170+
cls = cast(Hashable, type(x))
171+
return _issubclass_fast(cls, "torch", "Tensor")
169172

170173

171174
def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]:
@@ -187,7 +190,8 @@ def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]:
187190
is_jax_array
188191
is_pydata_sparse_array
189192
"""
190-
return _issubclass_fast(type(x), "ndonnx", "Array")
193+
cls = cast(Hashable, type(x))
194+
return _issubclass_fast(cls, "ndonnx", "Array")
191195

192196

193197
def is_dask_array(x: object) -> TypeIs[da.Array]:
@@ -209,7 +213,8 @@ def is_dask_array(x: object) -> TypeIs[da.Array]:
209213
is_jax_array
210214
is_pydata_sparse_array
211215
"""
212-
return _issubclass_fast(type(x), "dask.array", "Array")
216+
cls = cast(Hashable, type(x))
217+
return _issubclass_fast(cls, "dask.array", "Array")
213218

214219

215220
def is_jax_array(x: object) -> TypeIs[jax.Array]:
@@ -232,7 +237,8 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]:
232237
is_dask_array
233238
is_pydata_sparse_array
234239
"""
235-
return _issubclass_fast(type(x), "jax", "Array") or _is_jax_zero_gradient_array(x)
240+
cls = cast(Hashable, type(x))
241+
return _issubclass_fast(cls, "jax", "Array") or _is_jax_zero_gradient_array(x)
236242

237243

238244
def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
@@ -256,7 +262,8 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
256262
is_jax_array
257263
"""
258264
# TODO: Account for other backends.
259-
return _issubclass_fast(type(x), "sparse", "SparseArray")
265+
cls = cast(Hashable, type(x))
266+
return _issubclass_fast(cls, "sparse", "SparseArray")
260267

261268

262269
def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType]
@@ -274,7 +281,10 @@ def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[repo
274281
is_dask_array
275282
is_jax_array
276283
"""
277-
return hasattr(x, '__array_namespace__') or _is_array_api_cls(type(x))
284+
return (
285+
hasattr(x, '__array_namespace__')
286+
or _is_array_api_cls(cast(Hashable, type(x)))
287+
)
278288

279289

280290
@lru_cache(100)
@@ -946,9 +956,9 @@ def is_writeable_array(x: object) -> bool:
946956
As there is no standard way to check if an array is writeable without actually
947957
writing to it, this function blindly returns True for all unknown array types.
948958
"""
949-
cls = type(x)
959+
cls = cast(Hashable, type(x))
950960
if _issubclass_fast(cls, "numpy", "ndarray"):
951-
return x.flags.writeable
961+
return cast(npt.NDArray, x).flags.writeable
952962
res = _is_writeable_cls(cls)
953963
if res is not None:
954964
return res
@@ -998,7 +1008,8 @@ def is_lazy_array(x: object) -> bool:
9981008

9991009
# Note: skipping reclassification of JAX zero gradient arrays, as one will
10001010
# exclusively get them once they leave a jax.grad JIT context.
1001-
res = _is_lazy_cls(type(x))
1011+
cls = cast(Hashable, type(x))
1012+
res = _is_lazy_cls(cls)
10021013
if res is not None:
10031014
return res
10041015

0 commit comments

Comments
 (0)