Skip to content

Commit

Permalink
Add support for scalar arguments to xp.where
Browse files Browse the repository at this point in the history
  • Loading branch information
betatim committed Nov 25, 2024
1 parent d086c61 commit d9a6b73
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
11 changes: 9 additions & 2 deletions array_api_strict/_searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from ._array_object import Array
from ._dtypes import _result_type, _real_numeric_dtypes
from ._flags import requires_data_dependent_shapes, requires_api_version
from ._flags import requires_data_dependent_shapes, requires_api_version, get_array_api_strict_flags

from typing import TYPE_CHECKING
if TYPE_CHECKING:
Expand Down Expand Up @@ -72,12 +72,19 @@ def searchsorted(
# x1 must be 1-D, but NumPy already requires this.
return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter), device=x1.device)

def where(condition: Array, x1: Array, x2: Array, /) -> Array:
def where(condition: Array, x1: bool | int | float | Array, x2: bool | int | float | Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.where <numpy.where>`.
See its docstring for more information.
"""
if get_array_api_strict_flags()['api_version'] > '2023.12':
if isinstance(x1, (bool, float, int)):
x1 = Array._new(np.asarray(x1), device=condition.device)

if isinstance(x2, (bool, float, int)):
x2 = Array._new(np.asarray(x2), device=condition.device)

# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)

Expand Down
21 changes: 21 additions & 0 deletions array_api_strict/tests/test_searching_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest

import array_api_strict as xp

from array_api_strict import ArrayAPIStrictFlags
from array_api_strict._flags import next_supported_version


def test_where_with_scalars():
x = xp.asarray([1, 2, 3, 1])

# Versions up to and including 2023.12 don't support scalar arguments
with pytest.raises(AttributeError, match="object has no attribute 'dtype'"):
xp.where(x == 1, 42, 44)

# Versions after 2023.12 support scalar arguments
with ArrayAPIStrictFlags(api_version=next_supported_version):
x_where = xp.where(x == 1, 42, 44)

expected = xp.asarray([42, 44, 44, 42])
assert xp.all(x_where == expected)

0 comments on commit d9a6b73

Please sign in to comment.