@@ -739,6 +739,9 @@ def redistribute(self, axis: int) -> "MPIArray":
739
739
if self .axis == axis or self .comm is None :
740
740
return self
741
741
742
+ if self .comm .size == 1 :
743
+ return MPIArray .wrap (self .local_array , axis , self .comm )
744
+
742
745
# Check to make sure there is enough memory to perform the redistribution.
743
746
# Must be able to allocate the target array and 2 buffers. We allocate
744
747
# slightly more space than needed to be safe
@@ -761,72 +764,62 @@ def redistribute(self, axis: int) -> "MPIArray":
761
764
csize = self .comm .size
762
765
crank = self .comm .rank
763
766
764
- if csize == 1 :
765
- if arr .shape [self .axis ] == self .global_shape [self .axis ]:
766
- # We are working on a single node.
767
- target_arr [:] = arr
768
- else :
769
- raise ValueError (
770
- f"Global shape { self .global_shape } is incompatible with local "
771
- f"array shape { self .shape } "
772
- )
773
- else :
774
- # Get the start and end of each subrange of interest
775
- _ , sac , eac = mpiutil .split_all (self .global_shape [axis ], self .comm )
776
- _ , sar , ear = mpiutil .split_all (self .global_shape [self .axis ], self .comm )
777
- # Split the soruce array into properly sized blocks for sending
778
- blocks = np .array_split (arr , np .insert (eac , 0 , sac [0 ]), axis )[1 :]
779
- # Create fixed-size contiguous buffers for sending and receiving
780
- buffer_shape = list (target_arr .shape )
781
- buffer_shape [self .axis ] = max (ear - sar )
782
- buffer_shape [axis ] = max (eac - sac )
783
- # Pre-allocate buffers and buffer type
784
- recv_buffer = np .empty (buffer_shape , dtype = self .dtype )
785
- send_buffer = np .empty_like (recv_buffer )
786
- buf_type = self ._prep_buf (send_buffer )[1 ]
787
-
788
- # Empty slices for target, send buf, recv buf
789
- targetsl = [slice (None )] * len (buffer_shape )
790
- sendsl = [slice (None )] * len (buffer_shape )
791
- recvsl = [slice (None )] * len (buffer_shape )
792
- # Send and recv buf have some fixed axis slices per rank
793
- sendsl [self .axis ] = slice (ear [crank ] - sar [crank ])
794
- recvsl [axis ] = slice (eac [crank ] - sac [crank ])
795
-
796
- mpistatus = mpiutil .MPI .Status ()
797
-
798
- # Cyclically pass messages forward to i adjacent rank
799
- for i in range (csize ):
800
- send_to = (crank + i ) % csize
801
- recv_from = (crank - i ) % csize
802
-
803
- # Write send data into send buffer location
804
- sendsl [axis ] = slice (eac [send_to ] - sac [send_to ])
805
- send_buffer [tuple (sendsl )] = blocks [send_to ]
806
-
807
- self .comm .Sendrecv (
808
- sendbuf = (send_buffer , buf_type ),
809
- dest = send_to ,
810
- sendtag = (csize * crank + send_to ),
811
- recvbuf = (recv_buffer , buf_type ),
812
- source = recv_from ,
813
- recvtag = (csize * recv_from + crank ),
814
- status = mpistatus ,
815
- )
767
+ # Get the start and end of each subrange of interest
768
+ _ , sac , eac = mpiutil .split_all (self .global_shape [axis ], self .comm )
769
+ _ , sar , ear = mpiutil .split_all (self .global_shape [self .axis ], self .comm )
770
+ # Split the soruce array into properly sized blocks for sending
771
+ blocks = np .array_split (arr , np .insert (eac , 0 , sac [0 ]), axis )[1 :]
772
+ # Create fixed-size contiguous buffers for sending and receiving
773
+ buffer_shape = list (target_arr .shape )
774
+ buffer_shape [self .axis ] = max (ear - sar )
775
+ buffer_shape [axis ] = max (eac - sac )
776
+ # Pre-allocate buffers and buffer type
777
+ recv_buffer = np .empty (buffer_shape , dtype = self .dtype )
778
+ send_buffer = np .empty_like (recv_buffer )
779
+ buf_type = self ._prep_buf (send_buffer )[1 ]
780
+
781
+ # Empty slices for target, send buf, recv buf
782
+ targetsl = [slice (None )] * len (buffer_shape )
783
+ sendsl = [slice (None )] * len (buffer_shape )
784
+ recvsl = [slice (None )] * len (buffer_shape )
785
+ # Send and recv buf have some fixed axis slices per rank
786
+ sendsl [self .axis ] = slice (ear [crank ] - sar [crank ])
787
+ recvsl [axis ] = slice (eac [crank ] - sac [crank ])
788
+
789
+ mpistatus = mpiutil .MPI .Status ()
790
+
791
+ # Cyclically pass and receive array chunks across ranks
792
+ for i in range (csize ):
793
+ send_to = (crank + i ) % csize
794
+ recv_from = (crank - i ) % csize
795
+
796
+ # Write send data into send buffer location
797
+ sendsl [axis ] = slice (eac [send_to ] - sac [send_to ])
798
+ send_buffer [tuple (sendsl )] = blocks [send_to ]
799
+
800
+ self .comm .Sendrecv (
801
+ sendbuf = (send_buffer , buf_type ),
802
+ dest = send_to ,
803
+ sendtag = (csize * crank + send_to ),
804
+ recvbuf = (recv_buffer , buf_type ),
805
+ source = recv_from ,
806
+ recvtag = (csize * recv_from + crank ),
807
+ status = mpistatus ,
808
+ )
816
809
817
- if mpistatus .error != mpiutil .MPI .SUCCESS :
818
- logger .error (
819
- f"**** ERROR in MPI SEND/RECV "
820
- f"(rank={ crank } , "
821
- f"target={ send_to } , "
822
- f"receive={ recv_from } ) ****"
823
- )
810
+ if mpistatus .error != mpiutil .MPI .SUCCESS :
811
+ logger .error (
812
+ f"**** ERROR in MPI SEND/RECV "
813
+ f"(rank={ crank } , "
814
+ f"target={ send_to } , "
815
+ f"receive={ recv_from } ) ****"
816
+ )
824
817
825
- # Write buffer into target location
826
- targetsl [self .axis ] = slice (sar [recv_from ], ear [recv_from ])
827
- recvsl [self .axis ] = slice (ear [recv_from ] - sar [recv_from ])
818
+ # Write buffer into target location
819
+ targetsl [self .axis ] = slice (sar [recv_from ], ear [recv_from ])
820
+ recvsl [self .axis ] = slice (ear [recv_from ] - sar [recv_from ])
828
821
829
- target_arr [tuple (targetsl )] = recv_buffer [tuple (recvsl )]
822
+ target_arr [tuple (targetsl )] = recv_buffer [tuple (recvsl )]
830
823
831
824
return dist_arr
832
825
0 commit comments