23
23
from datetime import timedelta
24
24
from multiprocessing .connection import Connection
25
25
from typing import (
26
- TYPE_CHECKING ,
27
26
Any ,
28
27
Callable ,
28
+ cast ,
29
29
Dict ,
30
30
Generator ,
31
31
List ,
32
32
Optional ,
33
33
Tuple ,
34
+ TYPE_CHECKING ,
34
35
TypeVar ,
35
36
Union ,
36
- cast ,
37
37
)
38
38
39
39
import torch
44
44
# pyre-fixme[21]: no attribute ProcessGroupGloo
45
45
from torch .distributed import (
46
46
DeviceMesh ,
47
+ get_rank ,
48
+ init_device_mesh ,
47
49
PrefixStore ,
48
50
ProcessGroup as BaseProcessGroup ,
49
51
ProcessGroupGloo as BaseProcessGroupGloo ,
50
52
ProcessGroupNCCL as BaseProcessGroupNCCL ,
51
53
Store ,
52
54
TCPStore ,
53
- get_rank ,
54
- init_device_mesh ,
55
55
)
56
56
from torch .distributed .distributed_c10d import (
57
57
AllgatherOptions ,
@@ -970,6 +970,26 @@ def shutdown(self) -> None:
970
970
self ._p .kill ()
971
971
972
972
def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
973
+ """
974
+ Structure
975
+ +-------------------+
976
+ | |
977
+ | Main Process | (updates self._futures)
978
+ | | <---------------
979
+ +-------------------+ |
980
+ | Pipe 1 |
981
+ v |
982
+ +-------------------+ +-------------------+
983
+ | | | |
984
+ | Worker Process | -> | Future Thread |
985
+ | | Pipe 2 | |
986
+ +-------------------+ +-------------------+
987
+
988
+ Main Process: Central controller, maintains self._futures.
989
+ Worker Process: Handles tasks, communicates with Future Thread.
990
+ Future Thread: Manages asynchronous tasks, updates self._futures.
991
+ """
992
+
973
993
self ._world_size = world_size
974
994
975
995
self .shutdown ()
@@ -990,7 +1010,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
990
1010
rank ,
991
1011
world_size ,
992
1012
req_remote ,
993
- future_remote ,
1013
+ future_local ,
994
1014
curr_device ,
995
1015
),
996
1016
daemon = True ,
@@ -1003,7 +1023,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
1003
1023
self ._futures = {}
1004
1024
self ._future_thread = threading .Thread (
1005
1025
target = self ._future_handler ,
1006
- args = (future_local ,),
1026
+ args = (future_remote ,),
1007
1027
daemon = True ,
1008
1028
)
1009
1029
self ._future_thread .start ()
@@ -1163,11 +1183,14 @@ def callback(fut: Future[object], metadata: _OpMetadata) -> None:
1163
1183
1164
1184
def _future_handler (self , future_pipe : _MonitoredPipe ) -> None :
1165
1185
try :
1166
- while True :
1186
+ while not self . _future_thread_shutdown_flag . is_set () :
1167
1187
try :
1168
1188
cmd = future_pipe .recv (timedelta (seconds = 10 ))
1169
1189
except TimeoutError :
1170
1190
continue
1191
+ # except EOFError:
1192
+ # # Pipe was closed, exit the loop
1193
+ # break
1171
1194
1172
1195
op_id , mode , data , event = cast (
1173
1196
Tuple [int , str , object , Optional [torch .cuda .Event ]], cmd
0 commit comments