Skip to content

Commit 006d604

Browse files
committed
fix compute/communication overlap for gloo
Summary: - we current wait for pg work's future when preparing for a fragment - if we use gloo, this blocks the cpu - move the wait call to when we perform the actual sync of the fragment
1 parent b16e0cb commit 006d604

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

torchft/local_sgd.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -411,13 +411,6 @@ def prepare_sync(self) -> None:
411411
):
412412
self._average_grads()
413413

414-
for work in self._allreduce_work:
415-
work.wait()
416-
417-
if self._stream is not None:
418-
self._stop_event = torch.cuda.Event()
419-
self._stop_event.record()
420-
421414
@torch.profiler.record_function("torchft::local_sgd::perform_sync")
422415
def perform_sync(self) -> bool:
423416
"""
@@ -427,6 +420,18 @@ def perform_sync(self) -> bool:
427420
# Waiting for an allreduce before it has been sent is currently not supported.
428421
assert len(self._allreduce_work) > 0
429422

423+
with (
424+
torch.cuda.stream(self._stream)
425+
if self._stream is not None
426+
else nullcontext()
427+
):
428+
for work in self._allreduce_work:
429+
work.wait()
430+
431+
if self._stream is not None:
432+
self._stop_event = torch.cuda.Event()
433+
self._stop_event.record()
434+
430435
self.wait()
431436

432437
# save the parameters so they can be used for merging

0 commit comments

Comments
 (0)