From a04aa33b0e0bddb61a341d3e34c61a3d1754c515 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 23 Aug 2024 15:06:16 -0600 Subject: [PATCH] Fix issues with test_normalize_skip_axes It was incorrectly testing a "skip_len" which makes no sense when the skip axes are a list. Also fix normalize_skip_axes to raise AxisError before ValueError (for non-unique axes) as that matches the test and is easier to check. --- ndindex/shapetools.py | 4 +++- ndindex/tests/test_shapetools.py | 7 +++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/ndindex/shapetools.py b/ndindex/shapetools.py index e854cdef..4da386a3 100644 --- a/ndindex/shapetools.py +++ b/ndindex/shapetools.py @@ -482,6 +482,7 @@ def normalize_skip_axes(shapes, skip_axes): raise ValueError("skip_axes must be empty if there are no shapes") new_skip_axes = [] + err = None for shape in shapes: s = tuple(sorted(ndindex(i).reduce(len(shape), negative_int=True, axiserror=True).raw for i in skip_axes)) if len(s) != len(set(s)): @@ -489,6 +490,7 @@ def normalize_skip_axes(shapes, skip_axes): # For testing err.skip_axes = skip_axes err.shape = shape - raise err new_skip_axes.append(s) + if err: + raise err return new_skip_axes diff --git a/ndindex/tests/test_shapetools.py b/ndindex/tests/test_shapetools.py index d47501c9..5b5f9ba5 100644 --- a/ndindex/tests/test_shapetools.py +++ b/ndindex/tests/test_shapetools.py @@ -496,6 +496,8 @@ def test_asshape(): raises(TypeError, lambda: asshape(np.int64(1), allow_int=False)) raises(IndexError, lambda: asshape((2, 3), 3)) +@example([(0,), ()], (0, 0)) +@example([(0, 1), (0,), ()], [(-1,), (0,), ()]) @example([(5,)], (10,)) @example([], []) @example([()], []) @@ -521,13 +523,11 @@ def test_normalize_skip_axes(shapes, skip_axes): raises(AxisError, lambda: normalize_skip_axes(shapes, skip_axes)) return _skip_axes = [(skip_axes,)]*len(shapes) - skip_len = 1 elif isinstance(skip_axes, tuple): if not all(-min_dim <= s < min_dim for s in skip_axes): raises(AxisError, lambda: normalize_skip_axes(shapes, skip_axes)) return _skip_axes = [skip_axes]*len(shapes) - skip_len = len(skip_axes) elif not skip_axes: # empty list will be interpreted as a single skip_axes tuple assert normalize_skip_axes(shapes, skip_axes) == [()]*len(shapes) @@ -537,7 +537,6 @@ def test_normalize_skip_axes(shapes, skip_axes): raises(ValueError, lambda: normalize_skip_axes(shapes, skip_axes)) return _skip_axes = skip_axes - skip_len = len(skip_axes[0]) try: res = normalize_skip_axes(shapes, skip_axes) @@ -566,7 +565,7 @@ def test_normalize_skip_axes(shapes, skip_axes): assert len(res) == len(shapes) for shape, new_skip_axes in zip(shapes, res): - assert len(new_skip_axes) == len(set(new_skip_axes)) == skip_len + assert len(new_skip_axes) == len(set(new_skip_axes)) assert new_skip_axes == tuple(sorted(new_skip_axes)) for i in new_skip_axes: assert i < 0