Skip to content

Commit

Permalink
Merge pull request #166 from asmeurer/selected_indices
Browse files Browse the repository at this point in the history
Implement selected_indices()
  • Loading branch information
asmeurer authored Feb 8, 2024
2 parents e5629f5 + 12d9321 commit ff2ef22
Show file tree
Hide file tree
Showing 11 changed files with 178 additions and 8 deletions.
8 changes: 8 additions & 0 deletions ndindex/integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ def isempty(self, shape=None):

return False

def selected_indices(self, shape, axis=None):
if axis is None:
yield from self.expand(shape).selected_indices(shape)
else:
shape = asshape(shape, axis=axis)
yield self

def __eq__(self, other):
if isinstance(other, Integer):
return self.args == other.args
Expand All @@ -173,6 +180,7 @@ def __eq__(self, other):
def __hash__(self):
return super().__hash__()


# Imports at the bottom to avoid circular import issues
from .ndindex import ndindex
from .slice import Slice
Expand Down
8 changes: 8 additions & 0 deletions ndindex/integerarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,14 @@ def as_subindex(self, index):

raise NotImplementedError("IntegerArray.as_subindex is only implemented for slices")

def selected_indices(self, shape, axis=None):
if axis is None:
yield from self.expand(shape).selected_indices(shape)
else:
shape = asshape(shape, axis=axis)
for i in self.array.flat:
yield Integer(i)

def __eq__(self, other):
from numpy import ndarray

Expand Down
50 changes: 50 additions & 0 deletions ndindex/ndindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,56 @@ def broadcast_arrays(self):
"""
return self

def selected_indices(self, shape, axis=0):
"""
Return an iterator over all indices that are selected by `self` on an
array of shape `shape`.
The result is a set of indices `i` such that `[a[i] for i in
idx.selected_indices(a.shape)]` is all the elements of `a[idx]`. The
indices are all iterated over in C (i.e., row major) order.
>>> from ndindex import Slice, Tuple
>>> idx = Slice(5, 10)
>>> list(idx.selected_indices(20))
[Integer(5), Integer(6), Integer(7), Integer(8), Integer(9)]
>>> idx = Tuple(Slice(5, 10), Slice(0, 2))
>>> list(idx.selected_indices((20, 3)))
[Tuple(5, 0), Tuple(5, 1),
Tuple(6, 0), Tuple(6, 1),
Tuple(7, 0), Tuple(7, 1),
Tuple(8, 0), Tuple(8, 1),
Tuple(9, 0), Tuple(9, 1)]
To correspond these indices to the elements of `a[idx]`, you can use
`iter_indices(idx.newshape(shape))`, since both iterators iterate the
indices in C order.
>>> from ndindex import iter_indices
>>> idx = Tuple(Slice(3, 5), Slice(0, 2))
>>> shape = (5, 5)
>>> import numpy as np
>>> a = np.arange(25).reshape(shape)
>>> for a_idx, (new_idx,) in zip(
... idx.selected_indices(shape),
... iter_indices(idx.newshape(shape))):
... print(a_idx, new_idx, a[a_idx.raw], a[idx.raw][new_idx.raw])
Tuple(3, 0) Tuple(0, 0) 15 15
Tuple(3, 1) Tuple(0, 1) 16 16
Tuple(4, 0) Tuple(1, 0) 20 20
Tuple(4, 1) Tuple(1, 1) 21 21
See Also
========
ndindex.iter_indices:
An iterator of indices to select every element for arrays of a given shape.
ndindex.ChunkSize.as_subchunks:
A high-level iterator that efficiently gives only those chunks
that intersect with a given index
"""
return self.expand(shape).selected_indices(shape)

def operator_index(idx):
"""
Convert `idx` into an integer index using `__index__()` or raise
Expand Down
8 changes: 8 additions & 0 deletions ndindex/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,14 @@ def __eq__(self, other):
return self.args == other.args
return False

def selected_indices(self, shape, axis=None):
if axis is None:
yield from self.expand(shape).selected_indices(shape)
else:
shape = asshape(shape, axis=axis)
for i in range(shape[axis])[self.raw]:
yield Integer(i)

# Imports at the bottom to avoid circular import issues
from .ndindex import ndindex
from .tuple import Tuple
Expand Down
4 changes: 2 additions & 2 deletions ndindex/tests/test_as_subindex.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from pytest import raises

from numpy import array, arange, isin, prod, unique, intp
from numpy import array, arange, isin, unique, intp

from hypothesis import given, assume, example
from hypothesis.strategies import integers, one_of

from ..ndindex import ndindex
from ..integerarray import IntegerArray
from ..tuple import Tuple
from .helpers import ndindices, short_shapes, assert_equal, warnings_are_errors
from .helpers import ndindices, prod, short_shapes, assert_equal, warnings_are_errors

@example((slice(0, 8), slice(0, 9), slice(0, 10)),
([2, 5, 6, 7], slice(1, 9, 1), slice(5, 10, 1)),
Expand Down
4 changes: 2 additions & 2 deletions ndindex/tests/test_expand.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from numpy import arange, prod, array, intp, empty
from numpy import arange, array, intp, empty

from hypothesis import given, example
from hypothesis.strategies import integers, one_of
Expand All @@ -9,7 +9,7 @@
from ..integerarray import IntegerArray
from ..integer import Integer
from ..tuple import Tuple
from .helpers import ndindices, check_same, short_shapes
from .helpers import ndindices, check_same, short_shapes, prod

@example(True, (1,))
@example((Ellipsis, array([[ True, True]])), (1, 2))
Expand Down
5 changes: 3 additions & 2 deletions ndindex/tests/test_integerarray.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from numpy import prod, arange, array, int8, intp, empty
from numpy import arange, array, int8, intp, empty

from hypothesis import given, example
from hypothesis.strategies import one_of, integers

from pytest import raises

from .helpers import integer_arrays, short_shapes, check_same, assert_equal, reduce_kwargs
from .helpers import (integer_arrays, short_shapes, check_same, assert_equal,
reduce_kwargs, prod)

from ..integer import Integer
from ..integerarray import IntegerArray
Expand Down
1 change: 1 addition & 0 deletions ndindex/tests/test_ndindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .helpers import ndindices, check_same, assert_equal


@example(None)
@example([1, 2])
@given(ndindices)
def test_eq(idx):
Expand Down
4 changes: 2 additions & 2 deletions ndindex/tests/test_newshape.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from pytest import raises

from numpy import arange, prod, array, full
from numpy import arange, array, full

from hypothesis import given, example
from hypothesis.strategies import integers, one_of

from ..ndindex import ndindex
from ..tuple import Tuple
from ..integer import Integer
from .helpers import ndindices, check_same, short_shapes
from .helpers import ndindices, check_same, short_shapes, prod

@example(..., 0)
@example((True,), ())
Expand Down
45 changes: 45 additions & 0 deletions ndindex/tests/test_selected_indices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from numpy import arange

from hypothesis import given, example
from hypothesis.strategies import integers, one_of

from ..ndindex import ndindex
from ..tuple import Tuple
from ..integer import Integer
from .helpers import ndindices, check_same, short_shapes, prod

@example(([False], None), (1,))
@example((False, slice(0, 10)), (5, 2))
@example((None, True, 0), (5, 2))
@example((slice(0, 10), [0, -1]), (5, 2))
@example(slice(0, 10), 5)
@example([0, 1, 2], 3)
@example(([0, 1, 2],), 3)
@given(ndindices, one_of(short_shapes, integers(0, 10)))
def test_selected_indices_hypothesis(idx, shape):
if isinstance(shape, int):
a = arange(shape)
else:
a = arange(prod(shape)).reshape(shape)

ndindex(idx)

def raw_func(a, idx):
return list(a[idx].flat)

def ndindex_func(a, index):
indices = list(index.selected_indices(shape))
for i in indices:
if len(a.shape) == 1:
assert isinstance(i, Integer)
else:
assert isinstance(i, Tuple)
assert all(isinstance(j, Integer) for j in i.args)

return [a[i.raw] for i in indices]

def assert_equal(a, b):
assert a == b

check_same(a, idx, raw_func=raw_func, ndindex_func=ndindex_func,
assert_equal=assert_equal, same_exception=False)
49 changes: 49 additions & 0 deletions ndindex/tuple.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
import itertools

from .ndindex import NDIndex, ndindex
from .subindex_helpers import subindex_slice
Expand Down Expand Up @@ -741,6 +742,54 @@ def isempty(self, shape=None):

return any(i.isempty() for i in self.args)

def selected_indices(self, shape):
shape = asshape(shape)
idx = self.expand(shape)

def _zipped_array_indices(array_indices, shape, axis=0):
return zip(*[i.selected_indices(shape, axis=axis+j)
for j, i in enumerate(array_indices)])

def _flatten(l):
for element in l:
if isinstance(element, tuple):
yield from element
else:
yield element

# We need to zip all array indices into a single iterator.
iterators = []
array_indices = []
axis = 0
for i in idx.args:
if i == False:
return
elif i == True:
pass
elif isinstance(i, IntegerArray):
array_indices.append(i)
else:
# Tuples do not support array indices separated by slices,
# newaxes, or ellipses. Furthermore, if there are (non-scalar
# boolean) array indices, any Integer and BooleanArray indices
# are converted to IntegerArray. So we can assume all array
# indices are together in a single block, and this is the end
# of it.
if array_indices:
iterators.append(_zipped_array_indices(array_indices,
shape, axis=axis))
axis += len(array_indices)
array_indices.clear()
if i != None:
iterators.append(i.selected_indices(shape, axis=axis))
axis += 1
if idx.args and isinstance(idx.args[-1], IntegerArray):
iterators.append(_zipped_array_indices(array_indices,
shape, axis=axis))

for i in itertools.product(*iterators):
yield Tuple(*_flatten(i)).reduce()

# Imports at the bottom to avoid circular import issues
from .array import ArrayIndex
from .ellipsis import ellipsis
Expand Down

0 comments on commit ff2ef22

Please sign in to comment.