Skip to content

Commit 91207a2

Browse files
committed
fix compute/communication overlap for gloo
Summary: - we current wait for pg work's future when preparing for a fragment - if we use gloo, this blocks the cpu - move the wait call to when we perform the actual sync of the fragment - since we still call `work.wait()` in the allreduce call itself this doesn't completely fix the problem
1 parent 843854d commit 91207a2

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

torchft/local_sgd.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -401,13 +401,6 @@ def prepare_sync(self) -> None:
401401
):
402402
self._average_grads()
403403

404-
for work in self._allreduce_work:
405-
work.get_future().wait()
406-
407-
if self._stream is not None:
408-
self._stop_event = torch.cuda.Event()
409-
self._stop_event.record()
410-
411404
@torch.profiler.record_function("torchft::local_sgd::perform_sync")
412405
def perform_sync(self) -> bool:
413406
"""
@@ -417,6 +410,18 @@ def perform_sync(self) -> bool:
417410
# Waiting for an allreduce before it has been sent is currently not supported.
418411
assert len(self._allreduce_work) > 0
419412

413+
with (
414+
torch.cuda.stream(self._stream)
415+
if self._stream is not None
416+
else nullcontext()
417+
):
418+
for work in self._allreduce_work:
419+
work.get_future().wait()
420+
421+
if self._stream is not None:
422+
self._stop_event = torch.cuda.Event()
423+
self._stop_event.record()
424+
420425
self.wait()
421426

422427
# save the parameters so they can be used for merging

0 commit comments

Comments
 (0)