@@ -765,25 +765,26 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
765
765
# after being checked against x1
766
766
out = dpt .empty_like (out )
767
767
768
+ if order == "A" :
769
+ order = (
770
+ "F"
771
+ if all (
772
+ arr .flags .f_contiguous
773
+ for arr in (
774
+ x1 ,
775
+ x2 ,
776
+ )
777
+ )
778
+ else "C"
779
+ )
780
+
768
781
if buf1_dt is None and buf2_dt is None :
769
782
if out is None :
770
783
if order == "K" :
771
784
out = _empty_like_pair_orderK (
772
785
x1 , x2 , res_dt , res_shape , res_usm_type , exec_q
773
786
)
774
787
else :
775
- if order == "A" :
776
- order = (
777
- "F"
778
- if all (
779
- arr .flags .f_contiguous
780
- for arr in (
781
- x1 ,
782
- x2 ,
783
- )
784
- )
785
- else "C"
786
- )
787
788
out = dpt .empty (
788
789
res_shape ,
789
790
dtype = res_dt ,
@@ -823,8 +824,6 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
823
824
if order == "K" :
824
825
buf2 = _empty_like_orderK (x2 , buf2_dt )
825
826
else :
826
- if order == "A" :
827
- order = "F" if x1 .flags .f_contiguous else "C"
828
827
buf2 = dpt .empty_like (x2 , dtype = buf2_dt , order = order )
829
828
ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
830
829
src = x2 , dst = buf2 , sycl_queue = exec_q
@@ -878,8 +877,6 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
878
877
if order == "K" :
879
878
buf1 = _empty_like_orderK (x1 , buf1_dt )
880
879
else :
881
- if order == "A" :
882
- order = "F" if x1 .flags .f_contiguous else "C"
883
880
buf1 = dpt .empty_like (x1 , dtype = buf1_dt , order = order )
884
881
ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
885
882
src = x1 , dst = buf1 , sycl_queue = exec_q
@@ -929,13 +926,11 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
929
926
out = dpt .squeeze (out , tuple (appended_axes ))
930
927
return out
931
928
932
- if order in [ "K" , "A" ] :
929
+ if order == "K" :
933
930
if x1 .flags .f_contiguous and x2 .flags .f_contiguous :
934
931
order = "F"
935
932
elif x1 .flags .c_contiguous and x2 .flags .c_contiguous :
936
933
order = "C"
937
- else :
938
- order = "C" if order == "A" else "K"
939
934
if order == "K" :
940
935
buf1 = _empty_like_orderK (x1 , buf1_dt )
941
936
else :
0 commit comments