Skip to content

Commit

Permalink
Merge pull request #73 from betatim/inherit-device-scalar-promotion
Browse files Browse the repository at this point in the history
FIX Use array's device when promoting scalars
  • Loading branch information
asmeurer authored Oct 21, 2024
2 parents f34fa52 + f7152ff commit e775240
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
2 changes: 1 addition & 1 deletion array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def _promote_scalar(self, scalar):
# behavior for integers within the bounds of the integer dtype.
# Outside of those bounds we use the default NumPy behavior (either
# cast or raise OverflowError).
return Array._new(np.array(scalar, dtype=self.dtype._np_dtype), device=CPU_DEVICE)
return Array._new(np.array(scalar, dtype=self.dtype._np_dtype), device=self.device)

@staticmethod
def _normalize_two_args(x1, x2) -> Tuple[Array, Array]:
Expand Down
10 changes: 9 additions & 1 deletion array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

from .. import ones, asarray, result_type, all, equal
from .._array_object import Array, CPU_DEVICE
from .._array_object import Array, CPU_DEVICE, Device
from .._dtypes import (
_all_dtypes,
_boolean_dtypes,
Expand Down Expand Up @@ -88,6 +88,14 @@ def test_validate_index():
assert_raises(IndexError, lambda: a[0])
assert_raises(IndexError, lambda: a[:])

def test_promoted_scalar_inherits_device():
device1 = Device("device1")
x = asarray([1., 2, 3], device=device1)

y = x ** 2

assert y.device == device1

def test_operators():
# For every operator, we test that it works for the required type
# combinations and raises TypeError otherwise
Expand Down

0 comments on commit e775240

Please sign in to comment.