77// except according to those terms.
88
99use crate :: dimension:: BroadcastShape ;
10+ use crate :: data_traits:: MaybeUninitSubst ;
11+ use crate :: Zip ;
1012use num_complex:: Complex ;
1113
1214/// Elements that can be used as direct operands in arithmetic with arrays.
@@ -64,14 +66,15 @@ macro_rules! impl_binary_op(
6466/// **Panics** if broadcasting isn’t possible.
6567impl <A , B , S , S2 , D , E > $trt<ArrayBase <S2 , E >> for ArrayBase <S , D >
6668where
67- A : Copy + $trt<B , Output =A >,
69+ A : Clone + $trt<B , Output =A >,
6870 B : Clone ,
69- S : DataOwned <Elem =A > + DataMut ,
71+ S : DataOwned <Elem =A > + DataMut + MaybeUninitSubst <A >,
72+ <S as MaybeUninitSubst <A >>:: Output : DataMut ,
7073 S2 : Data <Elem =B >,
7174 D : Dimension + BroadcastShape <E >,
7275 E : Dimension ,
7376{
74- type Output = ArrayBase <S , <D as BroadcastShape <E >>:: BroadcastOutput >;
77+ type Output = ArrayBase <S , <D as BroadcastShape <E >>:: Output >;
7578 fn $mth( self , rhs: ArrayBase <S2 , E >) -> Self :: Output
7679 {
7780 self . $mth( & rhs)
8083
8184/// Perform elementwise
8285 #[ doc=$doc]
83- /// between reference `self` and `rhs`,
86+ /// between `self` and reference `rhs`,
8487/// and return the result.
8588///
8689/// `rhs` must be an `Array` or `ArcArray`.
@@ -89,44 +92,49 @@ where
8992/// cloning the data if needed.
9093///
9194/// **Panics** if broadcasting isn’t possible.
92- impl <' a, A , B , S , S2 , D , E > $trt<ArrayBase <S2 , E >> for & ' a ArrayBase <S , D >
95+ impl <' a, A , B , S , S2 , D , E > $trt<& ' a ArrayBase <S2 , E >> for ArrayBase <S , D >
9396where
94- A : Clone + $trt<B , Output =B >,
95- B : Copy ,
96- S : Data <Elem =A >,
97- S2 : DataOwned <Elem =B > + DataMut ,
98- D : Dimension ,
99- E : Dimension + BroadcastShape <D >,
97+ A : Clone + $trt<B , Output =A >,
98+ B : Clone ,
99+ S : DataOwned <Elem =A > + DataMut + MaybeUninitSubst <A >,
100+ <S as MaybeUninitSubst <A >>:: Output : DataMut ,
101+ S2 : Data <Elem =B >,
102+ D : Dimension + BroadcastShape <E >,
103+ E : Dimension ,
100104{
101- type Output = ArrayBase <S2 , <E as BroadcastShape <D >>:: BroadcastOutput >;
102- fn $mth( self , rhs: ArrayBase <S2 , E >) -> Self :: Output
105+ type Output = ArrayBase <S , <D as BroadcastShape <E >>:: Output >;
106+ fn $mth( self , rhs: & ArrayBase <S2 , E >) -> Self :: Output
103107 {
104- let shape = rhs . dim. broadcast_shape( & self . dim) . unwrap( ) ;
105- if shape. slice( ) == rhs . dim. slice( ) {
106- let mut out = rhs . into_dimensionality:: <<E as BroadcastShape <D >>:: BroadcastOutput >( ) . unwrap( ) ;
107- out. zip_mut_with( self , |x, y| {
108- * x = y . clone( ) $operator x . clone( ) ;
108+ let shape = self . dim. broadcast_shape( & rhs . dim) . unwrap( ) ;
109+ if shape. slice( ) == self . dim. slice( ) {
110+ let mut out = self . into_dimensionality:: <<D as BroadcastShape <E >>:: Output >( ) . unwrap( ) ;
111+ out. zip_mut_with( rhs , |x, y| {
112+ * x = x . clone( ) $operator y . clone( ) ;
109113 } ) ;
110114 out
111115 } else {
112- // SAFETY: Overwrite all the elements in the array after
113- // it is created via `zip_mut_from_pair`.
114- let mut out = unsafe {
115- Self :: Output :: uninitialized( shape. clone( ) . into_pattern( ) )
116- } ;
117116 let lhs = self . broadcast( shape. clone( ) ) . unwrap( ) ;
118- let rhs = rhs. broadcast( shape) . unwrap( ) ;
119- out. zip_mut_from_pair( & lhs, & rhs, |x, y| {
120- x. clone( ) $operator y. clone( )
121- } ) ;
122- out
117+ let rhs = rhs. broadcast( shape. clone( ) ) . unwrap( ) ;
118+ // SAFETY: Overwrite all the elements in the array after
119+ // it is created via `raw_view_mut`.
120+ unsafe {
121+ let mut out =ArrayBase :: <<S as MaybeUninitSubst <A >>:: Output , <D as BroadcastShape <E >>:: Output >:: maybe_uninit( shape. into_pattern( ) ) ;
122+ let output_view = out. raw_view_mut( ) . cast:: <A >( ) ;
123+ Zip :: from( & lhs) . and( & rhs)
124+ . and( output_view)
125+ . collect_with_partial( |x, y| {
126+ x. clone( ) $operator y. clone( )
127+ } )
128+ . release_ownership( ) ;
129+ out. assume_init( )
130+ }
123131 }
124132 }
125133}
126134
127135/// Perform elementwise
128136 #[ doc=$doc]
129- /// between `self` and reference `rhs`,
137+ /// between reference `self` and `rhs`,
130138/// and return the result.
131139///
132140/// `rhs` must be an `Array` or `ArcArray`.
@@ -135,37 +143,43 @@ where
135143/// cloning the data if needed.
136144///
137145/// **Panics** if broadcasting isn’t possible.
138- impl <' a, A , B , S , S2 , D , E > $trt<& ' a ArrayBase <S2 , E >> for ArrayBase <S , D >
146+ impl <' a, A , B , S , S2 , D , E > $trt<ArrayBase <S2 , E >> for & ' a ArrayBase <S , D >
139147where
140- A : Copy + $trt<B , Output =A >,
148+ A : Clone + $trt<B , Output =B >,
141149 B : Clone ,
142- S : DataOwned <Elem =A > + DataMut ,
143- S2 : Data <Elem =B >,
144- D : Dimension + BroadcastShape <E >,
145- E : Dimension ,
150+ S : Data <Elem =A >,
151+ S2 : DataOwned <Elem =B > + DataMut + MaybeUninitSubst <B >,
152+ <S2 as MaybeUninitSubst <B >>:: Output : DataMut ,
153+ D : Dimension ,
154+ E : Dimension + BroadcastShape <D >,
146155{
147- type Output = ArrayBase <S , <D as BroadcastShape <E >>:: BroadcastOutput >;
148- fn $mth( self , rhs: & ArrayBase <S2 , E >) -> Self :: Output
156+ type Output = ArrayBase <S2 , <E as BroadcastShape <D >>:: Output >;
157+ fn $mth( self , rhs: ArrayBase <S2 , E >) -> Self :: Output
158+ where
149159 {
150- let shape = self . dim. broadcast_shape( & rhs . dim) . unwrap( ) ;
151- if shape. slice( ) == self . dim. slice( ) {
152- let mut out = self . into_dimensionality:: <<D as BroadcastShape <E >>:: BroadcastOutput >( ) . unwrap( ) ;
153- out. zip_mut_with( rhs , |x, y| {
154- * x = x . clone( ) $operator y . clone( ) ;
160+ let shape = rhs . dim. broadcast_shape( & self . dim) . unwrap( ) ;
161+ if shape. slice( ) == rhs . dim. slice( ) {
162+ let mut out = rhs . into_dimensionality:: <<E as BroadcastShape <D >>:: Output >( ) . unwrap( ) ;
163+ out. zip_mut_with( self , |x, y| {
164+ * x = y . clone( ) $operator x . clone( ) ;
155165 } ) ;
156166 out
157167 } else {
158- // SAFETY: Overwrite all the elements in the array after
159- // it is created via `zip_mut_from_pair`.
160- let mut out = unsafe {
161- Self :: Output :: uninitialized( shape. clone( ) . into_pattern( ) )
162- } ;
163168 let lhs = self . broadcast( shape. clone( ) ) . unwrap( ) ;
164- let rhs = rhs. broadcast( shape) . unwrap( ) ;
165- out. zip_mut_from_pair( & lhs, & rhs, |x, y| {
166- x. clone( ) $operator y. clone( )
167- } ) ;
168- out
169+ let rhs = rhs. broadcast( shape. clone( ) ) . unwrap( ) ;
170+ // SAFETY: Overwrite all the elements in the array after
171+ // it is created via `raw_view_mut`.
172+ unsafe {
173+ let mut out =ArrayBase :: <<S2 as MaybeUninitSubst <B >>:: Output , <E as BroadcastShape <D >>:: Output >:: maybe_uninit( shape. into_pattern( ) ) ;
174+ let output_view = out. raw_view_mut( ) . cast:: <B >( ) ;
175+ Zip :: from( & lhs) . and( & rhs)
176+ . and( output_view)
177+ . collect_with_partial( |x, y| {
178+ x. clone( ) $operator y. clone( )
179+ } )
180+ . release_ownership( ) ;
181+ out. assume_init( )
182+ }
169183 }
170184 }
171185}
@@ -188,19 +202,12 @@ where
188202 D : Dimension + BroadcastShape <E >,
189203 E : Dimension ,
190204{
191- type Output = Array <A , <D as BroadcastShape <E >>:: BroadcastOutput >;
205+ type Output = Array <A , <D as BroadcastShape <E >>:: Output >;
192206 fn $mth( self , rhs: & ' a ArrayBase <S2 , E >) -> Self :: Output {
193207 let shape = self . dim. broadcast_shape( & rhs. dim) . unwrap( ) ;
194- // SAFETY: Overwrite all the elements in the array after
195- // it is created via `zip_mut_from_pair`.
196- let mut out = unsafe {
197- Self :: Output :: uninitialized( shape. clone( ) . into_pattern( ) )
198- } ;
199208 let lhs = self . broadcast( shape. clone( ) ) . unwrap( ) ;
200209 let rhs = rhs. broadcast( shape) . unwrap( ) ;
201- out. zip_mut_from_pair( & lhs, & rhs, |x, y| {
202- x. clone( ) $operator y. clone( )
203- } ) ;
210+ let out = Zip :: from( & lhs) . and( & rhs) . map_collect( |x, y| x. clone( ) $operator y. clone( ) ) ;
204211 out
205212 }
206213}
0 commit comments