Skip to content

Commit 8f8047b

Browse files
authored
BUG: preserve device for scalars with array input (#479)
1 parent d3c0fe6 commit 8f8047b

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,10 +212,10 @@ def asarrays(
212212
}
213213
kind = same_dtype[type(cast(complex, b))]
214214
if xp.isdtype(a.dtype, kind):
215-
xb = xp.asarray(b, dtype=a.dtype)
215+
xb = xp.asarray(b, dtype=a.dtype, device=_compat.device(a))
216216
else:
217217
# Undefined behaviour. Let the function deal with it, if it can.
218-
xb = xp.asarray(b)
218+
xb = xp.asarray(b, device=_compat.device(a))
219219

220220
else:
221221
# Neither a nor b are Array API objects.

tests/test_funcs.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,19 @@ def test_device(self, xp: ModuleType, device: Device, equal_nan: bool):
888888
b = xp.asarray([1e-9, 1e-4, xp.nan], device=device)
889889
res = isclose(a, b, equal_nan=equal_nan)
890890
assert get_device(res) == device
891+
892+
def test_array_on_device_with_scalar(self, xp: ModuleType, device: Device):
893+
a = xp.asarray([0.01, 0.5, 0.8, 0.9, 1.00001], device=device)
894+
b = 1
895+
res = isclose(a, b)
896+
assert get_device(res) == device
897+
xp_assert_equal(res, xp.asarray([False, False, False, False, True]))
898+
899+
a = 0.1
900+
b = xp.asarray([0.01, 0.5, 0.8, 0.9, 0.100001], device=device)
901+
res = isclose(a, b)
902+
assert get_device(res) == device
903+
xp_assert_equal(res, xp.asarray([False, False, False, False, True]))
891904

892905

893906
class TestKron:

0 commit comments

Comments
 (0)