@@ -765,25 +765,26 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
765765 # after being checked against x1
766766 out = dpt .empty_like (out )
767767
768+ if order == "A" :
769+ order = (
770+ "F"
771+ if all (
772+ arr .flags .f_contiguous
773+ for arr in (
774+ x1 ,
775+ x2 ,
776+ )
777+ )
778+ else "C"
779+ )
780+
768781 if buf1_dt is None and buf2_dt is None :
769782 if out is None :
770783 if order == "K" :
771784 out = _empty_like_pair_orderK (
772785 x1 , x2 , res_dt , res_shape , res_usm_type , exec_q
773786 )
774787 else :
775- if order == "A" :
776- order = (
777- "F"
778- if all (
779- arr .flags .f_contiguous
780- for arr in (
781- x1 ,
782- x2 ,
783- )
784- )
785- else "C"
786- )
787788 out = dpt .empty (
788789 res_shape ,
789790 dtype = res_dt ,
@@ -823,8 +824,6 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
823824 if order == "K" :
824825 buf2 = _empty_like_orderK (x2 , buf2_dt )
825826 else :
826- if order == "A" :
827- order = "F" if x1 .flags .f_contiguous else "C"
828827 buf2 = dpt .empty_like (x2 , dtype = buf2_dt , order = order )
829828 ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
830829 src = x2 , dst = buf2 , sycl_queue = exec_q
@@ -878,8 +877,6 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
878877 if order == "K" :
879878 buf1 = _empty_like_orderK (x1 , buf1_dt )
880879 else :
881- if order == "A" :
882- order = "F" if x1 .flags .f_contiguous else "C"
883880 buf1 = dpt .empty_like (x1 , dtype = buf1_dt , order = order )
884881 ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
885882 src = x1 , dst = buf1 , sycl_queue = exec_q
@@ -929,13 +926,11 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
929926 out = dpt .squeeze (out , tuple (appended_axes ))
930927 return out
931928
932- if order in [ "K" , "A" ] :
929+ if order == "K" :
933930 if x1 .flags .f_contiguous and x2 .flags .f_contiguous :
934931 order = "F"
935932 elif x1 .flags .c_contiguous and x2 .flags .c_contiguous :
936933 order = "C"
937- else :
938- order = "C" if order == "A" else "K"
939934 if order == "K" :
940935 buf1 = _empty_like_orderK (x1 , buf1_dt )
941936 else :
0 commit comments