diff --git a/sc2learner/agents/ppo_agent.py b/sc2learner/agents/ppo_agent.py index 3811f1d..bd273ad 100644 --- a/sc2learner/agents/ppo_agent.py +++ b/sc2learner/agents/ppo_agent.py @@ -368,12 +368,20 @@ def _pull_data(self, zmq_context, data_queue, episode_infos, unroll_split, while True: data = receiver.recv_pyobj() if unroll_split > 1: - data_queue.extend(list(zip(*( - [list(zip(*transform_tuple( - data[0], lambda x: np.split(x, unroll_split))))] + \ - [np.split(arr, unroll_split) for arr in data[1:-2]] + \ - [[data[-2] for _ in range(unroll_split)]] - )))) + if isinstance(data[0], tuple): # use_action_mask + data_queue.extend(list(zip(*( + [list(zip(*transform_tuple( + data[0], lambda x: np.split(x, unroll_split))))] + \ + [np.split(arr, unroll_split) for arr in data[1:-2]] + \ + [[data[-2] for _ in range(unroll_split)]] + )))) + else: # nouse_action_mask + data_queue.extend(list(zip(*( + [transform_tuple( + data[0], lambda x: np.split(x, unroll_split))] + \ + [np.split(arr, unroll_split) for arr in data[1:-2]] + \ + [[data[-2] for _ in range(unroll_split)]] + )))) else: data_queue.append(data[:-1]) episode_infos.extend(data[-1])