Skip to content

Commit a2f8ba9

Browse files
committed
perf(mpiarray): avoid redistribute copy on single process
1 parent ea85ecb commit a2f8ba9

File tree

1 file changed

+58
-63
lines changed

1 file changed

+58
-63
lines changed

caput/mpiarray.py

+58-63
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,10 @@ def redistribute(self, axis: int) -> "MPIArray":
739739
if self.axis == axis or self.comm is None:
740740
return self
741741

742+
if self.comm.size == 1:
743+
self._axis = axis
744+
return self
745+
742746
# Check to make sure there is enough memory to perform the redistribution.
743747
# Must be able to allocate the target array and 2 buffers. We allocate
744748
# slightly more space than needed to be safe
@@ -757,76 +761,67 @@ def redistribute(self, axis: int) -> "MPIArray":
757761
# Get views into local and target arrays
758762
arr = self.local_array
759763
target_arr = dist_arr.local_array
764+
760765
# Avoid repeat mpi property calls
761766
csize = self.comm.size
762767
crank = self.comm.rank
763768

764-
if csize == 1:
765-
if arr.shape[self.axis] == self.global_shape[self.axis]:
766-
# We are working on a single node.
767-
target_arr[:] = arr
768-
else:
769-
raise ValueError(
770-
f"Global shape {self.global_shape} is incompatible with local "
771-
f"array shape {self.shape}"
772-
)
773-
else:
774-
# Get the start and end of each subrange of interest
775-
_, sac, eac = mpiutil.split_all(self.global_shape[axis], self.comm)
776-
_, sar, ear = mpiutil.split_all(self.global_shape[self.axis], self.comm)
777-
# Split the soruce array into properly sized blocks for sending
778-
blocks = np.array_split(arr, np.insert(eac, 0, sac[0]), axis)[1:]
779-
# Create fixed-size contiguous buffers for sending and receiving
780-
buffer_shape = list(target_arr.shape)
781-
buffer_shape[self.axis] = max(ear - sar)
782-
buffer_shape[axis] = max(eac - sac)
783-
# Pre-allocate buffers and buffer type
784-
recv_buffer = np.empty(buffer_shape, dtype=self.dtype)
785-
send_buffer = np.empty_like(recv_buffer)
786-
buf_type = self._prep_buf(send_buffer)[1]
787-
788-
# Empty slices for target, send buf, recv buf
789-
targetsl = [slice(None)] * len(buffer_shape)
790-
sendsl = [slice(None)] * len(buffer_shape)
791-
recvsl = [slice(None)] * len(buffer_shape)
792-
# Send and recv buf have some fixed axis slices per rank
793-
sendsl[self.axis] = slice(ear[crank] - sar[crank])
794-
recvsl[axis] = slice(eac[crank] - sac[crank])
795-
796-
mpistatus = mpiutil.MPI.Status()
797-
798-
# Cyclically pass messages forward to i adjacent rank
799-
for i in range(csize):
800-
send_to = (crank + i) % csize
801-
recv_from = (crank - i) % csize
802-
803-
# Write send data into send buffer location
804-
sendsl[axis] = slice(eac[send_to] - sac[send_to])
805-
send_buffer[tuple(sendsl)] = blocks[send_to]
806-
807-
self.comm.Sendrecv(
808-
sendbuf=(send_buffer, buf_type),
809-
dest=send_to,
810-
sendtag=(csize * crank + send_to),
811-
recvbuf=(recv_buffer, buf_type),
812-
source=recv_from,
813-
recvtag=(csize * recv_from + crank),
814-
status=mpistatus,
815-
)
769+
# Get the start and end of each subrange of interest
770+
_, sac, eac = mpiutil.split_all(self.global_shape[axis], self.comm)
771+
_, sar, ear = mpiutil.split_all(self.global_shape[self.axis], self.comm)
772+
# Split the soruce array into properly sized blocks for sending
773+
blocks = np.array_split(arr, np.insert(eac, 0, sac[0]), axis)[1:]
774+
# Create fixed-size contiguous buffers for sending and receiving
775+
buffer_shape = list(target_arr.shape)
776+
buffer_shape[self.axis] = max(ear - sar)
777+
buffer_shape[axis] = max(eac - sac)
778+
# Pre-allocate buffers and buffer type
779+
recv_buffer = np.empty(buffer_shape, dtype=self.dtype)
780+
send_buffer = np.empty_like(recv_buffer)
781+
buf_type = self._prep_buf(send_buffer)[1]
782+
783+
# Empty slices for target, send buf, recv buf
784+
targetsl = [slice(None)] * len(buffer_shape)
785+
sendsl = [slice(None)] * len(buffer_shape)
786+
recvsl = [slice(None)] * len(buffer_shape)
787+
# Send and recv buf have some fixed axis slices per rank
788+
sendsl[self.axis] = slice(ear[crank] - sar[crank])
789+
recvsl[axis] = slice(eac[crank] - sac[crank])
790+
791+
mpistatus = mpiutil.MPI.Status()
792+
793+
# Cyclically pass messages forward to i adjacent rank
794+
for i in range(csize):
795+
send_to = (crank + i) % csize
796+
recv_from = (crank - i) % csize
797+
798+
# Write send data into send buffer location
799+
sendsl[axis] = slice(eac[send_to] - sac[send_to])
800+
send_buffer[tuple(sendsl)] = blocks[send_to]
801+
802+
self.comm.Sendrecv(
803+
sendbuf=(send_buffer, buf_type),
804+
dest=send_to,
805+
sendtag=(csize * crank + send_to),
806+
recvbuf=(recv_buffer, buf_type),
807+
source=recv_from,
808+
recvtag=(csize * recv_from + crank),
809+
status=mpistatus,
810+
)
816811

817-
if mpistatus.error != mpiutil.MPI.SUCCESS:
818-
logger.error(
819-
f"**** ERROR in MPI SEND/RECV "
820-
f"(rank={crank}, "
821-
f"target={send_to}, "
822-
f"receive={recv_from}) ****"
823-
)
812+
if mpistatus.error != mpiutil.MPI.SUCCESS:
813+
logger.error(
814+
f"**** ERROR in MPI SEND/RECV "
815+
f"(rank={crank}, "
816+
f"target={send_to}, "
817+
f"receive={recv_from}) ****"
818+
)
824819

825-
# Write buffer into target location
826-
targetsl[self.axis] = slice(sar[recv_from], ear[recv_from])
827-
recvsl[self.axis] = slice(ear[recv_from] - sar[recv_from])
820+
# Write buffer into target location
821+
targetsl[self.axis] = slice(sar[recv_from], ear[recv_from])
822+
recvsl[self.axis] = slice(ear[recv_from] - sar[recv_from])
828823

829-
target_arr[tuple(targetsl)] = recv_buffer[tuple(recvsl)]
824+
target_arr[tuple(targetsl)] = recv_buffer[tuple(recvsl)]
830825

831826
return dist_arr
832827

0 commit comments

Comments
 (0)