@@ -6,11 +6,11 @@ use crate::{Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
66/// 
77/// Uses the [NumPy broadcasting rules] 
88//  (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules). 
9- fn  co_broadcast < D1 ,  D2 ,  Output > ( shape1 :  & D1 ,  shape2 :  & D2 )  -> Result < Output ,  ShapeError > 
10-      where 
11-          D1 :  Dimension , 
12-          D2 :  Dimension , 
13-          Output :  Dimension , 
9+ pub ( crate )   fn  co_broadcast < D1 ,  D2 ,  Output > ( shape1 :  & D1 ,  shape2 :  & D2 )  -> Result < Output ,  ShapeError > 
10+ where 
11+     D1 :  Dimension , 
12+     D2 :  Dimension , 
13+     Output :  Dimension , 
1414{ 
1515    let  ( k,  overflow)  = shape1. ndim ( ) . overflowing_sub ( shape2. ndim ( ) ) ; 
1616    // Swap the order if d2 is longer. 
@@ -37,40 +37,23 @@ fn co_broadcast<D1, D2, Output>(shape1: &D1, shape2: &D2) -> Result<Output, Shap
3737pub  trait  DimMax < Other :  Dimension >  { 
3838    /// The resulting dimension type after broadcasting. 
3939type  Output :  Dimension ; 
40- 
41-     /// Determines the shape after broadcasting the shapes together. 
42- /// 
43- /// If the shapes are not compatible, returns `Err`. 
44- fn  broadcast_shape ( & self ,  other :  & Other )  -> Result < Self :: Output ,  ShapeError > ; 
4540} 
4641
4742/// Dimensions of the same type remain unchanged when co_broadcast. 
4843/// So you can directly use D as the resulting type. 
4944/// (Instead of <D as DimMax<D>>::BroadcastOutput) 
5045impl < D :  Dimension >  DimMax < D >  for  D  { 
5146    type  Output  = D ; 
52- 
53-     fn  broadcast_shape ( & self ,  other :  & D )  -> Result < Self :: Output ,  ShapeError >  { 
54-         co_broadcast :: < D ,  D ,  Self :: Output > ( self ,  other) 
55-     } 
5647} 
5748
5849macro_rules!  impl_broadcast_distinct_fixed { 
5950    ( $smaller: ty,  $larger: ty)  => { 
6051        impl  DimMax <$larger> for  $smaller { 
6152            type  Output  = $larger; 
62- 
63-             fn  broadcast_shape( & self ,  other:  & $larger)  -> Result <Self :: Output ,  ShapeError > { 
64-                 co_broadcast:: <Self ,  $larger,  Self :: Output >( self ,  other) 
65-             } 
6653        } 
6754
6855        impl  DimMax <$smaller> for  $larger { 
6956            type  Output  = $larger; 
70- 
71-             fn  broadcast_shape( & self ,  other:  & $smaller)  -> Result <Self :: Output ,  ShapeError > { 
72-                 co_broadcast:: <Self ,  $smaller,  Self :: Output >( self ,  other) 
73-             } 
7457        } 
7558    } ; 
7659} 
@@ -103,3 +86,57 @@ impl_broadcast_distinct_fixed!(Ix3, IxDyn);
10386impl_broadcast_distinct_fixed ! ( Ix4 ,  IxDyn ) ; 
10487impl_broadcast_distinct_fixed ! ( Ix5 ,  IxDyn ) ; 
10588impl_broadcast_distinct_fixed ! ( Ix6 ,  IxDyn ) ; 
89+ 
90+ 
91+ #[ cfg( test) ]  
92+ mod  tests { 
93+     use  super :: co_broadcast; 
94+     use  crate :: { Dimension ,  Dim ,  DimMax ,  ShapeError ,  Ix0 ,  IxDynImpl ,  ErrorKind } ; 
95+ 
96+     #[ test]  
97+     fn  test_broadcast_shape ( )  { 
98+         fn  test_co < D1 ,  D2 > ( 
99+             d1 :  & D1 , 
100+             d2 :  & D2 , 
101+             r :  Result < <D1  as  DimMax < D2 > >:: Output ,  ShapeError > , 
102+         )  where 
103+             D1 :  Dimension  + DimMax < D2 > , 
104+             D2 :  Dimension , 
105+         { 
106+             let  d = co_broadcast :: < D1 ,  D2 ,  <D1  as  DimMax < D2 > >:: Output > ( & d1,  d2) ; 
107+             assert_eq ! ( d,  r) ; 
108+         } 
109+         test_co ( & Dim ( [ 2 ,  3 ] ) ,  & Dim ( [ 4 ,  1 ,  3 ] ) ,  Ok ( Dim ( [ 4 ,  2 ,  3 ] ) ) ) ; 
110+         test_co ( 
111+             & Dim ( [ 1 ,  2 ,  2 ] ) , 
112+             & Dim ( [ 1 ,  3 ,  4 ] ) , 
113+             Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) , 
114+         ) ; 
115+         test_co ( & Dim ( [ 3 ,  4 ,  5 ] ) ,  & Ix0 ( ) ,  Ok ( Dim ( [ 3 ,  4 ,  5 ] ) ) ) ; 
116+         let  v = vec ! [ 1 ,  2 ,  3 ,  4 ,  5 ,  6 ,  7 ] ; 
117+         test_co ( 
118+             & Dim ( vec ! [ 1 ,  1 ,  3 ,  1 ,  5 ,  1 ,  7 ] ) , 
119+             & Dim ( [ 2 ,  1 ,  4 ,  1 ,  6 ,  1 ] ) , 
120+             Ok ( Dim ( IxDynImpl :: from ( v. as_slice ( ) ) ) ) , 
121+         ) ; 
122+         let  d = Dim ( [ 1 ,  2 ,  1 ,  3 ] ) ; 
123+         test_co ( & d,  & d,  Ok ( d) ) ; 
124+         test_co ( 
125+             & Dim ( [ 2 ,  1 ,  2 ] ) . into_dyn ( ) , 
126+             & Dim ( 0 ) , 
127+             Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) , 
128+         ) ; 
129+         test_co ( 
130+             & Dim ( [ 2 ,  1 ,  1 ] ) , 
131+             & Dim ( [ 0 ,  0 ,  1 ,  3 ,  4 ] ) , 
132+             Ok ( Dim ( [ 0 ,  0 ,  2 ,  3 ,  4 ] ) ) , 
133+         ) ; 
134+         test_co ( & Dim ( [ 0 ] ) ,  & Dim ( [ 0 ,  0 ,  0 ] ) ,  Ok ( Dim ( [ 0 ,  0 ,  0 ] ) ) ) ; 
135+         test_co ( & Dim ( 1 ) ,  & Dim ( [ 1 ,  0 ,  0 ] ) ,  Ok ( Dim ( [ 1 ,  0 ,  0 ] ) ) ) ; 
136+         test_co ( 
137+             & Dim ( [ 1 ,  3 ,  0 ,  1 ,  1 ] ) , 
138+             & Dim ( [ 1 ,  2 ,  3 ,  1 ] ) , 
139+             Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) , 
140+         ) ; 
141+     } 
142+ } 
0 commit comments