From d9a6b73fb58b0d882d1ab3c7b007e4fa417fe71f Mon Sep 17 00:00:00 2001 From: Tim Head Date: Thu, 31 Oct 2024 16:18:34 +0100 Subject: [PATCH] Add support for scalar arguments to xp.where --- array_api_strict/_searching_functions.py | 11 ++++++++-- .../tests/test_searching_functions.py | 21 +++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) create mode 100644 array_api_strict/tests/test_searching_functions.py diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index 0d7c0c8..05c7f2a 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -2,7 +2,7 @@ from ._array_object import Array from ._dtypes import _result_type, _real_numeric_dtypes -from ._flags import requires_data_dependent_shapes, requires_api_version +from ._flags import requires_data_dependent_shapes, requires_api_version, get_array_api_strict_flags from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -72,12 +72,19 @@ def searchsorted( # x1 must be 1-D, but NumPy already requires this. return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter), device=x1.device) -def where(condition: Array, x1: Array, x2: Array, /) -> Array: +def where(condition: Array, x1: bool | int | float | Array, x2: bool | int | float | Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.where `. See its docstring for more information. """ + if get_array_api_strict_flags()['api_version'] > '2023.12': + if isinstance(x1, (bool, float, int)): + x1 = Array._new(np.asarray(x1), device=condition.device) + + if isinstance(x2, (bool, float, int)): + x2 = Array._new(np.asarray(x2), device=condition.device) + # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) diff --git a/array_api_strict/tests/test_searching_functions.py b/array_api_strict/tests/test_searching_functions.py new file mode 100644 index 0000000..7d7ae1e --- /dev/null +++ b/array_api_strict/tests/test_searching_functions.py @@ -0,0 +1,21 @@ +import pytest + +import array_api_strict as xp + +from array_api_strict import ArrayAPIStrictFlags +from array_api_strict._flags import next_supported_version + + +def test_where_with_scalars(): + x = xp.asarray([1, 2, 3, 1]) + + # Versions up to and including 2023.12 don't support scalar arguments + with pytest.raises(AttributeError, match="object has no attribute 'dtype'"): + xp.where(x == 1, 42, 44) + + # Versions after 2023.12 support scalar arguments + with ArrayAPIStrictFlags(api_version=next_supported_version): + x_where = xp.where(x == 1, 42, 44) + + expected = xp.asarray([42, 44, 44, 42]) + assert xp.all(x_where == expected)