Skip to content

Commit def663a

Browse files
committed
perf(mpiarray): avoid redistribute copy with single process
1 parent f2e6fa2 commit def663a

File tree

1 file changed

+56
-63
lines changed

1 file changed

+56
-63
lines changed

caput/mpiarray.py

+56-63
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,9 @@ def redistribute(self, axis: int) -> "MPIArray":
735735
if self.axis == axis or self.comm is None:
736736
return self
737737

738+
if self.comm.size == 1:
739+
return MPIArray.wrap(self.local_array, axis, self.comm)
740+
738741
# Check to make sure there is enough memory to perform the redistribution.
739742
# Must be able to allocate the target array and 2 buffers. We allocate
740743
# slightly more space than needed to be safe
@@ -757,72 +760,62 @@ def redistribute(self, axis: int) -> "MPIArray":
757760
csize = self.comm.size
758761
crank = self.comm.rank
759762

760-
if csize == 1:
761-
if arr.shape[self.axis] == self.global_shape[self.axis]:
762-
# We are working on a single node.
763-
target_arr[:] = arr
764-
else:
765-
raise ValueError(
766-
f"Global shape {self.global_shape} is incompatible with local "
767-
f"array shape {self.shape}"
768-
)
769-
else:
770-
# Get the start and end of each subrange of interest
771-
_, sac, eac = mpiutil.split_all(self.global_shape[axis], self.comm)
772-
_, sar, ear = mpiutil.split_all(self.global_shape[self.axis], self.comm)
773-
# Split the soruce array into properly sized blocks for sending
774-
blocks = np.array_split(arr, np.insert(eac, 0, sac[0]), axis)[1:]
775-
# Create fixed-size contiguous buffers for sending and receiving
776-
buffer_shape = list(target_arr.shape)
777-
buffer_shape[self.axis] = max(ear - sar)
778-
buffer_shape[axis] = max(eac - sac)
779-
# Pre-allocate buffers and buffer type
780-
recv_buffer = np.empty(buffer_shape, dtype=self.dtype)
781-
send_buffer = np.empty_like(recv_buffer)
782-
buf_type = self._prep_buf(send_buffer)[1]
783-
784-
# Empty slices for target, send buf, recv buf
785-
targetsl = [slice(None)] * len(buffer_shape)
786-
sendsl = [slice(None)] * len(buffer_shape)
787-
recvsl = [slice(None)] * len(buffer_shape)
788-
# Send and recv buf have some fixed axis slices per rank
789-
sendsl[self.axis] = slice(ear[crank] - sar[crank])
790-
recvsl[axis] = slice(eac[crank] - sac[crank])
791-
792-
mpistatus = mpiutil.MPI.Status()
793-
794-
# Cyclically pass messages forward to i adjacent rank
795-
for i in range(csize):
796-
send_to = (crank + i) % csize
797-
recv_from = (crank - i) % csize
798-
799-
# Write send data into send buffer location
800-
sendsl[axis] = slice(eac[send_to] - sac[send_to])
801-
send_buffer[tuple(sendsl)] = blocks[send_to]
802-
803-
self.comm.Sendrecv(
804-
sendbuf=(send_buffer, buf_type),
805-
dest=send_to,
806-
sendtag=(csize * crank + send_to),
807-
recvbuf=(recv_buffer, buf_type),
808-
source=recv_from,
809-
recvtag=(csize * recv_from + crank),
810-
status=mpistatus,
811-
)
763+
# Get the start and end of each subrange of interest
764+
_, sac, eac = mpiutil.split_all(self.global_shape[axis], self.comm)
765+
_, sar, ear = mpiutil.split_all(self.global_shape[self.axis], self.comm)
766+
# Split the soruce array into properly sized blocks for sending
767+
blocks = np.array_split(arr, np.insert(eac, 0, sac[0]), axis)[1:]
768+
# Create fixed-size contiguous buffers for sending and receiving
769+
buffer_shape = list(target_arr.shape)
770+
buffer_shape[self.axis] = max(ear - sar)
771+
buffer_shape[axis] = max(eac - sac)
772+
# Pre-allocate buffers and buffer type
773+
recv_buffer = np.empty(buffer_shape, dtype=self.dtype)
774+
send_buffer = np.empty_like(recv_buffer)
775+
buf_type = self._prep_buf(send_buffer)[1]
776+
777+
# Empty slices for target, send buf, recv buf
778+
targetsl = [slice(None)] * len(buffer_shape)
779+
sendsl = [slice(None)] * len(buffer_shape)
780+
recvsl = [slice(None)] * len(buffer_shape)
781+
# Send and recv buf have some fixed axis slices per rank
782+
sendsl[self.axis] = slice(ear[crank] - sar[crank])
783+
recvsl[axis] = slice(eac[crank] - sac[crank])
784+
785+
mpistatus = mpiutil.MPI.Status()
786+
787+
# Cyclically pass and receive array chunks across ranks
788+
for i in range(csize):
789+
send_to = (crank + i) % csize
790+
recv_from = (crank - i) % csize
791+
792+
# Write send data into send buffer location
793+
sendsl[axis] = slice(eac[send_to] - sac[send_to])
794+
send_buffer[tuple(sendsl)] = blocks[send_to]
795+
796+
self.comm.Sendrecv(
797+
sendbuf=(send_buffer, buf_type),
798+
dest=send_to,
799+
sendtag=(csize * crank + send_to),
800+
recvbuf=(recv_buffer, buf_type),
801+
source=recv_from,
802+
recvtag=(csize * recv_from + crank),
803+
status=mpistatus,
804+
)
812805

813-
if mpistatus.error != mpiutil.MPI.SUCCESS:
814-
logger.error(
815-
f"**** ERROR in MPI SEND/RECV "
816-
f"(rank={crank}, "
817-
f"target={send_to}, "
818-
f"receive={recv_from}) ****"
819-
)
806+
if mpistatus.error != mpiutil.MPI.SUCCESS:
807+
logger.error(
808+
f"**** ERROR in MPI SEND/RECV "
809+
f"(rank={crank}, "
810+
f"target={send_to}, "
811+
f"receive={recv_from}) ****"
812+
)
820813

821-
# Write buffer into target location
822-
targetsl[self.axis] = slice(sar[recv_from], ear[recv_from])
823-
recvsl[self.axis] = slice(ear[recv_from] - sar[recv_from])
814+
# Write buffer into target location
815+
targetsl[self.axis] = slice(sar[recv_from], ear[recv_from])
816+
recvsl[self.axis] = slice(ear[recv_from] - sar[recv_from])
824817

825-
target_arr[tuple(targetsl)] = recv_buffer[tuple(recvsl)]
818+
target_arr[tuple(targetsl)] = recv_buffer[tuple(recvsl)]
826819

827820
return dist_arr
828821

0 commit comments

Comments
 (0)