Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
cov,
expand_dims,
isclose,
isin,
nan_to_num,
one_hot,
pad,
Expand Down Expand Up @@ -39,6 +40,7 @@
"default_dtype",
"expand_dims",
"isclose",
"isin",
"kron",
"lazy_apply",
"nan_to_num",
Expand Down
59 changes: 59 additions & 0 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,3 +836,62 @@ def argpartition(
# kth is not small compared to x.size

return _funcs.argpartition(a, kth, axis=axis, xp=xp)


def isin(
a: Array,
b: Array,
/,
*,
assume_unique: bool = False,
invert: bool = False,
kind: str | None = None,
xp: ModuleType | None = None,
) -> Array:
"""
Determine whether each element in `a` is present in `b`.

Return a boolean array of the same shape as `a` that is True for elements
that are in `b` and False otherwise.

Parameters
----------
a : array
Input elements.
b : array
The elements against which to test each element of `a`.
assume_unique : bool, optional
If True, the input arrays are both assumed to be unique which can speed
up the calculation. Default: False.
invert : bool, optional
If True, the values in the returned array are inverted. Default: False.
kind : str | None, optional
The algorithm or method to use. This will not affect the final result,
but will affect the speed and memory use.
For NumPy the options are {None, "sort", "table"}.
For Jax the mapped parameter is instead `method` and the options are
{"compare_all", "binary_search", "sort", and "auto" (default)}
For CuPy, Dask, Torch and the default case this parameter is not present and
thus ignored. Default: None.
xp : array_namespace, optional
The standard-compatible namespace for `a` and `b`. Default: infer.

Returns
-------
array
An array having the same shape as that of `a` that is True for elements
that are in `b` and False otherwise.
"""
if xp is None:
xp = array_namespace(a, b)

if is_numpy_namespace(xp):
return xp.isin(a, b, assume_unique=assume_unique, invert=invert, kind=kind)
if is_jax_namespace(xp):
if kind is None:
kind = "auto"
return xp.isin(a, b, assume_unique=assume_unique, invert=invert, method=kind)
if is_cupy_namespace(xp) or is_torch_namespace(xp) or is_dask_namespace(xp):
return xp.isin(a, b, assume_unique=assume_unique, invert=invert)

return _funcs.isin(a, b, assume_unique=assume_unique, invert=invert, xp=xp)
19 changes: 19 additions & 0 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,3 +801,22 @@ def argpartition( # numpydoc ignore=PR01,RT01
) -> Array:
"""See docstring in `array_api_extra._delegation.py`."""
return xp.argsort(x, axis=axis, stable=False)


def isin( # numpydoc ignore=PR01,RT01
a: Array,
b: Array,
/,
*,
assume_unique: bool = False,
invert: bool = False,
xp: ModuleType,
) -> Array:
"""See docstring in `array_api_extra._delegation.py`."""
original_a_shape = a.shape
a = xp.reshape(a, (-1,))
b = xp.reshape(b, (-1,))
return xp.reshape(
_helpers.in1d(a, b, assume_unique=assume_unique, invert=invert, xp=xp),
original_a_shape,
)
55 changes: 54 additions & 1 deletion tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
default_dtype,
expand_dims,
isclose,
isin,
kron,
nan_to_num,
nunique,
Expand Down Expand Up @@ -888,7 +889,7 @@ def test_device(self, xp: ModuleType, device: Device, equal_nan: bool):
b = xp.asarray([1e-9, 1e-4, xp.nan], device=device)
res = isclose(a, b, equal_nan=equal_nan)
assert get_device(res) == device

def test_array_on_device_with_scalar(self, xp: ModuleType, device: Device):
a = xp.asarray([0.01, 0.5, 0.8, 0.9, 1.00001], device=device)
b = 1
Expand Down Expand Up @@ -1476,3 +1477,55 @@ def test_nd(self, xp: ModuleType, ndim: int):
@override
def test_input_validation(self, xp: ModuleType):
self._test_input_validation(xp)


@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no unique_inverse")
class TestIsIn:
def test_simple(self, xp: ModuleType, library: Backend):
if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24):
pytest.xfail("NumPy <1.24 has no kind kwarg in isin")

b = xp.asarray([1, 2, 3, 4])

# `a` with 1 dimension
a = xp.asarray([1, 3, 6, 10])
expected = xp.asarray([True, True, False, False])
res = isin(a, b)
xp_assert_equal(res, expected)

# `a` with 2 dimensions
a = xp.asarray([[0, 2], [4, 6]])
expected = xp.asarray([[False, True], [True, False]])
res = isin(a, b)
xp_assert_equal(res, expected)

def test_device(self, xp: ModuleType, device: Device, library: Backend):
if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24):
pytest.xfail("NumPy <1.24 has no kind kwarg in isin")

a = xp.asarray([1, 3, 6], device=device)
b = xp.asarray([1, 2, 3], device=device)
assert get_device(isin(a, b)) == device

def test_assume_unique_and_invert(
self, xp: ModuleType, device: Device, library: Backend
):
if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24):
pytest.xfail("NumPy <1.24 has no kind kwarg in isin")

a = xp.asarray([0, 3, 6, 10], device=device)
b = xp.asarray([1, 2, 3, 10], device=device)
expected = xp.asarray([True, False, True, False])
res = isin(a, b, assume_unique=True, invert=True)
assert get_device(res) == device
xp_assert_equal(res, expected)

def test_kind(self, xp: ModuleType, library: Backend):
if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24):
pytest.xfail("NumPy <1.24 has no kind kwarg in isin")

a = xp.asarray([0, 3, 6, 10])
b = xp.asarray([1, 2, 3, 10])
expected = xp.asarray([False, True, False, True])
res = isin(a, b, kind="sort")
xp_assert_equal(res, expected)