Skip to content

Commit eb133e9

Browse files
Merge pull request #982 from IntelPython/fix_full_array_dtype
Added support for arrays for fill_value for full() function
2 parents 29bc97d + a84c615 commit eb133e9

File tree

2 files changed

+58
-2
lines changed

2 files changed

+58
-2
lines changed

dpctl/tensor/_ctors.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,7 @@ def full(
716716
dtype=None,
717717
order="C",
718718
device=None,
719-
usm_type="device",
719+
usm_type=None,
720720
sycl_queue=None,
721721
):
722722
"""
@@ -750,8 +750,29 @@ def full(
750750
"Unrecognized order keyword value, expecting 'F' or 'C'."
751751
)
752752
order = order[0].upper()
753-
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
753+
dpctl.utils.validate_usm_type(usm_type, allow_none=True)
754+
755+
if isinstance(fill_value, (dpt.usm_ndarray, np.ndarray, tuple, list)):
756+
if (
757+
isinstance(fill_value, dpt.usm_ndarray)
758+
and sycl_queue is None
759+
and device is None
760+
):
761+
sycl_queue = fill_value.sycl_queue
762+
else:
763+
sycl_queue = normalize_queue_device(
764+
sycl_queue=sycl_queue, device=device
765+
)
766+
X = dpt.asarray(
767+
fill_value,
768+
dtype=dtype,
769+
usm_type=usm_type,
770+
sycl_queue=sycl_queue,
771+
)
772+
return dpt.broadcast_to(X, sh)
773+
754774
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
775+
usm_type = usm_type if usm_type is not None else "device"
755776
dtype = _get_dtype(dtype, sycl_queue, ref_type=type(fill_value))
756777
res = dpt.usm_ndarray(
757778
sh,

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -992,6 +992,41 @@ def test_full_dtype_inference():
992992
assert np.issubdtype(dpt.full(10, 0.3 - 2j).dtype, np.complexfloating)
993993

994994

995+
def test_full_fill_array():
996+
q = get_queue_or_skip()
997+
998+
Xnp = np.array([1, 2, 3], dtype="i4")
999+
X = dpt.asarray(Xnp, sycl_queue=q)
1000+
1001+
shape = (3, 3)
1002+
Y = dpt.full(shape, X)
1003+
Ynp = np.full(shape, Xnp)
1004+
1005+
assert Y.dtype == Ynp.dtype
1006+
assert Y.usm_type == "device"
1007+
assert np.array_equal(dpt.asnumpy(Y), Ynp)
1008+
1009+
1010+
def test_full_compute_follows_data():
1011+
q1 = get_queue_or_skip()
1012+
q2 = get_queue_or_skip()
1013+
1014+
X = dpt.arange(10, dtype="i4", sycl_queue=q1, usm_type="shared")
1015+
Y = dpt.full(10, X[3])
1016+
1017+
assert Y.dtype == X.dtype
1018+
assert Y.usm_type == X.usm_type
1019+
assert dpctl.utils.get_execution_queue((Y.sycl_queue, X.sycl_queue))
1020+
assert np.array_equal(dpt.asnumpy(Y), np.full(10, 3, dtype="i4"))
1021+
1022+
Y = dpt.full(10, X[3], dtype="f4", sycl_queue=q2, usm_type="host")
1023+
1024+
assert Y.dtype == dpt.dtype("f4")
1025+
assert Y.usm_type == "host"
1026+
assert dpctl.utils.get_execution_queue((Y.sycl_queue, q2))
1027+
assert np.array_equal(dpt.asnumpy(Y), np.full(10, 3, dtype="f4"))
1028+
1029+
9951030
@pytest.mark.parametrize(
9961031
"dt",
9971032
_all_dtypes[1:],

0 commit comments

Comments
 (0)