Skip to content

Commit 5f82196

Browse files
Merge pull request #995 from IntelPython/add_full_strides
Fixed missing copy for dpctl.tensor.full() function if fill_value is array
2 parents 05c358e + 6c9a8cd commit 5f82196

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

dpctl/tensor/_ctors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -766,10 +766,11 @@ def full(
766766
X = dpt.asarray(
767767
fill_value,
768768
dtype=dtype,
769+
order=order,
769770
usm_type=usm_type,
770771
sycl_queue=sycl_queue,
771772
)
772-
return dpt.broadcast_to(X, sh)
773+
return dpt.copy(dpt.broadcast_to(X, sh), order=order)
773774

774775
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
775776
usm_type = usm_type if usm_type is not None else "device"

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,31 @@ def test_full_compute_follows_data():
10271027
assert np.array_equal(dpt.asnumpy(Y), np.full(10, 3, dtype="f4"))
10281028

10291029

1030+
@pytest.mark.parametrize("order1", ["F", "C"])
1031+
@pytest.mark.parametrize("order2", ["F", "C"])
1032+
def test_full_order(order1, order2):
1033+
q = get_queue_or_skip()
1034+
Xnp = np.array([1, 2, 3], order=order1)
1035+
Ynp = np.full((3, 3), Xnp, order=order2)
1036+
Y = dpt.full((3, 3), Xnp, order=order2, sycl_queue=q)
1037+
assert Y.flags.c_contiguous == Ynp.flags.c_contiguous
1038+
assert Y.flags.f_contiguous == Ynp.flags.f_contiguous
1039+
assert np.array_equal(dpt.asnumpy(Y), Ynp)
1040+
1041+
1042+
def test_full_strides():
1043+
q = get_queue_or_skip()
1044+
X = dpt.full((3, 3), dpt.arange(3, dtype="i4"), sycl_queue=q)
1045+
Xnp = np.full((3, 3), np.arange(3, dtype="i4"))
1046+
assert X.strides == tuple(el // Xnp.itemsize for el in Xnp.strides)
1047+
assert np.array_equal(dpt.asnumpy(X), Xnp)
1048+
1049+
X = dpt.full((3, 3), dpt.arange(6, dtype="i4")[::2], sycl_queue=q)
1050+
Xnp = np.full((3, 3), np.arange(6, dtype="i4")[::2])
1051+
assert X.strides == tuple(el // Xnp.itemsize for el in Xnp.strides)
1052+
assert np.array_equal(dpt.asnumpy(X), Xnp)
1053+
1054+
10301055
@pytest.mark.parametrize(
10311056
"dt",
10321057
_all_dtypes[1:],

0 commit comments

Comments
 (0)