|
7 | 7 | import types
|
8 | 8 | import unittest
|
9 | 9 | from datetime import timedelta
|
10 |
| -from typing import Callable, List, Optional |
| 10 | +from typing import Callable, List, Optional, TypeVar, cast |
| 11 | + |
| 12 | +# Define a type variable for the Future's value type |
| 13 | +T = TypeVar("T") |
11 | 14 |
|
12 | 15 | import parameterized
|
13 | 16 | import torch
|
@@ -65,20 +68,24 @@ def test_callbacks_execute_after_wait(
|
65 | 68 | )
|
66 | 69 |
|
67 | 70 | # Create the managed work
|
68 |
| - managed_work = _ManagedWork(work, manager, [tensor]) |
| 71 | + managed_work = _ManagedWork(manager, work, [tensor]) |
69 | 72 |
|
70 | 73 | # Track callback execution
|
71 | 74 | callback_executed: bool = False
|
72 | 75 |
|
73 |
| - def callback(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]: |
| 76 | + def callback(fut: Future[object]) -> List[torch.Tensor]: |
| 77 | + # Cast to the expected type |
| 78 | + tensor_fut = cast(Future[List[torch.Tensor]], fut) |
74 | 79 | nonlocal callback_executed
|
75 | 80 | callback_executed = True
|
76 | 81 | # Multiply tensor by 2 to verify the callback ran
|
77 |
| - fut.value()[0].mul_(2) |
78 |
| - return fut.value() |
| 82 | + value = tensor_fut.value() |
| 83 | + value[0].mul_(2) |
| 84 | + return value |
79 | 85 |
|
80 | 86 | # Add the callback
|
81 |
| - managed_work.add_callback(callback) |
| 87 | + fut = managed_work.get_future() |
| 88 | + fut = fut.then(callback) |
82 | 89 |
|
83 | 90 | # Verify callback hasn't executed yet
|
84 | 91 | self.assertFalse(callback_executed)
|
@@ -118,30 +125,40 @@ def test_multiple_callbacks_execute_in_order(
|
118 | 125 | )
|
119 | 126 |
|
120 | 127 | # Create the managed work
|
121 |
| - managed_work = _ManagedWork(work, manager, [tensor]) |
| 128 | + managed_work = _ManagedWork(manager, work, [tensor]) |
122 | 129 |
|
123 | 130 | # Track execution order
|
124 | 131 | execution_order: List[int] = []
|
125 | 132 |
|
126 |
| - def callback1(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]: |
| 133 | + def callback1(fut: Future[T]) -> List[torch.Tensor]: |
| 134 | + # Cast to the expected type |
| 135 | + tensor_fut = cast(Future[List[torch.Tensor]], fut) |
127 | 136 | execution_order.append(1)
|
128 |
| - fut.value()[0].add_(1) |
129 |
| - return fut.value() |
| 137 | + value = tensor_fut.value() |
| 138 | + value[0].add_(1) |
| 139 | + return value |
130 | 140 |
|
131 |
| - def callback2(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]: |
| 141 | + def callback2(fut: Future[T]) -> List[torch.Tensor]: |
| 142 | + # Cast to the expected type |
| 143 | + tensor_fut = cast(Future[List[torch.Tensor]], fut) |
132 | 144 | execution_order.append(2)
|
133 |
| - fut.value()[0].add_(2) |
134 |
| - return fut.value() |
| 145 | + value = tensor_fut.value() |
| 146 | + value[0].add_(2) |
| 147 | + return value |
135 | 148 |
|
136 |
| - def callback3(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]: |
| 149 | + def callback3(fut: Future[T]) -> List[torch.Tensor]: |
| 150 | + # Cast to the expected type |
| 151 | + tensor_fut = cast(Future[List[torch.Tensor]], fut) |
137 | 152 | execution_order.append(3)
|
138 |
| - fut.value()[0].add_(3) |
139 |
| - return fut.value() |
| 153 | + value = tensor_fut.value() |
| 154 | + value[0].add_(3) |
| 155 | + return value |
140 | 156 |
|
141 | 157 | # Add callbacks
|
142 |
| - managed_work.add_callback(callback1) |
143 |
| - managed_work.add_callback(callback2) |
144 |
| - managed_work.add_callback(callback3) |
| 158 | + fut = managed_work.get_future() |
| 159 | + fut = fut.then(callback1) |
| 160 | + fut = fut.then(callback2) |
| 161 | + fut = fut.then(callback3) |
145 | 162 |
|
146 | 163 | # Verify no callbacks have executed yet
|
147 | 164 | self.assertEqual(len(execution_order), 0)
|
@@ -181,29 +198,35 @@ def test_future_then_api(self, name: str, device: torch.device) -> None:
|
181 | 198 | )
|
182 | 199 |
|
183 | 200 | # Create the managed work
|
184 |
| - managed_work = _ManagedWork(work, manager, [tensor]) |
| 201 | + managed_work = _ManagedWork(manager, work, [tensor]) |
185 | 202 |
|
186 | 203 | # Get the future
|
187 | 204 | future = managed_work.get_future()
|
188 | 205 |
|
189 | 206 | # Track callback execution
|
190 | 207 | callback_executed: bool = False
|
191 | 208 |
|
192 |
| - def callback(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]: |
| 209 | + def callback(fut: Future[object]) -> List[torch.Tensor]: |
| 210 | + # Cast to the expected type |
| 211 | + tensor_fut = cast(Future[List[torch.Tensor]], fut) |
193 | 212 | nonlocal callback_executed
|
194 | 213 | callback_executed = True
|
195 | 214 | # Multiply tensor by 3 to verify the callback ran
|
196 |
| - fut.value()[0].mul_(3) |
197 |
| - return fut.value() |
| 215 | + value = tensor_fut.value() |
| 216 | + value[0].mul_(3) |
| 217 | + return value |
198 | 218 |
|
199 | 219 | # Use the then API
|
200 |
| - future.then(callback) |
| 220 | + future = future.then(callback) |
201 | 221 |
|
202 | 222 | # Verify callback hasn't executed yet
|
203 | 223 | self.assertFalse(callback_executed)
|
204 | 224 | self.assertEqual(tensor.item(), 1.0)
|
205 | 225 |
|
206 |
| - # Call wait() which should trigger the callback |
| 226 | + # Call wait() on the managed_work first to set up the future properly |
| 227 | + managed_work.wait() |
| 228 | + |
| 229 | + # Now wait on the future |
207 | 230 | future.wait()
|
208 | 231 |
|
209 | 232 | # Verify callback has executed
|
|
0 commit comments