13
13
import sys
14
14
import warnings
15
15
from collections .abc import Collection
16
+ from functools import lru_cache
16
17
from typing import (
17
18
TYPE_CHECKING ,
18
19
Any ,
61
62
_API_VERSIONS : Final = _API_VERSIONS_OLD | frozenset ({"2024.12" })
62
63
63
64
64
- def _is_jax_zero_gradient_array (x : object ) -> TypeGuard [_ZeroGradientArray ]:
65
- @cache
65
+ @lru_cache (100 )
66
66
def _issubclass_fast (cls : type , modname : str , clsname : str ) -> bool :
67
67
try :
68
68
mod = sys .modules [modname ]
@@ -72,6 +72,7 @@ def _issubclass_fast(cls: type, modname: str, clsname: str) -> bool:
72
72
return issubclass (cls , parent_cls )
73
73
74
74
75
+ def _is_jax_zero_gradient_array (x : object ) -> TypeGuard [_ZeroGradientArray ]:
75
76
"""Return True if `x` is a zero-gradient array.
76
77
77
78
These arrays are a design quirk of Jax that may one day be removed.
@@ -276,7 +277,7 @@ def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[repo
276
277
return hasattr (x , '__array_namespace__' ) or _is_array_api_cls (type (x ))
277
278
278
279
279
- @cache
280
+ @lru_cache ( 100 )
280
281
def _is_array_api_cls (cls : type ) -> bool :
281
282
return (
282
283
# TODO: drop support for numpy<2 which didn't have __array_namespace__
@@ -296,7 +297,7 @@ def _compat_module_name() -> str:
296
297
return __name__ .removesuffix (".common._helpers" )
297
298
298
299
299
- @cache
300
+ @lru_cache ( 100 )
300
301
def is_numpy_namespace (xp : Namespace ) -> bool :
301
302
"""
302
303
Returns True if `xp` is a NumPy namespace.
@@ -318,7 +319,7 @@ def is_numpy_namespace(xp: Namespace) -> bool:
318
319
return xp .__name__ in {"numpy" , _compat_module_name () + ".numpy" }
319
320
320
321
321
- @cache
322
+ @lru_cache ( 100 )
322
323
def is_cupy_namespace (xp : Namespace ) -> bool :
323
324
"""
324
325
Returns True if `xp` is a CuPy namespace.
@@ -340,7 +341,7 @@ def is_cupy_namespace(xp: Namespace) -> bool:
340
341
return xp .__name__ in {"cupy" , _compat_module_name () + ".cupy" }
341
342
342
343
343
- @cache
344
+ @lru_cache ( 100 )
344
345
def is_torch_namespace (xp : Namespace ) -> bool :
345
346
"""
346
347
Returns True if `xp` is a PyTorch namespace.
@@ -381,7 +382,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool:
381
382
return xp .__name__ == "ndonnx"
382
383
383
384
384
- @cache
385
+ @lru_cache ( 100 )
385
386
def is_dask_namespace (xp : Namespace ) -> bool :
386
387
"""
387
388
Returns True if `xp` is a Dask namespace.
@@ -922,7 +923,7 @@ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
922
923
return None if math .isnan (out ) else out
923
924
924
925
925
- @cache
926
+ @lru_cache ( 100 )
926
927
def _is_writeable_cls (cls : type ) -> bool | None :
927
928
if (
928
929
_issubclass_fast (cls , "numpy" , "generic" )
@@ -954,7 +955,7 @@ def is_writeable_array(x: object) -> bool:
954
955
return hasattr (x , '__array_namespace__' )
955
956
956
957
957
- @cache
958
+ @lru_cache ( 100 )
958
959
def _is_lazy_cls (cls : type ) -> bool | None :
959
960
if (
960
961
_issubclass_fast (cls , "numpy" , "ndarray" )
@@ -1054,7 +1055,7 @@ def is_lazy_array(x: object) -> bool:
1054
1055
"to_device" ,
1055
1056
]
1056
1057
1057
- _all_ignore = ['cache ' , 'sys' , 'math' , 'inspect' , 'warnings' ]
1058
+ _all_ignore = ['lru_cache ' , 'sys' , 'math' , 'inspect' , 'warnings' ]
1058
1059
1059
1060
def __dir__ () -> list [str ]:
1060
1061
return __all__
0 commit comments