Skip to content

Commit 1d82b34

Browse files
committed
Fixed missing copy for full() function if fill_value is array
1 parent 05c358e commit 1d82b34

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

dpctl/tensor/_ctors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -766,10 +766,11 @@ def full(
766766
X = dpt.asarray(
767767
fill_value,
768768
dtype=dtype,
769+
order=order,
769770
usm_type=usm_type,
770771
sycl_queue=sycl_queue,
771772
)
772-
return dpt.broadcast_to(X, sh)
773+
return dpt.copy(dpt.broadcast_to(X, sh))
773774

774775
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
775776
usm_type = usm_type if usm_type is not None else "device"

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,31 @@ def test_full_compute_follows_data():
10271027
assert np.array_equal(dpt.asnumpy(Y), np.full(10, 3, dtype="f4"))
10281028

10291029

1030+
@pytest.mark.parametrize("order1", ["F", "C"])
1031+
@pytest.mark.parametrize("order2", ["F", "C"])
1032+
def test_full_order(order1, order2):
1033+
q = get_queue_or_skip()
1034+
Xnp = np.array([1,2,3], order=order1)
1035+
Ynp = np.full((3), Xnp, order=order2)
1036+
Y = dpt.full((3), Xnp, order=order2, sycl_queue=q)
1037+
assert Y.flags.f_contiguous == Ynp.flags.f_contiguous
1038+
assert Y.flags.c_contiguous == Ynp.flags.c_contiguous
1039+
assert np.array_equal(dpt.asnumpy(Y), Ynp)
1040+
1041+
1042+
def test_full_strides():
1043+
q = get_queue_or_skip()
1044+
X = dpt.full((3,3), dpt.arange(3, dtype='i4'), sycl_queue=q)
1045+
Xnp = np.full((3,3), np.arange(3, dtype='i4'))
1046+
assert X.strides == tuple( el // Xnp.itemsize for el in Xnp.strides)
1047+
assert np.array_equal(dpt.asnumpy(X), Xnp)
1048+
1049+
X = dpt.full((3,3), dpt.arange(6, dtype='i4')[::2], sycl_queue=q)
1050+
Xnp = np.full((3,3), np.arange(6, dtype='i4')[::2])
1051+
assert X.strides == tuple( el // Xnp.itemsize for el in Xnp.strides)
1052+
assert np.array_equal(dpt.asnumpy(X), Xnp)
1053+
1054+
10301055
@pytest.mark.parametrize(
10311056
"dt",
10321057
_all_dtypes[1:],

0 commit comments

Comments
 (0)