Skip to content

Commit 6c2ba2b

Browse files
committed
perf(mpiarray): avoid redistribute copy with single process
1 parent f2e6fa2 commit 6c2ba2b

File tree

1 file changed

+60
-66
lines changed

1 file changed

+60
-66
lines changed

caput/mpiarray.py

+60-66
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,13 @@ def redistribute(self, axis: int) -> "MPIArray":
735735
if self.axis == axis or self.comm is None:
736736
return self
737737

738+
# Avoid repeat mpi property calls
739+
csize = self.comm.size
740+
crank = self.comm.rank
741+
742+
if csize == 1:
743+
return MPIArray.wrap(self.local_array, axis, self.comm)
744+
738745
# Check to make sure there is enough memory to perform the redistribution.
739746
# Must be able to allocate the target array and 2 buffers. We allocate
740747
# slightly more space than needed to be safe
@@ -753,76 +760,63 @@ def redistribute(self, axis: int) -> "MPIArray":
753760
# Get views into local and target arrays
754761
arr = self.local_array
755762
target_arr = dist_arr.local_array
756-
# Avoid repeat mpi property calls
757-
csize = self.comm.size
758-
crank = self.comm.rank
759763

760-
if csize == 1:
761-
if arr.shape[self.axis] == self.global_shape[self.axis]:
762-
# We are working on a single node.
763-
target_arr[:] = arr
764-
else:
765-
raise ValueError(
766-
f"Global shape {self.global_shape} is incompatible with local "
767-
f"array shape {self.shape}"
768-
)
769-
else:
770-
# Get the start and end of each subrange of interest
771-
_, sac, eac = mpiutil.split_all(self.global_shape[axis], self.comm)
772-
_, sar, ear = mpiutil.split_all(self.global_shape[self.axis], self.comm)
773-
# Split the soruce array into properly sized blocks for sending
774-
blocks = np.array_split(arr, np.insert(eac, 0, sac[0]), axis)[1:]
775-
# Create fixed-size contiguous buffers for sending and receiving
776-
buffer_shape = list(target_arr.shape)
777-
buffer_shape[self.axis] = max(ear - sar)
778-
buffer_shape[axis] = max(eac - sac)
779-
# Pre-allocate buffers and buffer type
780-
recv_buffer = np.empty(buffer_shape, dtype=self.dtype)
781-
send_buffer = np.empty_like(recv_buffer)
782-
buf_type = self._prep_buf(send_buffer)[1]
783-
784-
# Empty slices for target, send buf, recv buf
785-
targetsl = [slice(None)] * len(buffer_shape)
786-
sendsl = [slice(None)] * len(buffer_shape)
787-
recvsl = [slice(None)] * len(buffer_shape)
788-
# Send and recv buf have some fixed axis slices per rank
789-
sendsl[self.axis] = slice(ear[crank] - sar[crank])
790-
recvsl[axis] = slice(eac[crank] - sac[crank])
791-
792-
mpistatus = mpiutil.MPI.Status()
793-
794-
# Cyclically pass messages forward to i adjacent rank
795-
for i in range(csize):
796-
send_to = (crank + i) % csize
797-
recv_from = (crank - i) % csize
798-
799-
# Write send data into send buffer location
800-
sendsl[axis] = slice(eac[send_to] - sac[send_to])
801-
send_buffer[tuple(sendsl)] = blocks[send_to]
802-
803-
self.comm.Sendrecv(
804-
sendbuf=(send_buffer, buf_type),
805-
dest=send_to,
806-
sendtag=(csize * crank + send_to),
807-
recvbuf=(recv_buffer, buf_type),
808-
source=recv_from,
809-
recvtag=(csize * recv_from + crank),
810-
status=mpistatus,
811-
)
764+
# Get the start and end of each subrange of interest
765+
_, sac, eac = mpiutil.split_all(self.global_shape[axis], self.comm)
766+
_, sar, ear = mpiutil.split_all(self.global_shape[self.axis], self.comm)
767+
# Split the soruce array into properly sized blocks for sending
768+
blocks = np.array_split(arr, np.insert(eac, 0, sac[0]), axis)[1:]
769+
# Create fixed-size contiguous buffers for sending and receiving
770+
buffer_shape = list(target_arr.shape)
771+
buffer_shape[self.axis] = max(ear - sar)
772+
buffer_shape[axis] = max(eac - sac)
773+
# Pre-allocate buffers and buffer type
774+
recv_buffer = np.empty(buffer_shape, dtype=self.dtype)
775+
send_buffer = np.empty_like(recv_buffer)
776+
buf_type = self._prep_buf(send_buffer)[1]
777+
778+
# Empty slices for target, send buf, recv buf
779+
targetsl = [slice(None)] * len(buffer_shape)
780+
sendsl = [slice(None)] * len(buffer_shape)
781+
recvsl = [slice(None)] * len(buffer_shape)
782+
# Send and recv buf have some fixed axis slices per rank
783+
sendsl[self.axis] = slice(ear[crank] - sar[crank])
784+
recvsl[axis] = slice(eac[crank] - sac[crank])
785+
786+
mpistatus = mpiutil.MPI.Status()
787+
788+
# Cyclically pass and receive array chunks across ranks
789+
for i in range(csize):
790+
send_to = (crank + i) % csize
791+
recv_from = (crank - i) % csize
792+
793+
# Write send data into send buffer location
794+
sendsl[axis] = slice(eac[send_to] - sac[send_to])
795+
send_buffer[tuple(sendsl)] = blocks[send_to]
796+
797+
self.comm.Sendrecv(
798+
sendbuf=(send_buffer, buf_type),
799+
dest=send_to,
800+
sendtag=(csize * crank + send_to),
801+
recvbuf=(recv_buffer, buf_type),
802+
source=recv_from,
803+
recvtag=(csize * recv_from + crank),
804+
status=mpistatus,
805+
)
812806

813-
if mpistatus.error != mpiutil.MPI.SUCCESS:
814-
logger.error(
815-
f"**** ERROR in MPI SEND/RECV "
816-
f"(rank={crank}, "
817-
f"target={send_to}, "
818-
f"receive={recv_from}) ****"
819-
)
807+
if mpistatus.error != mpiutil.MPI.SUCCESS:
808+
logger.error(
809+
f"**** ERROR in MPI SEND/RECV "
810+
f"(rank={crank}, "
811+
f"target={send_to}, "
812+
f"receive={recv_from}) ****"
813+
)
820814

821-
# Write buffer into target location
822-
targetsl[self.axis] = slice(sar[recv_from], ear[recv_from])
823-
recvsl[self.axis] = slice(ear[recv_from] - sar[recv_from])
815+
# Write buffer into target location
816+
targetsl[self.axis] = slice(sar[recv_from], ear[recv_from])
817+
recvsl[self.axis] = slice(ear[recv_from] - sar[recv_from])
824818

825-
target_arr[tuple(targetsl)] = recv_buffer[tuple(recvsl)]
819+
target_arr[tuple(targetsl)] = recv_buffer[tuple(recvsl)]
826820

827821
return dist_arr
828822

0 commit comments

Comments
 (0)