Skip to content

Commit 4033168

Browse files
committed
fix managed pg allreduce
Summary: managed pg allreduce should just call manager's allreduce
1 parent eb099b5 commit 4033168

File tree

3 files changed

+12
-48
lines changed

3 files changed

+12
-48
lines changed

torchft/manager.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ def allreduce(
379379
self,
380380
tensor: torch.Tensor,
381381
should_quantize: bool = False,
382+
reduce_op: ReduceOp = ReduceOp.SUM,
382383
) -> Work:
383384
"""
384385
Fault tolerant allreduce the tensor and return a Future that will be completed when
@@ -413,19 +414,20 @@ def allreduce(
413414
# it later.
414415
if should_quantize and IS_TRITON_AVAILABLE:
415416
work = allreduce_quantized(
416-
[tensor], ReduceOp.SUM, self._pg, torch.cuda.current_stream()
417+
[tensor], reduce_op, self._pg, torch.cuda.current_stream()
417418
)
418419
else:
419-
work = self._pg.allreduce([tensor], ReduceOp.SUM)
420+
work = self._pg.allreduce([tensor], reduce_op)
420421

421422
# schedule grad normalization as a continuation
422423
# on the Future
423424
@torch.profiler.record_function("torchft::manager::allreduce::callback")
424425
def callback(
425426
fut: torch.futures.Future[list[torch.Tensor]],
426427
) -> torch.Tensor:
427-
nonlocal num_participants, tensor
428-
tensor /= num_participants
428+
nonlocal num_participants, tensor, reduce_op
429+
if reduce_op == ReduceOp.SUM:
430+
tensor /= num_participants
429431
return tensor
430432

431433
managed_work = _ManagedWork(self, work, tensor)

torchft/process_group.py

Lines changed: 5 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,30 +1062,6 @@ def callback(
10621062
return work
10631063

10641064

1065-
class _ManagedWork(Work):
1066-
def __init__(self, manager: "Manager", work: Work, default_result: object) -> None:
1067-
super().__init__()
1068-
1069-
self._manager = manager
1070-
self._work = work
1071-
self._default_result = default_result
1072-
1073-
def wait(self, timeout: Optional[timedelta] = None) -> bool:
1074-
try:
1075-
if self._work is not None:
1076-
if timeout is not None:
1077-
self._work.wait(timeout)
1078-
else:
1079-
self._work.wait()
1080-
except Exception as e:
1081-
self._manager.report_error(e)
1082-
1083-
return True
1084-
1085-
def get_future(self) -> Future[object]:
1086-
return self._manager.wrap_future(self._work.get_future(), self._default_result)
1087-
1088-
10891065
class ManagedProcessGroup(ProcessGroupWrapper):
10901066
"""
10911067
This is a wrapper around any ProcessGroup that is managed by a torchft
@@ -1105,23 +1081,13 @@ def __init__(self, manager: "Manager") -> None:
11051081
self._manager = manager
11061082

11071083
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
1108-
# Ensure we have a valid quorum and are configured before trying to do
1109-
# any work.
1110-
self._manager.wait_quorum()
1084+
if isinstance(opts, ReduceOp):
1085+
return self._manager.allreduce(tensors, reduce_op=opts)
11111086

1112-
if self._manager.errored() is not None:
1113-
return _DummyWork(tensors)
1114-
try:
1115-
work = super().allreduce(tensors, opts)
1116-
except Exception as e:
1117-
self._manager.report_error(e)
1118-
return _DummyWork(tensors)
1087+
if isinstance(opts, AllreduceOptions):
1088+
return self._manager.allreduce(tensors, reduce_op=opts.reduceOp)
11191089

1120-
return _ManagedWork(
1121-
self._manager,
1122-
work,
1123-
tensors,
1124-
)
1090+
assert False, "unreachable"
11251091

11261092
def size(self) -> int:
11271093
return self._manager.num_participants()

torchft/process_group_test.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
ProcessGroupNCCL,
4949
ProcessGroupWrapper,
5050
_ErrorSwallowingWork,
51-
_ManagedWork,
5251
extend_device_mesh,
5352
ft_init_device_mesh,
5453
)
@@ -810,11 +809,8 @@ def test_managed_process_group(self) -> None:
810809
self.assertEqual(pg.size(), 123)
811810

812811
works = _test_pg(pg)
813-
self.assertIsInstance(list(works.values())[0], _ManagedWork)
814812

815-
self.assertEqual(manager.report_error.call_count, 0)
816-
self.assertEqual(manager.wrap_future.call_count, 2)
817-
self.assertEqual(manager.wait_quorum.call_count, 2)
813+
self.assertEqual(manager.allreduce.call_count, 2)
818814

819815

820816
class DeviceMeshTest(TestCase):

0 commit comments

Comments
 (0)