Skip to content

Commit 222021b

Browse files
committed
fix managed pg allreduce
Summary: managed pg allreduce should just call manager's allreduce
1 parent 006d604 commit 222021b

File tree

2 files changed

+2
-46
lines changed

2 files changed

+2
-46
lines changed

torchft/process_group.py

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,30 +1059,6 @@ def callback(
10591059
return work
10601060

10611061

1062-
class _ManagedWork(Work):
1063-
def __init__(self, manager: "Manager", work: Work, default_result: object) -> None:
1064-
super().__init__()
1065-
1066-
self._manager = manager
1067-
self._work = work
1068-
self._default_result = default_result
1069-
1070-
def wait(self, timeout: Optional[timedelta] = None) -> bool:
1071-
try:
1072-
if self._work is not None:
1073-
if timeout is not None:
1074-
self._work.wait(timeout)
1075-
else:
1076-
self._work.wait()
1077-
except Exception as e:
1078-
self._manager.report_error(e)
1079-
1080-
return True
1081-
1082-
def get_future(self) -> Future[object]:
1083-
return self._manager.wrap_future(self._work.get_future(), self._default_result)
1084-
1085-
10861062
class ManagedProcessGroup(ProcessGroupWrapper):
10871063
"""
10881064
This is a wrapper around any ProcessGroup that is managed by a torchft
@@ -1102,23 +1078,7 @@ def __init__(self, manager: "Manager") -> None:
11021078
self._manager = manager
11031079

11041080
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
1105-
# Ensure we have a valid quorum and are configured before trying to do
1106-
# any work.
1107-
self._manager.wait_quorum()
1108-
1109-
if self._manager.errored() is not None:
1110-
return _DummyWork(tensors)
1111-
try:
1112-
work = super().allreduce(tensors, opts)
1113-
except Exception as e:
1114-
self._manager.report_error(e)
1115-
return _DummyWork(tensors)
1116-
1117-
return _ManagedWork(
1118-
self._manager,
1119-
work,
1120-
tensors,
1121-
)
1081+
return self._manager.allreduce(tensors)
11221082

11231083
def size(self) -> int:
11241084
return self._manager.num_participants()

torchft/process_group_test.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
ProcessGroupNCCL,
4949
ProcessGroupWrapper,
5050
_ErrorSwallowingWork,
51-
_ManagedWork,
5251
extend_device_mesh,
5352
ft_init_device_mesh,
5453
)
@@ -810,11 +809,8 @@ def test_managed_process_group(self) -> None:
810809
self.assertEqual(pg.size(), 123)
811810

812811
works = _test_pg(pg)
813-
self.assertIsInstance(list(works.values())[0], _ManagedWork)
814812

815-
self.assertEqual(manager.report_error.call_count, 0)
816-
self.assertEqual(manager.wrap_future.call_count, 2)
817-
self.assertEqual(manager.wait_quorum.call_count, 2)
813+
self.assertEqual(manager.allreduce.call_count, 2)
818814

819815

820816
class DeviceMeshTest(TestCase):

0 commit comments

Comments
 (0)