Skip to content

Commit b72ddcc

Browse files
committed
make managed futures generic
1 parent 9a72b49 commit b72ddcc

File tree

5 files changed

+338
-126
lines changed

5 files changed

+338
-126
lines changed

torchft/_test/managed_work_test.py

Lines changed: 136 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
import types
88
import unittest
99
from datetime import timedelta
10-
from typing import Callable, List, Optional
10+
from typing import Callable, Dict, List, Optional, Tuple, TypeVar, cast
11+
12+
# Define a type variable for the Future's value type
13+
T = TypeVar("T")
1114

1215
import parameterized
1316
import torch
@@ -51,7 +54,7 @@ def test_callbacks_execute_after_wait(
5154
self.skipTest("CUDA not available")
5255

5356
# Create a tensor to work with
54-
tensor = torch.ones(1, dtype=torch.float32, device=device)
57+
tensor: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device)
5558

5659
# Create a simple work object
5760
work = SimpleWork([tensor])
@@ -65,20 +68,22 @@ def test_callbacks_execute_after_wait(
6568
)
6669

6770
# Create the managed work
68-
managed_work = _ManagedWork(work, manager, [tensor])
71+
managed_work = _ManagedWork(manager, work, [tensor])
6972

7073
# Track callback execution
7174
callback_executed: bool = False
7275

73-
def callback(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]:
74-
nonlocal callback_executed
76+
def callback(fut: Future[object]) -> List[torch.Tensor]:
77+
# Cast to the expected type
78+
nonlocal callback_executed, tensor
7579
callback_executed = True
7680
# Multiply tensor by 2 to verify the callback ran
77-
fut.value()[0].mul_(2)
78-
return fut.value()
81+
tensor.mul_(2)
82+
return [tensor]
7983

8084
# Add the callback
81-
managed_work.add_callback(callback)
85+
fut = managed_work.get_future()
86+
fut = fut.then(callback)
8287

8388
# Verify callback hasn't executed yet
8489
self.assertFalse(callback_executed)
@@ -106,7 +111,7 @@ def test_multiple_callbacks_execute_in_order(
106111
self.skipTest("CUDA not available")
107112

108113
# Create a tensor to work with
109-
tensor = torch.ones(1, dtype=torch.float32, device=device)
114+
tensor: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device)
110115

111116
# Create a simple work object
112117
work = SimpleWork([tensor])
@@ -118,30 +123,35 @@ def test_multiple_callbacks_execute_in_order(
118123
)
119124

120125
# Create the managed work
121-
managed_work = _ManagedWork(work, manager, [tensor])
126+
managed_work = _ManagedWork(manager, work, [tensor])
122127

123128
# Track execution order
124129
execution_order: List[int] = []
125130

126-
def callback1(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]:
131+
def callback1(fut: Future[list[torch.Tensor]]) -> List[torch.Tensor]:
132+
nonlocal tensor
127133
execution_order.append(1)
128-
fut.value()[0].add_(1)
129-
return fut.value()
134+
tensor.add_(1)
135+
return [tensor]
130136

131-
def callback2(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]:
137+
def callback2(fut: Future[list[torch.Tensor]]) -> List[torch.Tensor]:
138+
nonlocal tensor
132139
execution_order.append(2)
133-
fut.value()[0].add_(2)
134-
return fut.value()
140+
tensor.add_(2)
141+
return [tensor]
135142

136-
def callback3(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]:
143+
def callback3(fut: Future[list[torch.Tensor]]) -> List[torch.Tensor]:
144+
nonlocal tensor
137145
execution_order.append(3)
138-
fut.value()[0].add_(3)
139-
return fut.value()
146+
tensor.add_(3)
147+
return [tensor]
140148

141149
# Add callbacks
142-
managed_work.add_callback(callback1)
143-
managed_work.add_callback(callback2)
144-
managed_work.add_callback(callback3)
150+
fut = managed_work.get_future()
151+
fut = cast(Future[list[torch.Tensor]], fut)
152+
fut = fut.then(callback1)
153+
fut = fut.then(callback2)
154+
fut = fut.then(callback3)
145155

146156
# Verify no callbacks have executed yet
147157
self.assertEqual(len(execution_order), 0)
@@ -169,7 +179,7 @@ def test_future_then_api(self, name: str, device: torch.device) -> None:
169179
self.skipTest("CUDA not available")
170180

171181
# Create a tensor to work with
172-
tensor = torch.ones(1, dtype=torch.float32, device=device)
182+
tensor: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device)
173183

174184
# Create a simple work object
175185
work = SimpleWork([tensor])
@@ -181,35 +191,130 @@ def test_future_then_api(self, name: str, device: torch.device) -> None:
181191
)
182192

183193
# Create the managed work
184-
managed_work = _ManagedWork(work, manager, [tensor])
194+
managed_work = _ManagedWork(manager, work, [tensor])
185195

186196
# Get the future
187197
future = managed_work.get_future()
188198

189199
# Track callback execution
190200
callback_executed: bool = False
191201

192-
def callback(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]:
193-
nonlocal callback_executed
202+
def callback(fut: Future[object]) -> List[torch.Tensor]:
203+
# Cast to the expected type
204+
nonlocal callback_executed, tensor
194205
callback_executed = True
195206
# Multiply tensor by 3 to verify the callback ran
196-
fut.value()[0].mul_(3)
197-
return fut.value()
207+
tensor.mul_(3)
208+
return [tensor]
198209

199210
# Use the then API
200-
future.then(callback)
211+
future = future.then(callback)
201212

202213
# Verify callback hasn't executed yet
203214
self.assertFalse(callback_executed)
204215
self.assertEqual(tensor.item(), 1.0)
205216

206-
# Call wait() which should trigger the callback
207-
future.wait()
217+
# Call wait() on the managed_work first to set up the future properly
218+
managed_work.wait()
208219

209220
# Verify callback has executed
210221
self.assertTrue(callback_executed)
211222
self.assertEqual(tensor.item(), 3.0)
212223

224+
@parameterized.parameterized.expand(
225+
[
226+
("cpu", torch.device("cpu")),
227+
("cuda", torch.device("cuda:0")),
228+
]
229+
)
230+
def test_callbacks_changing_return_types(
231+
self, name: str, device: torch.device
232+
) -> None:
233+
"""
234+
Test that callbacks can change return types and that tensors are modified in-place.
235+
This test demonstrates:
236+
1. Callbacks changing return types (List[Tensor] -> Dict -> Tuple)
237+
2. Using Future.value() instead of nonlocal
238+
3. Verifying tensors are modified in-place for both approaches
239+
"""
240+
# Skip if CUDA is requested but not available
241+
if device.type == "cuda" and not torch.cuda.is_available():
242+
self.skipTest("CUDA not available")
243+
244+
# Create tensors to work with
245+
tensor1: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device)
246+
tensor2: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device) * 2
247+
248+
# Store original tensor memory addresses to verify in-place modification
249+
tensor1_address = tensor1.data_ptr()
250+
tensor2_address = tensor2.data_ptr()
251+
252+
# Create a simple work object
253+
work = SimpleWork([tensor1, tensor2])
254+
255+
# Create a minimal manager object with just the wrap_future method
256+
manager = Manager.__new__(Manager) # Create instance without calling __init__
257+
manager.wrap_future = types.MethodType( # type: ignore
258+
lambda self, fut, default, timeout=None: fut, manager
259+
)
260+
261+
# Create the managed work
262+
managed_work = _ManagedWork(manager, work, [tensor1, tensor2])
263+
264+
# Get the future
265+
future = managed_work.get_future()
266+
future = cast(Future[List[torch.Tensor]], future)
267+
268+
# First callback: Takes List[Tensor] and returns Dict[str, Tensor]
269+
# Uses nonlocal to modify tensor1
270+
def callback1(fut: Future[List[torch.Tensor]]) -> Dict[str, torch.Tensor]:
271+
tensors = fut.value()
272+
nonlocal tensor1
273+
# Modify tensor1 in-place using nonlocal
274+
tensor1.mul_(3)
275+
# Return a dictionary instead of a list
276+
return {"first": tensors[0], "second": tensors[1]}
277+
278+
# Second callback: Takes Dict[str, Tensor] and returns Tuple[Tensor, float]
279+
# Uses Future.value() to modify tensor2
280+
def callback2(
281+
fut: Future[Dict[str, torch.Tensor]]
282+
) -> Tuple[torch.Tensor, float]:
283+
data = fut.value()
284+
# Modify tensor2 in-place using the value from the future
285+
data["second"].add_(5) # Should modify tensor2 in-place
286+
# Return a tuple instead of a dict
287+
return (data["second"], data["first"].item())
288+
289+
# Third callback: Takes Tuple[Tensor, float] and returns a single Tensor
290+
def callback3(fut: Future[Tuple[torch.Tensor, float]]) -> torch.Tensor:
291+
tensor, value = fut.value()
292+
# Create a new tensor based on the tuple values
293+
result = tensor * value
294+
return result
295+
296+
# Chain the callbacks
297+
future = future.then(callback1)
298+
future = future.then(callback2)
299+
future = future.then(callback3)
300+
301+
# Call wait() to trigger the callbacks
302+
managed_work.wait()
303+
304+
# Verify tensor1 was modified in-place (using nonlocal)
305+
self.assertEqual(tensor1.item(), 3.0) # 1 * 3 = 3
306+
self.assertEqual(tensor1.data_ptr(), tensor1_address) # Same memory address
307+
308+
# Verify tensor2 was modified in-place (using Future.value())
309+
self.assertEqual(tensor2.item(), 7.0) # 2 + 5 = 7
310+
self.assertEqual(tensor2.data_ptr(), tensor2_address) # Same memory address
311+
312+
# Get the final result from the future
313+
final_result = future.wait()
314+
315+
# The final result should be tensor2 * tensor1.item() = 7 * 3 = 21
316+
self.assertEqual(final_result.item(), 21.0)
317+
213318

214319
if __name__ == "__main__":
215320
unittest.main()

torchft/ddp.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import os
1616
import sys
17-
from typing import TYPE_CHECKING, Optional
17+
from typing import TYPE_CHECKING, Optional, cast
1818
from unittest.mock import patch
1919

2020
import torch
@@ -26,7 +26,7 @@
2626
from torchft.process_group import ProcessGroup, ProcessGroupDummy, ProcessGroupGloo
2727

2828
if TYPE_CHECKING:
29-
from torchft.manager import Manager
29+
from torchft.manager import Manager, _ManagedFuture
3030

3131

3232
class DistributedDataParallel(parallel.DistributedDataParallel):
@@ -69,22 +69,14 @@ def _comm_hook(
6969
state: "Manager", bucket: dist.GradBucket
7070
) -> torch.futures.Future[torch.Tensor]:
7171
work = state.allreduce(bucket.buffer())
72-
73-
result_fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
74-
72+
work.wait()
7573
fut = work.get_future()
7674

77-
def callback(
78-
tensors: torch.futures.Future[list[torch.Tensor]],
79-
) -> list[torch.Tensor]:
80-
nonlocal result_fut
81-
result_fut.set_result(tensors.value()[0])
82-
return []
83-
84-
fut = fut.then(callback)
85-
86-
work.wait()
87-
return result_fut
75+
# We need to return the underlying future here otherwise
76+
# this can hang
77+
fut = cast("_ManagedFuture[torch.Tensor]", fut)
78+
assert fut._fut
79+
return fut._fut
8880

8981

9082
class PureDistributedDataParallel(nn.Module):

torchft/ddp_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torch.futures import Future
1515

1616
from torchft.ddp import DistributedDataParallel, PureDistributedDataParallel
17-
from torchft.manager import Manager
17+
from torchft.manager import Manager, _ManagedWork
1818
from torchft.process_group import ProcessGroupBabyGloo, ProcessGroupGloo
1919
from torchft.work import _DummyWork
2020

@@ -41,14 +41,16 @@ def test_ddp(self) -> None:
4141

4242
call_count = 0
4343

44+
# pyre-ignore[53]: Captured variable `manager` is not annotated.
4445
def allreduce(
4546
tensor: torch.Tensor,
4647
) -> Work:
47-
nonlocal call_count
48+
nonlocal call_count, manager
4849

4950
call_count += 1
5051

51-
return _DummyWork(tensor)
52+
work = _DummyWork(tensor)
53+
return _ManagedWork(manager, work, tensor)
5254

5355
manager.allreduce = allreduce
5456

0 commit comments

Comments
 (0)