|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import types |
| 8 | +import unittest |
| 9 | +from datetime import timedelta |
| 10 | +from typing import Callable, Dict, List, Optional, Tuple, TypeVar, cast |
| 11 | + |
| 12 | +# Define a type variable for the Future's value type |
| 13 | +T = TypeVar("T") |
| 14 | + |
| 15 | +import parameterized |
| 16 | +import torch |
| 17 | +from torch.distributed.distributed_c10d import Work |
| 18 | +from torch.futures import Future |
| 19 | + |
| 20 | +from torchft.manager import Manager, _ManagedWork |
| 21 | + |
| 22 | + |
| 23 | +class SimpleWork(Work): |
| 24 | + """A simple implementation of torch.distributed.Work for testing.""" |
| 25 | + |
| 26 | + def __init__(self, tensors: List[torch.Tensor]) -> None: |
| 27 | + super().__init__() |
| 28 | + self._tensors = tensors |
| 29 | + self._future: Future[List[torch.Tensor]] = torch.futures.Future() |
| 30 | + self._is_completed: bool = False |
| 31 | + |
| 32 | + def wait(self, timeout: Optional[timedelta] = None) -> bool: |
| 33 | + self._is_completed = True |
| 34 | + self._future.set_result(self._tensors) |
| 35 | + return True |
| 36 | + |
| 37 | + def get_future(self) -> Future[List[torch.Tensor]]: |
| 38 | + return self._future |
| 39 | + |
| 40 | + |
| 41 | +class TestManagedWork(unittest.TestCase): |
| 42 | + @parameterized.parameterized.expand( |
| 43 | + [ |
| 44 | + ("cpu", torch.device("cpu")), |
| 45 | + ("cuda", torch.device("cuda:0")), |
| 46 | + ] |
| 47 | + ) |
| 48 | + def test_callbacks_execute_after_wait( |
| 49 | + self, name: str, device: torch.device |
| 50 | + ) -> None: |
| 51 | + """Test that callbacks are only executed after wait() is called.""" |
| 52 | + # Skip if CUDA is requested but not available |
| 53 | + if device.type == "cuda" and not torch.cuda.is_available(): |
| 54 | + self.skipTest("CUDA not available") |
| 55 | + |
| 56 | + # Create a tensor to work with |
| 57 | + tensor: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device) |
| 58 | + |
| 59 | + # Create a simple work object |
| 60 | + work = SimpleWork([tensor]) |
| 61 | + |
| 62 | + # Create a minimal manager object with just the wrap_future method |
| 63 | + manager = Manager.__new__(Manager) # Create instance without calling __init__ |
| 64 | + # We're using types.MethodType to attach a method to the manager instance |
| 65 | + # This is just for testing purposes |
| 66 | + manager.wrap_future = types.MethodType( # type: ignore |
| 67 | + lambda self, fut, default, timeout=None: fut, manager |
| 68 | + ) |
| 69 | + |
| 70 | + # Create the managed work |
| 71 | + managed_work = _ManagedWork(manager, work, [tensor]) |
| 72 | + |
| 73 | + # Track callback execution |
| 74 | + callback_executed: bool = False |
| 75 | + |
| 76 | + def callback(fut: Future[object]) -> List[torch.Tensor]: |
| 77 | + # Cast to the expected type |
| 78 | + nonlocal callback_executed, tensor |
| 79 | + callback_executed = True |
| 80 | + # Multiply tensor by 2 to verify the callback ran |
| 81 | + tensor.mul_(2) |
| 82 | + return [tensor] |
| 83 | + |
| 84 | + # Add the callback |
| 85 | + fut = managed_work.get_future() |
| 86 | + fut = fut.then(callback) |
| 87 | + |
| 88 | + # Verify callback hasn't executed yet |
| 89 | + self.assertFalse(callback_executed) |
| 90 | + self.assertEqual(tensor.item(), 1.0) |
| 91 | + |
| 92 | + # Call wait() which should trigger the callback |
| 93 | + managed_work.wait() |
| 94 | + |
| 95 | + # Verify callback has executed |
| 96 | + self.assertTrue(callback_executed) |
| 97 | + self.assertEqual(tensor.item(), 2.0) |
| 98 | + |
| 99 | + @parameterized.parameterized.expand( |
| 100 | + [ |
| 101 | + ("cpu", torch.device("cpu")), |
| 102 | + ("cuda", torch.device("cuda:0")), |
| 103 | + ] |
| 104 | + ) |
| 105 | + def test_multiple_callbacks_execute_in_order( |
| 106 | + self, name: str, device: torch.device |
| 107 | + ) -> None: |
| 108 | + """Test that multiple callbacks are executed in the order they were added.""" |
| 109 | + # Skip if CUDA is requested but not available |
| 110 | + if device.type == "cuda" and not torch.cuda.is_available(): |
| 111 | + self.skipTest("CUDA not available") |
| 112 | + |
| 113 | + # Create a tensor to work with |
| 114 | + tensor: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device) |
| 115 | + |
| 116 | + # Create a simple work object |
| 117 | + work = SimpleWork([tensor]) |
| 118 | + |
| 119 | + # Create a minimal manager object with just the wrap_future method |
| 120 | + manager = Manager.__new__(Manager) # Create instance without calling __init__ |
| 121 | + manager.wrap_future = types.MethodType( # type: ignore |
| 122 | + lambda self, fut, default, timeout=None: fut, manager |
| 123 | + ) |
| 124 | + |
| 125 | + # Create the managed work |
| 126 | + managed_work = _ManagedWork(manager, work, [tensor]) |
| 127 | + |
| 128 | + # Track execution order |
| 129 | + execution_order: List[int] = [] |
| 130 | + |
| 131 | + def callback1(fut: Future[list[torch.Tensor]]) -> List[torch.Tensor]: |
| 132 | + nonlocal tensor |
| 133 | + execution_order.append(1) |
| 134 | + tensor.add_(1) |
| 135 | + return [tensor] |
| 136 | + |
| 137 | + def callback2(fut: Future[list[torch.Tensor]]) -> List[torch.Tensor]: |
| 138 | + nonlocal tensor |
| 139 | + execution_order.append(2) |
| 140 | + tensor.add_(2) |
| 141 | + return [tensor] |
| 142 | + |
| 143 | + def callback3(fut: Future[list[torch.Tensor]]) -> List[torch.Tensor]: |
| 144 | + nonlocal tensor |
| 145 | + execution_order.append(3) |
| 146 | + tensor.add_(3) |
| 147 | + return [tensor] |
| 148 | + |
| 149 | + # Add callbacks |
| 150 | + fut = managed_work.get_future() |
| 151 | + fut = cast(Future[list[torch.Tensor]], fut) |
| 152 | + fut = fut.then(callback1) |
| 153 | + fut = fut.then(callback2) |
| 154 | + fut = fut.then(callback3) |
| 155 | + |
| 156 | + # Verify no callbacks have executed yet |
| 157 | + self.assertEqual(len(execution_order), 0) |
| 158 | + self.assertEqual(tensor.item(), 1.0) |
| 159 | + |
| 160 | + # Call wait() which should trigger the callbacks |
| 161 | + managed_work.wait() |
| 162 | + |
| 163 | + # Verify callbacks executed in order |
| 164 | + self.assertEqual(execution_order, [1, 2, 3]) |
| 165 | + |
| 166 | + # Each callback adds to the tensor, so final value should be 1 + 1 + 2 + 3 = 7 |
| 167 | + self.assertEqual(tensor.item(), 7.0) |
| 168 | + |
| 169 | + @parameterized.parameterized.expand( |
| 170 | + [ |
| 171 | + ("cpu", torch.device("cpu")), |
| 172 | + ("cuda", torch.device("cuda:0")), |
| 173 | + ] |
| 174 | + ) |
| 175 | + def test_future_then_api(self, name: str, device: torch.device) -> None: |
| 176 | + """Test that the future's then API works correctly with ManagedWork.""" |
| 177 | + # Skip if CUDA is requested but not available |
| 178 | + if device.type == "cuda" and not torch.cuda.is_available(): |
| 179 | + self.skipTest("CUDA not available") |
| 180 | + |
| 181 | + # Create a tensor to work with |
| 182 | + tensor: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device) |
| 183 | + |
| 184 | + # Create a simple work object |
| 185 | + work = SimpleWork([tensor]) |
| 186 | + |
| 187 | + # Create a minimal manager object with just the wrap_future method |
| 188 | + manager = Manager.__new__(Manager) # Create instance without calling __init__ |
| 189 | + manager.wrap_future = types.MethodType( # type: ignore |
| 190 | + lambda self, fut, default, timeout=None: fut, manager |
| 191 | + ) |
| 192 | + |
| 193 | + # Create the managed work |
| 194 | + managed_work = _ManagedWork(manager, work, [tensor]) |
| 195 | + |
| 196 | + # Get the future |
| 197 | + future = managed_work.get_future() |
| 198 | + |
| 199 | + # Track callback execution |
| 200 | + callback_executed: bool = False |
| 201 | + |
| 202 | + def callback(fut: Future[object]) -> List[torch.Tensor]: |
| 203 | + # Cast to the expected type |
| 204 | + nonlocal callback_executed, tensor |
| 205 | + callback_executed = True |
| 206 | + # Multiply tensor by 3 to verify the callback ran |
| 207 | + tensor.mul_(3) |
| 208 | + return [tensor] |
| 209 | + |
| 210 | + # Use the then API |
| 211 | + future = future.then(callback) |
| 212 | + |
| 213 | + # Verify callback hasn't executed yet |
| 214 | + self.assertFalse(callback_executed) |
| 215 | + self.assertEqual(tensor.item(), 1.0) |
| 216 | + |
| 217 | + # Call wait() on the managed_work first to set up the future properly |
| 218 | + managed_work.wait() |
| 219 | + |
| 220 | + # Verify callback has executed |
| 221 | + self.assertTrue(callback_executed) |
| 222 | + self.assertEqual(tensor.item(), 3.0) |
| 223 | + |
| 224 | + @parameterized.parameterized.expand( |
| 225 | + [ |
| 226 | + ("cpu", torch.device("cpu")), |
| 227 | + ("cuda", torch.device("cuda:0")), |
| 228 | + ] |
| 229 | + ) |
| 230 | + def test_callbacks_changing_return_types( |
| 231 | + self, name: str, device: torch.device |
| 232 | + ) -> None: |
| 233 | + """ |
| 234 | + Test that callbacks can change return types and that tensors are modified in-place. |
| 235 | + This test demonstrates: |
| 236 | + 1. Callbacks changing return types (List[Tensor] -> Dict -> Tuple) |
| 237 | + 2. Using Future.value() instead of nonlocal |
| 238 | + 3. Verifying tensors are modified in-place for both approaches |
| 239 | + """ |
| 240 | + # Skip if CUDA is requested but not available |
| 241 | + if device.type == "cuda" and not torch.cuda.is_available(): |
| 242 | + self.skipTest("CUDA not available") |
| 243 | + |
| 244 | + # Create tensors to work with |
| 245 | + tensor1: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device) |
| 246 | + tensor2: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device) * 2 |
| 247 | + |
| 248 | + # Store original tensor memory addresses to verify in-place modification |
| 249 | + tensor1_address = tensor1.data_ptr() |
| 250 | + tensor2_address = tensor2.data_ptr() |
| 251 | + |
| 252 | + # Create a simple work object |
| 253 | + work = SimpleWork([tensor1, tensor2]) |
| 254 | + |
| 255 | + # Create a minimal manager object with just the wrap_future method |
| 256 | + manager = Manager.__new__(Manager) # Create instance without calling __init__ |
| 257 | + manager.wrap_future = types.MethodType( # type: ignore |
| 258 | + lambda self, fut, default, timeout=None: fut, manager |
| 259 | + ) |
| 260 | + |
| 261 | + # Create the managed work |
| 262 | + managed_work = _ManagedWork(manager, work, [tensor1, tensor2]) |
| 263 | + |
| 264 | + # Get the future |
| 265 | + future = managed_work.get_future() |
| 266 | + future = cast(Future[List[torch.Tensor]], future) |
| 267 | + |
| 268 | + # First callback: Takes List[Tensor] and returns Dict[str, Tensor] |
| 269 | + # Uses nonlocal to modify tensor1 |
| 270 | + def callback1(fut: Future[List[torch.Tensor]]) -> Dict[str, torch.Tensor]: |
| 271 | + tensors = fut.value() |
| 272 | + nonlocal tensor1 |
| 273 | + # Modify tensor1 in-place using nonlocal |
| 274 | + tensor1.mul_(3) |
| 275 | + # Return a dictionary instead of a list |
| 276 | + return {"first": tensors[0], "second": tensors[1]} |
| 277 | + |
| 278 | + # Second callback: Takes Dict[str, Tensor] and returns Tuple[Tensor, float] |
| 279 | + # Uses Future.value() to modify tensor2 |
| 280 | + def callback2( |
| 281 | + fut: Future[Dict[str, torch.Tensor]] |
| 282 | + ) -> Tuple[torch.Tensor, float]: |
| 283 | + data = fut.value() |
| 284 | + # Modify tensor2 in-place using the value from the future |
| 285 | + data["second"].add_(5) # Should modify tensor2 in-place |
| 286 | + # Return a tuple instead of a dict |
| 287 | + return (data["second"], data["first"].item()) |
| 288 | + |
| 289 | + # Third callback: Takes Tuple[Tensor, float] and returns a single Tensor |
| 290 | + def callback3(fut: Future[Tuple[torch.Tensor, float]]) -> torch.Tensor: |
| 291 | + tensor, value = fut.value() |
| 292 | + # Create a new tensor based on the tuple values |
| 293 | + result = tensor * value |
| 294 | + return result |
| 295 | + |
| 296 | + # Chain the callbacks |
| 297 | + future = future.then(callback1) |
| 298 | + future = future.then(callback2) |
| 299 | + future = future.then(callback3) |
| 300 | + |
| 301 | + # Call wait() to trigger the callbacks |
| 302 | + managed_work.wait() |
| 303 | + |
| 304 | + # Verify tensor1 was modified in-place (using nonlocal) |
| 305 | + self.assertEqual(tensor1.item(), 3.0) # 1 * 3 = 3 |
| 306 | + self.assertEqual(tensor1.data_ptr(), tensor1_address) # Same memory address |
| 307 | + |
| 308 | + # Verify tensor2 was modified in-place (using Future.value()) |
| 309 | + self.assertEqual(tensor2.item(), 7.0) # 2 + 5 = 7 |
| 310 | + self.assertEqual(tensor2.data_ptr(), tensor2_address) # Same memory address |
| 311 | + |
| 312 | + # Get the final result from the future |
| 313 | + final_result = future.wait() |
| 314 | + |
| 315 | + # The final result should be tensor2 * tensor1.item() = 7 * 3 = 21 |
| 316 | + self.assertEqual(final_result.item(), 21.0) |
| 317 | + |
| 318 | + |
| 319 | +if __name__ == "__main__": |
| 320 | + unittest.main() |
0 commit comments