7
7
import types
8
8
import unittest
9
9
from datetime import timedelta
10
- from typing import Callable , List , Optional
10
+ from typing import Callable , Dict , List , Optional , Tuple , TypeVar , cast
11
+
12
+ # Define a type variable for the Future's value type
13
+ T = TypeVar ("T" )
11
14
12
15
import parameterized
13
16
import torch
@@ -51,7 +54,7 @@ def test_callbacks_execute_after_wait(
51
54
self .skipTest ("CUDA not available" )
52
55
53
56
# Create a tensor to work with
54
- tensor = torch .ones (1 , dtype = torch .float32 , device = device )
57
+ tensor : torch . Tensor = torch .ones (1 , dtype = torch .float32 , device = device )
55
58
56
59
# Create a simple work object
57
60
work = SimpleWork ([tensor ])
@@ -65,20 +68,22 @@ def test_callbacks_execute_after_wait(
65
68
)
66
69
67
70
# Create the managed work
68
- managed_work = _ManagedWork (work , manager , [tensor ])
71
+ managed_work = _ManagedWork (manager , work , [tensor ])
69
72
70
73
# Track callback execution
71
74
callback_executed : bool = False
72
75
73
- def callback (fut : Future [List [torch .Tensor ]]) -> List [torch .Tensor ]:
74
- nonlocal callback_executed
76
+ def callback (fut : Future [object ]) -> List [torch .Tensor ]:
77
+ # Cast to the expected type
78
+ nonlocal callback_executed , tensor
75
79
callback_executed = True
76
80
# Multiply tensor by 2 to verify the callback ran
77
- fut . value ()[ 0 ] .mul_ (2 )
78
- return fut . value ()
81
+ tensor .mul_ (2 )
82
+ return [ tensor ]
79
83
80
84
# Add the callback
81
- managed_work .add_callback (callback )
85
+ fut = managed_work .get_future ()
86
+ fut = fut .then (callback )
82
87
83
88
# Verify callback hasn't executed yet
84
89
self .assertFalse (callback_executed )
@@ -106,7 +111,7 @@ def test_multiple_callbacks_execute_in_order(
106
111
self .skipTest ("CUDA not available" )
107
112
108
113
# Create a tensor to work with
109
- tensor = torch .ones (1 , dtype = torch .float32 , device = device )
114
+ tensor : torch . Tensor = torch .ones (1 , dtype = torch .float32 , device = device )
110
115
111
116
# Create a simple work object
112
117
work = SimpleWork ([tensor ])
@@ -118,30 +123,35 @@ def test_multiple_callbacks_execute_in_order(
118
123
)
119
124
120
125
# Create the managed work
121
- managed_work = _ManagedWork (work , manager , [tensor ])
126
+ managed_work = _ManagedWork (manager , work , [tensor ])
122
127
123
128
# Track execution order
124
129
execution_order : List [int ] = []
125
130
126
- def callback1 (fut : Future [List [torch .Tensor ]]) -> List [torch .Tensor ]:
131
+ def callback1 (fut : Future [list [torch .Tensor ]]) -> List [torch .Tensor ]:
132
+ nonlocal tensor
127
133
execution_order .append (1 )
128
- fut . value ()[ 0 ] .add_ (1 )
129
- return fut . value ()
134
+ tensor .add_ (1 )
135
+ return [ tensor ]
130
136
131
- def callback2 (fut : Future [List [torch .Tensor ]]) -> List [torch .Tensor ]:
137
+ def callback2 (fut : Future [list [torch .Tensor ]]) -> List [torch .Tensor ]:
138
+ nonlocal tensor
132
139
execution_order .append (2 )
133
- fut . value ()[ 0 ] .add_ (2 )
134
- return fut . value ()
140
+ tensor .add_ (2 )
141
+ return [ tensor ]
135
142
136
- def callback3 (fut : Future [List [torch .Tensor ]]) -> List [torch .Tensor ]:
143
+ def callback3 (fut : Future [list [torch .Tensor ]]) -> List [torch .Tensor ]:
144
+ nonlocal tensor
137
145
execution_order .append (3 )
138
- fut . value ()[ 0 ] .add_ (3 )
139
- return fut . value ()
146
+ tensor .add_ (3 )
147
+ return [ tensor ]
140
148
141
149
# Add callbacks
142
- managed_work .add_callback (callback1 )
143
- managed_work .add_callback (callback2 )
144
- managed_work .add_callback (callback3 )
150
+ fut = managed_work .get_future ()
151
+ fut = cast (Future [list [torch .Tensor ]], fut )
152
+ fut = fut .then (callback1 )
153
+ fut = fut .then (callback2 )
154
+ fut = fut .then (callback3 )
145
155
146
156
# Verify no callbacks have executed yet
147
157
self .assertEqual (len (execution_order ), 0 )
@@ -169,7 +179,7 @@ def test_future_then_api(self, name: str, device: torch.device) -> None:
169
179
self .skipTest ("CUDA not available" )
170
180
171
181
# Create a tensor to work with
172
- tensor = torch .ones (1 , dtype = torch .float32 , device = device )
182
+ tensor : torch . Tensor = torch .ones (1 , dtype = torch .float32 , device = device )
173
183
174
184
# Create a simple work object
175
185
work = SimpleWork ([tensor ])
@@ -181,35 +191,130 @@ def test_future_then_api(self, name: str, device: torch.device) -> None:
181
191
)
182
192
183
193
# Create the managed work
184
- managed_work = _ManagedWork (work , manager , [tensor ])
194
+ managed_work = _ManagedWork (manager , work , [tensor ])
185
195
186
196
# Get the future
187
197
future = managed_work .get_future ()
188
198
189
199
# Track callback execution
190
200
callback_executed : bool = False
191
201
192
- def callback (fut : Future [List [torch .Tensor ]]) -> List [torch .Tensor ]:
193
- nonlocal callback_executed
202
+ def callback (fut : Future [object ]) -> List [torch .Tensor ]:
203
+ # Cast to the expected type
204
+ nonlocal callback_executed , tensor
194
205
callback_executed = True
195
206
# Multiply tensor by 3 to verify the callback ran
196
- fut . value ()[ 0 ] .mul_ (3 )
197
- return fut . value ()
207
+ tensor .mul_ (3 )
208
+ return [ tensor ]
198
209
199
210
# Use the then API
200
- future .then (callback )
211
+ future = future .then (callback )
201
212
202
213
# Verify callback hasn't executed yet
203
214
self .assertFalse (callback_executed )
204
215
self .assertEqual (tensor .item (), 1.0 )
205
216
206
- # Call wait() which should trigger the callback
207
- future .wait ()
217
+ # Call wait() on the managed_work first to set up the future properly
218
+ managed_work .wait ()
208
219
209
220
# Verify callback has executed
210
221
self .assertTrue (callback_executed )
211
222
self .assertEqual (tensor .item (), 3.0 )
212
223
224
+ @parameterized .parameterized .expand (
225
+ [
226
+ ("cpu" , torch .device ("cpu" )),
227
+ ("cuda" , torch .device ("cuda:0" )),
228
+ ]
229
+ )
230
+ def test_callbacks_changing_return_types (
231
+ self , name : str , device : torch .device
232
+ ) -> None :
233
+ """
234
+ Test that callbacks can change return types and that tensors are modified in-place.
235
+ This test demonstrates:
236
+ 1. Callbacks changing return types (List[Tensor] -> Dict -> Tuple)
237
+ 2. Using Future.value() instead of nonlocal
238
+ 3. Verifying tensors are modified in-place for both approaches
239
+ """
240
+ # Skip if CUDA is requested but not available
241
+ if device .type == "cuda" and not torch .cuda .is_available ():
242
+ self .skipTest ("CUDA not available" )
243
+
244
+ # Create tensors to work with
245
+ tensor1 : torch .Tensor = torch .ones (1 , dtype = torch .float32 , device = device )
246
+ tensor2 : torch .Tensor = torch .ones (1 , dtype = torch .float32 , device = device ) * 2
247
+
248
+ # Store original tensor memory addresses to verify in-place modification
249
+ tensor1_address = tensor1 .data_ptr ()
250
+ tensor2_address = tensor2 .data_ptr ()
251
+
252
+ # Create a simple work object
253
+ work = SimpleWork ([tensor1 , tensor2 ])
254
+
255
+ # Create a minimal manager object with just the wrap_future method
256
+ manager = Manager .__new__ (Manager ) # Create instance without calling __init__
257
+ manager .wrap_future = types .MethodType ( # type: ignore
258
+ lambda self , fut , default , timeout = None : fut , manager
259
+ )
260
+
261
+ # Create the managed work
262
+ managed_work = _ManagedWork (manager , work , [tensor1 , tensor2 ])
263
+
264
+ # Get the future
265
+ future = managed_work .get_future ()
266
+ future = cast (Future [List [torch .Tensor ]], future )
267
+
268
+ # First callback: Takes List[Tensor] and returns Dict[str, Tensor]
269
+ # Uses nonlocal to modify tensor1
270
+ def callback1 (fut : Future [List [torch .Tensor ]]) -> Dict [str , torch .Tensor ]:
271
+ tensors = fut .value ()
272
+ nonlocal tensor1
273
+ # Modify tensor1 in-place using nonlocal
274
+ tensor1 .mul_ (3 )
275
+ # Return a dictionary instead of a list
276
+ return {"first" : tensors [0 ], "second" : tensors [1 ]}
277
+
278
+ # Second callback: Takes Dict[str, Tensor] and returns Tuple[Tensor, float]
279
+ # Uses Future.value() to modify tensor2
280
+ def callback2 (
281
+ fut : Future [Dict [str , torch .Tensor ]]
282
+ ) -> Tuple [torch .Tensor , float ]:
283
+ data = fut .value ()
284
+ # Modify tensor2 in-place using the value from the future
285
+ data ["second" ].add_ (5 ) # Should modify tensor2 in-place
286
+ # Return a tuple instead of a dict
287
+ return (data ["second" ], data ["first" ].item ())
288
+
289
+ # Third callback: Takes Tuple[Tensor, float] and returns a single Tensor
290
+ def callback3 (fut : Future [Tuple [torch .Tensor , float ]]) -> torch .Tensor :
291
+ tensor , value = fut .value ()
292
+ # Create a new tensor based on the tuple values
293
+ result = tensor * value
294
+ return result
295
+
296
+ # Chain the callbacks
297
+ future = future .then (callback1 )
298
+ future = future .then (callback2 )
299
+ future = future .then (callback3 )
300
+
301
+ # Call wait() to trigger the callbacks
302
+ managed_work .wait ()
303
+
304
+ # Verify tensor1 was modified in-place (using nonlocal)
305
+ self .assertEqual (tensor1 .item (), 3.0 ) # 1 * 3 = 3
306
+ self .assertEqual (tensor1 .data_ptr (), tensor1_address ) # Same memory address
307
+
308
+ # Verify tensor2 was modified in-place (using Future.value())
309
+ self .assertEqual (tensor2 .item (), 7.0 ) # 2 + 5 = 7
310
+ self .assertEqual (tensor2 .data_ptr (), tensor2_address ) # Same memory address
311
+
312
+ # Get the final result from the future
313
+ final_result = future .wait ()
314
+
315
+ # The final result should be tensor2 * tensor1.item() = 7 * 3 = 21
316
+ self .assertEqual (final_result .item (), 21.0 )
317
+
213
318
214
319
if __name__ == "__main__" :
215
320
unittest .main ()
0 commit comments