@@ -735,6 +735,9 @@ def redistribute(self, axis: int) -> "MPIArray":
735
735
if self .axis == axis or self .comm is None :
736
736
return self
737
737
738
+ if self .comm .size == 1 :
739
+ return MPIArray .wrap (self .local_array , axis , self .comm )
740
+
738
741
# Check to make sure there is enough memory to perform the redistribution.
739
742
# Must be able to allocate the target array and 2 buffers. We allocate
740
743
# slightly more space than needed to be safe
@@ -757,72 +760,62 @@ def redistribute(self, axis: int) -> "MPIArray":
757
760
csize = self .comm .size
758
761
crank = self .comm .rank
759
762
760
- if csize == 1 :
761
- if arr .shape [self .axis ] == self .global_shape [self .axis ]:
762
- # We are working on a single node.
763
- target_arr [:] = arr
764
- else :
765
- raise ValueError (
766
- f"Global shape { self .global_shape } is incompatible with local "
767
- f"array shape { self .shape } "
768
- )
769
- else :
770
- # Get the start and end of each subrange of interest
771
- _ , sac , eac = mpiutil .split_all (self .global_shape [axis ], self .comm )
772
- _ , sar , ear = mpiutil .split_all (self .global_shape [self .axis ], self .comm )
773
- # Split the soruce array into properly sized blocks for sending
774
- blocks = np .array_split (arr , np .insert (eac , 0 , sac [0 ]), axis )[1 :]
775
- # Create fixed-size contiguous buffers for sending and receiving
776
- buffer_shape = list (target_arr .shape )
777
- buffer_shape [self .axis ] = max (ear - sar )
778
- buffer_shape [axis ] = max (eac - sac )
779
- # Pre-allocate buffers and buffer type
780
- recv_buffer = np .empty (buffer_shape , dtype = self .dtype )
781
- send_buffer = np .empty_like (recv_buffer )
782
- buf_type = self ._prep_buf (send_buffer )[1 ]
783
-
784
- # Empty slices for target, send buf, recv buf
785
- targetsl = [slice (None )] * len (buffer_shape )
786
- sendsl = [slice (None )] * len (buffer_shape )
787
- recvsl = [slice (None )] * len (buffer_shape )
788
- # Send and recv buf have some fixed axis slices per rank
789
- sendsl [self .axis ] = slice (ear [crank ] - sar [crank ])
790
- recvsl [axis ] = slice (eac [crank ] - sac [crank ])
791
-
792
- mpistatus = mpiutil .MPI .Status ()
793
-
794
- # Cyclically pass messages forward to i adjacent rank
795
- for i in range (csize ):
796
- send_to = (crank + i ) % csize
797
- recv_from = (crank - i ) % csize
798
-
799
- # Write send data into send buffer location
800
- sendsl [axis ] = slice (eac [send_to ] - sac [send_to ])
801
- send_buffer [tuple (sendsl )] = blocks [send_to ]
802
-
803
- self .comm .Sendrecv (
804
- sendbuf = (send_buffer , buf_type ),
805
- dest = send_to ,
806
- sendtag = (csize * crank + send_to ),
807
- recvbuf = (recv_buffer , buf_type ),
808
- source = recv_from ,
809
- recvtag = (csize * recv_from + crank ),
810
- status = mpistatus ,
811
- )
763
+ # Get the start and end of each subrange of interest
764
+ _ , sac , eac = mpiutil .split_all (self .global_shape [axis ], self .comm )
765
+ _ , sar , ear = mpiutil .split_all (self .global_shape [self .axis ], self .comm )
766
+ # Split the soruce array into properly sized blocks for sending
767
+ blocks = np .array_split (arr , np .insert (eac , 0 , sac [0 ]), axis )[1 :]
768
+ # Create fixed-size contiguous buffers for sending and receiving
769
+ buffer_shape = list (target_arr .shape )
770
+ buffer_shape [self .axis ] = max (ear - sar )
771
+ buffer_shape [axis ] = max (eac - sac )
772
+ # Pre-allocate buffers and buffer type
773
+ recv_buffer = np .empty (buffer_shape , dtype = self .dtype )
774
+ send_buffer = np .empty_like (recv_buffer )
775
+ buf_type = self ._prep_buf (send_buffer )[1 ]
776
+
777
+ # Empty slices for target, send buf, recv buf
778
+ targetsl = [slice (None )] * len (buffer_shape )
779
+ sendsl = [slice (None )] * len (buffer_shape )
780
+ recvsl = [slice (None )] * len (buffer_shape )
781
+ # Send and recv buf have some fixed axis slices per rank
782
+ sendsl [self .axis ] = slice (ear [crank ] - sar [crank ])
783
+ recvsl [axis ] = slice (eac [crank ] - sac [crank ])
784
+
785
+ mpistatus = mpiutil .MPI .Status ()
786
+
787
+ # Cyclically pass and receive array chunks across ranks
788
+ for i in range (csize ):
789
+ send_to = (crank + i ) % csize
790
+ recv_from = (crank - i ) % csize
791
+
792
+ # Write send data into send buffer location
793
+ sendsl [axis ] = slice (eac [send_to ] - sac [send_to ])
794
+ send_buffer [tuple (sendsl )] = blocks [send_to ]
795
+
796
+ self .comm .Sendrecv (
797
+ sendbuf = (send_buffer , buf_type ),
798
+ dest = send_to ,
799
+ sendtag = (csize * crank + send_to ),
800
+ recvbuf = (recv_buffer , buf_type ),
801
+ source = recv_from ,
802
+ recvtag = (csize * recv_from + crank ),
803
+ status = mpistatus ,
804
+ )
812
805
813
- if mpistatus .error != mpiutil .MPI .SUCCESS :
814
- logger .error (
815
- f"**** ERROR in MPI SEND/RECV "
816
- f"(rank={ crank } , "
817
- f"target={ send_to } , "
818
- f"receive={ recv_from } ) ****"
819
- )
806
+ if mpistatus .error != mpiutil .MPI .SUCCESS :
807
+ logger .error (
808
+ f"**** ERROR in MPI SEND/RECV "
809
+ f"(rank={ crank } , "
810
+ f"target={ send_to } , "
811
+ f"receive={ recv_from } ) ****"
812
+ )
820
813
821
- # Write buffer into target location
822
- targetsl [self .axis ] = slice (sar [recv_from ], ear [recv_from ])
823
- recvsl [self .axis ] = slice (ear [recv_from ] - sar [recv_from ])
814
+ # Write buffer into target location
815
+ targetsl [self .axis ] = slice (sar [recv_from ], ear [recv_from ])
816
+ recvsl [self .axis ] = slice (ear [recv_from ] - sar [recv_from ])
824
817
825
- target_arr [tuple (targetsl )] = recv_buffer [tuple (recvsl )]
818
+ target_arr [tuple (targetsl )] = recv_buffer [tuple (recvsl )]
826
819
827
820
return dist_arr
828
821
0 commit comments