Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def allreduce(
self,
tensor: torch.Tensor,
should_quantize: bool = False,
reduce_op: ReduceOp = ReduceOp.SUM,
) -> Work:
"""
Fault tolerant allreduce the tensor and return a Future that will be completed when
Expand Down Expand Up @@ -414,10 +415,10 @@ def allreduce(
# it later.
if should_quantize and IS_TRITON_AVAILABLE:
work = allreduce_quantized(
[tensor], ReduceOp.SUM, self._pg, torch.cuda.current_stream()
[tensor], reduce_op, self._pg, torch.cuda.current_stream()
)
else:
work = self._pg.allreduce([tensor], ReduceOp.SUM)
work = self._pg.allreduce([tensor], reduce_op)

# schedule grad normalization as a continuation
# on the Future
Expand All @@ -426,7 +427,8 @@ def callback(
fut: torch.futures.Future[list[torch.Tensor]],
) -> torch.Tensor:
nonlocal tensor
tensor /= num_participants
if reduce_op == ReduceOp.SUM:
tensor /= num_participants
return tensor

managed_work = _ManagedWork(self, work, tensor)
Expand Down
44 changes: 5 additions & 39 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,30 +1062,6 @@ def callback(
return work


class _ManagedWork(Work):
def __init__(self, manager: "Manager", work: Work, default_result: object) -> None:
super().__init__()

self._manager = manager
self._work = work
self._default_result = default_result

def wait(self, timeout: Optional[timedelta] = None) -> bool:
try:
if self._work is not None:
if timeout is not None:
self._work.wait(timeout)
else:
self._work.wait()
except Exception as e:
self._manager.report_error(e)

return True

def get_future(self) -> Future[object]:
return self._manager.wrap_future(self._work.get_future(), self._default_result)


class ManagedProcessGroup(ProcessGroupWrapper):
"""
This is a wrapper around any ProcessGroup that is managed by a torchft
Expand All @@ -1105,23 +1081,13 @@ def __init__(self, manager: "Manager") -> None:
self._manager = manager

def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
# Ensure we have a valid quorum and are configured before trying to do
# any work.
self._manager.wait_quorum()
if isinstance(opts, ReduceOp):
return self._manager.allreduce(tensors, reduce_op=opts)

if self._manager.errored() is not None:
return _DummyWork(tensors)
try:
work = super().allreduce(tensors, opts)
except Exception as e:
self._manager.report_error(e)
return _DummyWork(tensors)
if isinstance(opts, AllreduceOptions):
return self._manager.allreduce(tensors, reduce_op=opts.reduceOp)

return _ManagedWork(
self._manager,
work,
tensors,
)
assert False, "unreachable"

def size(self) -> int:
return self._manager.num_participants()
Expand Down
6 changes: 1 addition & 5 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
ProcessGroupNCCL,
ProcessGroupWrapper,
_ErrorSwallowingWork,
_ManagedWork,
extend_device_mesh,
ft_init_device_mesh,
)
Expand Down Expand Up @@ -810,11 +809,8 @@ def test_managed_process_group(self) -> None:
self.assertEqual(pg.size(), 123)

works = _test_pg(pg)
self.assertIsInstance(list(works.values())[0], _ManagedWork)

self.assertEqual(manager.report_error.call_count, 0)
self.assertEqual(manager.wrap_future.call_count, 2)
self.assertEqual(manager.wait_quorum.call_count, 2)
self.assertEqual(manager.allreduce.call_count, 2)


class DeviceMeshTest(TestCase):
Expand Down
Loading