From b918cfaebc454d799d03444f7399a3540f3c58cd Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Wed, 13 Aug 2025 23:03:11 -0700 Subject: [PATCH] fix managed pg allreduce Summary: managed pg allreduce should just call manager's allreduce --- torchft/manager.py | 8 ++++--- torchft/process_group.py | 44 ++++------------------------------- torchft/process_group_test.py | 6 +---- 3 files changed, 11 insertions(+), 47 deletions(-) diff --git a/torchft/manager.py b/torchft/manager.py index 711e167e..c0fd2910 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -380,6 +380,7 @@ def allreduce( self, tensor: torch.Tensor, should_quantize: bool = False, + reduce_op: ReduceOp = ReduceOp.SUM, ) -> Work: """ Fault tolerant allreduce the tensor and return a Future that will be completed when @@ -414,10 +415,10 @@ def allreduce( # it later. if should_quantize and IS_TRITON_AVAILABLE: work = allreduce_quantized( - [tensor], ReduceOp.SUM, self._pg, torch.cuda.current_stream() + [tensor], reduce_op, self._pg, torch.cuda.current_stream() ) else: - work = self._pg.allreduce([tensor], ReduceOp.SUM) + work = self._pg.allreduce([tensor], reduce_op) # schedule grad normalization as a continuation # on the Future @@ -426,7 +427,8 @@ def callback( fut: torch.futures.Future[list[torch.Tensor]], ) -> torch.Tensor: nonlocal tensor - tensor /= num_participants + if reduce_op == ReduceOp.SUM: + tensor /= num_participants return tensor managed_work = _ManagedWork(self, work, tensor) diff --git a/torchft/process_group.py b/torchft/process_group.py index ac6617f2..854dee12 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -1062,30 +1062,6 @@ def callback( return work -class _ManagedWork(Work): - def __init__(self, manager: "Manager", work: Work, default_result: object) -> None: - super().__init__() - - self._manager = manager - self._work = work - self._default_result = default_result - - def wait(self, timeout: Optional[timedelta] = None) -> bool: - try: - if self._work is not None: - if timeout is not None: - self._work.wait(timeout) - else: - self._work.wait() - except Exception as e: - self._manager.report_error(e) - - return True - - def get_future(self) -> Future[object]: - return self._manager.wrap_future(self._work.get_future(), self._default_result) - - class ManagedProcessGroup(ProcessGroupWrapper): """ This is a wrapper around any ProcessGroup that is managed by a torchft @@ -1105,23 +1081,13 @@ def __init__(self, manager: "Manager") -> None: self._manager = manager def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: - # Ensure we have a valid quorum and are configured before trying to do - # any work. - self._manager.wait_quorum() + if isinstance(opts, ReduceOp): + return self._manager.allreduce(tensors, reduce_op=opts) - if self._manager.errored() is not None: - return _DummyWork(tensors) - try: - work = super().allreduce(tensors, opts) - except Exception as e: - self._manager.report_error(e) - return _DummyWork(tensors) + if isinstance(opts, AllreduceOptions): + return self._manager.allreduce(tensors, reduce_op=opts.reduceOp) - return _ManagedWork( - self._manager, - work, - tensors, - ) + assert False, "unreachable" def size(self) -> int: return self._manager.num_participants() diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index 6d2c0a53..bc364e5f 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -48,7 +48,6 @@ ProcessGroupNCCL, ProcessGroupWrapper, _ErrorSwallowingWork, - _ManagedWork, extend_device_mesh, ft_init_device_mesh, ) @@ -810,11 +809,8 @@ def test_managed_process_group(self) -> None: self.assertEqual(pg.size(), 123) works = _test_pg(pg) - self.assertIsInstance(list(works.values())[0], _ManagedWork) - self.assertEqual(manager.report_error.call_count, 0) - self.assertEqual(manager.wrap_future.call_count, 2) - self.assertEqual(manager.wait_quorum.call_count, 2) + self.assertEqual(manager.allreduce.call_count, 2) class DeviceMeshTest(TestCase):