Skip to content

Commit 58e285f

Browse files
committed
perf(mpiarray): avoid redistribute copy on single process
1 parent ea85ecb commit 58e285f

File tree

1 file changed

+56
-63
lines changed

1 file changed

+56
-63
lines changed

caput/mpiarray.py

+56-63
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,9 @@ def redistribute(self, axis: int) -> "MPIArray":
739739
if self.axis == axis or self.comm is None:
740740
return self
741741

742+
if self.comm.size == 1:
743+
return MPIArray.wrap(self.local_array, axis, self.comm)
744+
742745
# Check to make sure there is enough memory to perform the redistribution.
743746
# Must be able to allocate the target array and 2 buffers. We allocate
744747
# slightly more space than needed to be safe
@@ -761,72 +764,62 @@ def redistribute(self, axis: int) -> "MPIArray":
761764
csize = self.comm.size
762765
crank = self.comm.rank
763766

764-
if csize == 1:
765-
if arr.shape[self.axis] == self.global_shape[self.axis]:
766-
# We are working on a single node.
767-
target_arr[:] = arr
768-
else:
769-
raise ValueError(
770-
f"Global shape {self.global_shape} is incompatible with local "
771-
f"array shape {self.shape}"
772-
)
773-
else:
774-
# Get the start and end of each subrange of interest
775-
_, sac, eac = mpiutil.split_all(self.global_shape[axis], self.comm)
776-
_, sar, ear = mpiutil.split_all(self.global_shape[self.axis], self.comm)
777-
# Split the soruce array into properly sized blocks for sending
778-
blocks = np.array_split(arr, np.insert(eac, 0, sac[0]), axis)[1:]
779-
# Create fixed-size contiguous buffers for sending and receiving
780-
buffer_shape = list(target_arr.shape)
781-
buffer_shape[self.axis] = max(ear - sar)
782-
buffer_shape[axis] = max(eac - sac)
783-
# Pre-allocate buffers and buffer type
784-
recv_buffer = np.empty(buffer_shape, dtype=self.dtype)
785-
send_buffer = np.empty_like(recv_buffer)
786-
buf_type = self._prep_buf(send_buffer)[1]
787-
788-
# Empty slices for target, send buf, recv buf
789-
targetsl = [slice(None)] * len(buffer_shape)
790-
sendsl = [slice(None)] * len(buffer_shape)
791-
recvsl = [slice(None)] * len(buffer_shape)
792-
# Send and recv buf have some fixed axis slices per rank
793-
sendsl[self.axis] = slice(ear[crank] - sar[crank])
794-
recvsl[axis] = slice(eac[crank] - sac[crank])
795-
796-
mpistatus = mpiutil.MPI.Status()
797-
798-
# Cyclically pass messages forward to i adjacent rank
799-
for i in range(csize):
800-
send_to = (crank + i) % csize
801-
recv_from = (crank - i) % csize
802-
803-
# Write send data into send buffer location
804-
sendsl[axis] = slice(eac[send_to] - sac[send_to])
805-
send_buffer[tuple(sendsl)] = blocks[send_to]
806-
807-
self.comm.Sendrecv(
808-
sendbuf=(send_buffer, buf_type),
809-
dest=send_to,
810-
sendtag=(csize * crank + send_to),
811-
recvbuf=(recv_buffer, buf_type),
812-
source=recv_from,
813-
recvtag=(csize * recv_from + crank),
814-
status=mpistatus,
815-
)
767+
# Get the start and end of each subrange of interest
768+
_, sac, eac = mpiutil.split_all(self.global_shape[axis], self.comm)
769+
_, sar, ear = mpiutil.split_all(self.global_shape[self.axis], self.comm)
770+
# Split the soruce array into properly sized blocks for sending
771+
blocks = np.array_split(arr, np.insert(eac, 0, sac[0]), axis)[1:]
772+
# Create fixed-size contiguous buffers for sending and receiving
773+
buffer_shape = list(target_arr.shape)
774+
buffer_shape[self.axis] = max(ear - sar)
775+
buffer_shape[axis] = max(eac - sac)
776+
# Pre-allocate buffers and buffer type
777+
recv_buffer = np.empty(buffer_shape, dtype=self.dtype)
778+
send_buffer = np.empty_like(recv_buffer)
779+
buf_type = self._prep_buf(send_buffer)[1]
780+
781+
# Empty slices for target, send buf, recv buf
782+
targetsl = [slice(None)] * len(buffer_shape)
783+
sendsl = [slice(None)] * len(buffer_shape)
784+
recvsl = [slice(None)] * len(buffer_shape)
785+
# Send and recv buf have some fixed axis slices per rank
786+
sendsl[self.axis] = slice(ear[crank] - sar[crank])
787+
recvsl[axis] = slice(eac[crank] - sac[crank])
788+
789+
mpistatus = mpiutil.MPI.Status()
790+
791+
# Cyclically pass and receive array chunks across ranks
792+
for i in range(csize):
793+
send_to = (crank + i) % csize
794+
recv_from = (crank - i) % csize
795+
796+
# Write send data into send buffer location
797+
sendsl[axis] = slice(eac[send_to] - sac[send_to])
798+
send_buffer[tuple(sendsl)] = blocks[send_to]
799+
800+
self.comm.Sendrecv(
801+
sendbuf=(send_buffer, buf_type),
802+
dest=send_to,
803+
sendtag=(csize * crank + send_to),
804+
recvbuf=(recv_buffer, buf_type),
805+
source=recv_from,
806+
recvtag=(csize * recv_from + crank),
807+
status=mpistatus,
808+
)
816809

817-
if mpistatus.error != mpiutil.MPI.SUCCESS:
818-
logger.error(
819-
f"**** ERROR in MPI SEND/RECV "
820-
f"(rank={crank}, "
821-
f"target={send_to}, "
822-
f"receive={recv_from}) ****"
823-
)
810+
if mpistatus.error != mpiutil.MPI.SUCCESS:
811+
logger.error(
812+
f"**** ERROR in MPI SEND/RECV "
813+
f"(rank={crank}, "
814+
f"target={send_to}, "
815+
f"receive={recv_from}) ****"
816+
)
824817

825-
# Write buffer into target location
826-
targetsl[self.axis] = slice(sar[recv_from], ear[recv_from])
827-
recvsl[self.axis] = slice(ear[recv_from] - sar[recv_from])
818+
# Write buffer into target location
819+
targetsl[self.axis] = slice(sar[recv_from], ear[recv_from])
820+
recvsl[self.axis] = slice(ear[recv_from] - sar[recv_from])
828821

829-
target_arr[tuple(targetsl)] = recv_buffer[tuple(recvsl)]
822+
target_arr[tuple(targetsl)] = recv_buffer[tuple(recvsl)]
830823

831824
return dist_arr
832825

0 commit comments

Comments
 (0)