18
18
import torch
19
19
import torch .distributed as dist
20
20
from torch import nn , optim
21
+ from torch .distributed .distributed_c10d import Work
21
22
from torch .distributed .tensor import DTensor
22
23
from torch .nn .parameter import Parameter
23
24
from torch .optim .optimizer import Optimizer
@@ -154,7 +155,8 @@ def _average(self) -> list[torch.Tensor]:
154
155
for p in self ._model .parameters ():
155
156
# Create a new tensor to store the averaged parameter
156
157
avg_param = extract_local_tensor (p )
157
- works .append (self ._manager .allreduce (avg_param ))
158
+ work = self ._manager .allreduce (avg_param )
159
+ works .append (work )
158
160
averaged_parameters .append (avg_param )
159
161
for work in works :
160
162
work .wait ()
@@ -200,7 +202,7 @@ def __init__(
200
202
self ._outer_optimizer = outer_optimizer
201
203
202
204
# Stores pending all reduce
203
- self ._allreduce_futures : list [torch . futures . Future [ torch . Tensor ] ] = []
205
+ self ._allreduce_work : list [Work ] = []
204
206
self ._stream : Optional [torch .cuda .Stream ] = (
205
207
torch .cuda .Stream () if torch .cuda .is_available () else None
206
208
)
@@ -368,15 +370,15 @@ def wait(self) -> None:
368
370
"""
369
371
Waits for the previously scheduled allreduce to finish
370
372
"""
371
- if len (self ._allreduce_futures ) == 0 :
373
+ if len (self ._allreduce_work ) == 0 :
372
374
return
373
375
374
376
if self ._stream is not None :
375
377
assert self ._stop_event is not None
376
378
self ._stop_event .synchronize ()
377
379
self ._stop_event = None
378
380
379
- self ._allreduce_futures = []
381
+ self ._allreduce_work = []
380
382
381
383
@torch .profiler .record_function ("torchft::local_sgd::prepare_sync" )
382
384
def prepare_sync (self ) -> None :
@@ -386,7 +388,7 @@ def prepare_sync(self) -> None:
386
388
"""
387
389
self ._save_grads ()
388
390
389
- assert len (self ._allreduce_futures ) == 0
391
+ assert len (self ._allreduce_work ) == 0
390
392
391
393
# Make sure tensors are available to `_stream`
392
394
if self ._stream is not None :
@@ -399,8 +401,8 @@ def prepare_sync(self) -> None:
399
401
):
400
402
self ._average_grads ()
401
403
402
- for work in self ._allreduce_futures :
403
- work .wait ()
404
+ for work in self ._allreduce_work :
405
+ work .get_future (). wait ()
404
406
405
407
if self ._stream is not None :
406
408
self ._stop_event = torch .cuda .Event ()
@@ -413,7 +415,7 @@ def perform_sync(self) -> bool:
413
415
steps using the outer optimizer.
414
416
"""
415
417
# Waiting for an allreduce before it has been sent is currently not supported.
416
- assert len (self ._allreduce_futures ) > 0
418
+ assert len (self ._allreduce_work ) > 0
417
419
418
420
self .wait ()
419
421
@@ -467,7 +469,8 @@ def _allreduce_per_param(self) -> None:
467
469
work = self ._manager .allreduce (
468
470
self ._grads [name ], should_quantize = self .should_quantize
469
471
)
470
- self ._allreduce_futures .append (work )
472
+
473
+ self ._allreduce_work .append (work )
471
474
472
475
def _bucketize_and_allreduce (
473
476
self ,
@@ -522,8 +525,10 @@ def callback(fut: torch.futures.Future[torch.Tensor]) -> None:
522
525
flat_buffer [pack_offset : pack_offset + numel ].view_as (t )
523
526
)
524
527
525
- work = work .then (callback )
526
- self ._allreduce_futures .append (work )
528
+ fut = work .get_future ()
529
+ fut = fut .then (callback )
530
+
531
+ self ._allreduce_work .append (work )
527
532
528
533
offset += chunk_size
529
534
0 commit comments