Skip to content

Commit 43dd4fb

Browse files
committed
fix stream dependencies in callbacks
Summary: - call future.wait in callbacks to make sure the continuation executes after the future has completed - set the stream correctly to execute callback scheduled by bucketized allreduce
1 parent f99b8dd commit 43dd4fb

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

torchft/collectives.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,8 @@ def callback(fut: Future[list[torch.Tensor]]) -> list[torch.Tensor]:
387387
nonlocal tensors, quantized_tensors, world_size, sync_stream
388388

389389
with torch.cuda.stream(sync_stream):
390+
# Setup stream dependency
391+
fut.wait()
390392
# Dequantize the result back to the original precision
391393
fused_dequantize_from_fp8(tensors, quantized_tensors, world_size)
392394
return tensors

torchft/local_sgd.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -524,9 +524,14 @@ def _bucketize_and_allreduce(
524524
)
525525

526526
def callback(fut: torch.futures.Future[torch.Tensor]) -> None:
527-
nonlocal bucket_tensors, flat_buffer
528-
for t, pack_offset, numel in bucket_tensors:
529-
t.copy_(flat_buffer[pack_offset : pack_offset + numel].view_as(t))
527+
with torch.cuda.stream(self._stream) if self._stream else nullcontext():
528+
nonlocal bucket_tensors, flat_buffer
529+
# Setup stream dependency
530+
fut.wait()
531+
for t, pack_offset, numel in bucket_tensors:
532+
t.copy_(
533+
flat_buffer[pack_offset : pack_offset + numel].view_as(t)
534+
)
530535

531536
work = work.then(callback)
532537
self._allreduce_futures.append(work)

torchft/manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,8 @@ def callback(
430430
# change the stream to avoid making the callback stream
431431
# dependent on process group stream running the allreduce
432432
with torch.cuda.stream(stream) if stream is not None else nullcontext():
433+
# Setup stream dependency
434+
fut.wait()
433435
fut.value()
434436
tensor /= num_participants
435437

0 commit comments

Comments
 (0)