Skip to content

Commit

Permalink
Fix assert_equal tests in cases where ndindex would turn a scalar int…
Browse files Browse the repository at this point in the history
…o a 0-D array
  • Loading branch information
asmeurer committed Sep 25, 2024
1 parent 4232472 commit 394dc1e
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 18 deletions.
20 changes: 14 additions & 6 deletions ndindex/tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import sys
from itertools import chain
import warnings
from functools import wraps
from functools import wraps, partial

from numpy import ndarray, intp, bool_, asarray, broadcast_shapes
from numpy import ndarray, generic, intp, bool_, asarray, broadcast_shapes
import numpy.testing

from pytest import fail
Expand Down Expand Up @@ -373,7 +373,7 @@ def matmul_arrays_st(draw):

reduce_kwargs = sampled_from([{}, {'negative_int': False}, {'negative_int': True}])

def assert_equal(actual, desired, err_msg='', verbose=True):
def assert_equal(actual, desired, allow_scalar_0d=False, err_msg='', verbose=True):
"""
Assert that two objects are equal.
Expand All @@ -384,12 +384,18 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
- If the objects are tuples, recursively call assert_equal to support
tuples of arrays.
- Require the types of actual and desired to be exactly the same.
- If allow_scalar_0d=True, scalars will be considered equal to equivalent
0-D arrays.
- Require the types of actual and desired to be exactly the same
(excepting for scalars when allow_scalar_0d=True).
"""
assert type(actual) is type(desired), err_msg or f"{type(actual)} != {type(desired)}"
if not (allow_scalar_0d and (isinstance(actual, generic)
or isinstance(desired, generic))):
assert type(actual) is type(desired), err_msg or f"{type(actual)} != {type(desired)}"

if isinstance(actual, ndarray):
if isinstance(actual, (ndarray, generic)):
numpy.testing.assert_equal(actual, desired, err_msg=err_msg,
verbose=verbose)
assert actual.shape == desired.shape, err_msg or f"{actual.shape} != {desired.shape}"
Expand All @@ -401,6 +407,8 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
else:
assert actual == desired, err_msg

assert_equal_allow_scalar_0d = partial(assert_equal, allow_scalar_0d=True)

def warnings_are_errors(f):
@wraps(f)
def inner(*args, **kwargs):
Expand Down
10 changes: 7 additions & 3 deletions ndindex/tests/test_ellipsis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from hypothesis.strategies import one_of, integers

from ..ndindex import ndindex
from .helpers import check_same, prod, shapes, ellipses, reduce_kwargs
from .helpers import (check_same, prod, shapes, ellipses, reduce_kwargs,
assert_equal_allow_scalar_0d)

def test_ellipsis_exhaustive():
for n in range(10):
Expand All @@ -24,7 +25,9 @@ def test_ellipsis_reduce_exhaustive():
@given(ellipses(), shapes, reduce_kwargs)
def test_ellipsis_reduce_hypothesis(idx, shape, kwargs):
a = arange(prod(shape)).reshape(shape)
check_same(a, idx, ndindex_func=lambda a, x: a[x.reduce(shape, **kwargs).raw])
check_same(a, idx,
ndindex_func=lambda a, x: a[x.reduce(shape, **kwargs).raw],
assert_equal=assert_equal_allow_scalar_0d)

def test_ellipsis_reduce_no_shape_exhaustive():
for n in range(10):
Expand All @@ -34,7 +37,8 @@ def test_ellipsis_reduce_no_shape_exhaustive():
@given(ellipses(), shapes, reduce_kwargs)
def test_ellipsis_reduce_no_shape_hypothesis(idx, shape, kwargs):
a = arange(prod(shape)).reshape(shape)
check_same(a, idx, ndindex_func=lambda a, x: a[x.reduce(**kwargs).raw])
check_same(a, idx, ndindex_func=lambda a, x: a[x.reduce(**kwargs).raw],
assert_equal=assert_equal_allow_scalar_0d)

@given(ellipses(), one_of(shapes, integers(0, 10)))
def test_ellipsis_isempty_hypothesis(idx, shape):
Expand Down
5 changes: 3 additions & 2 deletions ndindex/tests/test_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from ..integerarray import IntegerArray
from ..integer import Integer
from ..tuple import Tuple
from .helpers import ndindices, check_same, short_shapes, prod
from .helpers import (ndindices, check_same, short_shapes, prod,
assert_equal_allow_scalar_0d)

@example(True, (1,))
@example((Ellipsis, array([[ True, True]])), (1, 2))
Expand Down Expand Up @@ -41,7 +42,7 @@ def test_expand_hypothesis(idx, shape):
index = ndindex(idx)

check_same(a, index.raw, ndindex_func=lambda a, x: a[x.expand(shape).raw],
same_exception=False)
same_exception=False, assert_equal=assert_equal_allow_scalar_0d)

try:
expanded = index.expand(shape)
Expand Down
18 changes: 11 additions & 7 deletions ndindex/tests/test_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from ..booleanarray import BooleanArray
from ..integer import Integer
from ..integerarray import IntegerArray
from .helpers import check_same, Tuples, prod, short_shapes, iterslice, reduce_kwargs
from .helpers import (check_same, Tuples, prod, short_shapes, iterslice,
reduce_kwargs, assert_equal_allow_scalar_0d)

def test_tuple_constructor():
# Nested tuples are not allowed
Expand Down Expand Up @@ -96,7 +97,7 @@ def ndindex_func(a, index):
return a[ndindex((*index.raw[:index.ellipsis_index], ...,
*index.raw[index.ellipsis_index+1:])).raw]

check_same(a, t, ndindex_func=ndindex_func)
check_same(a, t, ndindex_func=ndindex_func, assert_equal=assert_equal_allow_scalar_0d)

@example((True, 0, False), 1, {})
@example((..., None), (), {})
Expand All @@ -109,8 +110,9 @@ def test_tuple_reduce_no_shape_hypothesis(t, shape, kwargs):

index = Tuple(*t)

check_same(a, index.raw, ndindex_func=lambda a, x: a[x.reduce(**kwargs).raw],
same_exception=False)
check_same(a, index.raw, ndindex_func=lambda a, x:
a[x.reduce(**kwargs).raw], same_exception=False,
assert_equal=assert_equal_allow_scalar_0d)

reduced = index.reduce(**kwargs)
if isinstance(reduced, Tuple):
Expand Down Expand Up @@ -143,8 +145,9 @@ def test_tuple_reduce_hypothesis(t, shape, kwargs):

index = Tuple(*t)

check_same(a, index.raw, ndindex_func=lambda a, x: a[x.reduce(shape, **kwargs).raw],
same_exception=False)
check_same(a, index.raw, ndindex_func=lambda a, x: a[x.reduce(shape,
**kwargs).raw], same_exception=False,
assert_equal=assert_equal_allow_scalar_0d)

negative_int = kwargs.get('negative_int', False)

Expand Down Expand Up @@ -197,7 +200,8 @@ def test_tuple_reduce_explicit():

a = arange(prod(shape)).reshape(shape)
check_same(a, before.raw, ndindex_func=lambda a, x:
a[x.reduce(shape).raw])
a[x.reduce(shape).raw],
assert_equal=assert_equal_allow_scalar_0d)

# Idempotency
assert reduced.reduce() == reduced
Expand Down

0 comments on commit 394dc1e

Please sign in to comment.