Skip to content

Commit d8ef363

Browse files
committed
return work from manager allreduce
Summary: returns the work object so we can be more flexible with the usage
1 parent 94ed227 commit d8ef363

12 files changed

+103
-67
lines changed

torchft/collectives.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
AllreduceOptions,
1919
AllToAllOptions,
2020
ReduceScatterOptions,
21+
Work,
2122
)
2223
from torch.futures import Future
2324

@@ -288,7 +289,7 @@ def allreduce_quantized(
288289
opts: AllreduceOptions | ReduceOp,
289290
process_group: "ProcessGroup",
290291
sync_stream: cuda.Stream | None = None,
291-
) -> Future[list[torch.Tensor]]:
292+
) -> Work:
292293
"""
293294
Performs a quantized all-reduce operation on a list of tensors.
294295
@@ -379,6 +380,14 @@ def allreduce_quantized(
379380
[torch.split(quantized_tensors_out.view(world_size, -1), 1)[rank]],
380381
_to_allgather_options(allreduce_opts),
381382
)
383+
384+
# NOTE: This is not supposed to be used with gloo, only with NCCL.
385+
# So we setup the stream dependency here by calling work.wait(),
386+
# which doesn't block the CPU.
387+
#
388+
# The future callback below will run after the work has been
389+
# completed.
390+
382391
work.wait()
383392
fut = work.get_future()
384393

@@ -394,4 +403,4 @@ def callback(fut: Future[list[torch.Tensor]]) -> list[torch.Tensor]:
394403
return tensors
395404

396405
fut = fut.then(callback)
397-
return fut
406+
return work

torchft/collectives_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ def _run_all_reduce_collective(
9494
)
9595
]
9696

97-
fut = allreduce_quantized(tensors, reduce_op, pg)
98-
fut.wait()
97+
work = allreduce_quantized(tensors, reduce_op, pg)
98+
work.wait()
9999

100100
work = pg.allreduce([expected], reduce_op)
101101
work.get_future().wait()

torchft/ddp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def __init__(self, manager: "Manager", module: nn.Module, **kwargs: object) -> N
6868
def _comm_hook(
6969
state: "Manager", bucket: dist.GradBucket
7070
) -> torch.futures.Future[torch.Tensor]:
71-
return state.allreduce(bucket.buffer())
71+
work = state.allreduce(bucket.buffer())
72+
return work.get_future()
7273

7374

7475
class PureDistributedDataParallel(nn.Module):

torchft/ddp_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
import torch
1111
import torch.distributed as dist
1212
from torch import nn
13+
from torch.distributed.distributed_c10d import Work
1314
from torch.futures import Future
1415

1516
from torchft.ddp import DistributedDataParallel, PureDistributedDataParallel
1617
from torchft.manager import Manager
1718
from torchft.process_group import ProcessGroupBabyGloo, ProcessGroupGloo
19+
from torchft.work import _DummyWork
1820

1921

2022
class TestDDP(TestCase):
@@ -39,14 +41,14 @@ def test_ddp(self) -> None:
3941

4042
call_count = 0
4143

42-
def allreduce(tensor: torch.Tensor) -> Future[torch.Tensor]:
44+
def allreduce(
45+
tensor: torch.Tensor,
46+
) -> Work:
4347
nonlocal call_count
4448

4549
call_count += 1
4650

47-
fut = Future() # pyre-fixme[29]: not a function
48-
fut.set_result(tensor)
49-
return fut
51+
return _DummyWork(tensor)
5052

5153
manager.allreduce = allreduce
5254

torchft/local_sgd.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
import torch.distributed as dist
2020
from torch import nn, optim
21+
from torch.distributed.distributed_c10d import Work
2122
from torch.distributed.tensor import DTensor
2223
from torch.nn.parameter import Parameter
2324
from torch.optim.optimizer import Optimizer
@@ -200,7 +201,7 @@ def __init__(
200201
self._outer_optimizer = outer_optimizer
201202

202203
# Stores pending all reduce
203-
self._allreduce_futures: list[torch.futures.Future[torch.Tensor]] = []
204+
self._allreduce_work: list[Work] = []
204205
self._stream: Optional[torch.cuda.Stream] = (
205206
torch.cuda.Stream() if torch.cuda.is_available() else None
206207
)
@@ -368,15 +369,15 @@ def wait(self) -> None:
368369
"""
369370
Waits for the previously scheduled allreduce to finish
370371
"""
371-
if len(self._allreduce_futures) == 0:
372+
if len(self._allreduce_work) == 0:
372373
return
373374

374375
if self._stream is not None:
375376
assert self._stop_event is not None
376377
self._stop_event.synchronize()
377378
self._stop_event = None
378379

379-
self._allreduce_futures = []
380+
self._allreduce_work = []
380381

381382
@torch.profiler.record_function("torchft::local_sgd::prepare_sync")
382383
def prepare_sync(self) -> None:
@@ -386,7 +387,7 @@ def prepare_sync(self) -> None:
386387
"""
387388
self._save_grads()
388389

389-
assert len(self._allreduce_futures) == 0
390+
assert len(self._allreduce_work) == 0
390391

391392
# Make sure tensors are available to `_stream`
392393
if self._stream is not None:
@@ -399,7 +400,7 @@ def prepare_sync(self) -> None:
399400
):
400401
self._average_grads()
401402

402-
for work in self._allreduce_futures:
403+
for work in self._allreduce_work:
403404
work.wait()
404405

405406
if self._stream is not None:
@@ -413,7 +414,7 @@ def perform_sync(self) -> bool:
413414
steps using the outer optimizer.
414415
"""
415416
# Waiting for an allreduce before it has been sent is currently not supported.
416-
assert len(self._allreduce_futures) > 0
417+
assert len(self._allreduce_work) > 0
417418

418419
self.wait()
419420

@@ -467,7 +468,8 @@ def _allreduce_per_param(self) -> None:
467468
work = self._manager.allreduce(
468469
self._grads[name], should_quantize=self.should_quantize
469470
)
470-
self._allreduce_futures.append(work)
471+
472+
self._allreduce_work.append(work)
471473

472474
def _bucketize_and_allreduce(
473475
self,
@@ -522,8 +524,10 @@ def callback(fut: torch.futures.Future[torch.Tensor]) -> None:
522524
flat_buffer[pack_offset : pack_offset + numel].view_as(t)
523525
)
524526

525-
work = work.then(callback)
526-
self._allreduce_futures.append(work)
527+
fut = work.get_future()
528+
fut = fut.then(callback)
529+
530+
self._allreduce_work.append(work)
527531

528532
offset += chunk_size
529533

torchft/local_sgd_test.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
import torch
1212
from parameterized import parameterized
1313
from torch import Tensor, nn, optim
14+
from torch.distributed.distributed_c10d import Work
1415
from torch.distributed.tensor import DTensor
1516

1617
from torchft.local_sgd import DiLoCo, LocalSGD, extract_local_tensor
1718
from torchft.manager import Manager
19+
from torchft.work import _DummyWork
1820

1921

2022
def create_manager() -> MagicMock:
@@ -26,6 +28,11 @@ def create_manager() -> MagicMock:
2628

2729
manager.errored.return_value = None
2830

31+
def mock_allreduce(tensor: torch.Tensor, should_quantize: bool = False) -> Work:
32+
return _DummyWork(tensor)
33+
34+
manager.allreduce.side_effect = mock_allreduce
35+
2936
return manager
3037

3138

@@ -66,7 +73,7 @@ class LocalSGDTest(TestCase):
6673
def test_local_sgd_healthy(self) -> None:
6774
model = SimpleModel()
6875
optimizer = optim.SGD(model.parameters())
69-
manager = create_autospec(Manager)
76+
manager = create_manager()
7077
with LocalSGD(manager, model, optimizer, sync_every=2) as local_sgd:
7178
self.assertEqual(local_sgd._local_step, 0)
7279
inp = torch.rand(2, 3)
@@ -240,13 +247,9 @@ def test_bucketization_correctness(self) -> None:
240247
manager.should_commit.return_value = True
241248

242249
# Define fake allreduce: multiplies buffer by 2
243-
def fake_allreduce(
244-
tensor: Tensor, should_quantize: bool
245-
) -> torch.futures.Future[Tensor]:
250+
def fake_allreduce(tensor: Tensor, should_quantize: bool) -> Work:
246251
tensor.mul_(2)
247-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
248-
fut.set_result(tensor)
249-
return fut
252+
return _DummyWork(tensor)
250253

251254
manager.allreduce.side_effect = fake_allreduce
252255

@@ -284,13 +287,9 @@ def test_gradient_correctness(self) -> None:
284287
manager.should_commit.return_value = True
285288

286289
# Define fake allreduce: multiplies buffer by 2
287-
def fake_allreduce(
288-
tensor: Tensor, should_quantize: bool
289-
) -> torch.futures.Future[Tensor]:
290+
def fake_allreduce(tensor: Tensor, should_quantize: bool) -> Work:
290291
tensor.mul_(2)
291-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
292-
fut.set_result(tensor)
293-
return fut
292+
return _DummyWork(tensor)
294293

295294
manager.allreduce.side_effect = fake_allreduce
296295

torchft/manager.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,12 @@
3939

4040
import torch
4141
from torch.distributed import ReduceOp, TCPStore
42-
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp
42+
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work
4343

4444
from torchft._torchft import ManagerClient, ManagerServer
4545
from torchft.checkpointing import CheckpointTransport, HTTPTransport
4646
from torchft.futures import future_timeout
47+
from torchft.work import _DummyWork, _WorkWrapper
4748

4849
if TYPE_CHECKING:
4950
from torchft.process_group import ProcessGroup
@@ -343,9 +344,7 @@ def shutdown(self, wait: bool = True) -> None:
343344
self._executor.shutdown(wait=wait)
344345

345346
@torch.profiler.record_function("torchft::manager::allreduce")
346-
def allreduce(
347-
self, tensor: torch.Tensor, should_quantize: bool = False
348-
) -> torch.futures.Future[torch.Tensor]:
347+
def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work:
349348
"""
350349
Fault tolerant allreduce the tensor and return a Future that will be completed when
351350
the tensor is ready.
@@ -365,9 +364,7 @@ def allreduce(
365364
a Future that will be completed with the allreduced tensor
366365
"""
367366
if self.errored():
368-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
369-
fut.set_result(tensor)
370-
return fut
367+
return _DummyWork(tensor)
371368

372369
self.wait_quorum()
373370
num_participants: int = self.num_participants()
@@ -380,13 +377,14 @@ def allreduce(
380377
# Run the allreduce async and save the work object so we can wait on
381378
# it later.
382379
if should_quantize and IS_TRITON_AVAILABLE:
383-
fut = allreduce_quantized(
380+
work = allreduce_quantized(
384381
[tensor], ReduceOp.SUM, self._pg, torch.cuda.current_stream()
385382
)
386383
else:
387384
work = self._pg.allreduce([tensor], ReduceOp.SUM)
388385
work.wait()
389-
fut = work.get_future()
386+
387+
fut = work.get_future()
390388

391389
stream: Optional[torch.cuda.Stream] = (
392390
torch.cuda.current_stream() if torch.cuda.is_available() else None
@@ -413,17 +411,16 @@ def callback(
413411
fut = fut.then(callback)
414412

415413
fut = self.wrap_future(fut, tensor)
416-
return fut
414+
415+
return _WorkWrapper(work, fut)
417416

418417
except Exception as e:
419418
self._logger.exception(
420419
f"got exception in all reduce -- skipping remaining: {e}"
421420
)
422421
self.report_error(e)
423422

424-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
425-
fut.set_result(tensor)
426-
return fut
423+
return _DummyWork(tensor)
427424

428425
def report_error(self, e: Exception) -> None:
429426
"""

torchft/manager_integ_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,7 @@ def all_reduce_callback(
634634

635635
manager.start_quorum()
636636
t1 = torch.ones((1, 3), device=device)
637-
fut = manager.allreduce(t1)
638-
fut.wait()
637+
work = manager.allreduce(t1)
638+
work.wait()
639639
return t1
640640
return None

torchft/manager_test.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from torchft._torchft import QuorumResult
1717
from torchft.checkpointing.transport import CheckpointTransport
1818
from torchft.manager import MANAGER_ADDR_KEY, REPLICA_ID_KEY, Manager, WorldSizeMode
19-
from torchft.process_group import ProcessGroup, _DummyWork
19+
from torchft.process_group import ProcessGroup
20+
from torchft.work import _DummyWork
2021

2122

2223
def mock_should_commit(
@@ -586,16 +587,16 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None:
586587
manager._pg.allreduce.return_value = _DummyWork(None)
587588

588589
self.assertTrue(manager.is_participating())
589-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
590-
fut = manager.allreduce(torch.tensor([1.0]))
590+
work = manager.allreduce(torch.tensor([1.0]))
591+
fut = work.get_future()
591592
result = fut.value()
592593
torch.testing.assert_close(result, torch.tensor([1.0 / 5]))
593594

594595
# check healing numerics
595596
manager._healing = True
596597
self.assertFalse(manager.is_participating())
597-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
598-
fut = manager.allreduce(torch.tensor([1.0]))
598+
work = manager.allreduce(torch.tensor([1.0]))
599+
fut = work.get_future()
599600
result = fut.value()
600601
torch.testing.assert_close(result, torch.tensor([0.0]))
601602

torchft/process_group.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
from torchft.device_mesh import * # noqa: F401
7070
from torchft.futures import context_timeout, stream_timeout
7171
from torchft.multiprocessing import _MonitoredPipe
72+
from torchft.work import _DummyWork
7273

7374
if TYPE_CHECKING:
7475
from torchft.manager import Manager
@@ -790,21 +791,6 @@ def getBackendName(self) -> str:
790791
return "torchft-nccl"
791792

792793

793-
class _DummyWork(dist._Work):
794-
def __init__(self, result: object) -> None:
795-
super().__init__()
796-
self.result_ = result
797-
# pyre-fixme[29]: Future is not a function
798-
self.future_: torch.futures.Future[object] = torch.futures.Future()
799-
self.future_.set_result(result)
800-
801-
def wait(self, timeout: Optional[timedelta] = None) -> bool:
802-
return True
803-
804-
def get_future(self) -> torch.futures.Future[object]:
805-
return self.future_
806-
807-
808794
class ProcessGroupDummy(ProcessGroup):
809795
"""
810796
This process group discards all data passed to it and returns success. This

0 commit comments

Comments
 (0)