18
18
import torch
19
19
import torch .distributed as dist
20
20
from torch import nn , optim
21
+ from torch .distributed .distributed_c10d import Work
21
22
from torch .distributed .tensor import DTensor
22
23
from torch .nn .parameter import Parameter
23
24
from torch .optim .optimizer import Optimizer
@@ -211,7 +212,7 @@ def __init__(
211
212
self ._outer_optimizer = outer_optimizer
212
213
213
214
# Stores pending all reduce
214
- self ._allreduce_futures : list [torch . futures . Future [ torch . Tensor ] ] = []
215
+ self ._allreduce_work : list [Work ] = []
215
216
self ._stream : Optional [torch .cuda .Stream ] = (
216
217
torch .cuda .Stream () if torch .cuda .is_available () else None
217
218
)
@@ -379,15 +380,15 @@ def wait(self) -> None:
379
380
"""
380
381
Waits for the previously scheduled allreduce to finish
381
382
"""
382
- if len (self ._allreduce_futures ) == 0 :
383
+ if len (self ._allreduce_work ) == 0 :
383
384
return
384
385
385
386
if self ._stream is not None :
386
387
assert self ._stop_event is not None
387
388
self ._stop_event .synchronize ()
388
389
self ._stop_event = None
389
390
390
- self ._allreduce_futures = []
391
+ self ._allreduce_work = []
391
392
392
393
@torch .profiler .record_function ("torchft::local_sgd::prepare_sync" )
393
394
def prepare_sync (self ) -> None :
@@ -397,7 +398,7 @@ def prepare_sync(self) -> None:
397
398
"""
398
399
self ._save_grads ()
399
400
400
- assert len (self ._allreduce_futures ) == 0
401
+ assert len (self ._allreduce_work ) == 0
401
402
402
403
# Make sure tensors are available to `_stream`
403
404
if self ._stream is not None :
@@ -410,7 +411,7 @@ def prepare_sync(self) -> None:
410
411
):
411
412
self ._average_grads ()
412
413
413
- for work in self ._allreduce_futures :
414
+ for work in self ._allreduce_work :
414
415
work .wait ()
415
416
416
417
if self ._stream is not None :
@@ -424,7 +425,7 @@ def perform_sync(self) -> bool:
424
425
steps using the outer optimizer.
425
426
"""
426
427
# Waiting for an allreduce before it has been sent is currently not supported.
427
- assert len (self ._allreduce_futures ) > 0
428
+ assert len (self ._allreduce_work ) > 0
428
429
429
430
self .wait ()
430
431
@@ -478,7 +479,8 @@ def _allreduce_per_param(self) -> None:
478
479
work = self ._manager .allreduce (
479
480
self ._grads [name ], should_quantize = self .should_quantize
480
481
)
481
- self ._allreduce_futures .append (work )
482
+
483
+ self ._allreduce_work .append (work )
482
484
483
485
def _bucketize_and_allreduce (
484
486
self ,
@@ -533,8 +535,10 @@ def callback(fut: torch.futures.Future[torch.Tensor]) -> None:
533
535
flat_buffer [pack_offset : pack_offset + numel ].view_as (t )
534
536
)
535
537
536
- work = work .then (callback )
537
- self ._allreduce_futures .append (work )
538
+ fut = work .get_future ()
539
+ fut = fut .then (callback )
540
+
541
+ self ._allreduce_work .append (work )
538
542
539
543
offset += chunk_size
540
544
0 commit comments