From f7152ff2ee1dbdc39fff856ff25f7825942719f4 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 21 Oct 2024 15:14:26 +0200 Subject: [PATCH] Use array's device when promoting scalars --- array_api_strict/_array_object.py | 2 +- array_api_strict/tests/test_array_object.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index cd9a360..34caff7 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -274,7 +274,7 @@ def _promote_scalar(self, scalar): # behavior for integers within the bounds of the integer dtype. # Outside of those bounds we use the default NumPy behavior (either # cast or raise OverflowError). - return Array._new(np.array(scalar, dtype=self.dtype._np_dtype), device=CPU_DEVICE) + return Array._new(np.array(scalar, dtype=self.dtype._np_dtype), device=self.device) @staticmethod def _normalize_two_args(x1, x2) -> Tuple[Array, Array]: diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 902c398..a9ea26d 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -6,7 +6,7 @@ import pytest from .. import ones, asarray, result_type, all, equal -from .._array_object import Array, CPU_DEVICE +from .._array_object import Array, CPU_DEVICE, Device from .._dtypes import ( _all_dtypes, _boolean_dtypes, @@ -88,6 +88,14 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[0]) assert_raises(IndexError, lambda: a[:]) +def test_promoted_scalar_inherits_device(): + device1 = Device("device1") + x = asarray([1., 2, 3], device=device1) + + y = x ** 2 + + assert y.device == device1 + def test_operators(): # For every operator, we test that it works for the required type # combinations and raises TypeError otherwise