12
12
import math
13
13
import sys
14
14
import warnings
15
- from collections .abc import Collection
15
+ from collections .abc import Collection , Hashable
16
16
from functools import lru_cache
17
17
from typing import (
18
18
TYPE_CHECKING ,
@@ -83,7 +83,8 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
83
83
dtype = x .dtype # type: ignore[attr-defined]
84
84
except AttributeError :
85
85
return False
86
- if not _issubclass_fast (type (dtype ), "numpy.dtypes" , "VoidDType" ):
86
+ cls = cast (Hashable , type (dtype ))
87
+ if not _issubclass_fast (cls , "numpy.dtypes" , "VoidDType" ):
87
88
return False
88
89
89
90
if "jax" not in sys .modules :
@@ -116,7 +117,7 @@ def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]:
116
117
is_pydata_sparse_array
117
118
"""
118
119
# TODO: Should we reject ndarray subclasses?
119
- cls = type (x )
120
+ cls = cast ( Hashable , type (x ) )
120
121
return (
121
122
_issubclass_fast (cls , "numpy" , "ndarray" )
122
123
or _issubclass_fast (cls , "numpy" , "generic" )
@@ -144,7 +145,8 @@ def is_cupy_array(x: object) -> bool:
144
145
is_jax_array
145
146
is_pydata_sparse_array
146
147
"""
147
- return _issubclass_fast (type (x ), "cupy" , "ndarray" )
148
+ cls = cast (Hashable , type (x ))
149
+ return _issubclass_fast (cls , "cupy" , "ndarray" )
148
150
149
151
150
152
def is_torch_array (x : object ) -> TypeIs [torch .Tensor ]:
@@ -165,7 +167,8 @@ def is_torch_array(x: object) -> TypeIs[torch.Tensor]:
165
167
is_jax_array
166
168
is_pydata_sparse_array
167
169
"""
168
- return _issubclass_fast (type (x ), "torch" , "Tensor" )
170
+ cls = cast (Hashable , type (x ))
171
+ return _issubclass_fast (cls , "torch" , "Tensor" )
169
172
170
173
171
174
def is_ndonnx_array (x : object ) -> TypeIs [ndx .Array ]:
@@ -187,7 +190,8 @@ def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]:
187
190
is_jax_array
188
191
is_pydata_sparse_array
189
192
"""
190
- return _issubclass_fast (type (x ), "ndonnx" , "Array" )
193
+ cls = cast (Hashable , type (x ))
194
+ return _issubclass_fast (cls , "ndonnx" , "Array" )
191
195
192
196
193
197
def is_dask_array (x : object ) -> TypeIs [da .Array ]:
@@ -209,7 +213,8 @@ def is_dask_array(x: object) -> TypeIs[da.Array]:
209
213
is_jax_array
210
214
is_pydata_sparse_array
211
215
"""
212
- return _issubclass_fast (type (x ), "dask.array" , "Array" )
216
+ cls = cast (Hashable , type (x ))
217
+ return _issubclass_fast (cls , "dask.array" , "Array" )
213
218
214
219
215
220
def is_jax_array (x : object ) -> TypeIs [jax .Array ]:
@@ -232,7 +237,8 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]:
232
237
is_dask_array
233
238
is_pydata_sparse_array
234
239
"""
235
- return _issubclass_fast (type (x ), "jax" , "Array" ) or _is_jax_zero_gradient_array (x )
240
+ cls = cast (Hashable , type (x ))
241
+ return _issubclass_fast (cls , "jax" , "Array" ) or _is_jax_zero_gradient_array (x )
236
242
237
243
238
244
def is_pydata_sparse_array (x : object ) -> TypeIs [sparse .SparseArray ]:
@@ -256,7 +262,8 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
256
262
is_jax_array
257
263
"""
258
264
# TODO: Account for other backends.
259
- return _issubclass_fast (type (x ), "sparse" , "SparseArray" )
265
+ cls = cast (Hashable , type (x ))
266
+ return _issubclass_fast (cls , "sparse" , "SparseArray" )
260
267
261
268
262
269
def is_array_api_obj (x : object ) -> TypeIs [_ArrayApiObj ]: # pyright: ignore[reportUnknownParameterType]
@@ -274,7 +281,10 @@ def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[repo
274
281
is_dask_array
275
282
is_jax_array
276
283
"""
277
- return hasattr (x , '__array_namespace__' ) or _is_array_api_cls (type (x ))
284
+ return (
285
+ hasattr (x , '__array_namespace__' )
286
+ or _is_array_api_cls (cast (Hashable , type (x )))
287
+ )
278
288
279
289
280
290
@lru_cache (100 )
@@ -946,9 +956,9 @@ def is_writeable_array(x: object) -> bool:
946
956
As there is no standard way to check if an array is writeable without actually
947
957
writing to it, this function blindly returns True for all unknown array types.
948
958
"""
949
- cls = type (x )
959
+ cls = cast ( Hashable , type (x ) )
950
960
if _issubclass_fast (cls , "numpy" , "ndarray" ):
951
- return x .flags .writeable
961
+ return cast ( npt . NDArray , x ) .flags .writeable
952
962
res = _is_writeable_cls (cls )
953
963
if res is not None :
954
964
return res
@@ -998,7 +1008,8 @@ def is_lazy_array(x: object) -> bool:
998
1008
999
1009
# Note: skipping reclassification of JAX zero gradient arrays, as one will
1000
1010
# exclusively get them once they leave a jax.grad JIT context.
1001
- res = _is_lazy_cls (type (x ))
1011
+ cls = cast (Hashable , type (x ))
1012
+ res = _is_lazy_cls (cls )
1002
1013
if res is not None :
1003
1014
return res
1004
1015
0 commit comments