diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index ff43660..98b0e95 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -172,10 +172,12 @@ minimum, multiply, negative, + nextafter, not_equal, positive, pow, real, + reciprocal, remainder, round, sign, @@ -240,10 +242,12 @@ "minimum", "multiply", "negative", + "nextafter", "not_equal", "positive", "pow", "real", + "reciprocal", "remainder", "round", "sign", @@ -258,9 +262,9 @@ "trunc", ] -from ._indexing_functions import take +from ._indexing_functions import take, take_along_axis -__all__ += ["take"] +__all__ += ["take", "take_along_axis"] from ._info import __array_namespace_info__ @@ -305,9 +309,9 @@ __all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"] -from ._utility_functions import all, any +from ._utility_functions import all, any, diff -__all__ += ["all", "any"] +__all__ += ["all", "any", "diff"] from ._array_object import Device __all__ += ["Device"] diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 8cec86a..7c64f67 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -805,6 +805,20 @@ def negative(x: Array, /) -> Array: return Array._new(np.negative(x._array), device=x.device) +@requires_api_version('2024.12') +def nextafter(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.nextafter `. + + See its docstring for more information. + """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: + raise TypeError("Only real floating-point dtypes are allowed in nextafter") + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.nextafter(x1._array, x2._array), device=x1.device) + def not_equal(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.not_equal `. @@ -858,6 +872,17 @@ def real(x: Array, /) -> Array: return Array._new(np.real(x._array), device=x.device) +@requires_api_version('2024.12') +def reciprocal(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.reciprocal `. + + See its docstring for more information. + """ + if x.dtype not in _floating_dtypes: + raise TypeError("Only floating-point dtypes are allowed in reciprocal") + return Array._new(np.reciprocal(x._array), device=x.device) + def remainder(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.remainder `. diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index c393ad9..2863e5f 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -24,6 +24,8 @@ "2023.12", ) +draft_version = "2024.12" + API_VERSION = default_version = "2023.12" BOOLEAN_INDEXING = True @@ -70,8 +72,8 @@ def set_array_api_strict_flags( ---------- api_version : str, optional The version of the standard to use. Supported versions are: - ``{supported_versions}``. The default version number is - ``{default_version!r}``. + ``{supported_versions}``, plus the draft version ``{draft_version}``. + The default version number is ``{default_version!r}``. Note that 2021.12 is supported, but currently gives the same thing as 2022.12 (except that the fft extension will be disabled). @@ -134,10 +136,12 @@ def set_array_api_strict_flags( global API_VERSION, BOOLEAN_INDEXING, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS if api_version is not None: - if api_version not in supported_versions: + if api_version not in [*supported_versions, draft_version]: raise ValueError(f"Unsupported standard version {api_version!r}") if api_version == "2021.12": warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12", stacklevel=2) + if api_version == draft_version: + warnings.warn(f"The {draft_version} version of the array API specification is in draft status. Not all features are implemented in array_api_strict, some functions may not be fully tested, and behaviors are subject to change before the final standard release.") API_VERSION = api_version array_api_strict.__array_api_version__ = API_VERSION @@ -169,6 +173,7 @@ def set_array_api_strict_flags( supported_versions=supported_versions, default_version=default_version, default_extensions=default_extensions, + draft_version=draft_version, ) def get_array_api_strict_flags(): diff --git a/array_api_strict/_indexing_functions.py b/array_api_strict/_indexing_functions.py index c0f8e26..d7a400e 100644 --- a/array_api_strict/_indexing_functions.py +++ b/array_api_strict/_indexing_functions.py @@ -2,6 +2,7 @@ from ._array_object import Array from ._dtypes import _integer_dtypes +from ._flags import requires_api_version from typing import TYPE_CHECKING @@ -25,3 +26,14 @@ def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array: if x.device != indices.device: raise ValueError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.") return Array._new(np.take(x._array, indices._array, axis=axis), device=x.device) + +@requires_api_version('2024.12') +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: + """ + Array API compatible wrapper for :py:func:`np.take_along_axis `. + + See its docstring for more information. + """ + if x.device != indices.device: + raise ValueError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.") + return Array._new(np.take_along_axis(x._array, indices._array, axis), device=x.device) diff --git a/array_api_strict/_info.py b/array_api_strict/_info.py index 3ed7fb2..a9dbebf 100644 --- a/array_api_strict/_info.py +++ b/array_api_strict/_info.py @@ -2,131 +2,138 @@ from typing import TYPE_CHECKING +import numpy as np + if TYPE_CHECKING: from typing import Optional, Union, Tuple, List - from ._typing import device, DefaultDataTypes, DataTypes, Capabilities, Info + from ._typing import device, DefaultDataTypes, DataTypes, Capabilities from ._array_object import ALL_DEVICES, CPU_DEVICE from ._flags import get_array_api_strict_flags, requires_api_version from ._dtypes import bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128 @requires_api_version('2023.12') -def __array_namespace_info__() -> Info: - import array_api_strict._info - return array_api_strict._info - -@requires_api_version('2023.12') -def capabilities() -> Capabilities: - flags = get_array_api_strict_flags() - return {"boolean indexing": flags['boolean_indexing'], - "data-dependent shapes": flags['data_dependent_shapes'], - } - -@requires_api_version('2023.12') -def default_device() -> device: - return CPU_DEVICE +class __array_namespace_info__: + @requires_api_version('2023.12') + def capabilities(self) -> Capabilities: + flags = get_array_api_strict_flags() + res = {"boolean indexing": flags['boolean_indexing'], + "data-dependent shapes": flags['data_dependent_shapes'], + } + if flags['api_version'] >= '2024.12': + # maxdims is 32 for NumPy 1.x and 64 for NumPy 2.0. Eventually we will + # drop support for NumPy 1 but for now, just compute the number + # directly + for i in range(1, 100): + try: + np.zeros((1,)*i) + except ValueError: + maxdims = i - 1 + break + else: + raise RuntimeError("Could not get max dimensions (this is a bug in array-api-strict)") + res['max dimensions'] = maxdims + return res -@requires_api_version('2023.12') -def default_dtypes( - *, - device: Optional[device] = None, -) -> DefaultDataTypes: - return { - "real floating": float64, - "complex floating": complex128, - "integral": int64, - "indexing": int64, - } + @requires_api_version('2023.12') + def default_device(self) -> device: + return CPU_DEVICE -@requires_api_version('2023.12') -def dtypes( - *, - device: Optional[device] = None, - kind: Optional[Union[str, Tuple[str, ...]]] = None, -) -> DataTypes: - if kind is None: - return { - "bool": bool, - "int8": int8, - "int16": int16, - "int32": int32, - "int64": int64, - "uint8": uint8, - "uint16": uint16, - "uint32": uint32, - "uint64": uint64, - "float32": float32, - "float64": float64, - "complex64": complex64, - "complex128": complex128, - } - if kind == "bool": - return {"bool": bool} - if kind == "signed integer": - return { - "int8": int8, - "int16": int16, - "int32": int32, - "int64": int64, - } - if kind == "unsigned integer": - return { - "uint8": uint8, - "uint16": uint16, - "uint32": uint32, - "uint64": uint64, - } - if kind == "integral": + @requires_api_version('2023.12') + def default_dtypes( + self, + *, + device: Optional[device] = None, + ) -> DefaultDataTypes: return { - "int8": int8, - "int16": int16, - "int32": int32, - "int64": int64, - "uint8": uint8, - "uint16": uint16, - "uint32": uint32, - "uint64": uint64, + "real floating": float64, + "complex floating": complex128, + "integral": int64, + "indexing": int64, } - if kind == "real floating": - return { - "float32": float32, - "float64": float64, - } - if kind == "complex floating": - return { - "complex64": complex64, - "complex128": complex128, - } - if kind == "numeric": - return { - "int8": int8, - "int16": int16, - "int32": int32, - "int64": int64, - "uint8": uint8, - "uint16": uint16, - "uint32": uint32, - "uint64": uint64, - "float32": float32, - "float64": float64, - "complex64": complex64, - "complex128": complex128, - } - if isinstance(kind, tuple): - res = {} - for k in kind: - res.update(dtypes(kind=k)) - return res - raise ValueError(f"unsupported kind: {kind!r}") -@requires_api_version('2023.12') -def devices() -> List[device]: - return list(ALL_DEVICES) + @requires_api_version('2023.12') + def dtypes( + self, + *, + device: Optional[device] = None, + kind: Optional[Union[str, Tuple[str, ...]]] = None, + ) -> DataTypes: + if kind is None: + return { + "bool": bool, + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + "float32": float32, + "float64": float64, + "complex64": complex64, + "complex128": complex128, + } + if kind == "bool": + return {"bool": bool} + if kind == "signed integer": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + } + if kind == "unsigned integer": + return { + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + } + if kind == "integral": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + } + if kind == "real floating": + return { + "float32": float32, + "float64": float64, + } + if kind == "complex floating": + return { + "complex64": complex64, + "complex128": complex128, + } + if kind == "numeric": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + "float32": float32, + "float64": float64, + "complex64": complex64, + "complex128": complex128, + } + if isinstance(kind, tuple): + res = {} + for k in kind: + res.update(self.dtypes(kind=k)) + return res + raise ValueError(f"unsupported kind: {kind!r}") -__all__ = [ - "capabilities", - "default_device", - "default_dtypes", - "devices", - "dtypes", -] + @requires_api_version('2023.12') + def devices(self) -> List[device]: + return list(ALL_DEVICES) diff --git a/array_api_strict/_typing.py b/array_api_strict/_typing.py index 05a479c..94c4975 100644 --- a/array_api_strict/_typing.py +++ b/array_api_strict/_typing.py @@ -21,7 +21,6 @@ from typing import ( Any, - ModuleType, TypedDict, TypeVar, Protocol, @@ -29,6 +28,7 @@ from ._array_object import Array, _device from ._dtypes import _DType +from ._info import __array_namespace_info__ _T_co = TypeVar("_T_co", covariant=True) @@ -41,7 +41,7 @@ def __len__(self, /) -> int: ... Dtype = _DType -Info = ModuleType +Info = __array_namespace_info__ if sys.version_info >= (3, 12): from collections.abc import Buffer as SupportsBufferProtocol @@ -54,7 +54,8 @@ class SupportsDLPack(Protocol): def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ... Capabilities = TypedDict( - "Capabilities", {"boolean indexing": bool, "data-dependent shapes": bool} + "Capabilities", {"boolean indexing": bool, "data-dependent shapes": bool, + "max dimensions": int} ) DefaultDataTypes = TypedDict( diff --git a/array_api_strict/_utility_functions.py b/array_api_strict/_utility_functions.py index 0d44ecb..f75f36f 100644 --- a/array_api_strict/_utility_functions.py +++ b/array_api_strict/_utility_functions.py @@ -1,6 +1,8 @@ from __future__ import annotations from ._array_object import Array +from ._flags import requires_api_version +from ._dtypes import _numeric_dtypes from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -37,3 +39,31 @@ def any( See its docstring for more information. """ return Array._new(np.asarray(np.any(x._array, axis=axis, keepdims=keepdims)), device=x.device) + +@requires_api_version('2024.12') +def diff( + x: Array, + /, + *, + axis: int = -1, + n: int = 1, + prepend: Optional[Array] = None, + append: Optional[Array] = None, +) -> Array: + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in diff") + + # TODO: The type promotion behavior for prepend and append is not + # currently specified. + + # NumPy does not support prepend=None or append=None + kwargs = dict(axis=axis, n=n) + if prepend is not None: + if prepend.device != x.device: + raise ValueError(f"Arrays from two different devices ({prepend.device} and {x.device}) can not be combined.") + kwargs['prepend'] = prepend._array + if append is not None: + if append.device != x.device: + raise ValueError(f"Arrays from two different devices ({append.device} and {x.device}) can not be combined.") + kwargs['append'] = append._array + return Array._new(np.diff(x._array, **kwargs), device=x.device) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index a9ea26d..5b8dbff 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -440,8 +440,12 @@ def test_array_namespace(): assert a.__array_namespace__(api_version="2021.12") is array_api_strict assert array_api_strict.__array_api_version__ == "2021.12" + with pytest.warns(UserWarning): + assert a.__array_namespace__(api_version="2024.12") is array_api_strict + assert array_api_strict.__array_api_version__ == "2024.12" + pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11")) - pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12")) + pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2025.12")) def test_iter(): pytest.raises(TypeError, lambda: iter(asarray(3))) diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index de11edf..4e1b9cc 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -2,6 +2,8 @@ from numpy.testing import assert_raises +import pytest + from .. import asarray, _elementwise_functions from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift from .._dtypes import ( @@ -79,10 +81,12 @@ def nargs(func): "minimum": "real numeric", "multiply": "numeric", "negative": "numeric", + "nextafter": "real floating-point", "not_equal": "all", "positive": "numeric", "pow": "numeric", "real": "complex floating-point", + "reciprocal": "floating-point", "remainder": "real numeric", "round": "numeric", "sign": "numeric", @@ -126,7 +130,8 @@ def _array_vals(dtypes): yield asarray(1., dtype=d) # Use the latest version of the standard so all functions are included - set_array_api_strict_flags(api_version="2023.12") + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version="2024.12") for func_name, types in elementwise_function_input_types.items(): dtypes = _dtype_categories[types] @@ -162,7 +167,8 @@ def _array_vals(): yield asarray(1.0, dtype=d) # Use the latest version of the standard so all functions are included - set_array_api_strict_flags(api_version="2023.12") + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version="2024.12") for x in _array_vals(): for func_name, types in elementwise_function_input_types.items(): diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 2603f35..e0b004b 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -3,8 +3,6 @@ from .._flags import (set_array_api_strict_flags, get_array_api_strict_flags, reset_array_api_strict_flags) -from .._info import (capabilities, default_device, default_dtypes, devices, - dtypes) from .._fft import (fft, ifft, fftn, ifftn, rfft, irfft, rfftn, irfftn, hfft, ihfft, fftfreq, rfftfreq, fftshift, ifftshift) from .._linalg import (cholesky, cross, det, diagonal, eigh, eigvalsh, inv, @@ -99,6 +97,21 @@ def test_flags_api_version_2023_12(): 'enabled_extensions': ('linalg', 'fft'), } +def test_flags_api_version_2024_12(): + # Make sure setting the version to 2024.12 issues a warning. + with pytest.warns(UserWarning) as record: + set_array_api_strict_flags(api_version='2024.12') + assert len(record) == 1 + assert '2024.12' in str(record[0].message) + assert 'draft' in str(record[0].message) + flags = get_array_api_strict_flags() + assert flags == { + 'api_version': '2024.12', + 'boolean_indexing': True, + 'data_dependent_shapes': True, + 'enabled_extensions': ('linalg', 'fft'), + } + def test_setting_flags_invalid(): # Test setting flags with invalid values pytest.raises(ValueError, lambda: @@ -245,14 +258,17 @@ def test_fft(func_name): set_array_api_strict_flags(enabled_extensions=('fft',)) func() +# Test functionality even if the info object is already created +_info = xp.__array_namespace_info__() + api_version_2023_12_examples = { '__array_namespace_info__': lambda: xp.__array_namespace_info__(), # Test these functions directly to ensure they are properly decorated - 'capabilities': capabilities, - 'default_device': default_device, - 'default_dtypes': default_dtypes, - 'devices': devices, - 'dtypes': dtypes, + 'capabilities': _info.capabilities, + 'default_device': _info.default_device, + 'default_dtypes': _info.default_dtypes, + 'devices': _info.devices, + 'dtypes': _info.dtypes, 'clip': lambda: xp.clip(xp.asarray([1, 2, 3]), 1, 2), 'copysign': lambda: xp.copysign(xp.asarray([1., 2., 3.]), xp.asarray([-1., -1., -1.])), 'cumulative_sum': lambda: xp.cumulative_sum(xp.asarray([1, 2, 3])), @@ -285,6 +301,36 @@ def test_api_version_2023_12(func_name): set_array_api_strict_flags(api_version='2022.12') pytest.raises(RuntimeError, func) +api_version_2024_12_examples = { + 'diff': lambda: xp.diff(xp.asarray([0, 1, 2])), + 'nextafter': lambda: xp.nextafter(xp.asarray(0.), xp.asarray(1.)), + 'reciprocal': lambda: xp.reciprocal(xp.asarray([2.])), + 'take_along_axis': lambda: xp.take_along_axis(xp.zeros((2, 3)), + xp.zeros((1, 4), dtype=xp.int64)), +} + +@pytest.mark.parametrize('func_name', api_version_2024_12_examples.keys()) +def test_api_version_2024_12(func_name): + func = api_version_2024_12_examples[func_name] + + # By default, these functions should error + pytest.raises(RuntimeError, func) + + # In 2022.12 and 2023.12, these functions should error + set_array_api_strict_flags(api_version='2022.12') + pytest.raises(RuntimeError, func) + set_array_api_strict_flags(api_version='2023.12') + pytest.raises(RuntimeError, func) + + # They should not error in 2024.12 + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version='2024.12') + func() + + # Test the behavior gets updated properly + set_array_api_strict_flags(api_version='2023.12') + pytest.raises(RuntimeError, func) + def test_disabled_extensions(): # Test that xp.extension errors when an extension is disabled, and that # xp.__all__ is updated properly.