@@ -739,6 +739,10 @@ def redistribute(self, axis: int) -> "MPIArray":
739
739
if self .axis == axis or self .comm is None :
740
740
return self
741
741
742
+ if self .comm .size == 1 :
743
+ self ._axis = axis
744
+ return self
745
+
742
746
# Check to make sure there is enough memory to perform the redistribution.
743
747
# Must be able to allocate the target array and 2 buffers. We allocate
744
748
# slightly more space than needed to be safe
@@ -757,76 +761,67 @@ def redistribute(self, axis: int) -> "MPIArray":
757
761
# Get views into local and target arrays
758
762
arr = self .local_array
759
763
target_arr = dist_arr .local_array
764
+
760
765
# Avoid repeat mpi property calls
761
766
csize = self .comm .size
762
767
crank = self .comm .rank
763
768
764
- if csize == 1 :
765
- if arr .shape [self .axis ] == self .global_shape [self .axis ]:
766
- # We are working on a single node.
767
- target_arr [:] = arr
768
- else :
769
- raise ValueError (
770
- f"Global shape { self .global_shape } is incompatible with local "
771
- f"array shape { self .shape } "
772
- )
773
- else :
774
- # Get the start and end of each subrange of interest
775
- _ , sac , eac = mpiutil .split_all (self .global_shape [axis ], self .comm )
776
- _ , sar , ear = mpiutil .split_all (self .global_shape [self .axis ], self .comm )
777
- # Split the soruce array into properly sized blocks for sending
778
- blocks = np .array_split (arr , np .insert (eac , 0 , sac [0 ]), axis )[1 :]
779
- # Create fixed-size contiguous buffers for sending and receiving
780
- buffer_shape = list (target_arr .shape )
781
- buffer_shape [self .axis ] = max (ear - sar )
782
- buffer_shape [axis ] = max (eac - sac )
783
- # Pre-allocate buffers and buffer type
784
- recv_buffer = np .empty (buffer_shape , dtype = self .dtype )
785
- send_buffer = np .empty_like (recv_buffer )
786
- buf_type = self ._prep_buf (send_buffer )[1 ]
787
-
788
- # Empty slices for target, send buf, recv buf
789
- targetsl = [slice (None )] * len (buffer_shape )
790
- sendsl = [slice (None )] * len (buffer_shape )
791
- recvsl = [slice (None )] * len (buffer_shape )
792
- # Send and recv buf have some fixed axis slices per rank
793
- sendsl [self .axis ] = slice (ear [crank ] - sar [crank ])
794
- recvsl [axis ] = slice (eac [crank ] - sac [crank ])
795
-
796
- mpistatus = mpiutil .MPI .Status ()
797
-
798
- # Cyclically pass messages forward to i adjacent rank
799
- for i in range (csize ):
800
- send_to = (crank + i ) % csize
801
- recv_from = (crank - i ) % csize
802
-
803
- # Write send data into send buffer location
804
- sendsl [axis ] = slice (eac [send_to ] - sac [send_to ])
805
- send_buffer [tuple (sendsl )] = blocks [send_to ]
806
-
807
- self .comm .Sendrecv (
808
- sendbuf = (send_buffer , buf_type ),
809
- dest = send_to ,
810
- sendtag = (csize * crank + send_to ),
811
- recvbuf = (recv_buffer , buf_type ),
812
- source = recv_from ,
813
- recvtag = (csize * recv_from + crank ),
814
- status = mpistatus ,
815
- )
769
+ # Get the start and end of each subrange of interest
770
+ _ , sac , eac = mpiutil .split_all (self .global_shape [axis ], self .comm )
771
+ _ , sar , ear = mpiutil .split_all (self .global_shape [self .axis ], self .comm )
772
+ # Split the soruce array into properly sized blocks for sending
773
+ blocks = np .array_split (arr , np .insert (eac , 0 , sac [0 ]), axis )[1 :]
774
+ # Create fixed-size contiguous buffers for sending and receiving
775
+ buffer_shape = list (target_arr .shape )
776
+ buffer_shape [self .axis ] = max (ear - sar )
777
+ buffer_shape [axis ] = max (eac - sac )
778
+ # Pre-allocate buffers and buffer type
779
+ recv_buffer = np .empty (buffer_shape , dtype = self .dtype )
780
+ send_buffer = np .empty_like (recv_buffer )
781
+ buf_type = self ._prep_buf (send_buffer )[1 ]
782
+
783
+ # Empty slices for target, send buf, recv buf
784
+ targetsl = [slice (None )] * len (buffer_shape )
785
+ sendsl = [slice (None )] * len (buffer_shape )
786
+ recvsl = [slice (None )] * len (buffer_shape )
787
+ # Send and recv buf have some fixed axis slices per rank
788
+ sendsl [self .axis ] = slice (ear [crank ] - sar [crank ])
789
+ recvsl [axis ] = slice (eac [crank ] - sac [crank ])
790
+
791
+ mpistatus = mpiutil .MPI .Status ()
792
+
793
+ # Cyclically pass messages forward to i adjacent rank
794
+ for i in range (csize ):
795
+ send_to = (crank + i ) % csize
796
+ recv_from = (crank - i ) % csize
797
+
798
+ # Write send data into send buffer location
799
+ sendsl [axis ] = slice (eac [send_to ] - sac [send_to ])
800
+ send_buffer [tuple (sendsl )] = blocks [send_to ]
801
+
802
+ self .comm .Sendrecv (
803
+ sendbuf = (send_buffer , buf_type ),
804
+ dest = send_to ,
805
+ sendtag = (csize * crank + send_to ),
806
+ recvbuf = (recv_buffer , buf_type ),
807
+ source = recv_from ,
808
+ recvtag = (csize * recv_from + crank ),
809
+ status = mpistatus ,
810
+ )
816
811
817
- if mpistatus .error != mpiutil .MPI .SUCCESS :
818
- logger .error (
819
- f"**** ERROR in MPI SEND/RECV "
820
- f"(rank={ crank } , "
821
- f"target={ send_to } , "
822
- f"receive={ recv_from } ) ****"
823
- )
812
+ if mpistatus .error != mpiutil .MPI .SUCCESS :
813
+ logger .error (
814
+ f"**** ERROR in MPI SEND/RECV "
815
+ f"(rank={ crank } , "
816
+ f"target={ send_to } , "
817
+ f"receive={ recv_from } ) ****"
818
+ )
824
819
825
- # Write buffer into target location
826
- targetsl [self .axis ] = slice (sar [recv_from ], ear [recv_from ])
827
- recvsl [self .axis ] = slice (ear [recv_from ] - sar [recv_from ])
820
+ # Write buffer into target location
821
+ targetsl [self .axis ] = slice (sar [recv_from ], ear [recv_from ])
822
+ recvsl [self .axis ] = slice (ear [recv_from ] - sar [recv_from ])
828
823
829
- target_arr [tuple (targetsl )] = recv_buffer [tuple (recvsl )]
824
+ target_arr [tuple (targetsl )] = recv_buffer [tuple (recvsl )]
830
825
831
826
return dist_arr
832
827
0 commit comments