Skip to content

Commit 94528f5

Browse files
committed
fix managed pg allreduce
Summary: managed pg allreduce should just call manager's allreduce
1 parent 70545d6 commit 94528f5

File tree

3 files changed

+11
-47
lines changed

3 files changed

+11
-47
lines changed

torchft/manager.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ def allreduce(
380380
self,
381381
tensor: torch.Tensor,
382382
should_quantize: bool = False,
383+
reduce_op: ReduceOp = ReduceOp.SUM,
383384
) -> Work:
384385
"""
385386
Fault tolerant allreduce the tensor and return a Future that will be completed when
@@ -414,10 +415,10 @@ def allreduce(
414415
# it later.
415416
if should_quantize and IS_TRITON_AVAILABLE:
416417
work = allreduce_quantized(
417-
[tensor], ReduceOp.SUM, self._pg, torch.cuda.current_stream()
418+
[tensor], reduce_op, self._pg, torch.cuda.current_stream()
418419
)
419420
else:
420-
work = self._pg.allreduce([tensor], ReduceOp.SUM)
421+
work = self._pg.allreduce([tensor], reduce_op)
421422

422423
# schedule grad normalization as a continuation
423424
# on the Future
@@ -426,7 +427,8 @@ def callback(
426427
fut: torch.futures.Future[list[torch.Tensor]],
427428
) -> torch.Tensor:
428429
nonlocal tensor
429-
tensor /= num_participants
430+
if reduce_op == ReduceOp.SUM:
431+
tensor /= num_participants
430432
return tensor
431433

432434
managed_work = _ManagedWork(self, work, tensor)

torchft/process_group.py

Lines changed: 5 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,30 +1062,6 @@ def callback(
10621062
return work
10631063

10641064

1065-
class _ManagedWork(Work):
1066-
def __init__(self, manager: "Manager", work: Work, default_result: object) -> None:
1067-
super().__init__()
1068-
1069-
self._manager = manager
1070-
self._work = work
1071-
self._default_result = default_result
1072-
1073-
def wait(self, timeout: Optional[timedelta] = None) -> bool:
1074-
try:
1075-
if self._work is not None:
1076-
if timeout is not None:
1077-
self._work.wait(timeout)
1078-
else:
1079-
self._work.wait()
1080-
except Exception as e:
1081-
self._manager.report_error(e)
1082-
1083-
return True
1084-
1085-
def get_future(self) -> Future[object]:
1086-
return self._manager.wrap_future(self._work.get_future(), self._default_result)
1087-
1088-
10891065
class ManagedProcessGroup(ProcessGroupWrapper):
10901066
"""
10911067
This is a wrapper around any ProcessGroup that is managed by a torchft
@@ -1105,23 +1081,13 @@ def __init__(self, manager: "Manager") -> None:
11051081
self._manager = manager
11061082

11071083
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
1108-
# Ensure we have a valid quorum and are configured before trying to do
1109-
# any work.
1110-
self._manager.wait_quorum()
1084+
if isinstance(opts, ReduceOp):
1085+
return self._manager.allreduce(tensors, reduce_op=opts)
11111086

1112-
if self._manager.errored() is not None:
1113-
return _DummyWork(tensors)
1114-
try:
1115-
work = super().allreduce(tensors, opts)
1116-
except Exception as e:
1117-
self._manager.report_error(e)
1118-
return _DummyWork(tensors)
1087+
if isinstance(opts, AllreduceOptions):
1088+
return self._manager.allreduce(tensors, reduce_op=opts.reduceOp)
11191089

1120-
return _ManagedWork(
1121-
self._manager,
1122-
work,
1123-
tensors,
1124-
)
1090+
assert False, "unreachable"
11251091

11261092
def size(self) -> int:
11271093
return self._manager.num_participants()

torchft/process_group_test.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
ProcessGroupNCCL,
4949
ProcessGroupWrapper,
5050
_ErrorSwallowingWork,
51-
_ManagedWork,
5251
extend_device_mesh,
5352
ft_init_device_mesh,
5453
)
@@ -810,11 +809,8 @@ def test_managed_process_group(self) -> None:
810809
self.assertEqual(pg.size(), 123)
811810

812811
works = _test_pg(pg)
813-
self.assertIsInstance(list(works.values())[0], _ManagedWork)
814812

815-
self.assertEqual(manager.report_error.call_count, 0)
816-
self.assertEqual(manager.wrap_future.call_count, 2)
817-
self.assertEqual(manager.wait_quorum.call_count, 2)
813+
self.assertEqual(manager.allreduce.call_count, 2)
818814

819815

820816
class DeviceMeshTest(TestCase):

0 commit comments

Comments
 (0)