Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add preliminary support for the draft 2024.12 version of the standard #82

Merged
merged 11 commits into from
Nov 11, 2024
12 changes: 8 additions & 4 deletions array_api_strict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,12 @@
minimum,
multiply,
negative,
nextafter,
not_equal,
positive,
pow,
real,
reciprocal,
remainder,
round,
sign,
Expand Down Expand Up @@ -240,10 +242,12 @@
"minimum",
"multiply",
"negative",
"nextafter",
"not_equal",
"positive",
"pow",
"real",
"reciprocal",
"remainder",
"round",
"sign",
Expand All @@ -258,9 +262,9 @@
"trunc",
]

from ._indexing_functions import take
from ._indexing_functions import take, take_along_axis

__all__ += ["take"]
__all__ += ["take", "take_along_axis"]

from ._info import __array_namespace_info__

Expand Down Expand Up @@ -305,9 +309,9 @@

__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"]

from ._utility_functions import all, any
from ._utility_functions import all, any, diff

__all__ += ["all", "any"]
__all__ += ["all", "any", "diff"]

from ._array_object import Device
__all__ += ["Device"]
Expand Down
25 changes: 25 additions & 0 deletions array_api_strict/_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,20 @@ def negative(x: Array, /) -> Array:
return Array._new(np.negative(x._array), device=x.device)


@requires_api_version('2024.12')
def nextafter(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.nextafter <numpy.nextafter>`.

See its docstring for more information.
"""
if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
raise TypeError("Only real floating-point dtypes are allowed in nextafter")
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.nextafter(x1._array, x2._array), device=x1.device)

def not_equal(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.not_equal <numpy.not_equal>`.
Expand Down Expand Up @@ -858,6 +872,17 @@ def real(x: Array, /) -> Array:
return Array._new(np.real(x._array), device=x.device)


@requires_api_version('2024.12')
def reciprocal(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.reciprocal <numpy.reciprocal>`.

See its docstring for more information.
"""
if x.dtype not in _floating_dtypes:
raise TypeError("Only floating-point dtypes are allowed in reciprocal")
return Array._new(np.reciprocal(x._array), device=x.device)

def remainder(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.remainder <numpy.remainder>`.
Expand Down
11 changes: 8 additions & 3 deletions array_api_strict/_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
"2023.12",
)

draft_version = "2024.12"

API_VERSION = default_version = "2023.12"

BOOLEAN_INDEXING = True
Expand Down Expand Up @@ -70,8 +72,8 @@ def set_array_api_strict_flags(
----------
api_version : str, optional
The version of the standard to use. Supported versions are:
``{supported_versions}``. The default version number is
``{default_version!r}``.
``{supported_versions}``, plus the draft version ``{draft_version}``.
The default version number is ``{default_version!r}``.

Note that 2021.12 is supported, but currently gives the same thing as
2022.12 (except that the fft extension will be disabled).
Expand Down Expand Up @@ -134,10 +136,12 @@ def set_array_api_strict_flags(
global API_VERSION, BOOLEAN_INDEXING, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS

if api_version is not None:
if api_version not in supported_versions:
if api_version not in [*supported_versions, draft_version]:
raise ValueError(f"Unsupported standard version {api_version!r}")
if api_version == "2021.12":
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12", stacklevel=2)
if api_version == draft_version:
warnings.warn(f"The {draft_version} version of the array API specification is in draft status. Not all features are implemented in array_api_strict, some functions may not be fully tested, and behaviors are subject to change before the final standard release.")
API_VERSION = api_version
array_api_strict.__array_api_version__ = API_VERSION

Expand Down Expand Up @@ -169,6 +173,7 @@ def set_array_api_strict_flags(
supported_versions=supported_versions,
default_version=default_version,
default_extensions=default_extensions,
draft_version=draft_version,
)

def get_array_api_strict_flags():
Expand Down
12 changes: 12 additions & 0 deletions array_api_strict/_indexing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from ._array_object import Array
from ._dtypes import _integer_dtypes
from ._flags import requires_api_version

from typing import TYPE_CHECKING

Expand All @@ -25,3 +26,14 @@ def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array:
if x.device != indices.device:
raise ValueError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.")
return Array._new(np.take(x._array, indices._array, axis=axis), device=x.device)

@requires_api_version('2024.12')
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
Copy link
Preview

Copilot AI Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function take_along_axis should check that indices.dtype is in _integer_dtypes to ensure only integer indices are allowed.

Copilot is powered by AI, so mistakes are possible. Review output carefully before use.

Positive Feedback
Negative Feedback

Provide additional feedback

Please help us improve GitHub Copilot by sharing more details about this comment.

Please select one or more of the options
"""
Array API compatible wrapper for :py:func:`np.take_along_axis <numpy.take_along_axis>`.

See its docstring for more information.
"""
if x.device != indices.device:
raise ValueError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.")
return Array._new(np.take_along_axis(x._array, indices._array, axis), device=x.device)
Loading
Loading