Skip to content

Commit 931faae

Browse files
committed
Merge
1 parent fe75549 commit 931faae

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

array_api_compat/common/_helpers.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import sys
1414
import warnings
1515
from collections.abc import Collection
16+
from functools import lru_cache
1617
from typing import (
1718
TYPE_CHECKING,
1819
Any,
@@ -61,8 +62,7 @@
6162
_API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"})
6263

6364

64-
def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
65-
@cache
65+
@lru_cache(100)
6666
def _issubclass_fast(cls: type, modname: str, clsname: str) -> bool:
6767
try:
6868
mod = sys.modules[modname]
@@ -72,6 +72,7 @@ def _issubclass_fast(cls: type, modname: str, clsname: str) -> bool:
7272
return issubclass(cls, parent_cls)
7373

7474

75+
def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
7576
"""Return True if `x` is a zero-gradient array.
7677
7778
These arrays are a design quirk of Jax that may one day be removed.
@@ -276,7 +277,7 @@ def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[repo
276277
return hasattr(x, '__array_namespace__') or _is_array_api_cls(type(x))
277278

278279

279-
@cache
280+
@lru_cache(100)
280281
def _is_array_api_cls(cls: type) -> bool:
281282
return (
282283
# TODO: drop support for numpy<2 which didn't have __array_namespace__
@@ -296,7 +297,7 @@ def _compat_module_name() -> str:
296297
return __name__.removesuffix(".common._helpers")
297298

298299

299-
@cache
300+
@lru_cache(100)
300301
def is_numpy_namespace(xp: Namespace) -> bool:
301302
"""
302303
Returns True if `xp` is a NumPy namespace.
@@ -318,7 +319,7 @@ def is_numpy_namespace(xp: Namespace) -> bool:
318319
return xp.__name__ in {"numpy", _compat_module_name() + ".numpy"}
319320

320321

321-
@cache
322+
@lru_cache(100)
322323
def is_cupy_namespace(xp: Namespace) -> bool:
323324
"""
324325
Returns True if `xp` is a CuPy namespace.
@@ -340,7 +341,7 @@ def is_cupy_namespace(xp: Namespace) -> bool:
340341
return xp.__name__ in {"cupy", _compat_module_name() + ".cupy"}
341342

342343

343-
@cache
344+
@lru_cache(100)
344345
def is_torch_namespace(xp: Namespace) -> bool:
345346
"""
346347
Returns True if `xp` is a PyTorch namespace.
@@ -381,7 +382,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool:
381382
return xp.__name__ == "ndonnx"
382383

383384

384-
@cache
385+
@lru_cache(100)
385386
def is_dask_namespace(xp: Namespace) -> bool:
386387
"""
387388
Returns True if `xp` is a Dask namespace.
@@ -922,7 +923,7 @@ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
922923
return None if math.isnan(out) else out
923924

924925

925-
@cache
926+
@lru_cache(100)
926927
def _is_writeable_cls(cls: type) -> bool | None:
927928
if (
928929
_issubclass_fast(cls, "numpy", "generic")
@@ -954,7 +955,7 @@ def is_writeable_array(x: object) -> bool:
954955
return hasattr(x, '__array_namespace__')
955956

956957

957-
@cache
958+
@lru_cache(100)
958959
def _is_lazy_cls(cls: type) -> bool | None:
959960
if (
960961
_issubclass_fast(cls, "numpy", "ndarray")
@@ -1054,7 +1055,7 @@ def is_lazy_array(x: object) -> bool:
10541055
"to_device",
10551056
]
10561057

1057-
_all_ignore = ['cache', 'sys', 'math', 'inspect', 'warnings']
1058+
_all_ignore = ['lru_cache', 'sys', 'math', 'inspect', 'warnings']
10581059

10591060
def __dir__() -> list[str]:
10601061
return __all__

0 commit comments

Comments
 (0)