Skip to content

Commit dd00ebf

Browse files
authored
feature(wgt): add barrier middleware (opendilab#570)
* feature(wgt): Add barrier middleware * fix trainer.py:multistep_trainer:policy.forward args bug * Modify codecov concurrency mode from threading to multiprocessing * polish barrier middleware
1 parent 3447f57 commit dd00ebf

File tree

6 files changed

+406
-0
lines changed

6 files changed

+406
-0
lines changed

.coveragerc

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[run]
2+
concurrency = multiprocessing,thread
23
omit =
34
ding/utils/slurm_helper.py
45
ding/utils/file_helper.py

ding/framework/middleware/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .learner import OffPolicyLearner, HERLearner
44
from .ckpt_handler import CkptSaver
55
from .distributer import ContextExchanger, ModelExchanger
6+
from .barrier import Barrier, BarrierRuntime

ding/framework/middleware/barrier.py

+227
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
from time import sleep, time
2+
from ditk import logging
3+
from ding.framework import task
4+
from ding.utils.lock_helper import LockContext, LockContextType
5+
from ding.utils.design_helper import SingletonMetaclass
6+
7+
8+
class BarrierRuntime(metaclass=SingletonMetaclass):
9+
10+
def __init__(self, node_id: int, max_world_size: int = 100):
11+
"""
12+
Overview:
13+
'BarrierRuntime' is a singleton class. In addition, it must be initialized before the
14+
class 'Parallel' starts MQ, otherwise the messages sent by other nodes may be lost after
15+
the detection is completed. We don't have a message retransmission mechanism, and losing
16+
a message means deadlock.
17+
Arguments:
18+
- node_id (int): Process ID.
19+
- max_world_size (int, optional): The maximum total number of processes that can be
20+
synchronized, the defalut value is 100.
21+
"""
22+
self.node_id = node_id
23+
self._has_detected = False
24+
self._range_len = len(str(max_world_size)) + 1
25+
26+
self._barrier_epoch = 0
27+
self._barrier_recv_peers_buff = dict()
28+
self._barrier_recv_peers = dict()
29+
self._barrier_ack_peers = []
30+
self._barrier_lock = LockContext(LockContextType.THREAD_LOCK)
31+
32+
self.mq_type = task.router.mq_type
33+
self._connected_peers = dict()
34+
self._connected_peers_lock = LockContext(LockContextType.THREAD_LOCK)
35+
self._keep_alive_daemon = False
36+
37+
self._event_name_detect = "b_det"
38+
self.event_name_req = "b_req"
39+
self.event_name_ack = "b_ack"
40+
41+
def _alive_msg_handler(self, peer_id):
42+
with self._connected_peers_lock:
43+
self._connected_peers[peer_id] = time()
44+
45+
def _add_barrier_req(self, msg):
46+
peer, epoch = self._unpickle_barrier_tag(msg)
47+
logging.debug("Node:[{}] recv barrier request from node:{}, epoch:{}".format(self.node_id, peer, epoch))
48+
with self._barrier_lock:
49+
if peer not in self._barrier_recv_peers:
50+
self._barrier_recv_peers[peer] = []
51+
self._barrier_recv_peers[peer].append(epoch)
52+
53+
def _add_barrier_ack(self, peer):
54+
logging.debug("Node:[{}] recv barrier ack from node:{}".format(self.node_id, peer))
55+
with self._barrier_lock:
56+
self._barrier_ack_peers.append(peer)
57+
58+
def _unpickle_barrier_tag(self, msg):
59+
return msg % self._range_len, msg // self._range_len
60+
61+
def pickle_barrier_tag(self):
62+
return int(self._barrier_epoch * self._range_len + self.node_id)
63+
64+
def reset_all_peers(self):
65+
with self._barrier_lock:
66+
for peer, q in self._barrier_recv_peers.items():
67+
if len(q) != 0:
68+
assert q.pop(0) == self._barrier_epoch
69+
self._barrier_ack_peers = []
70+
self._barrier_epoch += 1
71+
72+
def get_recv_num(self):
73+
count = 0
74+
with self._barrier_lock:
75+
if len(self._barrier_recv_peers) > 0:
76+
for _, q in self._barrier_recv_peers.items():
77+
if len(q) > 0 and q[0] == self._barrier_epoch:
78+
count += 1
79+
return count
80+
81+
def get_ack_num(self):
82+
with self._barrier_lock:
83+
return len(self._barrier_ack_peers)
84+
85+
def detect_alive(self, expected, timeout):
86+
# The barrier can only block other nodes within the visible range of the current node.
87+
# If the 'attch_to' list of a node is empty, it does not know how many nodes will attach to him,
88+
# so we cannot specify the effective range of a barrier in advance.
89+
assert task._running
90+
task.on(self._event_name_detect, self._alive_msg_handler)
91+
task.on(self.event_name_req, self._add_barrier_req)
92+
task.on(self.event_name_ack, self._add_barrier_ack)
93+
start = time()
94+
while True:
95+
sleep(0.1)
96+
task.emit(self._event_name_detect, self.node_id, only_remote=True)
97+
# In case the other node has not had time to receive our detect message,
98+
# we will send an additional round.
99+
if self._has_detected:
100+
break
101+
with self._connected_peers_lock:
102+
if len(self._connected_peers) == expected:
103+
self._has_detected = True
104+
105+
if time() - start > timeout:
106+
raise TimeoutError("Node-[{}] timeout when waiting barrier! ".format(task.router.node_id))
107+
108+
task.off(self._event_name_detect)
109+
logging.info(
110+
"Barrier detect node done, node-[{}] has connected with {} active nodes!".format(self.node_id, expected)
111+
)
112+
113+
114+
class BarrierContext:
115+
116+
def __init__(self, runtime: BarrierRuntime, detect_timeout, expected_peer_num: int = 0):
117+
self._runtime = runtime
118+
self._expected_peer_num = expected_peer_num
119+
self._timeout = detect_timeout
120+
121+
def __enter__(self):
122+
if not self._runtime._has_detected:
123+
self._runtime.detect_alive(self._expected_peer_num, self._timeout)
124+
125+
def __exit__(self, exc_type, exc_value, tb):
126+
if exc_type is not None:
127+
import traceback
128+
traceback.print_exception(exc_type, exc_value, tb)
129+
self._runtime.reset_all_peers()
130+
131+
132+
class Barrier:
133+
134+
def __init__(self, attch_from_nums: int, timeout: int = 60):
135+
"""
136+
Overview:
137+
Barrier() is a middleware for debug or profiling. It can synchronize the task step of each
138+
process within the scope of all visible processes. When using Barrier(), you need to pay
139+
attention to the following points:
140+
141+
1. All processes must call the same number of Barrier(), otherwise a deadlock occurs.
142+
143+
2. 'attch_from_nums' is a very important variable, This value indicates the number of times
144+
the current process will be attached to by other processes (the number of connections
145+
established).
146+
For example:
147+
Node0: address: 127.0.0.1:12345, attach_to = []
148+
Node1: address: 127.0.0.1:12346, attach_to = ["tcp://127.0.0.1:12345"]
149+
For Node0, the 'attch_from_nums' value is 1. (It will be acttched by Node1)
150+
For Node1, the 'attch_from_nums' value is 0. (No one will attach to Node1)
151+
Please note that this value must be given correctly, otherwise, for a node whose 'attach_to'
152+
list is empty, it cannot perceive how many processes will establish connections with it,
153+
resulting in any form of synchronization cannot be performed.
154+
155+
3. Barrier() is thread-safe, but it is not recommended to use barrier in multithreading. You need
156+
to carefully calculate the number of times each thread calls Barrier() to avoid deadlock.
157+
158+
4. In normal training tasks, please do not use Barrier(), which will force the step synchronization
159+
between each process, so it will greatly damage the training efficiency. In addition, if your
160+
training task has dynamic processes, do not use Barrier() to prevent deadlock.
161+
162+
Arguments:
163+
- attch_from_nums (int): [description]
164+
- timeout (int, optional): The timeout for successful detection of 'expected_peer_num'
165+
number of nodes, the default value is 60 seconds.
166+
"""
167+
self.node_id = task.router.node_id
168+
self.timeout = timeout
169+
self._runtime: BarrierRuntime = task.router.barrier_runtime
170+
self._barrier_peers_nums = task.get_attch_to_len() + attch_from_nums
171+
172+
logging.info(
173+
"Node:[{}], attach to num is:{}, attach from num is:{}".format(
174+
self.node_id, task.get_attch_to_len(), attch_from_nums
175+
)
176+
)
177+
178+
def __call__(self, ctx):
179+
self._wait_barrier(ctx)
180+
yield
181+
self._wait_barrier(ctx)
182+
183+
def _wait_barrier(self, ctx):
184+
self_ready = False
185+
with BarrierContext(self._runtime, self.timeout, self._barrier_peers_nums):
186+
logging.debug("Node:[{}] enter barrier".format(self.node_id))
187+
# Step1: Notifies all the attached nodes that we have reached the barrier.
188+
task.emit(self._runtime.event_name_req, self._runtime.pickle_barrier_tag(), only_remote=True)
189+
logging.debug("Node:[{}] sended barrier request".format(self.node_id))
190+
191+
# Step2: We check the number of flags we have received.
192+
# In the current CI design of DI-engine, there will always be a node whose 'attach_to' list is empty,
193+
# so there will always be a node that will send ACK unconditionally, so deadlock will not occur.
194+
if self._runtime.get_recv_num() == self._barrier_peers_nums:
195+
self_ready = True
196+
197+
# Step3: Waiting for our own to be ready.
198+
# Even if the current process has reached the barrier, we will not send an ack immediately,
199+
# we need to wait for the slowest directly connected or indirectly connected peer to
200+
# reach the barrier.
201+
start = time()
202+
if not self_ready:
203+
while True:
204+
if time() - start > self.timeout:
205+
raise TimeoutError("Node-[{}] timeout when waiting barrier! ".format(task.router.node_id))
206+
207+
if self._runtime.get_recv_num() != self._barrier_peers_nums:
208+
sleep(0.1)
209+
else:
210+
break
211+
212+
# Step4: Notifies all attached nodes that we are ready.
213+
task.emit(self._runtime.event_name_ack, self.node_id, only_remote=True)
214+
logging.debug("Node:[{}] sended barrier ack".format(self.node_id))
215+
216+
# Step5: Wait until all directly or indirectly connected nodes are ready.
217+
start = time()
218+
while True:
219+
if time() - start > self.timeout:
220+
raise TimeoutError("Node-[{}] timeout when waiting barrier! ".format(task.router.node_id))
221+
222+
if self._runtime.get_ack_num() != self._barrier_peers_nums:
223+
sleep(0.1)
224+
else:
225+
break
226+
227+
logging.info("Node-[{}] env_step:[{}] barrier finish".format(self.node_id, ctx.env_step))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import random
2+
import time
3+
import socket
4+
import pytest
5+
import multiprocessing as mp
6+
from ditk import logging
7+
from ding.framework import task
8+
from ding.framework.parallel import Parallel
9+
from ding.framework.context import OnlineRLContext
10+
from ding.framework.middleware.barrier import Barrier
11+
12+
PORTS_LIST = ["1235", "1236", "1237"]
13+
14+
15+
class EnvStepMiddleware:
16+
17+
def __call__(self, ctx):
18+
yield
19+
ctx.env_step += 1
20+
21+
22+
class SleepMiddleware:
23+
24+
def __init__(self, node_id):
25+
self.node_id = node_id
26+
27+
def random_sleep(self, diection, step):
28+
random.seed(self.node_id + step)
29+
sleep_second = random.randint(1, 5)
30+
logging.info("Node:[{}] env_step:[{}]-{} will sleep:{}s".format(self.node_id, step, diection, sleep_second))
31+
for i in range(sleep_second):
32+
time.sleep(1)
33+
print("Node:[{}] sleepping...".format(self.node_id))
34+
logging.info("Node:[{}] env_step:[{}]-{} wake up!".format(self.node_id, step, diection))
35+
36+
def __call__(self, ctx):
37+
self.random_sleep("forward", ctx.env_step)
38+
yield
39+
self.random_sleep("backward", ctx.env_step)
40+
41+
42+
def star_barrier():
43+
with task.start(ctx=OnlineRLContext()):
44+
node_id = task.router.node_id
45+
if node_id == 0:
46+
attch_from_nums = 3
47+
else:
48+
attch_from_nums = 0
49+
barrier = Barrier(attch_from_nums)
50+
task.use(barrier, lock=False)
51+
task.use(SleepMiddleware(node_id), lock=False)
52+
task.use(barrier, lock=False)
53+
task.use(EnvStepMiddleware(), lock=False)
54+
try:
55+
task.run(2)
56+
except Exception as e:
57+
logging.error(e)
58+
assert False
59+
60+
61+
def mesh_barrier():
62+
with task.start(ctx=OnlineRLContext()):
63+
node_id = task.router.node_id
64+
attch_from_nums = 3 - task.router.node_id
65+
barrier = Barrier(attch_from_nums)
66+
task.use(barrier, lock=False)
67+
task.use(SleepMiddleware(node_id), lock=False)
68+
task.use(barrier, lock=False)
69+
task.use(EnvStepMiddleware(), lock=False)
70+
try:
71+
task.run(2)
72+
except Exception as e:
73+
logging.error(e)
74+
assert False
75+
76+
77+
def unmatch_barrier():
78+
with task.start(ctx=OnlineRLContext()):
79+
node_id = task.router.node_id
80+
attch_from_nums = 3 - task.router.node_id
81+
task.use(Barrier(attch_from_nums, 5), lock=False)
82+
if node_id != 2:
83+
task.use(Barrier(attch_from_nums, 5), lock=False)
84+
try:
85+
task.run(2)
86+
except TimeoutError as e:
87+
assert node_id != 2
88+
logging.info("Node:[{}] timeout with barrier".format(node_id))
89+
else:
90+
time.sleep(5)
91+
assert node_id == 2
92+
logging.info("Node:[{}] finish barrier".format(node_id))
93+
94+
95+
def launch_barrier(args):
96+
i, topo, fn, test_id = args
97+
address = socket.gethostbyname(socket.gethostname())
98+
topology = "alone"
99+
attach_to = []
100+
port_base = PORTS_LIST[test_id]
101+
port = port_base + str(i)
102+
if topo == 'star':
103+
if i != 0:
104+
attach_to = ['tcp://{}:{}{}'.format(address, port_base, 0)]
105+
elif topo == 'mesh':
106+
for j in range(i):
107+
attach_to.append('tcp://{}:{}{}'.format(address, port_base, j))
108+
109+
Parallel.runner(
110+
node_ids=i,
111+
ports=int(port),
112+
attach_to=attach_to,
113+
topology=topology,
114+
protocol="tcp",
115+
n_parallel_workers=1,
116+
startup_interval=0
117+
)(fn)
118+
119+
120+
@pytest.mark.unittest
121+
def test_star_topology_barrier():
122+
ctx = mp.get_context("spawn")
123+
with ctx.Pool(processes=4) as pool:
124+
pool.map(launch_barrier, [[i, 'star', star_barrier, 0] for i in range(4)])
125+
pool.close()
126+
pool.join()
127+
128+
129+
@pytest.mark.unittest
130+
def test_mesh_topology_barrier():
131+
ctx = mp.get_context("spawn")
132+
with ctx.Pool(processes=4) as pool:
133+
pool.map(launch_barrier, [[i, 'mesh', mesh_barrier, 1] for i in range(4)])
134+
pool.close()
135+
pool.join()
136+
137+
138+
@pytest.mark.unittest
139+
def test_unmatch_barrier():
140+
ctx = mp.get_context("spawn")
141+
with ctx.Pool(processes=4) as pool:
142+
pool.map(launch_barrier, [[i, 'mesh', unmatch_barrier, 2] for i in range(4)])
143+
pool.close()
144+
pool.join()

0 commit comments

Comments
 (0)