Skip to content

Commit 1531a63

Browse files
Coverity report fix for _linear_agebra_functions
Made changes similar to those made in _elementwise_common.py file, where processing of order="A" keyword is done based on flags of both src1 and src2 irrespective of type promotion steps necessary.
1 parent 71b1640 commit 1531a63

File tree

1 file changed

+14
-19
lines changed

1 file changed

+14
-19
lines changed

dpctl/tensor/_linear_algebra_functions.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -765,25 +765,26 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
765765
# after being checked against x1
766766
out = dpt.empty_like(out)
767767

768+
if order == "A":
769+
order = (
770+
"F"
771+
if all(
772+
arr.flags.f_contiguous
773+
for arr in (
774+
x1,
775+
x2,
776+
)
777+
)
778+
else "C"
779+
)
780+
768781
if buf1_dt is None and buf2_dt is None:
769782
if out is None:
770783
if order == "K":
771784
out = _empty_like_pair_orderK(
772785
x1, x2, res_dt, res_shape, res_usm_type, exec_q
773786
)
774787
else:
775-
if order == "A":
776-
order = (
777-
"F"
778-
if all(
779-
arr.flags.f_contiguous
780-
for arr in (
781-
x1,
782-
x2,
783-
)
784-
)
785-
else "C"
786-
)
787788
out = dpt.empty(
788789
res_shape,
789790
dtype=res_dt,
@@ -823,8 +824,6 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
823824
if order == "K":
824825
buf2 = _empty_like_orderK(x2, buf2_dt)
825826
else:
826-
if order == "A":
827-
order = "F" if x1.flags.f_contiguous else "C"
828827
buf2 = dpt.empty_like(x2, dtype=buf2_dt, order=order)
829828
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
830829
src=x2, dst=buf2, sycl_queue=exec_q
@@ -878,8 +877,6 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
878877
if order == "K":
879878
buf1 = _empty_like_orderK(x1, buf1_dt)
880879
else:
881-
if order == "A":
882-
order = "F" if x1.flags.f_contiguous else "C"
883880
buf1 = dpt.empty_like(x1, dtype=buf1_dt, order=order)
884881
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
885882
src=x1, dst=buf1, sycl_queue=exec_q
@@ -929,13 +926,11 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
929926
out = dpt.squeeze(out, tuple(appended_axes))
930927
return out
931928

932-
if order in ["K", "A"]:
929+
if order == "K":
933930
if x1.flags.f_contiguous and x2.flags.f_contiguous:
934931
order = "F"
935932
elif x1.flags.c_contiguous and x2.flags.c_contiguous:
936933
order = "C"
937-
else:
938-
order = "C" if order == "A" else "K"
939934
if order == "K":
940935
buf1 = _empty_like_orderK(x1, buf1_dt)
941936
else:

0 commit comments

Comments
 (0)