Skip to content

Commit

Permalink
Merge branch 'main' into release
Browse files Browse the repository at this point in the history
  • Loading branch information
asmeurer committed Nov 8, 2024
2 parents 1fe03a5 + 6405e80 commit 310bb62
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 8 deletions.
21 changes: 13 additions & 8 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def __hash__(self):

_default = object()

_allow_array = False
# See https://github.com/data-apis/array-api-strict/issues/67 and the comment
# on __array__ below.
_allow_array = True

class Array:
"""
Expand Down Expand Up @@ -147,15 +149,18 @@ def __repr__(self: Array, /) -> str:
mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
return prefix + mid + suffix

# Disallow __array__, meaning calling `np.func()` on an array_api_strict
# array will give an error. If we don't explicitly disallow it, NumPy
# defaults to creating an object dtype array, which would lead to
# confusing error messages at best and surprising bugs at worst.
#
# The alternative of course is to just support __array__, which is what we
# used to do. But this isn't actually supported by the standard, so it can
# In the future, _allow_array will be set to False, which will disallow
# __array__. This means calling `np.func()` on an array_api_strict array
# will give an error. If we don't explicitly disallow it, NumPy defaults
# to creating an object dtype array, which would lead to confusing error
# messages at best and surprising bugs at worst. The reason for doing this
# is that __array__ is not actually supported by the standard, so it can
# lead to code assuming np.asarray(other_array) would always work in the
# standard.
#
# This was implemented historically for compatibility, and removing it has
# caused issues for some libraries (see
# https://github.com/data-apis/array-api-strict/issues/67).
def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None) -> npt.NDArray[Any]:
# We have to allow this to be internally enabled as there's no other
# easy way to parse a list of Array objects in asarray().
Expand Down
17 changes: 17 additions & 0 deletions array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,23 @@ def test_array_conversion():
with pytest.raises(RuntimeError, match="Can not convert array"):
asarray([a])

def test__array__():
# __array__ should work for now
a = ones((2, 3))
np.array(a)

# Test the _allow_array private global flag for disabling it in the
# future.
from .. import _array_object
original_value = _array_object._allow_array
try:
_array_object._allow_array = False
a = ones((2, 3))
with pytest.raises(ValueError, match="Conversion from an array_api_strict array to a NumPy ndarray is not supported"):
np.array(a)
finally:
_array_object._allow_array = original_value

def test_allow_newaxis():
a = ones(5)
indexed_a = a[None, :]
Expand Down

0 comments on commit 310bb62

Please sign in to comment.