From ab8de38e97422f0c5be93f645b2ebf70b0cabc1c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 15 Nov 2023 01:04:52 -0700 Subject: [PATCH 1/8] Initial implementation of selected_indices() It's currently wrong for array indices because it doesn't zip them together. --- ndindex/integer.py | 8 +++++ ndindex/integerarray.py | 8 +++++ ndindex/ndindex.py | 46 ++++++++++++++++++++++++++ ndindex/slice.py | 8 +++++ ndindex/tests/test_as_subindex.py | 4 +-- ndindex/tests/test_expand.py | 4 +-- ndindex/tests/test_integerarray.py | 5 +-- ndindex/tests/test_newshape.py | 4 +-- ndindex/tests/test_selected_indices.py | 43 ++++++++++++++++++++++++ ndindex/tuple.py | 12 +++++++ 10 files changed, 134 insertions(+), 8 deletions(-) create mode 100644 ndindex/tests/test_selected_indices.py diff --git a/ndindex/integer.py b/ndindex/integer.py index 18899934..10a1a548 100644 --- a/ndindex/integer.py +++ b/ndindex/integer.py @@ -154,6 +154,13 @@ def isempty(self, shape=None): return False + def selected_indices(self, shape, axis=None): + if axis is None: + yield from self.expand(shape).selected_indices(shape) + else: + shape = asshape(shape, axis=axis) + yield self + def __eq__(self, other): if isinstance(other, Integer): return self.args == other.args @@ -166,6 +173,7 @@ def __eq__(self, other): def __hash__(self): return super().__hash__() + # Imports at the bottom to avoid circular import issues from .ndindex import ndindex from .slice import Slice diff --git a/ndindex/integerarray.py b/ndindex/integerarray.py index 67a6094e..a76f8d74 100644 --- a/ndindex/integerarray.py +++ b/ndindex/integerarray.py @@ -175,6 +175,14 @@ def as_subindex(self, index): raise NotImplementedError("IntegerArray.as_subindex is only implemented for slices") + def selected_indices(self, shape, axis=None): + if axis is None: + yield from self.expand(shape).selected_indices(shape) + else: + shape = asshape(shape, axis=axis) + for i in self.array.flat: + yield Integer(i) + def __eq__(self, other): from numpy import ndarray diff --git a/ndindex/ndindex.py b/ndindex/ndindex.py index c9e04e03..843fecad 100644 --- a/ndindex/ndindex.py +++ b/ndindex/ndindex.py @@ -598,6 +598,52 @@ def broadcast_arrays(self): """ return self + def selected_indices(self, shape): + """ + Return an iterator over all indices that are selected by `self` on an + array of shape `shape`. + + The result is a set of indices `i` such that `[a[i] for i in + idx.selected_indices(a.shape)]` is all the elements of `a[idx]`. The + indices are all iterated over in C (i.e., row major) order. + + >>> from ndindex import Slice, Tuple + >>> idx = Slice(5, 10) + >>> list(idx.selected_indices(20)) + [5, 6, 7, 8, 9] + >>> idx = Tuple(Slice(5, 10), Slice(0, 2)) + >>> list(idx.selected_indices((20, 3))) + [(5, 0), (5, 1), (6, 0), (6, 1), (7, 0), (7, 1), (8, 0), (8, 1), (9, 0), (9, 1)] + + To correspond these indices to the elements of `a[idx]`, you can use + `iter_indices(idx.newshape(shape))`, since both iterators iterate the + indices in C order. + + >>> from ndindex import iter_indices + >>> idx = Tuple(Slice(3, 5), Slice(0, 2)) + >>> shape = (5, 5) + >>> import numpy as np + >>> a = np.arange(25).reshape(shape) + >>> for a_idx, (new_idx,) in zip( + ... idx.selected_indices(shape), + ... iter_indices(idx.newshape(shape))): + ... print(a_idx, new_idx, a[a_idx.raw], a[idx.raw][new_idx.raw]) + Tuple(3, 0) Tuple(0, 0) 15 15 + Tuple(3, 1) Tuple(0, 1) 16 16 + Tuple(4, 0) Tuple(1, 0) 20 20 + Tuple(4, 1) Tuple(1, 1) 21 21 + + See Also + ======== + + ndindex.iter_indices: + An iterator of indices to select every element for arrays of a given shape. + ndindex.ChunkSize.as_subchunks: + A high-level iterator that efficiently gives only those chunks + that intersect with a given index + """ + return self.expand(shape).selected_indices(shape) + def operator_index(idx): """ Convert `idx` into an integer index using `__index__()` or raise diff --git a/ndindex/slice.py b/ndindex/slice.py index c73aca43..15397615 100644 --- a/ndindex/slice.py +++ b/ndindex/slice.py @@ -576,6 +576,14 @@ def __eq__(self, other): return self.args == other.args return False + def selected_indices(self, shape, axis=None): + if axis is None: + yield from self.expand(shape).selected_indices(shape) + else: + shape = asshape(shape, axis=axis) + for i in range(shape[axis])[self.raw]: + yield Integer(i) + # Imports at the bottom to avoid circular import issues from .ndindex import ndindex from .tuple import Tuple diff --git a/ndindex/tests/test_as_subindex.py b/ndindex/tests/test_as_subindex.py index f6a466a4..c600c654 100644 --- a/ndindex/tests/test_as_subindex.py +++ b/ndindex/tests/test_as_subindex.py @@ -1,6 +1,6 @@ from pytest import raises -from numpy import array, arange, isin, prod, unique, intp +from numpy import array, arange, isin, unique, intp from hypothesis import given, assume, example from hypothesis.strategies import integers, one_of @@ -8,7 +8,7 @@ from ..ndindex import ndindex from ..integerarray import IntegerArray from ..tuple import Tuple -from .helpers import ndindices, short_shapes, assert_equal, warnings_are_errors +from .helpers import ndindices, prod, short_shapes, assert_equal, warnings_are_errors @example((slice(0, 8), slice(0, 9), slice(0, 10)), ([2, 5, 6, 7], slice(1, 9, 1), slice(5, 10, 1)), diff --git a/ndindex/tests/test_expand.py b/ndindex/tests/test_expand.py index 88ae207b..169348f4 100644 --- a/ndindex/tests/test_expand.py +++ b/ndindex/tests/test_expand.py @@ -1,4 +1,4 @@ -from numpy import arange, prod, array, intp, empty +from numpy import arange, array, intp, empty from hypothesis import given, example from hypothesis.strategies import integers, one_of @@ -9,7 +9,7 @@ from ..integerarray import IntegerArray from ..integer import Integer from ..tuple import Tuple -from .helpers import ndindices, check_same, short_shapes +from .helpers import ndindices, check_same, short_shapes, prod @example(True, (1,)) @example((Ellipsis, array([[ True, True]])), (1, 2)) diff --git a/ndindex/tests/test_integerarray.py b/ndindex/tests/test_integerarray.py index 27ae73cd..bc89bc21 100644 --- a/ndindex/tests/test_integerarray.py +++ b/ndindex/tests/test_integerarray.py @@ -1,11 +1,12 @@ -from numpy import prod, arange, array, int8, intp, empty +from numpy import arange, array, int8, intp, empty from hypothesis import given, example from hypothesis.strategies import one_of, integers from pytest import raises -from .helpers import integer_arrays, short_shapes, check_same, assert_equal, reduce_kwargs +from .helpers import (integer_arrays, short_shapes, check_same, assert_equal, + reduce_kwargs, prod) from ..integer import Integer from ..integerarray import IntegerArray diff --git a/ndindex/tests/test_newshape.py b/ndindex/tests/test_newshape.py index bd6022a6..eacc01ea 100644 --- a/ndindex/tests/test_newshape.py +++ b/ndindex/tests/test_newshape.py @@ -1,6 +1,6 @@ from pytest import raises -from numpy import arange, prod, array, full +from numpy import arange, array, full from hypothesis import given, example from hypothesis.strategies import integers, one_of @@ -8,7 +8,7 @@ from ..ndindex import ndindex from ..tuple import Tuple from ..integer import Integer -from .helpers import ndindices, check_same, short_shapes +from .helpers import ndindices, check_same, short_shapes, prod @example(..., 0) @example((True,), ()) diff --git a/ndindex/tests/test_selected_indices.py b/ndindex/tests/test_selected_indices.py new file mode 100644 index 00000000..dbe22df9 --- /dev/null +++ b/ndindex/tests/test_selected_indices.py @@ -0,0 +1,43 @@ +from pytest import raises + +from numpy import arange + +from hypothesis import given +from hypothesis.strategies import integers, one_of + +from ..ndindex import ndindex +from ..tuple import Tuple +from ..integer import Integer +from .helpers import ndindices, check_same, short_shapes, prod + +@given(ndindices, one_of(short_shapes, integers(0, 10))) +def test_selected_indices_hypothesis(idx, shape): + if isinstance(shape, int): + a = arange(shape) + else: + a = arange(prod(shape)).reshape(shape) + + try: + ndindex(idx) + except IndexError: + pass + + def raw_func(a, idx): + return list(a[idx].flat) + + def ndindex_func(a, index): + indices = list(index.selected_indices(shape)) + for i in indices: + if len(a.shape) == 1: + assert isinstance(i, Integer) + else: + assert isinstance(i, Tuple) + assert all(isinstance(j, Integer) for j in i.args) + + return [a[i.raw] for i in indices] + + def assert_equal(a, b): + assert a == b + + check_same(a, idx, raw_func=raw_func, ndindex_func=ndindex_func, + assert_equal=assert_equal) diff --git a/ndindex/tuple.py b/ndindex/tuple.py index b9438e5d..d205b0e2 100644 --- a/ndindex/tuple.py +++ b/ndindex/tuple.py @@ -1,4 +1,5 @@ import sys +import itertools from .ndindex import NDIndex, ndindex from .subindex_helpers import subindex_slice @@ -741,6 +742,17 @@ def isempty(self, shape=None): return any(i.isempty() for i in self.args) + def selected_indices(self, shape): + shape = asshape(shape) + idx = self.expand(shape) + args = [i for i in idx.args if i not in [None, True]] + # boolean scalar + if False in args: + return + for i in itertools.product(*[arg.selected_indices(shape, axis=axis) + for axis, arg in enumerate(args)]): + yield Tuple(*i).reduce() + # Imports at the bottom to avoid circular import issues from .array import ArrayIndex from .ellipsis import ellipsis From abe5e225ba1ac8af1275eb98baf6f807b644f802 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 15 Nov 2023 17:41:33 -0700 Subject: [PATCH 2/8] Handle integer arrays correctly in selected_indices() --- ndindex/ndindex.py | 2 +- ndindex/tests/test_selected_indices.py | 2 -- ndindex/tuple.py | 50 ++++++++++++++++++++++---- 3 files changed, 44 insertions(+), 10 deletions(-) diff --git a/ndindex/ndindex.py b/ndindex/ndindex.py index 843fecad..a0af4b4c 100644 --- a/ndindex/ndindex.py +++ b/ndindex/ndindex.py @@ -598,7 +598,7 @@ def broadcast_arrays(self): """ return self - def selected_indices(self, shape): + def selected_indices(self, shape, axis=0): """ Return an iterator over all indices that are selected by `self` on an array of shape `shape`. diff --git a/ndindex/tests/test_selected_indices.py b/ndindex/tests/test_selected_indices.py index dbe22df9..fbe3e4fc 100644 --- a/ndindex/tests/test_selected_indices.py +++ b/ndindex/tests/test_selected_indices.py @@ -1,5 +1,3 @@ -from pytest import raises - from numpy import arange from hypothesis import given diff --git a/ndindex/tuple.py b/ndindex/tuple.py index d205b0e2..45dafc48 100644 --- a/ndindex/tuple.py +++ b/ndindex/tuple.py @@ -745,13 +745,49 @@ def isempty(self, shape=None): def selected_indices(self, shape): shape = asshape(shape) idx = self.expand(shape) - args = [i for i in idx.args if i not in [None, True]] - # boolean scalar - if False in args: - return - for i in itertools.product(*[arg.selected_indices(shape, axis=axis) - for axis, arg in enumerate(args)]): - yield Tuple(*i).reduce() + + # We need to zip all array indices into a single iterator. + iterators = [] + array_indices = [] + axis = 0 + for i in idx.args: + if i in [None, True]: + continue + if i == False: + return + if isinstance(i, IntegerArray): + array_indices.append(i) + else: + # Tuples do not support array indices separated by slices, + # newaxes, or ellipses. Furthermore, if there are (non-scalar + # boolean) array indices, any Integer and BooleanArray indices + # are converted to IntegerArray. So we can assume all array + # indices are together in a single block, and this is the end + # of it. + if array_indices: + iterators.append(_zipped_array_indices(array_indices, + shape, axis=axis)) + axis += len(array_indices) + array_indices.clear() + iterators.append(i.selected_indices(shape, axis=axis)) + axis += 1 + if idx.args and isinstance(idx.args[-1], IntegerArray): + iterators.append(_zipped_array_indices(array_indices, + shape, axis=axis)) + + for i in itertools.product(*iterators): + yield Tuple(*flatten(i)).reduce() + +def flatten(l): + for element in l: + if isinstance(element, tuple): + yield from element + else: + yield element + +def _zipped_array_indices(array_indices, shape, axis=0): + return zip(*[i.selected_indices(shape, axis=axis+j) + for j, i in enumerate(array_indices)]) # Imports at the bottom to avoid circular import issues from .array import ArrayIndex From eebf3ad332ac0ccc29ff60e4b50a37c8a651a384 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 15 Nov 2023 17:44:46 -0700 Subject: [PATCH 3/8] Move some private functions inside of a function definition --- ndindex/tuple.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/ndindex/tuple.py b/ndindex/tuple.py index 45dafc48..31fb5687 100644 --- a/ndindex/tuple.py +++ b/ndindex/tuple.py @@ -746,6 +746,17 @@ def selected_indices(self, shape): shape = asshape(shape) idx = self.expand(shape) + def _zipped_array_indices(array_indices, shape, axis=0): + return zip(*[i.selected_indices(shape, axis=axis+j) + for j, i in enumerate(array_indices)]) + + def _flatten(l): + for element in l: + if isinstance(element, tuple): + yield from element + else: + yield element + # We need to zip all array indices into a single iterator. iterators = [] array_indices = [] @@ -776,18 +787,7 @@ def selected_indices(self, shape): shape, axis=axis)) for i in itertools.product(*iterators): - yield Tuple(*flatten(i)).reduce() - -def flatten(l): - for element in l: - if isinstance(element, tuple): - yield from element - else: - yield element - -def _zipped_array_indices(array_indices, shape, axis=0): - return zip(*[i.selected_indices(shape, axis=axis+j) - for j, i in enumerate(array_indices)]) + yield Tuple(*_flatten(i)).reduce() # Imports at the bottom to avoid circular import issues from .array import ArrayIndex From 63870bdad1b562023eb3fbec6492af05b6fc6936 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 21 Nov 2023 00:17:38 -0700 Subject: [PATCH 4/8] Set same_exception=False in test_selected_indices This is because of https://github.com/Quansight-Labs/ndindex/issues/167. When we fix that issue, we should set this back to True. --- ndindex/tests/test_selected_indices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndindex/tests/test_selected_indices.py b/ndindex/tests/test_selected_indices.py index fbe3e4fc..e4563343 100644 --- a/ndindex/tests/test_selected_indices.py +++ b/ndindex/tests/test_selected_indices.py @@ -38,4 +38,4 @@ def assert_equal(a, b): assert a == b check_same(a, idx, raw_func=raw_func, ndindex_func=ndindex_func, - assert_equal=assert_equal) + assert_equal=assert_equal, same_exception=False) From 558dc3df55bee9e5da1a7b990f6f6dc9ca3804e6 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 7 Feb 2024 16:59:56 -0700 Subject: [PATCH 5/8] Fix doctests --- ndindex/ndindex.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ndindex/ndindex.py b/ndindex/ndindex.py index a0af4b4c..ad43721a 100644 --- a/ndindex/ndindex.py +++ b/ndindex/ndindex.py @@ -610,10 +610,14 @@ def selected_indices(self, shape, axis=0): >>> from ndindex import Slice, Tuple >>> idx = Slice(5, 10) >>> list(idx.selected_indices(20)) - [5, 6, 7, 8, 9] + [Integer(5), Integer(6), Integer(7), Integer(8), Integer(9)] >>> idx = Tuple(Slice(5, 10), Slice(0, 2)) >>> list(idx.selected_indices((20, 3))) - [(5, 0), (5, 1), (6, 0), (6, 1), (7, 0), (7, 1), (8, 0), (8, 1), (9, 0), (9, 1)] + [Tuple(5, 0), Tuple(5, 1), + Tuple(6, 0), Tuple(6, 1), + Tuple(7, 0), Tuple(7, 1), + Tuple(8, 0), Tuple(8, 1), + Tuple(9, 0), Tuple(9, 1)] To correspond these indices to the elements of `a[idx]`, you can use `iter_indices(idx.newshape(shape))`, since both iterators iterate the From e21f4aabb4beb866c1ec5b4191af74213b002cb3 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 7 Feb 2024 17:34:50 -0700 Subject: [PATCH 6/8] The ndindices strategy never generates invalid indices --- ndindex/tests/test_selected_indices.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/ndindex/tests/test_selected_indices.py b/ndindex/tests/test_selected_indices.py index e4563343..4b139bf8 100644 --- a/ndindex/tests/test_selected_indices.py +++ b/ndindex/tests/test_selected_indices.py @@ -15,10 +15,7 @@ def test_selected_indices_hypothesis(idx, shape): else: a = arange(prod(shape)).reshape(shape) - try: - ndindex(idx) - except IndexError: - pass + ndindex(idx) def raw_func(a, idx): return list(a[idx].flat) From f607d9e038546ba16f9381c449bd9dc59bbe2828 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 7 Feb 2024 17:37:30 -0700 Subject: [PATCH 7/8] Improve selected_indices test coverage --- ndindex/tests/test_ndindex.py | 1 + ndindex/tests/test_selected_indices.py | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/ndindex/tests/test_ndindex.py b/ndindex/tests/test_ndindex.py index b13cccec..6b111a4b 100644 --- a/ndindex/tests/test_ndindex.py +++ b/ndindex/tests/test_ndindex.py @@ -15,6 +15,7 @@ from .helpers import ndindices, check_same, assert_equal +@example(None) @example([1, 2]) @given(ndindices) def test_eq(idx): diff --git a/ndindex/tests/test_selected_indices.py b/ndindex/tests/test_selected_indices.py index 4b139bf8..af60b9f0 100644 --- a/ndindex/tests/test_selected_indices.py +++ b/ndindex/tests/test_selected_indices.py @@ -1,6 +1,6 @@ from numpy import arange -from hypothesis import given +from hypothesis import given, example from hypothesis.strategies import integers, one_of from ..ndindex import ndindex @@ -8,6 +8,12 @@ from ..integer import Integer from .helpers import ndindices, check_same, short_shapes, prod +@example((False, slice(0, 10)), (5, 2)) +@example((None, True, 0), (5, 2)) +@example((slice(0, 10), [0, -1]), (5, 2)) +@example(slice(0, 10), 5) +@example([0, 1, 2], 3) +@example(([0, 1, 2],), 3) @given(ndindices, one_of(short_shapes, integers(0, 10))) def test_selected_indices_hypothesis(idx, shape): if isinstance(shape, int): From 12d932172757cf3a59265c55f9fd37be87db851c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 7 Feb 2024 18:19:06 -0700 Subject: [PATCH 8/8] Fix selected_indices logic when newaxis comes after the array indices --- ndindex/tests/test_selected_indices.py | 1 + ndindex/tuple.py | 11 ++++++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/ndindex/tests/test_selected_indices.py b/ndindex/tests/test_selected_indices.py index af60b9f0..7ebd2ca6 100644 --- a/ndindex/tests/test_selected_indices.py +++ b/ndindex/tests/test_selected_indices.py @@ -8,6 +8,7 @@ from ..integer import Integer from .helpers import ndindices, check_same, short_shapes, prod +@example(([False], None), (1,)) @example((False, slice(0, 10)), (5, 2)) @example((None, True, 0), (5, 2)) @example((slice(0, 10), [0, -1]), (5, 2)) diff --git a/ndindex/tuple.py b/ndindex/tuple.py index 029f70e5..cdf6969e 100644 --- a/ndindex/tuple.py +++ b/ndindex/tuple.py @@ -762,11 +762,11 @@ def _flatten(l): array_indices = [] axis = 0 for i in idx.args: - if i in [None, True]: - continue if i == False: return - if isinstance(i, IntegerArray): + elif i == True: + pass + elif isinstance(i, IntegerArray): array_indices.append(i) else: # Tuples do not support array indices separated by slices, @@ -780,8 +780,9 @@ def _flatten(l): shape, axis=axis)) axis += len(array_indices) array_indices.clear() - iterators.append(i.selected_indices(shape, axis=axis)) - axis += 1 + if i != None: + iterators.append(i.selected_indices(shape, axis=axis)) + axis += 1 if idx.args and isinstance(idx.args[-1], IntegerArray): iterators.append(_zipped_array_indices(array_indices, shape, axis=axis))