From d9a6b73fb58b0d882d1ab3c7b007e4fa417fe71f Mon Sep 17 00:00:00 2001 From: Tim Head Date: Thu, 31 Oct 2024 16:18:34 +0100 Subject: [PATCH 1/4] Add support for scalar arguments to xp.where --- array_api_strict/_searching_functions.py | 11 ++++++++-- .../tests/test_searching_functions.py | 21 +++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) create mode 100644 array_api_strict/tests/test_searching_functions.py diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index 0d7c0c8..05c7f2a 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -2,7 +2,7 @@ from ._array_object import Array from ._dtypes import _result_type, _real_numeric_dtypes -from ._flags import requires_data_dependent_shapes, requires_api_version +from ._flags import requires_data_dependent_shapes, requires_api_version, get_array_api_strict_flags from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -72,12 +72,19 @@ def searchsorted( # x1 must be 1-D, but NumPy already requires this. return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter), device=x1.device) -def where(condition: Array, x1: Array, x2: Array, /) -> Array: +def where(condition: Array, x1: bool | int | float | Array, x2: bool | int | float | Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.where `. See its docstring for more information. """ + if get_array_api_strict_flags()['api_version'] > '2023.12': + if isinstance(x1, (bool, float, int)): + x1 = Array._new(np.asarray(x1), device=condition.device) + + if isinstance(x2, (bool, float, int)): + x2 = Array._new(np.asarray(x2), device=condition.device) + # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) diff --git a/array_api_strict/tests/test_searching_functions.py b/array_api_strict/tests/test_searching_functions.py new file mode 100644 index 0000000..7d7ae1e --- /dev/null +++ b/array_api_strict/tests/test_searching_functions.py @@ -0,0 +1,21 @@ +import pytest + +import array_api_strict as xp + +from array_api_strict import ArrayAPIStrictFlags +from array_api_strict._flags import next_supported_version + + +def test_where_with_scalars(): + x = xp.asarray([1, 2, 3, 1]) + + # Versions up to and including 2023.12 don't support scalar arguments + with pytest.raises(AttributeError, match="object has no attribute 'dtype'"): + xp.where(x == 1, 42, 44) + + # Versions after 2023.12 support scalar arguments + with ArrayAPIStrictFlags(api_version=next_supported_version): + x_where = xp.where(x == 1, 42, 44) + + expected = xp.asarray([42, 44, 44, 42]) + assert xp.all(x_where == expected) From f11c85a4458cd57a6b949117bb9bee50ab9690b4 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 25 Nov 2024 14:31:08 +0100 Subject: [PATCH 2/4] Adapt to draft_version --- array_api_strict/_searching_functions.py | 2 +- array_api_strict/tests/test_searching_functions.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index 05c7f2a..7296b14 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -89,7 +89,7 @@ def where(condition: Array, x1: bool | int | float | Array, x2: bool | int | flo _result_type(x1.dtype, x2.dtype) if len({a.device for a in (condition, x1, x2)}) > 1: - raise ValueError("where inputs must all be on the same device") + raise ValueError("Inputs to `where` must all use the same device") x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.where(condition._array, x1._array, x2._array), device=x1.device) diff --git a/array_api_strict/tests/test_searching_functions.py b/array_api_strict/tests/test_searching_functions.py index 7d7ae1e..d29e08c 100644 --- a/array_api_strict/tests/test_searching_functions.py +++ b/array_api_strict/tests/test_searching_functions.py @@ -3,7 +3,7 @@ import array_api_strict as xp from array_api_strict import ArrayAPIStrictFlags -from array_api_strict._flags import next_supported_version +from array_api_strict._flags import draft_version def test_where_with_scalars(): @@ -14,7 +14,7 @@ def test_where_with_scalars(): xp.where(x == 1, 42, 44) # Versions after 2023.12 support scalar arguments - with ArrayAPIStrictFlags(api_version=next_supported_version): + with ArrayAPIStrictFlags(api_version=draft_version): x_where = xp.where(x == 1, 42, 44) expected = xp.asarray([42, 44, 44, 42]) From 9feef0789f7d1a11c1289f6336784722fd9b8565 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 25 Nov 2024 15:02:23 +0100 Subject: [PATCH 3/4] Expect warnings --- array_api_strict/tests/test_searching_functions.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/array_api_strict/tests/test_searching_functions.py b/array_api_strict/tests/test_searching_functions.py index d29e08c..8ccd1dd 100644 --- a/array_api_strict/tests/test_searching_functions.py +++ b/array_api_strict/tests/test_searching_functions.py @@ -14,7 +14,12 @@ def test_where_with_scalars(): xp.where(x == 1, 42, 44) # Versions after 2023.12 support scalar arguments - with ArrayAPIStrictFlags(api_version=draft_version): + with (pytest.warns( + UserWarning, + match="The 2024.12 version of the array API specification is in draft status" + ), + ArrayAPIStrictFlags(api_version=draft_version), + ): x_where = xp.where(x == 1, 42, 44) expected = xp.asarray([42, 44, 44, 42]) From e0a8d7a8a193fe9b7c60f20dc0fcf0bb470ed8b9 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 1 Feb 2025 15:01:55 +0100 Subject: [PATCH 4/4] Require that one of x1,x2 arguments to where is an array --- array_api_strict/_searching_functions.py | 17 ++++++++++++++--- .../tests/test_searching_functions.py | 6 +++++- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index 046b569..ad32aaa 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -90,18 +90,29 @@ def searchsorted( # x1 must be 1-D, but NumPy already requires this. return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter), device=x1.device) -def where(condition: Array, x1: bool | int | float | Array, x2: bool | int | float | Array, /) -> Array: +def where( + condition: Array, + x1: bool | int | float | complex | Array, + x2: bool | int | float | complex | Array, / +) -> Array: """ Array API compatible wrapper for :py:func:`np.where `. See its docstring for more information. """ if get_array_api_strict_flags()['api_version'] > '2023.12': - if isinstance(x1, (bool, float, int)): + num_scalars = 0 + + if isinstance(x1, (bool, float, complex, int)): x1 = Array._new(np.asarray(x1), device=condition.device) + num_scalars += 1 - if isinstance(x2, (bool, float, int)): + if isinstance(x2, (bool, float, complex, int)): x2 = Array._new(np.asarray(x2), device=condition.device) + num_scalars += 1 + + if num_scalars == 2: + raise ValueError("One of x1, x2 arguments must be an array.") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) diff --git a/array_api_strict/tests/test_searching_functions.py b/array_api_strict/tests/test_searching_functions.py index 8ccd1dd..dfb3fe7 100644 --- a/array_api_strict/tests/test_searching_functions.py +++ b/array_api_strict/tests/test_searching_functions.py @@ -20,7 +20,11 @@ def test_where_with_scalars(): ), ArrayAPIStrictFlags(api_version=draft_version), ): - x_where = xp.where(x == 1, 42, 44) + x_where = xp.where(x == 1, xp.asarray(42), 44) expected = xp.asarray([42, 44, 44, 42]) assert xp.all(x_where == expected) + + # The spec does not allow both x1 and x2 to be scalars + with pytest.raises(ValueError, match="One of"): + xp.where(x == 1, 42, 44)