diff --git a/torchft/_test/managed_work_test.py b/torchft/_test/managed_work_test.py new file mode 100644 index 00000000..118daf98 --- /dev/null +++ b/torchft/_test/managed_work_test.py @@ -0,0 +1,320 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import types +import unittest +from datetime import timedelta +from typing import Callable, Dict, List, Optional, Tuple, TypeVar, cast + +# Define a type variable for the Future's value type +T = TypeVar("T") + +import parameterized +import torch +from torch.distributed.distributed_c10d import Work +from torch.futures import Future + +from torchft.manager import Manager, _ManagedWork + + +class SimpleWork(Work): + """A simple implementation of torch.distributed.Work for testing.""" + + def __init__(self, tensors: List[torch.Tensor]) -> None: + super().__init__() + self._tensors = tensors + self._future: Future[List[torch.Tensor]] = torch.futures.Future() + self._is_completed: bool = False + + def wait(self, timeout: Optional[timedelta] = None) -> bool: + self._is_completed = True + self._future.set_result(self._tensors) + return True + + def get_future(self) -> Future[List[torch.Tensor]]: + return self._future + + +class TestManagedWork(unittest.TestCase): + @parameterized.parameterized.expand( + [ + ("cpu", torch.device("cpu")), + ("cuda", torch.device("cuda:0")), + ] + ) + def test_callbacks_execute_after_wait( + self, name: str, device: torch.device + ) -> None: + """Test that callbacks are only executed after wait() is called.""" + # Skip if CUDA is requested but not available + if device.type == "cuda" and not torch.cuda.is_available(): + self.skipTest("CUDA not available") + + # Create a tensor to work with + tensor: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device) + + # Create a simple work object + work = SimpleWork([tensor]) + + # Create a minimal manager object with just the wrap_future method + manager = Manager.__new__(Manager) # Create instance without calling __init__ + # We're using types.MethodType to attach a method to the manager instance + # This is just for testing purposes + manager.wrap_future = types.MethodType( # type: ignore + lambda self, fut, default, timeout=None: fut, manager + ) + + # Create the managed work + managed_work = _ManagedWork(manager, work, [tensor]) + + # Track callback execution + callback_executed: bool = False + + def callback(fut: Future[object]) -> List[torch.Tensor]: + # Cast to the expected type + nonlocal callback_executed, tensor + callback_executed = True + # Multiply tensor by 2 to verify the callback ran + tensor.mul_(2) + return [tensor] + + # Add the callback + fut = managed_work.get_future() + fut = fut.then(callback) + + # Verify callback hasn't executed yet + self.assertFalse(callback_executed) + self.assertEqual(tensor.item(), 1.0) + + # Call wait() which should trigger the callback + managed_work.wait() + + # Verify callback has executed + self.assertTrue(callback_executed) + self.assertEqual(tensor.item(), 2.0) + + @parameterized.parameterized.expand( + [ + ("cpu", torch.device("cpu")), + ("cuda", torch.device("cuda:0")), + ] + ) + def test_multiple_callbacks_execute_in_order( + self, name: str, device: torch.device + ) -> None: + """Test that multiple callbacks are executed in the order they were added.""" + # Skip if CUDA is requested but not available + if device.type == "cuda" and not torch.cuda.is_available(): + self.skipTest("CUDA not available") + + # Create a tensor to work with + tensor: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device) + + # Create a simple work object + work = SimpleWork([tensor]) + + # Create a minimal manager object with just the wrap_future method + manager = Manager.__new__(Manager) # Create instance without calling __init__ + manager.wrap_future = types.MethodType( # type: ignore + lambda self, fut, default, timeout=None: fut, manager + ) + + # Create the managed work + managed_work = _ManagedWork(manager, work, [tensor]) + + # Track execution order + execution_order: List[int] = [] + + def callback1(fut: Future[list[torch.Tensor]]) -> List[torch.Tensor]: + nonlocal tensor + execution_order.append(1) + tensor.add_(1) + return [tensor] + + def callback2(fut: Future[list[torch.Tensor]]) -> List[torch.Tensor]: + nonlocal tensor + execution_order.append(2) + tensor.add_(2) + return [tensor] + + def callback3(fut: Future[list[torch.Tensor]]) -> List[torch.Tensor]: + nonlocal tensor + execution_order.append(3) + tensor.add_(3) + return [tensor] + + # Add callbacks + fut = managed_work.get_future() + fut = cast(Future[list[torch.Tensor]], fut) + fut = fut.then(callback1) + fut = fut.then(callback2) + fut = fut.then(callback3) + + # Verify no callbacks have executed yet + self.assertEqual(len(execution_order), 0) + self.assertEqual(tensor.item(), 1.0) + + # Call wait() which should trigger the callbacks + managed_work.wait() + + # Verify callbacks executed in order + self.assertEqual(execution_order, [1, 2, 3]) + + # Each callback adds to the tensor, so final value should be 1 + 1 + 2 + 3 = 7 + self.assertEqual(tensor.item(), 7.0) + + @parameterized.parameterized.expand( + [ + ("cpu", torch.device("cpu")), + ("cuda", torch.device("cuda:0")), + ] + ) + def test_future_then_api(self, name: str, device: torch.device) -> None: + """Test that the future's then API works correctly with ManagedWork.""" + # Skip if CUDA is requested but not available + if device.type == "cuda" and not torch.cuda.is_available(): + self.skipTest("CUDA not available") + + # Create a tensor to work with + tensor: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device) + + # Create a simple work object + work = SimpleWork([tensor]) + + # Create a minimal manager object with just the wrap_future method + manager = Manager.__new__(Manager) # Create instance without calling __init__ + manager.wrap_future = types.MethodType( # type: ignore + lambda self, fut, default, timeout=None: fut, manager + ) + + # Create the managed work + managed_work = _ManagedWork(manager, work, [tensor]) + + # Get the future + future = managed_work.get_future() + + # Track callback execution + callback_executed: bool = False + + def callback(fut: Future[object]) -> List[torch.Tensor]: + # Cast to the expected type + nonlocal callback_executed, tensor + callback_executed = True + # Multiply tensor by 3 to verify the callback ran + tensor.mul_(3) + return [tensor] + + # Use the then API + future = future.then(callback) + + # Verify callback hasn't executed yet + self.assertFalse(callback_executed) + self.assertEqual(tensor.item(), 1.0) + + # Call wait() on the managed_work first to set up the future properly + managed_work.wait() + + # Verify callback has executed + self.assertTrue(callback_executed) + self.assertEqual(tensor.item(), 3.0) + + @parameterized.parameterized.expand( + [ + ("cpu", torch.device("cpu")), + ("cuda", torch.device("cuda:0")), + ] + ) + def test_callbacks_changing_return_types( + self, name: str, device: torch.device + ) -> None: + """ + Test that callbacks can change return types and that tensors are modified in-place. + This test demonstrates: + 1. Callbacks changing return types (List[Tensor] -> Dict -> Tuple) + 2. Using Future.value() instead of nonlocal + 3. Verifying tensors are modified in-place for both approaches + """ + # Skip if CUDA is requested but not available + if device.type == "cuda" and not torch.cuda.is_available(): + self.skipTest("CUDA not available") + + # Create tensors to work with + tensor1: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device) + tensor2: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device) * 2 + + # Store original tensor memory addresses to verify in-place modification + tensor1_address = tensor1.data_ptr() + tensor2_address = tensor2.data_ptr() + + # Create a simple work object + work = SimpleWork([tensor1, tensor2]) + + # Create a minimal manager object with just the wrap_future method + manager = Manager.__new__(Manager) # Create instance without calling __init__ + manager.wrap_future = types.MethodType( # type: ignore + lambda self, fut, default, timeout=None: fut, manager + ) + + # Create the managed work + managed_work = _ManagedWork(manager, work, [tensor1, tensor2]) + + # Get the future + future = managed_work.get_future() + future = cast(Future[List[torch.Tensor]], future) + + # First callback: Takes List[Tensor] and returns Dict[str, Tensor] + # Uses nonlocal to modify tensor1 + def callback1(fut: Future[List[torch.Tensor]]) -> Dict[str, torch.Tensor]: + tensors = fut.value() + nonlocal tensor1 + # Modify tensor1 in-place using nonlocal + tensor1.mul_(3) + # Return a dictionary instead of a list + return {"first": tensors[0], "second": tensors[1]} + + # Second callback: Takes Dict[str, Tensor] and returns Tuple[Tensor, float] + # Uses Future.value() to modify tensor2 + def callback2( + fut: Future[Dict[str, torch.Tensor]] + ) -> Tuple[torch.Tensor, float]: + data = fut.value() + # Modify tensor2 in-place using the value from the future + data["second"].add_(5) # Should modify tensor2 in-place + # Return a tuple instead of a dict + return (data["second"], data["first"].item()) + + # Third callback: Takes Tuple[Tensor, float] and returns a single Tensor + def callback3(fut: Future[Tuple[torch.Tensor, float]]) -> torch.Tensor: + tensor, value = fut.value() + # Create a new tensor based on the tuple values + result = tensor * value + return result + + # Chain the callbacks + future = future.then(callback1) + future = future.then(callback2) + future = future.then(callback3) + + # Call wait() to trigger the callbacks + managed_work.wait() + + # Verify tensor1 was modified in-place (using nonlocal) + self.assertEqual(tensor1.item(), 3.0) # 1 * 3 = 3 + self.assertEqual(tensor1.data_ptr(), tensor1_address) # Same memory address + + # Verify tensor2 was modified in-place (using Future.value()) + self.assertEqual(tensor2.item(), 7.0) # 2 + 5 = 7 + self.assertEqual(tensor2.data_ptr(), tensor2_address) # Same memory address + + # Get the final result from the future + final_result = future.wait() + + # The final result should be tensor2 * tensor1.item() = 7 * 3 = 21 + self.assertEqual(final_result.item(), 21.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchft/ddp.py b/torchft/ddp.py index 494a9b13..1af50876 100644 --- a/torchft/ddp.py +++ b/torchft/ddp.py @@ -14,7 +14,7 @@ import os import sys -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, cast from unittest.mock import patch import torch @@ -26,7 +26,7 @@ from torchft.process_group import ProcessGroup, ProcessGroupDummy, ProcessGroupGloo if TYPE_CHECKING: - from torchft.manager import Manager + from torchft.manager import Manager, _ManagedFuture class DistributedDataParallel(parallel.DistributedDataParallel): @@ -69,8 +69,14 @@ def _comm_hook( state: "Manager", bucket: dist.GradBucket ) -> torch.futures.Future[torch.Tensor]: work = state.allreduce(bucket.buffer()) - work.synchronize() - return work.get_future() + work.wait() + fut = work.get_future() + + # We need to return the underlying future here otherwise + # this can hang + fut = cast("_ManagedFuture[torch.Tensor]", fut) + assert fut._fut + return fut._fut class PureDistributedDataParallel(nn.Module): diff --git a/torchft/ddp_test.py b/torchft/ddp_test.py index 690bfd03..5ff42294 100644 --- a/torchft/ddp_test.py +++ b/torchft/ddp_test.py @@ -14,7 +14,7 @@ from torch.futures import Future from torchft.ddp import DistributedDataParallel, PureDistributedDataParallel -from torchft.manager import Manager +from torchft.manager import Manager, _ManagedWork from torchft.process_group import ProcessGroupBabyGloo, ProcessGroupGloo from torchft.work import _DummyWork @@ -41,14 +41,16 @@ def test_ddp(self) -> None: call_count = 0 + # pyre-ignore[53]: Captured variable `manager` is not annotated. def allreduce( tensor: torch.Tensor, ) -> Work: - nonlocal call_count + nonlocal call_count, manager call_count += 1 - return _DummyWork(tensor) + work = _DummyWork(tensor) + return _ManagedWork(manager, work, tensor) manager.allreduce = allreduce diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index e92d4bd7..5129e670 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -530,7 +530,9 @@ def _bucketize_and_allreduce( flat_buffer, should_quantize=self.should_quantize ) - def callback(fut: torch.futures.Future[torch.Tensor]) -> None: + def callback( + fut: torch.futures.Future[list[torch.Tensor]], + ) -> list[torch.Tensor]: with torch.cuda.stream(self._stream) if self._stream else nullcontext(): nonlocal bucket_tensors, flat_buffer # Setup stream dependency @@ -540,9 +542,10 @@ def callback(fut: torch.futures.Future[torch.Tensor]) -> None: flat_buffer[pack_offset : pack_offset + numel].view_as(t) ) - work.synchronize() + return [] + fut = work.get_future() - fut.add_done_callback(callback) + fut = fut.then(callback) self._allreduce_work.append(work) diff --git a/torchft/manager.py b/torchft/manager.py index ad9a0566..e947c729 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -31,6 +31,7 @@ import socket import traceback import uuid +import weakref from concurrent.futures import ThreadPoolExecutor from contextlib import nullcontext from datetime import timedelta @@ -379,6 +380,7 @@ def allreduce( self, tensor: torch.Tensor, should_quantize: bool = False, + reduce_op: ReduceOp = ReduceOp.SUM, ) -> Work: """ Fault tolerant allreduce the tensor and return a Future that will be completed when @@ -413,12 +415,26 @@ def allreduce( # it later. if should_quantize and IS_TRITON_AVAILABLE: work = allreduce_quantized( - [tensor], ReduceOp.SUM, self._pg, torch.cuda.current_stream() + [tensor], reduce_op, self._pg, torch.cuda.current_stream() ) else: - work = self._pg.allreduce([tensor], ReduceOp.SUM) - - return _ManagedWork(work, self, tensor, num_participants) + work = self._pg.allreduce([tensor], reduce_op) + + # schedule grad normalization as a continuation + # on the Future + @torch.profiler.record_function("torchft::manager::allreduce::callback") + def callback( + fut: torch.futures.Future[list[torch.Tensor]], + ) -> torch.Tensor: + nonlocal num_participants, tensor, reduce_op + if reduce_op == ReduceOp.SUM: + tensor /= num_participants + return tensor + + managed_work = _ManagedWork(self, work, tensor) + fut = managed_work.get_future() + fut = fut.then(callback) + return managed_work except Exception as e: self._logger.exception( @@ -943,81 +959,251 @@ def exception(self, msg: str) -> None: self._logger.exception(f"{self.prefix()} {msg}") +T = TypeVar("T") +S = TypeVar("S") + + +class _SimpleFuture(torch.futures.Future[T]): + """ + A simplified implementation of torch.futures.Future that wraps a value. + + This class provides a minimal Future implementation that holds a pre-determined value. + It's primarily used as a wrapper for values in the callback chain of `_ManagedFuture`. + Most methods raise `RuntimeError` as they're not intended to be called. + + This class is designed to be used only in specific contexts where we don't + want to call `value()` on the underlying `Future` as that would cause the CPU to block. + """ + + def __init__(self, value: T) -> None: + super().__init__() + self._value = value + + def value(self) -> T: + return self._value + + def then( + self, callback: Callable[[torch.futures.Future[T]], S] + ) -> torch.futures.Future[S]: + raise RuntimeError("should not be called") + + def wait(self) -> T: + raise RuntimeError("should not be called") + + def done(self) -> bool: + raise RuntimeError("should not be called") + + def add_done_callback( + self, callback: Callable[[torch.futures.Future[T]], None] + ) -> None: + raise RuntimeError("should not be called") + + def set_result(self, result: T) -> None: + raise RuntimeError("should not be called") + + def set_exception(self, result: T) -> None: + raise RuntimeError("should not be called") + + +class _ManagedFuture(torch.futures.Future[T]): + """ + A specialized Future implementation that works alongside `_ManagedWork`. + + This class extends torch.futures.Future to provide future chaining that is + lazy - `then()` method simply stores the callback, which is only executed when + `wait()` is called on `_ManagedFuture` or `_ManagedWork` + + Callback chains are implemented as a linked list of `_ManagedFuture` objects through the + `_next` attribute. When appending a callback to the chain, it also updates the tail of the + linked list stored in `_ManagedWork`. + + Delegates actual future operations to an internal torch.futures.Future. + + Raises RuntimeError for methods that should not be called. + """ + + def __init__(self, managed_work: weakref.ReferenceType["_ManagedWork"]) -> None: + super().__init__() + # Store a weak reference to _ManagedWork to avoid reference cycles + self._managed_work = managed_work + + # The underlying torch.futures.Future that this class delegates to + self._fut: Optional[torch.futures.Future[T]] = None + + # The next future in the callback chain + self._next: Optional[_ManagedFuture[object]] = None + + # The callback to be executed when the future is completed - this callback + # returns the next future in the chain + self._callback: Optional[Callable[[torch.futures.Future[T]], object]] = None + + def then( + self, + callback: Callable[[torch.futures.Future[T]], S], + ) -> torch.futures.Future[S]: + """ + Sets the callback to be executed when the future is completed. + + Since the callback returns a future, this method also creates a new future + in the chain and also updates the tail of the chain in `_ManagedWork`. + """ + managed_work = self._managed_work() + assert managed_work is not None, "got garbage collected" + + self._callback = callback + self._next = _ManagedFuture[object](self._managed_work) + managed_work._managed_fut_tail = self._next + return cast(torch.futures.Future[S], self._next) + + def wait(self) -> T: + assert self._fut + return self._fut.wait() + + def value(self) -> T: + raise RuntimeError("should not be called") + + def done(self) -> bool: + raise RuntimeError("should not be called") + + def add_done_callback( + self, callback: Callable[[torch.futures.Future[T]], None] + ) -> None: + raise RuntimeError("should not be called") + + def set_result(self, result: T) -> None: + raise RuntimeError("should not be called") + + def set_exception(self, result: T) -> None: + raise RuntimeError("should not be called") + + class _ManagedWork(dist._Work): + """ + A specialized `Work` implementation that works alongside `_ManagedFuture` to create + callback chains lazily. The callback chain is created when `wait()`, `block_current_stream()` + or `synchronize()` are called. + """ + def __init__( self, - work: dist._Work, manager: Manager, - tensor: torch.Tensor, - num_participants: int, + work: dist._Work, + value: object, ) -> None: super().__init__() - self._manager = manager + # Underlying `Work` retruned from process group operations self._work = work - self._tensor = tensor - self._num_participants = num_participants - self._fut: Union[ - torch.futures.Future[torch.Tensor], torch.futures.Future[None] - ] = work.get_future() + # Used to report errors to the manager through `wrap_future()` + self._manager = manager + + # The value returned by the final future in the callback chain + self._value = value + + # The head of the callback chain + self._managed_fut_head = _ManagedFuture[object](weakref.ref(self)) + + # The tail of the callback chain + self._managed_fut_tail: _ManagedFuture[object] = self._managed_fut_head + + # The stream used to created the `Work` - we ensure all operations in the future + # callback chain are executed on this stream self._stream: Optional[torch.cuda.Stream] = ( torch.cuda.current_stream() if torch.cuda.is_available() else None ) + # To ensure the future callback chain is only created once self._is_set_future_callback_called = False def _set_future_callback( self, ) -> None: + """ + Sets up the stored future callback chain. + + This method creates a chain of callbacks for the futures in the managed work, + ensuring that each callback is executed in the proper order and with the + appropriate stream context. It also wraps the futures with error handling + through the manager's `wrap_future` method. + + The method is called internally when waiting or synchronizing on the work. + """ if self._is_set_future_callback_called: return - # schedule grad normalization as a continuation - # on the Future - @torch.profiler.record_function("torchft::manager::allreduce::callback") - def callback( - fut: torch.futures.Future[List[torch.Tensor]], - ) -> torch.Tensor: - # change the stream to avoid making the callback stream - # dependent on process group stream running the allreduce - with ( - torch.cuda.stream(self._stream) - if self._stream is not None - else nullcontext() - ): - # Setup stream dependency - fut.wait() - self._tensor /= self._num_participants + managed_fut: _ManagedFuture[object] = self._managed_fut_head + managed_fut._fut = self._work.get_future() + value = self._value + + is_future_wrapped = False + while managed_fut._next: + + def callback( + fut: torch.futures.Future[object], + ) -> object: + nonlocal managed_fut, value + # change the stream to avoid making the callback stream + # dependent on process group stream running the allreduce + with ( + torch.cuda.stream(self._stream) + if self._stream is not None + else nullcontext() + ): + # Setup stream dependency + fut.wait() + assert managed_fut._callback + value = managed_fut._callback( + _SimpleFuture(value), + ) + return value - return self._tensor + assert managed_fut._fut + fut = managed_fut._fut.then(callback) + assert managed_fut._next + managed_fut = managed_fut._next + managed_fut._fut = fut - fut = self._fut - fut = fut.then(callback) - fut = self._manager.wrap_future(fut, self._tensor) - self._fut = fut + if is_future_wrapped: + continue + managed_fut._fut = self._manager.wrap_future(managed_fut._fut, value) + is_future_wrapped = True + + self._value = value self._is_set_future_callback_called = True + def _assert_same_stream(self) -> None: + """ + Asserts that the current CUDA stream is the same as the one used to create this work. + + This makes sure users of the API are aware about stream dependencies. + """ + if self._stream is not None: + assert self._stream == torch.cuda.current_stream() + def wait(self, timeout: Optional[timedelta] = None) -> bool: + self._assert_same_stream() + with ( torch.cuda.stream(self._stream) if self._stream is not None else nullcontext() ): self._work.wait() - - self._set_future_callback() + self._set_future_callback() with ( torch.cuda.stream(self._stream) if self._stream is not None else nullcontext() ): - self._fut.wait() + self._managed_fut_tail.wait() return True def block_current_stream(self, timeout: Optional[timedelta] = None) -> None: + self._assert_same_stream() + with ( torch.cuda.stream(self._stream) if self._stream is not None @@ -1028,6 +1214,8 @@ def block_current_stream(self, timeout: Optional[timedelta] = None) -> None: self._set_future_callback() def synchronize(self) -> None: + self._assert_same_stream() + if torch.cuda.is_available(): self.block_current_stream() else: @@ -1036,8 +1224,11 @@ def synchronize(self) -> None: def get_future( self, - ) -> Union[torch.futures.Future[torch.Tensor], torch.futures.Future[None]]: - assert ( - self._is_set_future_callback_called - ), "getting the future without calling synchronize() is unsafe" - return self._fut + ) -> torch.futures.Future[object]: + """ + Returns: + The tail of the managed future chain, which represents the final + result of all the chained operations. This future will be completed when + all the work and its callbacks have been executed. + """ + return self._managed_fut_tail diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py index ed2d11e8..e75d5dde 100644 --- a/torchft/manager_integ_test.py +++ b/torchft/manager_integ_test.py @@ -40,6 +40,7 @@ ProcessGroupGloo, ) +logging.basicConfig(level=logging.INFO) logger: logging.Logger = logging.getLogger(__name__) INIT_LOCK: threading.Lock = threading.Lock() @@ -638,3 +639,9 @@ def all_reduce_callback( work.wait() return t1 return None + + +if __name__ == "__main__": + import unittest + + unittest.main() diff --git a/torchft/manager_test.py b/torchft/manager_test.py index 6960abce..ca8a07e8 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -590,20 +590,18 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None: manager._pg.allreduce.return_value = _DummyWork(None) self.assertTrue(manager.is_participating()) - work = manager.allreduce(torch.tensor([1.0])) - work.synchronize() - fut = work.get_future() - result = fut.value() - torch.testing.assert_close(result, torch.tensor([1.0 / 5])) + tensor = torch.tensor([1.0]) + work = manager.allreduce(tensor) + work.wait() + torch.testing.assert_close(tensor, torch.tensor([1.0 / 5])) # check healing numerics manager._healing = True self.assertFalse(manager.is_participating()) - work = manager.allreduce(torch.tensor([1.0])) - work.synchronize() - fut = work.get_future() - result = fut.value() - torch.testing.assert_close(result, torch.tensor([0.0])) + tensor = torch.tensor([1.0]) + work = manager.allreduce(tensor) + work.wait() + torch.testing.assert_close(tensor, torch.tensor([0.0])) @patch("torchft.manager.ManagerClient", autospec=True) def test_quorum_happy_timeouts(self, client_mock: MagicMock) -> None: diff --git a/torchft/process_group.py b/torchft/process_group.py index ac6617f2..854dee12 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -1062,30 +1062,6 @@ def callback( return work -class _ManagedWork(Work): - def __init__(self, manager: "Manager", work: Work, default_result: object) -> None: - super().__init__() - - self._manager = manager - self._work = work - self._default_result = default_result - - def wait(self, timeout: Optional[timedelta] = None) -> bool: - try: - if self._work is not None: - if timeout is not None: - self._work.wait(timeout) - else: - self._work.wait() - except Exception as e: - self._manager.report_error(e) - - return True - - def get_future(self) -> Future[object]: - return self._manager.wrap_future(self._work.get_future(), self._default_result) - - class ManagedProcessGroup(ProcessGroupWrapper): """ This is a wrapper around any ProcessGroup that is managed by a torchft @@ -1105,23 +1081,13 @@ def __init__(self, manager: "Manager") -> None: self._manager = manager def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: - # Ensure we have a valid quorum and are configured before trying to do - # any work. - self._manager.wait_quorum() + if isinstance(opts, ReduceOp): + return self._manager.allreduce(tensors, reduce_op=opts) - if self._manager.errored() is not None: - return _DummyWork(tensors) - try: - work = super().allreduce(tensors, opts) - except Exception as e: - self._manager.report_error(e) - return _DummyWork(tensors) + if isinstance(opts, AllreduceOptions): + return self._manager.allreduce(tensors, reduce_op=opts.reduceOp) - return _ManagedWork( - self._manager, - work, - tensors, - ) + assert False, "unreachable" def size(self) -> int: return self._manager.num_participants() diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index 6d2c0a53..bc364e5f 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -48,7 +48,6 @@ ProcessGroupNCCL, ProcessGroupWrapper, _ErrorSwallowingWork, - _ManagedWork, extend_device_mesh, ft_init_device_mesh, ) @@ -810,11 +809,8 @@ def test_managed_process_group(self) -> None: self.assertEqual(pg.size(), 123) works = _test_pg(pg) - self.assertIsInstance(list(works.values())[0], _ManagedWork) - self.assertEqual(manager.report_error.call_count, 0) - self.assertEqual(manager.wrap_future.call_count, 2) - self.assertEqual(manager.wait_quorum.call_count, 2) + self.assertEqual(manager.allreduce.call_count, 2) class DeviceMeshTest(TestCase):