|
| 1 | +from time import sleep, time |
| 2 | +from ditk import logging |
| 3 | +from ding.framework import task |
| 4 | +from ding.utils.lock_helper import LockContext, LockContextType |
| 5 | +from ding.utils.design_helper import SingletonMetaclass |
| 6 | + |
| 7 | + |
| 8 | +class BarrierRuntime(metaclass=SingletonMetaclass): |
| 9 | + |
| 10 | + def __init__(self, node_id: int, max_world_size: int = 100): |
| 11 | + """ |
| 12 | + Overview: |
| 13 | + 'BarrierRuntime' is a singleton class. In addition, it must be initialized before the |
| 14 | + class 'Parallel' starts MQ, otherwise the messages sent by other nodes may be lost after |
| 15 | + the detection is completed. We don't have a message retransmission mechanism, and losing |
| 16 | + a message means deadlock. |
| 17 | + Arguments: |
| 18 | + - node_id (int): Process ID. |
| 19 | + - max_world_size (int, optional): The maximum total number of processes that can be |
| 20 | + synchronized, the defalut value is 100. |
| 21 | + """ |
| 22 | + self.node_id = node_id |
| 23 | + self._has_detected = False |
| 24 | + self._range_len = len(str(max_world_size)) + 1 |
| 25 | + |
| 26 | + self._barrier_epoch = 0 |
| 27 | + self._barrier_recv_peers_buff = dict() |
| 28 | + self._barrier_recv_peers = dict() |
| 29 | + self._barrier_ack_peers = [] |
| 30 | + self._barrier_lock = LockContext(LockContextType.THREAD_LOCK) |
| 31 | + |
| 32 | + self.mq_type = task.router.mq_type |
| 33 | + self._connected_peers = dict() |
| 34 | + self._connected_peers_lock = LockContext(LockContextType.THREAD_LOCK) |
| 35 | + self._keep_alive_daemon = False |
| 36 | + |
| 37 | + self._event_name_detect = "b_det" |
| 38 | + self.event_name_req = "b_req" |
| 39 | + self.event_name_ack = "b_ack" |
| 40 | + |
| 41 | + def _alive_msg_handler(self, peer_id): |
| 42 | + with self._connected_peers_lock: |
| 43 | + self._connected_peers[peer_id] = time() |
| 44 | + |
| 45 | + def _add_barrier_req(self, msg): |
| 46 | + peer, epoch = self._unpickle_barrier_tag(msg) |
| 47 | + logging.debug("Node:[{}] recv barrier request from node:{}, epoch:{}".format(self.node_id, peer, epoch)) |
| 48 | + with self._barrier_lock: |
| 49 | + if peer not in self._barrier_recv_peers: |
| 50 | + self._barrier_recv_peers[peer] = [] |
| 51 | + self._barrier_recv_peers[peer].append(epoch) |
| 52 | + |
| 53 | + def _add_barrier_ack(self, peer): |
| 54 | + logging.debug("Node:[{}] recv barrier ack from node:{}".format(self.node_id, peer)) |
| 55 | + with self._barrier_lock: |
| 56 | + self._barrier_ack_peers.append(peer) |
| 57 | + |
| 58 | + def _unpickle_barrier_tag(self, msg): |
| 59 | + return msg % self._range_len, msg // self._range_len |
| 60 | + |
| 61 | + def pickle_barrier_tag(self): |
| 62 | + return int(self._barrier_epoch * self._range_len + self.node_id) |
| 63 | + |
| 64 | + def reset_all_peers(self): |
| 65 | + with self._barrier_lock: |
| 66 | + for peer, q in self._barrier_recv_peers.items(): |
| 67 | + if len(q) != 0: |
| 68 | + assert q.pop(0) == self._barrier_epoch |
| 69 | + self._barrier_ack_peers = [] |
| 70 | + self._barrier_epoch += 1 |
| 71 | + |
| 72 | + def get_recv_num(self): |
| 73 | + count = 0 |
| 74 | + with self._barrier_lock: |
| 75 | + if len(self._barrier_recv_peers) > 0: |
| 76 | + for _, q in self._barrier_recv_peers.items(): |
| 77 | + if len(q) > 0 and q[0] == self._barrier_epoch: |
| 78 | + count += 1 |
| 79 | + return count |
| 80 | + |
| 81 | + def get_ack_num(self): |
| 82 | + with self._barrier_lock: |
| 83 | + return len(self._barrier_ack_peers) |
| 84 | + |
| 85 | + def detect_alive(self, expected, timeout): |
| 86 | + # The barrier can only block other nodes within the visible range of the current node. |
| 87 | + # If the 'attch_to' list of a node is empty, it does not know how many nodes will attach to him, |
| 88 | + # so we cannot specify the effective range of a barrier in advance. |
| 89 | + assert task._running |
| 90 | + task.on(self._event_name_detect, self._alive_msg_handler) |
| 91 | + task.on(self.event_name_req, self._add_barrier_req) |
| 92 | + task.on(self.event_name_ack, self._add_barrier_ack) |
| 93 | + start = time() |
| 94 | + while True: |
| 95 | + sleep(0.1) |
| 96 | + task.emit(self._event_name_detect, self.node_id, only_remote=True) |
| 97 | + # In case the other node has not had time to receive our detect message, |
| 98 | + # we will send an additional round. |
| 99 | + if self._has_detected: |
| 100 | + break |
| 101 | + with self._connected_peers_lock: |
| 102 | + if len(self._connected_peers) == expected: |
| 103 | + self._has_detected = True |
| 104 | + |
| 105 | + if time() - start > timeout: |
| 106 | + raise TimeoutError("Node-[{}] timeout when waiting barrier! ".format(task.router.node_id)) |
| 107 | + |
| 108 | + task.off(self._event_name_detect) |
| 109 | + logging.info( |
| 110 | + "Barrier detect node done, node-[{}] has connected with {} active nodes!".format(self.node_id, expected) |
| 111 | + ) |
| 112 | + |
| 113 | + |
| 114 | +class BarrierContext: |
| 115 | + |
| 116 | + def __init__(self, runtime: BarrierRuntime, detect_timeout, expected_peer_num: int = 0): |
| 117 | + self._runtime = runtime |
| 118 | + self._expected_peer_num = expected_peer_num |
| 119 | + self._timeout = detect_timeout |
| 120 | + |
| 121 | + def __enter__(self): |
| 122 | + if not self._runtime._has_detected: |
| 123 | + self._runtime.detect_alive(self._expected_peer_num, self._timeout) |
| 124 | + |
| 125 | + def __exit__(self, exc_type, exc_value, tb): |
| 126 | + if exc_type is not None: |
| 127 | + import traceback |
| 128 | + traceback.print_exception(exc_type, exc_value, tb) |
| 129 | + self._runtime.reset_all_peers() |
| 130 | + |
| 131 | + |
| 132 | +class Barrier: |
| 133 | + |
| 134 | + def __init__(self, attch_from_nums: int, timeout: int = 60): |
| 135 | + """ |
| 136 | + Overview: |
| 137 | + Barrier() is a middleware for debug or profiling. It can synchronize the task step of each |
| 138 | + process within the scope of all visible processes. When using Barrier(), you need to pay |
| 139 | + attention to the following points: |
| 140 | +
|
| 141 | + 1. All processes must call the same number of Barrier(), otherwise a deadlock occurs. |
| 142 | +
|
| 143 | + 2. 'attch_from_nums' is a very important variable, This value indicates the number of times |
| 144 | + the current process will be attached to by other processes (the number of connections |
| 145 | + established). |
| 146 | + For example: |
| 147 | + Node0: address: 127.0.0.1:12345, attach_to = [] |
| 148 | + Node1: address: 127.0.0.1:12346, attach_to = ["tcp://127.0.0.1:12345"] |
| 149 | + For Node0, the 'attch_from_nums' value is 1. (It will be acttched by Node1) |
| 150 | + For Node1, the 'attch_from_nums' value is 0. (No one will attach to Node1) |
| 151 | + Please note that this value must be given correctly, otherwise, for a node whose 'attach_to' |
| 152 | + list is empty, it cannot perceive how many processes will establish connections with it, |
| 153 | + resulting in any form of synchronization cannot be performed. |
| 154 | +
|
| 155 | + 3. Barrier() is thread-safe, but it is not recommended to use barrier in multithreading. You need |
| 156 | + to carefully calculate the number of times each thread calls Barrier() to avoid deadlock. |
| 157 | +
|
| 158 | + 4. In normal training tasks, please do not use Barrier(), which will force the step synchronization |
| 159 | + between each process, so it will greatly damage the training efficiency. In addition, if your |
| 160 | + training task has dynamic processes, do not use Barrier() to prevent deadlock. |
| 161 | +
|
| 162 | + Arguments: |
| 163 | + - attch_from_nums (int): [description] |
| 164 | + - timeout (int, optional): The timeout for successful detection of 'expected_peer_num' |
| 165 | + number of nodes, the default value is 60 seconds. |
| 166 | + """ |
| 167 | + self.node_id = task.router.node_id |
| 168 | + self.timeout = timeout |
| 169 | + self._runtime: BarrierRuntime = task.router.barrier_runtime |
| 170 | + self._barrier_peers_nums = task.get_attch_to_len() + attch_from_nums |
| 171 | + |
| 172 | + logging.info( |
| 173 | + "Node:[{}], attach to num is:{}, attach from num is:{}".format( |
| 174 | + self.node_id, task.get_attch_to_len(), attch_from_nums |
| 175 | + ) |
| 176 | + ) |
| 177 | + |
| 178 | + def __call__(self, ctx): |
| 179 | + self._wait_barrier(ctx) |
| 180 | + yield |
| 181 | + self._wait_barrier(ctx) |
| 182 | + |
| 183 | + def _wait_barrier(self, ctx): |
| 184 | + self_ready = False |
| 185 | + with BarrierContext(self._runtime, self.timeout, self._barrier_peers_nums): |
| 186 | + logging.debug("Node:[{}] enter barrier".format(self.node_id)) |
| 187 | + # Step1: Notifies all the attached nodes that we have reached the barrier. |
| 188 | + task.emit(self._runtime.event_name_req, self._runtime.pickle_barrier_tag(), only_remote=True) |
| 189 | + logging.debug("Node:[{}] sended barrier request".format(self.node_id)) |
| 190 | + |
| 191 | + # Step2: We check the number of flags we have received. |
| 192 | + # In the current CI design of DI-engine, there will always be a node whose 'attach_to' list is empty, |
| 193 | + # so there will always be a node that will send ACK unconditionally, so deadlock will not occur. |
| 194 | + if self._runtime.get_recv_num() == self._barrier_peers_nums: |
| 195 | + self_ready = True |
| 196 | + |
| 197 | + # Step3: Waiting for our own to be ready. |
| 198 | + # Even if the current process has reached the barrier, we will not send an ack immediately, |
| 199 | + # we need to wait for the slowest directly connected or indirectly connected peer to |
| 200 | + # reach the barrier. |
| 201 | + start = time() |
| 202 | + if not self_ready: |
| 203 | + while True: |
| 204 | + if time() - start > self.timeout: |
| 205 | + raise TimeoutError("Node-[{}] timeout when waiting barrier! ".format(task.router.node_id)) |
| 206 | + |
| 207 | + if self._runtime.get_recv_num() != self._barrier_peers_nums: |
| 208 | + sleep(0.1) |
| 209 | + else: |
| 210 | + break |
| 211 | + |
| 212 | + # Step4: Notifies all attached nodes that we are ready. |
| 213 | + task.emit(self._runtime.event_name_ack, self.node_id, only_remote=True) |
| 214 | + logging.debug("Node:[{}] sended barrier ack".format(self.node_id)) |
| 215 | + |
| 216 | + # Step5: Wait until all directly or indirectly connected nodes are ready. |
| 217 | + start = time() |
| 218 | + while True: |
| 219 | + if time() - start > self.timeout: |
| 220 | + raise TimeoutError("Node-[{}] timeout when waiting barrier! ".format(task.router.node_id)) |
| 221 | + |
| 222 | + if self._runtime.get_ack_num() != self._barrier_peers_nums: |
| 223 | + sleep(0.1) |
| 224 | + else: |
| 225 | + break |
| 226 | + |
| 227 | + logging.info("Node-[{}] env_step:[{}] barrier finish".format(self.node_id, ctx.env_step)) |
0 commit comments