Skip to content

Commit 855bcad

Browse files
authored
clean up managed work (#253)
Summary: - improve abstraction and api contract offered by managed work - users can change future return types, ensures users use correct cuda streams - added a managed future that is lazily attaches callbacks - based on when users call wait on the work or the future
1 parent 1078c01 commit 855bcad

File tree

7 files changed

+605
-58
lines changed

7 files changed

+605
-58
lines changed

torchft/_test/managed_work_test.py

Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import types
8+
import unittest
9+
from datetime import timedelta
10+
from typing import Callable, Dict, List, Optional, Tuple, TypeVar, cast
11+
12+
# Define a type variable for the Future's value type
13+
T = TypeVar("T")
14+
15+
import parameterized
16+
import torch
17+
from torch.distributed.distributed_c10d import Work
18+
from torch.futures import Future
19+
20+
from torchft.manager import Manager, _ManagedWork
21+
22+
23+
class SimpleWork(Work):
24+
"""A simple implementation of torch.distributed.Work for testing."""
25+
26+
def __init__(self, tensors: List[torch.Tensor]) -> None:
27+
super().__init__()
28+
self._tensors = tensors
29+
self._future: Future[List[torch.Tensor]] = torch.futures.Future()
30+
self._is_completed: bool = False
31+
32+
def wait(self, timeout: Optional[timedelta] = None) -> bool:
33+
self._is_completed = True
34+
self._future.set_result(self._tensors)
35+
return True
36+
37+
def get_future(self) -> Future[List[torch.Tensor]]:
38+
return self._future
39+
40+
41+
class TestManagedWork(unittest.TestCase):
42+
@parameterized.parameterized.expand(
43+
[
44+
("cpu", torch.device("cpu")),
45+
("cuda", torch.device("cuda:0")),
46+
]
47+
)
48+
def test_callbacks_execute_after_wait(
49+
self, name: str, device: torch.device
50+
) -> None:
51+
"""Test that callbacks are only executed after wait() is called."""
52+
# Skip if CUDA is requested but not available
53+
if device.type == "cuda" and not torch.cuda.is_available():
54+
self.skipTest("CUDA not available")
55+
56+
# Create a tensor to work with
57+
tensor: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device)
58+
59+
# Create a simple work object
60+
work = SimpleWork([tensor])
61+
62+
# Create a minimal manager object with just the wrap_future method
63+
manager = Manager.__new__(Manager) # Create instance without calling __init__
64+
# We're using types.MethodType to attach a method to the manager instance
65+
# This is just for testing purposes
66+
manager.wrap_future = types.MethodType( # type: ignore
67+
lambda self, fut, default, timeout=None: fut, manager
68+
)
69+
70+
# Create the managed work
71+
managed_work = _ManagedWork(manager, work, [tensor])
72+
73+
# Track callback execution
74+
callback_executed: bool = False
75+
76+
def callback(fut: Future[object]) -> List[torch.Tensor]:
77+
# Cast to the expected type
78+
nonlocal callback_executed, tensor
79+
callback_executed = True
80+
# Multiply tensor by 2 to verify the callback ran
81+
tensor.mul_(2)
82+
return [tensor]
83+
84+
# Add the callback
85+
fut = managed_work.get_future()
86+
fut = fut.then(callback)
87+
88+
# Verify callback hasn't executed yet
89+
self.assertFalse(callback_executed)
90+
self.assertEqual(tensor.item(), 1.0)
91+
92+
# Call wait() which should trigger the callback
93+
managed_work.wait()
94+
95+
# Verify callback has executed
96+
self.assertTrue(callback_executed)
97+
self.assertEqual(tensor.item(), 2.0)
98+
99+
@parameterized.parameterized.expand(
100+
[
101+
("cpu", torch.device("cpu")),
102+
("cuda", torch.device("cuda:0")),
103+
]
104+
)
105+
def test_multiple_callbacks_execute_in_order(
106+
self, name: str, device: torch.device
107+
) -> None:
108+
"""Test that multiple callbacks are executed in the order they were added."""
109+
# Skip if CUDA is requested but not available
110+
if device.type == "cuda" and not torch.cuda.is_available():
111+
self.skipTest("CUDA not available")
112+
113+
# Create a tensor to work with
114+
tensor: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device)
115+
116+
# Create a simple work object
117+
work = SimpleWork([tensor])
118+
119+
# Create a minimal manager object with just the wrap_future method
120+
manager = Manager.__new__(Manager) # Create instance without calling __init__
121+
manager.wrap_future = types.MethodType( # type: ignore
122+
lambda self, fut, default, timeout=None: fut, manager
123+
)
124+
125+
# Create the managed work
126+
managed_work = _ManagedWork(manager, work, [tensor])
127+
128+
# Track execution order
129+
execution_order: List[int] = []
130+
131+
def callback1(fut: Future[list[torch.Tensor]]) -> List[torch.Tensor]:
132+
nonlocal tensor
133+
execution_order.append(1)
134+
tensor.add_(1)
135+
return [tensor]
136+
137+
def callback2(fut: Future[list[torch.Tensor]]) -> List[torch.Tensor]:
138+
nonlocal tensor
139+
execution_order.append(2)
140+
tensor.add_(2)
141+
return [tensor]
142+
143+
def callback3(fut: Future[list[torch.Tensor]]) -> List[torch.Tensor]:
144+
nonlocal tensor
145+
execution_order.append(3)
146+
tensor.add_(3)
147+
return [tensor]
148+
149+
# Add callbacks
150+
fut = managed_work.get_future()
151+
fut = cast(Future[list[torch.Tensor]], fut)
152+
fut = fut.then(callback1)
153+
fut = fut.then(callback2)
154+
fut = fut.then(callback3)
155+
156+
# Verify no callbacks have executed yet
157+
self.assertEqual(len(execution_order), 0)
158+
self.assertEqual(tensor.item(), 1.0)
159+
160+
# Call wait() which should trigger the callbacks
161+
managed_work.wait()
162+
163+
# Verify callbacks executed in order
164+
self.assertEqual(execution_order, [1, 2, 3])
165+
166+
# Each callback adds to the tensor, so final value should be 1 + 1 + 2 + 3 = 7
167+
self.assertEqual(tensor.item(), 7.0)
168+
169+
@parameterized.parameterized.expand(
170+
[
171+
("cpu", torch.device("cpu")),
172+
("cuda", torch.device("cuda:0")),
173+
]
174+
)
175+
def test_future_then_api(self, name: str, device: torch.device) -> None:
176+
"""Test that the future's then API works correctly with ManagedWork."""
177+
# Skip if CUDA is requested but not available
178+
if device.type == "cuda" and not torch.cuda.is_available():
179+
self.skipTest("CUDA not available")
180+
181+
# Create a tensor to work with
182+
tensor: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device)
183+
184+
# Create a simple work object
185+
work = SimpleWork([tensor])
186+
187+
# Create a minimal manager object with just the wrap_future method
188+
manager = Manager.__new__(Manager) # Create instance without calling __init__
189+
manager.wrap_future = types.MethodType( # type: ignore
190+
lambda self, fut, default, timeout=None: fut, manager
191+
)
192+
193+
# Create the managed work
194+
managed_work = _ManagedWork(manager, work, [tensor])
195+
196+
# Get the future
197+
future = managed_work.get_future()
198+
199+
# Track callback execution
200+
callback_executed: bool = False
201+
202+
def callback(fut: Future[object]) -> List[torch.Tensor]:
203+
# Cast to the expected type
204+
nonlocal callback_executed, tensor
205+
callback_executed = True
206+
# Multiply tensor by 3 to verify the callback ran
207+
tensor.mul_(3)
208+
return [tensor]
209+
210+
# Use the then API
211+
future = future.then(callback)
212+
213+
# Verify callback hasn't executed yet
214+
self.assertFalse(callback_executed)
215+
self.assertEqual(tensor.item(), 1.0)
216+
217+
# Call wait() on the managed_work first to set up the future properly
218+
managed_work.wait()
219+
220+
# Verify callback has executed
221+
self.assertTrue(callback_executed)
222+
self.assertEqual(tensor.item(), 3.0)
223+
224+
@parameterized.parameterized.expand(
225+
[
226+
("cpu", torch.device("cpu")),
227+
("cuda", torch.device("cuda:0")),
228+
]
229+
)
230+
def test_callbacks_changing_return_types(
231+
self, name: str, device: torch.device
232+
) -> None:
233+
"""
234+
Test that callbacks can change return types and that tensors are modified in-place.
235+
This test demonstrates:
236+
1. Callbacks changing return types (List[Tensor] -> Dict -> Tuple)
237+
2. Using Future.value() instead of nonlocal
238+
3. Verifying tensors are modified in-place for both approaches
239+
"""
240+
# Skip if CUDA is requested but not available
241+
if device.type == "cuda" and not torch.cuda.is_available():
242+
self.skipTest("CUDA not available")
243+
244+
# Create tensors to work with
245+
tensor1: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device)
246+
tensor2: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device) * 2
247+
248+
# Store original tensor memory addresses to verify in-place modification
249+
tensor1_address = tensor1.data_ptr()
250+
tensor2_address = tensor2.data_ptr()
251+
252+
# Create a simple work object
253+
work = SimpleWork([tensor1, tensor2])
254+
255+
# Create a minimal manager object with just the wrap_future method
256+
manager = Manager.__new__(Manager) # Create instance without calling __init__
257+
manager.wrap_future = types.MethodType( # type: ignore
258+
lambda self, fut, default, timeout=None: fut, manager
259+
)
260+
261+
# Create the managed work
262+
managed_work = _ManagedWork(manager, work, [tensor1, tensor2])
263+
264+
# Get the future
265+
future = managed_work.get_future()
266+
future = cast(Future[List[torch.Tensor]], future)
267+
268+
# First callback: Takes List[Tensor] and returns Dict[str, Tensor]
269+
# Uses nonlocal to modify tensor1
270+
def callback1(fut: Future[List[torch.Tensor]]) -> Dict[str, torch.Tensor]:
271+
tensors = fut.value()
272+
nonlocal tensor1
273+
# Modify tensor1 in-place using nonlocal
274+
tensor1.mul_(3)
275+
# Return a dictionary instead of a list
276+
return {"first": tensors[0], "second": tensors[1]}
277+
278+
# Second callback: Takes Dict[str, Tensor] and returns Tuple[Tensor, float]
279+
# Uses Future.value() to modify tensor2
280+
def callback2(
281+
fut: Future[Dict[str, torch.Tensor]]
282+
) -> Tuple[torch.Tensor, float]:
283+
data = fut.value()
284+
# Modify tensor2 in-place using the value from the future
285+
data["second"].add_(5) # Should modify tensor2 in-place
286+
# Return a tuple instead of a dict
287+
return (data["second"], data["first"].item())
288+
289+
# Third callback: Takes Tuple[Tensor, float] and returns a single Tensor
290+
def callback3(fut: Future[Tuple[torch.Tensor, float]]) -> torch.Tensor:
291+
tensor, value = fut.value()
292+
# Create a new tensor based on the tuple values
293+
result = tensor * value
294+
return result
295+
296+
# Chain the callbacks
297+
future = future.then(callback1)
298+
future = future.then(callback2)
299+
future = future.then(callback3)
300+
301+
# Call wait() to trigger the callbacks
302+
managed_work.wait()
303+
304+
# Verify tensor1 was modified in-place (using nonlocal)
305+
self.assertEqual(tensor1.item(), 3.0) # 1 * 3 = 3
306+
self.assertEqual(tensor1.data_ptr(), tensor1_address) # Same memory address
307+
308+
# Verify tensor2 was modified in-place (using Future.value())
309+
self.assertEqual(tensor2.item(), 7.0) # 2 + 5 = 7
310+
self.assertEqual(tensor2.data_ptr(), tensor2_address) # Same memory address
311+
312+
# Get the final result from the future
313+
final_result = future.wait()
314+
315+
# The final result should be tensor2 * tensor1.item() = 7 * 3 = 21
316+
self.assertEqual(final_result.item(), 21.0)
317+
318+
319+
if __name__ == "__main__":
320+
unittest.main()

torchft/ddp.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import os
1616
import sys
17-
from typing import TYPE_CHECKING, Optional
17+
from typing import TYPE_CHECKING, Optional, cast
1818
from unittest.mock import patch
1919

2020
import torch
@@ -26,7 +26,7 @@
2626
from torchft.process_group import ProcessGroup, ProcessGroupDummy, ProcessGroupGloo
2727

2828
if TYPE_CHECKING:
29-
from torchft.manager import Manager
29+
from torchft.manager import Manager, _ManagedFuture
3030

3131

3232
class DistributedDataParallel(parallel.DistributedDataParallel):
@@ -69,8 +69,14 @@ def _comm_hook(
6969
state: "Manager", bucket: dist.GradBucket
7070
) -> torch.futures.Future[torch.Tensor]:
7171
work = state.allreduce(bucket.buffer())
72-
work.synchronize()
73-
return work.get_future()
72+
work.wait()
73+
fut = work.get_future()
74+
75+
# We need to return the underlying future here otherwise
76+
# this can hang
77+
fut = cast("_ManagedFuture[torch.Tensor]", fut)
78+
assert fut._fut
79+
return fut._fut
7480

7581

7682
class PureDistributedDataParallel(nn.Module):

torchft/ddp_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torch.futures import Future
1515

1616
from torchft.ddp import DistributedDataParallel, PureDistributedDataParallel
17-
from torchft.manager import Manager
17+
from torchft.manager import Manager, _ManagedWork
1818
from torchft.process_group import ProcessGroupBabyGloo, ProcessGroupGloo
1919
from torchft.work import _DummyWork
2020

@@ -41,14 +41,16 @@ def test_ddp(self) -> None:
4141

4242
call_count = 0
4343

44+
# pyre-ignore[53]: Captured variable `manager` is not annotated.
4445
def allreduce(
4546
tensor: torch.Tensor,
4647
) -> Work:
4748
nonlocal call_count
4849

4950
call_count += 1
5051

51-
return _DummyWork(tensor)
52+
work = _DummyWork(tensor)
53+
return _ManagedWork(manager, work, tensor)
5254

5355
manager.allreduce = allreduce
5456

0 commit comments

Comments
 (0)