@@ -56,20 +56,22 @@ macro_rules! impl_binary_op(
5656/// between `self` and `rhs`,
5757/// and return the result.
5858///
59+ /// `self` must be an `Array` or `ArcArray`.
60+ ///
5961/// If their shapes disagree, `self` is broadcast to their broadcast shape,
6062/// cloning the data if needed.
6163///
6264/// **Panics** if broadcasting isn’t possible.
6365impl <A , B , S , S2 , D , E > $trt<ArrayBase <S2 , E >> for ArrayBase <S , D >
6466where
65- A : Clone + $trt<B , Output =A >,
67+ A : Copy + $trt<B , Output =A >,
6668 B : Clone ,
67- S : Data <Elem =A >,
69+ S : DataOwned <Elem =A > + DataMut ,
6870 S2 : Data <Elem =B >,
6971 D : Dimension + BroadcastShape <E >,
7072 E : Dimension ,
7173{
72- type Output = Array < A , <D as BroadcastShape <E >>:: BroadcastOutput >;
74+ type Output = ArrayBase < S , <D as BroadcastShape <E >>:: BroadcastOutput >;
7375 fn $mth( self , rhs: ArrayBase <S2 , E >) -> Self :: Output
7476 {
7577 self . $mth( & rhs)
@@ -79,25 +81,46 @@ where
7981/// Perform elementwise
8082 #[ doc=$doc]
8183/// between reference `self` and `rhs`,
82- /// and return the result as a new `Array`.
84+ /// and return the result.
85+ ///
86+ /// `rhs` must be an `Array` or `ArcArray`.
8387///
8488/// If their shapes disagree, `self` is broadcast to their broadcast shape,
8589/// cloning the data if needed.
8690///
8791/// **Panics** if broadcasting isn’t possible.
8892impl <' a, A , B , S , S2 , D , E > $trt<ArrayBase <S2 , E >> for & ' a ArrayBase <S , D >
8993where
90- A : Clone + $trt<B , Output =A >,
91- B : Clone ,
94+ A : Clone + $trt<B , Output =B >,
95+ B : Copy ,
9296 S : Data <Elem =A >,
93- S2 : Data <Elem =B >,
94- D : Dimension + BroadcastShape < E > ,
95- E : Dimension ,
97+ S2 : DataOwned <Elem =B > + DataMut ,
98+ D : Dimension ,
99+ E : Dimension + BroadcastShape < D > ,
96100{
97- type Output = Array < A , <D as BroadcastShape <E >>:: BroadcastOutput >;
101+ type Output = ArrayBase < S2 , <E as BroadcastShape <D >>:: BroadcastOutput >;
98102 fn $mth( self , rhs: ArrayBase <S2 , E >) -> Self :: Output
99103 {
100- self . $mth( & rhs)
104+ let shape = rhs. dim. broadcast_shape( & self . dim) . unwrap( ) ;
105+ if shape. slice( ) == rhs. dim. slice( ) {
106+ let mut out = rhs. into_dimensionality:: <<E as BroadcastShape <D >>:: BroadcastOutput >( ) . unwrap( ) ;
107+ out. zip_mut_with( self , |x, y| {
108+ * x = y. clone( ) $operator x. clone( ) ;
109+ } ) ;
110+ out
111+ } else {
112+ // SAFETY: Overwrite all the elements in the array after
113+ // it is created via `zip_mut_from_pair`.
114+ let mut out = unsafe {
115+ Self :: Output :: uninitialized( shape. clone( ) . into_pattern( ) )
116+ } ;
117+ let lhs = self . broadcast( shape. clone( ) ) . unwrap( ) ;
118+ let rhs = rhs. broadcast( shape) . unwrap( ) ;
119+ out. zip_mut_from_pair( & lhs, & rhs, |x, y| {
120+ x. clone( ) $operator y. clone( )
121+ } ) ;
122+ out
123+ }
101124 }
102125}
103126
@@ -106,32 +129,44 @@ where
106129/// between `self` and reference `rhs`,
107130/// and return the result.
108131///
132+ /// `rhs` must be an `Array` or `ArcArray`.
133+ ///
109134/// If their shapes disagree, `self` is broadcast to their broadcast shape,
110135/// cloning the data if needed.
111136///
112137/// **Panics** if broadcasting isn’t possible.
113138impl <' a, A , B , S , S2 , D , E > $trt<& ' a ArrayBase <S2 , E >> for ArrayBase <S , D >
114139where
115- A : Clone + $trt<B , Output =A >,
140+ A : Copy + $trt<B , Output =A >,
116141 B : Clone ,
117- S : Data <Elem =A >,
142+ S : DataOwned <Elem =A > + DataMut ,
118143 S2 : Data <Elem =B >,
119144 D : Dimension + BroadcastShape <E >,
120145 E : Dimension ,
121146{
122- type Output = Array < A , <D as BroadcastShape <E >>:: BroadcastOutput >;
147+ type Output = ArrayBase < S , <D as BroadcastShape <E >>:: BroadcastOutput >;
123148 fn $mth( self , rhs: & ArrayBase <S2 , E >) -> Self :: Output
124149 {
125150 let shape = self . dim. broadcast_shape( & rhs. dim) . unwrap( ) ;
126- let mut self_ = if shape. slice( ) == self . dim. slice( ) {
127- self . into_owned( ) . into_dimensionality:: <<D as BroadcastShape <E >>:: BroadcastOutput >( ) . unwrap( )
151+ if shape. slice( ) == self . dim. slice( ) {
152+ let mut out = self . into_dimensionality:: <<D as BroadcastShape <E >>:: BroadcastOutput >( ) . unwrap( ) ;
153+ out. zip_mut_with( rhs, |x, y| {
154+ * x = x. clone( ) $operator y. clone( ) ;
155+ } ) ;
156+ out
128157 } else {
129- self . broadcast( shape) . unwrap( ) . to_owned( )
130- } ;
131- self_. zip_mut_with( rhs, |x, y| {
132- * x = x. clone( ) $operator y. clone( ) ;
133- } ) ;
134- self_
158+ // SAFETY: Overwrite all the elements in the array after
159+ // it is created via `zip_mut_from_pair`.
160+ let mut out = unsafe {
161+ Self :: Output :: uninitialized( shape. clone( ) . into_pattern( ) )
162+ } ;
163+ let lhs = self . broadcast( shape. clone( ) ) . unwrap( ) ;
164+ let rhs = rhs. broadcast( shape) . unwrap( ) ;
165+ out. zip_mut_from_pair( & lhs, & rhs, |x, y| {
166+ x. clone( ) $operator y. clone( )
167+ } ) ;
168+ out
169+ }
135170 }
136171}
137172
@@ -140,13 +175,13 @@ where
140175/// between references `self` and `rhs`,
141176/// and return the result as a new `Array`.
142177///
143- /// If their shapes disagree, `self` is broadcast to their broadcast shape,
178+ /// If their shapes disagree, `self` and `rhs` is broadcast to their broadcast shape,
144179/// cloning the data if needed.
145180///
146181/// **Panics** if broadcasting isn’t possible.
147182impl <' a, A , B , S , S2 , D , E > $trt<& ' a ArrayBase <S2 , E >> for & ' a ArrayBase <S , D >
148183where
149- A : Clone + $trt<B , Output =A >,
184+ A : Copy + $trt<B , Output =A >,
150185 B : Clone ,
151186 S : Data <Elem =A >,
152187 S2 : Data <Elem =B >,
@@ -156,15 +191,17 @@ where
156191 type Output = Array <A , <D as BroadcastShape <E >>:: BroadcastOutput >;
157192 fn $mth( self , rhs: & ' a ArrayBase <S2 , E >) -> Self :: Output {
158193 let shape = self . dim. broadcast_shape( & rhs. dim) . unwrap( ) ;
159- let mut self_ = if shape . slice ( ) == self . dim . slice ( ) {
160- self . to_owned ( ) . into_dimensionality :: << D as BroadcastShape < E >> :: BroadcastOutput > ( ) . unwrap ( )
161- } else {
162- self . broadcast ( shape) . unwrap ( ) . to_owned ( )
194+ // SAFETY: Overwrite all the elements in the array after
195+ // it is created via `zip_mut_from_pair`.
196+ let mut out = unsafe {
197+ Self :: Output :: uninitialized ( shape. clone ( ) . into_pattern ( ) )
163198 } ;
164- self_. zip_mut_with( rhs, |x, y| {
165- * x = x. clone( ) $operator y. clone( ) ;
199+ let lhs = self . broadcast( shape. clone( ) ) . unwrap( ) ;
200+ let rhs = rhs. broadcast( shape) . unwrap( ) ;
201+ out. zip_mut_from_pair( & lhs, & rhs, |x, y| {
202+ x. clone( ) $operator y. clone( )
166203 } ) ;
167- self_
204+ out
168205 }
169206}
170207
0 commit comments