Skip to content

Commit b7b5239

Browse files
committed
process_group pipe fix
1 parent 8f021e1 commit b7b5239

File tree

1 file changed

+30
-7
lines changed

1 file changed

+30
-7
lines changed

torchft/process_group.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,17 @@
2323
from datetime import timedelta
2424
from multiprocessing.connection import Connection
2525
from typing import (
26-
TYPE_CHECKING,
2726
Any,
2827
Callable,
28+
cast,
2929
Dict,
3030
Generator,
3131
List,
3232
Optional,
3333
Tuple,
34+
TYPE_CHECKING,
3435
TypeVar,
3536
Union,
36-
cast,
3737
)
3838

3939
import torch
@@ -44,14 +44,14 @@
4444
# pyre-fixme[21]: no attribute ProcessGroupGloo
4545
from torch.distributed import (
4646
DeviceMesh,
47+
get_rank,
48+
init_device_mesh,
4749
PrefixStore,
4850
ProcessGroup as BaseProcessGroup,
4951
ProcessGroupGloo as BaseProcessGroupGloo,
5052
ProcessGroupNCCL as BaseProcessGroupNCCL,
5153
Store,
5254
TCPStore,
53-
get_rank,
54-
init_device_mesh,
5555
)
5656
from torch.distributed.distributed_c10d import (
5757
AllgatherOptions,
@@ -970,6 +970,26 @@ def shutdown(self) -> None:
970970
self._p.kill()
971971

972972
def configure(self, store_addr: str, rank: int, world_size: int) -> None:
973+
"""
974+
Structure
975+
+-------------------+
976+
| |
977+
| Main Process | (updates self._futures)
978+
| | <---------------
979+
+-------------------+ |
980+
| Pipe 1 |
981+
v |
982+
+-------------------+ +-------------------+
983+
| | | |
984+
| Worker Process | -> | Future Thread |
985+
| | Pipe 2 | |
986+
+-------------------+ +-------------------+
987+
988+
Main Process: Central controller, maintains self._futures.
989+
Worker Process: Handles tasks, communicates with Future Thread.
990+
Future Thread: Manages asynchronous tasks, updates self._futures.
991+
"""
992+
973993
self._world_size = world_size
974994

975995
self.shutdown()
@@ -990,7 +1010,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
9901010
rank,
9911011
world_size,
9921012
req_remote,
993-
future_remote,
1013+
future_local,
9941014
curr_device,
9951015
),
9961016
daemon=True,
@@ -1003,7 +1023,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
10031023
self._futures = {}
10041024
self._future_thread = threading.Thread(
10051025
target=self._future_handler,
1006-
args=(future_local,),
1026+
args=(future_remote,),
10071027
daemon=True,
10081028
)
10091029
self._future_thread.start()
@@ -1163,11 +1183,14 @@ def callback(fut: Future[object], metadata: _OpMetadata) -> None:
11631183

11641184
def _future_handler(self, future_pipe: _MonitoredPipe) -> None:
11651185
try:
1166-
while True:
1186+
while not self._future_thread_shutdown_flag.is_set():
11671187
try:
11681188
cmd = future_pipe.recv(timedelta(seconds=10))
11691189
except TimeoutError:
11701190
continue
1191+
# except EOFError:
1192+
# # Pipe was closed, exit the loop
1193+
# break
11711194

11721195
op_id, mode, data, event = cast(
11731196
Tuple[int, str, object, Optional[torch.cuda.Event]], cmd

0 commit comments

Comments
 (0)