Skip to content

Commit 15cbf05

Browse files
committed
make managed futures generic
1 parent c4fb543 commit 15cbf05

File tree

3 files changed

+156
-98
lines changed

3 files changed

+156
-98
lines changed

torchft/_test/managed_work_test.py

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
import types
88
import unittest
99
from datetime import timedelta
10-
from typing import Callable, List, Optional
10+
from typing import Callable, List, Optional, TypeVar, cast
11+
12+
# Define a type variable for the Future's value type
13+
T = TypeVar("T")
1114

1215
import parameterized
1316
import torch
@@ -65,20 +68,24 @@ def test_callbacks_execute_after_wait(
6568
)
6669

6770
# Create the managed work
68-
managed_work = _ManagedWork(work, manager, [tensor])
71+
managed_work = _ManagedWork(manager, work, [tensor])
6972

7073
# Track callback execution
7174
callback_executed: bool = False
7275

73-
def callback(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]:
76+
def callback(fut: Future[object]) -> List[torch.Tensor]:
77+
# Cast to the expected type
78+
tensor_fut = cast(Future[List[torch.Tensor]], fut)
7479
nonlocal callback_executed
7580
callback_executed = True
7681
# Multiply tensor by 2 to verify the callback ran
77-
fut.value()[0].mul_(2)
78-
return fut.value()
82+
value = tensor_fut.value()
83+
value[0].mul_(2)
84+
return value
7985

8086
# Add the callback
81-
managed_work.add_callback(callback)
87+
fut = managed_work.get_future()
88+
fut = fut.then(callback)
8289

8390
# Verify callback hasn't executed yet
8491
self.assertFalse(callback_executed)
@@ -118,30 +125,40 @@ def test_multiple_callbacks_execute_in_order(
118125
)
119126

120127
# Create the managed work
121-
managed_work = _ManagedWork(work, manager, [tensor])
128+
managed_work = _ManagedWork(manager, work, [tensor])
122129

123130
# Track execution order
124131
execution_order: List[int] = []
125132

126-
def callback1(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]:
133+
def callback1(fut: Future[T]) -> List[torch.Tensor]:
134+
# Cast to the expected type
135+
tensor_fut = cast(Future[List[torch.Tensor]], fut)
127136
execution_order.append(1)
128-
fut.value()[0].add_(1)
129-
return fut.value()
137+
value = tensor_fut.value()
138+
value[0].add_(1)
139+
return value
130140

131-
def callback2(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]:
141+
def callback2(fut: Future[T]) -> List[torch.Tensor]:
142+
# Cast to the expected type
143+
tensor_fut = cast(Future[List[torch.Tensor]], fut)
132144
execution_order.append(2)
133-
fut.value()[0].add_(2)
134-
return fut.value()
145+
value = tensor_fut.value()
146+
value[0].add_(2)
147+
return value
135148

136-
def callback3(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]:
149+
def callback3(fut: Future[T]) -> List[torch.Tensor]:
150+
# Cast to the expected type
151+
tensor_fut = cast(Future[List[torch.Tensor]], fut)
137152
execution_order.append(3)
138-
fut.value()[0].add_(3)
139-
return fut.value()
153+
value = tensor_fut.value()
154+
value[0].add_(3)
155+
return value
140156

141157
# Add callbacks
142-
managed_work.add_callback(callback1)
143-
managed_work.add_callback(callback2)
144-
managed_work.add_callback(callback3)
158+
fut = managed_work.get_future()
159+
fut = fut.then(callback1)
160+
fut = fut.then(callback2)
161+
fut = fut.then(callback3)
145162

146163
# Verify no callbacks have executed yet
147164
self.assertEqual(len(execution_order), 0)
@@ -181,29 +198,35 @@ def test_future_then_api(self, name: str, device: torch.device) -> None:
181198
)
182199

183200
# Create the managed work
184-
managed_work = _ManagedWork(work, manager, [tensor])
201+
managed_work = _ManagedWork(manager, work, [tensor])
185202

186203
# Get the future
187204
future = managed_work.get_future()
188205

189206
# Track callback execution
190207
callback_executed: bool = False
191208

192-
def callback(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]:
209+
def callback(fut: Future[object]) -> List[torch.Tensor]:
210+
# Cast to the expected type
211+
tensor_fut = cast(Future[List[torch.Tensor]], fut)
193212
nonlocal callback_executed
194213
callback_executed = True
195214
# Multiply tensor by 3 to verify the callback ran
196-
fut.value()[0].mul_(3)
197-
return fut.value()
215+
value = tensor_fut.value()
216+
value[0].mul_(3)
217+
return value
198218

199219
# Use the then API
200-
future.then(callback)
220+
future = future.then(callback)
201221

202222
# Verify callback hasn't executed yet
203223
self.assertFalse(callback_executed)
204224
self.assertEqual(tensor.item(), 1.0)
205225

206-
# Call wait() which should trigger the callback
226+
# Call wait() on the managed_work first to set up the future properly
227+
managed_work.wait()
228+
229+
# Now wait on the future
207230
future.wait()
208231

209232
# Verify callback has executed

torchft/ddp.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,10 @@ def _comm_hook(
7575
fut = work.get_future()
7676

7777
def callback(
78-
tensors: torch.futures.Future[list[torch.Tensor]],
79-
) -> list[torch.Tensor]:
78+
tensor: torch.futures.Future[torch.Tensor],
79+
) -> None:
8080
nonlocal result_fut
81-
result_fut.set_result(tensors.value()[0])
82-
return []
81+
result_fut.set_result(tensor.value())
8382

8483
fut = fut.then(callback)
8584

0 commit comments

Comments
 (0)