Skip to content

Commit 09bbdea

Browse files
committed
fix stream dependencies in callbacks
Summary: - call future.wait in callbacks to make sure the continuation executes after the future has completed - set the stream correctly to execute callback scheduled by bucketized allreduce
1 parent fef4abc commit 09bbdea

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

torchft/collectives.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,8 @@ def callback(fut: Future[list[torch.Tensor]]) -> list[torch.Tensor]:
387387
nonlocal tensors, quantized_tensors, world_size, sync_stream
388388

389389
with torch.cuda.stream(sync_stream):
390+
# Setup stream dependency
391+
fut.wait()
390392
# Dequantize the result back to the original precision
391393
fused_dequantize_from_fp8(tensors, quantized_tensors, world_size)
392394
return tensors

torchft/local_sgd.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -513,9 +513,14 @@ def _bucketize_and_allreduce(
513513
)
514514

515515
def callback(fut: torch.futures.Future[torch.Tensor]) -> None:
516-
nonlocal bucket_tensors, flat_buffer
517-
for t, pack_offset, numel in bucket_tensors:
518-
t.copy_(flat_buffer[pack_offset : pack_offset + numel].view_as(t))
516+
with torch.cuda.stream(self._stream) if self._stream else nullcontext():
517+
nonlocal bucket_tensors, flat_buffer
518+
# Setup stream dependency
519+
fut.wait()
520+
for t, pack_offset, numel in bucket_tensors:
521+
t.copy_(
522+
flat_buffer[pack_offset : pack_offset + numel].view_as(t)
523+
)
519524

520525
work = work.then(callback)
521526
self._allreduce_futures.append(work)

torchft/manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,8 @@ def callback(
403403
# change the stream to avoid making the callback stream
404404
# dependent on process group stream running the allreduce
405405
with torch.cuda.stream(stream) if stream is not None else nullcontext():
406+
# Setup stream dependency
407+
fut.wait()
406408
fut.value()
407409
tensor /= num_participants
408410

0 commit comments

Comments
 (0)