@@ -735,6 +735,13 @@ def redistribute(self, axis: int) -> "MPIArray":
735
735
if self .axis == axis or self .comm is None :
736
736
return self
737
737
738
+ # Avoid repeat mpi property calls
739
+ csize = self .comm .size
740
+ crank = self .comm .rank
741
+
742
+ if csize == 1 :
743
+ return MPIArray .wrap (self .local_array , axis , self .comm )
744
+
738
745
# Check to make sure there is enough memory to perform the redistribution.
739
746
# Must be able to allocate the target array and 2 buffers. We allocate
740
747
# slightly more space than needed to be safe
@@ -753,76 +760,63 @@ def redistribute(self, axis: int) -> "MPIArray":
753
760
# Get views into local and target arrays
754
761
arr = self .local_array
755
762
target_arr = dist_arr .local_array
756
- # Avoid repeat mpi property calls
757
- csize = self .comm .size
758
- crank = self .comm .rank
759
763
760
- if csize == 1 :
761
- if arr .shape [self .axis ] == self .global_shape [self .axis ]:
762
- # We are working on a single node.
763
- target_arr [:] = arr
764
- else :
765
- raise ValueError (
766
- f"Global shape { self .global_shape } is incompatible with local "
767
- f"array shape { self .shape } "
768
- )
769
- else :
770
- # Get the start and end of each subrange of interest
771
- _ , sac , eac = mpiutil .split_all (self .global_shape [axis ], self .comm )
772
- _ , sar , ear = mpiutil .split_all (self .global_shape [self .axis ], self .comm )
773
- # Split the soruce array into properly sized blocks for sending
774
- blocks = np .array_split (arr , np .insert (eac , 0 , sac [0 ]), axis )[1 :]
775
- # Create fixed-size contiguous buffers for sending and receiving
776
- buffer_shape = list (target_arr .shape )
777
- buffer_shape [self .axis ] = max (ear - sar )
778
- buffer_shape [axis ] = max (eac - sac )
779
- # Pre-allocate buffers and buffer type
780
- recv_buffer = np .empty (buffer_shape , dtype = self .dtype )
781
- send_buffer = np .empty_like (recv_buffer )
782
- buf_type = self ._prep_buf (send_buffer )[1 ]
783
-
784
- # Empty slices for target, send buf, recv buf
785
- targetsl = [slice (None )] * len (buffer_shape )
786
- sendsl = [slice (None )] * len (buffer_shape )
787
- recvsl = [slice (None )] * len (buffer_shape )
788
- # Send and recv buf have some fixed axis slices per rank
789
- sendsl [self .axis ] = slice (ear [crank ] - sar [crank ])
790
- recvsl [axis ] = slice (eac [crank ] - sac [crank ])
791
-
792
- mpistatus = mpiutil .MPI .Status ()
793
-
794
- # Cyclically pass messages forward to i adjacent rank
795
- for i in range (csize ):
796
- send_to = (crank + i ) % csize
797
- recv_from = (crank - i ) % csize
798
-
799
- # Write send data into send buffer location
800
- sendsl [axis ] = slice (eac [send_to ] - sac [send_to ])
801
- send_buffer [tuple (sendsl )] = blocks [send_to ]
802
-
803
- self .comm .Sendrecv (
804
- sendbuf = (send_buffer , buf_type ),
805
- dest = send_to ,
806
- sendtag = (csize * crank + send_to ),
807
- recvbuf = (recv_buffer , buf_type ),
808
- source = recv_from ,
809
- recvtag = (csize * recv_from + crank ),
810
- status = mpistatus ,
811
- )
764
+ # Get the start and end of each subrange of interest
765
+ _ , sac , eac = mpiutil .split_all (self .global_shape [axis ], self .comm )
766
+ _ , sar , ear = mpiutil .split_all (self .global_shape [self .axis ], self .comm )
767
+ # Split the soruce array into properly sized blocks for sending
768
+ blocks = np .array_split (arr , np .insert (eac , 0 , sac [0 ]), axis )[1 :]
769
+ # Create fixed-size contiguous buffers for sending and receiving
770
+ buffer_shape = list (target_arr .shape )
771
+ buffer_shape [self .axis ] = max (ear - sar )
772
+ buffer_shape [axis ] = max (eac - sac )
773
+ # Pre-allocate buffers and buffer type
774
+ recv_buffer = np .empty (buffer_shape , dtype = self .dtype )
775
+ send_buffer = np .empty_like (recv_buffer )
776
+ buf_type = self ._prep_buf (send_buffer )[1 ]
777
+
778
+ # Empty slices for target, send buf, recv buf
779
+ targetsl = [slice (None )] * len (buffer_shape )
780
+ sendsl = [slice (None )] * len (buffer_shape )
781
+ recvsl = [slice (None )] * len (buffer_shape )
782
+ # Send and recv buf have some fixed axis slices per rank
783
+ sendsl [self .axis ] = slice (ear [crank ] - sar [crank ])
784
+ recvsl [axis ] = slice (eac [crank ] - sac [crank ])
785
+
786
+ mpistatus = mpiutil .MPI .Status ()
787
+
788
+ # Cyclically pass and receive array chunks across ranks
789
+ for i in range (csize ):
790
+ send_to = (crank + i ) % csize
791
+ recv_from = (crank - i ) % csize
792
+
793
+ # Write send data into send buffer location
794
+ sendsl [axis ] = slice (eac [send_to ] - sac [send_to ])
795
+ send_buffer [tuple (sendsl )] = blocks [send_to ]
796
+
797
+ self .comm .Sendrecv (
798
+ sendbuf = (send_buffer , buf_type ),
799
+ dest = send_to ,
800
+ sendtag = (csize * crank + send_to ),
801
+ recvbuf = (recv_buffer , buf_type ),
802
+ source = recv_from ,
803
+ recvtag = (csize * recv_from + crank ),
804
+ status = mpistatus ,
805
+ )
812
806
813
- if mpistatus .error != mpiutil .MPI .SUCCESS :
814
- logger .error (
815
- f"**** ERROR in MPI SEND/RECV "
816
- f"(rank={ crank } , "
817
- f"target={ send_to } , "
818
- f"receive={ recv_from } ) ****"
819
- )
807
+ if mpistatus .error != mpiutil .MPI .SUCCESS :
808
+ logger .error (
809
+ f"**** ERROR in MPI SEND/RECV "
810
+ f"(rank={ crank } , "
811
+ f"target={ send_to } , "
812
+ f"receive={ recv_from } ) ****"
813
+ )
820
814
821
- # Write buffer into target location
822
- targetsl [self .axis ] = slice (sar [recv_from ], ear [recv_from ])
823
- recvsl [self .axis ] = slice (ear [recv_from ] - sar [recv_from ])
815
+ # Write buffer into target location
816
+ targetsl [self .axis ] = slice (sar [recv_from ], ear [recv_from ])
817
+ recvsl [self .axis ] = slice (ear [recv_from ] - sar [recv_from ])
824
818
825
- target_arr [tuple (targetsl )] = recv_buffer [tuple (recvsl )]
819
+ target_arr [tuple (targetsl )] = recv_buffer [tuple (recvsl )]
826
820
827
821
return dist_arr
828
822
0 commit comments