Skip to content

Commit 0486388

Browse files
authored
Add isin (#832)
1 parent 32a0b67 commit 0486388

File tree

5 files changed

+70
-0
lines changed

5 files changed

+70
-0
lines changed

cubed/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,10 @@
341341

342342
__all__ += ["argmax", "argmin", "count_nonzero", "searchsorted", "where"]
343343

344+
from .array_api.set_functions import isin
345+
346+
__all__ += ["isin"]
347+
344348
from .array_api.statistical_functions import (
345349
cumulative_prod,
346350
cumulative_sum,

cubed/array_api/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,10 @@
272272

273273
__all__ += ["argmax", "argmin", "count_nonzero", "searchsorted", "where"]
274274

275+
from .set_functions import isin
276+
277+
__all__ += ["isin"]
278+
275279
from .statistical_functions import (
276280
cumulative_prod,
277281
cumulative_sum,

cubed/array_api/set_functions.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from cubed.array_api.utility_functions import any as cubed_any
2+
from cubed.backend_array_api import namespace as nxp
3+
from cubed.core import blockwise
4+
5+
6+
def isin(x1, x2, /, *, invert=False):
7+
# based on dask isin
8+
9+
x1_axes = tuple(range(x1.ndim))
10+
x2_axes = tuple(i + x1.ndim for i in range(x2.ndim))
11+
mapped = blockwise(
12+
_isin,
13+
x1_axes + x2_axes,
14+
x1,
15+
x1_axes,
16+
x2,
17+
x2_axes,
18+
dtype=nxp.bool,
19+
adjust_chunks={axis: lambda _: 1 for axis in x2_axes},
20+
)
21+
22+
result = cubed_any(mapped, axis=x2_axes)
23+
if invert:
24+
result = ~result
25+
return result
26+
27+
28+
def _isin(a1, a2):
29+
a1_flattened = nxp.reshape(a1, (-1,))
30+
values = nxp.isin(a1_flattened, a2)
31+
return nxp.reshape(values, a1.shape + (1,) * a2.ndim)

cubed/tests/test_array_api.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,36 @@ def test_where_scalars():
847847
xp.where(condition, 0, 1)
848848

849849

850+
# Set functions
851+
852+
@pytest.mark.parametrize(("low", "high"), [(0, 10)])
853+
@pytest.mark.parametrize(
854+
("elements_shape", "elements_chunks"),
855+
[((10,), (5,)), ((10,), (3,)), ((4, 5), (3, 2)), ((20, 20), (4, 5))],
856+
)
857+
@pytest.mark.parametrize(
858+
("test_shape", "test_chunks"),
859+
[((10,), (5,)), ((10,), (3,)), ((4, 5), (3, 2)), ((20, 20), (4, 5))],
860+
)
861+
@pytest.mark.parametrize("invert", [True, False])
862+
def test_isin(
863+
low, high, elements_shape, elements_chunks, test_shape, test_chunks, invert
864+
):
865+
# based on dask test
866+
rng = np.random.default_rng()
867+
868+
a1 = rng.integers(low, high, size=elements_shape)
869+
c1 = cubed.from_array(a1, chunks=elements_chunks)
870+
871+
a2 = rng.integers(low, high, size=test_shape) - 5
872+
c2 = cubed.from_array(a2, chunks=test_chunks)
873+
874+
r_a = np.isin(a1, a2, invert=invert)
875+
r_c = xp.isin(c1, c2, invert=invert)
876+
877+
assert_array_equal(r_c, r_a)
878+
879+
850880
# Statistical functions
851881

852882

docs/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ These are functions that have not (yet) been included in the Python Array API St
5858
:nosignatures:
5959
:toctree: generated/
6060

61+
isin
6162
nanmean
6263
nansum
6364
pad

0 commit comments

Comments
 (0)