From ec0ba9dc90c98d3704133cc523feeaa38b78924c Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Thu, 1 Jun 2023 00:21:28 +0800 Subject: [PATCH 01/54] feature(yzj): adapt multi agent env gobigger with ez --- lzero/entry/__init__.py | 1 + lzero/entry/train_muzero_gobigger.py | 166 ++++ lzero/mcts/buffer/__init__.py | 1 + .../buffer/gobigger_game_buffer_muzero.py | 703 ++++++++++++++++ .../gobigger_efficientzero_model_mlp.py | 480 +++++++++++ lzero/model/gobigger/gobigger_model.py | 314 +++++++ lzero/model/gobigger/network/__init__.py | 8 + lzero/model/gobigger/network/activation.py | 96 +++ lzero/model/gobigger/network/encoder.py | 136 +++ lzero/model/gobigger/network/nn_module.py | 235 ++++++ lzero/model/gobigger/network/normalization.py | 36 + lzero/model/gobigger/network/res_block.py | 231 +++++ lzero/model/gobigger/network/rnn.py | 276 ++++++ .../gobigger/network/scatter_connection.py | 107 +++ lzero/model/gobigger/network/soft_argmax.py | 60 ++ lzero/model/gobigger/network/transformer.py | 397 +++++++++ lzero/policy/gobigger_efficientzero.py | 787 ++++++++++++++++++ lzero/worker/__init__.py | 2 + lzero/worker/gobigger_muzero_collector.py | 658 +++++++++++++++ lzero/worker/gobigger_muzero_evaluator.py | 434 ++++++++++ .../config/gobigger_efficientzero_config.py | 141 ++++ zoo/gobigger/env/gobigger_env.py | 507 +++++++++++ 22 files changed, 5776 insertions(+) create mode 100644 lzero/entry/train_muzero_gobigger.py create mode 100644 lzero/mcts/buffer/gobigger_game_buffer_muzero.py create mode 100644 lzero/model/gobigger/gobigger_efficientzero_model_mlp.py create mode 100644 lzero/model/gobigger/gobigger_model.py create mode 100644 lzero/model/gobigger/network/__init__.py create mode 100644 lzero/model/gobigger/network/activation.py create mode 100644 lzero/model/gobigger/network/encoder.py create mode 100644 lzero/model/gobigger/network/nn_module.py create mode 100644 lzero/model/gobigger/network/normalization.py create mode 100644 lzero/model/gobigger/network/res_block.py create mode 100644 lzero/model/gobigger/network/rnn.py create mode 100644 lzero/model/gobigger/network/scatter_connection.py create mode 100644 lzero/model/gobigger/network/soft_argmax.py create mode 100644 lzero/model/gobigger/network/transformer.py create mode 100644 lzero/policy/gobigger_efficientzero.py create mode 100644 lzero/worker/gobigger_muzero_collector.py create mode 100644 lzero/worker/gobigger_muzero_evaluator.py create mode 100644 zoo/gobigger/config/gobigger_efficientzero_config.py create mode 100644 zoo/gobigger/env/gobigger_env.py diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py index 352d29ddf..60ac5f42b 100644 --- a/lzero/entry/__init__.py +++ b/lzero/entry/__init__.py @@ -4,3 +4,4 @@ from .eval_muzero import eval_muzero from .eval_muzero_with_gym_env import eval_muzero_with_gym_env from .train_muzero_with_gym_env import train_muzero_with_gym_env +from .train_muzero_gobigger import train_muzero_gobigger diff --git a/lzero/entry/train_muzero_gobigger.py b/lzero/entry/train_muzero_gobigger.py new file mode 100644 index 000000000..4f44ed49d --- /dev/null +++ b/lzero/entry/train_muzero_gobigger.py @@ -0,0 +1,166 @@ +import logging +import os +from functools import partial +from typing import Optional, Tuple + +import torch +from ding.config import compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.utils import set_pkg_seed +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.worker import GoBiggerMuZeroCollector, GoBiggerMuZeroEvaluator +from gobigger.agents import BotAgent + +def train_muzero_gobigger( + input_cfg: Tuple[dict, dict], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': # noqa + """ + Overview: + The train entry for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. + Arguments: + - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type. + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): The pretrained model path, which should + point to the ckpt file of the pretrained model, and an absolute path is recommended. + In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. + - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + + cfg, create_cfg = input_cfg + assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gobigger_efficientzero'], \ + "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero'" + + if create_cfg.policy.type == 'muzero': + from lzero.mcts import MuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'efficientzero': + from lzero.mcts import EfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'sampled_efficientzero': + from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'gobigger_efficientzero': + from lzero.mcts import GoBiggerMuZeroGameBuffer as GameBuffer + + if cfg.policy.cuda and torch.cuda.is_available(): + cfg.policy.device = 'cuda' + else: + cfg.policy.device = 'cpu' + + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create main components: env, policy + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + collector_env.seed(cfg.seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # load pretrained model + if model_path is not None: + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = cfg.policy + batch_size = policy_config.batch_size + # specific game buffer for MCTS+RL algorithms + replay_buffer = GameBuffer(policy_config) + collector = GoBiggerMuZeroCollector( + collect_print_freq=cfg.collect.collector.collect_print_freq, + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config + ) + evaluator = GoBiggerMuZeroEvaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config + ) + + # ============================================================== + # Main loop + # ============================================================== + # Learner's before_run hook. + learner.call_hook('before_run') + while True: + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + collect_kwargs = {} + # set temperature for visit count distributions according to the train_iter, + # please refer to Appendix D in MuZero paper for details. + collect_kwargs['temperature'] = visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ) + + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + + # Evaluate policy performance. + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + # Collect data by default config n_sample/n_episode. + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + # save returned new_data collected by the collector + replay_buffer.push_game_segments(new_data) + # remove the oldest data if the replay buffer is full. + replay_buffer.remove_oldest_data_to_fit() + + # Learn policy from collected data. + for i in range(cfg.policy.update_per_collect): + # Learner will train ``update_per_collect`` times in one iteration. + if replay_buffer.get_num_of_transitions() > batch_size: + train_data = replay_buffer.sample(batch_size, policy) + else: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, ' + f'{replay_buffer} ' + f'continue to collect now ....' + ) + break + + # The core train steps for MCTS+RL algorithms. + log_vars = learner.train(train_data, collector.envstep) + + if cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + break + + # Learner's after_run hook. + learner.call_hook('after_run') + return policy diff --git a/lzero/mcts/buffer/__init__.py b/lzero/mcts/buffer/__init__.py index 80d0cd87b..5be32de5c 100644 --- a/lzero/mcts/buffer/__init__.py +++ b/lzero/mcts/buffer/__init__.py @@ -1,3 +1,4 @@ from .game_buffer_muzero import MuZeroGameBuffer from .game_buffer_efficientzero import EfficientZeroGameBuffer from .game_buffer_sampled_efficientzero import SampledEfficientZeroGameBuffer +from .gobigger_game_buffer_muzero import GoBiggerMuZeroGameBuffer diff --git a/lzero/mcts/buffer/gobigger_game_buffer_muzero.py b/lzero/mcts/buffer/gobigger_game_buffer_muzero.py new file mode 100644 index 000000000..eb1a435b5 --- /dev/null +++ b/lzero/mcts/buffer/gobigger_game_buffer_muzero.py @@ -0,0 +1,703 @@ +from typing import Any, List, Tuple, Union, TYPE_CHECKING, Optional + +import numpy as np +import torch +from ding.utils import BUFFER_REGISTRY + +from lzero.mcts.tree_search.mcts_ctree import MuZeroMCTSCtree as MCTSCtree +from lzero.mcts.tree_search.mcts_ptree import MuZeroMCTSPtree as MCTSPtree +from lzero.mcts.utils import prepare_observation +from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform +from .game_buffer import GameBuffer + +if TYPE_CHECKING: + from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy +from ding.torch_utils import to_tensor, squeeze + +@BUFFER_REGISTRY.register('gobigger_game_buffer_muzero') +class GoBiggerMuZeroGameBuffer(GameBuffer): + """ + Overview: + The specific game buffer for MuZero policy. + """ + + def __init__(self, cfg: dict): + super().__init__(cfg) + """ + Overview: + Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key + in the default configuration, the user-provided value will override the default configuration. Otherwise, + the default configuration will be used. + """ + default_config = self.default_config() + default_config.update(cfg) + self._cfg = default_config + assert self._cfg.env_type in ['not_board_games', 'board_games'] + self.replay_buffer_size = self._cfg.replay_buffer_size + self.batch_size = self._cfg.batch_size + self._alpha = self._cfg.priority_prob_alpha + self._beta = self._cfg.priority_prob_beta + + self.keep_ratio = 1 + self.model_update_interval = 10 + self.num_of_collected_episodes = 0 + self.base_idx = 0 + self.clear_time = 0 + + self.game_segment_buffer = [] + self.game_pos_priorities = [] + self.game_segment_game_pos_look_up = [] + + self.tmp_obs = None # for value obs list [46 + 4(td_step)] not < 50(game_segment) + + def sample( + self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] + ) -> List[Any]: + """ + Overview: + sample data from ``GameBuffer`` and prepare the current and target batch for training. + Arguments: + - batch_size (:obj:`int`): batch size. + - policy (:obj:`Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]`): policy. + Returns: + - train_data (:obj:`List`): List of train data, including current_batch and target_batch. + """ + policy._target_model.to(self._cfg.device) + policy._target_model.eval() + + # obtain the current_batch and prepare target context + reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch( + batch_size, self._cfg.reanalyze_ratio + ) + # target reward, target value + batch_rewards, batch_target_values = self._compute_target_reward_value( + reward_value_context, policy._target_model + ) + # target policy + batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model) + batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( + policy_non_re_context, self._cfg.model.action_space_size + ) + + # fusion of batch_target_policies_re and batch_target_policies_non_re to batch_target_policies + if 0 < self._cfg.reanalyze_ratio < 1: + batch_target_policies = np.concatenate([batch_target_policies_re, batch_target_policies_non_re]) + elif self._cfg.reanalyze_ratio == 1: + batch_target_policies = batch_target_policies_re + elif self._cfg.reanalyze_ratio == 0: + batch_target_policies = batch_target_policies_non_re + + target_batch = [batch_rewards, batch_target_values, batch_target_policies] + + # a batch contains the current_batch and the target_batch + train_data = [current_batch, target_batch] + return train_data + + def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: + """ + Overview: + first sample orig_data through ``_sample_orig_data()``, + then prepare the context of a batch: + reward_value_context: the context of reanalyzed value targets + policy_re_context: the context of reanalyzed policy targets + policy_non_re_context: the context of non-reanalyzed policy targets + current_batch: the inputs of batch + Arguments: + - batch_size (:obj:`int`): the batch size of orig_data from replay buffer. + - reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed) + Returns: + - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch + """ + # obtain the batch context from replay buffer + orig_data = self._sample_orig_data(batch_size) + game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data + batch_size = len(batch_index_list) + obs_list, action_list, mask_list = [], [], [] + # prepare the inputs of a batch + for i in range(batch_size): + game = game_segment_list[i] + pos_in_game_segment = pos_in_game_segment_list[i] + + actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment + + self._cfg.num_unroll_steps].tolist() + # add mask for invalid actions (out of trajectory), 1 for valid, 0 for invalid + mask_tmp = [1. for i in range(len(actions_tmp))] + mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] + + # pad random action + actions_tmp += [ + np.random.randint(0, game.action_space_size) + for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) + ] + + # obtain the input observations + # pad if length of obs in game_segment is less than stack+num_unroll_steps + # e.g. stack+num_unroll_steps = 4+5 + obs_list.append( + game_segment_list[i].get_unroll_obs( + pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True + ) + ) + action_list.append(actions_tmp) + mask_list.append(mask_tmp) + + # formalize the input observations + # obs_list = prepare_observation(obs_list, self._cfg.model.model_type) + + # formalize the inputs of a batch + current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list] + for i in range(len(current_batch)): + current_batch[i] = np.asarray(current_batch[i]) + + total_transitions = self.get_num_of_transitions() + + # obtain the context of value targets + reward_value_context = self._prepare_reward_value_context( + batch_index_list, game_segment_list, pos_in_game_segment_list, total_transitions + ) + """ + only reanalyze recent reanalyze_ratio (e.g. 50%) data + if self._cfg.reanalyze_outdated is True, batch_index_list is sorted according to its generated env_steps + 0: reanalyze_num -> reanalyzed policy, reanalyze_num:end -> non reanalyzed policy + """ + reanalyze_num = int(batch_size * reanalyze_ratio) + # reanalyzed policy + if reanalyze_num > 0: + # obtain the context of reanalyzed policy targets + policy_re_context = self._prepare_policy_reanalyzed_context( + batch_index_list[:reanalyze_num], game_segment_list[:reanalyze_num], + pos_in_game_segment_list[:reanalyze_num] + ) + else: + policy_re_context = None + + # non reanalyzed policy + if reanalyze_num < batch_size: + # obtain the context of non-reanalyzed policy targets + policy_non_re_context = self._prepare_policy_non_reanalyzed_context( + batch_index_list[reanalyze_num:], game_segment_list[reanalyze_num:], + pos_in_game_segment_list[reanalyze_num:] + ) + else: + policy_non_re_context = None + + context = reward_value_context, policy_re_context, policy_non_re_context, current_batch + return context + + def _prepare_reward_value_context( + self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[Any], + total_transitions: int + ) -> List[Any]: + """ + Overview: + prepare the context of rewards and values for calculating TD value target in reanalyzing part. + Arguments: + - batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer + - game_segment_list (:obj:`list`): list of game segments + - pos_in_game_segment_list (:obj:`list`): list of transition index in game_segment + - total_transitions (:obj:`int`): number of collected transitions + Returns: + - reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, + td_steps_list, action_mask_segment, to_play_segment + """ + # zero_obs = game_segment_list[0].zero_obs() + value_obs_list = [] + # the value is valid or not (out of game_segment) + value_mask = [] + rewards_list = [] + game_segment_lens = [] + # for board games + action_mask_segment, to_play_segment = [], [] + + td_steps_list = [] + for game_segment, state_index, idx in zip(game_segment_list, pos_in_game_segment_list, batch_index_list): + game_segment_len = len(game_segment) + game_segment_lens.append(game_segment_len) + + td_steps = np.clip(self._cfg.td_steps, 1, max(1, game_segment_len - state_index)).astype(np.int32) + + # prepare the corresponding observations for bootstrapped values o_{t+k} + # o[t+ td_steps, t + td_steps + stack frames + num_unroll_steps] + # t=2+3 -> o[2+3, 2+3+4+5] -> o[5, 14] + game_obs = game_segment.get_unroll_obs(state_index + td_steps, self._cfg.num_unroll_steps) + + rewards_list.append(game_segment.reward_segment) + + # for board games + action_mask_segment.append(game_segment.action_mask_segment) + to_play_segment.append(game_segment.to_play_segment) + + for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + # get the bootstrapped target obs + td_steps_list.append(td_steps) + # index of bootstrapped obs o_{t+td_steps} + bootstrap_index = current_index + td_steps + + if bootstrap_index < game_segment_len: + value_mask.append(1) + # beg_index = bootstrap_index - (state_index + td_steps), max of beg_index is num_unroll_steps + beg_index = current_index - state_index + end_index = beg_index + self._cfg.model.frame_stack_num + # the stacked obs in time t + obs = game_obs[beg_index:end_index] + self.tmp_obs = obs # will be masked + else: + value_mask.append(0) + obs = self.tmp_obs # will be masked + value_obs_list.append(obs.tolist()) + + reward_value_context = [ + value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, + action_mask_segment, to_play_segment + ] + return reward_value_context + + def _prepare_policy_non_reanalyzed_context( + self, batch_index_list: List[int], game_segment_list: List[Any], pos_in_game_segment_list: List[int] + ) -> List[Any]: + """ + Overview: + prepare the context of policies for calculating policy target in non-reanalyzing part, just return the policy in self-play + Arguments: + - batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer + - game_segment_list (:obj:`list`): list of game segments + - pos_in_game_segment_list (:obj:`list`): list transition index in game + Returns: + - policy_non_re_context (:obj:`list`): pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment + """ + child_visits = [] + game_segment_lens = [] + # for board games + action_mask_segment, to_play_segment = [], [] + + for game_segment, state_index, idx in zip(game_segment_list, pos_in_game_segment_list, batch_index_list): + game_segment_len = len(game_segment) + game_segment_lens.append(game_segment_len) + # for board games + action_mask_segment.append(game_segment.action_mask_segment) + to_play_segment.append(game_segment.to_play_segment) + + child_visits.append(game_segment.child_visit_segment) + + policy_non_re_context = [ + pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment + ] + return policy_non_re_context + + def _prepare_policy_reanalyzed_context( + self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str] + ) -> List[Any]: + """ + Overview: + prepare the context of policies for calculating policy target in reanalyzing part. + Arguments: + - batch_index_list (:obj:'list'): start transition index in the replay buffer + - game_segment_list (:obj:'list'): list of game segments + - pos_in_game_segment_list (:obj:'list'): position of transition index in one game history + Returns: + - policy_re_context (:obj:`list`): policy_obs_list, policy_mask, pos_in_game_segment_list, indices, + child_visits, game_segment_lens, action_mask_segment, to_play_segment + """ + zero_obs = game_segment_list[0].zero_obs() + with torch.no_grad(): + # for policy + policy_obs_list = [] + policy_mask = [] + # 0 -> Invalid target policy for padding outside of game segments, + # 1 -> Previous target policy for game segments. + rewards, child_visits, game_segment_lens = [], [], [] + # for board games + action_mask_segment, to_play_segment = [], [] + for game_segment, state_index in zip(game_segment_list, pos_in_game_segment_list): + game_segment_len = len(game_segment) + game_segment_lens.append(game_segment_len) + rewards.append(game_segment.reward_segment) + # for board games + action_mask_segment.append(game_segment.action_mask_segment) + to_play_segment.append(game_segment.to_play_segment) + + child_visits.append(game_segment.child_visit_segment) + # prepare the corresponding observations + game_obs = game_segment.get_unroll_obs(state_index, self._cfg.num_unroll_steps) + for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + + if current_index < game_segment_len: + policy_mask.append(1) + beg_index = current_index - state_index + end_index = beg_index + self._cfg.model.frame_stack_num + obs = game_obs[beg_index:end_index] + else: + policy_mask.append(0) + obs = zero_obs + policy_obs_list.append(obs) + + policy_re_context = [ + policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, + action_mask_segment, to_play_segment + ] + return policy_re_context + + def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any) -> Tuple[Any, Any]: + """ + Overview: + prepare reward and value targets from the context of rewards and values. + Arguments: + - reward_value_context (:obj:'list'): the reward value context + - model (:obj:'torch.tensor'):model of the target model + Returns: + - batch_value_prefixs (:obj:'np.ndarray): batch of value prefix + - batch_target_values (:obj:'np.ndarray): batch of value estimation + """ + value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, \ + to_play_segment = reward_value_context # noqa + # transition_batch_size = game_segment_batch_size * (num_unroll_steps+1) + transition_batch_size = len(value_obs_list) + game_segment_batch_size = len(pos_in_game_segment_list) + + to_play, action_mask = self._preprocess_to_play_and_action_mask( + game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list + ) + if self._cfg.model.continuous_action_space is True: + # when the action space of the environment is continuous, action_mask[:] is None. + action_mask = [ + list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) + ] + # NOTE: in continuous action space env: we set all legal_actions as -1 + legal_actions = [ + [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) + ] + else: + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] + + batch_target_values, batch_rewards = [], [] + with torch.no_grad(): + # value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) + # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors + slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) + network_output = [] + for i in range(slices): + beg_index = self._cfg.mini_infer_size * i + end_index = self._cfg.mini_infer_size * (i + 1) + + # m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() + m_obs = value_obs_list[beg_index:end_index] + m_obs = to_tensor(m_obs) + m_obs = sum(m_obs, []) + + # calculate the target value + m_output = model.initial_inference(m_obs) + + if not model.training: + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + m_output.policy_logits + ] + ) + + network_output.append(m_output) + + # concat the output slices after model inference + if self._cfg.use_root_value: + # use the root values from MCTS, as in EfficiientZero + # the root values have limited improvement but require much more GPU actors; + _, reward_pool, policy_logits_pool, latent_state_roots = concat_output( + network_output, data_type='muzero' + ) + reward_pool = reward_pool.squeeze().tolist() + policy_logits_pool = policy_logits_pool.tolist() + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(transition_batch_size) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(transition_batch_size, legal_actions) + roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) + # do MCTS for a new policy with the recent target model + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) + else: + # python mcts_tree + roots = MCTSPtree.roots(transition_batch_size, legal_actions) + roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) + # do MCTS for a new policy with the recent target model + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) + + roots_values = roots.get_values() + value_list = np.array(roots_values) + else: + # use the predicted values + value_list = concat_output_value(network_output) + + # get last state value + if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: + # TODO(pu): for board_games, very important, to check + value_list = value_list.reshape(-1) * np.array( + [ + self._cfg.discount_factor ** td_steps_list[i] if int(td_steps_list[i]) % + 2 == 0 else -self._cfg.discount_factor ** td_steps_list[i] + for i in range(transition_batch_size) + ] + ) + else: + value_list = value_list.reshape(-1) * ( + np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list + ) + + value_list = value_list * np.array(value_mask) + value_list = value_list.tolist() + horizon_id, value_index = 0, 0 + + for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list, + pos_in_game_segment_list, + to_play_segment): + target_values = [] + target_rewards = [] + base_index = state_index + for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + bootstrap_index = current_index + td_steps_list[value_index] + # for i, reward in enumerate(game.rewards[current_index:bootstrap_index]): + for i, reward in enumerate(reward_list[current_index:bootstrap_index]): + if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: + # TODO(pu): for board_games, very important, to check + if to_play_list[base_index] == to_play_list[i]: + value_list[value_index] += reward * self._cfg.discount_factor ** i + else: + value_list[value_index] += -reward * self._cfg.discount_factor ** i + else: + value_list[value_index] += reward * self._cfg.discount_factor ** i + horizon_id += 1 + + if current_index < game_segment_len_non_re: + target_values.append(value_list[value_index]) + target_rewards.append(reward_list[current_index]) + else: + target_values.append(0) + target_rewards.append(0.0) + # TODO: check + # target_rewards.append(reward) + value_index += 1 + + batch_rewards.append(target_rewards) + batch_target_values.append(target_values) + + batch_rewards = np.asarray(batch_rewards, dtype=object) + batch_target_values = np.asarray(batch_target_values, dtype=object) + return batch_rewards, batch_target_values + + def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any) -> np.ndarray: + """ + Overview: + prepare policy targets from the reanalyzed context of policies + Arguments: + - policy_re_context (:obj:`List`): List of policy context to reanalyzed + Returns: + - batch_target_policies_re + """ + if policy_re_context is None: + return [] + batch_target_policies_re = [] + + # for board games + policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, action_mask_segment, \ + to_play_segment = policy_re_context # noqa + # transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1) + transition_batch_size = len(policy_obs_list) + game_segment_batch_size = len(pos_in_game_segment_list) + + to_play, action_mask = self._preprocess_to_play_and_action_mask( + game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list + ) + + if self._cfg.model.continuous_action_space is True: + # when the action space of the environment is continuous, action_mask[:] is None. + action_mask = [ + list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) + ] + # NOTE: in continuous action space env: we set all legal_actions as -1 + legal_actions = [ + [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) + ] + else: + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] + + with torch.no_grad(): + policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type) + # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors + slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) + network_output = [] + for i in range(slices): + beg_index = self._cfg.mini_infer_size * i + end_index = self._cfg.mini_infer_size * (i + 1) + m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device).float() + m_output = model.initial_inference(m_obs) + if not model.training: + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + m_output.policy_logits + ] + ) + + network_output.append(m_output) + + _, reward_pool, policy_logits_pool, latent_state_roots = concat_output(network_output, data_type='muzero') + reward_pool = reward_pool.squeeze().tolist() + policy_logits_pool = policy_logits_pool.tolist() + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size + ).astype(np.float32).tolist() for _ in range(transition_batch_size) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(transition_batch_size, legal_actions) + roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) + # do MCTS for a new policy with the recent target model + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) + else: + # python mcts_tree + roots = MCTSPtree.roots(transition_batch_size, legal_actions) + roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) + # do MCTS for a new policy with the recent target model + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) + + roots_legal_actions_list = legal_actions + roots_distributions = roots.get_distributions() + policy_index = 0 + for state_index, game_index in zip(pos_in_game_segment_list, batch_index_list): + target_policies = [] + + for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + distributions = roots_distributions[policy_index] + + if policy_mask[policy_index] == 0: + # NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0 + target_policies.append([0 for _ in range(self._cfg.model.action_space_size)]) + else: + if distributions is None: + # if at some obs, the legal_action is None, add the fake target_policy + target_policies.append( + list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size) + ) + else: + if self._cfg.env_type == 'not_board_games': + # for atari/classic_control/box2d environments that only have one player. + sum_visits = sum(distributions) + policy = [visit_count / sum_visits for visit_count in distributions] + target_policies.append(policy) + else: + # for board games that have two players and legal_actions is dy + policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)] + # to make sure target_policies have the same dimension + sum_visits = sum(distributions) + policy = [visit_count / sum_visits for visit_count in distributions] + for index, legal_action in enumerate(roots_legal_actions_list[policy_index]): + policy_tmp[legal_action] = policy[index] + target_policies.append(policy_tmp) + + policy_index += 1 + + batch_target_policies_re.append(target_policies) + + batch_target_policies_re = np.array(batch_target_policies_re) + + return batch_target_policies_re + + def _compute_target_policy_non_reanalyzed( + self, policy_non_re_context: List[Any], policy_shape: Optional[int] + ) -> np.ndarray: + """ + Overview: + prepare policy targets from the non-reanalyzed context of policies + Arguments: + - policy_non_re_context (:obj:`List`): List containing: + - pos_in_game_segment_list + - child_visits + - game_segment_lens + - action_mask_segment + - to_play_segment + - policy_shape: self._cfg.model.action_space_size + Returns: + - batch_target_policies_non_re + """ + batch_target_policies_non_re = [] + if policy_non_re_context is None: + return batch_target_policies_non_re + + pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment = policy_non_re_context + game_segment_batch_size = len(pos_in_game_segment_list) + transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1) + + to_play, action_mask = self._preprocess_to_play_and_action_mask( + game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list + ) + + if self._cfg.model.continuous_action_space is True: + # when the action space of the environment is continuous, action_mask[:] is None. + action_mask = [ + list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) + ] + # NOTE: in continuous action space env: we set all legal_actions as -1 + legal_actions = [ + [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) + ] + else: + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] + + with torch.no_grad(): + policy_index = 0 + # 0 -> Invalid target policy for padding outside of game segments, + # 1 -> Previous target policy for game segments. + policy_mask = [] + for game_segment_len, child_visit, state_index in zip(game_segment_lens, child_visits, + pos_in_game_segment_list): + target_policies = [] + + for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + if current_index < game_segment_len: + policy_mask.append(1) + # NOTE: child_visit is already a distribution + distributions = child_visit[current_index] + if self._cfg.env_type == 'not_board_games': + # for atari/classic_control/box2d environments that only have one player. + target_policies.append(distributions) + else: + # for board games that have two players. + policy_tmp = [0 for _ in range(policy_shape)] + for index, legal_action in enumerate(legal_actions[policy_index]): + # only the action in ``legal_action`` the policy logits is nonzero + policy_tmp[legal_action] = distributions[index] + target_policies.append(policy_tmp) + else: + # NOTE: the invalid padding target policy, O is to make sure the correspoding cross_entropy_loss=0 + policy_mask.append(0) + target_policies.append([0 for _ in range(policy_shape)]) + + policy_index += 1 + + batch_target_policies_non_re.append(target_policies) + batch_target_policies_non_re = np.asarray(batch_target_policies_non_re) + return batch_target_policies_non_re + + def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) -> None: + """ + Overview: + Update the priority of training data. + Arguments: + - train_data (:obj:`Optional[List[Optional[np.ndarray]]]`): training data to be updated priority. + - batch_priorities (:obj:`batch_priorities`): priorities to update to. + NOTE: + train_data = [current_batch, target_batch] + current_batch = [obs_list, action_list, mask_list, batch_index_list, weights, make_time_list] + """ + indices = train_data[0][3] + metas = {'make_time': train_data[0][5], 'batch_priorities': batch_priorities} + # only update the priorities for data still in replay buffer + for i in range(len(indices)): + if metas['make_time'][i] > self.clear_time: + idx, prio = indices[i], metas['batch_priorities'][i] + self.game_pos_priorities[idx] = prio diff --git a/lzero/model/gobigger/gobigger_efficientzero_model_mlp.py b/lzero/model/gobigger/gobigger_efficientzero_model_mlp.py new file mode 100644 index 000000000..ee8dbfb40 --- /dev/null +++ b/lzero/model/gobigger/gobigger_efficientzero_model_mlp.py @@ -0,0 +1,480 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from ding.torch_utils import MLP +from ding.utils import MODEL_REGISTRY, SequenceType +from numpy import ndarray + +from ..common import EZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP +from ..utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean +from .gobigger_model import Encoder +import yaml +from easydict import EasyDict +from ding.utils.data import default_collate + +@MODEL_REGISTRY.register('EfficientZeroModelMLP') +class GoBiggerEfficientZeroModelMLP(nn.Module): + + def __init__( + self, + observation_shape: int = 2, + action_space_size: int = 6, + lstm_hidden_size: int = 512, + latent_state_dim: int = 256, + fc_reward_layers: SequenceType = [32], + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + reward_support_size: int = 601, + value_support_size: int = 601, + proj_hid: int = 1024, + proj_out: int = 1024, + pred_hid: int = 512, + pred_out: int = 1024, + self_supervised_learning_loss: bool = True, + categorical_distribution: bool = True, + last_linear_layer_init_zero: bool = True, + state_norm: bool = False, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + discrete_action_encoding_type: str = 'one_hot', + res_connection_in_dynamics: bool = False, + *args, + **kwargs, + ): + """ + Overview: + The definition of the network model of EfficientZero, which is a generalization version for 1D vector obs. + The networks are mainly built on fully connected layers. + Sampled EfficientZero model consists of a representation network, a dynamics network and a prediction network. + The representation network is an MLP network which maps the raw observation to a latent state. + The dynamics network is an MLP+LSTM network which predicts the next latent state, reward_hidden_state and value_prefix given the current latent state and action. + The prediction network is an MLP network which predicts the value and policy given the current latent state. + Arguments: + - observation_shape (:obj:`int`): Observation space shape, e.g. 8 for Lunarlander. + - action_space_size: (:obj:`int`): Action space size, e.g. 4 for Lunarlander. + - lstm_hidden_size (:obj:`int`): The hidden size of LSTM in dynamics network to predict value_prefix. + - latent_state_dim (:obj:`int`): The dimension of latent state, such as 256. + - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - reward_support_size (:obj:`int`): The size of categorical reward output + - value_support_size (:obj:`int`): The size of categorical value output. + - proj_hid (:obj:`int`): The size of projection hidden layer. + - proj_out (:obj:`int`): The size of projection output layer. + - pred_hid (:obj:`int`): The size of prediction hidden layer. + - pred_out (:obj:`int`): The size of prediction output layer. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks in Sampled EfficientZero model, default set it to False. + - categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical distribution for value, reward/value_prefix. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. + - state_norm (:obj:`bool`): Whether to use normalization for latent states, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - discrete_action_encoding_type (:obj:`str`): The type of encoding for discrete action. Default sets it to 'one_hot'. options = {'one_hot', 'not_one_hot'} + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. + """ + super(GoBiggerEfficientZeroModelMLP, self).__init__() + if not categorical_distribution: + self.reward_support_size = 1 + self.value_support_size = 1 + else: + self.reward_support_size = reward_support_size + self.value_support_size = value_support_size + + self.action_space_size = action_space_size + self.continuous_action_space = False + # The dim of action space. For discrete action space, it is 1. + # For continuous action space, it is the dimension of continuous action. + self.action_space_dim = action_space_size if self.continuous_action_space else 1 + assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type + self.discrete_action_encoding_type = discrete_action_encoding_type + if self.continuous_action_space: + self.action_encoding_dim = action_space_size + else: + if self.discrete_action_encoding_type == 'one_hot': + self.action_encoding_dim = action_space_size + elif self.discrete_action_encoding_type == 'not_one_hot': + self.action_encoding_dim = 1 + + self.lstm_hidden_size = lstm_hidden_size + self.proj_hid = proj_hid + self.proj_out = proj_out + self.pred_hid = pred_hid + self.pred_out = pred_out + self.self_supervised_learning_loss = self_supervised_learning_loss + self.last_linear_layer_init_zero = last_linear_layer_init_zero + self.state_norm = state_norm + self.res_connection_in_dynamics = res_connection_in_dynamics + + # self.representation_network = RepresentationNetworkMLP( + # observation_shape=observation_shape, hidden_channels=latent_state_dim, norm_type=norm_type + # ) + with open('lzero/model/gobigger/default_model_config.yaml', "r") as f: + encoder_cfg = yaml.safe_load(f) + encoder_cfg = EasyDict(encoder_cfg) + self.representation_network = Encoder(encoder_cfg) + + self.dynamics_network = DynamicsNetworkMLP( + action_encoding_dim=self.action_encoding_dim, + num_channels=latent_state_dim + self.action_encoding_dim, + common_layer_num=2, + lstm_hidden_size=lstm_hidden_size, + fc_reward_layers=fc_reward_layers, + output_support_size=self.reward_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type, + res_connection_in_dynamics=self.res_connection_in_dynamics, + ) + + self.prediction_network = PredictionNetworkMLP( + action_space_size=action_space_size, + num_channels=latent_state_dim, + fc_value_layers=fc_value_layers, + fc_policy_layers=fc_policy_layers, + output_support_size=self.value_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type + ) + + if self.self_supervised_learning_loss: + # self_supervised_learning_loss related network proposed in EfficientZero + self.projection_input_dim = latent_state_dim + + self.projection = nn.Sequential( + nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) + ) + self.prediction_head = nn.Sequential( + nn.Linear(self.proj_out, self.pred_hid), + nn.BatchNorm1d(self.pred_hid), + activation, + nn.Linear(self.pred_hid, self.pred_out), + ) + + def initial_inference(self, obs: torch.Tensor) -> EZNetworkOutput: + """ + Overview: + Initial inference of EfficientZero model, which is the first step of the EfficientZero model. + To perform the initial inference, we first use the representation network to obtain the "latent_state" of the observation. + Then we use the prediction network to predict the "value" and "policy_logits" of the "latent_state", and + also prepare the zeros-like ``reward_hidden_state`` for the next step of the EfficientZero model. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns (EZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. \ + In initial inference, we set it to zero vector. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The hidden state of LSTM about reward. In initial inference, \ + we set it to the zeros-like hidden state (H and C). + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size. + """ + batch_size = len(obs) + obs = default_collate(obs) + latent_state = self._representation(obs) + device = latent_state.device + policy_logits, value = self._prediction(latent_state) + # zero initialization for reward hidden states + # (hn, cn), each element shape is (layer_num=1, batch_size, lstm_hidden_size) + reward_hidden_state = (torch.zeros(1, batch_size, self.lstm_hidden_size).to(device), + torch.zeros(1, batch_size, self.lstm_hidden_size).to(device),) + return EZNetworkOutput(value, [0. for _ in range(batch_size)], policy_logits, latent_state, reward_hidden_state) + + def recurrent_inference( + self, latent_state: torch.Tensor, reward_hidden_state: torch.Tensor, action: torch.Tensor + ) -> EZNetworkOutput: + """ + Overview: + Recurrent inference of EfficientZero model, which is the rollout step of the EfficientZero model. + To perform the recurrent inference, we first use the dynamics network to predict ``next_latent_state``, + ``reward_hidden_state``, ``value_prefix`` by the given current ``latent_state`` and ``action``. + We then use the prediction network to predict the ``value`` and ``policy_logits``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns (EZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. + Shapes: + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size. + """ + next_latent_state, reward_hidden_state, value_prefix = self._dynamics(latent_state, reward_hidden_state, action) + policy_logits, value = self._prediction(next_latent_state) + return EZNetworkOutput(value, value_prefix, policy_logits, next_latent_state, reward_hidden_state) + + def _representation(self, observation: torch.Tensor) -> Tuple[torch.Tensor]: + """ + Overview: + Use the representation network to encode the observations into latent state. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + """ + latent_state = self.representation_network(observation) + if self.state_norm: + latent_state = renormalize(latent_state) + return latent_state + + def _prediction(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor]: + """ + Overview: + Use the representation network to encode the observations into latent state. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns: + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + """ + policy_logits, value = self.prediction_network(latent_state) + return policy_logits, value + + def _dynamics(self, latent_state: torch.Tensor, reward_hidden_state: Tuple, + action: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor], torch.Tensor]: + """ + Overview: + Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` + ``value_prefix`` and ``next_reward_hidden_state``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns: + - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. + - next_reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + """ + # NOTE: the discrete action encoding type is important for some environments + + # discrete action space + if self.discrete_action_encoding_type == 'one_hot': + # Stack latent_state with the one hot encoded action + if len(action.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action = action.unsqueeze(-1) + + # transform action to one-hot encoding. + # action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4) + action_one_hot = torch.zeros(action.shape[0], self.action_space_size, device=action.device) + # transform action to torch.int64 + action = action.long() + action_one_hot.scatter_(1, action, 1) + action_encoding = action_one_hot + elif self.discrete_action_encoding_type == 'not_one_hot': + action_encoding = action / self.action_space_size + if len(action_encoding.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action_encoding = action_encoding.unsqueeze(-1) + + action_encoding = action_encoding.to(latent_state.device).float() + # state_action_encoding shape: (batch_size, latent_state[1] + action_dim]) or + # (batch_size, latent_state[1] + action_space_size]) depending on the discrete_action_encoding_type. + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + + # NOTE: the key difference with MuZero + next_latent_state, next_reward_hidden_state, value_prefix = self.dynamics_network( + state_action_encoding, reward_hidden_state + ) + + if self.state_norm: + next_latent_state = renormalize(next_latent_state) + return next_latent_state, next_reward_hidden_state, value_prefix + + def project(self, latent_state: torch.Tensor, with_grad=True): + """ + Overview: + Project the latent state to a lower dimension to calculate the self-supervised loss, which is proposed in EfficientZero. + For more details, please refer to the paper ``Exploring Simple Siamese Representation Learning``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - with_grad (:obj:`bool`): Whether to calculate gradient for the projection result. + Returns: + - proj (:obj:`torch.Tensor`): The result embedding vector of projection operation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - proj (:obj:`torch.Tensor`): :math:`(B, projection_output_dim)`, where B is batch_size. + + Examples: + >>> latent_state = torch.randn(256, 64) + >>> output = self.project(latent_state) + >>> output.shape # (256, 1024) + """ + proj = self.projection(latent_state) + + if with_grad: + # with grad, use prediction_head + return self.prediction_head(proj) + else: + return proj.detach() + + def get_params_mean(self) -> float: + return get_params_mean(self) + + +class DynamicsNetworkMLP(nn.Module): + + def __init__( + self, + action_encoding_dim: int = 2, + num_channels: int = 64, + common_layer_num: int = 2, + fc_reward_layers: SequenceType = [32], + output_support_size: int = 601, + lstm_hidden_size: int = 512, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + res_connection_in_dynamics: bool = False, + ): + """ + Overview: + The definition of dynamics network in EfficientZero algorithm, which is used to predict next latent state + value_prefix and reward_hidden_state by the given current latent state and action. + The networks are mainly built on fully connected layers. + Arguments: + - action_encoding_dim (:obj:`int`): The dimension of action encoding. + - num_channels (:obj:`int`): The num of channels in latent states. + - common_layer_num (:obj:`int`): The number of common layers in dynamics network. + - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - output_support_size (:obj:`int`): The size of categorical reward output. + - lstm_hidden_size (:obj:`int`): The hidden size of lstm in dynamics network. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializationss for the last layer of value/policy head, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection in dynamics network. + """ + super().__init__() + assert num_channels > action_encoding_dim, f'num_channels:{num_channels} <= action_encoding_dim:{action_encoding_dim}' + + self.num_channels = num_channels + self.action_encoding_dim = action_encoding_dim + self.latent_state_dim = self.num_channels - self.action_encoding_dim + self.lstm_hidden_size = lstm_hidden_size + self.activation = activation + self.res_connection_in_dynamics = res_connection_in_dynamics + + if self.res_connection_in_dynamics: + self.fc_dynamics_1 = MLP( + in_channels=self.num_channels, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + self.fc_dynamics_2 = MLP( + in_channels=self.latent_state_dim, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + else: + self.fc_dynamics = MLP( + in_channels=self.num_channels, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + + # input_shape: (sequence_length,batch_size,input_size) + # output_shape: (sequence_length, batch_size, hidden_size) + self.lstm = nn.LSTM(input_size=self.latent_state_dim, hidden_size=self.lstm_hidden_size) + + self.fc_reward_head = MLP( + in_channels=self.lstm_hidden_size, + hidden_channels=fc_reward_layers[0], + layer_num=2, + out_channels=output_support_size, + activation=self.activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, state_action_encoding: torch.Tensor, reward_hidden_state): + """ + Overview: + Forward computation of the dynamics network. Predict next latent state given current state_action_encoding and reward hidden state. + Arguments: + - state_action_encoding (:obj:`torch.Tensor`): The state-action encoding, which is the concatenation of \ + latent state and action encoding, with shape (batch_size, num_channels, height, width). + - reward_hidden_state (:obj:`Tuple[torch.Tensor, torch.Tensor]`): The input hidden state of LSTM about reward. + Returns: + - next_latent_state (:obj:`torch.Tensor`): The next latent state, with shape (batch_size, latent_state_dim). + - next_reward_hidden_state (:obj:`torch.Tensor`): The input hidden state of LSTM about reward. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. + """ + if self.res_connection_in_dynamics: + # take the state encoding (latent_state), state_action_encoding[:, -self.action_encoding_dim] + # is action encoding + latent_state = state_action_encoding[:, :-self.action_encoding_dim] + x = self.fc_dynamics_1(state_action_encoding) + # the residual link: add state encoding to the state_action encoding + next_latent_state = x + latent_state + next_latent_state_ = self.fc_dynamics_2(next_latent_state) + else: + next_latent_state = self.fc_dynamics(state_action_encoding) + next_latent_state_ = next_latent_state + + next_latent_state_unsqueeze = next_latent_state_.unsqueeze(0) + value_prefix, next_reward_hidden_state = self.lstm(next_latent_state_unsqueeze, reward_hidden_state) + value_prefix = self.fc_reward_head(value_prefix.squeeze(0)) + + return next_latent_state, next_reward_hidden_state, value_prefix + + def get_dynamic_mean(self) -> float: + return get_dynamic_mean(self) + + def get_reward_mean(self) -> Tuple[ndarray, float]: + return get_reward_mean(self) diff --git a/lzero/model/gobigger/gobigger_model.py b/lzero/model/gobigger/gobigger_model.py new file mode 100644 index 000000000..91ff3b7ee --- /dev/null +++ b/lzero/model/gobigger/gobigger_model.py @@ -0,0 +1,314 @@ +from typing import Dict + +import torch +import torch.nn as nn +from torch import Tensor + +from .network import sequence_mask, ScatterConnection +from .network.encoder import SignBinaryEncoder, BinaryEncoder, OnehotEncoder, TimeEncoder, UnsqueezeEncoder +from .network.nn_module import fc_block, conv2d_block, MLP +from .network.res_block import ResBlock +from .network.transformer import Transformer +from typing import Any, List, Tuple, Union, Optional, Callable + + +def sequence_mask(lengths: torch.Tensor, max_len: Optional[int] =None): + r""" + Overview: + create a mask for a batch sequences with different lengths + Arguments: + - lengths (:obj:`tensor`): lengths in each different sequences, shape could be (n, 1) or (n) + - max_len (:obj:`int`): the padding size, if max_len is None, the padding size is the + max length of sequences + Returns: + - masks (:obj:`torch.BoolTensor`): mask has the same device as lengths + """ + if len(lengths.shape) == 1: + lengths = lengths.unsqueeze(dim=1) + bz = lengths.numel() + if max_len is None: + max_len = lengths.max() + return torch.arange(0, max_len).type_as(lengths).repeat(bz, 1).lt(lengths).to(lengths.device) + + +class ScalarEncoder(nn.Module): + def __init__(self, cfg): + super(ScalarEncoder, self).__init__() + self.whole_cfg = cfg + self.cfg = self.whole_cfg.model.scalar_encoder + self.encode_modules = nn.ModuleDict() + for k, item in self.cfg.modules.items(): + if item['arc'] == 'time': + self.encode_modules[k] = TimeEncoder(embedding_dim=item['embedding_dim']) + elif item['arc'] == 'one_hot': + self.encode_modules[k] = OnehotEncoder(num_embeddings=item['num_embeddings'], ) + elif item['arc'] == 'binary': + self.encode_modules[k] = BinaryEncoder(num_embeddings=item['num_embeddings'], ) + elif item['arc'] == 'sign_binary': + self.encode_modules[k] = SignBinaryEncoder(num_embeddings=item['num_embeddings'], ) + else: + print(f'cant implement {k} for arc {item["arc"]}') + raise NotImplementedError + + self.layers = MLP(in_channels=self.cfg.input_dim, hidden_channels=self.cfg.hidden_dim, + out_channels=self.cfg.output_dim, + layer_num=self.cfg.layer_num, + layer_fn=fc_block, + activation=self.cfg.activation, + norm_type=self.cfg.norm_type, + use_dropout=False + ) + + def forward(self, x: Dict[str, Tensor]): + embeddings = [] + for key, item in self.cfg.modules.items(): + assert key in x, key + embeddings.append(self.encode_modules[key](x[key])) + + out = torch.cat(embeddings, dim=-1) + out = self.layers(out) + return out + + +class TeamEncoder(nn.Module): + def __init__(self, cfg): + super(TeamEncoder, self).__init__() + self.whole_cfg = cfg + self.cfg = self.whole_cfg.model.team_encoder + self.encode_modules = nn.ModuleDict() + + for k, item in self.cfg.modules.items(): + if item['arc'] == 'one_hot': + self.encode_modules[k] = OnehotEncoder(num_embeddings=item['num_embeddings'], ) + elif item['arc'] == 'binary': + self.encode_modules[k] = BinaryEncoder(num_embeddings=item['num_embeddings'], ) + elif item['arc'] == 'sign_binary': + self.encode_modules[k] = SignBinaryEncoder(num_embeddings=item['num_embeddings'], ) + else: + print(f'cant implement {k} for arc {item["arc"]}') + raise NotImplementedError + + self.embedding_dim = self.cfg.embedding_dim + self.encoder_cfg = self.cfg.encoder + self.encode_layers = MLP(in_channels=self.encoder_cfg.input_dim, + hidden_channels=self.encoder_cfg.hidden_dim, + out_channels=self.embedding_dim, + layer_num=self.encoder_cfg.layer_num, + layer_fn=fc_block, + activation=self.encoder_cfg.activation, + norm_type=self.encoder_cfg.norm_type, + use_dropout=False) + # self.activation_type = self.cfg.activation + + self.transformer_cfg = self.cfg.transformer + self.transformer = Transformer( + n_heads=self.transformer_cfg.head_num, + embedding_size=self.embedding_dim, + ffn_size=self.transformer_cfg.ffn_size, + n_layers=self.transformer_cfg.layer_num, + attention_dropout=0.0, + relu_dropout=0.0, + dropout=0.0, + activation=self.transformer_cfg.activation, + variant=self.transformer_cfg.variant, + ) + self.output_cfg = self.cfg.output + self.output_fc = fc_block(self.embedding_dim, + self.output_cfg.output_dim, + norm_type=self.output_cfg.norm_type, + activation=self.output_cfg.activation) + + def forward(self, x): + embeddings = [] + player_num = x['player_num'] + mask = sequence_mask(player_num, max_len=x['view_x'].shape[1]) + for key, item in self.cfg.modules.items(): + assert key in x, f"{key} not implemented" + x_input = x[key] + embeddings.append(self.encode_modules[key](x_input)) + + x = torch.cat(embeddings, dim=-1) + x = self.encode_layers(x) + x = self.transformer(x, mask=mask) + team_info = self.output_fc(x.sum(dim=1) / player_num.unsqueeze(dim=-1)) + return team_info + + +class BallEncoder(nn.Module): + def __init__(self, cfg): + super(BallEncoder, self).__init__() + self.whole_cfg = cfg + self.cfg = self.whole_cfg.model.ball_encoder + self.encode_modules = nn.ModuleDict() + for k, item in self.cfg.modules.items(): + if item['arc'] == 'one_hot': + self.encode_modules[k] = OnehotEncoder(num_embeddings=item['num_embeddings'], ) + elif item['arc'] == 'binary': + self.encode_modules[k] = BinaryEncoder(num_embeddings=item['num_embeddings'], ) + elif item['arc'] == 'sign_binary': + self.encode_modules[k] = SignBinaryEncoder(num_embeddings=item['num_embeddings'], ) + elif item['arc'] == 'unsqueeze': + self.encode_modules[k] = UnsqueezeEncoder() + else: + print(f'cant implement {k} for arc {item["arc"]}') + raise NotImplementedError + self.embedding_dim = self.cfg.embedding_dim + self.encoder_cfg = self.cfg.encoder + self.encode_layers = MLP(in_channels=self.encoder_cfg.input_dim, + hidden_channels=self.encoder_cfg.hidden_dim, + out_channels=self.embedding_dim, + layer_num=self.encoder_cfg.layer_num, + layer_fn=fc_block, + activation=self.encoder_cfg.activation, + norm_type=self.encoder_cfg.norm_type, + use_dropout=False) + + self.transformer_cfg = self.cfg.transformer + self.transformer = Transformer( + n_heads=self.transformer_cfg.head_num, + embedding_size=self.embedding_dim, + ffn_size=self.transformer_cfg.ffn_size, + n_layers=self.transformer_cfg.layer_num, + attention_dropout=0.0, + relu_dropout=0.0, + dropout=0.0, + activation=self.transformer_cfg.activation, + variant=self.transformer_cfg.variant, + ) + self.output_cfg = self.cfg.output + self.output_fc = fc_block(self.embedding_dim, + self.output_cfg.output_dim, + norm_type=self.output_cfg.norm_type, + activation=self.output_cfg.activation) + + def forward(self, x): + ball_num = x['ball_num'] + embeddings = [] + mask = sequence_mask(ball_num, max_len=x['x'].shape[1]) + for key, item in self.cfg.modules.items(): + assert key in x, key + x_input = x[key] + embeddings.append(self.encode_modules[key](x_input)) + x = torch.cat(embeddings, dim=-1) + x = self.encode_layers(x) + x = self.transformer(x, mask=mask) + + ball_info = x.sum(dim=1) / ball_num.unsqueeze(dim=-1) + ball_info = self.output_fc(ball_info) + return x, ball_info + + +class SpatialEncoder(nn.Module): + def __init__(self, cfg): + super(SpatialEncoder, self).__init__() + self.whole_cfg = cfg + self.cfg = self.whole_cfg.model.spatial_encoder + + # scatter related + self.spatial_x = 64 + self.spatial_y = 64 + self.scatter_cfg = self.cfg.scatter + self.scatter_fc = fc_block(in_channels=self.scatter_cfg.input_dim, out_channels=self.scatter_cfg.output_dim, + activation=self.scatter_cfg.activation, norm_type=self.scatter_cfg.norm_type) + self.scatter_connection = ScatterConnection(self.scatter_cfg.scatter_type) + + # resnet related + self.resnet_cfg = self.cfg.resnet + self.get_resnet_blocks() + + self.output_cfg = self.cfg.output + self.output_fc = fc_block( + in_channels=self.spatial_x // 8 * self.spatial_y // 8 * self.resnet_cfg.down_channels[-1], + out_channels=self.output_cfg.output_dim, + norm_type=self.output_cfg.norm_type, + activation=self.output_cfg.activation) + + def get_resnet_blocks(self): + # 2 means food/spore embedding + project = conv2d_block(in_channels=self.scatter_cfg.output_dim + 2, + out_channels=self.resnet_cfg.project_dim, + kernel_size=1, + stride=1, + padding=0, + activation=self.resnet_cfg.activation, + norm_type=self.resnet_cfg.norm_type, + bias=False, + ) + + layers = [project] + dims = [self.resnet_cfg.project_dim] + self.resnet_cfg.down_channels + for i in range(len(dims) - 1): + layer = conv2d_block(in_channels=dims[i], + out_channels=dims[i + 1], + kernel_size=4, + stride=2, + padding=1, + activation=self.resnet_cfg.activation, + norm_type=self.resnet_cfg.norm_type, + bias=False, + ) + layers.append(layer) + layers.append(ResBlock(in_channels=dims[i + 1], + activation=self.resnet_cfg.activation, + norm_type=self.resnet_cfg.norm_type)) + self.resnet = torch.nn.Sequential(*layers) + + + def get_background_embedding(self, coord_x, coord_y, num, ): + + background_ones = torch.ones(size=(coord_x.shape[0], coord_x.shape[1]), device=coord_x.device) + background_mask = sequence_mask(num, max_len=coord_x.shape[1]) + background_ones = (background_ones * background_mask).unsqueeze(-1) + background_embedding = self.scatter_connection.xy_forward(background_ones, + spatial_size=[self.spatial_x, self.spatial_y], + coord_x=coord_x, + coord_y=coord_y) + + return background_embedding + + def forward(self, inputs, ball_embeddings, ): + spatial_info = inputs['spatial_info'] + # food and spore + food_embedding = self.get_background_embedding(coord_x=spatial_info['food_x'], + coord_y=spatial_info['food_y'], + num=spatial_info['food_num'], ) + + spore_embedding = self.get_background_embedding(coord_x=spatial_info['spore_x'], + coord_y=spatial_info['spore_y'], + num=spatial_info['spore_num'], ) + # scatter ball embeddings + ball_info = inputs['ball_info'] + ball_num = ball_info['ball_num'] + ball_mask = sequence_mask(ball_num, max_len=ball_embeddings.shape[1]) + ball_embedding = self.scatter_fc(ball_embeddings) * ball_mask.unsqueeze(dim=2) + + ball_embedding = self.scatter_connection.xy_forward(ball_embedding, + spatial_size=[self.spatial_x, self.spatial_y], + coord_x=spatial_info['ball_x'], + coord_y=spatial_info['ball_y']) + + x = torch.cat([food_embedding, spore_embedding, ball_embedding], dim=1) + + x = self.resnet(x) + + x = torch.flatten(x, start_dim=1, end_dim=-1) + x = self.output_fc(x) + return x + + +class Encoder(nn.Module): + def __init__(self, cfg): + super(Encoder, self).__init__() + self.whole_cfg = cfg + self.scalar_encoder = ScalarEncoder(cfg) + self.team_encoder = TeamEncoder(cfg) + self.ball_encoder = BallEncoder(cfg) + self.spatial_encoder = SpatialEncoder(cfg) + + def forward(self, x): + scalar_info = self.scalar_encoder(x['scalar_info']) + team_info = self.team_encoder(x['team_info']) + ball_embeddings, ball_info = self.ball_encoder(x['ball_info']) + spatial_info = self.spatial_encoder(x, ball_embeddings) + x = torch.cat([scalar_info, team_info, ball_info, spatial_info], dim=1) + return x diff --git a/lzero/model/gobigger/network/__init__.py b/lzero/model/gobigger/network/__init__.py new file mode 100644 index 000000000..50e7db84b --- /dev/null +++ b/lzero/model/gobigger/network/__init__.py @@ -0,0 +1,8 @@ +from .activation import build_activation +from .res_block import ResBlock, ResFCBlock,ResFCBlock2 +from .nn_module import fc_block, fc_block2, conv2d_block, MLP +from .normalization import build_normalization +from .rnn import get_lstm, sequence_mask +from .soft_argmax import SoftArgmax +from .transformer import Transformer +from .scatter_connection import ScatterConnection diff --git a/lzero/model/gobigger/network/activation.py b/lzero/model/gobigger/network/activation.py new file mode 100644 index 000000000..550bee3d5 --- /dev/null +++ b/lzero/model/gobigger/network/activation.py @@ -0,0 +1,96 @@ +""" +Copyright 2020 Sensetime X-lab. All Rights Reserved + +Main Function: + 1. build activation: you can use build_activation to build relu or glu +""" +import torch +import torch.nn as nn + + +class GLU(nn.Module): + r""" + Overview: + Gating Linear Unit. + This class does a thing like this: + + .. code:: python + + # Inputs: input, context, output_size + # The gate value is a learnt function of the input. + gate = sigmoid(linear(input.size)(context)) + # Gate the input and return an output of desired size. + gated_input = gate * input + output = linear(output_size)(gated_input) + return output + Interfaces: + forward + + .. tip:: + + This module also supports 2D convolution, in which case, the input and context must have the same shape. + """ + + def __init__(self, input_dim: int, output_dim: int, context_dim: int, input_type: str = 'fc') -> None: + r""" + Overview: + Init GLU + Arguments: + - input_dim (:obj:`int`): the input dimension + - output_dim (:obj:`int`): the output dimension + - context_dim (:obj:`int`): the context dimension + - input_type (:obj:`str`): the type of input, now support ['fc', 'conv2d'] + """ + super(GLU, self).__init__() + assert (input_type in ['fc', 'conv2d']) + if input_type == 'fc': + self.layer1 = nn.Linear(context_dim, input_dim) + self.layer2 = nn.Linear(input_dim, output_dim) + elif input_type == 'conv2d': + self.layer1 = nn.Conv2d(context_dim, input_dim, 1, 1, 0) + self.layer2 = nn.Conv2d(input_dim, output_dim, 1, 1, 0) + + def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor: + r""" + Overview: + Return GLU computed tensor + Arguments: + - x (:obj:`torch.Tensor`) : the input tensor + - context (:obj:`torch.Tensor`) : the context tensor + Returns: + - x (:obj:`torch.Tensor`): the computed tensor + """ + gate = self.layer1(context) + gate = torch.sigmoid(gate) + x = gate * x + x = self.layer2(x) + return x + +class Swish(nn.Module): + + def __init__(self): + super(Swish, self).__init__() + + def forward(self, x): + x = x * torch.sigmoid(x) + return x + +def build_activation(activation: str, inplace: bool = None) -> nn.Module: + r""" + Overview: + Return the activation module according to the given type. + Arguments: + - actvation (:obj:`str`): the type of activation module, now supports ['relu', 'glu', 'prelu'] + - inplace (:obj:`bool`): can optionally do the operation in-place in relu. Default ``None`` + Returns: + - act_func (:obj:`nn.module`): the corresponding activation module + """ + if inplace is not None: + assert activation == 'relu', 'inplace argument is not compatible with {}'.format(activation) + else: + inplace = True + act_func = {'relu': nn.ReLU(inplace=inplace), 'glu': GLU, 'prelu': nn.PReLU(),'swish': Swish()} + if activation in act_func.keys(): + return act_func[activation] + else: + raise KeyError("invalid key for activation: {}".format(activation)) diff --git a/lzero/model/gobigger/network/encoder.py b/lzero/model/gobigger/network/encoder.py new file mode 100644 index 000000000..daa014ec5 --- /dev/null +++ b/lzero/model/gobigger/network/encoder.py @@ -0,0 +1,136 @@ +import numpy as np +import torch +import torch.nn as nn + + +class OnehotEncoder(nn.Module): + def __init__(self, num_embeddings: int): + super(OnehotEncoder, self).__init__() + self.num_embeddings = num_embeddings + self.main = nn.Embedding.from_pretrained(torch.eye(self.num_embeddings), freeze=True, + padding_idx=None) + + def forward(self, x: torch.Tensor): + x = x.long().clamp_(max=self.num_embeddings - 1) + return self.main(x) + + +class OnehotEmbedding(nn.Module): + def __init__(self, num_embeddings: int, embedding_dim: int): + super(OnehotEmbedding, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.main = nn.Embedding(num_embeddings=self.num_embeddings, embedding_dim=self.embedding_dim) + + def forward(self, x: torch.Tensor): + x = x.long().clamp_(max=self.num_embeddings - 1) + return self.main(x) + + +class BinaryEncoder(nn.Module): + def __init__(self, num_embeddings: int): + super(BinaryEncoder, self).__init__() + self.bit_num = num_embeddings + self.main = nn.Embedding.from_pretrained(self.get_binary_embed_matrix(self.bit_num), freeze=True, + padding_idx=None) + + @staticmethod + def get_binary_embed_matrix(bit_num): + embedding_matrix = [] + for n in range(2 ** bit_num): + embedding = [n >> d & 1 for d in range(bit_num)][::-1] + embedding_matrix.append(embedding) + return torch.tensor(embedding_matrix, dtype=torch.float) + + def forward(self, x: torch.Tensor): + x = x.long().clamp_(max=2 ** self.bit_num - 1) + return self.main(x) + + +class SignBinaryEncoder(nn.Module): + def __init__(self, num_embeddings): + super(SignBinaryEncoder, self).__init__() + self.bit_num = num_embeddings + self.main = nn.Embedding.from_pretrained(self.get_sign_binary_matrix(self.bit_num), freeze=True, + padding_idx=None) + self.max_val = 2 ** (self.bit_num - 1) - 1 + + @staticmethod + def get_sign_binary_matrix(bit_num): + neg_embedding_matrix = [] + pos_embedding_matrix = [] + for n in range(1, 2 ** (bit_num - 1)): + embedding = [n >> d & 1 for d in range(bit_num - 1)][::-1] + neg_embedding_matrix.append([1] + embedding) + pos_embedding_matrix.append([0] + embedding) + embedding_matrix = neg_embedding_matrix[::-1] + [[0 for _ in range(bit_num)]] + pos_embedding_matrix + return torch.tensor(embedding_matrix, dtype=torch.float) + + def forward(self, x: torch.Tensor): + x = x.long().clamp_(max=self.max_val, min=- self.max_val) + return self.main(x + self.max_val) + + +class PositionEncoder(nn.Module): + def __init__(self, num_embeddings, embedding_dim=None): + super(PositionEncoder, self).__init__() + self.n_position = num_embeddings + self.embedding_dim = self.n_position if embedding_dim is None else embedding_dim + self.position_enc = nn.Embedding.from_pretrained( + self.position_encoding_init(self.n_position, self.embedding_dim), + freeze=True, padding_idx=None) + + @staticmethod + def position_encoding_init(n_position, embedding_dim): + ''' Init the sinusoid position encoding table ''' + + # keep dim 0 for padding token position encoding zero vector + position_enc = np.array([ + [pos / np.power(10000, 2 * (j // 2) / embedding_dim) for j in range(embedding_dim)] + for pos in range(n_position)]) + position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # apply sin on 0th,2nd,4th...embedding_dim + position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # apply cos on 1st,3rd,5th...embedding_dim + return torch.from_numpy(position_enc).type(torch.FloatTensor) + + def forward(self, x: torch.Tensor): + return self.position_enc(x) + + +class TimeEncoder(nn.Module): + def __init__(self, embedding_dim): + super(TimeEncoder, self).__init__() + self.embedding_dim = embedding_dim + self.position_array = torch.nn.Parameter(self.get_position_array(), requires_grad=False) + + def get_position_array(self): + x = torch.arange(0, self.embedding_dim, dtype=torch.float) + x = x // 2 * 2 + x = torch.div(x, self.embedding_dim) + x = torch.pow(10000., x) + x = torch.div(1., x) + return x + + def forward(self, x: torch.Tensor): + v = torch.zeros(size=(x.shape[0], self.embedding_dim), dtype=torch.float, device=x.device) + assert len(x.shape) == 1 + x = x.unsqueeze(dim=1) + v[:, 0::2] = torch.sin(x * self.position_array[0::2]) # even + v[:, 1::2] = torch.cos(x * self.position_array[1::2]) # odd + return v + + +class UnsqueezeEncoder(nn.Module): + def __init__(self, unsqueeze_dim: int = -1, norm_value: float = 1): + super(UnsqueezeEncoder, self).__init__() + self.unsqueeze_dim = unsqueeze_dim + self.norm_value = norm_value + + def forward(self, x: torch.Tensor): + x = x.float().unsqueeze(dim=self.unsqueeze_dim) + if self.norm_value != 1: + x = x / self.norm_value + return x + + +if __name__ == '__main__': + pass diff --git a/lzero/model/gobigger/network/nn_module.py b/lzero/model/gobigger/network/nn_module.py new file mode 100644 index 000000000..976831432 --- /dev/null +++ b/lzero/model/gobigger/network/nn_module.py @@ -0,0 +1,235 @@ +from typing import Callable + +import torch +import torch.nn as nn + +from .activation import build_activation +from .normalization import build_normalization + + +def fc_block( + in_channels: int, + out_channels: int, + activation: nn.Module = None, + norm_type: str = None, + use_dropout: bool = False, + dropout_probability: float = 0.5 +) -> nn.Sequential: + r""" + Overview: + Create a fully-connected block with activation, normalization and dropout. + Optional normalization can be done to the dim 1 (across the channels) + x -> fc -> norm -> act -> dropout -> out + Arguments: + - in_channels (:obj:`int`): Number of channels in the input tensor + - out_channels (:obj:`int`): Number of channels in the output tensor + - activation (:obj:`nn.Module`): the optional activation function + - norm_type (:obj:`str`): type of the normalization + - use_dropout (:obj:`bool`) : whether to use dropout in the fully-connected block + - dropout_probability (:obj:`float`) : probability of an element to be zeroed in the dropout. Default: 0.5 + Returns: + - block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the fully-connected block + + .. note:: + + you can refer to nn.linear (https://pytorch.org/docs/master/generated/torch.nn.Linear.html) + """ + block = [] + block.append(nn.Linear(in_channels, out_channels)) + if norm_type is not None and norm_type != 'none': + block.append(build_normalization(norm_type, dim=1)(out_channels)) + if isinstance(activation, str) and activation != 'none': + block.append(build_activation(activation)) + elif isinstance(activation, torch.nn.Module): + block.append(activation) + if use_dropout: + block.append(nn.Dropout(dropout_probability)) + return nn.Sequential(*block) + + +def fc_block2( + in_channels, + out_channels, + activation=None, + norm_type=None, + use_dropout=False, + dropout_probability=0.5 +): + r""" + Overview: + create a fully-connected block with activation, normalization and dropout + optional normalization can be done to the dim 1 (across the channels) + x -> fc -> norm -> act -> dropout -> out + Arguments: + - in_channels (:obj:`int`): Number of channels in the input tensor + - out_channels (:obj:`int`): Number of channels in the output tensor + - init_type (:obj:`str`): the type of init to implement + - activation (:obj:`nn.Moduel`): the optional activation function + - norm_type (:obj:`str`): type of the normalization + - use_dropout (:obj:`bool`) : whether to use dropout in the fully-connected block + - dropout_probability (:obj:`float`) : probability of an element to be zeroed in the dropout. Default: 0.5 + Returns: + - block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the fully-connected block + + .. note:: + you can refer to nn.linear (https://pytorch.org/docs/master/generated/torch.nn.Linear.html) + """ + block = [] + if norm_type is not None and norm_type != 'none': + block.append(build_normalization(norm_type, dim=1)(in_channels)) + if isinstance(activation, str) and activation != 'none': + block.append(build_activation(activation)) + elif isinstance(activation, torch.nn.Module): + block.append(activation) + block.append(nn.Linear(in_channels, out_channels)) + if use_dropout: + block.append(nn.Dropout(dropout_probability)) + return nn.Sequential(*block) + + +def conv2d_block( + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + activation: str = None, + norm_type: str = None, + bias: bool = True, +) -> nn.Sequential: + r""" + Overview: + Create a 2-dim convlution layer with activation and normalization. + Arguments: + - in_channels (:obj:`int`): Number of channels in the input tensor + - out_channels (:obj:`int`): Number of channels in the output tensor + - kernel_size (:obj:`int`): Size of the convolving kernel + - stride (:obj:`int`): Stride of the convolution + - padding (:obj:`int`): Zero-padding added to both sides of the input + - dilation (:obj:`int`): Spacing between kernel elements + - groups (:obj:`int`): Number of blocked connections from input channels to output channels + - pad_type (:obj:`str`): the way to add padding, include ['zero', 'reflect', 'replicate'], default: None + - activation (:obj:`nn.Module`): the optional activation function + - norm_type (:obj:`str`): type of the normalization, default set to None, now support ['BN', 'IN', 'SyncBN'] + Returns: + - block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the 2 dim convlution layer + + .. note:: + + Conv2d (https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d) + """ + block = [] + block.append( + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups,bias=bias) + ) + if norm_type is not None: + block.append(nn.GroupNorm(num_groups=1, num_channels=out_channels)) + if isinstance(activation, str) and activation != 'none': + block.append(build_activation(activation)) + elif isinstance(activation, torch.nn.Module): + block.append(activation) + return nn.Sequential(*block) + + +def conv2d_block2( + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + activation: str = None, + norm_type=None, + bias: bool = True, +): + r""" + Overview: + create a 2-dim convlution layer with activation and normalization. + + Note: + Conv2d (https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d) + + Arguments: + - in_channels (:obj:`int`): Number of channels in the input tensor + - out_channels (:obj:`int`): Number of channels in the output tensor + - kernel_size (:obj:`int`): Size of the convolving kernel + - stride (:obj:`int`): Stride of the convolution + - padding (:obj:`int`): Zero-padding added to both sides of the input + - dilation (:obj:`int`): Spacing between kernel elements + - groups (:obj:`int`): Number of blocked connections from input channels to output channels + - init_type (:obj:`str`): the type of init to implement + - pad_type (:obj:`str`): the way to add padding, include ['zero', 'reflect', 'replicate'], default: None + - activation (:obj:`nn.Moduel`): the optional activation function + - norm_type (:obj:`str`): type of the normalization, default set to None, now support ['BN', 'IN', 'SyncBN'] + + Returns: + - block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the 2 dim convlution layer + """ + + block = [] + if norm_type is not None: + block.append(nn.GroupNorm(num_groups=1, num_channels=out_channels)) + if isinstance(activation, str) and activation != 'none': + block.append(build_activation(activation)) + elif isinstance(activation, torch.nn.Module): + block.append(activation) + block.append( + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups,bias=bias) + ) + return nn.Sequential(*block) + + +def MLP( + in_channels: int, + hidden_channels: int, + out_channels: int, + layer_num: int, + layer_fn: Callable = None, + activation: str = None, + norm_type: str = None, + use_dropout: bool = False, + dropout_probability: float = 0.5 +): + r""" + Overview: + create a multi-layer perceptron using fully-connected blocks with activation, normalization and dropout, + optional normalization can be done to the dim 1 (across the channels) + x -> fc -> norm -> act -> dropout -> out + Arguments: + - in_channels (:obj:`int`): Number of channels in the input tensor + - hidden_channels (:obj:`int`): Number of channels in the hidden tensor + - out_channels (:obj:`int`): Number of channels in the output tensor + - layer_num (:obj:`int`): Number of layers + - layer_fn (:obj:`Callable`): layer function + - activation (:obj:`nn.Module`): the optional activation function + - norm_type (:obj:`str`): type of the normalization + - use_dropout (:obj:`bool`): whether to use dropout in the fully-connected block + - dropout_probability (:obj:`float`): probability of an element to be zeroed in the dropout. Default: 0.5 + Returns: + - block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the fully-connected block + + .. note:: + + you can refer to nn.linear (https://pytorch.org/docs/master/generated/torch.nn.Linear.html) + """ + assert layer_num >= 0, layer_num + if layer_num == 0: + return nn.Sequential(*[nn.Identity()]) + + channels = [in_channels] + [hidden_channels] * (layer_num - 1) + [out_channels] + if layer_fn is None: + layer_fn = fc_block + block = [] + for i, (in_channels, out_channels) in enumerate(zip(channels[:-1], channels[1:])): + block.append(layer_fn(in_channels=in_channels, + out_channels=out_channels, + activation=activation, + norm_type=norm_type, + use_dropout=use_dropout, + dropout_probability=dropout_probability)) + return nn.Sequential(*block) diff --git a/lzero/model/gobigger/network/normalization.py b/lzero/model/gobigger/network/normalization.py new file mode 100644 index 000000000..fd5831c14 --- /dev/null +++ b/lzero/model/gobigger/network/normalization.py @@ -0,0 +1,36 @@ +from typing import Optional +import torch.nn as nn + + +def build_normalization(norm_type: str, dim: Optional[int] = None) -> nn.Module: + r""" + Overview: + Build the corresponding normalization module + Arguments: + - norm_type (:obj:`str`): type of the normaliztion, now support ['BN', 'IN', 'SyncBN', 'AdaptiveIN'] + - dim (:obj:`int`): dimension of the normalization, when norm_type is in [BN, IN] + Returns: + - norm_func (:obj:`nn.Module`): the corresponding batch normalization function + + .. note:: + For beginers, you can refer to to learn more about batch normalization. + """ + if dim is None: + key = norm_type + else: + if norm_type in ['BN', 'IN', 'SyncBN']: + key = norm_type + str(dim) + elif norm_type in ['LN']: + key = norm_type + else: + raise NotImplementedError("not support indicated dim when creates {}".format(norm_type)) + norm_func = { + 'BN1': nn.BatchNorm1d, + 'BN2': nn.BatchNorm2d, + 'LN': nn.LayerNorm, + 'IN2': nn.InstanceNorm2d, + } + if key in norm_func.keys(): + return norm_func[key] + else: + raise KeyError("invalid norm type: {}".format(key)) \ No newline at end of file diff --git a/lzero/model/gobigger/network/res_block.py b/lzero/model/gobigger/network/res_block.py new file mode 100644 index 000000000..f64fae1db --- /dev/null +++ b/lzero/model/gobigger/network/res_block.py @@ -0,0 +1,231 @@ +""" +Copyright 2020 Sensetime X-lab. All Rights Reserved + +Main Function: + 1. build ResBlock: you can use this classes to build residual blocks +""" +import torch.nn as nn +from .nn_module import conv2d_block, fc_block,conv2d_block2,fc_block2 +from .activation import build_activation +from .normalization import build_normalization + + +class ResBlock(nn.Module): + r''' + Overview: + Residual Block with 2D convolution layers, including 2 types: + basic block: + input channel: C + x -> 3*3*C -> norm -> act -> 3*3*C -> norm -> act -> out + \__________________________________________/+ + bottleneck block: + x -> 1*1*(1/4*C) -> norm -> act -> 3*3*(1/4*C) -> norm -> act -> 1*1*C -> norm -> act -> out + \_____________________________________________________________________________/+ + + Interface: + __init__, forward + ''' + + def __init__(self, in_channels, out_channels=None,stride=1, downsample=None, activation='relu', norm_type='LN',): + r""" + Overview: + Init the Residual Block + + Arguments: + - in_channels (:obj:`int`): Number of channels in the input tensor + - activation (:obj:`nn.Module`): the optional activation function + - norm_type (:obj:`str`): type of the normalization, defalut set to batch normalization, + support ['BN', 'IN', 'SyncBN', None] + - res_type (:obj:`str`): type of residual block, support ['basic', 'bottleneck'], see overview for details + """ + super(ResBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = self.in_channels if out_channels is None else out_channels + self.activation_type = activation + self.norm_type = norm_type + self.stride = stride + self.downsample = downsample + self.conv1 = conv2d_block(in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=self.stride, + padding= 1, + activation=self.activation_type, + norm_type=self.norm_type) + self.conv2 = conv2d_block(in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=self.stride, + padding= 1, + activation=None, + norm_type=self.norm_type) + self.activation = build_activation(self.activation_type) + + def forward(self, x): + r""" + Overview: + return the redisual block output + + Arguments: + - x (:obj:`tensor`): the input tensor + + Returns: + - x(:obj:`tensor`): the resblock output tensor + """ + residual = x + out = self.conv1(x) + out = self.conv2(out) + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.activation(out) + return out + + +class ResBlock2(nn.Module): + r''' + Overview: + Residual Block with 2D convolution layers, including 2 types: + basic block: + input channel: C + x -> 3*3*C -> norm -> act -> 3*3*C -> norm -> act -> out + \__________________________________________/+ + bottleneck block: + x -> 1*1*(1/4*C) -> norm -> act -> 3*3*(1/4*C) -> norm -> act -> 1*1*C -> norm -> act -> out + \_____________________________________________________________________________/+ + + Interface: + __init__, forward + ''' + + def __init__(self, in_channels, out_channels=None,stride=1, downsample=None, activation='relu', norm_type='LN',): + r""" + Overview: + Init the Residual Block + + Arguments: + - in_channels (:obj:`int`): Number of channels in the input tensor + - activation (:obj:`nn.Module`): the optional activation function + - norm_type (:obj:`str`): type of the normalization, defalut set to batch normalization, + support ['BN', 'IN', 'SyncBN', None] + - res_type (:obj:`str`): type of residual block, support ['basic', 'bottleneck'], see overview for details + """ + super(ResBlock2, self).__init__() + self.in_channels = in_channels + self.out_channels = self.in_channels if out_channels is None else out_channels + self.activation_type = activation + self.norm_type = norm_type + self.stride = stride + self.downsample = downsample + self.conv1 = conv2d_block2(in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=self.stride, + padding= 1, + activation=self.activation_type, + norm_type=self.norm_type) + self.conv2 = conv2d_block2(in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=self.stride, + padding= 1, + activation=self.activation_type, + norm_type=self.norm_type) + self.activation = build_activation(self.activation_type) + + + def forward(self, x): + r""" + Overview: + return the redisual block output + + Arguments: + - x (:obj:`tensor`): the input tensor + + Returns: + - x(:obj:`tensor`): the resblock output tensor + """ + residual = x + out = self.conv1(x) + out = self.conv2(out) + if self.downsample is not None: + residual = self.downsample(x) + out += residual + return x + +class ResFCBlock(nn.Module): + def __init__(self, in_channels, activation='relu', norm_type=None): + r""" + Overview: + Init the Residual Block + + Arguments: + - activation (:obj:`nn.Module`): the optional activation function + - norm_type (:obj:`str`): type of the normalization, defalut set to batch normalization + """ + super(ResFCBlock, self).__init__() + self.activation_type = activation + self.norm_type = norm_type + self.fc1 = fc_block(in_channels, in_channels, norm_type=self.norm_type, activation=self.activation_type) + self.fc2 = fc_block(in_channels, in_channels,norm_type=self.norm_type, activation=None) + self.activation = build_activation(self.activation_type) + + + def forward(self, x): + r""" + Overview: + return output of the residual block with 2 fully connected block + + Arguments: + - x (:obj:`tensor`): the input tensor + + Returns: + - x(:obj:`tensor`): the resblock output tensor + """ + residual = x + x = self.fc1(x) + x = self.fc2(x) + x = self.activation(x + residual) + return x + +class ResFCBlock2(nn.Module): + r''' + Overview: + Residual Block with 2 fully connected block + x -> fc1 -> norm -> act -> fc2 -> norm -> act -> out + \_____________________________________/+ + + Interface: + __init__, forward + ''' + + def __init__(self, in_channels, activation='relu', norm_type='LN'): + r""" + Overview: + Init the Residual Block + + Arguments: + - activation (:obj:`nn.Module`): the optional activation function + - norm_type (:obj:`str`): type of the normalization, defalut set to batch normalization + """ + super(ResFCBlock2, self).__init__() + self.activation_type = activation + self.fc1 = fc_block2(in_channels, in_channels, activation=self.activation_type, norm_type=norm_type) + self.fc2 = fc_block2(in_channels, in_channels, activation=self.activation_type, norm_type=norm_type) + + def forward(self, x): + r""" + Overview: + return output of the residual block with 2 fully connected block + + Arguments: + - x (:obj:`tensor`): the input tensor + + Returns: + - x(:obj:`tensor`): the resblock output tensor + """ + residual = x + x = self.fc1(x) + x = self.fc2(x) + x = x + residual + return x \ No newline at end of file diff --git a/lzero/model/gobigger/network/rnn.py b/lzero/model/gobigger/network/rnn.py new file mode 100644 index 000000000..363107360 --- /dev/null +++ b/lzero/model/gobigger/network/rnn.py @@ -0,0 +1,276 @@ +""" +Copyright 2020 Sensetime X-lab. All Rights Reserved + +Main Function: + 1. build LSTM: you can use build_LSTM to build the lstm module +""" +import math + +import torch +import torch.nn as nn + +from typing import Optional +from .normalization import build_normalization + + +def is_sequence(data): + return isinstance(data, list) or isinstance(data, tuple) + + +def sequence_mask(lengths: torch.Tensor, max_len: Optional[int] =None): + r""" + Overview: + create a mask for a batch sequences with different lengths + Arguments: + - lengths (:obj:`tensor`): lengths in each different sequences, shape could be (n, 1) or (n) + - max_len (:obj:`int`): the padding size, if max_len is None, the padding size is the + max length of sequences + Returns: + - masks (:obj:`torch.BoolTensor`): mask has the same device as lengths + """ + if len(lengths.shape) == 1: + lengths = lengths.unsqueeze(dim=1) + bz = lengths.numel() + if max_len is None: + max_len = lengths.max() + return torch.arange(0, max_len).type_as(lengths).repeat(bz, 1).lt(lengths).to(lengths.device) + + +class LSTMForwardWrapper(object): + r""" + Overview: + abstract class used to wrap the LSTM forward method + Interface: + _before_forward, _after_forward + """ + + def _before_forward(self, inputs, prev_state): + r""" + Overview: + preprocess the inputs and previous states + Arguments: + - inputs (:obj:`tensor`): input vector of cell, tensor of size [seq_len, batch_size, input_size] + - prev_state (:obj:`tensor` or :obj:`list`): + None or tensor of size [num_directions*num_layers, batch_size, hidden_size], if None then prv_state + will be initialized to all zeros. + Returns: + - prev_state (:obj:`tensor`): batch previous state in lstm + """ + assert hasattr(self, 'num_layers') + assert hasattr(self, 'hidden_size') + seq_len, batch_size = inputs.shape[:2] + if prev_state is None: + num_directions = 1 + zeros = torch.zeros( + num_directions * self.num_layers, + batch_size, + self.hidden_size, + dtype=inputs.dtype, + device=inputs.device + ) + prev_state = (zeros, zeros) + elif is_sequence(prev_state): + if len(prev_state) == 2 and isinstance(prev_state[0], torch.Tensor): + pass + else: + if len(prev_state) != batch_size: + raise RuntimeError( + "prev_state number is not equal to batch_size: {}/{}".format(len(prev_state), batch_size) + ) + num_directions = 1 + zeros = torch.zeros( + num_directions * self.num_layers, 1, self.hidden_size, dtype=inputs.dtype, device=inputs.device + ) + state = [] + for prev in prev_state: + if prev is None: + state.append([zeros, zeros]) + else: + state.append(prev) + state = list(zip(*state)) + prev_state = [torch.cat(t, dim=1) for t in state] + else: + raise TypeError("not support prev_state type: {}".format(type(prev_state))) + return prev_state + + def _after_forward(self, next_state, list_next_state=False): + r""" + Overview: + post process the next_state, return list or tensor type next_states + Arguments: + - next_state (:obj:`list` :obj:`Tuple` of :obj:`tensor`): list of Tuple contains the next (h, c) + - list_next_state (:obj:`bool`): whether return next_state with list format, default set to False + Returns: + - next_state(:obj:`list` of :obj:`tensor` or :obj:`tensor`): the formated next_state + """ + if list_next_state: + h, c = [torch.stack(t, dim=0) for t in zip(*next_state)] + batch_size = h.shape[1] + next_state = [torch.chunk(h, batch_size, dim=1), torch.chunk(c, batch_size, dim=1)] + next_state = list(zip(*next_state)) + else: + next_state = [torch.stack(t, dim=0) for t in zip(*next_state)] + return next_state + + +class LSTM(nn.Module, LSTMForwardWrapper): + r""" + Overview: + Implimentation of LSTM cell + + .. note:: + for begainners, you can reference to learn the basics about lstm + + Interface: + __init__, forward + """ + + def __init__(self, input_size, hidden_size, num_layers, norm_type=None, dropout=0.): + r""" + Overview: + initializate the LSTM cell + + Arguments: + - input_size (:obj:`int`): size of the input vector + - hidden_size (:obj:`int`): size of the hidden state vector + - num_layers (:obj:`int`): number of lstm layers + - norm_type (:obj:`str`): type of the normaliztion, (default: None) + - dropout (:obj:float): dropout rate, default set to .0 + """ + super(LSTM, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + + norm_func = build_normalization(norm_type) + self.norm = nn.ModuleList([norm_func(hidden_size * 4) for _ in range(2 * num_layers)]) + self.wx = nn.ParameterList() + self.wh = nn.ParameterList() + dims = [input_size] + [hidden_size] * num_layers + for l in range(num_layers): + self.wx.append(nn.Parameter(torch.zeros(dims[l], dims[l + 1] * 4))) + self.wh.append(nn.Parameter(torch.zeros(hidden_size, hidden_size * 4))) + self.bias = nn.Parameter(torch.zeros(num_layers, hidden_size * 4)) + self.use_dropout = dropout > 0. + if self.use_dropout: + self.dropout = nn.Dropout(dropout) + self._init() + + def _init(self): + gain = math.sqrt(1. / self.hidden_size) + for l in range(self.num_layers): + torch.nn.init.uniform_(self.wx[l], -gain, gain) + torch.nn.init.uniform_(self.wh[l], -gain, gain) + if self.bias is not None: + torch.nn.init.uniform_(self.bias[l], -gain, gain) + + def forward(self, inputs, prev_state, list_next_state=True): + r""" + Overview: + Take the previous state and the input and calculate the output and the nextstate + Arguments: + - inputs (:obj:`tensor`): input vector of cell, tensor of size [seq_len, batch_size, input_size] + - prev_state (:obj:`tensor`): None or tensor of size [num_directions*num_layers, batch_size, hidden_size] + - list_next_state (:obj:`bool`): whether return next_state with list format, default set to False + Returns: + - x (:obj:`tensor`): output from lstm + - next_state (:obj:`tensor` or :obj:`list`): hidden state from lstm + """ + seq_len, batch_size = inputs.shape[:2] + prev_state = self._before_forward(inputs, prev_state) + + H, C = prev_state + x = inputs + next_state = [] + for l in range(self.num_layers): + h, c = H[l], C[l] + new_x = [] + for s in range(seq_len): + gate = self.norm[l * 2](torch.matmul(x[s], self.wx[l]) + ) + self.norm[l * 2 + 1](torch.matmul(h, self.wh[l])) + if self.bias is not None: + gate += self.bias[l] + gate = list(torch.chunk(gate, 4, dim=1)) + i, f, o, u = gate + i = torch.sigmoid(i) + f = torch.sigmoid(f) + o = torch.sigmoid(o) + u = torch.tanh(u) + c = f * c + i * u + h = o * torch.tanh(c) + new_x.append(h) + next_state.append((h, c)) + x = torch.stack(new_x, dim=0) + if self.use_dropout and l != self.num_layers - 1: + x = self.dropout(x) + + next_state = self._after_forward(next_state, list_next_state) + return x, next_state + + +class PytorchLSTM(nn.LSTM, LSTMForwardWrapper): + r""" + Overview: + Wrap the nn.LSTM , format the input and output + Interface: + forward + + .. note:: + you can reference the + """ + + def forward(self, inputs, prev_state, list_next_state=True): + r""" + Overview: + wrapped nn.LSTM.forward + Arguments: + - inputs (:obj:`tensor`): input vector of cell, tensor of size [seq_len, batch_size, input_size] + - prev_state (:obj:`tensor`): None or tensor of size [num_directions*num_layers, batch_size, hidden_size] + - list_next_state (:obj:`bool`): whether return next_state with list format, default set to False + Returns: + - output (:obj:`tensor`): output from lstm + - next_state (:obj:`tensor` or :obj:`list`): hidden state from lstm + """ + prev_state = self._before_forward(inputs, prev_state) + output, next_state = nn.LSTM.forward(self, inputs, prev_state) + next_state = self._after_forward(next_state, list_next_state) + return output, next_state + + def _after_forward(self, next_state, list_next_state=False): + r""" + Overview: + process hidden state after lstm, make it list or remains tensor + Arguments: + - nex_state (:obj:`tensor`): hidden state from lstm + - list_nex_state (:obj:`bool`): whether return next_state with list format, default set to False + Returns: + - next_state (:obj:`tensor` or :obj:`list`): hidden state from lstm + """ + if list_next_state: + h, c = next_state + batch_size = h.shape[1] + next_state = [torch.chunk(h, batch_size, dim=1), torch.chunk(c, batch_size, dim=1)] + return list(zip(*next_state)) + else: + return next_state + + +def get_lstm(lstm_type, input_size, hidden_size, num_layers=1, norm_type='LN', dropout=0.): + r""" + Overview: + build and return the corresponding LSTM cell + Arguments: + - lstm_type (:obj:`str`): version of lstm cell, now support ['normal', 'pytorch'] + - input_size (:obj:`int`): size of the input vector + - hidden_size (:obj:`int`): size of the hidden state vector + - num_layers (:obj:`int`): number of lstm layers + - norm_type (:obj:`str`): type of the normaliztion, (default: None) + - dropout (:obj:float): dropout rate, default set to .0 + Returns: + - lstm (:obj:`LSTM` or :obj:`PytorchLSTM`): the corresponding lstm cell + """ + assert lstm_type in ['normal', 'pytorch'] + if lstm_type == 'normal': + return LSTM(input_size, hidden_size, num_layers, norm_type, dropout=dropout) + elif lstm_type == 'pytorch': + return PytorchLSTM(input_size, hidden_size, num_layers, dropout=dropout) diff --git a/lzero/model/gobigger/network/scatter_connection.py b/lzero/model/gobigger/network/scatter_connection.py new file mode 100644 index 000000000..dbb6ab716 --- /dev/null +++ b/lzero/model/gobigger/network/scatter_connection.py @@ -0,0 +1,107 @@ +from typing import Tuple + +import torch +import torch.nn as nn + + +class ScatterConnection(nn.Module): + r""" + Overview: + Scatter feature to its corresponding location + In alphastar, each entity is embedded into a tensor, these tensors are scattered into a feature map + with map size + """ + + def __init__(self, scatter_type='add') -> None: + r""" + Overview: + Init class + Arguments: + - scatter_type (:obj:`str`): add or cover, if two entities have same location, scatter type decides the + first one should be covered or added to second one + """ + super(ScatterConnection, self).__init__() + self.scatter_type = scatter_type + assert self.scatter_type in ['cover', 'add'] + + def xy_forward(self, x: torch.Tensor, spatial_size: Tuple[int, int], coord_x: torch.Tensor,coord_y) -> torch.Tensor: + device = x.device + BatchSize, Num, EmbeddingSize = x.shape + x = x.permute(0, 2, 1) + H, W = spatial_size + indices = (coord_x * W + coord_y).long() + indices = indices.unsqueeze(dim=1).repeat(1, EmbeddingSize, 1) + output = torch.zeros(size=(BatchSize, EmbeddingSize, H, W), device=device).view(BatchSize, EmbeddingSize, + H * W) + if self.scatter_type == 'cover': + output.scatter_(dim=2, index=indices, src=x) + elif self.scatter_type == 'add': + output.scatter_add_(dim=2, index=indices, src=x) + output = output.view(BatchSize, EmbeddingSize, H, W) + return output + + def forward(self, x: torch.Tensor, spatial_size: Tuple[int, int], location: torch.Tensor) -> torch.Tensor: + """ + Overview: + scatter x into a spatial feature map + Arguments: + - x (:obj:`tensor`): input tensor :math: `(B, M, N)` where `M` means the number of entity, `N` means\ + the dimension of entity attributes + - spatial_size (:obj:`tuple`): Tuple[H, W], the size of spatial feature x will be scattered into + - location (:obj:`tensor`): :math: `(B, M, 2)` torch.LongTensor, each location should be (y, x) + Returns: + - output (:obj:`tensor`): :math: `(B, N, H, W)` where `H` and `W` are spatial_size, return the\ + scattered feature map + Shapes: + - Input: :math: `(B, M, N)` where `M` means the number of entity, `N` means\ + the dimension of entity attributes + - Size: Tuple[H, W] + - Location: :math: `(B, M, 2)` torch.LongTensor, each location should be (y, x) + - Output: :math: `(B, N, H, W)` where `H` and `W` are spatial_size + + .. note:: + when there are some overlapping in locations, ``cover`` mode will result in the loss of information, we + use the addition as temporal substitute. + """ + device = x.device + BatchSize, Num, EmbeddingSize = x.shape + x = x.permute(0, 2, 1) + H, W = spatial_size + indices = location[:, :, 1] + location[:, :, 0] * W + indices = indices.unsqueeze(dim=1).repeat(1, EmbeddingSize, 1) + output = torch.zeros(size=(BatchSize, EmbeddingSize, H, W), device=device).view(BatchSize, EmbeddingSize, + H * W) + if self.scatter_type == 'cover': + output.scatter_(dim=2, index=indices, src=x) + elif self.scatter_type == 'add': + output.scatter_add_(dim=2, index=indices, src=x) + output = output.view(BatchSize, EmbeddingSize, H, W) + + # device = x.device + # B, M, N = x.shape + # H, W = spatial_size + # index = location.view(-1, 2) + # bias = torch.arange(B).mul_(H * W).unsqueeze(1).repeat(1, M).view(-1).to(device) + # index = index[:, 0] * W + index[:, 1] + # index += bias + # index = index.repeat(N, 1) + # x = x.view(-1, N).permute(1, 0) + # output = torch.zeros(N, B * H * W, device=device) + # if self.scatter_type == 'cover': + # output.scatter_(dim=1, index=index, src=x) + # elif self.scatter_type == 'add': + # output.scatter_add_(dim=1, index=index, src=x) + # output = output.reshape(N, B, H, W) + # output = output.permute(1, 0, 2, 3).contiguous() + + return output + + +if __name__ == '__main__': + scatter_conn = ScatterConnection() + BatchSize, Num, EmbeddingSize = 10, 20, 3 + SpatialSize = (13, 17) + for _ in range(10): + x = torch.randn(size=(BatchSize, Num, EmbeddingSize)) + locations = torch.randint(low=0, high=12, size=(BatchSize, Num, 2)) + scatter_conn.forward(x, SpatialSize, location=locations) diff --git a/lzero/model/gobigger/network/soft_argmax.py b/lzero/model/gobigger/network/soft_argmax.py new file mode 100644 index 000000000..a963fd1ad --- /dev/null +++ b/lzero/model/gobigger/network/soft_argmax.py @@ -0,0 +1,60 @@ +""" +Copyright 2020 Sensetime X-lab. All Rights Reserved + +Main Function: + 1. SoftArgmax: a nn.Module that computes SoftArgmax +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SoftArgmax(nn.Module): + r""" + Overview: + a nn.Module that computes SoftArgmax + + Note: + for more softargmax info, you can reference the wiki page + or reference the lecture + + + Interface: + __init__, forward + """ + + def __init__(self): + r""" + Overview: + initialize the SoftArgmax module + """ + super(SoftArgmax, self).__init__() + + def forward(self, x): + r""" + Overview: + soft-argmax for location regression + + Arguments: + - x (:obj:`Tensor`): predict heat map + + Returns: + - location (:obj:`Tensor`): predict location + + Shapes: + - x (:obj:`Tensor`): :math:`(B, C, H, W)`, while B is the batch size, + C is number of channels , H and W stands for height and width + - location (:obj:`Tensor`): :math:`(B, 2)`, while B is the batch size + """ + B, C, H, W = x.shape + device, dtype = x.device, x.dtype + # 1 channel + assert (x.shape[1] == 1) + h_kernel = torch.arange(0, H, device=device).to(dtype) + h_kernel = h_kernel.view(1, 1, H, 1).repeat(1, 1, 1, W) + w_kernel = torch.arange(0, W, device=device).to(dtype) + w_kernel = w_kernel.view(1, 1, 1, W).repeat(1, 1, H, 1) + x = F.softmax(x.view(B, C, -1), dim=-1).view(B, C, H, W) + h = (x * h_kernel).sum(dim=[1, 2, 3]) + w = (x * w_kernel).sum(dim=[1, 2, 3]) + return torch.stack([h, w], dim=1) diff --git a/lzero/model/gobigger/network/transformer.py b/lzero/model/gobigger/network/transformer.py new file mode 100644 index 000000000..67ae4426d --- /dev/null +++ b/lzero/model/gobigger/network/transformer.py @@ -0,0 +1,397 @@ +import math +from typing import Dict, Tuple, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +LAYER_NORM_EPS = 1e-5 +NEAR_INF = 1e20 +NEAR_INF_FP16 = 65504 + + +def neginf(dtype: torch.dtype) -> float: + """ + Return a representable finite number near -inf for a dtype. + """ + if dtype is torch.float16: + return -NEAR_INF_FP16 + else: + return -NEAR_INF + + +class MultiHeadAttention(nn.Module): + r""" + Overview: + For each entry embedding, compute individual attention across all entries, add them up to get output attention + """ + + def __init__(self, n_heads: int = None, dim: int = None, dropout: float = 0): + r""" + Overview: + Init attention + Arguments: + - input_dim (:obj:`int`): dimension of input + - head_dim (:obj:`int`): dimension of each head + - output_dim (:obj:`int`): dimension of output + - head_num (:obj:`int`): head num for multihead attention + - dropout (:obj:`nn.Module`): dropout layer + """ + super(MultiHeadAttention, self).__init__() + self.n_heads = n_heads + self.dim = dim + + self.attn_dropout = nn.Dropout(p=dropout) + self.q_lin = nn.Linear(dim, dim) + self.k_lin = nn.Linear(dim, dim) + self.v_lin = nn.Linear(dim, dim) + + # TODO: merge for the initialization step + nn.init.xavier_normal_(self.q_lin.weight) + nn.init.xavier_normal_(self.k_lin.weight) + nn.init.xavier_normal_(self.v_lin.weight) + self.out_lin = nn.Linear(dim, dim) + nn.init.xavier_normal_(self.out_lin.weight) + + # self.attention_pre = fc_block(self.dim, self.dim * 3) # query, key, value + # self.project = fc_block(self.dim,self.dim) + + def split(self, x, T=False): + r""" + Overview: + Split input to get multihead queries, keys, values + Arguments: + - x (:obj:`tensor`): query or key or value + - T (:obj:`bool`): whether to transpose output + Returns: + - x (:obj:`list`): list of output tensors for each head + """ + B, N = x.shape[:2] + x = x.view(B, N, self.head_num, self.head_dim) + x = x.permute(0, 2, 1, 3).contiguous() # B, head_num, N, head_dim + if T: + x = x.permute(0, 1, 3, 2).contiguous() + return x + + def forward(self, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + value: Optional[torch.Tensor] = None, + mask: torch.Tensor = None, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + batch_size, query_len, dim = query.size() + assert ( + dim == self.dim + ), 'Dimensions do not match: {} query vs {} configured'.format(dim, self.dim) + assert mask is not None, 'Mask is None, please specify a mask' + n_heads = self.n_heads + dim_per_head = dim // n_heads + scale = math.sqrt(dim_per_head) + + def prepare_head(tensor): + # input is [batch_size, seq_len, n_heads * dim_per_head] + # output is [batch_size * n_heads, seq_len, dim_per_head] + bsz, seq_len, _ = tensor.size() + tensor = tensor.view(batch_size, tensor.size(1), n_heads, dim_per_head) + tensor = ( + tensor.transpose(1, 2) + .contiguous() + .view(batch_size * n_heads, seq_len, dim_per_head) + ) + return tensor + + # q, k, v are the transformed values + if key is None and value is None: + # self attention + key = value = query + _, _key_len, dim = query.size() + elif value is None: + # key and value are the same, but query differs + # self attention + value = key + + assert key is not None # let mypy know we sorted this + _, _key_len, dim = key.size() + + q = prepare_head(self.q_lin(query)) + k = prepare_head(self.k_lin(key)) + v = prepare_head(self.v_lin(value)) + full_key_len = k.size(1) + dot_prod = q.div_(scale).bmm(k.transpose(1, 2)) + # [B * n_heads, query_len, key_len] + attn_mask = ( + (mask == 0) + .view(batch_size, 1, -1, full_key_len) + .repeat(1, n_heads, 1, 1) + .expand(batch_size, n_heads, query_len, full_key_len) + .view(batch_size * n_heads, query_len, full_key_len) + ) + assert attn_mask.shape == dot_prod.shape + dot_prod.masked_fill_(attn_mask, neginf(dot_prod.dtype)) + + attn_weights = F.softmax( + dot_prod, dim=-1, dtype=torch.float # type: ignore + ).type_as(query) + attn_weights = self.attn_dropout(attn_weights) # --attention-dropout + + attentioned = attn_weights.bmm(v) + attentioned = ( + attentioned.type_as(query) + .view(batch_size, n_heads, query_len, dim_per_head) + .transpose(1, 2) + .contiguous() + .view(batch_size, query_len, dim) + ) + + out = self.out_lin(attentioned) + + return out, dot_prod + # + # def forward(self, x, mask=None): + # r""" + # Overview: + # Compute attention + # Arguments: + # - x (:obj:`tensor`): input tensor + # - mask (:obj:`tensor`): mask out invalid entries + # Returns: + # - attention (:obj:`tensor`): attention tensor + # """ + # assert (len(x.shape) == 3) + # B, N = x.shape[:2] + # x = self.attention_pre(x) + # query, key, value = torch.chunk(x, 3, dim=2) + # query, key, value = self.split(query), self.split(key, T=True), self.split(value) + # + # score = torch.matmul(query, key) # B, head_num, N, N + # score /= math.sqrt(self.head_dim) + # if mask is not None: + # score.masked_fill_(~mask, value=-1e9) + # + # score = F.softmax(score, dim=-1) + # score = self.dropout(score) + # attention = torch.matmul(score, value) # B, head_num, N, head_dim + # + # attention = attention.permute(0, 2, 1, 3).contiguous() # B, N, head_num, head_dim + # attention = self.project(attention.view(B, N, -1)) # B, N, output_dim + # return attention + + +class TransformerFFN(nn.Module): + """ + Implements the FFN part of the transformer. + """ + + def __init__( + self, + dim: int = None, + dim_hidden: int = None, + dropout: float = 0, + activation: str = 'relu', + **kwargs, + ): + super(TransformerFFN, self).__init__(**kwargs) + self.dim = dim + self.dim_hidden = dim_hidden + self.dropout_ratio = dropout + self.relu_dropout = nn.Dropout(p=self.dropout_ratio) + if activation == 'relu': + self.nonlinear = F.relu + elif activation == 'gelu': + self.nonlinear = F.gelu + else: + raise ValueError( + "Don't know how to handle --activation {}".format(activation) + ) + self.lin1 = nn.Linear(self.dim, self.dim_hidden) + self.lin2 = nn.Linear(self.dim_hidden, self.dim) + nn.init.xavier_uniform_(self.lin1.weight) + nn.init.xavier_uniform_(self.lin2.weight) + # TODO: initialize biases to 0 + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Forward pass. + """ + x = self.nonlinear(self.lin1(x)) + x = self.relu_dropout(x) # --relu-dropout + x = self.lin2(x) + return x + + +class TransformerLayer(nn.Module): + r""" + Overview: + In transformer layer, first computes entries's attention and applies a feedforward layer + """ + + def __init__(self, + n_heads: int = None, + embedding_size: int = None, + ffn_size: int = None, + attention_dropout: float = 0.0, + relu_dropout: float = 0.0, + dropout: float = 0.0, + activation: str = 'relu', + variant: Optional[str] = None, + ): + r""" + Overview: + Init transformer layer + Arguments: + - input_dim (:obj:`int`): dimension of input + - head_dim (:obj:`int`): dimension of each head + - hidden_dim (:obj:`int`): dimension of hidden layer in mlp + - output_dim (:obj:`int`): dimension of output + - head_num (:obj:`int`): number of heads for multihead attention + - mlp_num (:obj:`int`): number of mlp layers + - dropout (:obj:`nn.Module`): dropout layer + - activation (:obj:`nn.Module`): activation function + """ + super(TransformerLayer, self).__init__() + self.n_heads = n_heads + self.dim = embedding_size + self.ffn_dim = ffn_size + self.activation = activation + self.variant = variant + self.attention = MultiHeadAttention( + n_heads=self.n_heads, + dim=embedding_size, + dropout=attention_dropout) + self.norm1 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) + self.ffn = TransformerFFN(dim=embedding_size, + dim_hidden=ffn_size, + dropout=relu_dropout, + activation=activation, + ) + self.norm2 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, mask): + """ + Overview: + transformer layer forward + Arguments: + - inputs (:obj:`tuple`): x and mask + Returns: + - output (:obj:`tuple`): x and mask + """ + residual = x + + if self.variant == 'prenorm': + x = self.norm1(x) + attended_tensor = self.attention(x, mask=mask)[0] + x = residual + self.dropout(attended_tensor) + if self.variant == 'postnorm': + x = self.norm1(x) + + residual = x + if self.variant == 'prenorm': + x = self.norm2(x) + x = residual + self.dropout(self.ffn(x)) + if self.variant == 'postnorm': + x = self.norm2(x) + + x *= mask.unsqueeze(-1).type_as(x) + return x + + +class Transformer(nn.Module): + ''' + Overview: + Transformer implementation + + Note: + For details refer to Attention is all you need: http://arxiv.org/abs/1706.03762 + ''' + + def __init__( + self, + n_heads=8, + embedding_size: int = 128, + ffn_size: int = 128, + n_layers: int = 3, + attention_dropout: float = 0.0, + relu_dropout: float = 0.0, + dropout: float = 0.0, + activation: Optional[str] = 'relu', + variant: Optional[str] = 'prenorm', + ): + r""" + Overview: + Init transformer + Arguments: + - input_dim (:obj:`int`): dimension of input + - head_dim (:obj:`int`): dimension of each head + - hidden_dim (:obj:`int`): dimension of hidden layer in mlp + - output_dim (:obj:`int`): dimension of output + - head_num (:obj:`int`): number of heads for multihead attention + - mlp_num (:obj:`int`): number of mlp layers + - layer_num (:obj:`int`): number of transformer layers + - dropout_ratio (:obj:`float`): dropout ratio + - activation (:obj:`nn.Module`): activation function + """ + super(Transformer, self).__init__() + self.n_heads = n_heads + self.dim = embedding_size + self.ffn_size = ffn_size + self.n_layers = n_layers + + self.dropout_ratio = dropout + self.attention_dropout = attention_dropout + self.relu_dropout = relu_dropout + self.activation = activation + self.variant = variant + + # build the model + self.layers = self.build_layers() + self.norm_embedding = torch.nn.LayerNorm(self.dim, eps=LAYER_NORM_EPS) + + def build_layers(self) -> nn.ModuleList: + layers = nn.ModuleList() + for _ in range(self.n_layers): + layer = TransformerLayer( + n_heads=self.n_heads, + embedding_size=self.dim, + ffn_size=self.ffn_size, + attention_dropout=self.attention_dropout, + relu_dropout=self.relu_dropout, + dropout=self.dropout_ratio, + variant=self.variant, + activation=self.activation, + ) + layers.append(layer) + return layers + + def forward(self, x, mask=None): + r""" + Overview: + Transformer forward + Arguments: + - x (:obj:`tensor`): input tensor, shape (B, N, C), B is batch size, N is number of entries, + C is feature dimension + - mask (:obj:`tensor` or :obj:`None`): bool tensor, can be used to mask out invalid entries in attention, + shape (B, N), B is batch size, N is number of entries + Returns: + - x (:obj:`tensor`): transformer output + """ + if self.variant == 'postnorm': + x = self.norm_embedding(x) + if mask is not None: + x *= mask.unsqueeze(-1).type_as(x) + else: + mask = torch.ones(size=x.shape[:2],dtype=torch.bool, device=x.device) + if self.variant == 'postnorm': + x = self.norm_embedding(x) + for i in range(self.n_layers): + x = self.layers[i](x, mask) + if self.variant == 'prenorm': + x = self.norm_embedding(x) + return x + +if __name__ == '__main__': + transformer = Transformer(n_heads=8,embedding_size=128) + from bigrl.core.torch_utils.network.rnn import sequence_mask + mask = sequence_mask(lengths=torch.tensor([1,2,3,4,5,6,2,3,0,0]),max_len=20) + y = transformer.forward(x = torch.randn(size=(10,20,128)),mask=mask) + print(y) \ No newline at end of file diff --git a/lzero/policy/gobigger_efficientzero.py b/lzero/policy/gobigger_efficientzero.py new file mode 100644 index 000000000..89277489f --- /dev/null +++ b/lzero/policy/gobigger_efficientzero.py @@ -0,0 +1,787 @@ +import copy +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +import torch.optim as optim +from ding.model import model_wrap +from ding.policy.base_policy import Policy +from ding.torch_utils import to_tensor, squeeze +from ding.utils import POLICY_REGISTRY +from torch.distributions import Categorical +from torch.nn import L1Loss + +from lzero.mcts import EfficientZeroMCTSCtree as MCTSCtree +from lzero.mcts import EfficientZeroMCTSPtree as MCTSPtree +from lzero.model import ImageTransforms +from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ + DiscreteSupport, select_action, to_torch_float_tensor, ez_network_output_unpack, negative_cosine_similarity, prepare_obs, \ + configure_optimizers +from collections import defaultdict + + +@POLICY_REGISTRY.register('gobigger_efficientzero') +class GoBiggerEfficientZeroPolicy(Policy): + """ + Overview: + The policy class for EfficientZero. + """ + + # The default_config for EfficientZero policy. + config = dict( + model=dict( + # (str) The model type. For 1-dimensional vector obs, we use mlp model. For 3-dimensional image obs, we use conv model. + model_type='conv', # options={'mlp', 'conv'} + # (bool) If True, the action space of the environment is continuous, otherwise discrete. + continuous_action_space=False, + # (tuple) The stacked obs shape. + # observation_shape=(1, 96, 96), # if frame_stack_num=1 + observation_shape=(4, 96, 96), # if frame_stack_num=4 + # (bool) Whether to use the self-supervised learning loss. + self_supervised_learning_loss=True, + # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. + categorical_distribution=True, + # (int) The image channel in image observation. + image_channel=1, + # (int) The number of frames to stack together. + frame_stack_num=1, + # (int) The scale of supports used in categorical distribution. + # This variable is only effective when ``categorical_distribution=True``. + support_scale=300, + # (int) The hidden size in LSTM. + lstm_hidden_size=512, + # (bool) whether to learn bias in the last linear layer in value and policy head. + bias=True, + # (str) The type of action encoding. Options are ['one_hot', 'not_one_hot']. Default to 'one_hot'. + discrete_action_encoding_type='one_hot', + # (bool) whether to use res connection in dynamics. + res_connection_in_dynamics=True, + # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'LN'. + norm_type='BN', + ), + # ****** common ****** + # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) + # this variable is used in ``collector``. + sampled_algo=False, + # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. + mcts_ctree=True, + # (bool) Whether to use cuda for network. + cuda=True, + # (int) The number of environments used in collecting data. + collector_env_num=8, + # (int) The number of environments used in evaluating policy. + evaluator_env_num=3, + # (str) The type of environment. The options are ['not_board_games', 'board_games']. + env_type='not_board_games', + # (str) The type of battle mode. The options are ['play_with_bot_mode', 'self_play_mode']. + battle_mode='play_with_bot_mode', + # (bool) Whether to monitor extra statistics in tensorboard. + monitor_extra_statistics=True, + # (int) The transition number of one ``GameSegment``. + game_segment_length=200, + + # ****** observation ****** + # (bool) Whether to transform image to string to save memory. + transform2string=False, + # (bool) Whether to use data augmentation. + use_augmentation=False, + # (list) The style of augmentation. + augmentation=['shift', 'intensity'], + + # ******* learn ****** + # (int) How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + # For different env, we have different episode_length, + # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor + update_per_collect=100, + # (int) Minibatch size for one gradient descent. + batch_size=256, + # (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW'] + optim_type='SGD', + # (float) Learning rate for training policy network. Initial lr for manually decay schedule. + learning_rate=0.2, + # (int) Frequency of target network update. + target_update_freq=100, + # (float) Weight decay for training policy network. + weight_decay=1e-4, + # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). + momentum=0.9, + # (float) The maximum constraint value of gradient norm clipping. + grad_clip_value=10, + # (int) The number of episode in each collecting stage. + n_episode=8, + # (float) the number of simulations in MCTS. + num_simulations=50, + # (float) Discount factor (gamma) for returns. + discount_factor=0.997, + # (int) The number of step for calculating target q_value. + td_steps=5, + # (int) The number of unroll steps in dynamics network. + num_unroll_steps=5, + # (int) reset the hidden states in LSTM every ``lstm_horizon_len`` horizon steps. + lstm_horizon_len=5, + # (float) The weight of reward loss. + reward_loss_weight=1, + # (float) The weight of value loss. + value_loss_weight=0.25, + # (float) The weight of policy loss. + policy_loss_weight=1, + # (float) The weight of ssl (self-supervised learning) loss. + ssl_loss_weight=2, + # (bool) Whether to use piecewise constant learning rate decay. + # i.e. lr: 0.2 -> 0.02 -> 0.002 + lr_piecewise_constant_decay=True, + # (int) The number of final training iterations to control lr decay, which is only used for manually decay. + threshold_training_steps_for_final_lr=int(5e4), + # (int) The number of final training iterations to control temperature, which is only used for manually decay. + threshold_training_steps_for_final_temperature=int(1e5), + # (bool) Whether to use manually decayed temperature. + # i.e. temperature: 1 -> 0.5 -> 0.25 + manual_temperature_decay=False, + # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. + # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. + fixed_temperature_value=0.25, + + # ****** Priority ****** + # (bool) Whether to use priority when sampling training data from the buffer. + use_priority=True, + # (bool) Whether to use the maximum priority for new collecting data. + use_max_priority_for_new_data=True, + # (float) The degree of prioritization to use. A value of 0 means no prioritization, + # while a value of 1 means full prioritization. + priority_prob_alpha=0.6, + # (float) The degree of correction to use. A value of 0 means no correction, + # while a value of 1 means full correction. + priority_prob_beta=0.4, + + # ****** UCB ****** + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of the search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default model setting. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. + - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. + - import_names (:obj:`List[str]`): The model class path list used in this algorithm. + .. note:: + The user can define and use customized network model but must obey the same interface definition indicated \ + by import_names path. For EfficientZero, ``lzero.model.efficientzero_model.EfficientZeroModel`` + """ + return 'EfficientZeroModelMLP', ['lzero.model.gobigger.gobigger_efficientzero_model_mlp'] + + def _init_learn(self) -> None: + """ + Overview: + Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. + """ + assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type + if self._cfg.optim_type == 'SGD': + self._optimizer = optim.SGD( + self._model.parameters(), + lr=self._cfg.learning_rate, + momentum=self._cfg.momentum, + weight_decay=self._cfg.weight_decay, + ) + + elif self._cfg.optim_type == 'Adam': + self._optimizer = optim.Adam( + self._model.parameters(), + lr=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + ) + elif self._cfg.optim_type == 'AdamW': + self._optimizer = configure_optimizers( + model=self._model, + weight_decay=self._cfg.weight_decay, + learning_rate=self._cfg.learning_rate, + device_type=self._cfg.device + ) + + if self._cfg.lr_piecewise_constant_decay: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda) + + # use model_wrapper for specialized demands of different modes + self._target_model = copy.deepcopy(self._model) + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq} + ) + self._learn_model = self._model + + if self._cfg.use_augmentation: + self.image_transforms = ImageTransforms( + self._cfg.augmentation, + image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + ) + self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + + def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning policy in learn mode, which is the core of the learning process. + The data is sampled from replay buffer. + The loss is calculated by the loss function and the loss is backpropagated to update the model. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. + The first tensor is the current_batch, the second tensor is the target_batch. + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ + current learning loss and learning statistics. + """ + self._learn_model.train() + self._target_model.train() + + current_batch, target_batch = data + obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch + target_value_prefix, target_value, target_policy = target_batch + + obs_batch_ori = obs_batch_ori.tolist() + obs_batch_ori = np.array(obs_batch_ori) + obs_batch = obs_batch_ori[:, 0:self._cfg.model.frame_stack_num] + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = obs_batch_ori[:, self._cfg.model.frame_stack_num:] + # obs_batch, obs_target_batch = obs_batch_ori.tolist() + + # # do augmentations + # if self._cfg.use_augmentation: + # obs_batch = self.image_transforms.transform(obs_batch) + # obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # shape: (batch_size, num_unroll_steps, action_dim) + # NOTE: .long(), in discrete action space. + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() + data_list = [ + mask_batch, + target_value_prefix.astype('float64'), + target_value.astype('float64'), target_policy, weights + ] + [mask_batch, target_value_prefix, target_value, target_policy, + weights] = to_torch_float_tensor(data_list, self._cfg.device) + + target_value_prefix = target_value_prefix.view(self._cfg.batch_size, -1) + target_value = target_value.view(self._cfg.batch_size, -1) + assert obs_batch.size == self._cfg.batch_size == target_value_prefix.size(0) + + # ``scalar_transform`` to transform the original value to the scaled value, + # i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + transformed_target_value_prefix = scalar_transform(target_value_prefix) + transformed_target_value = scalar_transform(target_value) + # transform a scalar to its categorical_distribution. After this transformation, each scalar is + # represented as the linear combination of its two adjacent supports. + target_value_prefix_categorical = phi_transform(self.reward_support, transformed_target_value_prefix) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # ============================================================== + # the core initial_inference in EfficientZero policy. + # ============================================================== + obs_batch = obs_batch.tolist() + obs_batch = sum(obs_batch, []) + network_output = self._learn_model.initial_inference(obs_batch) + # value_prefix shape: (batch_size, 10), the ``value_prefix`` at the first step is zero padding. + latent_state, value_prefix, reward_hidden_state, value, policy_logits = ez_network_output_unpack(network_output) + + # transform the scaled value or its categorical representation to its original value, + # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + original_value = self.inverse_scalar_transform_handle(value) + + # Note: The following lines are just for debugging. + predicted_value_prefixs = [] + if self._cfg.monitor_extra_statistics: + latent_state_list = latent_state.detach().cpu().numpy() + predicted_values, predicted_policies = original_value.detach().cpu(), torch.softmax( + policy_logits, dim=1 + ).detach().cpu() + + # calculate the new priorities for each transition. + value_priority = L1Loss(reduction='none')(original_value.squeeze(-1), target_value[:, 0]) + value_priority = value_priority.data.cpu().numpy() + 1e-6 + + prob = torch.softmax(policy_logits, dim=-1) + dist = Categorical(prob) + policy_entropy = dist.entropy().mean() + + # ============================================================== + # calculate policy and value loss for the first step. + # ============================================================== + policy_loss = cross_entropy_loss(policy_logits, target_policy[:, 0]) + + # Here we take the init hypothetical step k=0. + target_normalized_visit_count_init_step = target_policy[:, 0] + + # ******* NOTE: target_policy_entropy is only for debug. ****** + non_masked_indices = torch.nonzero(mask_batch[:, 0]).squeeze(-1) + # Check if there are any unmasked rows + if len(non_masked_indices) > 0: + target_normalized_visit_count_masked = torch.index_select( + target_normalized_visit_count_init_step, 0, non_masked_indices + ) + target_dist = Categorical(target_normalized_visit_count_masked) + target_policy_entropy = target_dist.entropy().mean() + else: + # Set target_policy_entropy to 0 if all rows are masked + target_policy_entropy = 0 + + value_loss = cross_entropy_loss(value, target_value_categorical[:, 0]) + + value_prefix_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + consistency_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + + # ============================================================== + # the core recurrent_inference in EfficientZero policy. + # ============================================================== + for step_i in range(self._cfg.num_unroll_steps): + # unroll with the dynamics function: predict the next ``latent_state``, ``reward_hidden_state``, + # `` value_prefix`` given current ``latent_state`` ``reward_hidden_state`` and ``action``. + # And then predict policy_logits and value with the prediction function. + network_output = self._learn_model.recurrent_inference( + latent_state, reward_hidden_state, action_batch[:, step_i] + ) + latent_state, value_prefix, reward_hidden_state, value, policy_logits = ez_network_output_unpack( + network_output + ) + + # transform the scaled value or its categorical representation to its original value, + # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + original_value = self.inverse_scalar_transform_handle(value) + + # ============================================================== + # calculate consistency loss for the next ``num_unroll_steps`` unroll steps. + # ============================================================== + if self._cfg.ssl_loss_weight > 0: + beg_index = step_i + end_index = step_i + self._cfg.model.frame_stack_num + obs_target_batch_tmp = obs_target_batch[:, beg_index:end_index].tolist() + obs_target_batch_tmp = sum(obs_target_batch_tmp, []) + network_output = self._learn_model.initial_inference(obs_target_batch_tmp) + + latent_state = to_tensor(latent_state) + representation_state = to_tensor(network_output.latent_state) + + # NOTE: no grad for the representation_state branch. + dynamic_proj = self._learn_model.project(latent_state, with_grad=True) + observation_proj = self._learn_model.project(representation_state, with_grad=False) + temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_i] + + consistency_loss += temp_loss + + # NOTE: the target policy, target_value_categorical, target_value_prefix_categorical is calculated in + # game buffer now. + # ============================================================== + # calculate policy loss for the next ``num_unroll_steps`` unroll steps. + # NOTE: the +=. + # ============================================================== + policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_i + 1]) + + # Here we take the hypothetical step k = step_i + 1 + prob = torch.softmax(policy_logits, dim=-1) + dist = Categorical(prob) + policy_entropy += dist.entropy().mean() + target_normalized_visit_count = target_policy[:, step_i + 1] + + # ******* NOTE: target_policy_entropy is only for debug. ****** + non_masked_indices = torch.nonzero(mask_batch[:, step_i + 1]).squeeze(-1) + # Check if there are any unmasked rows + if len(non_masked_indices) > 0: + target_normalized_visit_count_masked = torch.index_select( + target_normalized_visit_count, 0, non_masked_indices + ) + target_dist = Categorical(target_normalized_visit_count_masked) + target_policy_entropy += target_dist.entropy().mean() + else: + # Set target_policy_entropy to 0 if all rows are masked + target_policy_entropy += 0 + + value_loss += cross_entropy_loss(value, target_value_categorical[:, step_i + 1]) + value_prefix_loss += cross_entropy_loss(value_prefix, target_value_prefix_categorical[:, step_i]) + + # reset hidden states every ``lstm_horizon_len`` unroll steps. + if (step_i + 1) % self._cfg.lstm_horizon_len == 0: + reward_hidden_state = ( + torch.zeros(1, self._cfg.batch_size, self._cfg.model.lstm_hidden_size).to(self._cfg.device), + torch.zeros(1, self._cfg.batch_size, self._cfg.model.lstm_hidden_size).to(self._cfg.device) + ) + + if self._cfg.monitor_extra_statistics: + original_value_prefixs = self.inverse_scalar_transform_handle(value_prefix) + original_value_prefixs_cpu = original_value_prefixs.detach().cpu() + + predicted_values = torch.cat( + (predicted_values, self.inverse_scalar_transform_handle(value).detach().cpu()) + ) + predicted_value_prefixs.append(original_value_prefixs_cpu) + predicted_policies = torch.cat((predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu())) + latent_state_list = np.concatenate((latent_state_list, latent_state.detach().cpu().numpy())) + + # ============================================================== + # the core learn model update step. + # ============================================================== + # weighted loss with masks (some invalid states which are out of trajectory.) + loss = ( + self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss + + self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * value_prefix_loss + ) + weighted_total_loss = (weights * loss).mean() + # TODO(pu): test the effect of gradient scale. + gradient_scale = 1 / self._cfg.num_unroll_steps + weighted_total_loss.register_hook(lambda grad: grad * gradient_scale) + self._optimizer.zero_grad() + weighted_total_loss.backward() + total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_( + self._learn_model.parameters(), self._cfg.grad_clip_value + ) + self._optimizer.step() + if self._cfg.lr_piecewise_constant_decay is True: + self.lr_scheduler.step() + + # ============================================================== + # the core target model update step. + # ============================================================== + self._target_model.update(self._learn_model.state_dict()) + + # packing loss info for tensorboard logging + loss_info = ( + weighted_total_loss.item(), loss.mean().item(), policy_loss.mean().item(), value_prefix_loss.mean().item(), + value_loss.mean().item(), consistency_loss.mean() + ) + + if self._cfg.monitor_extra_statistics: + predicted_value_prefixs = torch.stack(predicted_value_prefixs).transpose(1, 0).squeeze(-1) + predicted_value_prefixs = predicted_value_prefixs.reshape(-1).unsqueeze(-1) + + td_data = ( + value_priority, target_value_prefix.detach().cpu().numpy(), target_value.detach().cpu().numpy(), + transformed_target_value_prefix.detach().cpu().numpy(), transformed_target_value.detach().cpu().numpy(), + target_value_prefix_categorical.detach().cpu().numpy(), target_value_categorical.detach().cpu().numpy(), + predicted_value_prefixs.detach().cpu().numpy(), predicted_values.detach().cpu().numpy(), + target_policy.detach().cpu().numpy(), predicted_policies.detach().cpu().numpy(), latent_state_list + ) + + return { + 'collect_mcts_temperature': self.collect_mcts_temperature, + 'cur_lr': self._optimizer.param_groups[0]['lr'], + 'weighted_total_loss': loss_info[0], + 'total_loss': loss_info[1], + 'policy_loss': loss_info[2], + 'policy_entropy': policy_entropy.item() / (self._cfg.num_unroll_steps + 1), + 'target_policy_entropy': target_policy_entropy.item() / (self._cfg.num_unroll_steps + 1), + 'value_prefix_loss': loss_info[3], + 'value_loss': loss_info[4], + 'consistency_loss': loss_info[5] / self._cfg.num_unroll_steps, + + # ============================================================== + # priority related + # ============================================================== + 'value_priority': td_data[0].flatten().mean().item(), + 'value_priority_orig': value_priority, + 'target_value_prefix': td_data[1].flatten().mean().item(), + 'target_value': td_data[2].flatten().mean().item(), + 'transformed_target_value_prefix': td_data[3].flatten().mean().item(), + 'transformed_target_value': td_data[4].flatten().mean().item(), + 'predicted_value_prefixs': td_data[7].flatten().mean().item(), + 'predicted_values': td_data[8].flatten().mean().item(), + 'total_grad_norm_before_clip': total_grad_norm_before_clip + } + + def _init_collect(self) -> None: + """ + Overview: + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + """ + self._collect_model = self._model + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(self._cfg) + else: + self._mcts_collect = MCTSPtree(self._cfg) + self.collect_mcts_temperature = 1 + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + ready_env_id=None + ): + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + self.collect_mcts_temperature = temperature + + active_collect_env_num = len(data) + data = to_tensor(data) + data = sum(sum(data, []), []) + batch_size = len(data) + agent_num = batch_size // active_collect_env_num + to_play = np.array(to_play).reshape(-1).tolist() + + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._collect_model.initial_inference(data) + latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( + network_output + ) + + if not self._learn_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + reward_hidden_state_roots = ( + reward_hidden_state_roots[0].detach().cpu().numpy(), + reward_hidden_state_roots[1].detach().cpu().numpy() + ) + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + action_mask = sum(action_mask, []) + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] + # the only difference between collect and eval is the dirichlet noise. + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(batch_size) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(batch_size, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(batch_size, legal_actions) + roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_roots, policy_logits, to_play) + self._mcts_collect.search( + roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play + ) + + roots_visit_count_distributions = roots.get_distributions() # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_values = roots.get_values() # shape: {list: batch_size} + + data_id = [i for i in range(active_collect_env_num)] + output = {i: defaultdict(list) for i in data_id} + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + + for i in range(batch_size): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self.collect_mcts_temperature, deterministic=False + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[i//agent_num]['action'].append(action) + output[i//agent_num]['distributions'].append(distributions) + output[i//agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[i//agent_num]['value'].append(value) + output[i//agent_num]['pred_value'].append(pred_values[i]) + output[i//agent_num]['policy_logits'].append(policy_logits[i]) + + + # for i, env_id in enumerate(ready_env_id): + # distributions, value = roots_visit_count_distributions[i], roots_values[i] + # # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # # the index within the legal action set, rather than the index in the entire action set. + # action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + # distributions, temperature=self.collect_mcts_temperature, deterministic=False + # ) + # # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + # action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + # output[env_id] = { + # 'action': action, + # 'distributions': distributions, + # 'visit_count_distribution_entropy': visit_count_distribution_entropy, + # 'value': value, + # 'pred_value': pred_values[i], + # 'policy_logits': policy_logits[i], + # } + + return output + + def _init_eval(self) -> None: + """ + Overview: + Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. + """ + self._eval_model = self._model + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(self._cfg) + else: + self._mcts_eval = MCTSPtree(self._cfg) + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id=None): + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + active_eval_env_num = len(data) + data = to_tensor(data) + data = sum(sum(data, []), []) + batch_size = len(data) + agent_num = batch_size // active_eval_env_num + to_play = np.array(to_play).reshape(-1).tolist() + + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._eval_model.initial_inference(data) + latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( + network_output + ) + + if not self._eval_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + reward_hidden_state_roots = ( + reward_hidden_state_roots[0].detach().cpu().numpy(), + reward_hidden_state_roots[1].detach().cpu().numpy() + ) + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + + action_mask = sum(action_mask, []) + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(batch_size, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(batch_size, legal_actions) + roots.prepare_no_noise(value_prefix_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, reward_hidden_state_roots, to_play) + + roots_visit_count_distributions = roots.get_distributions( + ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_values = roots.get_values() # shape: {list: batch_size} + data_id = [i for i in range(active_eval_env_num)] + output = {i: defaultdict(list) for i in data_id} + + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + + for i in range(batch_size): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than sampling during the evaluation phase. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[i//agent_num]['action'].append(action) + output[i//agent_num]['distributions'].append(distributions) + output[i//agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[i//agent_num]['value'].append(value) + output[i//agent_num]['pred_value'].append(pred_values[i]) + output[i//agent_num]['policy_logits'].append(policy_logits[i]) + + return output + + def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Register the variables to be monitored in learn mode. The registered variables will be logged in + tensorboard according to the return value ``_forward_learn``. + """ + return [ + 'collect_mcts_temperature', + 'cur_lr', + 'weighted_total_loss', + 'total_loss', + 'policy_loss', + 'policy_entropy', + 'target_policy_entropy', + 'value_prefix_loss', + 'value_loss', + 'consistency_loss', + 'value_priority', + 'target_value_prefix', + 'target_value', + 'predicted_value_prefixs', + 'predicted_values', + 'transformed_target_value_prefix', + 'transformed_target_value', + 'total_grad_norm_before_clip', + ] + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model and optimizer. + Returns: + - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. + """ + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'optimizer': self._optimizer.state_dict(), + } + + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before. + """ + self._learn_model.load_state_dict(state_dict['model']) + self._target_model.load_state_dict(state_dict['target_model']) + self._optimizer.load_state_dict(state_dict['optimizer']) + + def _process_transition(self, obs, policy_output, timestep): + # be compatible with DI-engine Policy class + pass + + def _get_train_sample(self, data): + # be compatible with DI-engine Policy class + pass diff --git a/lzero/worker/__init__.py b/lzero/worker/__init__.py index b74e1e745..62a97b743 100644 --- a/lzero/worker/__init__.py +++ b/lzero/worker/__init__.py @@ -2,3 +2,5 @@ from .alphazero_evaluator import AlphaZeroEvaluator from .muzero_collector import MuZeroCollector from .muzero_evaluator import MuZeroEvaluator +from .gobigger_muzero_collector import GoBiggerMuZeroCollector +from .gobigger_muzero_evaluator import GoBiggerMuZeroEvaluator diff --git a/lzero/worker/gobigger_muzero_collector.py b/lzero/worker/gobigger_muzero_collector.py new file mode 100644 index 000000000..d43527c9a --- /dev/null +++ b/lzero/worker/gobigger_muzero_collector.py @@ -0,0 +1,658 @@ +import time +from collections import deque, namedtuple +from typing import Optional, Any, List + +import numpy as np +import torch +from ding.envs import BaseEnvManager +from ding.torch_utils import to_ndarray +from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY +from ding.worker.collector.base_serial_collector import ISerialCollector +from torch.nn import L1Loss + +from lzero.mcts.buffer.game_segment import GameSegment +from lzero.mcts.utils import prepare_observation +from collections import defaultdict + + +@SERIAL_COLLECTOR_REGISTRY.register('gobigger_episode_muzero') +class GoBiggerMuZeroCollector(ISerialCollector): + """ + Overview: + The Collector for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. + Interfaces: + __init__, reset, reset_env, reset_policy, collect, close + Property: + envstep + """ + + # TO be compatible with ISerialCollector + config = dict() + + def __init__( + self, + collect_print_freq: int = 100, + env: BaseEnvManager = None, + policy: namedtuple = None, + tb_logger: 'SummaryWriter' = None, # noqa + exp_name: Optional[str] = 'default_experiment', + instance_name: Optional[str] = 'collector', + policy_config: 'policy_config' = None, # noqa + ) -> None: + """ + Overview: + Init the collector according to input arguments. + Arguments: + - collect_print_freq (:obj:`int`): collect_print_frequency in terms of training_steps. + - env (:obj:`BaseEnvManager`): the subclass of vectorized env_manager(BaseEnvManager) + - policy (:obj:`namedtuple`): the api namedtuple of collect_mode policy + - tb_logger (:obj:`SummaryWriter`): tensorboard handle + - instance_name (:obj:`Optional[str]`): Name of this instance. + - exp_name (:obj:`str`): Experiment name, which is used to indicate output directory. + - policy_config: Config of game. + """ + self._exp_name = exp_name + self._instance_name = instance_name + self._collect_print_freq = collect_print_freq + self._timer = EasyTimer() + self._end_flag = False + + if tb_logger is not None: + self._logger, _ = build_logger( + path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False + ) + self._tb_logger = tb_logger + else: + self._logger, self._tb_logger = build_logger( + path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name + ) + + self.policy_config = policy_config + + self.reset(policy, env) + + def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: + """ + Overview: + Reset the environment. + If _env is None, reset the old environment. + If _env is not None, replace the old environment in the collector with the new passed \ + in environment and launch. + Arguments: + - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ + env_manager(BaseEnvManager) + """ + if _env is not None: + self._env = _env + self._env.launch() + self._env_num = self._env.env_num + else: + self._env.reset() + + def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: + """ + Overview: + Reset the policy. + If _policy is None, reset the old policy. + If _policy is not None, replace the old policy in the collector with the new passed in policy. + Arguments: + - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy + """ + assert hasattr(self, '_env'), "please set env first" + if _policy is not None: + self._policy = _policy + self._default_n_episode = _policy.get_attribute('cfg').get('n_episode', None) + self._logger.debug( + 'Set default n_episode mode(n_episode({}), env_num({}))'.format(self._default_n_episode, self._env_num) + ) + self._policy.reset() + + def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: + """ + Overview: + Reset the environment and policy. + If _env is None, reset the old environment. + If _env is not None, replace the old environment in the collector with the new passed \ + in environment and launch. + If _policy is None, reset the old policy. + If _policy is not None, replace the old policy in the collector with the new passed in policy. + Arguments: + - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy + - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ + env_manager(BaseEnvManager) + """ + if _env is not None: + self.reset_env(_env) + if _policy is not None: + self.reset_policy(_policy) + + self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)} + + self._episode_info = [] + self._total_envstep_count = 0 + self._total_episode_count = 0 + self._total_duration = 0 + self._last_train_iter = 0 + self._end_flag = False + + # A game_segment_pool implementation based on the deque structure. + self.game_segment_pool = deque(maxlen=int(1e6)) + self.unroll_plus_td_steps = self.policy_config.num_unroll_steps + self.policy_config.td_steps + + def _reset_stat(self, env_id: int) -> None: + """ + Overview: + Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool\ + and env_info. Reset these states according to env_id. You can refer to base_serial_collector\ + to get more messages. + Arguments: + - env_id (:obj:`int`): the id where we need to reset the collector's state + """ + self._env_info[env_id] = {'time': 0., 'step': 0} + + @property + def envstep(self) -> int: + """ + Overview: + Print the total envstep count. + Return: + - envstep (:obj:`int`): the total envstep count + """ + return self._total_envstep_count + + def close(self) -> None: + """ + Overview: + Close the collector. If end_flag is False, close the environment, flush the tb_logger\ + and close the tb_logger. + """ + if self._end_flag: + return + self._end_flag = True + self._env.close() + self._tb_logger.flush() + self._tb_logger.close() + + def __del__(self) -> None: + """ + Overview: + Execute the close command and close the collector. __del__ is automatically called to \ + destroy the collector instance when the collector finishes its work + """ + self.close() + + # ============================================================== + # MCTS+RL related core code + # ============================================================== + def _compute_priorities(self, i, agent_id, pred_values_lst, search_values_lst): + """ + Overview: + obtain the priorities at index i. + Arguments: + - i: index. + - pred_values_lst: The list of value being predicted. + - search_values_lst: The list of value obtained through search. + """ + if self.policy_config.use_priority and not self.policy_config.use_max_priority_for_new_data: + pred_values = torch.from_numpy(np.array(pred_values_lst[i][agent_id])).to(self.policy_config.device).float().view(-1) + search_values = torch.from_numpy(np.array(search_values_lst[i][agent_id])).to(self.policy_config.device + ).float().view(-1) + priorities = L1Loss(reduction='none' + )(pred_values, + search_values).detach().cpu().numpy() + self.policy_config.prioritized_replay_eps + else: + # priorities is None -> use the max priority for all newly collected data + priorities = None + + return priorities + + def pad_and_save_last_trajectory(self, i, agent_id, last_game_segments, last_game_priorities, game_segments, done) -> None: + """ + Overview: + put the last game block into the pool if the current game is finished + Arguments: + - last_game_segments (:obj:`list`): list of the last game segments + - last_game_priorities (:obj:`list`): list of the last game priorities + - game_segments (:obj:`list`): list of the current game segments + Note: + (last_game_segments[i].obs_segment[-4:][j] == game_segments[i].obs_segment[:4][j]).all() is True + """ + # pad over last block trajectory + beg_index = self.policy_config.model.frame_stack_num + end_index = beg_index + self.policy_config.num_unroll_steps + + # the start obs is init zero obs, so we take the [ : +] obs as the pad obs + # e.g. the start 4 obs is init zero obs, the num_unroll_steps is 5, so we take the [4:9] obs as the pad obs + pad_obs_lst = game_segments[i][agent_id].obs_segment[beg_index:end_index] + pad_child_visits_lst = game_segments[i][agent_id].child_visit_segment[:self.policy_config.num_unroll_steps] + # EfficientZero original repo bug: + # pad_child_visits_lst = game_segments[i].child_visit_segment[beg_index:end_index] + + beg_index = 0 + # self.unroll_plus_td_steps = self.policy_config.num_unroll_steps + self.policy_config.td_steps + end_index = beg_index + self.unroll_plus_td_steps - 1 + + pad_reward_lst = game_segments[i][agent_id].reward_segment[beg_index:end_index] + + beg_index = 0 + end_index = beg_index + self.unroll_plus_td_steps + + pad_root_values_lst = game_segments[i][agent_id].root_value_segment[beg_index:end_index] + + # pad over and save + last_game_segments[i][agent_id].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst) + """ + Note: + game_segment element shape: + obs: game_segment_length + stack + num_unroll_steps, 20+4 +5 + rew: game_segment_length + stack + num_unroll_steps + td_steps -1 20 +5+3-1 + action: game_segment_length -> 20 + root_values: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 + child_visits: game_segment_length + num_unroll_steps -> 20 +5 + to_play: game_segment_length -> 20 + action_mask: game_segment_length -> 20 + """ + + last_game_segments[i][agent_id].game_segment_to_array() + + # put the game block into the pool + self.game_segment_pool.append((last_game_segments[i][agent_id], last_game_priorities[i][agent_id], done[i])) + + # reset last game_segments + last_game_segments[i][agent_id] = None + last_game_priorities[i][agent_id] = None + + return None + + def collect(self, + n_episode: Optional[int] = None, + train_iter: int = 0, + policy_kwargs: Optional[dict] = None) -> List[Any]: + """ + Overview: + Collect `n_episode` data with policy_kwargs, which is already trained `train_iter` iterations. + Arguments: + - n_episode (:obj:`int`): the number of collecting data episode. + - train_iter (:obj:`int`): the number of training iteration. + - policy_kwargs (:obj:`dict`): the keyword args for policy forward. + Returns: + - return_data (:obj:`List`): A list containing collected game_segments + """ + if n_episode is None: + if self._default_n_episode is None: + raise RuntimeError("Please specify collect n_episode") + else: + n_episode = self._default_n_episode + assert n_episode >= self._env_num, "Please make sure n_episode >= env_num{}/{}".format(n_episode, self._env_num) + if policy_kwargs is None: + policy_kwargs = {} + temperature = policy_kwargs['temperature'] + + collected_episode = 0 + env_nums = self._env_num + + # initializations + init_obs = self._env.ready_obs + + retry_waiting_time = 0.001 + while len(init_obs.keys()) != self._env_num: + # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to + # len(self._env.ready_obs), especially in tictactoe env. + self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) + self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) + time.sleep(retry_waiting_time) + self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) + self._logger.info( + 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states) + ) + init_obs = self._env.ready_obs + + action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} + to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} + agent_num = len(init_obs[0]['action_mask']) + game_segments = [ + [GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) for _ in range(agent_num)] for _ in range(env_nums) + ] + # stacked observation windows in reset stage for init game_segments + observation_window_stack = [[[] for _ in range(agent_num)] for _ in range(env_nums)] + for env_id in range(env_nums): + for agent_id in range(agent_num): + observation_window_stack[env_id][agent_id] = deque( + [to_ndarray(init_obs[env_id]['observation'][agent_id]) for _ in range(self.policy_config.model.frame_stack_num)], + maxlen=self.policy_config.model.frame_stack_num + ) + game_segments[env_id][agent_id].reset(observation_window_stack[env_id][agent_id]) + + dones = np.array([False for _ in range(env_nums)]) + last_game_segments = [[None for _ in range(agent_num)] for _ in range(env_nums)] + last_game_priorities = [[None for _ in range(agent_num)] for _ in range(env_nums)] + # for priorities in self-play + search_values_lst = [[[] for _ in range(agent_num)] for _ in range(env_nums)] + pred_values_lst = [[[] for _ in range(agent_num)] for _ in range(env_nums)] + + + # some logs + eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros((env_nums, agent_num)) + self_play_moves = 0. + self_play_episodes = 0. + self_play_moves_max = 0 + self_play_visit_entropy = [] + total_transitions = 0 + + ready_env_id = set() + remain_episode = n_episode + + while True: + with self._timer: + # Get current ready env obs. + obs = self._env.ready_obs + new_available_env_id = set(obs.keys()).difference(ready_env_id) + ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) + remain_episode -= min(len(new_available_env_id), remain_episode) + + stack_obs = defaultdict(list) + for env_id in ready_env_id: + for agent_id in range(agent_num): + stack_obs[env_id].append(game_segments[env_id][agent_id].get_obs()) + + # stack_obs = {env_id: [game_segments[env_id][agent_id].get_obs() for agent_id in agent_num] for env_id in ready_env_id} + stack_obs = list(stack_obs.values()) + + action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} + to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} + action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] + to_play = [to_play_dict[env_id] for env_id in ready_env_id] + + stack_obs = to_ndarray(stack_obs) + + # stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) + + # stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() + + # ============================================================== + # policy forward + # ============================================================== + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play) + + actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} + distributions_dict_no_env_id = {k: v['distributions'] for k, v in policy_output.items()} + if self.policy_config.sampled_algo: + root_sampled_actions_dict_no_env_id = { + k: v['root_sampled_actions'] + for k, v in policy_output.items() + } + value_dict_no_env_id = {k: v['value'] for k, v in policy_output.items()} + pred_value_dict_no_env_id = {k: v['pred_value'] for k, v in policy_output.items()} + visit_entropy_dict_no_env_id = { + k: v['visit_count_distribution_entropy'] + for k, v in policy_output.items() + } + + # TODO(pu): subprocess + actions = {} + distributions_dict = {} + if self.policy_config.sampled_algo: + root_sampled_actions_dict = {} + value_dict = {} + pred_value_dict = {} + visit_entropy_dict = {} + for index, env_id in enumerate(ready_env_id): + actions[env_id] = actions_no_env_id.pop(index) + distributions_dict[env_id] = distributions_dict_no_env_id.pop(index) + if self.policy_config.sampled_algo: + root_sampled_actions_dict[env_id] = root_sampled_actions_dict_no_env_id.pop(index) + value_dict[env_id] = value_dict_no_env_id.pop(index) + pred_value_dict[env_id] = pred_value_dict_no_env_id.pop(index) + visit_entropy_dict[env_id] = visit_entropy_dict_no_env_id.pop(index) + + # ============================================================== + # Interact with env. + # ============================================================== + timesteps = self._env.step(actions) + + interaction_duration = self._timer.value / len(timesteps) + for env_id, timestep in timesteps.items(): + with self._timer: + if timestep.info.get('abnormal', False): + # If there is an abnormal timestep, reset all the related variables(including this env). + # suppose there is no reset param, just reset this env + self._env.reset({env_id: None}) + self._policy.reset([env_id]) + self._reset_stat(env_id) + self._logger.info('Env{} returns a abnormal step, its info is {}'.format(env_id, timestep.info)) + continue + obs, reward, done, info = timestep.obs, timestep.reward, timestep.done, timestep.info + + if self.policy_config.sampled_algo: + for agent_id in range(agent_num): + game_segments[env_id][agent_id].store_search_stats( + distributions_dict[env_id][agent_id], value_dict[env_id][agent_id], root_sampled_actions_dict[env_id][agent_id]) + else: + for agent_id in range(agent_num): + if len(distributions_dict[env_id][agent_id])!=27: + print('') + game_segments[env_id][agent_id].store_search_stats(distributions_dict[env_id][agent_id], value_dict[env_id][agent_id]) + # append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t} + # in ``game_segments[env_id].init``, we have append o_{t} in ``self.obs_segment`` + for agent_id in range(agent_num): + game_segments[env_id][agent_id].append( + actions[env_id][agent_id], to_ndarray(obs['observation'][agent_id]), reward[agent_id], action_mask_dict[env_id][agent_id], + to_play_dict[env_id] + ) + + # NOTE: the position of code snippet is very important. + # the obs['action_mask'] and obs['to_play'] is corresponding to next action + action_mask_dict[env_id] = to_ndarray(obs['action_mask']) + to_play_dict[env_id] = to_ndarray(obs['to_play']) + dones[env_id] = done + for agent_id in range(agent_num): + visit_entropies_lst[env_id][agent_id] += visit_entropy_dict[env_id][agent_id] + eps_steps_lst[env_id] += 1 + total_transitions += 1 + + if self.policy_config.use_priority and not self.policy_config.use_max_priority_for_new_data: + for agent_id in range(agent_num): + pred_values_lst[env_id][agent_id].append(pred_value_dict[env_id][agent_id]) + search_values_lst[env_id][agent_id].append(value_dict[env_id][agent_id]) + + # append the newest obs + for agent_id in range(agent_num): + observation_window_stack[env_id][agent_id].append(to_ndarray(obs['observation'][agent_id])) + + # ============================================================== + # we will save a game block if it is the end of the game or the next game block is finished. + # ============================================================== + + # if game block is full, we will save the last game block + for agent_id in range(agent_num): + if game_segments[env_id][agent_id].is_full(): + # pad over last block trajectory + if last_game_segments[env_id][agent_id] is not None: + # TODO(pu): return the one game block + self.pad_and_save_last_trajectory( + env_id, agent_id, last_game_segments, last_game_priorities, game_segments, dones + ) + + # calculate priority + priorities = self._compute_priorities(env_id, agent_id, pred_values_lst, search_values_lst) + # pred_values_lst[env_id] = [] + # search_values_lst[env_id] = [] + search_values_lst = [[[] for _ in range(agent_num)] for _ in range(env_nums)] + pred_values_lst = [[[] for _ in range(agent_num)] for _ in range(env_nums)] + + # the current game_segments become last_game_segment + last_game_segments[env_id][agent_id] = game_segments[env_id][agent_id] + last_game_priorities[env_id][agent_id] = priorities + + # create new GameSegment + game_segments[env_id][agent_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + game_segments[env_id][agent_id].reset(observation_window_stack[env_id][agent_id]) + + self._env_info[env_id]['step'] += 1 + self._total_envstep_count += 1 + self._env_info[env_id]['time'] += self._timer.value + interaction_duration + if timestep.done: + self._total_episode_count += 1 + reward = timestep.info['eval_episode_return'][0] + info = { + 'reward': reward, + 'time': self._env_info[env_id]['time'], + 'step': self._env_info[env_id]['step'], + 'visit_entropy': visit_entropies_lst[env_id] / eps_steps_lst[env_id], + } + collected_episode += 1 + self._episode_info.append(info) + + # ============================================================== + # if it is the end of the game, we will save the game block + # ============================================================== + + # NOTE: put the penultimate game block in one episode into the trajectory_pool + # pad over 2th last game_segment using the last game_segment + for agent_id in range(agent_num): + if last_game_segments[env_id] is not None: + self.pad_and_save_last_trajectory( + env_id, agent_id, last_game_segments, last_game_priorities, game_segments, dones + ) + + # store current block trajectory + priorities = self._compute_priorities(env_id, agent_id, pred_values_lst, search_values_lst) + + # NOTE: put the last game block in one episode into the trajectory_pool + game_segments[env_id][agent_id].game_segment_to_array() + + # assert len(game_segments[env_id]) == len(priorities) + # NOTE: save the last game block in one episode into the trajectory_pool if it's not null + if len(game_segments[env_id][agent_id].reward_segment) != 0: + self.game_segment_pool.append((game_segments[env_id][agent_id], priorities, dones[env_id])) + + # print(game_segments[env_id].reward_segment) + # reset the finished env and init game_segments + if n_episode > self._env_num: + # Get current ready env obs. + init_obs = self._env.ready_obs + retry_waiting_time = 0.001 + while len(init_obs.keys()) != self._env_num: + # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to + # len(self._env.ready_obs), especially in tictactoe env. + self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) + self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) + time.sleep(retry_waiting_time) + self._logger.info( + '=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10 + ) + self._logger.info( + 'After sleeping {}s, the current _env_states is {}'.format( + retry_waiting_time, self._env._env_states + ) + ) + init_obs = self._env.ready_obs + + new_available_env_id = set(init_obs.keys()).difference(ready_env_id) + ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) + remain_episode -= min(len(new_available_env_id), remain_episode) + + action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) + to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) + agent_num = len(init_obs[0]['action_mask']) + for agent_id in range(agent_num): + game_segments[env_id][agent_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + observation_window_stack[env_id][agent_id] = deque( + [init_obs[env_id]['observation'][agent_id] for _ in range(self.policy_config.model.frame_stack_num)], + maxlen=self.policy_config.model.frame_stack_num + ) + game_segments[env_id][agent_id].reset(observation_window_stack[env_id][agent_id]) + last_game_segments[env_id] = [None for _ in range(agent_num)] + last_game_priorities[env_id] = [None for _ in range(agent_num)] + + # log + self_play_moves_max = max(self_play_moves_max, eps_steps_lst[env_id]) + self_play_visit_entropy.append(visit_entropies_lst[env_id] / eps_steps_lst[env_id]) + self_play_moves += eps_steps_lst[env_id] + self_play_episodes += 1 + + pred_values_lst[env_id] = [] + search_values_lst[env_id] = [] + eps_steps_lst[env_id] = 0 + visit_entropies_lst[env_id] = 0 + + # Env reset is done by env_manager automatically + self._policy.reset([env_id]) + self._reset_stat(env_id) + # TODO(pu): subprocess mode, when n_episode > self._env_num, occasionally the ready_env_id=() + # and the stack_obs is np.array(None, dtype=object) + ready_env_id.remove(env_id) + + if collected_episode >= n_episode: + # [data, meta_data] + return_data = [self.game_segment_pool[i][0] for i in range(len(self.game_segment_pool))], [ + { + 'priorities': self.game_segment_pool[i][1], + 'done': self.game_segment_pool[i][2], + 'unroll_plus_td_steps': self.unroll_plus_td_steps + } for i in range(len(self.game_segment_pool)) + ] + # for i in range(len(self.game_segment_pool)): + # print(self.game_segment_pool[i][0].obs_segment.__len__()) + # print(self.game_segment_pool[i][0].reward_segment) + # for i in range(len(return_data[0])): + # print(return_data[0][i].reward_segment) + break + # log + self._output_log(train_iter) + return return_data + + def _output_log(self, train_iter: int) -> None: + """ + Overview: + Print the output log information. You can refer to Docs/Best Practice/How to understand\ + training generated folders/Serial mode/log/collector for more details. + Arguments: + - train_iter (:obj:`int`): the number of training iteration. + """ + if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: + self._last_train_iter = train_iter + episode_count = len(self._episode_info) + envstep_count = sum([d['step'] for d in self._episode_info]) + duration = sum([d['time'] for d in self._episode_info]) + episode_reward = [d['reward'] for d in self._episode_info] + visit_entropy = [d['visit_entropy'][0] for d in self._episode_info] + self._total_duration += duration + info = { + 'episode_count': episode_count, + 'envstep_count': envstep_count, + 'avg_envstep_per_episode': envstep_count / episode_count, + 'avg_envstep_per_sec': envstep_count / duration, + 'avg_episode_per_sec': episode_count / duration, + 'collect_time': duration, + 'reward_mean': np.mean(episode_reward), + 'reward_std': np.std(episode_reward), + 'reward_max': np.max(episode_reward), + 'reward_min': np.min(episode_reward), + 'total_envstep_count': self._total_envstep_count, + 'total_episode_count': self._total_episode_count, + 'total_duration': self._total_duration, + 'visit_entropy': np.mean(visit_entropy), + # 'each_reward': episode_reward, + } + self._episode_info.clear() + self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) + for k, v in info.items(): + if k in ['each_reward']: + continue + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + if k in ['total_envstep_count']: + continue + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) diff --git a/lzero/worker/gobigger_muzero_evaluator.py b/lzero/worker/gobigger_muzero_evaluator.py new file mode 100644 index 000000000..f6a6aeb8b --- /dev/null +++ b/lzero/worker/gobigger_muzero_evaluator.py @@ -0,0 +1,434 @@ +import time +import copy +from collections import namedtuple +from typing import Optional, Callable, Tuple + +import numpy as np +import torch +from easydict import EasyDict + +from ding.envs import BaseEnvManager +from ding.torch_utils import to_ndarray +from ding.utils import build_logger, EasyTimer +from ding.worker.collector.base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor +from lzero.mcts.buffer.game_segment import GameSegment +from lzero.mcts.utils import prepare_observation + + +class GoBiggerMuZeroEvaluator(ISerialEvaluator): + """ + Overview: + The Evaluator for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. + Interfaces: + __init__, reset, reset_policy, reset_env, close, should_eval, eval + Property: + env, policy + """ + + @classmethod + def default_config(cls: type) -> EasyDict: + """ + Overview: + Get evaluator's default config. We merge evaluator's default config with other default configs\ + and user's config to get the final config. + Return: + cfg (:obj:`EasyDict`): evaluator's default config + """ + cfg = EasyDict(copy.deepcopy(cls.config)) + cfg.cfg_type = cls.__name__ + 'Dict' + return cfg + + config = dict( + # Evaluate every "eval_freq" training iterations. + eval_freq=50, + ) + + def __init__( + self, + eval_freq: int = 1000, + n_evaluator_episode: int = 3, + stop_value: int = 1e6, + env: BaseEnvManager = None, + policy: namedtuple = None, + tb_logger: 'SummaryWriter' = None, # noqa + exp_name: Optional[str] = 'default_experiment', + instance_name: Optional[str] = 'evaluator', + policy_config: 'policy_config' = None, # noqa + ) -> None: + """ + Overview: + Init method. Load config and use ``self._cfg`` setting to build common serial evaluator components, + e.g. logger helper, timer. + Arguments: + - eval_freq (:obj:`int`): evaluation frequency in terms of training steps. + - n_evaluator_episode (:obj:`int`): the number of episodes to eval in total. + - env (:obj:`BaseEnvManager`): the subclass of vectorized env_manager(BaseEnvManager) + - policy (:obj:`namedtuple`): the api namedtuple of collect_mode policy + - tb_logger (:obj:`SummaryWriter`): tensorboard handle + - exp_name (:obj:`str`): Experiment name, which is used to indicate output directory. + - instance_name (:obj:`Optional[str]`): Name of this instance. + - policy_config: Config of game. + """ + self._eval_freq = eval_freq + self._exp_name = exp_name + self._instance_name = instance_name + if tb_logger is not None: + self._logger, _ = build_logger( + path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False + ) + self._tb_logger = tb_logger + else: + self._logger, self._tb_logger = build_logger( + path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name + ) + self.reset(policy, env) + + self._timer = EasyTimer() + self._default_n_episode = n_evaluator_episode + self._stop_value = stop_value + + # ============================================================== + # MCTS+RL related core code + # ============================================================== + self.policy_config = policy_config + + def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: + """ + Overview: + Reset evaluator's environment. In some case, we need evaluator use the same policy in different \ + environments. We can use reset_env to reset the environment. + If _env is None, reset the old environment. + If _env is not None, replace the old environment in the evaluator with the \ + new passed in environment and launch. + Arguments: + - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ + env_manager(BaseEnvManager) + """ + if _env is not None: + self._env = _env + self._env.launch() + self._env_num = self._env.env_num + else: + self._env.reset() + + def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: + """ + Overview: + Reset evaluator's policy. In some case, we need evaluator work in this same environment but use\ + different policy. We can use reset_policy to reset the policy. + If _policy is None, reset the old policy. + If _policy is not None, replace the old policy in the evaluator with the new passed in policy. + Arguments: + - policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy + """ + assert hasattr(self, '_env'), "please set env first" + if _policy is not None: + self._policy = _policy + self._policy.reset() + + def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: + """ + Overview: + Reset evaluator's policy and environment. Use new policy and environment to collect data. + If _env is None, reset the old environment. + If _env is not None, replace the old environment in the evaluator with the new passed in \ + environment and launch. + If _policy is None, reset the old policy. + If _policy is not None, replace the old policy in the evaluator with the new passed in policy. + Arguments: + - policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy + - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ + env_manager(BaseEnvManager) + """ + if _env is not None: + self.reset_env(_env) + if _policy is not None: + self.reset_policy(_policy) + self._max_eval_reward = float("-inf") + self._last_eval_iter = 0 + self._end_flag = False + + def close(self) -> None: + """ + Overview: + Close the evaluator. If end_flag is False, close the environment, flush the tb_logger\ + and close the tb_logger. + """ + if self._end_flag: + return + self._end_flag = True + self._env.close() + self._tb_logger.flush() + self._tb_logger.close() + + def __del__(self): + """ + Overview: + Execute the close command and close the evaluator. __del__ is automatically called \ + to destroy the evaluator instance when the evaluator finishes its work + """ + self.close() + + def should_eval(self, train_iter: int) -> bool: + """ + Overview: + Determine whether you need to start the evaluation mode, if the number of training has reached\ + the maximum number of times to start the evaluator, return True + Arguments: + - train_iter (:obj:`int`): Current training iteration. + """ + if train_iter == self._last_eval_iter: + return False + if (train_iter - self._last_eval_iter) < self._eval_freq and train_iter != 0: + return False + self._last_eval_iter = train_iter + return True + + def eval( + self, + save_ckpt_fn: Callable = None, + train_iter: int = -1, + envstep: int = -1, + n_episode: Optional[int] = None, + ) -> Tuple[bool, float]: + """ + Overview: + Evaluate policy and store the best policy based on whether it reaches the highest historical reward. + Arguments: + - save_ckpt_fn (:obj:`Callable`): Saving ckpt function, which will be triggered by getting the best reward. + - train_iter (:obj:`int`): Current training iteration. + - envstep (:obj:`int`): Current env interaction step. + - n_episode (:obj:`int`): Number of evaluation episodes. + Returns: + - stop_flag (:obj:`bool`): Whether this training program can be ended. + - eval_reward (:obj:`float`): Current eval_reward. + """ + if n_episode is None: + n_episode = self._default_n_episode + assert n_episode is not None, "please indicate eval n_episode" + envstep_count = 0 + eval_monitor = VectorEvalMonitor(self._env.env_num, n_episode) + env_nums = self._env.env_num + + self._env.reset() + self._policy.reset() + + # initializations + init_obs = self._env.ready_obs + + retry_waiting_time = 0.001 + while len(init_obs.keys()) != self._env_num: + # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to + # len(self._env.ready_obs), especially in tictactoe env. + self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) + self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) + time.sleep(retry_waiting_time) + self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) + self._logger.info( + 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states) + ) + init_obs = self._env.ready_obs + + action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} + + to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} + dones = np.array([False for _ in range(env_nums)]) + + game_segments = [ + GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) for _ in range(env_nums) + ] + for i in range(env_nums): + game_segments[i].reset( + [to_ndarray(init_obs[i]['observation']) for _ in range(self.policy_config.model.frame_stack_num)] + ) + + ready_env_id = set() + remain_episode = n_episode + + with self._timer: + while not eval_monitor.is_finished(): + # Get current ready env obs. + obs = self._env.ready_obs + new_available_env_id = set(obs.keys()).difference(ready_env_id) + ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) + remain_episode -= min(len(new_available_env_id), remain_episode) + + stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + stack_obs = list(stack_obs.values()) + + action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} + to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} + action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] + to_play = [to_play_dict[env_id] for env_id in ready_env_id] + + stack_obs = to_ndarray(stack_obs) + # stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) + # stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() + + # ============================================================== + # policy forward + # ============================================================== + policy_output = self._policy.forward(stack_obs, action_mask, to_play) + + actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} + distributions_dict_no_env_id = {k: v['distributions'] for k, v in policy_output.items()} + if self.policy_config.sampled_algo: + root_sampled_actions_dict_no_env_id = { + k: v['root_sampled_actions'] + for k, v in policy_output.items() + } + + value_dict_no_env_id = {k: v['value'] for k, v in policy_output.items()} + pred_value_dict_no_env_id = {k: v['pred_value'] for k, v in policy_output.items()} + visit_entropy_dict_no_env_id = { + k: v['visit_count_distribution_entropy'] + for k, v in policy_output.items() + } + + actions = {} + distributions_dict = {} + if self.policy_config.sampled_algo: + root_sampled_actions_dict = {} + value_dict = {} + pred_value_dict = {} + visit_entropy_dict = {} + for index, env_id in enumerate(ready_env_id): + actions[env_id] = actions_no_env_id.pop(index) + distributions_dict[env_id] = distributions_dict_no_env_id.pop(index) + if self.policy_config.sampled_algo: + root_sampled_actions_dict[env_id] = root_sampled_actions_dict_no_env_id.pop(index) + value_dict[env_id] = value_dict_no_env_id.pop(index) + pred_value_dict[env_id] = pred_value_dict_no_env_id.pop(index) + visit_entropy_dict[env_id] = visit_entropy_dict_no_env_id.pop(index) + + # ============================================================== + # Interact with env. + # ============================================================== + timesteps = self._env.step(actions) + + for env_id, t in timesteps.items(): + obs, reward, done, info = t.obs, t.reward, t.done, t.info + + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], + to_play_dict[env_id] + ) + + # NOTE: in evaluator, we only need save the ``o_{t+1} = obs['observation']`` + # game_segments[env_id].obs_segment.append(to_ndarray(obs['observation'])) + + # NOTE: the position of code snippet is very important. + # the obs['action_mask'] and obs['to_play'] is corresponding to next action + action_mask_dict[env_id] = to_ndarray(obs['action_mask']) + to_play_dict[env_id] = to_ndarray(obs['to_play']) + + dones[env_id] = done + if t.done: + # Env reset is done by env_manager automatically. + self._policy.reset([env_id]) + reward = t.info['eval_episode_return'][0] + if 'episode_info' in t.info: + eval_monitor.update_info(env_id, t.info['episode_info']) + eval_monitor.update_reward(env_id, reward) + self._logger.info( + "[EVALUATOR]env {} finish episode, final reward: {}, current episode: {}".format( + env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode() + ) + ) + + # reset the finished env and init game_segments + if n_episode > self._env_num: + # Get current ready env obs. + init_obs = self._env.ready_obs + retry_waiting_time = 0.001 + while len(init_obs.keys()) != self._env_num: + # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to + # len(self._env.ready_obs), especially in tictactoe env. + self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) + self._logger.info( + 'Before sleeping, the _env_states is {}'.format(self._env._env_states) + ) + time.sleep(retry_waiting_time) + self._logger.info( + '=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10 + ) + self._logger.info( + 'After sleeping {}s, the current _env_states is {}'.format( + retry_waiting_time, self._env._env_states + ) + ) + init_obs = self._env.ready_obs + + new_available_env_id = set(init_obs.keys()).difference(ready_env_id) + ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) + remain_episode -= min(len(new_available_env_id), remain_episode) + + action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) + to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) + + game_segments[env_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + + game_segments[env_id].reset( + [ + init_obs[env_id]['observation'] + for _ in range(self.policy_config.model.frame_stack_num) + ] + ) + + # Env reset is done by env_manager automatically. + self._policy.reset([env_id]) + # TODO(pu): subprocess mode, when n_episode > self._env_num, occasionally the ready_env_id=() + # and the stack_obs is np.array(None, dtype=object) + ready_env_id.remove(env_id) + + envstep_count += 1 + duration = self._timer.value + episode_return = eval_monitor.get_episode_return() + info = { + 'train_iter': train_iter, + 'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter), + 'episode_count': n_episode, + 'envstep_count': envstep_count, + 'avg_envstep_per_episode': envstep_count / n_episode, + 'evaluate_time': duration, + 'avg_envstep_per_sec': envstep_count / duration, + 'avg_time_per_episode': n_episode / duration, + 'reward_mean': np.mean(episode_return), + 'reward_std': np.std(episode_return), + 'reward_max': np.max(episode_return), + 'reward_min': np.min(episode_return), + # 'each_reward': episode_return, + } + episode_info = eval_monitor.get_episode_info() + if episode_info is not None: + info.update(episode_info) + self._logger.info(self._logger.get_tabulate_vars_hor(info)) + # self._logger.info(self._logger.get_tabulate_vars(info)) + for k, v in info.items(): + if k in ['train_iter', 'ckpt_name', 'each_reward']: + continue + if not np.isscalar(v): + continue + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) + eval_reward = np.mean(episode_return) + if eval_reward > self._max_eval_reward: + if save_ckpt_fn: + save_ckpt_fn('ckpt_best.pth.tar') + self._max_eval_reward = eval_reward + stop_flag = eval_reward >= self._stop_value and train_iter > 0 + if stop_flag: + self._logger.info( + "[LightZero serial pipeline] " + + "Current eval_reward: {} is greater than stop_value: {}".format(eval_reward, self._stop_value) + + ", so your MCTS/RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details." + ) + return stop_flag, eval_reward diff --git a/zoo/gobigger/config/gobigger_efficientzero_config.py b/zoo/gobigger/config/gobigger_efficientzero_config.py new file mode 100644 index 000000000..189f60e50 --- /dev/null +++ b/zoo/gobigger/config/gobigger_efficientzero_config.py @@ -0,0 +1,141 @@ +from easydict import EasyDict + +env_name = 'GoBigger' + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 16 +n_episode = 16 +evaluator_env_num = 5 +num_simulations = 50 +update_per_collect = 1000 +batch_size = 256 +reanalyze_ratio = 0. +action_space_size = 27 +direction_num=12 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +atari_efficientzero_config = dict( + exp_name= + f'data_ez_ctree/{env_name}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + env=dict( + env_name=env_name, + team_num=2, + player_num_per_team=2, + direction_num=direction_num, + step_mul=8, + map_width=64, + map_height=64, + frame_limit=3600, + action_space_size=action_space_size, + use_action_mask=False, + reward_div_value=0.1, + reward_type='log_reward', + start_spirit_progress=0.2, + end_spirit_progress=0.8, + manager_settings=dict( + food_manager=dict( + num_init=260, + num_min=260, + num_max=300, + ), + thorns_manager=dict( + num_init=3, + num_min=3, + num_max=4, + ), + player_manager=dict( + ball_settings=dict( + score_init=13000, + ), + ), + ), + playback_settings=dict( + playback_type='by_frame', + by_frame=dict( + save_frame=False, + # save_frame=True, + save_dir='./', + save_name_prefix='gobigger', + ), + ), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + # observation_shape=(4, 96, 96), + latent_state_dim=176, + frame_stack_num=1, + action_space_size=action_space_size, + downsample=True, + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + mcts_ctree=True, + env_type='not_board_games', + game_segment_length=400, + use_augmentation=False, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + learn=dict( + learner=dict( + log_policy=True, + hook=dict( + log_show_after_iter=10, + ), + ), + ), + collect=dict( + collector=dict( + collect_print_freq=10, + ), + ), + eval=dict( + evaluator=dict( + eval_freq=5000, + stop_value=10000000000, + ), + ), +) +atari_efficientzero_config = EasyDict(atari_efficientzero_config) +main_config = atari_efficientzero_config + +atari_efficientzero_create_config = dict( + env=dict( + type='gobigger_lightzero', + import_names=['zoo.gobigger.env.gobigger_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='gobigger_efficientzero', + import_names=['lzero.policy.gobigger_efficientzero'], + ), + collector=dict( + type='gobigger_episode_muzero', + import_names=['lzero.worker.gobigger_muzero_collector'], + ) +) +atari_efficientzero_create_config = EasyDict(atari_efficientzero_create_config) +create_config = atari_efficientzero_create_config + +if __name__ == "__main__": + from lzero.entry import train_muzero_gobigger + train_muzero_gobigger([main_config, create_config], seed=0) diff --git a/zoo/gobigger/env/gobigger_env.py b/zoo/gobigger/env/gobigger_env.py new file mode 100644 index 000000000..46168b8ae --- /dev/null +++ b/zoo/gobigger/env/gobigger_env.py @@ -0,0 +1,507 @@ +import gym +import numpy as np +from ding.envs import BaseEnv, BaseEnvTimestep +from ding.utils import ENV_REGISTRY +from gobigger.envs import GoBiggerEnv +import math + + +@ENV_REGISTRY.register('gobigger_lightzero') +class GoBiggerLightZeroEnv(BaseEnv): + + def __init__(self, cfg: dict) -> None: + self._cfg = cfg + self._init_flag = False + self._observation_space = None + self._action_space = None + self._reward_space = None + self.team_num = self._cfg.team_num + self.player_num_per_team = self._cfg.player_num_per_team + self.direction_num = self._cfg.direction_num + self.use_action_mask = self._cfg.use_action_mask + self.action_space_size = self._cfg.action_space_size + self.max_player_num = self.player_num_per_team + self.step_mul = self._cfg.get('step_mul', 8) + self.setup_action() + # feature engineering + self.second_per_frame = 0.05 + self.spatial_x = 64 + self.spatial_y = 64 + self.max_ball_num = 80 + self.max_food_num = 256 + self.max_spore_num = 64 + self.reward_div_value = self._cfg.reward_div_value + self.reward_type = self._cfg.reward_type + self.player_init_score = self._cfg.manager_settings.player_manager.ball_settings.score_init + self.start_spirit_progress = self._cfg.start_spirit_progress + self.end_spirit_progress = self._cfg.end_spirit_progress + + def reset(self) -> np.ndarray: + if not self._init_flag: + self._env = GoBiggerEnv(self._cfg, step_mul=self.step_mul) + self._init_flag = True + self.last_action_types = {player_id: self.direction_num * 2 for player_id in range(self.player_num_per_team * self.team_num)} + raw_obs = self._env.reset() + self.eval_episode_return = [[] for _ in range(self.team_num)] + obs = self.observation(raw_obs) + self.last_action_types = {player_id: self.direction_num * 2 for player_id in + range(self.player_num_per_team * self.team_num)} + self.last_leaderboard = {team_idx: self.player_init_score * self.player_num_per_team for team_idx in range(self.team_num)} + self.last_player_scores = {player_id: self.player_init_score for player_id in range(self.player_num_per_team * self.team_num)} + return obs + + def observation(self, raw_obs): + obs = self.preprocess_obs(raw_obs) + action_mask = [np.logical_not(o['action_mask']) for o in obs] + to_play = [ -1 for _ in range(len(obs))] + obs = {'observation': obs, 'action_mask': action_mask, 'to_play': to_play} + return obs + + def step(self, action: list) -> BaseEnvTimestep: + action = {i: self.transform_action(a) for i, a in enumerate(action)} + raw_obs, raw_rew, done, info = self._env.step(action) + # print('current_frame={}'.format(raw_obs[0]['last_time'])) + # print('raw_rew={}'.format(raw_rew)) + # print('action={}'.format(action)) + # print('raw_rew={}, done={}'.format(raw_rew, done)) + rew = self.transform_reward(raw_obs) + for i in range(self.team_num): + self.eval_episode_return[i].append(raw_obs[0]['leaderboard'][i]) + obs = self.observation(raw_obs) + if done: + info['eval_episode_return'] = [np.mean(self.eval_episode_return[i]) for i in range(self.team_num)] + return BaseEnvTimestep(obs, rew, done, info) + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + self._seed = seed + self._dynamic_seed = dynamic_seed + np.random.seed(self._seed) + + def close(self) -> None: + if self._init_flag: + self._env.close() + self._init_flag = False + + @property + def observation_space(self) -> gym.spaces.Space: + return self._observation_space + + @property + def action_space(self) -> gym.spaces.Space: + return self._action_space + + @property + def reward_space(self) -> gym.spaces.Space: + return self._reward_space + + def __repr__(self) -> str: + return "LightZero Env({})".format(self.cfg.env_name) + + def transform_obs(self, obs, game_player_id=1, padding=True, last_action_type=None, ): + global_state, player_observations = obs + player2team = self.get_player2team() + own_player_id = game_player_id + leaderboard = global_state['leaderboard'] + team2rank = {key: rank for rank, key in enumerate(sorted(leaderboard, key=leaderboard.get, reverse=True), )} + + own_player_obs = player_observations[own_player_id] + own_team_id = player2team[own_player_id] + + # =========== + # scalar info + # =========== + scene_size = global_state['border'][0] + own_left_top_x, own_left_top_y, own_right_bottom_x, own_right_bottom_y = own_player_obs['rectangle'] + own_view_center = [(own_left_top_x + own_right_bottom_x - scene_size) / 2, + (own_left_top_y + own_right_bottom_y - scene_size) / 2] + own_view_width = float(own_right_bottom_x - own_left_top_x) + # own_view_height = float(own_right_bottom_y - own_left_top_y) + + own_score = own_player_obs['score'] / 100 + own_team_score = global_state['leaderboard'][own_team_id] / 100 + own_rank = team2rank[own_team_id] + + scalar_info = { + 'view_x': np.round(np.array(own_view_center[0])).astype(np.int64), + 'view_y': np.round(np.array(own_view_center[1])).astype(np.int64), + 'view_width': np.round(np.array(own_view_width)).astype(np.int64), + 'score': np.clip(np.round(np.log(np.array(own_score) / 10)).astype(np.int64), a_min=None, a_max=9), + 'team_score': np.clip(np.round(np.log(np.array(own_team_score / 10))).astype(np.int64), a_min=None, a_max=9), + 'time': np.array(global_state['last_time']//20, dtype=np.int64), + 'rank': np.array(own_rank, dtype=np.int64), + 'last_action_type': np.array(last_action_type, dtype=np.int64) + } + + # =========== + # team_info + # =========== + + all_players = [] + scene_size = global_state['border'][0] + + for game_player_id in player_observations.keys(): + game_team_id = player2team[game_player_id] + game_player_left_top_x, game_player_left_top_y, game_player_right_bottom_x, game_player_right_bottom_y = \ + player_observations[game_player_id]['rectangle'] + if game_player_id == own_player_id: + alliance = 0 + elif game_team_id == own_team_id: + alliance = 1 + else: + alliance = 2 + if alliance != 2: + game_player_view_x = (game_player_right_bottom_x + game_player_left_top_x - scene_size) / 2 + game_player_view_y = (game_player_right_bottom_y + game_player_left_top_y - scene_size) / 2 + + all_players.append([alliance, + game_player_view_x, + game_player_view_y, + ]) + + all_players = np.array(all_players) + player_padding_num = self.max_player_num - len(all_players) + player_num = len(all_players) + all_players = np.pad(all_players, pad_width=((0, player_padding_num), (0, 0)), mode='constant') + team_info = { + 'alliance': all_players[:, 0].astype(np.int64), + 'view_x': np.round(all_players[:, 1]).astype(np.int64), + 'view_y': np.round(all_players[:, 2]).astype(np.int64), + 'player_num': np.array(player_num, dtype=np.int64), + } + + # =========== + # ball info + # =========== + ball_type_map = {'clone': 1, 'food': 2, 'thorns': 3, 'spore': 4} + clone = own_player_obs['overlap']['clone'] + thorns = own_player_obs['overlap']['thorns'] + food = own_player_obs['overlap']['food'] + spore = own_player_obs['overlap']['spore'] + + neutral_team_id = self.team_num + neutral_player_id = self.team_num * self.player_num_per_team + neutral_team_rank = self.team_num + + clone = [[ball_type_map['clone'], bl[3], bl[-2], bl[-1], team2rank[bl[-1]], bl[0], bl[1], + *self.next_position(bl[0], bl[1], bl[4], bl[5])] for bl in clone] + thorns = [[ball_type_map['thorns'], bl[3], neutral_player_id, neutral_team_id, neutral_team_rank, bl[0], bl[1], + *self.next_position(bl[0], bl[1], bl[4], bl[5])] for bl in thorns] + food = [ + [ball_type_map['food'], bl[3], neutral_player_id, neutral_team_id, neutral_team_rank, bl[0], bl[1], bl[0], + bl[1]] for bl in food] + + spore = [ + [ball_type_map['spore'], bl[3], bl[-1], player2team[bl[-1]], team2rank[player2team[bl[-1]]], bl[0], + bl[1], + *self.next_position(bl[0], bl[1], bl[4], bl[5])] for bl in spore] + + all_balls = clone + thorns + food + spore + + for b in all_balls: + if b[2] == own_player_id and b[0] == 1: + if b[5] < own_left_top_x or b[5] > own_right_bottom_x or \ + b[6] < own_left_top_y or b[6] > own_right_bottom_y: + b[5] = int((own_left_top_x + own_right_bottom_x) / 2) + b[6] = int((own_left_top_y + own_right_bottom_y) / 2) + b[7], b[8] = b[5], b[6] + all_balls = np.array(all_balls) + + origin_x = own_left_top_x + origin_y = own_left_top_y + + all_balls[:, -4] = ((all_balls[:, -4] - origin_x) / own_view_width * self.spatial_x) + all_balls[:, -3] = ((all_balls[:, -3] - origin_y) / own_view_width * self.spatial_y) + all_balls[:, -2] = ((all_balls[:, -2] - origin_x) / own_view_width * self.spatial_x) + all_balls[:, -1] = ((all_balls[:, -1] - origin_y) / own_view_width * self.spatial_y) + + # ball + ball_indices = np.logical_and(all_balls[:, 0] != 2, all_balls[:, 0] != 4) + balls = all_balls[ball_indices] + + balls_num = len(balls) + + # consider position of thorns ball + if balls_num > self.max_ball_num: # filter small balls + own_indices = balls[:, 3] == own_player_id + teammate_indices = (balls[:, 4] == own_team_id) & ~own_indices + enemy_indices = balls[:, 4] != own_team_id + + own_balls = balls[own_indices] + teammate_balls = balls[teammate_indices] + enemy_balls = balls[enemy_indices] + + if own_balls.shape[0] + teammate_balls.shape[0] >= self.max_ball_num: + remain_ball_num = self.max_ball_num - own_balls.shape[0] + teammate_ball_score = teammate_balls[:, 1] + teammate_high_score_indices = teammate_ball_score.sort(descending=True)[1][:remain_ball_num] + teammate_remain_balls = teammate_balls[teammate_high_score_indices] + balls = np.concatenate([own_balls, teammate_remain_balls], axis=0) + else: + remain_ball_num = self.max_ball_num - own_balls.shape[0] - teammate_balls.shape[0] + enemy_ball_score = enemy_balls[:, 1] + enemy_high_score_ball_indices = enemy_ball_score.sort(descending=True)[1][:remain_ball_num] + remain_enemy_balls = enemy_balls[enemy_high_score_ball_indices] + + balls = np.concatenate([own_balls, teammate_balls, remain_enemy_balls], axis=0) + + ball_padding_num = self.max_ball_num - len(balls) + if padding or ball_padding_num < 0: + balls = np.pad(balls, ((0, ball_padding_num), (0, 0)), 'constant', constant_values=0) + alliance = np.zeros(self.max_ball_num) + balls_num = min(self.max_ball_num, balls_num) + else: + alliance = np.zeros(balls_num) + alliance[balls[:, 3] == own_team_id] = 2 + alliance[balls[:, 2] == own_player_id] = 1 + alliance[balls[:, 3] != own_team_id] = 3 + alliance[balls[:, 0] == 3] = 0 + + ## score&radius + scale_score = balls[:, 1] / 100 + radius = np.clip(np.sqrt(scale_score * 0.042 + 0.15) / own_view_width, a_max=1, a_min=0) + score = np.clip(np.round(np.sqrt(scale_score * 0.042 + 0.15) / own_view_width * 50), a_max=49, a_min=0).astype(int) + + ## rank: + ball_rank = balls[:, 4] + + ## coordinate + x = balls[:, -4] - self.spatial_x // 2 + y = balls[:, -3] - self.spatial_y // 2 + next_x = balls[:, -2] - self.spatial_x // 2 + next_y = balls[:, -1] - self.spatial_y // 2 + + ball_info = { + 'alliance': alliance.astype(np.int64), + 'score': score.astype(np.int64), + 'radius': radius, + 'rank': ball_rank.astype(np.int64), + 'x': np.round(x).astype(np.int64), + 'y': np.round(y).astype(np.int64), + 'next_x': np.round(next_x).astype(np.int64), + 'next_y': np.round(next_y).astype(np.int64), + 'ball_num': np.array(balls_num).astype(np.int64), + } + + # ============ + # spatial info + # ============ + # ball coordinate for scatter connection + ball_x = balls[:, -4] + ball_y = balls[:, -3] + + food_indices = all_balls[:, 0] == 2 + food_x = all_balls[food_indices, -4] + food_y = all_balls[food_indices, -3] + food_num = len(food_x) + food_padding_num = self.max_food_num - len(food_x) + if padding or food_padding_num < 0: + food_x = np.pad(food_x, (0, food_padding_num), 'constant', constant_values=0) + food_y = np.pad(food_y, (0, food_padding_num), 'constant', constant_values=0) + food_num = min(food_num, self.max_food_num) + + spore_indices = all_balls[:, 0] == 4 + spore_x = all_balls[spore_indices, -4] + spore_y = all_balls[spore_indices, -3] + spore_num = len(spore_x) + spore_padding_num = self.max_spore_num - len(spore_x) + if padding or spore_padding_num < 0: + spore_x = np.pad(spore_x, (0, spore_padding_num), 'constant', constant_values=0) + spore_y = np.pad(spore_y, (0, spore_padding_num), 'constant', constant_values=0) + spore_num = min(spore_num, self.max_spore_num) + + spatial_info = { + 'food_x': np.clip(np.round(food_x), 0, self.spatial_x - 1).astype(np.int64), + 'food_y': np.clip(np.round(food_y), 0, self.spatial_y - 1).astype(np.int64), + 'spore_x': np.clip(np.round(spore_x), 0, self.spatial_x - 1).astype(np.int64), + 'spore_y': np.clip(np.round(spore_y), 0, self.spatial_y - 1).astype(np.int64), + 'ball_x': np.clip(np.round(ball_x), 0, self.spatial_x - 1).astype(np.int64), + 'ball_y': np.clip(np.round(ball_y), 0, self.spatial_y - 1).astype(np.int64), + 'food_num': np.array(food_num).astype(np.int64), + 'spore_num': np.array(spore_num).astype(np.int64), + } + + output_obs = { + 'scalar_info': scalar_info, + 'team_info': team_info, + 'ball_info': ball_info, + 'spatial_info': spatial_info, + } + return output_obs + + def _preprocess_obs(self, raw_obs, env_status=None, eval_vsbot=False): + env_player_obs = [] + game_player_num = self.player_num_per_team if eval_vsbot else self.player_num_per_team * self.team_num + for game_player_id in range(game_player_num): + if env_status is None: + last_action_type = self.direction_num * 2 + else: + last_action_type = self.last_action_types[game_player_id] + if self.use_action_mask: + can_eject = raw_obs[1][game_player_id]['can_eject'] + can_split = raw_obs[1][game_player_id]['can_split'] + action_mask = self.generate_action_mask(can_eject=can_eject, can_split=can_split) + else: + action_mask = self.generate_action_mask(can_eject=True, can_split=True) + game_player_obs = self.transform_obs(raw_obs, game_player_id=game_player_id, padding=True, + last_action_type=last_action_type) + game_player_obs['action_mask'] = action_mask + env_player_obs.append(game_player_obs) + # env_player_obs = default_collate_with_dim(env_player_obs) + return env_player_obs + + # def collate_obs(self, env_player_obs): + # processed_obs_list = [] + # for env_id, env_obs in env_player_obs.items(): + # for game_player_id, game_player_obs in env_obs.items(): + # processed_obs_list.append(game_player_obs) + # obs_batch = default_collate_with_dim(processed_obs_list, device=self.device) + # return obs_batch + + def preprocess_obs(self, obs_list, env_status=None, eval_vsbot=False): + env_player_obs = self._preprocess_obs(obs_list, env_status, eval_vsbot) + return env_player_obs + + def generate_action_mask(self, can_eject, can_split, ): + action_mask = np.zeros((self.action_space_size,), dtype=np.bool_) + if not can_eject: + action_mask[self.direction_num * 2 + 1] = True + if not can_split: + action_mask[self.direction_num * 2 + 2] = True + return action_mask + + def get_player2team(self, ): + player2team = {} + for player_id in range(self.player_num_per_team * self.team_num): + player2team[player_id] = player_id // self.player_num_per_team + return player2team + + def next_position(self, x, y, vel_x, vel_y): + next_x = x + self.second_per_frame * vel_x * self.step_mul + next_y = y + self.second_per_frame * vel_y * self.step_mul + return next_x, next_y + + def transform_action(self, action_idx): + return self.x_y_action_List[int(action_idx)] + + def setup_action(self): + theta = math.pi * 2 / self.direction_num + self.x_y_action_List = [[0.3 * math.cos(theta * i), 0.3 * math.sin(theta * i), 0] for i in + range(self.direction_num)] + \ + [[math.cos(theta * i), math.sin(theta * i), 0] for i in + range(self.direction_num)] + \ + [[0, 0, 0], [0, 0, 1], [0, 0, 2]] + + def get_spirit(self, progress): + if progress < self.start_spirit_progress: + return 0 + elif progress <= self.end_spirit_progress: + spirit = (progress - self.start_spirit_progress) / (self.end_spirit_progress - self.start_spirit_progress) + return spirit + else: + return 1 + + def transform_reward(self, next_obs): + last_time = next_obs[0]['last_time'] + total_frame = next_obs[0]['total_frame'] + progress = last_time / total_frame + spirit = self.get_spirit(progress) + score_rewards_list = [] + for game_player_id in range(self.player_num_per_team * self.team_num): + game_team_id = game_player_id // self.player_num_per_team + player_score = next_obs[1][game_player_id]['score'] + team_score = next_obs[0]['leaderboard'][game_team_id] + if self.reward_type == 'log_reward': + player_reward = math.log(player_score) - math.log(self.last_player_scores[game_player_id]) + team_reward = math.log(team_score) - math.log(self.last_leaderboard[game_team_id]) + score_reward = (1 - spirit) * player_reward + spirit * team_reward / self.player_num_per_team + score_reward = score_reward / self.reward_div_value + score_rewards_list.append(score_reward) + elif self.reward_type == 'score': + player_reward = player_score - self.last_player_scores[game_player_id] + team_reward = team_score - self.last_leaderboard[game_team_id] + score_reward = (1 - spirit) * player_reward + spirit * team_reward / self.player_num_per_team + score_reward = score_reward / self.reward_div_value + score_rewards_list.append(score_reward) + elif self.reward_type == 'sqrt_player': + player_reward = player_score - self.last_player_scores[game_player_id] + reward_sign = (player_reward > 0) - (player_reward < 0) # np.sign + score_rewards_list.append(reward_sign * math.sqrt(abs(player_reward)) / 2) + elif self.reward_type == 'sqrt_team': + team_reward = team_score - self.last_leaderboard[game_team_id] + reward_sign = (team_reward > 0) - (team_reward < 0) # np.sign + score_rewards_list.append(reward_sign * math.sqrt(abs(team_reward)) / 2) + else: + raise NotImplementedError + self.last_player_scores[game_player_id] = player_score + self.last_leaderboard = next_obs[0]['leaderboard'] + return score_rewards_list + +if __name__ == '__main__': + from easydict import EasyDict + env_cfg=EasyDict(dict( + env_name='gobigger', + team_num=2, + player_num_per_team=2, + direction_num=12, + step_mul=8, + map_width=64, + map_height=64, + frame_limit=3600, + action_space_size=27, + use_action_mask=False, + reward_div_value=0.1, + reward_type='log_reward', + start_spirit_progress=0.2, + end_spirit_progress=0.8, + manager_settings=dict( + food_manager=dict( + num_init=260, + num_min=260, + num_max=300, + ), + thorns_manager=dict( + num_init=3, + num_min=3, + num_max=4, + ), + player_manager=dict( + ball_settings=dict( + score_init=13000, + ), + ), + ), + playback_settings=dict( + playback_type='by_frame', + by_frame=dict( + # save_frame=False, + save_frame=True, + save_dir='./', + save_name_prefix='gobigger-1', + ), + ), + )) + + env = GoBiggerLightZeroEnv(env_cfg) + env.reset() + import random + while True: + actions = [random.randint(0, 26), random.randint(0, 26), random.randint(0, 26), random.randint(0, 26)] + # actions = {0: [random.uniform(-1, 1), random.uniform(-1, 1), -1], + # 1: [random.uniform(-1, 1), random.uniform(-1, 1), -1], + # 2: [random.uniform(-1, 1), random.uniform(-1, 1), -1], + # 3: [random.uniform(-1, 1), random.uniform(-1, 1), -1]} + timestep = env.step(actions) + if timestep.done: + break + + # from ding.envs import create_env_manager + # from functools import partial + # env_manager=EasyDict({'episode_num': float('inf'), 'max_retry': 1, 'retry_type': 'reset', 'auto_reset': True, + # 'step_timeout': None, 'reset_timeout': None, 'retry_waiting_time': 0.1, 'cfg_type': 'BaseEnvManagerDict', + # 'type': 'base', 'shared_memory': False}) + + # collector_env = create_env_manager(env_manager, [partial(GoBiggerLightZeroEnv, cfg=c) for c in [env_cfg]]) + # collector_env.launch() + # print(collector_env._env_num) + # for i in range(500): + # timestep = collector_env.step({0:[0,0,0,0]}) \ No newline at end of file From 2c29842040745f277543cbfc220b2d86e597b817 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Thu, 1 Jun 2023 13:22:29 +0800 Subject: [PATCH 02/54] fix(yzj): fix data device bug in gobigger ez pipeline --- lzero/mcts/buffer/gobigger_game_buffer_muzero.py | 2 ++ lzero/policy/gobigger_efficientzero.py | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/lzero/mcts/buffer/gobigger_game_buffer_muzero.py b/lzero/mcts/buffer/gobigger_game_buffer_muzero.py index eb1a435b5..3e30d0656 100644 --- a/lzero/mcts/buffer/gobigger_game_buffer_muzero.py +++ b/lzero/mcts/buffer/gobigger_game_buffer_muzero.py @@ -9,6 +9,7 @@ from lzero.mcts.utils import prepare_observation from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform from .game_buffer import GameBuffer +from ding.torch_utils import to_device if TYPE_CHECKING: from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy @@ -383,6 +384,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A m_obs = value_obs_list[beg_index:end_index] m_obs = to_tensor(m_obs) m_obs = sum(m_obs, []) + m_obs = to_device(m_obs, self._cfg.device) # calculate the target value m_output = model.initial_inference(m_obs) diff --git a/lzero/policy/gobigger_efficientzero.py b/lzero/policy/gobigger_efficientzero.py index 89277489f..846300a95 100644 --- a/lzero/policy/gobigger_efficientzero.py +++ b/lzero/policy/gobigger_efficientzero.py @@ -18,6 +18,7 @@ DiscreteSupport, select_action, to_torch_float_tensor, ez_network_output_unpack, negative_cosine_similarity, prepare_obs, \ configure_optimizers from collections import defaultdict +from ding.torch_utils import to_device @POLICY_REGISTRY.register('gobigger_efficientzero') @@ -294,6 +295,8 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # ============================================================== obs_batch = obs_batch.tolist() obs_batch = sum(obs_batch, []) + obs_batch = to_tensor(obs_batch) + obs_batch = to_device(obs_batch, self._cfg.device) network_output = self._learn_model.initial_inference(obs_batch) # value_prefix shape: (batch_size, 10), the ``value_prefix`` at the first step is zero padding. latent_state, value_prefix, reward_hidden_state, value, policy_logits = ez_network_output_unpack(network_output) @@ -370,6 +373,8 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: end_index = step_i + self._cfg.model.frame_stack_num obs_target_batch_tmp = obs_target_batch[:, beg_index:end_index].tolist() obs_target_batch_tmp = sum(obs_target_batch_tmp, []) + obs_target_batch_tmp = to_tensor(obs_target_batch_tmp) + obs_target_batch_tmp = to_device(obs_target_batch_tmp, self._cfg.device) network_output = self._learn_model.initial_inference(obs_target_batch_tmp) latent_state = to_tensor(latent_state) @@ -550,6 +555,7 @@ def _forward_collect( data = to_tensor(data) data = sum(sum(data, []), []) batch_size = len(data) + data = to_device(data, self._cfg.device) agent_num = batch_size // active_collect_env_num to_play = np.array(to_play).reshape(-1).tolist() @@ -668,6 +674,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read data = to_tensor(data) data = sum(sum(data, []), []) batch_size = len(data) + data = to_device(data, self._cfg.device) agent_num = batch_size // active_eval_env_num to_play = np.array(to_play).reshape(-1).tolist() From 335b0fc2a64443343a4d9f85374b29a39e55de13 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Thu, 1 Jun 2023 22:09:03 +0800 Subject: [PATCH 03/54] feature(yzj): add vsbot with ez pipeline and add eat-info in tensorboard --- lzero/entry/train_muzero_gobigger.py | 4 +- lzero/worker/gobigger_muzero_collector.py | 8 +- lzero/worker/gobigger_muzero_evaluator.py | 385 ++++++++++++++++-- .../config/gobigger_efficientzero_config.py | 7 +- zoo/gobigger/env/gobigger_env.py | 26 +- zoo/gobigger/env/gobigger_rule_bot.py | 211 ++++++++++ 6 files changed, 592 insertions(+), 49 deletions(-) create mode 100644 zoo/gobigger/env/gobigger_rule_bot.py diff --git a/lzero/entry/train_muzero_gobigger.py b/lzero/entry/train_muzero_gobigger.py index 4f44ed49d..20452f245 100644 --- a/lzero/entry/train_muzero_gobigger.py +++ b/lzero/entry/train_muzero_gobigger.py @@ -15,7 +15,6 @@ from lzero.entry.utils import log_buffer_memory_usage from lzero.policy import visit_count_temperature from lzero.worker import GoBiggerMuZeroCollector, GoBiggerMuZeroEvaluator -from gobigger.agents import BotAgent def train_muzero_gobigger( input_cfg: Tuple[dict, dict], @@ -123,11 +122,10 @@ def train_muzero_gobigger( trained_steps=learner.train_iter ) - stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) - # Evaluate policy performance. if evaluator.should_eval(learner.train_iter): stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + stop, reward= evaluator.eval_vsbot(learner.save_checkpoint, learner.train_iter, collector.envstep) if stop: break diff --git a/lzero/worker/gobigger_muzero_collector.py b/lzero/worker/gobigger_muzero_collector.py index d43527c9a..a715cce5b 100644 --- a/lzero/worker/gobigger_muzero_collector.py +++ b/lzero/worker/gobigger_muzero_collector.py @@ -377,8 +377,12 @@ def collect(self, # policy forward # ============================================================== policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play) + actions_no_env_id=defaultdict(dict) + for k,v in policy_output.items(): + for agent_id, act in enumerate(v['action']): + actions_no_env_id[k][agent_id] = act - actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} + # actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} distributions_dict_no_env_id = {k: v['distributions'] for k, v in policy_output.items()} if self.policy_config.sampled_algo: root_sampled_actions_dict_no_env_id = { @@ -562,7 +566,7 @@ def collect(self, action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) - agent_num = len(init_obs[0]['action_mask']) + for agent_id in range(agent_num): game_segments[env_id][agent_id] = GameSegment( self._env.action_space, diff --git a/lzero/worker/gobigger_muzero_evaluator.py b/lzero/worker/gobigger_muzero_evaluator.py index f6a6aeb8b..d7ace9b0b 100644 --- a/lzero/worker/gobigger_muzero_evaluator.py +++ b/lzero/worker/gobigger_muzero_evaluator.py @@ -1,7 +1,7 @@ import time import copy from collections import namedtuple -from typing import Optional, Callable, Tuple +from typing import Any, Optional, Callable, Tuple import numpy as np import torch @@ -13,6 +13,9 @@ from ding.worker.collector.base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor from lzero.mcts.buffer.game_segment import GameSegment from lzero.mcts.utils import prepare_observation +from collections import defaultdict +from zoo.gobigger.env.gobigger_rule_bot import GoBiggerBot +from collections import namedtuple, deque class GoBiggerMuZeroEvaluator(ISerialEvaluator): @@ -232,19 +235,21 @@ def eval( action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} + agent_num = len(init_obs[0]['action_mask']) dones = np.array([False for _ in range(env_nums)]) game_segments = [ - GameSegment( + [GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config - ) for _ in range(env_nums) + ) for _ in range(agent_num)] for _ in range(env_nums) ] - for i in range(env_nums): - game_segments[i].reset( - [to_ndarray(init_obs[i]['observation']) for _ in range(self.policy_config.model.frame_stack_num)] - ) + for env_id in range(env_nums): + for agent_id in range(agent_num): + game_segments[env_id][agent_id].reset( + [to_ndarray(init_obs[env_id]['observation'][agent_id]) for _ in range(self.policy_config.model.frame_stack_num)] + ) ready_env_id = set() remain_episode = n_episode @@ -257,7 +262,11 @@ def eval( ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) remain_episode -= min(len(new_available_env_id), remain_episode) - stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + stack_obs = defaultdict(list) + for env_id in ready_env_id: + for agent_id in range(agent_num): + stack_obs[env_id].append(game_segments[env_id][agent_id].get_obs()) + # stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} stack_obs = list(stack_obs.values()) action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} @@ -273,8 +282,11 @@ def eval( # policy forward # ============================================================== policy_output = self._policy.forward(stack_obs, action_mask, to_play) - - actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} + actions_no_env_id=defaultdict(dict) + for k,v in policy_output.items(): + for agent_id, act in enumerate(v['action']): + actions_no_env_id[k][agent_id] = act + # actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} distributions_dict_no_env_id = {k: v['distributions'] for k, v in policy_output.items()} if self.policy_config.sampled_algo: root_sampled_actions_dict_no_env_id = { @@ -313,10 +325,11 @@ def eval( for env_id, t in timesteps.items(): obs, reward, done, info = t.obs, t.reward, t.done, t.info - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id] - ) + for agent_id in range(agent_num): + game_segments[env_id][agent_id].append( + actions[env_id][agent_id], to_ndarray(obs['observation'][agent_id]), reward[agent_id], action_mask_dict[env_id][agent_id], + to_play_dict[env_id] + ) # NOTE: in evaluator, we only need save the ``o_{t+1} = obs['observation']`` # game_segments[env_id].obs_segment.append(to_ndarray(obs['observation'])) @@ -335,7 +348,7 @@ def eval( eval_monitor.update_info(env_id, t.info['episode_info']) eval_monitor.update_reward(env_id, reward) self._logger.info( - "[EVALUATOR]env {} finish episode, final reward: {}, current episode: {}".format( + "[EVALUATOR selfplay]env {} finish episode, final reward: {}, current episode: {}".format( env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode() ) ) @@ -370,18 +383,19 @@ def eval( action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) - game_segments[env_id] = GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) + for agent_id in range(agent_num): + game_segments[env_id][agent_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) - game_segments[env_id].reset( - [ - init_obs[env_id]['observation'] - for _ in range(self.policy_config.model.frame_stack_num) - ] - ) + game_segments[env_id][agent_id].reset( + [ + init_obs[env_id]['observation'][agent_id] + for _ in range(self.policy_config.model.frame_stack_num) + ] + ) # Env reset is done by env_manager automatically. self._policy.reset([env_id]) @@ -405,8 +419,12 @@ def eval( 'reward_std': np.std(episode_return), 'reward_max': np.max(episode_return), 'reward_min': np.min(episode_return), - # 'each_reward': episode_return, } + # add eat info + for i in range(len(t.info['eats'])//2): + for k,v in t.info['eats'][i].items(): + info['agent_{}_{}'.format(i, k)] = v + episode_info = eval_monitor.get_episode_info() if episode_info is not None: info.update(episode_info) @@ -420,6 +438,292 @@ def eval( self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) eval_reward = np.mean(episode_return) + # if eval_reward > self._max_eval_reward: + # if save_ckpt_fn: + # save_ckpt_fn('ckpt_best.pth.tar') + # self._max_eval_reward = eval_reward + stop_flag = eval_reward >= self._stop_value and train_iter > 0 + if stop_flag: + self._logger.info( + "[LightZero serial pipeline] " + + "Current eval_reward: {} is greater than stop_value: {}".format(eval_reward, self._stop_value) + + ", so your MCTS/RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details." + ) + return stop_flag, eval_reward + + def eval_vsbot( + self, + save_ckpt_fn: Callable = None, + train_iter: int = -1, + envstep: int = -1, + n_episode: Optional[int] = None, + ) -> Tuple[bool, float]: + """ + Overview: + Evaluate policy and store the best policy based on whether it reaches the highest historical reward. + Arguments: + - save_ckpt_fn (:obj:`Callable`): Saving ckpt function, which will be triggered by getting the best reward. + - train_iter (:obj:`int`): Current training iteration. + - envstep (:obj:`int`): Current env interaction step. + - n_episode (:obj:`int`): Number of evaluation episodes. + Returns: + - stop_flag (:obj:`bool`): Whether this training program can be ended. + - eval_reward (:obj:`float`): Current eval_reward. + """ + if n_episode is None: + n_episode = self._default_n_episode + assert n_episode is not None, "please indicate eval n_episode" + envstep_count = 0 + eval_monitor = GoBiggerVectorEvalMonitor(self._env.env_num, n_episode) + env_nums = self._env.env_num + + self._env.reset() + self._policy.reset() + self._bot_policy = GoBiggerBot(env_nums, agent_id=[2,3]) #TODO only support t2p2 + self._bot_policy.reset() + + # initializations + init_obs = self._env.ready_obs + agent_num = len(init_obs[0]['action_mask'])//2 #TODO only support t2p2 + + retry_waiting_time = 0.001 + while len(init_obs.keys()) != self._env_num: + # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to + # len(self._env.ready_obs), especially in tictactoe env. + self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) + self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) + time.sleep(retry_waiting_time) + self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) + self._logger.info( + 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states) + ) + init_obs = self._env.ready_obs + + for i in range(env_nums): + for k, v in init_obs[i].items(): + init_obs[i][k] = v[:agent_num] + + action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} + to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} + dones = np.array([False for _ in range(env_nums)]) + + game_segments = [ + [GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) for _ in range(agent_num)] for _ in range(env_nums) + ] + for env_id in range(env_nums): + for agent_id in range(agent_num): + game_segments[env_id][agent_id].reset( + [to_ndarray(init_obs[env_id]['observation'][agent_id]) for _ in range(self.policy_config.model.frame_stack_num)] + ) + + ready_env_id = set() + remain_episode = n_episode + + with self._timer: + while not eval_monitor.is_finished(): + # Get current ready env obs. + obs = self._env.ready_obs + raw_obs = [v['raw_obs'] for k,v in obs.items()] + new_available_env_id = set(obs.keys()).difference(ready_env_id) + ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) + remain_episode -= min(len(new_available_env_id), remain_episode) + + stack_obs = defaultdict(list) + for env_id in ready_env_id: + for agent_id in range(agent_num): + stack_obs[env_id].append(game_segments[env_id][agent_id].get_obs()) + # stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + stack_obs = list(stack_obs.values()) + + action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} + to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} + action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] + to_play = [to_play_dict[env_id] for env_id in ready_env_id] + + stack_obs = to_ndarray(stack_obs) + # stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) + # stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() + + # ============================================================== + # bot forward + # ============================================================== + bot_actions = self._bot_policy.forward(raw_obs) + + # ============================================================== + # policy forward + # ============================================================== + policy_output = self._policy.forward(stack_obs, action_mask, to_play) + actions_no_env_id=defaultdict(dict) + for k,v in policy_output.items(): + for agent_id, act in enumerate(v['action']): + actions_no_env_id[k][agent_id] = act + # actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} + distributions_dict_no_env_id = {k: v['distributions'] for k, v in policy_output.items()} + if self.policy_config.sampled_algo: + root_sampled_actions_dict_no_env_id = { + k: v['root_sampled_actions'] + for k, v in policy_output.items() + } + + value_dict_no_env_id = {k: v['value'] for k, v in policy_output.items()} + pred_value_dict_no_env_id = {k: v['pred_value'] for k, v in policy_output.items()} + visit_entropy_dict_no_env_id = { + k: v['visit_count_distribution_entropy'] + for k, v in policy_output.items() + } + + actions = {} + distributions_dict = {} + if self.policy_config.sampled_algo: + root_sampled_actions_dict = {} + value_dict = {} + pred_value_dict = {} + visit_entropy_dict = {} + for index, env_id in enumerate(ready_env_id): + actions[env_id] = actions_no_env_id.pop(index) + distributions_dict[env_id] = distributions_dict_no_env_id.pop(index) + if self.policy_config.sampled_algo: + root_sampled_actions_dict[env_id] = root_sampled_actions_dict_no_env_id.pop(index) + value_dict[env_id] = value_dict_no_env_id.pop(index) + pred_value_dict[env_id] = pred_value_dict_no_env_id.pop(index) + visit_entropy_dict[env_id] = visit_entropy_dict_no_env_id.pop(index) + + # ============================================================== + # Interact with env. + # ============================================================== + for env_id, v in bot_actions.items(): + actions[env_id].update(v) + + timesteps = self._env.step(actions) + + for env_id, t in timesteps.items(): + obs, reward, done, info = t.obs, t.reward, t.done, t.info + + for agent_id in range(agent_num): + game_segments[env_id][agent_id].append( + actions[env_id][agent_id], to_ndarray(obs['observation'][agent_id]), reward[agent_id], action_mask_dict[env_id][agent_id], + to_play_dict[env_id] + ) + + # NOTE: in evaluator, we only need save the ``o_{t+1} = obs['observation']`` + # game_segments[env_id].obs_segment.append(to_ndarray(obs['observation'])) + + # NOTE: the position of code snippet is very important. + # the obs['action_mask'] and obs['to_play'] is corresponding to next action + action_mask_dict[env_id] = to_ndarray(obs['action_mask']) + to_play_dict[env_id] = to_ndarray(obs['to_play']) + + dones[env_id] = done + if t.done: + # Env reset is done by env_manager automatically. + self._policy.reset([env_id]) + reward = t.info['eval_episode_return'][0] + bot_reward = t.info['eval_episode_return'][1] + if 'episode_info' in t.info: + eval_monitor.update_info(env_id, t.info['episode_info']) + eval_monitor.update_reward(env_id, reward) + eval_monitor.update_bot_reward(env_id, bot_reward) + self._logger.info( + "[EVALUATOR vsbot]env {} finish episode, final reward: {}, current episode: {}".format( + env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode() + ) + ) + + # reset the finished env and init game_segments + if n_episode > self._env_num: + # Get current ready env obs. + init_obs = self._env.ready_obs + retry_waiting_time = 0.001 + while len(init_obs.keys()) != self._env_num: + # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to + # len(self._env.ready_obs), especially in tictactoe env. + self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) + self._logger.info( + 'Before sleeping, the _env_states is {}'.format(self._env._env_states) + ) + time.sleep(retry_waiting_time) + self._logger.info( + '=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10 + ) + self._logger.info( + 'After sleeping {}s, the current _env_states is {}'.format( + retry_waiting_time, self._env._env_states + ) + ) + init_obs = self._env.ready_obs + + new_available_env_id = set(init_obs.keys()).difference(ready_env_id) + ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) + remain_episode -= min(len(new_available_env_id), remain_episode) + + action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) + to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) + + for agent_id in range(agent_num): + game_segments[env_id][agent_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + + game_segments[env_id][agent_id].reset( + [ + init_obs[env_id]['observation'][agent_id] + for _ in range(self.policy_config.model.frame_stack_num) + ] + ) + + # Env reset is done by env_manager automatically. + self._policy.reset([env_id]) + self._bot_policy.reset([env_id]) + # TODO(pu): subprocess mode, when n_episode > self._env_num, occasionally the ready_env_id=() + # and the stack_obs is np.array(None, dtype=object) + ready_env_id.remove(env_id) + + envstep_count += 1 + duration = self._timer.value + episode_return = eval_monitor.get_episode_return() + bot_episode_return = eval_monitor.get_bot_episode_return() + info = { + 'train_iter': train_iter, + 'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter), + 'episode_count': n_episode, + 'envstep_count': envstep_count, + 'avg_envstep_per_episode': envstep_count / n_episode, + 'evaluate_time': duration, + 'avg_envstep_per_sec': envstep_count / duration, + 'avg_time_per_episode': n_episode / duration, + 'reward_mean': np.mean(episode_return), + 'reward_std': np.std(episode_return), + 'reward_max': np.max(episode_return), + 'reward_min': np.min(episode_return), + 'bot_reward_mean': np.mean(bot_episode_return), + 'bot_reward_std': np.std(bot_episode_return), + 'bot_reward_max': np.max(bot_episode_return), + 'bot_reward_min': np.min(bot_episode_return), + } + # add eat info + for i in range(len(t.info['eats'])): + for k,v in t.info['eats'][i].items(): + info['agent_{}_{}'.format(i, k)] = v + + episode_info = eval_monitor.get_episode_info() + if episode_info is not None: + info.update(episode_info) + self._logger.info(self._logger.get_tabulate_vars_hor(info)) + # self._logger.info(self._logger.get_tabulate_vars(info)) + for k, v in info.items(): + if k in ['train_iter', 'ckpt_name', 'each_reward']: + continue + if not np.isscalar(v): + continue + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name+'_vsbot') + k, v, train_iter) + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name+'_vsbot') + k, v, envstep) + eval_reward = np.mean(episode_return) if eval_reward > self._max_eval_reward: if save_ckpt_fn: save_ckpt_fn('ckpt_best.pth.tar') @@ -432,3 +736,30 @@ def eval( ", so your MCTS/RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details." ) return stop_flag, eval_reward + + +class GoBiggerVectorEvalMonitor(VectorEvalMonitor): + + def __init__(self, env_num: int, n_episode: int) -> None: + super().__init__(env_num, n_episode) + each_env_episode = [n_episode // env_num for _ in range(env_num)] + self._bot_reward = {env_id: deque(maxlen=maxlen) for env_id, maxlen in enumerate(each_env_episode)} + + def get_bot_episode_return(self) -> list: + """ + Overview: + Sum up all reward and get the total return of one episode. + """ + return sum([list(v) for v in self._bot_reward.values()], []) # sum(iterable, start) + + def update_bot_reward(self, env_id: int, reward: Any) -> None: + """ + Overview: + Update the reward indicated by env_id. + Arguments: + - env_id: (:obj:`int`): the id of the environment we need to update the reward + - reward: (:obj:`Any`): the reward we need to update + """ + if isinstance(reward, torch.Tensor): + reward = reward.item() + self._bot_reward[env_id].append(reward) \ No newline at end of file diff --git a/zoo/gobigger/config/gobigger_efficientzero_config.py b/zoo/gobigger/config/gobigger_efficientzero_config.py index 189f60e50..b0268b566 100644 --- a/zoo/gobigger/config/gobigger_efficientzero_config.py +++ b/zoo/gobigger/config/gobigger_efficientzero_config.py @@ -5,11 +5,11 @@ # ============================================================== # begin of the most frequently changed config specified by the user # ============================================================== -collector_env_num = 16 -n_episode = 16 +collector_env_num = 32 +n_episode = 32 evaluator_env_num = 5 num_simulations = 50 -update_per_collect = 1000 +update_per_collect = 2000 batch_size = 256 reanalyze_ratio = 0. action_space_size = 27 @@ -69,7 +69,6 @@ ), policy=dict( model=dict( - # observation_shape=(4, 96, 96), latent_state_dim=176, frame_stack_num=1, action_space_size=action_space_size, diff --git a/zoo/gobigger/env/gobigger_env.py b/zoo/gobigger/env/gobigger_env.py index 46168b8ae..209407e28 100644 --- a/zoo/gobigger/env/gobigger_env.py +++ b/zoo/gobigger/env/gobigger_env.py @@ -54,11 +54,11 @@ def observation(self, raw_obs): obs = self.preprocess_obs(raw_obs) action_mask = [np.logical_not(o['action_mask']) for o in obs] to_play = [ -1 for _ in range(len(obs))] - obs = {'observation': obs, 'action_mask': action_mask, 'to_play': to_play} + obs = {'observation': obs, 'action_mask': action_mask, 'to_play': to_play, 'raw_obs':raw_obs} return obs - def step(self, action: list) -> BaseEnvTimestep: - action = {i: self.transform_action(a) for i, a in enumerate(action)} + def step(self, action: dict) -> BaseEnvTimestep: + action = {k: self.transform_action(v) if np.isscalar(v) else v for k, v in action.items()} raw_obs, raw_rew, done, info = self._env.step(action) # print('current_frame={}'.format(raw_obs[0]['last_time'])) # print('raw_rew={}'.format(raw_rew)) @@ -476,22 +476,22 @@ def transform_reward(self, next_obs): # save_frame=False, save_frame=True, save_dir='./', - save_name_prefix='gobigger-1', + save_name_prefix='gobigger-bot', ), ), )) env = GoBiggerLightZeroEnv(env_cfg) - env.reset() - import random + obs = env.reset() + from gobigger_rule_bot import BotAgent + bot = [BotAgent(i) for i in range(4)] while True: - actions = [random.randint(0, 26), random.randint(0, 26), random.randint(0, 26), random.randint(0, 26)] - # actions = {0: [random.uniform(-1, 1), random.uniform(-1, 1), -1], - # 1: [random.uniform(-1, 1), random.uniform(-1, 1), -1], - # 2: [random.uniform(-1, 1), random.uniform(-1, 1), -1], - # 3: [random.uniform(-1, 1), random.uniform(-1, 1), -1]} - timestep = env.step(actions) - if timestep.done: + actions = {} + for i in range(4): + # bot[i].step(obs['raw_obs'] is dict + actions.update(bot[i].step(obs['raw_obs'])) + obs, rew, done, info = env.step(actions) + if done: break # from ding.envs import create_env_manager diff --git a/zoo/gobigger/env/gobigger_rule_bot.py b/zoo/gobigger/env/gobigger_rule_bot.py new file mode 100644 index 000000000..d7a16b87f --- /dev/null +++ b/zoo/gobigger/env/gobigger_rule_bot.py @@ -0,0 +1,211 @@ +import copy +from ding.policy.base_policy import Policy +from ding.utils import POLICY_REGISTRY +import torch +import math +import queue +import random +from pygame.math import Vector2 +from typing import List, Dict, Any, Optional, Tuple, Union +from collections import namedtuple +from collections import defaultdict + + +@POLICY_REGISTRY.register('gobigger_bot') +class GoBiggerBot(Policy): + + def __init__(self, env_num, agent_id: List[int]): + self.env_num = env_num + self.agent_id = agent_id + self.bot = [[BotAgent(i) for i in self.agent_id] for _ in range(self.env_num)] + + def forward(self, raw_obs): + action = defaultdict(dict) + for env_id in range(self.env_num): + obs = raw_obs[env_id] + for agent in self.bot[env_id]: + action[env_id].update(agent.step(obs)) + return action + + def reset(self, env_id_lst=None): + if env_id_lst is None: + env_id_lst = range(self.env_num) + for env_id in env_id_lst: + for agent in self.bot[env_id]: + agent.reset() + + def _init_learn(self) -> None: + pass + + def _init_collect(self) -> None: + pass + + def _init_eval(self) -> None: + pass + + def _forward_learn(self, data: dict) -> dict: + pass + + def _forward_collect(self, envs: Dict, obs: Dict, temperature: float = 1) -> Dict[str, torch.Tensor]: + pass + + def _forward_eval(self, data: dict) -> dict: + pass + + def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: + pass + + def _get_train_sample(self, data: list) -> Union[None, List[Any]]: + pass + + def default_model(self) -> Tuple[str, List[str]]: + return 'bot_model', ['lzero.model.bot_model'] + + def _monitor_vars_learn(self) -> List[str]: + pass + +class BotAgent(): + + def __init__(self, game_player_id): + self.game_player_id = game_player_id # start from 0 + self.actions_queue = queue.Queue() + + def step(self, obs): + obs = obs[1][self.game_player_id] + if self.actions_queue.qsize() > 0: + return {self.game_player_id: self.actions_queue.get()} + overlap = obs['overlap'] + overlap = self.preprocess(overlap) + food_balls = overlap['food'] + thorns_balls = overlap['thorns'] + spore_balls = overlap['spore'] + clone_balls = overlap['clone'] + + my_clone_balls, others_clone_balls = self.process_clone_balls(clone_balls) + + if len(my_clone_balls) >= 9 and my_clone_balls[4]['radius'] > 4: + self.actions_queue.put([0, 0, 0]) + self.actions_queue.put([0, 0, 0]) + self.actions_queue.put([0, 0, 0]) + self.actions_queue.put([0, 0, 0]) + self.actions_queue.put([0, 0, 0]) + self.actions_queue.put([0, 0, 0]) + self.actions_queue.put([0, 0, 0]) + self.actions_queue.put([None, None, 1]) + self.actions_queue.put([None, None, 1]) + self.actions_queue.put([None, None, 1]) + self.actions_queue.put([None, None, 1]) + self.actions_queue.put([None, None, 1]) + self.actions_queue.put([None, None, 1]) + self.actions_queue.put([None, None, 1]) + self.actions_queue.put([None, None, 1]) + action_ret = self.actions_queue.get() + return {self.game_player_id: action_ret} + + if len(others_clone_balls) > 0 and self.can_eat(others_clone_balls[0]['radius'], my_clone_balls[0]['radius']): + direction = (my_clone_balls[0]['position'] - others_clone_balls[0]['position']) + action_type = 0 + else: + min_distance, min_thorns_ball = self.process_thorns_balls(thorns_balls, my_clone_balls[0]) + if min_thorns_ball is not None: + direction = (min_thorns_ball['position'] - my_clone_balls[0]['position']) + else: + min_distance, min_food_ball = self.process_food_balls(food_balls, my_clone_balls[0]) + if min_food_ball is not None: + direction = (min_food_ball['position'] - my_clone_balls[0]['position']) + else: + direction = (Vector2(0, 0) - my_clone_balls[0]['position']) + action_random = random.random() + if action_random < 0.02: + action_type = 1 + if action_random < 0.04 and action_random > 0.02: + action_type = 2 + else: + action_type = 0 + if direction.length()>0: + direction = direction.normalize() + else: + direction = Vector2(1, 1).normalize() + direction = self.add_noise_to_direction(direction).normalize() + self.actions_queue.put([direction.x, direction.y, action_type]) + action_ret = self.actions_queue.get() + return {self.game_player_id: action_ret} + + def process_clone_balls(self, clone_balls): + my_clone_balls = [] + others_clone_balls = [] + for clone_ball in clone_balls: + if clone_ball['player'] == self.game_player_id: + my_clone_balls.append(copy.deepcopy(clone_ball)) + my_clone_balls.sort(key=lambda a: a['radius'], reverse=True) + + for clone_ball in clone_balls: + if clone_ball['player'] != self.game_player_id: + others_clone_balls.append(copy.deepcopy(clone_ball)) + others_clone_balls.sort(key=lambda a: a['radius'], reverse=True) + return my_clone_balls, others_clone_balls + + def process_thorns_balls(self, thorns_balls, my_max_clone_ball): + min_distance = 10000 + min_thorns_ball = None + for thorns_ball in thorns_balls: + if self.can_eat(my_max_clone_ball['radius'], thorns_ball['radius']): + distance = (thorns_ball['position'] - my_max_clone_ball['position']).length() + if distance < min_distance: + min_distance = distance + min_thorns_ball = copy.deepcopy(thorns_ball) + return min_distance, min_thorns_ball + + def process_food_balls(self, food_balls, my_max_clone_ball): + min_distance = 10000 + min_food_ball = None + for food_ball in food_balls: + distance = (food_ball['position'] - my_max_clone_ball['position']).length() + if distance < min_distance: + min_distance = distance + min_food_ball = copy.deepcopy(food_ball) + return min_distance, min_food_ball + + def preprocess(self, overlap): + new_overlap = {} + for k, v in overlap.items(): + if k =='clone': + new_overlap[k] = [] + for index, vv in enumerate(v): + tmp={} + tmp['position'] = Vector2(vv[0],vv[1]) + tmp['radius'] = vv[2] + tmp['player'] = int(vv[-2]) + tmp['team'] = int(vv[-1]) + new_overlap[k].append(tmp) + else: + new_overlap[k] = [] + for index, vv in enumerate(v): + tmp={} + tmp['position'] = Vector2(vv[0],vv[1]) + tmp['radius'] = vv[2] + new_overlap[k].append(tmp) + return new_overlap + + def preprocess_tuple2vector(self, overlap): + new_overlap = {} + for k, v in overlap.items(): + new_overlap[k] = [] + for index, vv in enumerate(v): + new_overlap[k].append(vv) + new_overlap[k][index]['position'] = Vector2(*vv['position']) + return new_overlap + + def add_noise_to_direction(self, direction, noise_ratio=0.1): + direction = direction + Vector2(((random.random() * 2 - 1)*noise_ratio)*direction.x, + ((random.random() * 2 - 1)*noise_ratio)*direction.y) + return direction + + def radius_to_score(self, radius): + return (math.pow(radius,2) - 0.15) / 0.042 * 100 + + def can_eat(self, radius1, radius2): + return self.radius_to_score(radius1) > 1.3 * self.radius_to_score(radius2) + + def reset(self,): + self.actions_queue.queue.clear() \ No newline at end of file From 0875e74c2dcbdb6faf2701eb1c8fb83157784ef5 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Fri, 2 Jun 2023 00:46:22 +0800 Subject: [PATCH 04/54] feature(yzj): add vsbot with mz pipeline and polish model and buffer --- lzero/entry/train_muzero_gobigger.py | 16 +- lzero/mcts/buffer/__init__.py | 1 + .../gobigger_game_buffer_efficientzero.py | 447 +++++++++++ .../buffer/gobigger_game_buffer_muzero.py | 3 +- ...mlp.py => gobigger_efficientzero_model.py} | 8 +- ...{gobigger_model.py => gobigger_encoder.py} | 0 lzero/model/gobigger/gobigger_muzero_model.py | 455 +++++++++++ lzero/policy/gobigger_efficientzero.py | 24 +- lzero/policy/gobigger_muzero.py | 724 ++++++++++++++++++ zoo/gobigger/config/gobigger_muzero_config.py | 142 ++++ 10 files changed, 1783 insertions(+), 37 deletions(-) create mode 100644 lzero/mcts/buffer/gobigger_game_buffer_efficientzero.py rename lzero/model/gobigger/{gobigger_efficientzero_model_mlp.py => gobigger_efficientzero_model.py} (99%) rename lzero/model/gobigger/{gobigger_model.py => gobigger_encoder.py} (100%) create mode 100644 lzero/model/gobigger/gobigger_muzero_model.py create mode 100644 lzero/policy/gobigger_muzero.py create mode 100644 zoo/gobigger/config/gobigger_muzero_config.py diff --git a/lzero/entry/train_muzero_gobigger.py b/lzero/entry/train_muzero_gobigger.py index 20452f245..60952a353 100644 --- a/lzero/entry/train_muzero_gobigger.py +++ b/lzero/entry/train_muzero_gobigger.py @@ -42,16 +42,12 @@ def train_muzero_gobigger( """ cfg, create_cfg = input_cfg - assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gobigger_efficientzero'], \ - "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero'" - - if create_cfg.policy.type == 'muzero': - from lzero.mcts import MuZeroGameBuffer as GameBuffer - elif create_cfg.policy.type == 'efficientzero': - from lzero.mcts import EfficientZeroGameBuffer as GameBuffer - elif create_cfg.policy.type == 'sampled_efficientzero': - from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer - elif create_cfg.policy.type == 'gobigger_efficientzero': + assert create_cfg.policy.type in ['gobigger_efficientzero', 'gobigger_muzero'], \ + "train_muzero entry now only support the following algo.: 'gobigger_efficientzero', 'gobigger_muzero'" + + if create_cfg.policy.type == 'gobigger_efficientzero': + from lzero.mcts import GoBiggerEfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'gobigger_muzero': from lzero.mcts import GoBiggerMuZeroGameBuffer as GameBuffer if cfg.policy.cuda and torch.cuda.is_available(): diff --git a/lzero/mcts/buffer/__init__.py b/lzero/mcts/buffer/__init__.py index 5be32de5c..e5a245ee9 100644 --- a/lzero/mcts/buffer/__init__.py +++ b/lzero/mcts/buffer/__init__.py @@ -2,3 +2,4 @@ from .game_buffer_efficientzero import EfficientZeroGameBuffer from .game_buffer_sampled_efficientzero import SampledEfficientZeroGameBuffer from .gobigger_game_buffer_muzero import GoBiggerMuZeroGameBuffer +from .gobigger_game_buffer_efficientzero import GoBiggerEfficientZeroGameBuffer diff --git a/lzero/mcts/buffer/gobigger_game_buffer_efficientzero.py b/lzero/mcts/buffer/gobigger_game_buffer_efficientzero.py new file mode 100644 index 000000000..b8c76cac7 --- /dev/null +++ b/lzero/mcts/buffer/gobigger_game_buffer_efficientzero.py @@ -0,0 +1,447 @@ +from typing import Any, List + +import numpy as np +import torch +from ding.utils import BUFFER_REGISTRY + +from lzero.mcts.tree_search.mcts_ctree import EfficientZeroMCTSCtree as MCTSCtree +from lzero.mcts.tree_search.mcts_ptree import EfficientZeroMCTSPtree as MCTSPtree +from lzero.mcts.utils import prepare_observation +from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform +from .gobigger_game_buffer_muzero import GoBiggerMuZeroGameBuffer +from ding.torch_utils import to_device, to_tensor, to_ndarray + + +@BUFFER_REGISTRY.register('gobigger_game_buffer_efficientzero') +class GoBiggerEfficientZeroGameBuffer(GoBiggerMuZeroGameBuffer): + """ + Overview: + The specific game buffer for EfficientZero policy. + """ + + def __init__(self, cfg: dict): + super().__init__(cfg) + """ + Overview: + Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key + in the default configuration, the user-provided value will override the default configuration. Otherwise, + the default configuration will be used. + """ + default_config = self.default_config() + default_config.update(cfg) + self._cfg = default_config + assert self._cfg.env_type in ['not_board_games', 'board_games'] + self.replay_buffer_size = self._cfg.replay_buffer_size + self.batch_size = self._cfg.batch_size + self._alpha = self._cfg.priority_prob_alpha + self._beta = self._cfg.priority_prob_beta + + self.game_segment_buffer = [] + self.game_pos_priorities = [] + self.game_segment_game_pos_look_up = [] + + self.keep_ratio = 1 + self.num_of_collected_episodes = 0 + self.base_idx = 0 + self.clear_time = 0 + + self.tmp_obs = None # for value obs list [46 + 4(td_step)] not < 50(game_segment) + + def sample(self, batch_size: int, policy: Any) -> List[Any]: + """ + Overview: + sample data from ``GameBuffer`` and prepare the current and target batch for training + Arguments: + - batch_size (:obj:`int`): batch size + - policy (:obj:`torch.tensor`): model of policy + Returns: + - train_data (:obj:`List`): List of train data + """ + policy._target_model.to(self._cfg.device) + policy._target_model.eval() + + # obtain the current_batch and prepare target context + reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch( + batch_size, self._cfg.reanalyze_ratio + ) + + # target value_prefixs, target value + batch_value_prefixs, batch_target_values = self._compute_target_reward_value( + reward_value_context, policy._target_model + ) + # target policy + batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model) + batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( + policy_non_re_context, self._cfg.model.action_space_size + ) + + if 0 < self._cfg.reanalyze_ratio < 1: + batch_target_policies = np.concatenate([batch_target_policies_re, batch_target_policies_non_re]) + elif self._cfg.reanalyze_ratio == 1: + batch_target_policies = batch_target_policies_re + elif self._cfg.reanalyze_ratio == 0: + batch_target_policies = batch_target_policies_non_re + + target_batch = [batch_value_prefixs, batch_target_values, batch_target_policies] + # a batch contains the current_batch and the target_batch + train_data = [current_batch, target_batch] + return train_data + + def _prepare_reward_value_context( + self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[Any], + total_transitions: int + ) -> List[Any]: + """ + Overview: + prepare the context of rewards and values for calculating TD value target in reanalyzing part. + Arguments: + - batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer + - game_segment_list (:obj:`list`): list of game segments + - pos_in_game_segment_list (:obj:`list`): list of transition index in game_segment + - total_transitions (:obj:`int`): number of collected transitions + Returns: + - reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, + td_steps_list, action_mask_segment, to_play_segment + """ + # zero_obs = game_segment_list[0].zero_obs() + value_obs_list = [] + # the value is valid or not (out of trajectory) + value_mask = [] + rewards_list = [] + game_segment_lens = [] + # for two_player board games + action_mask_segment, to_play_segment = [], [] + + td_steps_list = [] + for game_segment, state_index, idx in zip(game_segment_list, pos_in_game_segment_list, batch_index_list): + game_segment_len = len(game_segment) + game_segment_lens.append(game_segment_len) + + # ============================================================== + # EfficientZero related core code + # ============================================================== + # TODO(pu): + # for atari, off-policy correction: shorter horizon of td steps + # delta_td = (total_transitions - idx) // config.auto_td_steps + # td_steps = config.td_steps - delta_td + # td_steps = np.clip(td_steps, 1, 5).astype(np.int) + td_steps = np.clip(self._cfg.td_steps, 1, max(1, game_segment_len - state_index)).astype(np.int32) + + # prepare the corresponding observations for bootstrapped values o_{t+k} + # o[t+ td_steps, t + td_steps + stack frames + num_unroll_steps] + # t=2+3 -> o[2+3, 2+3+4+5] -> o[5, 14] + game_obs = game_segment.get_unroll_obs(state_index + td_steps, self._cfg.num_unroll_steps) + + rewards_list.append(game_segment.reward_segment) + + # for two_player board games + action_mask_segment.append(game_segment.action_mask_segment) + to_play_segment.append(game_segment.to_play_segment) + + for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + # get the bootstrapped target obs + td_steps_list.append(td_steps) + # index of bootstrapped obs o_{t+td_steps} + bootstrap_index = current_index + td_steps + + if bootstrap_index < game_segment_len: + value_mask.append(1) + # beg_index = bootstrap_index - (state_index + td_steps), max of beg_index is num_unroll_steps + beg_index = current_index - state_index + end_index = beg_index + self._cfg.model.frame_stack_num + # the stacked obs in time t + obs = game_obs[beg_index:end_index] + self.tmp_obs = obs # will be masked + else: + value_mask.append(0) + obs = self.tmp_obs # will be masked + value_obs_list.append(obs.tolist()) + + reward_value_context = [ + value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, + action_mask_segment, to_play_segment + ] + return reward_value_context + + def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any) -> List[np.ndarray]: + """ + Overview: + prepare reward and value targets from the context of rewards and values. + Arguments: + - reward_value_context (:obj:'list'): the reward value context + - model (:obj:'torch.tensor'):model of the target model + Returns: + - batch_value_prefixs (:obj:'np.ndarray): batch of value prefix + - batch_target_values (:obj:'np.ndarray): batch of value estimation + """ + value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, \ + to_play_segment = reward_value_context # noqa + # transition_batch_size = game_segment_batch_size * (num_unroll_steps+1) + transition_batch_size = len(value_obs_list) + game_segment_batch_size = len(pos_in_game_segment_list) + + to_play, action_mask = self._preprocess_to_play_and_action_mask( + game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list + ) + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] + + # ============================================================== + # EfficientZero related core code + # ============================================================== + batch_target_values, batch_value_prefixs = [], [] + with torch.no_grad(): + #value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) + # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors + slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) + network_output = [] + for i in range(slices): + beg_index = self._cfg.mini_infer_size * i + end_index = self._cfg.mini_infer_size * (i + 1) + + #m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() + m_obs = value_obs_list[beg_index:end_index] + m_obs = to_tensor(m_obs) + m_obs = sum(m_obs, []) + m_obs = to_device(m_obs, self._cfg.device) + + # calculate the target value + m_output = model.initial_inference(m_obs) + if not model.training: + # ============================================================== + # EfficientZero related core code + # ============================================================== + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + m_output.policy_logits + ] + ) + m_output.reward_hidden_state = ( + m_output.reward_hidden_state[0].detach().cpu().numpy(), + m_output.reward_hidden_state[1].detach().cpu().numpy() + ) + network_output.append(m_output) + + # concat the output slices after model inference + if self._cfg.use_root_value: + # use the root values from MCTS, as in EfficiientZero + # the root values have limited improvement but require much more GPU actors; + _, value_prefix_pool, policy_logits_pool, latent_state_roots, reward_hidden_state_roots = concat_output( + network_output, data_type='efficientzero' + ) + value_prefix_pool = value_prefix_pool.squeeze().tolist() + policy_logits_pool = policy_logits_pool.tolist() + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size + ).astype(np.float32).tolist() for _ in range(transition_batch_size) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(transition_batch_size, legal_actions) + roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play) + # do MCTS for a new policy with the recent target model + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play) + else: + # python mcts_tree + roots = MCTSPtree.roots(transition_batch_size, legal_actions) + roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play) + # do MCTS for a new policy with the recent target model + MCTSPtree(self._cfg).search( + roots, model, latent_state_roots, reward_hidden_state_roots, to_play=to_play + ) + roots_values = roots.get_values() + value_list = np.array(roots_values) + else: + # use the predicted values + value_list = concat_output_value(network_output) + + # get last state value + if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: + # TODO(pu): for board_games, very important, to check + value_list = value_list.reshape(-1) * np.array( + [ + self._cfg.discount_factor ** td_steps_list[i] if int(td_steps_list[i]) % + 2 == 0 else -self._cfg.discount_factor ** td_steps_list[i] + for i in range(transition_batch_size) + ] + ) + else: + value_list = value_list.reshape(-1) * ( + np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list + ) + + value_list = value_list * np.array(value_mask) + value_list = value_list.tolist() + horizon_id, value_index = 0, 0 + for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list, + pos_in_game_segment_list, + to_play_segment): + target_values = [] + target_value_prefixs = [] + value_prefix = 0.0 + base_index = state_index + for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + bootstrap_index = current_index + td_steps_list[value_index] + for i, reward in enumerate(reward_list[current_index:bootstrap_index]): + if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: + # TODO(pu): for board_games, very important, to check + if to_play_list[base_index] == to_play_list[i]: + value_list[value_index] += reward * self._cfg.discount_factor ** i + else: + value_list[value_index] += -reward * self._cfg.discount_factor ** i + else: + value_list[value_index] += reward * self._cfg.discount_factor ** i + + # reset every lstm_horizon_len + if horizon_id % self._cfg.lstm_horizon_len == 0: + value_prefix = 0.0 + base_index = current_index + horizon_id += 1 + + if current_index < game_segment_len_non_re: + target_values.append(value_list[value_index]) + # TODO: Since the horizon is small and the discount_factor is close to 1. + # Compute the reward sum to approximate the value prefix for simplification + value_prefix += reward_list[current_index + ] # * self._cfg.discount_factor ** (current_index - base_index) + target_value_prefixs.append(value_prefix) + else: + target_values.append(0) + target_value_prefixs.append(value_prefix) + value_index += 1 + batch_value_prefixs.append(target_value_prefixs) + batch_target_values.append(target_values) + batch_value_prefixs = np.asarray(batch_value_prefixs, dtype=object) + batch_target_values = np.asarray(batch_target_values, dtype=object) + + return batch_value_prefixs, batch_target_values + + def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any) -> np.ndarray: + """ + Overview: + prepare policy targets from the reanalyzed context of policies + Arguments: + - policy_re_context (:obj:`List`): List of policy context to reanalyzed + Returns: + - batch_target_policies_re + """ + if policy_re_context is None: + return [] + batch_target_policies_re = [] + + policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, action_mask_segment, \ + to_play_segment = policy_re_context # noqa + # transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1) + transition_batch_size = len(policy_obs_list) + game_segment_batch_size = len(pos_in_game_segment_list) + to_play, action_mask = self._preprocess_to_play_and_action_mask( + game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list + ) + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] + with torch.no_grad(): + policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type) + # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors + slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) + network_output = [] + for i in range(slices): + beg_index = self._cfg.mini_infer_size * i + end_index = self._cfg.mini_infer_size * (i + 1) + m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device).float() + + m_output = model.initial_inference(m_obs) + + if not model.training: + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + m_output.policy_logits + ] + ) + m_output.reward_hidden_state = ( + m_output.reward_hidden_state[0].detach().cpu().numpy(), + m_output.reward_hidden_state[1].detach().cpu().numpy() + ) + + network_output.append(m_output) + + _, value_prefix_pool, policy_logits_pool, latent_state_roots, reward_hidden_state_roots = concat_output( + network_output, data_type='efficientzero' + ) + value_prefix_pool = value_prefix_pool.squeeze().tolist() + policy_logits_pool = policy_logits_pool.tolist() + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size + ).astype(np.float32).tolist() for _ in range(transition_batch_size) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(transition_batch_size, legal_actions) + roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play) + # do MCTS for a new policy with the recent target model + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play) + else: + # python mcts_tree + roots = MCTSPtree.roots(transition_batch_size, legal_actions) + roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play) + # do MCTS for a new policy with the recent target model + MCTSPtree(self._cfg).search( + roots, model, latent_state_roots, reward_hidden_state_roots, to_play=to_play + ) + + roots_legal_actions_list = legal_actions + roots_distributions = roots.get_distributions() + policy_index = 0 + for state_index, game_index in zip(pos_in_game_segment_list, batch_index_list): + target_policies = [] + for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + distributions = roots_distributions[policy_index] + if policy_mask[policy_index] == 0: + # NOTE: the invalid padding target policy, O is to make sure the correspoding cross_entropy_loss=0 + target_policies.append([0 for _ in range(self._cfg.model.action_space_size)]) + else: + if distributions is None: + # if at some obs, the legal_action is None, add the fake target_policy + target_policies.append( + list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size) + ) + else: + if self._cfg.mcts_ctree: + # cpp mcts_tree + if self._cfg.env_type == 'not_board_games': + sum_visits = sum(distributions) + policy = [visit_count / sum_visits for visit_count in distributions] + target_policies.append(policy) + else: + # for two_player board games + policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)] + # to make sure target_policies have the same dimension + sum_visits = sum(distributions) + policy = [visit_count / sum_visits for visit_count in distributions] + for index, legal_action in enumerate(roots_legal_actions_list[policy_index]): + policy_tmp[legal_action] = policy[index] + target_policies.append(policy_tmp) + else: + # python mcts_tree + if self._cfg.env_type == 'not_board_games': + sum_visits = sum(distributions) + policy = [visit_count / sum_visits for visit_count in distributions] + target_policies.append(policy) + else: + # for two_player board games + policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)] + # to make sure target_policies have the same dimension + sum_visits = sum(distributions) + policy = [visit_count / sum_visits for visit_count in distributions] + for index, legal_action in enumerate(roots_legal_actions_list[policy_index]): + policy_tmp[legal_action] = policy[index] + target_policies.append(policy_tmp) + policy_index += 1 + batch_target_policies_re.append(target_policies) + batch_target_policies_re = np.array(batch_target_policies_re) + return batch_target_policies_re diff --git a/lzero/mcts/buffer/gobigger_game_buffer_muzero.py b/lzero/mcts/buffer/gobigger_game_buffer_muzero.py index 3e30d0656..5df565e74 100644 --- a/lzero/mcts/buffer/gobigger_game_buffer_muzero.py +++ b/lzero/mcts/buffer/gobigger_game_buffer_muzero.py @@ -9,11 +9,10 @@ from lzero.mcts.utils import prepare_observation from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform from .game_buffer import GameBuffer -from ding.torch_utils import to_device +from ding.torch_utils import to_device, to_tensor if TYPE_CHECKING: from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy -from ding.torch_utils import to_tensor, squeeze @BUFFER_REGISTRY.register('gobigger_game_buffer_muzero') class GoBiggerMuZeroGameBuffer(GameBuffer): diff --git a/lzero/model/gobigger/gobigger_efficientzero_model_mlp.py b/lzero/model/gobigger/gobigger_efficientzero_model.py similarity index 99% rename from lzero/model/gobigger/gobigger_efficientzero_model_mlp.py rename to lzero/model/gobigger/gobigger_efficientzero_model.py index ee8dbfb40..0ab0c5c49 100644 --- a/lzero/model/gobigger/gobigger_efficientzero_model_mlp.py +++ b/lzero/model/gobigger/gobigger_efficientzero_model.py @@ -8,13 +8,13 @@ from ..common import EZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP from ..utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean -from .gobigger_model import Encoder +from .gobigger_encoder import Encoder import yaml from easydict import EasyDict from ding.utils.data import default_collate -@MODEL_REGISTRY.register('EfficientZeroModelMLP') -class GoBiggerEfficientZeroModelMLP(nn.Module): +@MODEL_REGISTRY.register('GoBiggerEfficientZeroModel') +class GoBiggerEfficientZeroModel(nn.Module): def __init__( self, @@ -74,7 +74,7 @@ def __init__( - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. """ - super(GoBiggerEfficientZeroModelMLP, self).__init__() + super(GoBiggerEfficientZeroModel, self).__init__() if not categorical_distribution: self.reward_support_size = 1 self.value_support_size = 1 diff --git a/lzero/model/gobigger/gobigger_model.py b/lzero/model/gobigger/gobigger_encoder.py similarity index 100% rename from lzero/model/gobigger/gobigger_model.py rename to lzero/model/gobigger/gobigger_encoder.py diff --git a/lzero/model/gobigger/gobigger_muzero_model.py b/lzero/model/gobigger/gobigger_muzero_model.py new file mode 100644 index 000000000..709d23c53 --- /dev/null +++ b/lzero/model/gobigger/gobigger_muzero_model.py @@ -0,0 +1,455 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from ding.torch_utils import MLP +from ding.utils import MODEL_REGISTRY, SequenceType + +from ..common import MZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP +from ..utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean +from .gobigger_encoder import Encoder +import yaml +from easydict import EasyDict +from ding.utils.data import default_collate + + +@MODEL_REGISTRY.register('GoBiggerMuZeroModel') +class GoBiggerMuZeroModel(nn.Module): + + def __init__( + self, + observation_shape: int = 2, + action_space_size: int = 6, + latent_state_dim: int = 256, + fc_reward_layers: SequenceType = [32], + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + reward_support_size: int = 601, + value_support_size: int = 601, + proj_hid: int = 1024, + proj_out: int = 1024, + pred_hid: int = 512, + pred_out: int = 1024, + self_supervised_learning_loss: bool = False, + categorical_distribution: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + last_linear_layer_init_zero: bool = True, + state_norm: bool = False, + discrete_action_encoding_type: str = 'one_hot', + norm_type: Optional[str] = 'BN', + res_connection_in_dynamics: bool = False, + *args, + **kwargs + ): + """ + Overview: + The definition of the network model of MuZero, which is a generalization version for 1D vector obs. + The networks are mainly built on fully connected layers. + The representation network is an MLP network which maps the raw observation to a latent state. + The dynamics network is an MLP network which predicts the next latent state, and reward given the current latent state and action. + The prediction network is an MLP network which predicts the value and policy given the current latent state. + Arguments: + - observation_shape (:obj:`int`): Observation space shape, e.g. 8 for Lunarlander. + - action_space_size: (:obj:`int`): Action space size, e.g. 4 for Lunarlander. + - latent_state_dim (:obj:`int`): The dimension of latent state, such as 256. + - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - reward_support_size (:obj:`int`): The size of categorical reward output + - value_support_size (:obj:`int`): The size of categorical value output. + - proj_hid (:obj:`int`): The size of projection hidden layer. + - proj_out (:obj:`int`): The size of projection output layer. + - pred_hid (:obj:`int`): The size of prediction hidden layer. + - pred_out (:obj:`int`): The size of prediction output layer. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks in MuZero model, default set it to False. + - categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical distribution for value, reward/value_prefix. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. + - state_norm (:obj:`bool`): Whether to use normalization for latent states, default sets it to True. + - discrete_action_encoding_type (:obj:`str`): The encoding type of discrete action, which can be 'one_hot' or 'not_one_hot'. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. + """ + super(GoBiggerMuZeroModel, self).__init__() + self.categorical_distribution = categorical_distribution + if not self.categorical_distribution: + self.reward_support_size = 1 + self.value_support_size = 1 + else: + self.reward_support_size = reward_support_size + self.value_support_size = value_support_size + + self.action_space_size = action_space_size + self.continuous_action_space = False + # The dim of action space. For discrete action space, it is 1. + # For continuous action space, it is the dimension of continuous action. + self.action_space_dim = action_space_size if self.continuous_action_space else 1 + assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type + self.discrete_action_encoding_type = discrete_action_encoding_type + if self.continuous_action_space: + self.action_encoding_dim = action_space_size + else: + if self.discrete_action_encoding_type == 'one_hot': + self.action_encoding_dim = action_space_size + elif self.discrete_action_encoding_type == 'not_one_hot': + self.action_encoding_dim = 1 + + self.latent_state_dim = latent_state_dim + self.proj_hid = proj_hid + self.proj_out = proj_out + self.pred_hid = pred_hid + self.pred_out = pred_out + self.self_supervised_learning_loss = self_supervised_learning_loss + self.last_linear_layer_init_zero = last_linear_layer_init_zero + self.state_norm = state_norm + self.res_connection_in_dynamics = res_connection_in_dynamics + + # self.representation_network = RepresentationNetworkMLP( + # observation_shape=observation_shape, hidden_channels=self.latent_state_dim, norm_type=norm_type + # ) + with open('lzero/model/gobigger/default_model_config.yaml', "r") as f: + encoder_cfg = yaml.safe_load(f) + encoder_cfg = EasyDict(encoder_cfg) + self.representation_network = Encoder(encoder_cfg) + + self.dynamics_network = DynamicsNetwork( + action_encoding_dim=self.action_encoding_dim, + num_channels=self.latent_state_dim + self.action_encoding_dim, + common_layer_num=2, + fc_reward_layers=fc_reward_layers, + output_support_size=self.reward_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type, + res_connection_in_dynamics=self.res_connection_in_dynamics, + ) + + self.prediction_network = PredictionNetworkMLP( + action_space_size=action_space_size, + num_channels=latent_state_dim, + fc_value_layers=fc_value_layers, + fc_policy_layers=fc_policy_layers, + output_support_size=self.value_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type + ) + + if self.self_supervised_learning_loss: + # self_supervised_learning_loss related network proposed in EfficientZero + self.projection_input_dim = latent_state_dim + + self.projection = nn.Sequential( + nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) + ) + self.prediction_head = nn.Sequential( + nn.Linear(self.proj_out, self.pred_hid), + nn.BatchNorm1d(self.pred_hid), + activation, + nn.Linear(self.pred_hid, self.pred_out), + ) + + def initial_inference(self, obs: torch.Tensor) -> MZNetworkOutput: + """ + Overview: + Initial inference of MuZero model, which is the first step of the MuZero model. + To perform the initial inference, we first use the representation network to obtain the "latent_state" of the observation. + Then we use the prediction network to predict the "value" and "policy_logits" of the "latent_state", and + also prepare the zeros-like ``reward`` for the next step of the MuZero model. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. \ + In initial inference, we set it to zero vector. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The hidden state of LSTM about reward. In initial inference, \ + we set it to the zeros-like hidden state (H and C). + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + """ + batch_size = len(obs) + obs = default_collate(obs) + latent_state = self._representation(obs) + policy_logits, value = self._prediction(latent_state) + return MZNetworkOutput( + value, + [0. for _ in range(batch_size)], + policy_logits, + latent_state, + ) + + def recurrent_inference(self, latent_state: torch.Tensor, action: torch.Tensor) -> MZNetworkOutput: + """ + Overview: + Recurrent inference of MuZero model, which is the rollout step of the MuZero model. + To perform the recurrent inference, we first use the dynamics network to predict ``next_latent_state``, + ``reward`` by the given current ``latent_state`` and ``action``. + We then use the prediction network to predict the ``value`` and ``policy_logits``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input obs. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - reward (:obj:`torch.Tensor`): The predicted reward for input state. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. + Shapes: + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + """ + next_latent_state, reward = self._dynamics(latent_state, action) + policy_logits, value = self._prediction(next_latent_state) + return MZNetworkOutput(value, reward, policy_logits, next_latent_state) + + def _representation(self, observation: torch.Tensor) -> Tuple[torch.Tensor]: + """ + Overview: + Use the representation network to encode the observations into latent state. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + """ + latent_state = self.representation_network(observation) + if self.state_norm: + latent_state = renormalize(latent_state) + return latent_state + + def _prediction(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Use the representation network to encode the observations into latent state. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns: + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + """ + policy_logits, value = self.prediction_network(latent_state) + return policy_logits, value + + def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` + ``reward`` and ``next_reward_hidden_state``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns: + - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. + - next_reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. + - reward (:obj:`torch.Tensor`): The predicted reward for input state. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + """ + # NOTE: the discrete action encoding type is important for some environments + + # discrete action space + if self.discrete_action_encoding_type == 'one_hot': + # Stack latent_state with the one hot encoded action + if len(action.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action = action.unsqueeze(-1) + + # transform action to one-hot encoding. + # action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4) + action_one_hot = torch.zeros(action.shape[0], self.action_space_size, device=action.device) + # transform action to torch.int64 + action = action.long() + action_one_hot.scatter_(1, action, 1) + action_encoding = action_one_hot + elif self.discrete_action_encoding_type == 'not_one_hot': + action_encoding = action / self.action_space_size + if len(action_encoding.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action_encoding = action_encoding.unsqueeze(-1) + + action_encoding = action_encoding.to(latent_state.device).float() + # state_action_encoding shape: (batch_size, latent_state[1] + action_dim]) or + # (batch_size, latent_state[1] + action_space_size]) depending on the discrete_action_encoding_type. + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + + next_latent_state, reward = self.dynamics_network(state_action_encoding) + + if not self.state_norm: + return next_latent_state, reward + else: + next_latent_state_normalized = renormalize(next_latent_state) + return next_latent_state_normalized, reward + + def project(self, latent_state: torch.Tensor, with_grad=True) -> torch.Tensor: + """ + Overview: + Project the latent state to a lower dimension to calculate the self-supervised loss, which is proposed in EfficientZero. + For more details, please refer to the paper ``Exploring Simple Siamese Representation Learning``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - with_grad (:obj:`bool`): Whether to calculate gradient for the projection result. + Returns: + - proj (:obj:`torch.Tensor`): The result embedding vector of projection operation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - proj (:obj:`torch.Tensor`): :math:`(B, projection_output_dim)`, where B is batch_size. + + Examples: + >>> latent_state = torch.randn(256, 64) + >>> output = self.project(latent_state) + >>> output.shape # (256, 1024) + """ + proj = self.projection(latent_state) + + if with_grad: + # with grad, use prediction_head + return self.prediction_head(proj) + else: + return proj.detach() + + def get_params_mean(self) -> float: + return get_params_mean(self) + + +class DynamicsNetwork(nn.Module): + + def __init__( + self, + action_encoding_dim: int = 2, + num_channels: int = 64, + common_layer_num: int = 2, + fc_reward_layers: SequenceType = [32], + output_support_size: int = 601, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + res_connection_in_dynamics: bool = False, + ): + """ + Overview: + The definition of dynamics network in MuZero algorithm, which is used to predict next latent state + reward by the given current latent state and action. + The networks are mainly built on fully connected layers. + Arguments: + - action_encoding_dim (:obj:`int`): The dimension of action encoding. + - num_channels (:obj:`int`): The num of channels in latent states. + - common_layer_num (:obj:`int`): The number of common layers in dynamics network. + - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - output_support_size (:obj:`int`): The size of categorical reward output. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection in dynamics network. + """ + super().__init__() + self.num_channels = num_channels + self.action_encoding_dim = action_encoding_dim + self.latent_state_dim = self.num_channels - self.action_encoding_dim + + self.res_connection_in_dynamics = res_connection_in_dynamics + if self.res_connection_in_dynamics: + self.fc_dynamics_1 = MLP( + in_channels=self.num_channels, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + self.fc_dynamics_2 = MLP( + in_channels=self.latent_state_dim, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + else: + self.fc_dynamics = MLP( + in_channels=self.num_channels, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + + self.fc_reward_head = MLP( + in_channels=self.latent_state_dim, + hidden_channels=fc_reward_layers[0], + layer_num=2, + out_channels=output_support_size, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Forward computation of the dynamics network. Predict the next latent state given current latent state and action. + Arguments: + - state_action_encoding (:obj:`torch.Tensor`): The state-action encoding, which is the concatenation of \ + latent state and action encoding, with shape (batch_size, num_channels, height, width). + Returns: + - next_latent_state (:obj:`torch.Tensor`): The next latent state, with shape (batch_size, latent_state_dim). + - reward (:obj:`torch.Tensor`): The predicted reward for input state. + """ + if self.res_connection_in_dynamics: + # take the state encoding (e.g. latent_state), + # state_action_encoding[:, -self.action_encoding_dim:] is action encoding + latent_state = state_action_encoding[:, :-self.action_encoding_dim] + x = self.fc_dynamics_1(state_action_encoding) + # the residual link: add the latent_state to the state_action encoding + next_latent_state = x + latent_state + next_latent_state_encoding = self.fc_dynamics_2(next_latent_state) + else: + next_latent_state = self.fc_dynamics(state_action_encoding) + next_latent_state_encoding = next_latent_state + + reward = self.fc_reward_head(next_latent_state_encoding) + + return next_latent_state, reward + + def get_dynamic_mean(self) -> float: + return get_dynamic_mean(self) + + def get_reward_mean(self) -> float: + return get_reward_mean(self) diff --git a/lzero/policy/gobigger_efficientzero.py b/lzero/policy/gobigger_efficientzero.py index 846300a95..fda441635 100644 --- a/lzero/policy/gobigger_efficientzero.py +++ b/lzero/policy/gobigger_efficientzero.py @@ -6,7 +6,7 @@ import torch.optim as optim from ding.model import model_wrap from ding.policy.base_policy import Policy -from ding.torch_utils import to_tensor, squeeze +from ding.torch_utils import to_tensor from ding.utils import POLICY_REGISTRY from torch.distributions import Categorical from torch.nn import L1Loss @@ -175,7 +175,7 @@ def default_model(self) -> Tuple[str, List[str]]: The user can define and use customized network model but must obey the same interface definition indicated \ by import_names path. For EfficientZero, ``lzero.model.efficientzero_model.EfficientZeroModel`` """ - return 'EfficientZeroModelMLP', ['lzero.model.gobigger.gobigger_efficientzero_model_mlp'] + return 'GoBiggerEfficientZeroModel', ['lzero.model.gobigger.gobigger_efficientzero_model'] def _init_learn(self) -> None: """ @@ -614,25 +614,7 @@ def _forward_collect( output[i//agent_num]['value'].append(value) output[i//agent_num]['pred_value'].append(pred_values[i]) output[i//agent_num]['policy_logits'].append(policy_logits[i]) - - - # for i, env_id in enumerate(ready_env_id): - # distributions, value = roots_visit_count_distributions[i], roots_values[i] - # # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents - # # the index within the legal action set, rather than the index in the entire action set. - # action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - # distributions, temperature=self.collect_mcts_temperature, deterministic=False - # ) - # # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. - # action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - # output[env_id] = { - # 'action': action, - # 'distributions': distributions, - # 'visit_count_distribution_entropy': visit_count_distribution_entropy, - # 'value': value, - # 'pred_value': pred_values[i], - # 'policy_logits': policy_logits[i], - # } + return output diff --git a/lzero/policy/gobigger_muzero.py b/lzero/policy/gobigger_muzero.py new file mode 100644 index 000000000..027027a2e --- /dev/null +++ b/lzero/policy/gobigger_muzero.py @@ -0,0 +1,724 @@ +import copy +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +import torch.optim as optim +from ding.model import model_wrap +from ding.policy.base_policy import Policy +from ding.torch_utils import to_tensor +from ding.utils import POLICY_REGISTRY +from torch.nn import L1Loss + +from lzero.mcts import MuZeroMCTSCtree as MCTSCtree +from lzero.mcts import MuZeroMCTSPtree as MCTSPtree +from lzero.model import ImageTransforms +from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ + DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, prepare_obs, \ + configure_optimizers +from collections import defaultdict +from ding.torch_utils import to_device + +@POLICY_REGISTRY.register('gobigger_muzero') +class GoBiggerMuZeroPolicy(Policy): + """ + Overview: + The policy class for MuZero. + """ + + # The default_config for MuZero policy. + config = dict( + model=dict( + # (str) The model type. For 1-dimensional vector obs, we use mlp model. For the image obs, we use conv model. + model_type='conv', # options={'mlp', 'conv'} + # (bool) If True, the action space of the environment is continuous, otherwise discrete. + continuous_action_space=False, + # (tuple) The stacked obs shape. + # observation_shape=(1, 96, 96), # if frame_stack_num=1 + observation_shape=(4, 96, 96), # if frame_stack_num=4 + # (bool) Whether to use the self-supervised learning loss. + self_supervised_learning_loss=False, + # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. + categorical_distribution=True, + # (int) The image channel in image observation. + image_channel=1, + # (int) The number of frames to stack together. + frame_stack_num=1, + # (int) The number of res blocks in MuZero model. + num_res_blocks=1, + # (int) The number of channels of hidden states in MuZero model. + num_channels=64, + # (int) The scale of supports used in categorical distribution. + # This variable is only effective when ``categorical_distribution=True``. + support_scale=300, + # (bool) whether to learn bias in the last linear layer in value and policy head. + bias=True, + # (str) The type of action encoding. Options are ['one_hot', 'not_one_hot']. Default to 'one_hot'. + discrete_action_encoding_type='one_hot', + # (bool) whether to use res connection in dynamics. + res_connection_in_dynamics=True, + # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'LN'. + norm_type='BN', + ), + # ****** common ****** + # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) + # this variable is used in ``collector``. + sampled_algo=False, + # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. + mcts_ctree=True, + # (bool) Whether to use cuda for network. + cuda=True, + # (int) The number of environments used in collecting data. + collector_env_num=8, + # (int) The number of environments used in evaluating policy. + evaluator_env_num=3, + # (str) The type of environment. Options is ['not_board_games', 'board_games']. + env_type='not_board_games', + # (str) The type of battle mode. Options is ['play_with_bot_mode', 'self_play_mode']. + battle_mode='play_with_bot_mode', + # (bool) Whether to monitor extra statistics in tensorboard. + monitor_extra_statistics=True, + # (int) The transition number of one ``GameSegment``. + game_segment_length=200, + + # ****** observation ****** + # (bool) Whether to transform image to string to save memory. + transform2string=False, + # (bool) Whether to use data augmentation. + use_augmentation=False, + # (list) The style of augmentation. + augmentation=['shift', 'intensity'], + + # ******* learn ****** + # (int) How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + # For different env, we have different episode_length, + # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor + update_per_collect=100, + # (int) Minibatch size for one gradient descent. + batch_size=256, + # (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW'] + optim_type='SGD', + # (float) Learning rate for training policy network. Initial lr for manually decay schedule. + learning_rate=0.2, + # (int) Frequency of target network update. + target_update_freq=100, + # (float) Weight decay for training policy network. + weight_decay=1e-4, + # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). + momentum=0.9, + # (float) The maximum constraint value of gradient norm clipping. + grad_clip_value=10, + # (int) The number of episodes in each collecting stage. + n_episode=8, + # (int) the number of simulations in MCTS. + num_simulations=50, + # (float) Discount factor (gamma) for returns. + discount_factor=0.997, + # (int) The number of steps for calculating target q_value. + td_steps=5, + # (int) The number of unroll steps in dynamics network. + num_unroll_steps=5, + # (float) The weight of reward loss. + reward_loss_weight=1, + # (float) The weight of value loss. + value_loss_weight=0.25, + # (float) The weight of policy loss. + policy_loss_weight=1, + # (float) The weight of ssl (self-supervised learning) loss. + ssl_loss_weight=0, + # (bool) Whether to use piecewise constant learning rate decay. + # i.e. lr: 0.2 -> 0.02 -> 0.002 + lr_piecewise_constant_decay=True, + # (int) The number of final training iterations to control lr decay, which is only used for manually decay. + threshold_training_steps_for_final_lr=int(5e4), + # (bool) Whether to use manually decayed temperature. + manual_temperature_decay=False, + # (int) The number of final training iterations to control temperature, which is only used for manually decay. + threshold_training_steps_for_final_temperature=int(1e5), + # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. + # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. + fixed_temperature_value=0.25, + + # ****** Priority ****** + # (bool) Whether to use priority when sampling training data from the buffer. + use_priority=True, + # (bool) Whether to use the maximum priority for new collecting data. + use_max_priority_for_new_data=True, + # (float) The degree of prioritization to use. A value of 0 means no prioritization, + # while a value of 1 means full prioritization. + priority_prob_alpha=0.6, + # (float) The degree of correction to use. A value of 0 means no correction, + # while a value of 1 means full correction. + priority_prob_beta=0.4, + + # ****** UCB ****** + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default model setting. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. + - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. + - import_names (:obj:`List[str]`): The model class path list used in this algorithm. + .. note:: + The user can define and use customized network model but must obey the same interface definition indicated \ + by import_names path. For MuZero, ``lzero.model.muzero_model.MuZeroModel`` + """ + return 'GoBiggerMuZeroModel', ['lzero.model.gobigger.gobigger_muzero_model'] + + + def _init_learn(self) -> None: + """ + Overview: + Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. + """ + assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type + # NOTE: in board_gmaes, for fixed lr 0.003, 'Adam' is better than 'SGD'. + if self._cfg.optim_type == 'SGD': + self._optimizer = optim.SGD( + self._model.parameters(), + lr=self._cfg.learning_rate, + momentum=self._cfg.momentum, + weight_decay=self._cfg.weight_decay, + ) + elif self._cfg.optim_type == 'Adam': + self._optimizer = optim.Adam( + self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay + ) + elif self._cfg.optim_type == 'AdamW': + self._optimizer = configure_optimizers( + model=self._model, + weight_decay=self._cfg.weight_decay, + learning_rate=self._cfg.learning_rate, + device_type=self._cfg.device + ) + + if self._cfg.lr_piecewise_constant_decay: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda) + + # use model_wrapper for specialized demands of different modes + self._target_model = copy.deepcopy(self._model) + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq} + ) + self._learn_model = self._model + + if self._cfg.use_augmentation: + self.image_transforms = ImageTransforms( + self._cfg.augmentation, + image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + ) + self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + + def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning policy in learn mode, which is the core of the learning process. + The data is sampled from replay buffer. + The loss is calculated by the loss function and the loss is backpropagated to update the model. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. + The first tensor is the current_batch, the second tensor is the target_batch. + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ + current learning loss and learning statistics. + """ + self._learn_model.train() + self._target_model.train() + + current_batch, target_batch = data + obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch + target_reward, target_value, target_policy = target_batch + + obs_batch_ori = obs_batch_ori.tolist() + obs_batch_ori = np.array(obs_batch_ori) + obs_batch = obs_batch_ori[:, 0:self._cfg.model.frame_stack_num] + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = obs_batch_ori[:, self._cfg.model.frame_stack_num:] + # obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) + + # # do augmentations + # if self._cfg.use_augmentation: + # obs_batch = self.image_transforms.transform(obs_batch) + # if self._cfg.model.self_supervised_learning_loss: + # obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # shape: (batch_size, num_unroll_steps, action_dim) + # NOTE: .long(), in discrete action space. + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() + data_list = [ + mask_batch, + target_reward.astype('float64'), + target_value.astype('float64'), target_policy, weights + ] + [mask_batch, target_reward, target_value, target_policy, + weights] = to_torch_float_tensor(data_list, self._cfg.device) + + target_reward = target_reward.view(self._cfg.batch_size, -1) + target_value = target_value.view(self._cfg.batch_size, -1) + + assert obs_batch.size == self._cfg.batch_size == target_reward.size(0) + + # ``scalar_transform`` to transform the original value to the scaled value, + # i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + transformed_target_reward = scalar_transform(target_reward) + transformed_target_value = scalar_transform(target_value) + + # transform a scalar to its categorical_distribution. After this transformation, each scalar is + # represented as the linear combination of its two adjacent supports. + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # ============================================================== + # the core initial_inference in MuZero policy. + # ============================================================== + obs_batch = obs_batch.tolist() + obs_batch = sum(obs_batch, []) + obs_batch = to_tensor(obs_batch) + obs_batch = to_device(obs_batch, self._cfg.device) + network_output = self._learn_model.initial_inference(obs_batch) + + # value_prefix shape: (batch_size, 10), the ``value_prefix`` at the first step is zero padding. + latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) + + # transform the scaled value or its categorical representation to its original value, + # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + original_value = self.inverse_scalar_transform_handle(value) + + # Note: The following lines are just for debugging. + predicted_rewards = [] + if self._cfg.monitor_extra_statistics: + latent_state_list = latent_state.detach().cpu().numpy() + predicted_values, predicted_policies = original_value.detach().cpu(), torch.softmax( + policy_logits, dim=1 + ).detach().cpu() + + # calculate the new priorities for each transition. + value_priority = L1Loss(reduction='none')(original_value.squeeze(-1), target_value[:, 0]) + value_priority = value_priority.data.cpu().numpy() + 1e-6 + + # ============================================================== + # calculate policy and value loss for the first step. + # ============================================================== + policy_loss = cross_entropy_loss(policy_logits, target_policy[:, 0]) + value_loss = cross_entropy_loss(value, target_value_categorical[:, 0]) + + reward_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + consistency_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + + gradient_scale = 1 / self._cfg.num_unroll_steps + + # ============================================================== + # the core recurrent_inference in MuZero policy. + # ============================================================== + for step_i in range(self._cfg.num_unroll_steps): + # unroll with the dynamics function: predict the next ``latent_state``, ``reward``, + # given current ``latent_state`` and ``action``. + # And then predict policy_logits and value with the prediction function. + network_output = self._learn_model.recurrent_inference(latent_state, action_batch[:, step_i]) + latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) + + # transform the scaled value or its categorical representation to its original value, + # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + original_value = self.inverse_scalar_transform_handle(value) + + if self._cfg.model.self_supervised_learning_loss: + # ============================================================== + # calculate consistency loss for the next ``num_unroll_steps`` unroll steps. + # ============================================================== + if self._cfg.ssl_loss_weight > 0: + beg_index = step_i + end_index = step_i + self._cfg.model.frame_stack_num + obs_target_batch_tmp = obs_target_batch[:, beg_index:end_index].tolist() + obs_target_batch_tmp = sum(obs_target_batch_tmp, []) + obs_target_batch_tmp = to_tensor(obs_target_batch_tmp) + obs_target_batch_tmp = to_device(obs_target_batch_tmp, self._cfg.device) + network_output = self._learn_model.initial_inference(obs_target_batch_tmp) + + latent_state = to_tensor(latent_state) + representation_state = to_tensor(network_output.latent_state) + + # NOTE: no grad for the representation_state branch + dynamic_proj = self._learn_model.project(latent_state, with_grad=True) + observation_proj = self._learn_model.project(representation_state, with_grad=False) + temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_i] + consistency_loss += temp_loss + + # NOTE: the target policy, target_value_categorical, target_reward_categorical is calculated in + # game buffer now. + # ============================================================== + # calculate policy loss for the next ``num_unroll_steps`` unroll steps. + # NOTE: the +=. + # ============================================================== + policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_i + 1]) + + value_loss += cross_entropy_loss(value, target_value_categorical[:, step_i + 1]) + reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_i]) + + # Follow MuZero, set half gradient + # latent_state.register_hook(lambda grad: grad * 0.5) + + if self._cfg.monitor_extra_statistics: + original_rewards = self.inverse_scalar_transform_handle(reward) + original_rewards_cpu = original_rewards.detach().cpu() + + predicted_values = torch.cat( + (predicted_values, self.inverse_scalar_transform_handle(value).detach().cpu()) + ) + predicted_rewards.append(original_rewards_cpu) + predicted_policies = torch.cat((predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu())) + latent_state_list = np.concatenate((latent_state_list, latent_state.detach().cpu().numpy())) + + # ============================================================== + # the core learn model update step. + # ============================================================== + # weighted loss with masks (some invalid states which are out of trajectory.) + loss = ( + self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss + + self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * reward_loss + ) + weighted_total_loss = (weights * loss).mean() + + gradient_scale = 1 / self._cfg.num_unroll_steps + weighted_total_loss.register_hook(lambda grad: grad * gradient_scale) + self._optimizer.zero_grad() + weighted_total_loss.backward() + total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_( + self._learn_model.parameters(), self._cfg.grad_clip_value + ) + self._optimizer.step() + if self._cfg.lr_piecewise_constant_decay is True: + self.lr_scheduler.step() + + # ============================================================== + # the core target model update step. + # ============================================================== + self._target_model.update(self._learn_model.state_dict()) + + # packing loss info for tensorboard logging + loss_info = ( + weighted_total_loss.item(), loss.mean().item(), policy_loss.mean().item(), reward_loss.mean().item(), + value_loss.mean().item(), consistency_loss.mean() + ) + if self._cfg.monitor_extra_statistics: + predicted_rewards = torch.stack(predicted_rewards).transpose(1, 0).squeeze(-1) + predicted_rewards = predicted_rewards.reshape(-1).unsqueeze(-1) + + td_data = ( + value_priority, + target_reward.detach().cpu().numpy(), + target_value.detach().cpu().numpy(), + transformed_target_reward.detach().cpu().numpy(), + transformed_target_value.detach().cpu().numpy(), + target_reward_categorical.detach().cpu().numpy(), + target_value_categorical.detach().cpu().numpy(), + predicted_rewards.detach().cpu().numpy(), + predicted_values.detach().cpu().numpy(), + target_policy.detach().cpu().numpy(), + predicted_policies.detach().cpu().numpy(), + latent_state_list, + ) + + return { + 'collect_mcts_temperature': self.collect_mcts_temperature, + 'cur_lr': self._optimizer.param_groups[0]['lr'], + 'weighted_total_loss': loss_info[0], + 'total_loss': loss_info[1], + 'policy_loss': loss_info[2], + 'reward_loss': loss_info[3], + 'value_loss': loss_info[4], + 'consistency_loss': loss_info[5] / self._cfg.num_unroll_steps, + + # ============================================================== + # priority related + # ============================================================== + 'value_priority_orig': value_priority, + 'value_priority': td_data[0].flatten().mean().item(), + 'target_reward': td_data[1].flatten().mean().item(), + 'target_value': td_data[2].flatten().mean().item(), + 'transformed_target_reward': td_data[3].flatten().mean().item(), + 'transformed_target_value': td_data[4].flatten().mean().item(), + 'predicted_rewards': td_data[7].flatten().mean().item(), + 'predicted_values': td_data[8].flatten().mean().item(), + 'total_grad_norm_before_clip': total_grad_norm_before_clip + } + + def _init_collect(self) -> None: + """ + Overview: + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + """ + self._collect_model = self._model + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(self._cfg) + else: + self._mcts_collect = MCTSPtree(self._cfg) + self.collect_mcts_temperature = 1 + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + ready_env_id=None + ) -> Dict: + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + self.collect_mcts_temperature = temperature + active_collect_env_num = len(data) + data = to_tensor(data) + data = sum(sum(data, []), []) + batch_size = len(data) + data = to_device(data, self._cfg.device) + agent_num = batch_size // active_collect_env_num + to_play = np.array(to_play).reshape(-1).tolist() + + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._collect_model.initial_inference(data) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + if not self._learn_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + action_mask = sum(action_mask, []) + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] + # the only difference between collect and eval is the dirichlet noise + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(batch_size) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(batch_size, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(batch_size, legal_actions) + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play) + + roots_visit_count_distributions = roots.get_distributions() # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_values = roots.get_values() # shape: {list: batch_size} + + data_id = [i for i in range(active_collect_env_num)] + output = {i: defaultdict(list) for i in data_id} + + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + + for i in range(batch_size): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self.collect_mcts_temperature, deterministic=False + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the + # entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[i//agent_num]['action'].append(action) + output[i//agent_num]['distributions'].append(distributions) + output[i//agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[i//agent_num]['value'].append(value) + output[i//agent_num]['pred_value'].append(pred_values[i]) + output[i//agent_num]['policy_logits'].append(policy_logits[i]) + + return output + + def _init_eval(self) -> None: + """ + Overview: + Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. + """ + self._eval_model = self._model + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(self._cfg) + else: + self._mcts_eval = MCTSPtree(self._cfg) + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id=None) -> Dict: + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + active_eval_env_num = len(data) + data = to_tensor(data) + data = sum(sum(data, []), []) + batch_size = len(data) + data = to_device(data, self._cfg.device) + agent_num = batch_size // active_eval_env_num + to_play = np.array(to_play).reshape(-1).tolist() + + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._collect_model.initial_inference(data) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + if not self._eval_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + + action_mask = sum(action_mask, []) + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(batch_size, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(batch_size, legal_actions) + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play) + + roots_visit_count_distributions = roots.get_distributions( + ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_values = roots.get_values() # shape: {list: batch_size} + + data_id = [i for i in range(active_eval_env_num)] + output = {i: defaultdict(list) for i in data_id} + + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + + for i in range(batch_size): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than + # sampling during the evaluation phase. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the + # entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[i//agent_num]['action'].append(action) + output[i//agent_num]['distributions'].append(distributions) + output[i//agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[i//agent_num]['value'].append(value) + output[i//agent_num]['pred_value'].append(pred_values[i]) + output[i//agent_num]['policy_logits'].append(policy_logits[i]) + + return output + + def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Register the variables to be monitored in learn mode. The registered variables will be logged in + tensorboard according to the return value ``_forward_learn``. + """ + return [ + 'collect_mcts_temperature', + 'cur_lr', + 'weighted_total_loss', + 'total_loss', + 'policy_loss', + 'reward_loss', + 'value_loss', + 'consistency_loss', + 'value_priority', + 'target_reward', + 'target_value', + 'predicted_rewards', + 'predicted_values', + 'transformed_target_reward', + 'transformed_target_value', + 'total_grad_norm_before_clip', + ] + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model, target_model and optimizer. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. + """ + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'optimizer': self._optimizer.state_dict(), + } + + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + """ + self._learn_model.load_state_dict(state_dict['model']) + self._target_model.load_state_dict(state_dict['target_model']) + self._optimizer.load_state_dict(state_dict['optimizer']) + + def _process_transition(self, obs, policy_output, timestep): + # be compatible with DI-engine Policy class + pass + + def _get_train_sample(self, data): + # be compatible with DI-engine Policy class + pass diff --git a/zoo/gobigger/config/gobigger_muzero_config.py b/zoo/gobigger/config/gobigger_muzero_config.py new file mode 100644 index 000000000..25662d245 --- /dev/null +++ b/zoo/gobigger/config/gobigger_muzero_config.py @@ -0,0 +1,142 @@ +from easydict import EasyDict + +env_name = 'GoBigger' + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 32 +n_episode = 32 +evaluator_env_num = 5 +num_simulations = 50 +update_per_collect = 2000 +batch_size = 256 +reanalyze_ratio = 0. +action_space_size = 27 +direction_num=12 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +atari_muzero_config = dict( + exp_name= + f'data_mz_ctree/{env_name}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + env=dict( + env_name=env_name, + team_num=2, + player_num_per_team=2, + direction_num=direction_num, + step_mul=8, + map_width=64, + map_height=64, + frame_limit=3600, + action_space_size=action_space_size, + use_action_mask=False, + reward_div_value=0.1, + reward_type='log_reward', + start_spirit_progress=0.2, + end_spirit_progress=0.8, + manager_settings=dict( + food_manager=dict( + num_init=260, + num_min=260, + num_max=300, + ), + thorns_manager=dict( + num_init=3, + num_min=3, + num_max=4, + ), + player_manager=dict( + ball_settings=dict( + score_init=13000, + ), + ), + ), + playback_settings=dict( + playback_type='by_frame', + by_frame=dict( + save_frame=False, + # save_frame=True, + save_dir='./', + save_name_prefix='gobigger', + ), + ), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + latent_state_dim=176, + frame_stack_num=1, + action_space_size=action_space_size, + downsample=True, + self_supervised_learning_loss=False, # default is False + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + mcts_ctree=True, + env_type='not_board_games', + game_segment_length=400, + use_augmentation=False, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + ssl_loss_weight=0, # default is 0 + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + learn=dict( + learner=dict( + log_policy=True, + hook=dict( + log_show_after_iter=10, + ), + ), + ), + collect=dict( + collector=dict( + collect_print_freq=10, + ), + ), + eval=dict( + evaluator=dict( + eval_freq=5000, + stop_value=10000000000, + ), + ), +) +atari_muzero_config = EasyDict(atari_muzero_config) +main_config = atari_muzero_config + +atari_muzero_create_config = dict( + env=dict( + type='gobigger_lightzero', + import_names=['zoo.gobigger.env.gobigger_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='gobigger_muzero', + import_names=['lzero.policy.gobigger_muzero'], + ), + collector=dict( + type='gobigger_episode_muzero', + import_names=['lzero.worker.gobigger_muzero_collector'], + ) +) +atari_muzero_create_config = EasyDict(atari_muzero_create_config) +create_config = atari_muzero_create_config + +if __name__ == "__main__": + from lzero.entry import train_muzero_gobigger + train_muzero_gobigger([main_config, create_config], seed=0) From d88d79cc3a0d15937833e257526e133d4d2460ed Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Fri, 2 Jun 2023 12:21:04 +0800 Subject: [PATCH 05/54] polish(yzj): polish gobigger env --- zoo/gobigger/env/gobigger_env.py | 24 +----------------------- 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/zoo/gobigger/env/gobigger_env.py b/zoo/gobigger/env/gobigger_env.py index 209407e28..1077a93db 100644 --- a/zoo/gobigger/env/gobigger_env.py +++ b/zoo/gobigger/env/gobigger_env.py @@ -61,7 +61,6 @@ def step(self, action: dict) -> BaseEnvTimestep: action = {k: self.transform_action(v) if np.isscalar(v) else v for k, v in action.items()} raw_obs, raw_rew, done, info = self._env.step(action) # print('current_frame={}'.format(raw_obs[0]['last_time'])) - # print('raw_rew={}'.format(raw_rew)) # print('action={}'.format(action)) # print('raw_rew={}, done={}'.format(raw_rew, done)) rew = self.transform_reward(raw_obs) @@ -346,17 +345,8 @@ def _preprocess_obs(self, raw_obs, env_status=None, eval_vsbot=False): last_action_type=last_action_type) game_player_obs['action_mask'] = action_mask env_player_obs.append(game_player_obs) - # env_player_obs = default_collate_with_dim(env_player_obs) return env_player_obs - # def collate_obs(self, env_player_obs): - # processed_obs_list = [] - # for env_id, env_obs in env_player_obs.items(): - # for game_player_id, game_player_obs in env_obs.items(): - # processed_obs_list.append(game_player_obs) - # obs_batch = default_collate_with_dim(processed_obs_list, device=self.device) - # return obs_batch - def preprocess_obs(self, obs_list, env_status=None, eval_vsbot=False): env_player_obs = self._preprocess_obs(obs_list, env_status, eval_vsbot) return env_player_obs @@ -492,16 +482,4 @@ def transform_reward(self, next_obs): actions.update(bot[i].step(obs['raw_obs'])) obs, rew, done, info = env.step(actions) if done: - break - - # from ding.envs import create_env_manager - # from functools import partial - # env_manager=EasyDict({'episode_num': float('inf'), 'max_retry': 1, 'retry_type': 'reset', 'auto_reset': True, - # 'step_timeout': None, 'reset_timeout': None, 'retry_waiting_time': 0.1, 'cfg_type': 'BaseEnvManagerDict', - # 'type': 'base', 'shared_memory': False}) - - # collector_env = create_env_manager(env_manager, [partial(GoBiggerLightZeroEnv, cfg=c) for c in [env_cfg]]) - # collector_env.launch() - # print(collector_env._env_num) - # for i in range(500): - # timestep = collector_env.step({0:[0,0,0,0]}) \ No newline at end of file + break \ No newline at end of file From 17992eb6755e78556e0a237b996db8cfc23e51bb Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Fri, 2 Jun 2023 20:38:08 +0800 Subject: [PATCH 06/54] feature(yzj): adapt multi agent env gobigger with sez --- lzero/entry/train_muzero_gobigger.py | 6 +- lzero/mcts/buffer/__init__.py | 1 + ...igger_game_buffer_sampled_efficientzero.py | 593 +++++++++ lzero/mcts/tree_search/mcts_ptree_sampled.py | 2 +- .../gobigger_sampled_efficientzero_model.py | 531 ++++++++ .../policy/gobigger_sampled_efficientzero.py | 1176 +++++++++++++++++ .../gobigger_sampled_efficientzero_config.py | 145 ++ 7 files changed, 2451 insertions(+), 3 deletions(-) create mode 100644 lzero/mcts/buffer/gobigger_game_buffer_sampled_efficientzero.py create mode 100644 lzero/model/gobigger/gobigger_sampled_efficientzero_model.py create mode 100644 lzero/policy/gobigger_sampled_efficientzero.py create mode 100644 zoo/gobigger/config/gobigger_sampled_efficientzero_config.py diff --git a/lzero/entry/train_muzero_gobigger.py b/lzero/entry/train_muzero_gobigger.py index 60952a353..f0b095b52 100644 --- a/lzero/entry/train_muzero_gobigger.py +++ b/lzero/entry/train_muzero_gobigger.py @@ -42,13 +42,15 @@ def train_muzero_gobigger( """ cfg, create_cfg = input_cfg - assert create_cfg.policy.type in ['gobigger_efficientzero', 'gobigger_muzero'], \ - "train_muzero entry now only support the following algo.: 'gobigger_efficientzero', 'gobigger_muzero'" + assert create_cfg.policy.type in ['gobigger_efficientzero', 'gobigger_muzero', 'gobigger_sampled_efficientzero'], \ + "train_muzero entry now only support the following algo.: 'gobigger_efficientzero', 'gobigger_muzero', 'gobigger_sampled_efficientzero'" if create_cfg.policy.type == 'gobigger_efficientzero': from lzero.mcts import GoBiggerEfficientZeroGameBuffer as GameBuffer elif create_cfg.policy.type == 'gobigger_muzero': from lzero.mcts import GoBiggerMuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'gobigger_sampled_efficientzero': + from lzero.mcts import GoBiggerSampledEfficientZeroGameBuffer as GameBuffer if cfg.policy.cuda and torch.cuda.is_available(): cfg.policy.device = 'cuda' diff --git a/lzero/mcts/buffer/__init__.py b/lzero/mcts/buffer/__init__.py index e5a245ee9..69e2f9d68 100644 --- a/lzero/mcts/buffer/__init__.py +++ b/lzero/mcts/buffer/__init__.py @@ -3,3 +3,4 @@ from .game_buffer_sampled_efficientzero import SampledEfficientZeroGameBuffer from .gobigger_game_buffer_muzero import GoBiggerMuZeroGameBuffer from .gobigger_game_buffer_efficientzero import GoBiggerEfficientZeroGameBuffer +from .gobigger_game_buffer_sampled_efficientzero import GoBiggerSampledEfficientZeroGameBuffer diff --git a/lzero/mcts/buffer/gobigger_game_buffer_sampled_efficientzero.py b/lzero/mcts/buffer/gobigger_game_buffer_sampled_efficientzero.py new file mode 100644 index 000000000..7dd8258e3 --- /dev/null +++ b/lzero/mcts/buffer/gobigger_game_buffer_sampled_efficientzero.py @@ -0,0 +1,593 @@ +from typing import Any, List, Tuple + +import numpy as np +import torch +from ding.utils import BUFFER_REGISTRY + +from lzero.mcts.tree_search.mcts_ctree_sampled import SampledEfficientZeroMCTSCtree as MCTSCtree +from lzero.mcts.tree_search.mcts_ptree_sampled import SampledEfficientZeroMCTSPtree as MCTSPtree +from lzero.mcts.utils import prepare_observation +from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform +from .gobigger_game_buffer_muzero import GoBiggerMuZeroGameBuffer +from ding.torch_utils import to_device, to_tensor, to_ndarray + + +@BUFFER_REGISTRY.register('gobigger_game_buffer_sampled_efficientzero') +class GoBiggerSampledEfficientZeroGameBuffer(GoBiggerMuZeroGameBuffer): + """ + Overview: + The specific game buffer for Sampled EfficientZero policy. + """ + + def __init__(self, cfg: dict): + super().__init__(cfg) + """ + Overview: + Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key + in the default configuration, the user-provided value will override the default configuration. Otherwise, + the default configuration will be used. + """ + default_config = self.default_config() + default_config.update(cfg) + self._cfg = default_config + assert self._cfg.env_type in ['not_board_games', 'board_games'] + self.replay_buffer_size = self._cfg.replay_buffer_size + self.batch_size = self._cfg.batch_size + self._alpha = self._cfg.priority_prob_alpha + self._beta = self._cfg.priority_prob_beta + + self.game_segment_buffer = [] + self.game_pos_priorities = [] + self.game_segment_game_pos_look_up = [] + + self.keep_ratio = 1 + self.num_of_collected_episodes = 0 + self.base_idx = 0 + self.clear_time = 0 + + self.tmp_obs = None # for value obs list [46 + 4(td_step)] not < 50(game_segment) + + def sample(self, batch_size: int, policy: Any) -> List[Any]: + """ + Overview: + sample data from ``GameBuffer`` and prepare the current and target batch for training + Arguments: + - batch_size (:obj:`int`): batch size + - policy (:obj:`torch.tensor`): model of policy + Returns: + - train_data (:obj:`List`): List of train data + """ + + policy._target_model.to(self._cfg.device) + policy._target_model.eval() + + reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch( + batch_size, self._cfg.reanalyze_ratio + ) + + # target reward, target value + batch_value_prefixs, batch_target_values = self._compute_target_reward_value( + reward_value_context, policy._target_model + ) + + batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( + policy_non_re_context, self._cfg.model.num_of_sampled_actions + ) + + if self._cfg.reanalyze_ratio > 0: + # target policy + batch_target_policies_re, root_sampled_actions = self._compute_target_policy_reanalyzed( + policy_re_context, policy._target_model + ) + # ============================================================== + # fix reanalyze in sez: + # use the latest root_sampled_actions after the reanalyze process, + # because the batch_target_policies_re is corresponding to the latest root_sampled_actions + # ============================================================== + + assert (self._cfg.reanalyze_ratio > 0 and self._cfg.reanalyze_outdated is True), \ + "in sampled effiicientzero, if self._cfg.reanalyze_ratio>0, you must set self._cfg.reanalyze_outdated=True" + # current_batch = [obs_list, action_list, root_sampled_actions_list, mask_list, batch_index_list, weights_list, make_time_list] + if self._cfg.model.continuous_action_space: + current_batch[2][:int(batch_size * self._cfg.reanalyze_ratio)] = root_sampled_actions.reshape( + int(batch_size * self._cfg.reanalyze_ratio), self._cfg.num_unroll_steps + 1, + self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size + ) + else: + current_batch[2][:int(batch_size * self._cfg.reanalyze_ratio)] = root_sampled_actions.reshape( + int(batch_size * self._cfg.reanalyze_ratio), self._cfg.num_unroll_steps + 1, + self._cfg.model.num_of_sampled_actions, 1 + ) + + if 0 < self._cfg.reanalyze_ratio < 1: + try: + batch_target_policies = np.concatenate([batch_target_policies_re, batch_target_policies_non_re]) + except Exception as error: + print(error) + elif self._cfg.reanalyze_ratio == 1: + batch_target_policies = batch_target_policies_re + elif self._cfg.reanalyze_ratio == 0: + batch_target_policies = batch_target_policies_non_re + + target_batch = [batch_value_prefixs, batch_target_values, batch_target_policies] + # a batch contains the current_batch and the target_batch + train_data = [current_batch, target_batch] + return train_data + + def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: + """ + Overview: + first sample orig_data through ``_sample_orig_data()``, + then prepare the context of a batch: + reward_value_context: the context of reanalyzed value targets + policy_re_context: the context of reanalyzed policy targets + policy_non_re_context: the context of non-reanalyzed policy targets + current_batch: the inputs of batch + Arguments: + - batch_size (:obj:`int`): the batch size of orig_data from replay buffer. + - reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed) + Returns: + - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch + """ + # obtain the batch context from replay buffer + orig_data = self._sample_orig_data(batch_size) + game_lst, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data + batch_size = len(batch_index_list) + obs_list, action_list, mask_list = [], [], [] + root_sampled_actions_list = [] + # prepare the inputs of a batch + for i in range(batch_size): + game = game_lst[i] + pos_in_game_segment = pos_in_game_segment_list[i] + # ============================================================== + # sampled related core code + # ============================================================== + actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment + + self._cfg.num_unroll_steps].tolist() + + # NOTE: self._cfg.num_unroll_steps + 1 + root_sampled_actions_tmp = game.root_sampled_actions[pos_in_game_segment:pos_in_game_segment + + self._cfg.num_unroll_steps + 1] + + # add mask for invalid actions (out of trajectory), 1 for valid, 0 for invalid + mask_tmp = [1. for i in range(len(root_sampled_actions_tmp))] + mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] + + # pad random action + if self._cfg.model.continuous_action_space: + actions_tmp += [ + np.random.randn(self._cfg.model.action_space_size) + for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) + ] + root_sampled_actions_tmp += [ + np.random.rand(self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size) + for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp)) + ] + else: + actions_tmp += [ + np.random.randint(0, self._cfg.model.action_space_size, 1).item() + for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) + ] + if len(root_sampled_actions_tmp[0].shape) == 1: + root_sampled_actions_tmp += [ + np.random.randint(0, self._cfg.model.action_space_size, + self._cfg.model.num_of_sampled_actions) + # NOTE: self._cfg.num_unroll_steps + 1 + for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp)) + ] + else: + root_sampled_actions_tmp += [ + np.random.randint(0, self._cfg.model.action_space_size, + self._cfg.model.num_of_sampled_actions).reshape( + self._cfg.model.num_of_sampled_actions, 1 + ) # NOTE: self._cfg.num_unroll_steps + 1 + for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp)) + ] + + # obtain the input observations + # stack+num_unroll_steps 4+5 + # pad if length of obs in game_segment is less than stack+num_unroll_steps + obs_list.append( + game_lst[i].get_unroll_obs( + pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True + ) + ) + action_list.append(actions_tmp) + root_sampled_actions_list.append(root_sampled_actions_tmp) + + mask_list.append(mask_tmp) + + # formalize the input observations + #obs_list = prepare_observation(obs_list, self._cfg.model.model_type) + # ============================================================== + # sampled related core code + # ============================================================== + # formalize the inputs of a batch + current_batch = [ + obs_list, action_list, root_sampled_actions_list, mask_list, batch_index_list, weights_list, make_time_list + ] + + for i in range(len(current_batch)): + current_batch[i] = np.asarray(current_batch[i]) + + total_transitions = self.get_num_of_transitions() + + # obtain the context of value targets + reward_value_context = self._prepare_reward_value_context( + batch_index_list, game_lst, pos_in_game_segment_list, total_transitions + ) + """ + only reanalyze recent reanalyze_ratio (e.g. 50%) data + if self._cfg.reanalyze_outdated is True, batch_index_list is sorted according to its generated env_steps + 0: reanalyze_num -> reanalyzed policy, reanalyze_num:end -> non reanalyzed policy + """ + reanalyze_num = int(batch_size * reanalyze_ratio) + # reanalyzed policy + if reanalyze_num > 0: + # obtain the context of reanalyzed policy targets + policy_re_context = self._prepare_policy_reanalyzed_context( + batch_index_list[:reanalyze_num], game_lst[:reanalyze_num], pos_in_game_segment_list[:reanalyze_num] + ) + else: + policy_re_context = None + + # non reanalyzed policy + if reanalyze_num < batch_size: + # obtain the context of non-reanalyzed policy targets + policy_non_re_context = self._prepare_policy_non_reanalyzed_context( + batch_index_list[reanalyze_num:], game_lst[reanalyze_num:], pos_in_game_segment_list[reanalyze_num:] + ) + else: + policy_non_re_context = None + + context = reward_value_context, policy_re_context, policy_non_re_context, current_batch + return context + + def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any) -> List[np.ndarray]: + """ + Overview: + prepare reward and value targets from the context of rewards and values. + Arguments: + - reward_value_context (:obj:'list'): the reward value context + - model (:obj:'torch.tensor'):model of the target model + Returns: + - batch_value_prefixs (:obj:'np.ndarray): batch of value prefix + - batch_target_values (:obj:'np.ndarray): batch of value estimation + """ + value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, \ + to_play_segment = reward_value_context # noqa + + # transition_batch_size = game_segment_batch_size * (num_unroll_steps+1) + transition_batch_size = len(value_obs_list) + game_segment_batch_size = len(pos_in_game_segment_list) + + to_play, action_mask = self._preprocess_to_play_and_action_mask( + game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list + ) + if self._cfg.model.continuous_action_space is True: + # when the action space of the environment is continuous, action_mask[:] is None. + action_mask = [ + list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) + ] + # NOTE: in continuous action space env: we set all legal_actions as -1 + legal_actions = [ + [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) + ] + else: + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] + + batch_target_values, batch_value_prefixs = [], [] + with torch.no_grad(): + # value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) + # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors + slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) + network_output = [] + for i in range(slices): + beg_index = self._cfg.mini_infer_size * i + end_index = self._cfg.mini_infer_size * (i + 1) + # m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() + m_obs = value_obs_list[beg_index:end_index] + m_obs = to_tensor(m_obs) + m_obs = sum(m_obs, []) + m_obs = to_device(m_obs, self._cfg.device) + + # calculate the target value + m_output = model.initial_inference(m_obs) + + # TODO(pu) + if not model.training: + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + m_output.policy_logits + ] + ) + m_output.reward_hidden_state = ( + m_output.reward_hidden_state[0].detach().cpu().numpy(), + m_output.reward_hidden_state[1].detach().cpu().numpy() + ) + + network_output.append(m_output) + + # concat the output slices after model inference + if self._cfg.use_root_value: + # use the root values from MCTS + # the root values have limited improvement but require much more GPU actors; + _, value_prefix_pool, policy_logits_pool, latent_state_roots, reward_hidden_state_roots = concat_output( + network_output, data_type='efficientzero' + ) + value_prefix_pool = value_prefix_pool.squeeze().tolist() + policy_logits_pool = policy_logits_pool.tolist() + # generate the noises for the root nodes + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.num_of_sampled_actions + ).astype(np.float32).tolist() for _ in range(transition_batch_size) + ] + + if self._cfg.mcts_ctree: + # cpp mcts_tree + # prepare the root nodes for MCTS + roots = MCTSCtree.roots( + transition_batch_size, legal_actions, self._cfg.model.action_space_size, + self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space + ) + + roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play) + # do MCTS for a new policy with the recent target model + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play) + else: + # python mcts_tree + roots = MCTSPtree.roots( + transition_batch_size, legal_actions, self._cfg.model.action_space_size, + self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space + ) + roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play) + # do MCTS for a new policy with the recent target model + MCTSPtree.roots(self._cfg + ).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play) + + roots_values = roots.get_values() + value_list = np.array(roots_values) + else: + # use the predicted values + value_list = concat_output_value(network_output) + + # get last state value + if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: + # TODO(pu): for board_games, very important, to check + value_list = value_list.reshape(-1) * np.array( + [ + self._cfg.discount_factor ** td_steps_list[i] if int(td_steps_list[i]) % + 2 == 0 else -self._cfg.discount_factor ** td_steps_list[i] + for i in range(transition_batch_size) + ] + ) + else: + value_list = value_list.reshape(-1) * ( + np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list + ) + + value_list = value_list * np.array(value_mask) + value_list = value_list.tolist() + + horizon_id, value_index = 0, 0 + for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list, + pos_in_game_segment_list, + to_play_segment): + target_values = [] + target_value_prefixs = [] + + value_prefix = 0.0 + base_index = state_index + for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + bootstrap_index = current_index + td_steps_list[value_index] + # for i, reward in enumerate(game.rewards[current_index:bootstrap_index]): + for i, reward in enumerate(reward_list[current_index:bootstrap_index]): + if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: + # TODO(pu): for board_games, very important, to check + if to_play_list[base_index] == to_play_list[i]: + value_list[value_index] += reward * self._cfg.discount_factor ** i + else: + value_list[value_index] += -reward * self._cfg.discount_factor ** i + else: + value_list[value_index] += reward * self._cfg.discount_factor ** i + # TODO(pu): why value don't use discount_factor factor + + # reset every lstm_horizon_len + if horizon_id % self._cfg.lstm_horizon_len == 0: + value_prefix = 0.0 + base_index = current_index + horizon_id += 1 + + if current_index < game_segment_len_non_re: + target_values.append(value_list[value_index]) + # Since the horizon is small and the discount_factor is close to 1. + # Compute the reward sum to approximate the value prefix for simplification + value_prefix += reward_list[current_index + ] # * config.discount_factor ** (current_index - base_index) + target_value_prefixs.append(value_prefix) + else: + target_values.append(0) + target_value_prefixs.append(value_prefix) + + value_index += 1 + + batch_value_prefixs.append(target_value_prefixs) + batch_target_values.append(target_values) + + batch_value_prefixs = np.asarray(batch_value_prefixs, dtype=object) + batch_target_values = np.asarray(batch_target_values, dtype=object) + + return batch_value_prefixs, batch_target_values + + def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any) -> np.ndarray: + """ + Overview: + prepare policy targets from the reanalyzed context of policies + Arguments: + - policy_re_context (:obj:`List`): List of policy context to reanalyzed + Returns: + - batch_target_policies_re + """ + if policy_re_context is None: + return [] + batch_target_policies_re = [] + + policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, action_mask_segment, \ + to_play_segment = policy_re_context # noqa + # transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1) + transition_batch_size = len(policy_obs_list) + game_segment_batch_size = len(pos_in_game_segment_list) + + to_play, action_mask = self._preprocess_to_play_and_action_mask( + game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list + ) + if self._cfg.model.continuous_action_space is True: + # when the action space of the environment is continuous, action_mask[:] is None. + action_mask = [ + list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) + ] + # NOTE: in continuous action space env, we set all legal_actions as -1 + legal_actions = [ + [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) + ] + else: + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] + + with torch.no_grad(): + policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type) + # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors + self._cfg.mini_infer_size = self._cfg.mini_infer_size + slices = np.ceil(transition_batch_size / self._cfg.mini_infer_size).astype(np.int_) + network_output = [] + for i in range(slices): + beg_index = self._cfg.mini_infer_size * i + end_index = self._cfg.mini_infer_size * (i + 1) + m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device).float() + + m_output = model.initial_inference(m_obs) + + if not model.training: + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + m_output.policy_logits + ] + ) + m_output.reward_hidden_state = ( + m_output.reward_hidden_state[0].detach().cpu().numpy(), + m_output.reward_hidden_state[1].detach().cpu().numpy() + ) + + network_output.append(m_output) + + _, value_prefix_pool, policy_logits_pool, latent_state_roots, reward_hidden_state_roots = concat_output( + network_output, data_type='efficientzero' + ) + + value_prefix_pool = value_prefix_pool.squeeze().tolist() + policy_logits_pool = policy_logits_pool.tolist() + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.num_of_sampled_actions + ).astype(np.float32).tolist() for _ in range(transition_batch_size) + ] + if self._cfg.mcts_ctree: + # ============================================================== + # sampled related core code + # ============================================================== + # cpp mcts_tree + roots = MCTSCtree.roots( + transition_batch_size, legal_actions, self._cfg.model.action_space_size, + self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space + ) + roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play) + # do MCTS for a new policy with the recent target model + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play) + else: + # python mcts_tree + roots = MCTSPtree.roots( + transition_batch_size, legal_actions, self._cfg.model.action_space_size, + self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space + ) + roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play) + # do MCTS for a new policy with the recent target model + MCTSPtree.roots(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play) + + roots_legal_actions_list = legal_actions + roots_distributions = roots.get_distributions() + + # ============================================================== + # fix reanalyze in sez + # ============================================================== + roots_sampled_actions = roots.get_sampled_actions() + try: + root_sampled_actions = np.array([action.value for action in roots_sampled_actions]) + except Exception: + root_sampled_actions = np.array([action for action in roots_sampled_actions]) + + policy_index = 0 + for state_index, game_idx in zip(pos_in_game_segment_list, batch_index_list): + target_policies = [] + for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + distributions = roots_distributions[policy_index] + # ============================================================== + # sampled related core code + # ============================================================== + if policy_mask[policy_index] == 0: + # NOTE: the invalid padding target policy, O is to make sure the correspoding cross_entropy_loss=0 + target_policies.append([0 for _ in range(self._cfg.model.num_of_sampled_actions)]) + else: + if distributions is None: + # if at some obs, the legal_action is None, then add the fake target_policy + target_policies.append( + list( + np.ones(self._cfg.model.num_of_sampled_actions) / + self._cfg.model.num_of_sampled_actions + ) + ) + else: + if self._cfg.env_type == 'not_board_games': + sum_visits = sum(distributions) + policy = [visit_count / sum_visits for visit_count in distributions] + target_policies.append(policy) + else: + # for two_player board games + policy_tmp = [0 for _ in range(self._cfg.model.num_of_sampled_actions)] + # to make sure target_policies have the same dimension + sum_visits = sum(distributions) + policy = [visit_count / sum_visits for visit_count in distributions] + for index, legal_action in enumerate(roots_legal_actions_list[policy_index]): + policy_tmp[legal_action] = policy[index] + target_policies.append(policy_tmp) + + policy_index += 1 + + batch_target_policies_re.append(target_policies) + + batch_target_policies_re = np.array(batch_target_policies_re) + + return batch_target_policies_re, root_sampled_actions + + def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) -> None: + """ + Overview: + Update the priority of training data. + Arguments: + - train_data (:obj:`Optional[List[Optional[np.ndarray]]]`): training data to be updated priority. + - batch_priorities (:obj:`batch_priorities`): priorities to update to. + NOTE: + train_data = [current_batch, target_batch] + current_batch = [obs_list, action_list, root_sampled_actions_list, mask_list, batch_index_list, weights_list, make_time_list] + """ + + batch_index_list = train_data[0][4] + metas = {'make_time': train_data[0][6], 'batch_priorities': batch_priorities} + # only update the priorities for data still in replay buffer + for i in range(len(batch_index_list)): + if metas['make_time'][i] > self.clear_time: + idx, prio = batch_index_list[i], metas['batch_priorities'][i] + self.game_pos_priorities[idx] = prio diff --git a/lzero/mcts/tree_search/mcts_ptree_sampled.py b/lzero/mcts/tree_search/mcts_ptree_sampled.py index 6c5b59cc7..bf26e58aa 100644 --- a/lzero/mcts/tree_search/mcts_ptree_sampled.py +++ b/lzero/mcts/tree_search/mcts_ptree_sampled.py @@ -177,7 +177,7 @@ def search( network_output.latent_state, network_output.policy_logits, self.inverse_scalar_transform_handle(network_output.value_prefix), - self.inverse_scalar_transform_handle(network_output.reward), + self.inverse_scalar_transform_handle(network_output.value_prefix), ] ) network_output.reward_hidden_state = ( diff --git a/lzero/model/gobigger/gobigger_sampled_efficientzero_model.py b/lzero/model/gobigger/gobigger_sampled_efficientzero_model.py new file mode 100644 index 000000000..8a73c8116 --- /dev/null +++ b/lzero/model/gobigger/gobigger_sampled_efficientzero_model.py @@ -0,0 +1,531 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from ding.model.common import ReparameterizationHead +from ding.torch_utils import MLP +from ding.utils import MODEL_REGISTRY, SequenceType + +from ..common import EZNetworkOutput, RepresentationNetworkMLP +from ..efficientzero_model_mlp import DynamicsNetworkMLP +from ..utils import renormalize, get_params_mean +from .gobigger_encoder import Encoder +import yaml +from easydict import EasyDict +from ding.utils.data import default_collate + + +@MODEL_REGISTRY.register('GoBiggerSampledEfficientZeroModel') +class GoBiggerSampledEfficientZeroModel(nn.Module): + + def __init__( + self, + observation_shape: int = 2, + action_space_size: int = 6, + latent_state_dim: int = 256, + lstm_hidden_size: int = 512, + fc_reward_layers: SequenceType = [32], + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + reward_support_size: int = 601, + value_support_size: int = 601, + proj_hid: int = 1024, + proj_out: int = 1024, + pred_hid: int = 512, + pred_out: int = 1024, + self_supervised_learning_loss: bool = True, + categorical_distribution: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + last_linear_layer_init_zero: bool = True, + state_norm: bool = False, + # ============================================================== + # specific sampled related config + # ============================================================== + continuous_action_space: bool = False, + num_of_sampled_actions: int = 6, + sigma_type='conditioned', + fixed_sigma_value: float = 0.3, + bound_type: str = None, + norm_type: str = 'BN', + discrete_action_encoding_type: str = 'one_hot', + res_connection_in_dynamics: bool = False, + *args, + **kwargs, + ): + """ + Overview: + The definition of the network model of Sampled EfficientZero, which is a generalization version for 1D vector obs. + The networks are mainly built on fully connected layers. + Sampled EfficientZero model consists of a representation network, a dynamics network and a prediction network. + The representation network is an MLP network which maps the raw observation to a latent state. + The dynamics network is an MLP+LSTM network which predicts the next latent state, reward_hidden_state and value_prefix given the current latent state and action. + The prediction network is an MLP network which predicts the value and policy given the current latent state. + Arguments: + - observation_shape (:obj:`int`): Observation space shape, e.g. 8 for Lunarlander. + - action_space_size: (:obj:`int`): Action space size, which is an integer number. For discrete action space, it is the num of discrete actions, \ + e.g. 4 for Lunarlander. For continuous action space, it is the dimension of the continuous action, e.g. 4 for bipedalwalker. + - latent_state_dim (:obj:`int`): The dimension of latent state, such as 256. + - lstm_hidden_size (:obj:`int`): The hidden size of LSTM in dynamics network to predict value_prefix. + - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - reward_support_size (:obj:`int`): The size of categorical reward output + - value_support_size (:obj:`int`): The size of categorical value output. + - proj_hid (:obj:`int`): The size of projection hidden layer. + - proj_out (:obj:`int`): The size of projection output layer. + - pred_hid (:obj:`int`): The size of prediction hidden layer. + - pred_out (:obj:`int`): The size of prediction output layer. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks in Sampled EfficientZero model, default set it to False. + - categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical distribution for value, reward/value_prefix. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. + - state_norm (:obj:`bool`): Whether to use normalization for latent states, default sets it to True. + # ============================================================== + # specific sampled related config + # ============================================================== + - continuous_action_space (:obj:`bool`): The type of action space. default set it to False. + - num_of_sampled_actions (:obj:`int`): the number of sampled actions, i.e. the K in original Sampled MuZero paper. + # see ``ReparameterizationHead`` in ``ding.model.common.head`` for more details about the following arguments. + - sigma_type (:obj:`str`): the type of sigma in policy head of prediction network, options={'conditioned', 'fixed'}. + - fixed_sigma_value (:obj:`float`): the fixed sigma value in policy head of prediction network, + - bound_type (:obj:`str`): The type of bound in networks. Default sets it to None. + - norm_type (:obj:`str`): The type of normalization in networks. default set it to 'BN'. + - discrete_action_encoding_type (:obj:`str`): The type of encoding for discrete action. Default sets it to 'one_hot'. options = {'one_hot', 'not_one_hot'} + - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. + """ + super(GoBiggerSampledEfficientZeroModel, self).__init__() + if not categorical_distribution: + self.reward_support_size = 1 + self.value_support_size = 1 + else: + self.reward_support_size = reward_support_size + self.value_support_size = value_support_size + + self.continuous_action_space = continuous_action_space + self.observation_shape = observation_shape + self.action_space_size = action_space_size + # The dim of action space. For discrete action space, it is 1. + # For continuous action space, it is the dimension of continuous action. + self.action_space_dim = action_space_size if self.continuous_action_space else 1 + assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type + self.discrete_action_encoding_type = discrete_action_encoding_type + if self.continuous_action_space: + self.action_encoding_dim = action_space_size + else: + if self.discrete_action_encoding_type == 'one_hot': + self.action_encoding_dim = action_space_size + elif self.discrete_action_encoding_type == 'not_one_hot': + self.action_encoding_dim = 1 + + self.lstm_hidden_size = lstm_hidden_size + self.latent_state_dim = latent_state_dim + self.fc_reward_layers = fc_reward_layers + self.fc_value_layers = fc_value_layers + self.fc_policy_layers = fc_policy_layers + self.proj_hid = proj_hid + self.proj_out = proj_out + self.pred_hid = pred_hid + self.pred_out = pred_out + + self.last_linear_layer_init_zero = last_linear_layer_init_zero + self.state_norm = state_norm + self.self_supervised_learning_loss = self_supervised_learning_loss + + self.sigma_type = sigma_type + self.fixed_sigma_value = fixed_sigma_value + self.bound_type = bound_type + self.norm_type = norm_type + self.num_of_sampled_actions = num_of_sampled_actions + self.res_connection_in_dynamics = res_connection_in_dynamics + + # self.representation_network = RepresentationNetworkMLP( + # observation_shape=self.observation_shape, hidden_channels=self.latent_state_dim, norm_type=norm_type + # ) + with open('lzero/model/gobigger/default_model_config.yaml', "r") as f: + encoder_cfg = yaml.safe_load(f) + encoder_cfg = EasyDict(encoder_cfg) + self.representation_network = Encoder(encoder_cfg) + + self.dynamics_network = DynamicsNetworkMLP( + action_encoding_dim=self.action_encoding_dim, + num_channels=self.latent_state_dim + self.action_encoding_dim, + common_layer_num=2, + lstm_hidden_size=self.lstm_hidden_size, + fc_reward_layers=self.fc_reward_layers, + output_support_size=self.reward_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type, + res_connection_in_dynamics=self.res_connection_in_dynamics, + ) + + self.prediction_network = PredictionNetworkMLP( + continuous_action_space=self.continuous_action_space, + action_space_size=self.action_space_size, + num_channels=self.latent_state_dim, + fc_value_layers=self.fc_value_layers, + fc_policy_layers=self.fc_policy_layers, + output_support_size=self.value_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + sigma_type=self.sigma_type, + fixed_sigma_value=self.fixed_sigma_value, + bound_type=self.bound_type, + norm_type=self.norm_type, + ) + + if self.self_supervised_learning_loss: + # self_supervised_learning_loss related network proposed in EfficientZero + self.projection_input_dim = latent_state_dim + self.projection = nn.Sequential( + nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) + ) + self.prediction_head = nn.Sequential( + nn.Linear(self.proj_out, self.pred_hid), + nn.BatchNorm1d(self.pred_hid), + activation, + nn.Linear(self.pred_hid, self.pred_out), + ) + + def initial_inference(self, obs: torch.Tensor) -> EZNetworkOutput: + """ + Overview: + Initial inference of SampledEfficientZero model, which is the first step of the SampledEfficientZero model. + To perform the initial inference, we first use the representation network to obtain the "latent_state" of the observation. + Then we use the prediction network to predict the "value" and "policy_logits" of the "latent_state", and + also prepare the zeros-like ``reward_hidden_state`` for the next step of the Sampled EfficientZero model. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns (EZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. \ + In initial inference, we set it to zero vector. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The hidden state of LSTM about reward. In initial inference, \ + we set it to the zeros-like hidden state (H and C). + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size. + """ + batch_size = len(obs) + obs = default_collate(obs) + latent_state = self._representation(obs) + device = latent_state.device + policy_logits, value = self._prediction(latent_state) + # zero initialization for reward hidden states + # (hn, cn), each element shape is (layer_num=1, batch_size, lstm_hidden_size) + reward_hidden_state = ( + torch.zeros(1, batch_size, + self.lstm_hidden_size).to(device), torch.zeros(1, batch_size, + self.lstm_hidden_size).to(device) + ) + return EZNetworkOutput(value, [0. for _ in range(batch_size)], policy_logits, latent_state, reward_hidden_state) + + def recurrent_inference( + self, latent_state: torch.Tensor, reward_hidden_state: torch.Tensor, action: torch.Tensor + ) -> EZNetworkOutput: + """ + Overview: + Recurrent inference of Sampled EfficientZero model, which is the rollout step of the Sampled EfficientZero model. + To perform the recurrent inference, we first use the dynamics network to predict ``next_latent_state``, + ``reward_hidden_state``, ``value_prefix`` by the given current ``latent_state`` and ``action``. + We then use the prediction network to predict the ``value`` and ``policy_logits``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns (EZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. + Shapes: + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size. + """ + next_latent_state, reward_hidden_state, value_prefix = self._dynamics(latent_state, reward_hidden_state, action) + policy_logits, value = self._prediction(next_latent_state) + return EZNetworkOutput(value, value_prefix, policy_logits, next_latent_state, reward_hidden_state) + + def _representation(self, observation: torch.Tensor) -> Tuple[torch.Tensor]: + """ + Overview: + Use the representation network to encode the observations into latent state. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + """ + latent_state = self.representation_network(observation) + if self.state_norm: + latent_state = renormalize(latent_state) + return latent_state + + def _prediction(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Use the representation network to encode the observations into latent state. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns: + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + """ + policy, value = self.prediction_network(latent_state) + return policy, value + + def _dynamics(self, latent_state: torch.Tensor, reward_hidden_state: Tuple, + action: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor], torch.Tensor]: + """ + Overview: + Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` + ``value_prefix`` and ``next_reward_hidden_state``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns: + - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. + - next_reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + """ + # NOTE: the discrete action encoding type is important for some environments + + if not self.continuous_action_space: + # discrete action space + if self.discrete_action_encoding_type == 'one_hot': + # Stack latent_state with the one hot encoded action + if len(action.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action = action.unsqueeze(-1) + + # transform action to one-hot encoding. + # action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4) + action_one_hot = torch.zeros(action.shape[0], self.action_space_size, device=action.device) + # transform action to torch.int64 + action = action.long() + action_one_hot.scatter_(1, action, 1) + action_encoding = action_one_hot + elif self.discrete_action_encoding_type == 'not_one_hot': + action_encoding = action / self.action_space_size + if len(action_encoding.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action_encoding = action_encoding.unsqueeze(-1) + else: + # continuous action space + if len(action.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action = action.unsqueeze(-1) + elif len(action.shape) == 3: + # (batch_size, action_dim, 1) -> (batch_size, action_dim) + # e.g., torch.Size([8, 2, 1]) -> torch.Size([8, 2]) + action = action.squeeze(-1) + + action_encoding = action + + action_encoding = action_encoding.to(latent_state.device).float() + # state_action_encoding shape: (batch_size, latent_state[1] + action_dim]) or + # (batch_size, latent_state[1] + action_space_size]) depending on the discrete_action_encoding_type. + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + + next_latent_state, next_reward_hidden_state, value_prefix = self.dynamics_network( + state_action_encoding, reward_hidden_state + ) + + if not self.state_norm: + return next_latent_state, next_reward_hidden_state, value_prefix + else: + next_latent_state_normalized = renormalize(next_latent_state) + return next_latent_state_normalized, next_reward_hidden_state, value_prefix + + def project(self, latent_state: torch.Tensor, with_grad=True) -> torch.Tensor: + """ + Overview: + Project the latent state to a lower dimension to calculate the self-supervised loss, which is proposed in EfficientZero. + For more details, please refer to the paper ``Exploring Simple Siamese Representation Learning``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - with_grad (:obj:`bool`): Whether to calculate gradient for the projection result. + Returns: + - proj (:obj:`torch.Tensor`): The result embedding vector of projection operation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - proj (:obj:`torch.Tensor`): :math:`(B, projection_output_dim)`, where B is batch_size. + + Examples: + >>> latent_state = torch.randn(256, 64) + >>> output = self.project(latent_state) + >>> output.shape # (256, 1024) + """ + proj = self.projection(latent_state) + + if with_grad: + # with grad, use prediction_head + return self.prediction_head(proj) + else: + return proj.detach() + + def get_params_mean(self): + return get_params_mean(self) + + +class PredictionNetworkMLP(nn.Module): + + def __init__( + self, + continuous_action_space, + action_space_size, + num_channels, + common_layer_num: int = 2, + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + output_support_size: int = 601, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + # ============================================================== + # specific sampled related config + # ============================================================== + sigma_type='conditioned', + fixed_sigma_value: float = 0.3, + bound_type: str = None, + norm_type: str = 'BN', + ): + """ + Overview: + The definition of policy and value prediction network, which is used to predict value and policy by the + given latent state. + The networks are mainly built on fully connected layers. + Arguments: + - continuous_action_space (:obj:`bool`): The type of action space. default set it to False. + - action_space_size: (:obj:`int`): Action space size, usually an integer number. For discrete action \ + space, it is the number of discrete actions. For continuous action space, it is the dimension of \ + continuous action. + - num_channels (:obj:`int`): The num of channels in latent states. + - num_res_blocks (:obj:`int`): The number of res blocks. + - fc_value_layers (:obj:`SequenceType`): hidden layers of the value prediction head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): hidden layers of the policy prediction head (MLP head). + - output_support_size (:obj:`int`): dim of value output. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. + # ============================================================== + # specific sampled related config + # ============================================================== + # see ``ReparameterizationHead`` in ``ding.model.common.head`` for more details about thee following arguments. + - sigma_type (:obj:`str`): the type of sigma in policy head of prediction network, options={'conditioned', 'fixed'}. + - fixed_sigma_value (:obj:`float`): the fixed sigma value in policy head of prediction network, + - bound_type (:obj:`str`): The type of bound in networks. default set it to None. + - norm_type (:obj:`str`): The type of normalization in networks. default set it to 'BN'. + """ + super().__init__() + self.num_channels = num_channels + self.continuous_action_space = continuous_action_space + self.norm_type = norm_type + self.sigma_type = sigma_type + self.fixed_sigma_value = fixed_sigma_value + self.bound_type = bound_type + self.action_space_size = action_space_size + if self.continuous_action_space: + self.action_encoding_dim = self.action_space_size + else: + self.action_encoding_dim = 1 + + # ******* common backbone ****** + self.fc_prediction_common = MLP( + in_channels=self.num_channels, + hidden_channels=self.num_channels, + out_channels=self.num_channels, + layer_num=common_layer_num, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + + # ******* value and policy head ****** + self.fc_value_head = MLP( + in_channels=self.num_channels, + hidden_channels=fc_value_layers[0], + out_channels=output_support_size, + layer_num=2, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + # sampled related core code + if self.continuous_action_space: + self.fc_policy_head = ReparameterizationHead( + input_size=self.num_channels, + output_size=action_space_size, + layer_num=2, + sigma_type=self.sigma_type, + fixed_sigma_value=self.fixed_sigma_value, + activation=nn.ReLU(), + norm_type=None, + bound_type=self.bound_type + ) + else: + self.fc_policy_head = MLP( + in_channels=self.num_channels, + hidden_channels=fc_policy_layers[0], + out_channels=action_space_size, + layer_num=2, + activation=activation, + norm_type=self.norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Forward computation of the prediction network. + Arguments: + - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, in_channels). + Returns: + - policy (:obj:`torch.Tensor`): policy tensor. If action space is discrete, shape is (B, action_space_size). + If action space is continuous, shape is (B, action_space_size * 2). + - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). + """ + x_prediction_common = self.fc_prediction_common(latent_state) + value = self.fc_value_head(x_prediction_common) + + # sampled related core code + policy = self.fc_policy_head(x_prediction_common) + if self.continuous_action_space: + policy = torch.cat([policy['mu'], policy['sigma']], dim=-1) + + return policy, value diff --git a/lzero/policy/gobigger_sampled_efficientzero.py b/lzero/policy/gobigger_sampled_efficientzero.py new file mode 100644 index 000000000..0f16c7e22 --- /dev/null +++ b/lzero/policy/gobigger_sampled_efficientzero.py @@ -0,0 +1,1176 @@ +import copy +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +import torch.optim as optim +from ding.model import model_wrap +from ding.policy.base_policy import Policy +from ding.torch_utils import to_tensor +from ding.utils import POLICY_REGISTRY +from ditk import logging +from torch.distributions import Categorical, Independent, Normal +from torch.nn import L1Loss + +from lzero.mcts import SampledEfficientZeroMCTSCtree as MCTSCtree +from lzero.mcts import SampledEfficientZeroMCTSPtree as MCTSPtree +from lzero.model import ImageTransforms +from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ + DiscreteSupport, to_torch_float_tensor, ez_network_output_unpack, select_action, negative_cosine_similarity, prepare_obs, \ + configure_optimizers +from collections import defaultdict +from ding.torch_utils import to_device + + +@POLICY_REGISTRY.register('gobigger_sampled_efficientzero') +class GoBiggerSampledEfficientZeroPolicy(Policy): + """ + Overview: + The policy class for Sampled EfficientZero. + """ + + # The default_config for Sampled fEficientZero policy. + config = dict( + model=dict( + # (str) The model type. For 1-dimensional vector obs, we use mlp model. For 3-dimensional image obs, we use conv model. + model_type='conv', # options={'mlp', 'conv'} + # (bool) If True, the action space of the environment is continuous, otherwise discrete. + continuous_action_space=False, + # (tuple) the stacked obs shape. + # observation_shape=(1, 96, 96), # if frame_stack_num=1 + observation_shape=(4, 96, 96), # if frame_stack_num=4 + # (bool) Whether to use the self-supervised learning loss. + self_supervised_learning_loss=True, + # (int) The size of action space. For discrete action space, it is the number of actions. + # For continuous action space, it is the dimension of action. + action_space_size=6, + # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. + categorical_distribution=True, + # (int) the image channel in image observation. + image_channel=1, + # (int) The number of frames to stack together. + frame_stack_num=1, + # (int) The scale of supports used in categorical distribution. + # This variable is only effective when ``categorical_distribution=True``. + support_scale=300, + # (int) The hidden size in LSTM. + lstm_hidden_size=512, + # (str) The type of sigma. options={'conditioned', 'fixed'} + sigma_type='conditioned', + # (float) The fixed sigma value. Only effective when ``sigma_type='fixed'``. + fixed_sigma_value=0.3, + # (bool) whether to learn bias in the last linear layer in value and policy head. + bias=True, + # (str) The type of action encoding. Options are ['one_hot', 'not_one_hot']. Default to 'one_hot'. + discrete_action_encoding_type='one_hot', + # (bool) whether to use res connection in dynamics. + res_connection_in_dynamics=True, + # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'LN'. + norm_type='BN', + ), + # ****** common ****** + # (bool) ``sampled_algo=True`` means the policy is sampled-based algorithm (e.g. Sampled EfficientZero), which is used in ``collector``. + sampled_algo=True, + # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. + mcts_ctree=True, + # (bool) Whether to use cuda in policy. + cuda=True, + # (int) The number of environments used in collecting data. + collector_env_num=8, + # (int) The number of environments used in evaluating policy. + evaluator_env_num=3, + # (str) The type of environment. The options are ['not_board_games', 'board_games']. + env_type='not_board_games', + # (str) The type of battle mode. The options are ['play_with_bot_mode', 'self_play_mode']. + battle_mode='play_with_bot_mode', + # (bool) Whether to monitor extra statistics in tensorboard. + monitor_extra_statistics=True, + # (int) The transition number of one ``GameSegment``. + game_segment_length=200, + + # ****** observation ****** + # (bool) Whether to transform image to string to save memory. + transform2string=False, + # (bool) Whether to use data augmentation. + use_augmentation=False, + # (list) The style of augmentation. + augmentation=['shift', 'intensity'], + + # ******* learn ****** + # (int) How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + # For different env, we have different episode_length, + # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor + update_per_collect=100, + # (int) Minibatch size for one gradient descent. + batch_size=256, + # (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW'] + optim_type='SGD', + learning_rate=0.2, # init lr for manually decay schedule + # optim_type='Adam', + # lr_piecewise_constant_decay=False, + # learning_rate=0.003, # lr for Adam optimizer + # (float) Weight uniform initialization range in the last output layer + init_w=3e-3, + normalize_prob_of_sampled_actions=False, + policy_loss_type='cross_entropy', # options={'cross_entropy', 'KL'} + # (int) Frequency of target network update. + target_update_freq=100, + weight_decay=1e-4, + momentum=0.9, + grad_clip_value=10, + # You can use either "n_sample" or "n_episode" in collector.collect. + # Get "n_episode" episodes per collect. + n_episode=8, + # (float) the number of simulations in MCTS. + num_simulations=50, + # (float) Discount factor (gamma) for returns. + discount_factor=0.997, + # (int) The number of step for calculating target q_value. + td_steps=5, + # (int) The number of unroll steps in dynamics network. + num_unroll_steps=5, + # (int) reset the hidden states in LSTM every ``lstm_horizon_len`` horizon steps. + lstm_horizon_len=5, + # (float) The weight of reward loss. + reward_loss_weight=1, + # (float) The weight of value loss. + value_loss_weight=0.25, + # (float) The weight of policy loss. + policy_loss_weight=1, + # (float) The weight of policy entropy loss. + policy_entropy_loss_weight=0, + # (float) The weight of ssl (self-supervised learning) loss. + ssl_loss_weight=2, + # (bool) Whether to use the cosine learning rate decay. + cos_lr_scheduler=False, + # (bool) Whether to use piecewise constant learning rate decay. + # i.e. lr: 0.2 -> 0.02 -> 0.002 + lr_piecewise_constant_decay=True, + # (int) The number of final training iterations to control lr decay, which is only used for manually decay. + threshold_training_steps_for_final_lr=int(5e4), + # (int) The number of final training iterations to control temperature, which is only used for manually decay. + threshold_training_steps_for_final_temperature=int(1e5), + # (bool) Whether to use manually decayed temperature. + # i.e. temperature: 1 -> 0.5 -> 0.25 + manual_temperature_decay=False, + # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. + # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. + fixed_temperature_value=0.25, + + # ****** Priority ****** + # (bool) Whether to use priority when sampling training data from the buffer. + use_priority=True, + # (bool) Whether to use the maximum priority for new collecting data. + use_max_priority_for_new_data=True, + # (float) The degree of prioritization to use. A value of 0 means no prioritization, + # while a value of 1 means full prioritization. + priority_prob_alpha=0.6, + # (float) The degree of correction to use. A value of 0 means no correction, + # while a value of 1 means full correction. + priority_prob_beta=0.4, + + # ****** UCB ****** + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of the search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default model setting. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. + - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. + - import_names (:obj:`List[str]`): The model class path list used in this algorithm. + + .. note:: + The user can define and use customized network model but must obey the same interface definition indicated \ + by import_names path. For Sampled EfficientZero, ``lzero.model.sampled_efficientzero_model.SampledEfficientZeroModel`` + """ + return 'GoBiggerSampledEfficientZeroModel', ['lzero.model.gobigger.gobigger_sampled_efficientzero_model'] + + def _init_learn(self) -> None: + """ + Overview: + Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. + """ + assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type + if self._cfg.model.continuous_action_space: + # Weight Init for the last output layer of gaussian policy head in prediction network. + init_w = self._cfg.init_w + self._model.prediction_network.fc_policy_head.mu.weight.data.uniform_(-init_w, init_w) + self._model.prediction_network.fc_policy_head.mu.bias.data.uniform_(-init_w, init_w) + self._model.prediction_network.fc_policy_head.log_sigma_layer.weight.data.uniform_(-init_w, init_w) + try: + self._model.prediction_network.fc_policy_head.log_sigma_layer.bias.data.uniform_(-init_w, init_w) + except Exception as exception: + logging.warning(exception) + + if self._cfg.optim_type == 'SGD': + self._optimizer = optim.SGD( + self._model.parameters(), + lr=self._cfg.learning_rate, + momentum=self._cfg.momentum, + weight_decay=self._cfg.weight_decay, + ) + + elif self._cfg.optim_type == 'Adam': + self._optimizer = optim.Adam( + self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay + ) + elif self._cfg.optim_type == 'AdamW': + self._optimizer = configure_optimizers( + model=self._model, + weight_decay=self._cfg.weight_decay, + learning_rate=self._cfg.learning_rate, + device_type=self._cfg.device + ) + + if self._cfg.cos_lr_scheduler is True: + from torch.optim.lr_scheduler import CosineAnnealingLR + self.lr_scheduler = CosineAnnealingLR(self._optimizer, 1e6, eta_min=0, last_epoch=-1) + + if self._cfg.lr_piecewise_constant_decay: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda) + + # use model_wrapper for specialized demands of different modes + self._target_model = copy.deepcopy(self._model) + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq} + ) + self._learn_model = self._model + + if self._cfg.use_augmentation: + self.image_transforms = ImageTransforms( + self._cfg.augmentation, + image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + ) + self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + + def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning policy in learn mode, which is the core of the learning process. + The data is sampled from replay buffer. + The loss is calculated by the loss function and the loss is backpropagated to update the model. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. + The first tensor is the current_batch, the second tensor is the target_batch. + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ + current learning loss and learning statistics. + """ + self._learn_model.train() + self._target_model.train() + + current_batch, target_batch = data + # ============================================================== + # sampled related core code + # ============================================================== + obs_batch_ori, action_batch, child_sampled_actions_batch, mask_batch, indices, weights, make_time = current_batch + target_value_prefix, target_value, target_policy = target_batch + + obs_batch_ori = obs_batch_ori.tolist() + obs_batch_ori = np.array(obs_batch_ori) + obs_batch = obs_batch_ori[:, 0:self._cfg.model.frame_stack_num] + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = obs_batch_ori[:, self._cfg.model.frame_stack_num:] + + # do augmentations + # if self._cfg.use_augmentation: + # obs_batch = self.image_transforms.transform(obs_batch) + # if self._cfg.model.self_supervised_learning_loss: + # obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # shape: (batch_size, num_unroll_steps, action_dim) + # NOTE: .float(), in continuous action space. + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).float().unsqueeze(-1) + data_list = [ + mask_batch, + target_value_prefix.astype('float64'), + target_value.astype('float64'), target_policy, weights + ] + [mask_batch, target_value_prefix, target_value, target_policy, + weights] = to_torch_float_tensor(data_list, self._cfg.device) + # ============================================================== + # sampled related core code + # ============================================================== + # shape: (batch_size, num_unroll_steps+1, num_of_sampled_actions, action_dim, 1), e.g. (4, 6, 5, 1, 1) + child_sampled_actions_batch = torch.from_numpy(child_sampled_actions_batch).to(self._cfg.device).unsqueeze(-1) + + target_value_prefix = target_value_prefix.view(self._cfg.batch_size, -1) + target_value = target_value.view(self._cfg.batch_size, -1) + + assert obs_batch.size == self._cfg.batch_size == target_value_prefix.size(0) + + # ``scalar_transform`` to transform the original value to the scaled value, + # i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + transformed_target_value_prefix = scalar_transform(target_value_prefix) + transformed_target_value = scalar_transform(target_value) + # transform a scalar to its categorical_distribution. After this transformation, each scalar is + # represented as the linear combination of its two adjacent supports. + target_value_prefix_categorical = phi_transform(self.reward_support, transformed_target_value_prefix) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # ============================================================== + # the core initial_inference in SampledEfficientZero policy. + # ============================================================== + obs_batch = obs_batch.tolist() + obs_batch = sum(obs_batch, []) + obs_batch = to_tensor(obs_batch) + obs_batch = to_device(obs_batch, self._cfg.device) + network_output = self._learn_model.initial_inference(obs_batch) + # value_prefix shape: (batch_size, 10), the ``value_prefix`` at the first step is zero padding. + latent_state, value_prefix, reward_hidden_state, value, policy_logits = ez_network_output_unpack(network_output) + + # transform the scaled value or its categorical representation to its original value, + # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + original_value = self.inverse_scalar_transform_handle(value) + + # Note: The following lines are just for logging. + predicted_value_prefixs = [] + if self._cfg.monitor_extra_statistics: + latent_state_list = latent_state.detach().cpu().numpy() + predicted_values, predicted_policies = original_value.detach().cpu(), torch.softmax( + policy_logits, dim=1 + ).detach().cpu() + + # calculate the new priorities for each transition. + value_priority = L1Loss(reduction='none')(original_value.squeeze(-1), target_value[:, 0]) + value_priority = value_priority.data.cpu().numpy() + 1e-6 + + # ============================================================== + # calculate policy and value loss for the first step. + # ============================================================== + value_loss = cross_entropy_loss(value, target_value_categorical[:, 0]) + + policy_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + # ============================================================== + # sampled related core code: calculate policy loss, typically cross_entropy_loss + # ============================================================== + if self._cfg.model.continuous_action_space: + """continuous action space""" + policy_loss, policy_entropy, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma = self._calculate_policy_loss_cont( + policy_loss, policy_logits, target_policy, mask_batch, child_sampled_actions_batch, unroll_step=0 + ) + else: + """discrete action space""" + policy_loss, policy_entropy, policy_entropy_loss, target_policy_entropy, target_sampled_actions = self._calculate_policy_loss_disc( + policy_loss, policy_logits, target_policy, mask_batch, child_sampled_actions_batch, unroll_step=0 + ) + + value_prefix_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + consistency_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + + gradient_scale = 1 / self._cfg.num_unroll_steps + + # ============================================================== + # the core recurrent_inference in SampledEfficientZero policy. + # ============================================================== + for step_i in range(self._cfg.num_unroll_steps): + # unroll with the dynamics function: predict the next ``latent_state``, ``reward_hidden_state``, + # `` value_prefix`` given current ``latent_state`` ``reward_hidden_state`` and ``action``. + # And then predict policy_logits and value with the prediction function. + network_output = self._learn_model.recurrent_inference( + latent_state, reward_hidden_state, action_batch[:, step_i] + ) + latent_state, value_prefix, reward_hidden_state, value, policy_logits = ez_network_output_unpack( + network_output + ) + + # transform the scaled value or its categorical representation to its original value, + # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + original_value = self.inverse_scalar_transform_handle(value) + + if self._cfg.model.self_supervised_learning_loss: + # ============================================================== + # calculate consistency loss for the next ``num_unroll_steps`` unroll steps. + # ============================================================== + if self._cfg.ssl_loss_weight > 0: + beg_index = step_i + end_index = step_i + self._cfg.model.frame_stack_num + obs_target_batch_tmp = obs_target_batch[:, beg_index:end_index].tolist() + obs_target_batch_tmp = sum(obs_target_batch_tmp, []) + obs_target_batch_tmp = to_tensor(obs_target_batch_tmp) + obs_target_batch_tmp = to_device(obs_target_batch_tmp, self._cfg.device) + network_output = self._learn_model.initial_inference(obs_target_batch_tmp) + + + latent_state = to_tensor(latent_state) + representation_state = to_tensor(network_output.latent_state) + + # NOTE: no grad for the representation_state branch. + dynamic_proj = self._learn_model.project(latent_state, with_grad=True) + observation_proj = self._learn_model.project(representation_state, with_grad=False) + temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_i] + + consistency_loss += temp_loss + + # NOTE: the target policy, target_value_categorical, target_value_prefix_categorical is calculated in + # game buffer now. + # ============================================================== + # sampled related core code: + # calculate policy loss for the next ``num_unroll_steps`` unroll steps. + # NOTE: the += in policy loss. + # ============================================================== + if self._cfg.model.continuous_action_space: + """continuous action space""" + policy_loss, policy_entropy, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma = self._calculate_policy_loss_cont( + policy_loss, + policy_logits, + target_policy, + mask_batch, + child_sampled_actions_batch, + unroll_step=step_i + 1 + ) + else: + """discrete action space""" + policy_loss, policy_entropy, policy_entropy_loss, target_policy_entropy, target_sampled_actions = self._calculate_policy_loss_disc( + policy_loss, + policy_logits, + target_policy, + mask_batch, + child_sampled_actions_batch, + unroll_step=step_i + 1 + ) + + value_loss += cross_entropy_loss(value, target_value_categorical[:, step_i + 1]) + value_prefix_loss += cross_entropy_loss(value_prefix, target_value_prefix_categorical[:, step_i]) + + # reset hidden states every ``lstm_horizon_len`` unroll steps. + if (step_i + 1) % self._cfg.lstm_horizon_len == 0: + reward_hidden_state = ( + torch.zeros(1, self._cfg.batch_size, self._cfg.model.lstm_hidden_size).to(self._cfg.device), + torch.zeros(1, self._cfg.batch_size, self._cfg.model.lstm_hidden_size).to(self._cfg.device) + ) + + if self._cfg.monitor_extra_statistics: + original_value_prefixs = self.inverse_scalar_transform_handle(value_prefix) + original_value_prefixs_cpu = original_value_prefixs.detach().cpu() + + predicted_values = torch.cat( + (predicted_values, self.inverse_scalar_transform_handle(value).detach().cpu()) + ) + predicted_value_prefixs.append(original_value_prefixs_cpu) + predicted_policies = torch.cat((predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu())) + latent_state_list = np.concatenate((latent_state_list, latent_state.detach().cpu().numpy())) + + # ============================================================== + # the core learn model update step. + # ============================================================== + # weighted loss with masks (some invalid states which are out of trajectory.) + loss = ( + self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss + + self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * value_prefix_loss + + self._cfg.policy_entropy_loss_weight * policy_entropy_loss + ) + weighted_total_loss = (weights * loss).mean() + weighted_total_loss.register_hook(lambda grad: grad * gradient_scale) + self._optimizer.zero_grad() + weighted_total_loss.backward() + total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_( + self._learn_model.parameters(), self._cfg.grad_clip_value + ) + self._optimizer.step() + if self._cfg.cos_lr_scheduler is True or self._cfg.lr_piecewise_constant_decay is True: + self.lr_scheduler.step() + + # ============================================================== + # the core target model update step. + # ============================================================== + self._target_model.update(self._learn_model.state_dict()) + + loss_data = ( + weighted_total_loss.item(), loss.mean().item(), policy_loss.mean().item(), value_prefix_loss.mean().item(), + value_loss.mean().item(), consistency_loss.mean() + ) + if self._cfg.monitor_extra_statistics: + predicted_value_prefixs = torch.stack(predicted_value_prefixs).transpose(1, 0).squeeze(-1) + predicted_value_prefixs = predicted_value_prefixs.reshape(-1).unsqueeze(-1) + + td_data = ( + value_priority, target_value_prefix.detach().cpu().numpy(), target_value.detach().cpu().numpy(), + transformed_target_value_prefix.detach().cpu().numpy(), transformed_target_value.detach().cpu().numpy(), + target_value_prefix_categorical.detach().cpu().numpy(), target_value_categorical.detach().cpu().numpy(), + predicted_value_prefixs.detach().cpu().numpy(), predicted_values.detach().cpu().numpy(), + target_policy.detach().cpu().numpy(), predicted_policies.detach().cpu().numpy(), latent_state_list + ) + + if self._cfg.model.continuous_action_space: + return { + 'cur_lr': self._optimizer.param_groups[0]['lr'], + 'collect_mcts_temperature': self.collect_mcts_temperature, + 'weighted_total_loss': loss_data[0], + 'total_loss': loss_data[1], + 'policy_loss': loss_data[2], + 'policy_entropy': policy_entropy.item() / (self._cfg.num_unroll_steps + 1), + 'target_policy_entropy': target_policy_entropy.item() / (self._cfg.num_unroll_steps + 1), + 'value_prefix_loss': loss_data[3], + 'value_loss': loss_data[4], + 'consistency_loss': loss_data[5] / self._cfg.num_unroll_steps, + + # ============================================================== + # priority related + # ============================================================== + 'value_priority': td_data[0].flatten().mean().item(), + 'value_priority_orig': value_priority, + 'target_value_prefix': td_data[1].flatten().mean().item(), + 'target_value': td_data[2].flatten().mean().item(), + 'transformed_target_value_prefix': td_data[3].flatten().mean().item(), + 'transformed_target_value': td_data[4].flatten().mean().item(), + 'predicted_value_prefixs': td_data[7].flatten().mean().item(), + 'predicted_values': td_data[8].flatten().mean().item(), + + # ============================================================== + # sampled related core code + # ============================================================== + 'policy_mu_max': mu[:, 0].max().item(), + 'policy_mu_min': mu[:, 0].min().item(), + 'policy_mu_mean': mu[:, 0].mean().item(), + 'policy_sigma_max': sigma.max().item(), + 'policy_sigma_min': sigma.min().item(), + 'policy_sigma_mean': sigma.mean().item(), + # take the fist dim in action space + 'target_sampled_actions_max': target_sampled_actions[:, :, 0].max().item(), + 'target_sampled_actions_min': target_sampled_actions[:, :, 0].min().item(), + 'target_sampled_actions_mean': target_sampled_actions[:, :, 0].mean().item(), + 'total_grad_norm_before_clip': total_grad_norm_before_clip + } + else: + return { + 'cur_lr': self._optimizer.param_groups[0]['lr'], + 'collect_mcts_temperature': self.collect_mcts_temperature, + 'weighted_total_loss': loss_data[0], + 'total_loss': loss_data[1], + 'policy_loss': loss_data[2], + 'policy_entropy': policy_entropy.item() / (self._cfg.num_unroll_steps + 1), + 'target_policy_entropy': target_policy_entropy.item() / (self._cfg.num_unroll_steps + 1), + 'value_prefix_loss': loss_data[3], + 'value_loss': loss_data[4], + 'consistency_loss': loss_data[5] / self._cfg.num_unroll_steps, + + # ============================================================== + # priority related + # ============================================================== + 'value_priority': td_data[0].flatten().mean().item(), + 'value_priority_orig': value_priority, + 'target_value_prefix': td_data[1].flatten().mean().item(), + 'target_value': td_data[2].flatten().mean().item(), + 'transformed_target_value_prefix': td_data[3].flatten().mean().item(), + 'transformed_target_value': td_data[4].flatten().mean().item(), + 'predicted_value_prefixs': td_data[7].flatten().mean().item(), + 'predicted_values': td_data[8].flatten().mean().item(), + + # ============================================================== + # sampled related core code + # ============================================================== + # take the fist dim in action space + 'target_sampled_actions_max': target_sampled_actions[:, :].float().max().item(), + 'target_sampled_actions_min': target_sampled_actions[:, :].float().min().item(), + 'target_sampled_actions_mean': target_sampled_actions[:, :].float().mean().item(), + 'total_grad_norm_before_clip': total_grad_norm_before_clip + } + + def _calculate_policy_loss_cont( + self, policy_loss: torch.Tensor, policy_logits: torch.Tensor, target_policy: torch.Tensor, + mask_batch: torch.Tensor, child_sampled_actions_batch: torch.Tensor, unroll_step: int + ) -> Tuple[torch.Tensor]: + """ + Overview: + Calculate the policy loss for continuous action space. + Arguments: + - policy_loss (:obj:`torch.Tensor`): The policy loss tensor. + - policy_logits (:obj:`torch.Tensor`): The policy logits tensor. + - target_policy (:obj:`torch.Tensor`): The target policy tensor. + - mask_batch (:obj:`torch.Tensor`): The mask tensor. + - child_sampled_actions_batch (:obj:`torch.Tensor`): The child sampled actions tensor. + - unroll_step (:obj:`int`): The unroll step. + Returns: + - policy_loss (:obj:`torch.Tensor`): The policy loss tensor. + - policy_entropy (:obj:`torch.Tensor`): The policy entropy tensor. + - policy_entropy_loss (:obj:`torch.Tensor`): The policy entropy loss tensor. + - target_policy_entropy (:obj:`torch.Tensor`): The target policy entropy tensor. + - target_sampled_actions (:obj:`torch.Tensor`): The target sampled actions tensor. + - mu (:obj:`torch.Tensor`): The mu tensor. + - sigma (:obj:`torch.Tensor`): The sigma tensor. + """ + (mu, sigma + ) = policy_logits[:, :self._cfg.model.action_space_size], policy_logits[:, -self._cfg.model.action_space_size:] + + dist = Independent(Normal(mu, sigma), 1) + + # take the init hypothetical step k=unroll_step + target_normalized_visit_count = target_policy[:, unroll_step] + + # ******* NOTE: target_policy_entropy is only for debug. ****** + non_masked_indices = torch.nonzero(mask_batch[:, unroll_step]).squeeze(-1) + # Check if there are any unmasked rows + if len(non_masked_indices) > 0: + target_normalized_visit_count_masked = torch.index_select( + target_normalized_visit_count, 0, non_masked_indices + ) + target_dist = Categorical(target_normalized_visit_count_masked) + target_policy_entropy = target_dist.entropy().mean() + else: + # Set target_policy_entropy to 0 if all rows are masked + target_policy_entropy = 0 + + # shape: (batch_size, num_unroll_steps, num_of_sampled_actions, action_dim, 1) -> (batch_size, + # num_of_sampled_actions, action_dim) e.g. (4, 6, 20, 2, 1) -> (4, 20, 2) + target_sampled_actions = child_sampled_actions_batch[:, unroll_step].squeeze(-1) + + policy_entropy = dist.entropy().mean() + policy_entropy_loss = -dist.entropy() + + # Project the sampled-based improved policy back onto the space of representable policies. calculate KL + # loss (batch_size, num_of_sampled_actions) -> (4,20) target_normalized_visit_count is + # categorical distribution, the range of target_log_prob_sampled_actions is (-inf, 0), add 1e-6 for + # numerical stability. + target_log_prob_sampled_actions = torch.log(target_normalized_visit_count + 1e-6) + log_prob_sampled_actions = [] + for k in range(self._cfg.model.num_of_sampled_actions): + # target_sampled_actions[:,i,:].shape: batch_size, action_dim -> 4,2 + # dist.log_prob(target_sampled_actions[:,i,:]).shape: batch_size -> 4 + # dist is normal distribution, the range of log_prob_sampled_actions is (-inf, inf) + + # way 1: + # log_prob = dist.log_prob(target_sampled_actions[:, k, :]) + + # way 2: SAC-like + y = 1 - target_sampled_actions[:, k, :].pow(2) + + # NOTE: for numerical stability. + target_sampled_actions_clamped = torch.clamp( + target_sampled_actions[:, k, :], torch.tensor(-1 + 1e-6), torch.tensor(1 - 1e-6) + ) + target_sampled_actions_before_tanh = torch.arctanh(target_sampled_actions_clamped) + + # keep dimension for loss computation (usually for action space is 1 env. e.g. pendulum) + log_prob = dist.log_prob(target_sampled_actions_before_tanh).unsqueeze(-1) + log_prob = log_prob - torch.log(y + 1e-6).sum(-1, keepdim=True) + log_prob = log_prob.squeeze(-1) + + log_prob_sampled_actions.append(log_prob) + + # shape: (batch_size, num_of_sampled_actions) e.g. (4,20) + log_prob_sampled_actions = torch.stack(log_prob_sampled_actions, dim=-1) + + if self._cfg.normalize_prob_of_sampled_actions: + # normalize the prob of sampled actions + prob_sampled_actions_norm = torch.exp(log_prob_sampled_actions) / torch.exp(log_prob_sampled_actions).sum( + -1 + ).unsqueeze(-1).repeat(1, log_prob_sampled_actions.shape[-1]).detach() + # the above line is equal to the following line. + # prob_sampled_actions_norm = F.normalize(torch.exp(log_prob_sampled_actions), p=1., dim=-1, eps=1e-6) + log_prob_sampled_actions = torch.log(prob_sampled_actions_norm + 1e-6) + + # NOTE: the +=. + if self._cfg.policy_loss_type == 'KL': + # KL divergence loss: sum( p* log(p/q) ) = sum( p*log(p) - p*log(q) )= sum( p*log(p)) - sum( p*log(q) ) + policy_loss += ( + torch.exp(target_log_prob_sampled_actions.detach()) * + (target_log_prob_sampled_actions.detach() - log_prob_sampled_actions) + ).sum(-1) * mask_batch[:, unroll_step] + elif self._cfg.policy_loss_type == 'cross_entropy': + # cross_entropy loss: - sum(p * log (q) ) + policy_loss += -torch.sum( + torch.exp(target_log_prob_sampled_actions.detach()) * log_prob_sampled_actions, 1 + ) * mask_batch[:, unroll_step] + + return policy_loss, policy_entropy, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma + + def _calculate_policy_loss_disc( + self, policy_loss: torch.Tensor, policy_logits: torch.Tensor, target_policy: torch.Tensor, + mask_batch: torch.Tensor, child_sampled_actions_batch: torch.Tensor, unroll_step: int + ) -> Tuple[torch.Tensor]: + """ + Overview: + Calculate the policy loss for discrete action space. + Arguments: + - policy_loss (:obj:`torch.Tensor`): The policy loss tensor. + - policy_logits (:obj:`torch.Tensor`): The policy logits tensor. + - target_policy (:obj:`torch.Tensor`): The target policy tensor. + - mask_batch (:obj:`torch.Tensor`): The mask tensor. + - child_sampled_actions_batch (:obj:`torch.Tensor`): The child sampled actions tensor. + - unroll_step (:obj:`int`): The unroll step. + Returns: + - policy_loss (:obj:`torch.Tensor`): The policy loss tensor. + - policy_entropy (:obj:`torch.Tensor`): The policy entropy tensor. + - policy_entropy_loss (:obj:`torch.Tensor`): The policy entropy loss tensor. + - target_policy_entropy (:obj:`torch.Tensor`): The target policy entropy tensor. + - target_sampled_actions (:obj:`torch.Tensor`): The target sampled actions tensor. + """ + prob = torch.softmax(policy_logits, dim=-1) + dist = Categorical(prob) + + # take the init hypothetical step k=unroll_step + target_normalized_visit_count = target_policy[:, unroll_step] + + # Note: The target_policy_entropy is just for debugging. + target_normalized_visit_count_masked = torch.index_select( + target_normalized_visit_count, 0, + torch.nonzero(mask_batch[:, unroll_step]).squeeze(-1) + ) + target_dist = Categorical(target_normalized_visit_count_masked) + target_policy_entropy = target_dist.entropy().mean() + + # shape: (batch_size, num_unroll_steps, num_of_sampled_actions, action_dim, 1) -> (batch_size, + # num_of_sampled_actions, action_dim) e.g. (4, 6, 20, 2, 1) -> (4, 20, 2) + target_sampled_actions = child_sampled_actions_batch[:, unroll_step].squeeze(-1) + + policy_entropy = dist.entropy().mean() + policy_entropy_loss = -dist.entropy() + + # Project the sampled-based improved policy back onto the space of representable policies. calculate KL + # loss (batch_size, num_of_sampled_actions) -> (4,20) target_normalized_visit_count is + # categorical distribution, the range of target_log_prob_sampled_actions is (-inf, 0), add 1e-6 for + # numerical stability. + target_log_prob_sampled_actions = torch.log(target_normalized_visit_count + 1e-6) + + log_prob_sampled_actions = [] + for k in range(self._cfg.model.num_of_sampled_actions): + # target_sampled_actions[:,i,:] shape: (batch_size, action_dim) e.g. (4,2) + # dist.log_prob(target_sampled_actions[:,i,:]) shape: batch_size e.g. 4 + # dist is normal distribution, the range of log_prob_sampled_actions is (-inf, inf) + + if len(target_sampled_actions.shape) == 2: + target_sampled_actions = target_sampled_actions.unsqueeze(-1) + + log_prob = torch.log(prob.gather(-1, target_sampled_actions[:, k].long()).squeeze(-1) + 1e-6) + log_prob_sampled_actions.append(log_prob) + + # (batch_size, num_of_sampled_actions) e.g. (4,20) + log_prob_sampled_actions = torch.stack(log_prob_sampled_actions, dim=-1) + + if self._cfg.normalize_prob_of_sampled_actions: + # normalize the prob of sampled actions + prob_sampled_actions_norm = torch.exp(log_prob_sampled_actions) / torch.exp(log_prob_sampled_actions).sum( + -1 + ).unsqueeze(-1).repeat(1, log_prob_sampled_actions.shape[-1]).detach() + # the above line is equal to the following line. + # prob_sampled_actions_norm = F.normalize(torch.exp(log_prob_sampled_actions), p=1., dim=-1, eps=1e-6) + log_prob_sampled_actions = torch.log(prob_sampled_actions_norm + 1e-6) + + # NOTE: the +=. + if self._cfg.policy_loss_type == 'KL': + # KL divergence loss: sum( p* log(p/q) ) = sum( p*log(p) - p*log(q) )= sum( p*log(p)) - sum( p*log(q) ) + policy_loss += ( + torch.exp(target_log_prob_sampled_actions.detach()) * + (target_log_prob_sampled_actions.detach() - log_prob_sampled_actions) + ).sum(-1) * mask_batch[:, unroll_step] + elif self._cfg.policy_loss_type == 'cross_entropy': + # cross_entropy loss: - sum(p * log (q) ) + policy_loss += -torch.sum( + torch.exp(target_log_prob_sampled_actions.detach()) * log_prob_sampled_actions, 1 + ) * mask_batch[:, unroll_step] + + return policy_loss, policy_entropy, policy_entropy_loss, target_policy_entropy, target_sampled_actions + + def _init_collect(self) -> None: + """ + Overview: + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + """ + self._collect_model = self._model + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(self._cfg) + else: + self._mcts_collect = MCTSPtree(self._cfg) + self.collect_mcts_temperature = 1 + + def _forward_collect( + self, data: torch.Tensor, action_mask: list = None, temperature: np.ndarray = 1, to_play=-1, ready_env_id=None + ): + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + self.collect_mcts_temperature = temperature + + active_collect_env_num = len(data) + data = to_tensor(data) + data = sum(sum(data, []), []) + batch_size = len(data) + data = to_device(data, self._cfg.device) + agent_num = batch_size // active_collect_env_num + to_play = np.array(to_play).reshape(-1).tolist() + + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._collect_model.initial_inference(data) + latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( + network_output + ) + + if not self._learn_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + reward_hidden_state_roots = ( + reward_hidden_state_roots[0].detach().cpu().numpy(), + reward_hidden_state_roots[1].detach().cpu().numpy() + ) + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + action_mask = sum(action_mask, []) + if self._cfg.model.continuous_action_space is True: + # when the action space of the environment is continuous, action_mask[:] is None. + # NOTE: in continuous action space env: we set all legal_actions as -1 + legal_actions = [ + [-1 for _ in range(self._cfg.model.num_of_sampled_actions)] for _ in range(batch_size) + ] + else: + legal_actions = [ + [i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size) + ] + + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots( + batch_size, legal_actions, self._cfg.model.action_space_size, + self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space + ) + else: + # python mcts_tree + roots = MCTSPtree.roots( + batch_size, legal_actions, self._cfg.model.action_space_size, + self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space + ) + + # the only difference between collect and eval is the dirichlet noise + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(self._cfg.model.num_of_sampled_actions) + ).astype(np.float32).tolist() for j in range(batch_size) + ] + + roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_roots, policy_logits, to_play) + self._mcts_collect.search( + roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play + ) + + roots_visit_count_distributions = roots.get_distributions( + ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_values = roots.get_values() # shape: {list: batch_size} + roots_sampled_actions = roots.get_sampled_actions() # {list: 1}->{list:6} + + data_id = [i for i in range(active_collect_env_num)] + output = {i: defaultdict(list) for i in data_id} + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + + for i in range(batch_size): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + try: + root_sampled_actions = np.array([action.value for action in roots_sampled_actions[i]]) + except Exception: + # logging.warning('ctree_sampled_efficientzero roots.get_sampled_actions() return list') + root_sampled_actions = np.array([action for action in roots_sampled_actions[i]]) + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + action, visit_count_distribution_entropy = select_action( + distributions, temperature=self.collect_mcts_temperature, deterministic=False + ) + try: + action = roots_sampled_actions[i][action].value + # logging.warning('ptree_sampled_efficientzero roots.get_sampled_actions() return array') + except Exception: + # logging.warning('ctree_sampled_efficientzero roots.get_sampled_actions() return list') + action = np.array(roots_sampled_actions[i][action]) + + if not self._cfg.model.continuous_action_space: + if len(action.shape) == 0: + action = int(action) + elif len(action.shape) == 1: + action = int(action[0]) + + output[i//agent_num]['action'].append(action) + output[i//agent_num]['distributions'].append(distributions) + output[i//agent_num]['root_sampled_actions'].append(root_sampled_actions) + output[i//agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[i//agent_num]['value'].append(value) + output[i//agent_num]['pred_value'].append(pred_values[i]) + output[i//agent_num]['policy_logits'].append(policy_logits[i]) + + + return output + + def _init_eval(self) -> None: + """ + Overview: + Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. + """ + self._eval_model = self._model + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(self._cfg) + else: + self._mcts_eval = MCTSPtree(self._cfg) + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id=None): + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + active_eval_env_num = len(data) + data = to_tensor(data) + data = sum(sum(data, []), []) + batch_size = len(data) + data = to_device(data, self._cfg.device) + agent_num = batch_size // active_eval_env_num + to_play = np.array(to_play).reshape(-1).tolist() + + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._eval_model.initial_inference(data) + latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( + network_output + ) + + if not self._eval_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + reward_hidden_state_roots = ( + reward_hidden_state_roots[0].detach().cpu().numpy(), + reward_hidden_state_roots[1].detach().cpu().numpy() + ) + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + + action_mask = sum(action_mask, []) + if self._cfg.model.continuous_action_space is True: + # when the action space of the environment is continuous, action_mask[:] is None. + # NOTE: in continuous action space env: we set all legal_actions as -1 + legal_actions = [ + [-1 for _ in range(self._cfg.model.num_of_sampled_actions)] for _ in range(batch_size) + ] + else: + legal_actions = [ + [i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size) + ] + + # cpp mcts_tree + if self._cfg.mcts_ctree: + roots = MCTSCtree.roots( + batch_size, legal_actions, self._cfg.model.action_space_size, + self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space + ) + else: + # python mcts_tree + roots = MCTSPtree.roots( + batch_size, legal_actions, self._cfg.model.action_space_size, + self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space + ) + + roots.prepare_no_noise(value_prefix_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, reward_hidden_state_roots, to_play) + + roots_visit_count_distributions = roots.get_distributions( + ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_values = roots.get_values() # shape: {list: batch_size} + # ============================================================== + # sampled related core code + # ============================================================== + roots_sampled_actions = roots.get_sampled_actions( + ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + + data_id = [i for i in range(active_eval_env_num)] + output = {i: defaultdict(list) for i in data_id} + + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + + for i in range(batch_size): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + try: + root_sampled_actions = np.array([action.value for action in roots_sampled_actions[i]]) + except Exception: + # logging.warning('ctree_sampled_efficientzero roots.get_sampled_actions() return list') + root_sampled_actions = np.array([action for action in roots_sampled_actions[i]]) + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than sampling during the evaluation phase. + action, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # ============================================================== + # sampled related core code + # ============================================================== + + try: + action = roots_sampled_actions[i][action].value + # logging.warning('ptree_sampled_efficientzero roots.get_sampled_actions() return array') + except Exception: + # logging.warning('ctree_sampled_efficientzero roots.get_sampled_actions() return list') + action = np.array(roots_sampled_actions[i][action]) + + if not self._cfg.model.continuous_action_space: + if len(action.shape) == 0: + action = int(action) + elif len(action.shape) == 1: + action = int(action[0]) + output[i//agent_num]['action'].append(action) + output[i//agent_num]['distributions'].append(distributions) + output[i//agent_num]['root_sampled_actions'].append(root_sampled_actions) + output[i//agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[i//agent_num]['value'].append(value) + output[i//agent_num]['pred_value'].append(pred_values[i]) + output[i//agent_num]['policy_logits'].append(policy_logits[i]) + + return output + + def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Register the variables to be monitored in learn mode. The registered variables will be logged in + tensorboard according to the return value ``_forward_learn``. + """ + if self._cfg.model.continuous_action_space: + return [ + 'collect_mcts_temperature', + 'cur_lr', + 'total_loss', + 'weighted_total_loss', + 'policy_loss', + 'value_prefix_loss', + 'value_loss', + 'consistency_loss', + 'value_priority', + 'target_value_prefix', + 'target_value', + 'predicted_value_prefixs', + 'predicted_values', + 'transformed_target_value_prefix', + 'transformed_target_value', + + # ============================================================== + # sampled related core code + # ============================================================== + 'policy_entropy', + 'target_policy_entropy', + 'policy_mu_max', + 'policy_mu_min', + 'policy_mu_mean', + 'policy_sigma_max', + 'policy_sigma_min', + 'policy_sigma_mean', + # take the fist dim in action space + 'target_sampled_actions_max', + 'target_sampled_actions_min', + 'target_sampled_actions_mean', + 'total_grad_norm_before_clip', + ] + else: + return [ + 'collect_mcts_temperature', + 'cur_lr', + 'total_loss', + 'weighted_total_loss', + 'loss_mean', + 'policy_loss', + 'value_prefix_loss', + 'value_loss', + 'consistency_loss', + 'value_priority', + 'target_value_prefix', + 'target_value', + 'predicted_value_prefixs', + 'predicted_values', + 'transformed_target_value_prefix', + 'transformed_target_value', + + # ============================================================== + # sampled related core code + # ============================================================== + 'policy_entropy', + 'target_policy_entropy', + + # take the fist dim in action space + 'target_sampled_actions_max', + 'target_sampled_actions_min', + 'target_sampled_actions_mean', + 'total_grad_norm_before_clip', + ] + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model and optimizer. + Returns: + - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. + """ + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'optimizer': self._optimizer.state_dict(), + } + + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before. + """ + self._learn_model.load_state_dict(state_dict['model']) + self._target_model.load_state_dict(state_dict['target_model']) + self._optimizer.load_state_dict(state_dict['optimizer']) + + def _process_transition(self, obs, policy_output, timestep): + # be compatible with DI-engine Policy class + pass + + def _get_train_sample(self, data): + # be compatible with DI-engine Policy class + pass diff --git a/zoo/gobigger/config/gobigger_sampled_efficientzero_config.py b/zoo/gobigger/config/gobigger_sampled_efficientzero_config.py new file mode 100644 index 000000000..8bbef8236 --- /dev/null +++ b/zoo/gobigger/config/gobigger_sampled_efficientzero_config.py @@ -0,0 +1,145 @@ +from easydict import EasyDict + +env_name = 'GoBigger' + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +continuous_action_space = False +K = 20 # num_of_sampled_actions +collector_env_num = 32 +n_episode = 32 +evaluator_env_num = 5 +num_simulations = 50 +update_per_collect = 2000 +batch_size = 256 +reanalyze_ratio = 0. +action_space_size = 27 +direction_num=12 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +atari_sampled_efficientzero_config = dict( + exp_name= + f'data_sez_ctree/{env_name[:-14]}_sampled_efficientzero_k{K}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + env=dict( + env_name=env_name, + team_num=2, + player_num_per_team=2, + direction_num=direction_num, + step_mul=8, + map_width=64, + map_height=64, + frame_limit=3600, + action_space_size=action_space_size, + use_action_mask=False, + reward_div_value=0.1, + reward_type='log_reward', + start_spirit_progress=0.2, + end_spirit_progress=0.8, + manager_settings=dict( + food_manager=dict( + num_init=260, + num_min=260, + num_max=300, + ), + thorns_manager=dict( + num_init=3, + num_min=3, + num_max=4, + ), + player_manager=dict( + ball_settings=dict( + score_init=13000, + ), + ), + ), + playback_settings=dict( + playback_type='by_frame', + by_frame=dict( + save_frame=False, + # save_frame=True, + save_dir='./', + save_name_prefix='gobigger', + ), + ), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + latent_state_dim=176, + frame_stack_num=1, + action_space_size=action_space_size, + downsample=True, + continuous_action_space=continuous_action_space, + num_of_sampled_actions=K, + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + mcts_ctree=True, + env_type='not_board_games', + game_segment_length=400, + use_augmentation=False, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + policy_loss_type='cross_entropy', + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + learn=dict( + learner=dict( + log_policy=True, + hook=dict( + log_show_after_iter=10, + ), + ), + ), + collect=dict( + collector=dict( + collect_print_freq=10, + ), + ), + eval=dict( + evaluator=dict( + eval_freq=5000, + stop_value=10000000000, + ), + ), +) +atari_sampled_efficientzero_config = EasyDict(atari_sampled_efficientzero_config) +main_config = atari_sampled_efficientzero_config + +atari_sampled_efficientzero_create_config = dict( + env=dict( + type='gobigger_lightzero', + import_names=['zoo.gobigger.env.gobigger_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='gobigger_sampled_efficientzero', + import_names=['lzero.policy.gobigger_sampled_efficientzero'], + ), + collector=dict( + type='gobigger_episode_muzero', + import_names=['lzero.worker.gobigger_muzero_collector'], + ) +) +atari_sampled_efficientzero_create_config = EasyDict(atari_sampled_efficientzero_create_config) +create_config = atari_sampled_efficientzero_create_config + +if __name__ == "__main__": + from lzero.entry import train_muzero_gobigger + train_muzero_gobigger([main_config, create_config], seed=0) From 4925d01766baa9655131b1dc691f9ddcd898842e Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Wed, 7 Jun 2023 11:51:02 +0800 Subject: [PATCH 07/54] feature(yzj): add gobigger visualization and polish gobigger eval config --- lzero/entry/__init__.py | 1 + lzero/entry/eval_muzero_gobigger.py | 98 +++++++++++++++++++ .../config/gobigger_efficientzero_config.py | 3 +- zoo/gobigger/config/gobigger_eval_config.py | 33 +++++++ 4 files changed, 133 insertions(+), 2 deletions(-) create mode 100644 lzero/entry/eval_muzero_gobigger.py create mode 100644 zoo/gobigger/config/gobigger_eval_config.py diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py index 60ac5f42b..c6db99f30 100644 --- a/lzero/entry/__init__.py +++ b/lzero/entry/__init__.py @@ -5,3 +5,4 @@ from .eval_muzero_with_gym_env import eval_muzero_with_gym_env from .train_muzero_with_gym_env import train_muzero_with_gym_env from .train_muzero_gobigger import train_muzero_gobigger +from .eval_muzero_gobigger import eval_muzero_gobigger diff --git a/lzero/entry/eval_muzero_gobigger.py b/lzero/entry/eval_muzero_gobigger.py new file mode 100644 index 000000000..e5d65901b --- /dev/null +++ b/lzero/entry/eval_muzero_gobigger.py @@ -0,0 +1,98 @@ +import logging +import os +from functools import partial +from typing import Optional, Tuple +import numpy as np +import torch +from tensorboardX import SummaryWriter + +from ding.config import compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.utils import set_pkg_seed +from ding.worker import BaseLearner +from lzero.worker import GoBiggerMuZeroEvaluator + +def eval_muzero_gobigger( + input_cfg: Tuple[dict, dict], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, +) -> 'Policy': # noqa + """ + Overview: + The train entry for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. + Arguments: + - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type. + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): The pretrained model path, which should + point to the ckpt file of the pretrained model, and an absolute path is recommended. + In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + cfg, create_cfg = input_cfg + assert create_cfg.policy.type in ['gobigger_efficientzero', 'gobigger_muzero', 'gobigger_sampled_efficientzero'], \ + "train_muzero entry now only support the following algo.: 'gobigger_efficientzero', 'gobigger_muzero', 'gobigger_sampled_efficientzero'" + + if create_cfg.policy.type == 'gobigger_efficientzero': + from lzero.mcts import GoBiggerEfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'gobigger_muzero': + from lzero.mcts import GoBiggerMuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'gobigger_sampled_efficientzero': + from lzero.mcts import GoBiggerSampledEfficientZeroGameBuffer as GameBuffer + + if cfg.policy.cuda and torch.cuda.is_available(): + cfg.policy.device = 'cuda' + else: + cfg.policy.device = 'cpu' + + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create main components: env, policy + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + collector_env.seed(cfg.seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # load pretrained model + if model_path is not None: + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = cfg.policy + evaluator = GoBiggerMuZeroEvaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config + ) + + # ============================================================== + # Main loop + # ============================================================== + # Learner's before_run hook. + learner.call_hook('before_run') + # ============================================================== + # eval trained model + # ============================================================== + _, reward_sp = evaluator.eval(learner.save_checkpoint, learner.train_iter) + _, reward_vsbot= evaluator.eval_vsbot(learner.save_checkpoint, learner.train_iter) + return reward_sp, reward_vsbot diff --git a/zoo/gobigger/config/gobigger_efficientzero_config.py b/zoo/gobigger/config/gobigger_efficientzero_config.py index b0268b566..585b6d7d6 100644 --- a/zoo/gobigger/config/gobigger_efficientzero_config.py +++ b/zoo/gobigger/config/gobigger_efficientzero_config.py @@ -56,8 +56,7 @@ playback_settings=dict( playback_type='by_frame', by_frame=dict( - save_frame=False, - # save_frame=True, + save_frame=False, # when training should set as False save_dir='./', save_name_prefix='gobigger', ), diff --git a/zoo/gobigger/config/gobigger_eval_config.py b/zoo/gobigger/config/gobigger_eval_config.py new file mode 100644 index 000000000..aa94ea2fa --- /dev/null +++ b/zoo/gobigger/config/gobigger_eval_config.py @@ -0,0 +1,33 @@ +# According to the model you want to evaluate, import the corresponding config. +from lzero.entry import eval_muzero_gobigger +import numpy as np + +if __name__ == "__main__": + """ + model_path (:obj:`Optional[str]`): The pretrained model path, which should + point to the ckpt file of the pretrained model, and an absolute path is recommended. + In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + """ + # sez + # from gobigger_efficientzero_config import main_config, create_config + from gobigger_muzero_config import main_config, create_config + model_path = "/path/ckpt/ckpt_best.pth.tar" + + returns_mean_seeds = [] + returns_seeds = [] + seeds = [0] + create_config.env_manager.type = 'base' # when visualize must set as base + main_config.env.evaluator_env_num = 1 # when visualize must set as 1 + main_config.env.n_evaluator_episode = 2 # each seed eval episodes num + main_config.env.playback_settings.by_frame.save_frame=True + main_config.env.playback_settings.by_frame.save_name_prefix= 'gobigger' + + for seed in seeds: + returns_selfplay_mean, returns_vsbot_mean = eval_muzero_gobigger( + [main_config, create_config], + seed=seed, + model_path=model_path, + ) + print('seed: {}'.format(seed)) + print('returns_selfplay_mean: {}'.format(returns_selfplay_mean)) + print('returns_vsbot_mean: {}'.format(returns_vsbot_mean)) From b8e044eacb4e2ceb61e965c80a52ba7ccf0a8b14 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Wed, 7 Jun 2023 17:00:16 +0800 Subject: [PATCH 08/54] fix(yzj): fix eval_episode_return and polish env --- zoo/gobigger/env/gobigger_env.py | 90 ++++++++++++++++++-------------- 1 file changed, 50 insertions(+), 40 deletions(-) diff --git a/zoo/gobigger/env/gobigger_env.py b/zoo/gobigger/env/gobigger_env.py index 1077a93db..c4f727edb 100644 --- a/zoo/gobigger/env/gobigger_env.py +++ b/zoo/gobigger/env/gobigger_env.py @@ -11,25 +11,29 @@ class GoBiggerLightZeroEnv(BaseEnv): def __init__(self, cfg: dict) -> None: self._cfg = cfg + # ding env info self._init_flag = False self._observation_space = None self._action_space = None self._reward_space = None + # gobigger env info self.team_num = self._cfg.team_num self.player_num_per_team = self._cfg.player_num_per_team self.direction_num = self._cfg.direction_num self.use_action_mask = self._cfg.use_action_mask - self.action_space_size = self._cfg.action_space_size - self.max_player_num = self.player_num_per_team + self.action_space_size = self._cfg.action_space_size # discrete action space size self.step_mul = self._cfg.get('step_mul', 8) self.setup_action() - # feature engineering + self.setup_feature() + + def setup_feature(self): self.second_per_frame = 0.05 self.spatial_x = 64 self.spatial_y = 64 self.max_ball_num = 80 self.max_food_num = 256 self.max_spore_num = 64 + self.max_player_num = self.player_num_per_team self.reward_div_value = self._cfg.reward_div_value self.reward_type = self._cfg.reward_type self.player_init_score = self._cfg.manager_settings.player_manager.ball_settings.score_init @@ -42,7 +46,6 @@ def reset(self) -> np.ndarray: self._init_flag = True self.last_action_types = {player_id: self.direction_num * 2 for player_id in range(self.player_num_per_team * self.team_num)} raw_obs = self._env.reset() - self.eval_episode_return = [[] for _ in range(self.team_num)] obs = self.observation(raw_obs) self.last_action_types = {player_id: self.direction_num * 2 for player_id in range(self.player_num_per_team * self.team_num)} @@ -52,23 +55,31 @@ def reset(self) -> np.ndarray: def observation(self, raw_obs): obs = self.preprocess_obs(raw_obs) - action_mask = [np.logical_not(o['action_mask']) for o in obs] - to_play = [ -1 for _ in range(len(obs))] + # for alignment with other environments, reverse the action mask + action_mask = [np.logical_not(o['action_mask']) for o in obs] + to_play = [ -1 for _ in range(len(obs))] # Moot, for alignment with other environments obs = {'observation': obs, 'action_mask': action_mask, 'to_play': to_play, 'raw_obs':raw_obs} return obs - - def step(self, action: dict) -> BaseEnvTimestep: - action = {k: self.transform_action(v) if np.isscalar(v) else v for k, v in action.items()} + + def postproecess(self, action_dict): + for k, v in action_dict.items(): + if np.isscalar(v): + self.last_action_types[k] = v + else: + self.last_action_types[k] = self.direction_num * 2 + + def step(self, action_dict: dict) -> BaseEnvTimestep: + action = {k: self.transform_action(v) if np.isscalar(v) else v for k, v in action_dict.items()} raw_obs, raw_rew, done, info = self._env.step(action) # print('current_frame={}'.format(raw_obs[0]['last_time'])) # print('action={}'.format(action)) # print('raw_rew={}, done={}'.format(raw_rew, done)) rew = self.transform_reward(raw_obs) - for i in range(self.team_num): - self.eval_episode_return[i].append(raw_obs[0]['leaderboard'][i]) obs = self.observation(raw_obs) + # postprocess + self.postproecess(action_dict) if done: - info['eval_episode_return'] = [np.mean(self.eval_episode_return[i]) for i in range(self.team_num)] + info['eval_episode_return'] = [raw_obs[0]['leaderboard'][i] for i in range(self.team_num)] return BaseEnvTimestep(obs, rew, done, info) def seed(self, seed: int, dynamic_seed: bool = True) -> None: @@ -96,10 +107,9 @@ def reward_space(self) -> gym.spaces.Space: def __repr__(self) -> str: return "LightZero Env({})".format(self.cfg.env_name) - def transform_obs(self, obs, game_player_id=1, padding=True, last_action_type=None, ): + def transform_obs(self, obs, own_player_id=1, padding=True, last_action_type=None, ): global_state, player_observations = obs player2team = self.get_player2team() - own_player_id = game_player_id leaderboard = global_state['leaderboard'] team2rank = {key: rank for rank, key in enumerate(sorted(leaderboard, key=leaderboard.get, reverse=True), )} @@ -113,8 +123,8 @@ def transform_obs(self, obs, game_player_id=1, padding=True, last_action_type=No own_left_top_x, own_left_top_y, own_right_bottom_x, own_right_bottom_y = own_player_obs['rectangle'] own_view_center = [(own_left_top_x + own_right_bottom_x - scene_size) / 2, (own_left_top_y + own_right_bottom_y - scene_size) / 2] + # own_view_width == own_view_height own_view_width = float(own_right_bottom_x - own_left_top_x) - # own_view_height = float(own_right_bottom_y - own_left_top_y) own_score = own_player_obs['score'] / 100 own_team_score = global_state['leaderboard'][own_team_id] / 100 @@ -181,14 +191,20 @@ def transform_obs(self, obs, game_player_id=1, padding=True, last_action_type=No neutral_player_id = self.team_num * self.player_num_per_team neutral_team_rank = self.team_num + # clone = [type, score, player_id, team_id, team_rank, x, y, next_x, next_y] clone = [[ball_type_map['clone'], bl[3], bl[-2], bl[-1], team2rank[bl[-1]], bl[0], bl[1], *self.next_position(bl[0], bl[1], bl[4], bl[5])] for bl in clone] + + # thorn = [type, score, player_id, team_id, team_rank, x, y, next_x, next_y] thorns = [[ball_type_map['thorns'], bl[3], neutral_player_id, neutral_team_id, neutral_team_rank, bl[0], bl[1], *self.next_position(bl[0], bl[1], bl[4], bl[5])] for bl in thorns] + + # thorn = [type, score, player_id, team_id, team_rank, x, y, next_x, next_y] food = [ [ball_type_map['food'], bl[3], neutral_player_id, neutral_team_id, neutral_team_rank, bl[0], bl[1], bl[0], bl[1]] for bl in food] + # spore = [type, score, player_id, team_id, team_rank, x, y, next_x, next_y] spore = [ [ball_type_map['spore'], bl[3], bl[-1], player2team[bl[-1]], team2rank[player2team[bl[-1]]], bl[0], bl[1], @@ -196,6 +212,7 @@ def transform_obs(self, obs, game_player_id=1, padding=True, last_action_type=No all_balls = clone + thorns + food + spore + # Particularly handle balls outside the field of view for b in all_balls: if b[2] == own_player_id and b[0] == 1: if b[5] < own_left_top_x or b[5] > own_right_bottom_x or \ @@ -214,7 +231,7 @@ def transform_obs(self, obs, game_player_id=1, padding=True, last_action_type=No all_balls[:, -1] = ((all_balls[:, -1] - origin_y) / own_view_width * self.spatial_y) # ball - ball_indices = np.logical_and(all_balls[:, 0] != 2, all_balls[:, 0] != 4) + ball_indices = np.logical_and(all_balls[:, 0] != 2, all_balls[:, 0] != 4) # include player balls and thorn balls balls = all_balls[ball_indices] balls_num = len(balls) @@ -242,7 +259,7 @@ def transform_obs(self, obs, game_player_id=1, padding=True, last_action_type=No remain_enemy_balls = enemy_balls[enemy_high_score_ball_indices] balls = np.concatenate([own_balls, teammate_balls, remain_enemy_balls], axis=0) - + balls_num = len(balls) ball_padding_num = self.max_ball_num - len(balls) if padding or ball_padding_num < 0: balls = np.pad(balls, ((0, ball_padding_num), (0, 0)), 'constant', constant_values=0) @@ -257,13 +274,12 @@ def transform_obs(self, obs, game_player_id=1, padding=True, last_action_type=No ## score&radius scale_score = balls[:, 1] / 100 - radius = np.clip(np.sqrt(scale_score * 0.042 + 0.15) / own_view_width, a_max=1, a_min=0) - score = np.clip(np.round(np.sqrt(scale_score * 0.042 + 0.15) / own_view_width * 50), a_max=49, a_min=0).astype(int) - + radius = np.clip(np.sqrt(scale_score * 0.042 + 0.15) / own_view_width, a_max=1, a_min=None) + score = np.clip(np.round(np.clip(np.sqrt(scale_score * 0.042 + 0.15) / own_view_width, a_max=1, a_min=None)*50).astype(int), a_max=49, a_min=None) ## rank: ball_rank = balls[:, 4] - ## coordinate + ## coordinates relative to the center of [spatial_x, spatial_y] x = balls[:, -4] - self.spatial_x // 2 y = balls[:, -3] - self.spatial_y // 2 next_x = balls[:, -2] - self.spatial_x // 2 @@ -285,6 +301,7 @@ def transform_obs(self, obs, game_player_id=1, padding=True, last_action_type=No # spatial info # ============ # ball coordinate for scatter connection + # coordinates relative to the upper left corner of [spatial_x, spatial_y] ball_x = balls[:, -4] ball_y = balls[:, -3] @@ -327,31 +344,25 @@ def transform_obs(self, obs, game_player_id=1, padding=True, last_action_type=No } return output_obs - def _preprocess_obs(self, raw_obs, env_status=None, eval_vsbot=False): + def preprocess_obs(self, raw_obs): env_player_obs = [] - game_player_num = self.player_num_per_team if eval_vsbot else self.player_num_per_team * self.team_num - for game_player_id in range(game_player_num): - if env_status is None: - last_action_type = self.direction_num * 2 - else: - last_action_type = self.last_action_types[game_player_id] + for game_player_id in range(self.player_num_per_team): + last_action_type = self.last_action_types[game_player_id] if self.use_action_mask: can_eject = raw_obs[1][game_player_id]['can_eject'] can_split = raw_obs[1][game_player_id]['can_split'] action_mask = self.generate_action_mask(can_eject=can_eject, can_split=can_split) else: action_mask = self.generate_action_mask(can_eject=True, can_split=True) - game_player_obs = self.transform_obs(raw_obs, game_player_id=game_player_id, padding=True, - last_action_type=last_action_type) + game_player_obs = self.transform_obs(raw_obs, own_player_id=game_player_id, padding=True, last_action_type=last_action_type) game_player_obs['action_mask'] = action_mask env_player_obs.append(game_player_obs) return env_player_obs - - def preprocess_obs(self, obs_list, env_status=None, eval_vsbot=False): - env_player_obs = self._preprocess_obs(obs_list, env_status, eval_vsbot) - return env_player_obs - def generate_action_mask(self, can_eject, can_split, ): + def generate_action_mask(self, can_eject, can_split): + # action mask + # 1 represent can not do this action + # 0 represent can do this action action_mask = np.zeros((self.action_space_size,), dtype=np.bool_) if not can_eject: action_mask[self.direction_num * 2 + 1] = True @@ -375,10 +386,8 @@ def transform_action(self, action_idx): def setup_action(self): theta = math.pi * 2 / self.direction_num - self.x_y_action_List = [[0.3 * math.cos(theta * i), 0.3 * math.sin(theta * i), 0] for i in - range(self.direction_num)] + \ - [[math.cos(theta * i), math.sin(theta * i), 0] for i in - range(self.direction_num)] + \ + self.x_y_action_List = [[0.3 * math.cos(theta * i), 0.3 * math.sin(theta * i), 0] for i in range(self.direction_num)] + \ + [[math.cos(theta * i), math.sin(theta * i), 0] for i in range(self.direction_num)] + \ [[0, 0, 0], [0, 0, 1], [0, 0, 2]] def get_spirit(self, progress): @@ -466,7 +475,7 @@ def transform_reward(self, next_obs): # save_frame=False, save_frame=True, save_dir='./', - save_name_prefix='gobigger-bot', + save_name_prefix='test', ), ), )) @@ -481,5 +490,6 @@ def transform_reward(self, next_obs): # bot[i].step(obs['raw_obs'] is dict actions.update(bot[i].step(obs['raw_obs'])) obs, rew, done, info = env.step(actions) + print(rew, info) if done: break \ No newline at end of file From f229b6afdd3db3d9d4a0ae02224823c22767f6fd Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Wed, 7 Jun 2023 17:08:36 +0800 Subject: [PATCH 09/54] polish(yzj): polish gobigger env pytest --- zoo/gobigger/env/gobigger_env.py | 59 ----------------------- zoo/gobigger/env/test_gobbiger_env.py | 69 +++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 59 deletions(-) create mode 100644 zoo/gobigger/env/test_gobbiger_env.py diff --git a/zoo/gobigger/env/gobigger_env.py b/zoo/gobigger/env/gobigger_env.py index c4f727edb..0a3963e51 100644 --- a/zoo/gobigger/env/gobigger_env.py +++ b/zoo/gobigger/env/gobigger_env.py @@ -434,62 +434,3 @@ def transform_reward(self, next_obs): self.last_player_scores[game_player_id] = player_score self.last_leaderboard = next_obs[0]['leaderboard'] return score_rewards_list - -if __name__ == '__main__': - from easydict import EasyDict - env_cfg=EasyDict(dict( - env_name='gobigger', - team_num=2, - player_num_per_team=2, - direction_num=12, - step_mul=8, - map_width=64, - map_height=64, - frame_limit=3600, - action_space_size=27, - use_action_mask=False, - reward_div_value=0.1, - reward_type='log_reward', - start_spirit_progress=0.2, - end_spirit_progress=0.8, - manager_settings=dict( - food_manager=dict( - num_init=260, - num_min=260, - num_max=300, - ), - thorns_manager=dict( - num_init=3, - num_min=3, - num_max=4, - ), - player_manager=dict( - ball_settings=dict( - score_init=13000, - ), - ), - ), - playback_settings=dict( - playback_type='by_frame', - by_frame=dict( - # save_frame=False, - save_frame=True, - save_dir='./', - save_name_prefix='test', - ), - ), - )) - - env = GoBiggerLightZeroEnv(env_cfg) - obs = env.reset() - from gobigger_rule_bot import BotAgent - bot = [BotAgent(i) for i in range(4)] - while True: - actions = {} - for i in range(4): - # bot[i].step(obs['raw_obs'] is dict - actions.update(bot[i].step(obs['raw_obs'])) - obs, rew, done, info = env.step(actions) - print(rew, info) - if done: - break \ No newline at end of file diff --git a/zoo/gobigger/env/test_gobbiger_env.py b/zoo/gobigger/env/test_gobbiger_env.py new file mode 100644 index 000000000..fb756efcf --- /dev/null +++ b/zoo/gobigger/env/test_gobbiger_env.py @@ -0,0 +1,69 @@ + +import pytest +from easydict import EasyDict +from gobigger_env import GoBiggerLightZeroEnv +from gobigger_rule_bot import BotAgent + + +env_cfg=EasyDict(dict( + env_name='gobigger', + team_num=2, + player_num_per_team=2, + direction_num=12, + step_mul=8, + map_width=64, + map_height=64, + frame_limit=3600, + action_space_size=27, + use_action_mask=False, + reward_div_value=0.1, + reward_type='log_reward', + start_spirit_progress=0.2, + end_spirit_progress=0.8, + manager_settings=dict( + food_manager=dict( + num_init=260, + num_min=260, + num_max=300, + ), + thorns_manager=dict( + num_init=3, + num_min=3, + num_max=4, + ), + player_manager=dict( + ball_settings=dict( + score_init=13000, + ), + ), + ), + playback_settings=dict( + playback_type='by_frame', + by_frame=dict( + # save_frame=False, + save_frame=True, + save_dir='./', + save_name_prefix='test', + ), + ), +)) + +@pytest.mark.envtest +class TestGoBiggerLightZeroEnv: + + def test_env(self): + env = GoBiggerLightZeroEnv(env_cfg) + obs = env.reset() + from gobigger_rule_bot import BotAgent + bot = [BotAgent(i) for i in range(4)] + while True: + actions = {} + for i in range(4): + # bot[i].step(obs['raw_obs'] is dict + actions.update(bot[i].step(obs['raw_obs'])) + obs, rew, done, info = env.step(actions) + print(rew, info) + if done: + break + +TestGoBiggerLightZeroEnv().test_env() \ No newline at end of file From 4bbbeb0fed29cac6fc5785f3c271d0bd36d959d7 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Wed, 7 Jun 2023 19:39:22 +0800 Subject: [PATCH 10/54] polish(yzj): polish gobigger env and eat info in evaluator --- lzero/worker/gobigger_muzero_evaluator.py | 16 ++++++++++++---- zoo/gobigger/env/gobigger_env.py | 2 +- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/lzero/worker/gobigger_muzero_evaluator.py b/lzero/worker/gobigger_muzero_evaluator.py index d7ace9b0b..efe39d241 100644 --- a/lzero/worker/gobigger_muzero_evaluator.py +++ b/lzero/worker/gobigger_muzero_evaluator.py @@ -501,7 +501,8 @@ def eval_vsbot( for i in range(env_nums): for k, v in init_obs[i].items(): - init_obs[i][k] = v[:agent_num] + if k != 'raw_obs': + init_obs[i][k] = v[:agent_num] action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} @@ -522,6 +523,7 @@ def eval_vsbot( ready_env_id = set() remain_episode = n_episode + eat_info = defaultdict() with self._timer: while not eval_monitor.is_finished(): @@ -623,6 +625,7 @@ def eval_vsbot( self._policy.reset([env_id]) reward = t.info['eval_episode_return'][0] bot_reward = t.info['eval_episode_return'][1] + eat_info[env_id] = t.info['eats'] if 'episode_info' in t.info: eval_monitor.update_info(env_id, t.info['episode_info']) eval_monitor.update_reward(env_id, reward) @@ -707,9 +710,14 @@ def eval_vsbot( 'bot_reward_min': np.min(bot_episode_return), } # add eat info - for i in range(len(t.info['eats'])): - for k,v in t.info['eats'][i].items(): - info['agent_{}_{}'.format(i, k)] = v + for k,v in eat_info.items(): + for i in range(len(v)): + for k1, v1 in v[i].items(): + info['agent_{}_{}'.format(i, k1)] = info.get('agent_{}_{}'.format(i, k1), []) + [v1] + + for k,v in info.items(): + if 'agent' in k: + info[k] = np.mean(v) episode_info = eval_monitor.get_episode_info() if episode_info is not None: diff --git a/zoo/gobigger/env/gobigger_env.py b/zoo/gobigger/env/gobigger_env.py index 0a3963e51..69449e05b 100644 --- a/zoo/gobigger/env/gobigger_env.py +++ b/zoo/gobigger/env/gobigger_env.py @@ -346,7 +346,7 @@ def transform_obs(self, obs, own_player_id=1, padding=True, last_action_type=Non def preprocess_obs(self, raw_obs): env_player_obs = [] - for game_player_id in range(self.player_num_per_team): + for game_player_id in range(self.player_num_per_team*self.team_num): last_action_type = self.last_action_types[game_player_id] if self.use_action_mask: can_eject = raw_obs[1][game_player_id]['can_eject'] From 7529170cbec21142f8fbc9fe992c4173690ab14b Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Mon, 12 Jun 2023 18:21:10 +0800 Subject: [PATCH 11/54] fix(yzj): fix np.pad bug, which need padding_num>0 --- zoo/gobigger/env/gobigger_env.py | 21 +++++++++++++++++---- zoo/gobigger/env/test_gobbiger_env.py | 2 -- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/zoo/gobigger/env/gobigger_env.py b/zoo/gobigger/env/gobigger_env.py index 69449e05b..183d80d3c 100644 --- a/zoo/gobigger/env/gobigger_env.py +++ b/zoo/gobigger/env/gobigger_env.py @@ -170,7 +170,10 @@ def transform_obs(self, obs, own_player_id=1, padding=True, last_action_type=Non all_players = np.array(all_players) player_padding_num = self.max_player_num - len(all_players) player_num = len(all_players) - all_players = np.pad(all_players, pad_width=((0, player_padding_num), (0, 0)), mode='constant') + if player_padding_num < 0: + all_players = all_players[:self.max_player_num,:] + else: + all_players = np.pad(all_players, pad_width=((0, player_padding_num), (0, 0)), mode='constant') team_info = { 'alliance': all_players[:, 0].astype(np.int64), 'view_x': np.round(all_players[:, 1]).astype(np.int64), @@ -261,7 +264,11 @@ def transform_obs(self, obs, own_player_id=1, padding=True, last_action_type=Non balls = np.concatenate([own_balls, teammate_balls, remain_enemy_balls], axis=0) balls_num = len(balls) ball_padding_num = self.max_ball_num - len(balls) - if padding or ball_padding_num < 0: + if ball_padding_num < 0: + balls = balls[:self.max_ball_num, :] + alliance = np.zeros(self.max_ball_num) + balls_num = self.max_ball_num + elif padding: balls = np.pad(balls, ((0, ball_padding_num), (0, 0)), 'constant', constant_values=0) alliance = np.zeros(self.max_ball_num) balls_num = min(self.max_ball_num, balls_num) @@ -310,7 +317,10 @@ def transform_obs(self, obs, own_player_id=1, padding=True, last_action_type=Non food_y = all_balls[food_indices, -3] food_num = len(food_x) food_padding_num = self.max_food_num - len(food_x) - if padding or food_padding_num < 0: + if food_padding_num < 0: + food_x = food_x[:self.max_food_num] + food_y = food_y[:self.max_food_num] + elif padding: food_x = np.pad(food_x, (0, food_padding_num), 'constant', constant_values=0) food_y = np.pad(food_y, (0, food_padding_num), 'constant', constant_values=0) food_num = min(food_num, self.max_food_num) @@ -320,7 +330,10 @@ def transform_obs(self, obs, own_player_id=1, padding=True, last_action_type=Non spore_y = all_balls[spore_indices, -3] spore_num = len(spore_x) spore_padding_num = self.max_spore_num - len(spore_x) - if padding or spore_padding_num < 0: + if spore_padding_num < 0: + spore_x = spore_x[:self.max_spore_num] + spore_y = spore_y[:self.max_spore_num] + elif padding: spore_x = np.pad(spore_x, (0, spore_padding_num), 'constant', constant_values=0) spore_y = np.pad(spore_y, (0, spore_padding_num), 'constant', constant_values=0) spore_num = min(spore_num, self.max_spore_num) diff --git a/zoo/gobigger/env/test_gobbiger_env.py b/zoo/gobigger/env/test_gobbiger_env.py index fb756efcf..7a0140b94 100644 --- a/zoo/gobigger/env/test_gobbiger_env.py +++ b/zoo/gobigger/env/test_gobbiger_env.py @@ -65,5 +65,3 @@ def test_env(self): print(rew, info) if done: break - -TestGoBiggerLightZeroEnv().test_env() \ No newline at end of file From 85aeacf40f04948045e871a9839face117f630c3 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Tue, 13 Jun 2023 20:07:52 +0800 Subject: [PATCH 12/54] polish(yzj): contain raw obs only on eval mode for save memory --- lzero/entry/eval_muzero_gobigger.py | 21 +++++++++++++++-- lzero/entry/train_muzero_gobigger.py | 23 ++++++++++++++++++- .../config/gobigger_efficientzero_config.py | 1 + zoo/gobigger/config/gobigger_muzero_config.py | 1 + .../gobigger_sampled_efficientzero_config.py | 1 + zoo/gobigger/env/gobigger_env.py | 6 ++++- zoo/gobigger/env/test_gobbiger_env.py | 1 + 7 files changed, 50 insertions(+), 4 deletions(-) diff --git a/lzero/entry/eval_muzero_gobigger.py b/lzero/entry/eval_muzero_gobigger.py index e5d65901b..73b7abb41 100644 --- a/lzero/entry/eval_muzero_gobigger.py +++ b/lzero/entry/eval_muzero_gobigger.py @@ -5,6 +5,7 @@ import numpy as np import torch from tensorboardX import SummaryWriter +import copy from ding.config import compile_config from ding.envs import create_env_manager @@ -56,8 +57,14 @@ def eval_muzero_gobigger( collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + env_cfg = copy.deepcopy(evaluator_env_cfg[0]) + env_cfg.contain_raw_obs = True + vsbot_evaluator_env_cfg = [env_cfg for _ in range(len(evaluator_env_cfg))] + vsbot_evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in vsbot_evaluator_env_cfg]) + collector_env.seed(cfg.seed) evaluator_env.seed(cfg.seed, dynamic_seed=False) + vsbot_evaluator_env.seed(cfg.seed, dynamic_seed=False) set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) @@ -84,7 +91,17 @@ def eval_muzero_gobigger( exp_name=cfg.exp_name, policy_config=policy_config ) - + vsbot_evaluator = GoBiggerMuZeroEvaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=vsbot_evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + instance_name='vsbot_evaluator' + ) # ============================================================== # Main loop # ============================================================== @@ -94,5 +111,5 @@ def eval_muzero_gobigger( # eval trained model # ============================================================== _, reward_sp = evaluator.eval(learner.save_checkpoint, learner.train_iter) - _, reward_vsbot= evaluator.eval_vsbot(learner.save_checkpoint, learner.train_iter) + _, reward_vsbot= vsbot_evaluator.eval_vsbot(learner.save_checkpoint, learner.train_iter) return reward_sp, reward_vsbot diff --git a/lzero/entry/train_muzero_gobigger.py b/lzero/entry/train_muzero_gobigger.py index f0b095b52..b8e33cadc 100644 --- a/lzero/entry/train_muzero_gobigger.py +++ b/lzero/entry/train_muzero_gobigger.py @@ -11,6 +11,7 @@ from ding.utils import set_pkg_seed from ding.worker import BaseLearner from tensorboardX import SummaryWriter +import copy from lzero.entry.utils import log_buffer_memory_usage from lzero.policy import visit_count_temperature @@ -63,8 +64,14 @@ def train_muzero_gobigger( collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + env_cfg = copy.deepcopy(evaluator_env_cfg[0]) + env_cfg.contain_raw_obs = True + vsbot_evaluator_env_cfg = [env_cfg for _ in range(len(evaluator_env_cfg))] + vsbot_evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in vsbot_evaluator_env_cfg]) + collector_env.seed(cfg.seed) evaluator_env.seed(cfg.seed, dynamic_seed=False) + vsbot_evaluator_env.seed(cfg.seed, dynamic_seed=False) set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) @@ -103,6 +110,18 @@ def train_muzero_gobigger( policy_config=policy_config ) + vsbot_evaluator = GoBiggerMuZeroEvaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=vsbot_evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + instance_name='vsbot_evaluator' + ) + # ============================================================== # Main loop # ============================================================== @@ -119,11 +138,13 @@ def train_muzero_gobigger( policy_config.threshold_training_steps_for_final_temperature, trained_steps=learner.train_iter ) + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + stop, reward = vsbot_evaluator.eval_vsbot(learner.save_checkpoint, learner.train_iter, collector.envstep) # Evaluate policy performance. if evaluator.should_eval(learner.train_iter): stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) - stop, reward= evaluator.eval_vsbot(learner.save_checkpoint, learner.train_iter, collector.envstep) + stop, reward = vsbot_evaluator.eval_vsbot(learner.save_checkpoint, learner.train_iter, collector.envstep) if stop: break diff --git a/zoo/gobigger/config/gobigger_efficientzero_config.py b/zoo/gobigger/config/gobigger_efficientzero_config.py index 585b6d7d6..11ffca0e3 100644 --- a/zoo/gobigger/config/gobigger_efficientzero_config.py +++ b/zoo/gobigger/config/gobigger_efficientzero_config.py @@ -34,6 +34,7 @@ use_action_mask=False, reward_div_value=0.1, reward_type='log_reward', + contain_raw_obs=False, # False on collect mode, True on eval vsbot mode, because bot need raw obs start_spirit_progress=0.2, end_spirit_progress=0.8, manager_settings=dict( diff --git a/zoo/gobigger/config/gobigger_muzero_config.py b/zoo/gobigger/config/gobigger_muzero_config.py index 25662d245..5cfcd9d66 100644 --- a/zoo/gobigger/config/gobigger_muzero_config.py +++ b/zoo/gobigger/config/gobigger_muzero_config.py @@ -34,6 +34,7 @@ use_action_mask=False, reward_div_value=0.1, reward_type='log_reward', + contain_raw_obs=False, # False on collect mode, True on eval vsbot mode, because bot need raw obs start_spirit_progress=0.2, end_spirit_progress=0.8, manager_settings=dict( diff --git a/zoo/gobigger/config/gobigger_sampled_efficientzero_config.py b/zoo/gobigger/config/gobigger_sampled_efficientzero_config.py index 8bbef8236..9db353d2d 100644 --- a/zoo/gobigger/config/gobigger_sampled_efficientzero_config.py +++ b/zoo/gobigger/config/gobigger_sampled_efficientzero_config.py @@ -36,6 +36,7 @@ use_action_mask=False, reward_div_value=0.1, reward_type='log_reward', + contain_raw_obs=False, # False on collect mode, True on eval vsbot mode, because bot need raw obs start_spirit_progress=0.2, end_spirit_progress=0.8, manager_settings=dict( diff --git a/zoo/gobigger/env/gobigger_env.py b/zoo/gobigger/env/gobigger_env.py index 183d80d3c..daa7abc3d 100644 --- a/zoo/gobigger/env/gobigger_env.py +++ b/zoo/gobigger/env/gobigger_env.py @@ -25,6 +25,7 @@ def __init__(self, cfg: dict) -> None: self.step_mul = self._cfg.get('step_mul', 8) self.setup_action() self.setup_feature() + self.contain_raw_obs = self._cfg.contain_raw_obs # for save memory def setup_feature(self): self.second_per_frame = 0.05 @@ -58,7 +59,10 @@ def observation(self, raw_obs): # for alignment with other environments, reverse the action mask action_mask = [np.logical_not(o['action_mask']) for o in obs] to_play = [ -1 for _ in range(len(obs))] # Moot, for alignment with other environments - obs = {'observation': obs, 'action_mask': action_mask, 'to_play': to_play, 'raw_obs':raw_obs} + if self.contain_raw_obs: + obs = {'observation': obs, 'action_mask': action_mask, 'to_play': to_play, 'raw_obs':raw_obs} + else: + obs = {'observation': obs, 'action_mask': action_mask, 'to_play': to_play} return obs def postproecess(self, action_dict): diff --git a/zoo/gobigger/env/test_gobbiger_env.py b/zoo/gobigger/env/test_gobbiger_env.py index 7a0140b94..d933c38c9 100644 --- a/zoo/gobigger/env/test_gobbiger_env.py +++ b/zoo/gobigger/env/test_gobbiger_env.py @@ -18,6 +18,7 @@ use_action_mask=False, reward_div_value=0.1, reward_type='log_reward', + contain_raw_obs=True, # False on collect mode, True on eval vsbot mode, because bot need raw obs start_spirit_progress=0.2, end_spirit_progress=0.8, manager_settings=dict( From f146c4d6b62f940ec67dec3c79fa3da22305d7b6 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Tue, 13 Jun 2023 20:14:51 +0800 Subject: [PATCH 13/54] fix(yzj): fix mcts ptree sampled value/value-prefix bug --- lzero/mcts/tree_search/mcts_ptree_sampled.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lzero/mcts/tree_search/mcts_ptree_sampled.py b/lzero/mcts/tree_search/mcts_ptree_sampled.py index bf26e58aa..62d0cfec5 100644 --- a/lzero/mcts/tree_search/mcts_ptree_sampled.py +++ b/lzero/mcts/tree_search/mcts_ptree_sampled.py @@ -176,7 +176,7 @@ def search( [ network_output.latent_state, network_output.policy_logits, - self.inverse_scalar_transform_handle(network_output.value_prefix), + self.inverse_scalar_transform_handle(network_output.value), self.inverse_scalar_transform_handle(network_output.value_prefix), ] ) From 47b145ef58af0c521169fdb0a999e7cf61bb8519 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Thu, 15 Jun 2023 18:05:22 +0800 Subject: [PATCH 14/54] polish(yzj): polish gobigger encoder model --- .../gobigger/gobigger_efficientzero_model.py | 10 +- lzero/model/gobigger/gobigger_encoder.py | 194 +++++++++++------- lzero/model/gobigger/gobigger_muzero_model.py | 10 +- .../gobigger_sampled_efficientzero_model.py | 10 +- lzero/worker/gobigger_muzero_evaluator.py | 4 +- 5 files changed, 124 insertions(+), 104 deletions(-) diff --git a/lzero/model/gobigger/gobigger_efficientzero_model.py b/lzero/model/gobigger/gobigger_efficientzero_model.py index 0ab0c5c49..f400ce0c0 100644 --- a/lzero/model/gobigger/gobigger_efficientzero_model.py +++ b/lzero/model/gobigger/gobigger_efficientzero_model.py @@ -8,7 +8,7 @@ from ..common import EZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP from ..utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean -from .gobigger_encoder import Encoder +from .gobigger_encoder import GoBiggerEncoder import yaml from easydict import EasyDict from ding.utils.data import default_collate @@ -107,13 +107,7 @@ def __init__( self.state_norm = state_norm self.res_connection_in_dynamics = res_connection_in_dynamics - # self.representation_network = RepresentationNetworkMLP( - # observation_shape=observation_shape, hidden_channels=latent_state_dim, norm_type=norm_type - # ) - with open('lzero/model/gobigger/default_model_config.yaml', "r") as f: - encoder_cfg = yaml.safe_load(f) - encoder_cfg = EasyDict(encoder_cfg) - self.representation_network = Encoder(encoder_cfg) + self.representation_network = GoBiggerEncoder() self.dynamics_network = DynamicsNetworkMLP( action_encoding_dim=self.action_encoding_dim, diff --git a/lzero/model/gobigger/gobigger_encoder.py b/lzero/model/gobigger/gobigger_encoder.py index 91ff3b7ee..5ba0d80f7 100644 --- a/lzero/model/gobigger/gobigger_encoder.py +++ b/lzero/model/gobigger/gobigger_encoder.py @@ -10,7 +10,8 @@ from .network.res_block import ResBlock from .network.transformer import Transformer from typing import Any, List, Tuple, Union, Optional, Callable - +from easydict import EasyDict +from ding.utils.default_helper import deep_merge_dicts def sequence_mask(lengths: torch.Tensor, max_len: Optional[int] =None): r""" @@ -35,7 +36,7 @@ class ScalarEncoder(nn.Module): def __init__(self, cfg): super(ScalarEncoder, self).__init__() self.whole_cfg = cfg - self.cfg = self.whole_cfg.model.scalar_encoder + self.cfg = self.whole_cfg.scalar_encoder self.encode_modules = nn.ModuleDict() for k, item in self.cfg.modules.items(): if item['arc'] == 'time': @@ -50,12 +51,12 @@ def __init__(self, cfg): print(f'cant implement {k} for arc {item["arc"]}') raise NotImplementedError - self.layers = MLP(in_channels=self.cfg.input_dim, hidden_channels=self.cfg.hidden_dim, - out_channels=self.cfg.output_dim, - layer_num=self.cfg.layer_num, + self.layers = MLP(in_channels=self.cfg.mlp.input_dim, hidden_channels=self.cfg.mlp.hidden_dim, + out_channels=self.cfg.mlp.output_dim, + layer_num=self.cfg.mlp.layer_num, layer_fn=fc_block, - activation=self.cfg.activation, - norm_type=self.cfg.norm_type, + activation=self.cfg.mlp.activation, + norm_type=self.cfg.mlp.norm_type, use_dropout=False ) @@ -74,7 +75,7 @@ class TeamEncoder(nn.Module): def __init__(self, cfg): super(TeamEncoder, self).__init__() self.whole_cfg = cfg - self.cfg = self.whole_cfg.model.team_encoder + self.cfg = self.whole_cfg.team_encoder self.encode_modules = nn.ModuleDict() for k, item in self.cfg.modules.items(): @@ -88,35 +89,30 @@ def __init__(self, cfg): print(f'cant implement {k} for arc {item["arc"]}') raise NotImplementedError - self.embedding_dim = self.cfg.embedding_dim - self.encoder_cfg = self.cfg.encoder - self.encode_layers = MLP(in_channels=self.encoder_cfg.input_dim, - hidden_channels=self.encoder_cfg.hidden_dim, - out_channels=self.embedding_dim, - layer_num=self.encoder_cfg.layer_num, + self.encode_layers = MLP(in_channels=self.cfg.mlp.input_dim, + hidden_channels=self.cfg.mlp.hidden_dim, + out_channels=self.cfg.mlp.output_dim, + layer_num=self.cfg.mlp.layer_num, layer_fn=fc_block, - activation=self.encoder_cfg.activation, - norm_type=self.encoder_cfg.norm_type, + activation=self.cfg.mlp.activation, + norm_type=self.cfg.mlp.norm_type, use_dropout=False) - # self.activation_type = self.cfg.activation - self.transformer_cfg = self.cfg.transformer self.transformer = Transformer( - n_heads=self.transformer_cfg.head_num, - embedding_size=self.embedding_dim, - ffn_size=self.transformer_cfg.ffn_size, - n_layers=self.transformer_cfg.layer_num, + n_heads=self.cfg.transformer.head_num, + embedding_size=self.cfg.transformer.embedding_dim, + ffn_size=self.cfg.transformer.ffn_size, + n_layers=self.cfg.transformer.layer_num, attention_dropout=0.0, relu_dropout=0.0, dropout=0.0, - activation=self.transformer_cfg.activation, - variant=self.transformer_cfg.variant, + activation=self.cfg.transformer.activation, + variant=self.cfg.transformer.variant, ) - self.output_cfg = self.cfg.output - self.output_fc = fc_block(self.embedding_dim, - self.output_cfg.output_dim, - norm_type=self.output_cfg.norm_type, - activation=self.output_cfg.activation) + self.output_fc = fc_block(self.cfg.fc_block.input_dim, + self.cfg.fc_block.output_dim, + norm_type=self.cfg.fc_block.norm_type, + activation=self.cfg.fc_block.activation) def forward(self, x): embeddings = [] @@ -138,7 +134,7 @@ class BallEncoder(nn.Module): def __init__(self, cfg): super(BallEncoder, self).__init__() self.whole_cfg = cfg - self.cfg = self.whole_cfg.model.ball_encoder + self.cfg = self.whole_cfg.ball_encoder self.encode_modules = nn.ModuleDict() for k, item in self.cfg.modules.items(): if item['arc'] == 'one_hot': @@ -152,34 +148,30 @@ def __init__(self, cfg): else: print(f'cant implement {k} for arc {item["arc"]}') raise NotImplementedError - self.embedding_dim = self.cfg.embedding_dim - self.encoder_cfg = self.cfg.encoder - self.encode_layers = MLP(in_channels=self.encoder_cfg.input_dim, - hidden_channels=self.encoder_cfg.hidden_dim, - out_channels=self.embedding_dim, - layer_num=self.encoder_cfg.layer_num, + self.encode_layers = MLP(in_channels=self.cfg.mlp.input_dim, + hidden_channels=self.cfg.mlp.hidden_dim, + out_channels=self.cfg.mlp.output_dim, + layer_num=self.cfg.mlp.layer_num, layer_fn=fc_block, - activation=self.encoder_cfg.activation, - norm_type=self.encoder_cfg.norm_type, + activation=self.cfg.mlp.activation, + norm_type=self.cfg.mlp.norm_type, use_dropout=False) - self.transformer_cfg = self.cfg.transformer self.transformer = Transformer( - n_heads=self.transformer_cfg.head_num, - embedding_size=self.embedding_dim, - ffn_size=self.transformer_cfg.ffn_size, - n_layers=self.transformer_cfg.layer_num, + n_heads=self.cfg.transformer.head_num, + embedding_size=self.cfg.transformer.embedding_dim, + ffn_size=self.cfg.transformer.ffn_size, + n_layers=self.cfg.transformer.layer_num, attention_dropout=0.0, relu_dropout=0.0, dropout=0.0, - activation=self.transformer_cfg.activation, - variant=self.transformer_cfg.variant, + activation=self.cfg.transformer.activation, + variant=self.cfg.transformer.variant, ) - self.output_cfg = self.cfg.output - self.output_fc = fc_block(self.embedding_dim, - self.output_cfg.output_dim, - norm_type=self.output_cfg.norm_type, - activation=self.output_cfg.activation) + self.output_fc = fc_block(self.cfg.fc_block.input_dim, + self.cfg.fc_block.output_dim, + norm_type=self.cfg.fc_block.norm_type, + activation=self.cfg.fc_block.activation) def forward(self, x): ball_num = x['ball_num'] @@ -202,55 +194,54 @@ class SpatialEncoder(nn.Module): def __init__(self, cfg): super(SpatialEncoder, self).__init__() self.whole_cfg = cfg - self.cfg = self.whole_cfg.model.spatial_encoder + self.cfg = self.whole_cfg.spatial_encoder # scatter related self.spatial_x = 64 self.spatial_y = 64 - self.scatter_cfg = self.cfg.scatter - self.scatter_fc = fc_block(in_channels=self.scatter_cfg.input_dim, out_channels=self.scatter_cfg.output_dim, - activation=self.scatter_cfg.activation, norm_type=self.scatter_cfg.norm_type) - self.scatter_connection = ScatterConnection(self.scatter_cfg.scatter_type) + self.scatter_fc = fc_block(in_channels=self.cfg.scatter.input_dim, + out_channels=self.cfg.scatter.output_dim, + activation=self.cfg.scatter.activation, + norm_type=self.cfg.scatter.norm_type) + self.scatter_connection = ScatterConnection(self.cfg.scatter.scatter_type) # resnet related - self.resnet_cfg = self.cfg.resnet self.get_resnet_blocks() - self.output_cfg = self.cfg.output self.output_fc = fc_block( - in_channels=self.spatial_x // 8 * self.spatial_y // 8 * self.resnet_cfg.down_channels[-1], - out_channels=self.output_cfg.output_dim, - norm_type=self.output_cfg.norm_type, - activation=self.output_cfg.activation) + in_channels=self.spatial_x // 8 * self.spatial_y // 8 * self.cfg.resnet.down_channels[-1], + out_channels=self.cfg.fc_block.output_dim, + norm_type=self.cfg.fc_block.norm_type, + activation=self.cfg.fc_block.activation) def get_resnet_blocks(self): # 2 means food/spore embedding - project = conv2d_block(in_channels=self.scatter_cfg.output_dim + 2, - out_channels=self.resnet_cfg.project_dim, + project = conv2d_block(in_channels=self.cfg.scatter.output_dim + 2, + out_channels=self.cfg.resnet.project_dim, kernel_size=1, stride=1, padding=0, - activation=self.resnet_cfg.activation, - norm_type=self.resnet_cfg.norm_type, + activation=self.cfg.resnet.activation, + norm_type=self.cfg.resnet.norm_type, bias=False, ) layers = [project] - dims = [self.resnet_cfg.project_dim] + self.resnet_cfg.down_channels + dims = [self.cfg.resnet.project_dim] + self.cfg.resnet.down_channels for i in range(len(dims) - 1): layer = conv2d_block(in_channels=dims[i], out_channels=dims[i + 1], kernel_size=4, stride=2, padding=1, - activation=self.resnet_cfg.activation, - norm_type=self.resnet_cfg.norm_type, + activation=self.cfg.resnet.activation, + norm_type=self.cfg.resnet.norm_type, bias=False, ) layers.append(layer) layers.append(ResBlock(in_channels=dims[i + 1], - activation=self.resnet_cfg.activation, - norm_type=self.resnet_cfg.norm_type)) + activation=self.cfg.resnet.activation, + norm_type=self.cfg.resnet.norm_type)) self.resnet = torch.nn.Sequential(*layers) @@ -296,14 +287,61 @@ def forward(self, inputs, ball_embeddings, ): return x -class Encoder(nn.Module): - def __init__(self, cfg): - super(Encoder, self).__init__() - self.whole_cfg = cfg - self.scalar_encoder = ScalarEncoder(cfg) - self.team_encoder = TeamEncoder(cfg) - self.ball_encoder = BallEncoder(cfg) - self.spatial_encoder = SpatialEncoder(cfg) +class GoBiggerEncoder(nn.Module): + config=dict( + scalar_encoder=dict( + modules=dict( + view_x=dict(arc='sign_binary', num_embeddings=7), + view_y=dict(arc='sign_binary', num_embeddings=7), + view_width=dict(arc='binary', num_embeddings=7), + score=dict(arc='one_hot', num_embeddings=10), + team_score=dict(arc='one_hot', num_embeddings=10), + rank=dict(arc='one_hot', num_embeddings=4), + time=dict(arc='time', embedding_dim=8), + last_action_type=dict(arc='one_hot', num_embeddings=27), + ), + mlp=dict(input_dim=80, hidden_dim=64, layer_num=2, norm_type='none', output_dim=32, activation='relu'), + ), + team_encoder=dict( + modules=dict( + alliance=dict(arc='one_hot', num_embeddings=2), + view_x=dict(arc='sign_binary', num_embeddings=7), + view_y=dict(arc='sign_binary', num_embeddings=7), + ), + mlp=dict(input_dim=16, hidden_dim=32, layer_num=2, norm_type='none', output_dim=16, activation='relu'), + transformer=dict(head_num=4, ffn_size=32, layer_num=2, embedding_dim=16, activation='relu', variant='postnorm'), + fc_block=dict(input_dim=16, output_dim=16, activation='relu', norm_type='none'), + ), + ball_encoder=dict( + modules=dict( + alliance=dict(arc='one_hot', num_embeddings=4), + score=dict(arc='one_hot', num_embeddings=50), + radius=dict(arc='unsqueeze',), + rank=dict(arc='one_hot', num_embeddings=5), + x=dict(arc='sign_binary', num_embeddings=8), + y=dict(arc='sign_binary', num_embeddings=8), + next_x=dict(arc='sign_binary', num_embeddings=8), + next_y=dict(arc='sign_binary', num_embeddings=8), + ), + mlp=dict(input_dim=92, hidden_dim=128, layer_num=2, norm_type='none', output_dim=64, activation='relu'), + transformer=dict(head_num=4, ffn_size=64, layer_num=3, embedding_dim=64, activation='relu', variant='postnorm'), + fc_block=dict(input_dim=64, output_dim=64, activation='relu', norm_type='none'), + ), + spatial_encoder=dict( + scatter=dict(input_dim=64, output_dim=16, scatter_type='add', activation='relu', norm_type='none'), + resnet=dict(project_dim=12, down_channels=[32, 32, 16 ], activation='relu', norm_type='none'), + fc_block=dict(output_dim=64, activation='relu', norm_type='none'), + ), + ) + + def __init__(self, cfg=None): + super(GoBiggerEncoder, self).__init__() + self._cfg = deep_merge_dicts(self.config, cfg) + self._cfg = EasyDict(self._cfg) + self.scalar_encoder = ScalarEncoder(self._cfg) + self.team_encoder = TeamEncoder(self._cfg) + self.ball_encoder = BallEncoder(self._cfg) + self.spatial_encoder = SpatialEncoder(self._cfg) def forward(self, x): scalar_info = self.scalar_encoder(x['scalar_info']) diff --git a/lzero/model/gobigger/gobigger_muzero_model.py b/lzero/model/gobigger/gobigger_muzero_model.py index 709d23c53..96dc78685 100644 --- a/lzero/model/gobigger/gobigger_muzero_model.py +++ b/lzero/model/gobigger/gobigger_muzero_model.py @@ -7,7 +7,7 @@ from ..common import MZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP from ..utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean -from .gobigger_encoder import Encoder +from .gobigger_encoder import GoBiggerEncoder import yaml from easydict import EasyDict from ding.utils.data import default_collate @@ -105,13 +105,7 @@ def __init__( self.state_norm = state_norm self.res_connection_in_dynamics = res_connection_in_dynamics - # self.representation_network = RepresentationNetworkMLP( - # observation_shape=observation_shape, hidden_channels=self.latent_state_dim, norm_type=norm_type - # ) - with open('lzero/model/gobigger/default_model_config.yaml', "r") as f: - encoder_cfg = yaml.safe_load(f) - encoder_cfg = EasyDict(encoder_cfg) - self.representation_network = Encoder(encoder_cfg) + self.representation_network = GoBiggerEncoder() self.dynamics_network = DynamicsNetwork( action_encoding_dim=self.action_encoding_dim, diff --git a/lzero/model/gobigger/gobigger_sampled_efficientzero_model.py b/lzero/model/gobigger/gobigger_sampled_efficientzero_model.py index 8a73c8116..762f15cd6 100644 --- a/lzero/model/gobigger/gobigger_sampled_efficientzero_model.py +++ b/lzero/model/gobigger/gobigger_sampled_efficientzero_model.py @@ -9,7 +9,7 @@ from ..common import EZNetworkOutput, RepresentationNetworkMLP from ..efficientzero_model_mlp import DynamicsNetworkMLP from ..utils import renormalize, get_params_mean -from .gobigger_encoder import Encoder +from .gobigger_encoder import GoBiggerEncoder import yaml from easydict import EasyDict from ding.utils.data import default_collate @@ -139,13 +139,7 @@ def __init__( self.num_of_sampled_actions = num_of_sampled_actions self.res_connection_in_dynamics = res_connection_in_dynamics - # self.representation_network = RepresentationNetworkMLP( - # observation_shape=self.observation_shape, hidden_channels=self.latent_state_dim, norm_type=norm_type - # ) - with open('lzero/model/gobigger/default_model_config.yaml', "r") as f: - encoder_cfg = yaml.safe_load(f) - encoder_cfg = EasyDict(encoder_cfg) - self.representation_network = Encoder(encoder_cfg) + self.representation_network = GoBiggerEncoder() self.dynamics_network = DynamicsNetworkMLP( action_encoding_dim=self.action_encoding_dim, diff --git a/lzero/worker/gobigger_muzero_evaluator.py b/lzero/worker/gobigger_muzero_evaluator.py index efe39d241..40c741414 100644 --- a/lzero/worker/gobigger_muzero_evaluator.py +++ b/lzero/worker/gobigger_muzero_evaluator.py @@ -729,8 +729,8 @@ def eval_vsbot( continue if not np.isscalar(v): continue - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name+'_vsbot') + k, v, train_iter) - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name+'_vsbot') + k, v, envstep) + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) eval_reward = np.mean(episode_return) if eval_reward > self._max_eval_reward: if save_ckpt_fn: From 2772ffdd88d28cc7068f3407e5eae96d65bf9ca6 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Fri, 16 Jun 2023 16:26:35 +0800 Subject: [PATCH 15/54] polish(yzj): polish gobigger encoder model with ding --- .../gobigger/gobigger_efficientzero_model.py | 2 +- lzero/model/gobigger/gobigger_muzero_model.py | 2 +- .../gobigger_sampled_efficientzero_model.py | 2 +- lzero/model/gobigger/network/__init__.py | 8 - lzero/model/gobigger/network/activation.py | 96 ----- .../{ => network}/gobigger_encoder.py | 84 ++-- lzero/model/gobigger/network/nn_module.py | 235 ----------- lzero/model/gobigger/network/normalization.py | 36 -- lzero/model/gobigger/network/res_block.py | 231 ---------- lzero/model/gobigger/network/rnn.py | 276 ------------ .../gobigger/network/scatter_connection.py | 18 - lzero/model/gobigger/network/soft_argmax.py | 60 --- lzero/model/gobigger/network/transformer.py | 397 ------------------ 13 files changed, 46 insertions(+), 1401 deletions(-) delete mode 100644 lzero/model/gobigger/network/activation.py rename lzero/model/gobigger/{ => network}/gobigger_encoder.py (85%) delete mode 100644 lzero/model/gobigger/network/nn_module.py delete mode 100644 lzero/model/gobigger/network/normalization.py delete mode 100644 lzero/model/gobigger/network/res_block.py delete mode 100644 lzero/model/gobigger/network/rnn.py delete mode 100644 lzero/model/gobigger/network/soft_argmax.py delete mode 100644 lzero/model/gobigger/network/transformer.py diff --git a/lzero/model/gobigger/gobigger_efficientzero_model.py b/lzero/model/gobigger/gobigger_efficientzero_model.py index f400ce0c0..5844729bf 100644 --- a/lzero/model/gobigger/gobigger_efficientzero_model.py +++ b/lzero/model/gobigger/gobigger_efficientzero_model.py @@ -8,7 +8,7 @@ from ..common import EZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP from ..utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean -from .gobigger_encoder import GoBiggerEncoder +from .network.gobigger_encoder import GoBiggerEncoder import yaml from easydict import EasyDict from ding.utils.data import default_collate diff --git a/lzero/model/gobigger/gobigger_muzero_model.py b/lzero/model/gobigger/gobigger_muzero_model.py index 96dc78685..36585b4ef 100644 --- a/lzero/model/gobigger/gobigger_muzero_model.py +++ b/lzero/model/gobigger/gobigger_muzero_model.py @@ -7,7 +7,7 @@ from ..common import MZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP from ..utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean -from .gobigger_encoder import GoBiggerEncoder +from .network.gobigger_encoder import GoBiggerEncoder import yaml from easydict import EasyDict from ding.utils.data import default_collate diff --git a/lzero/model/gobigger/gobigger_sampled_efficientzero_model.py b/lzero/model/gobigger/gobigger_sampled_efficientzero_model.py index 762f15cd6..aa7b79504 100644 --- a/lzero/model/gobigger/gobigger_sampled_efficientzero_model.py +++ b/lzero/model/gobigger/gobigger_sampled_efficientzero_model.py @@ -9,7 +9,7 @@ from ..common import EZNetworkOutput, RepresentationNetworkMLP from ..efficientzero_model_mlp import DynamicsNetworkMLP from ..utils import renormalize, get_params_mean -from .gobigger_encoder import GoBiggerEncoder +from .network.gobigger_encoder import GoBiggerEncoder import yaml from easydict import EasyDict from ding.utils.data import default_collate diff --git a/lzero/model/gobigger/network/__init__.py b/lzero/model/gobigger/network/__init__.py index 50e7db84b..e69de29bb 100644 --- a/lzero/model/gobigger/network/__init__.py +++ b/lzero/model/gobigger/network/__init__.py @@ -1,8 +0,0 @@ -from .activation import build_activation -from .res_block import ResBlock, ResFCBlock,ResFCBlock2 -from .nn_module import fc_block, fc_block2, conv2d_block, MLP -from .normalization import build_normalization -from .rnn import get_lstm, sequence_mask -from .soft_argmax import SoftArgmax -from .transformer import Transformer -from .scatter_connection import ScatterConnection diff --git a/lzero/model/gobigger/network/activation.py b/lzero/model/gobigger/network/activation.py deleted file mode 100644 index 550bee3d5..000000000 --- a/lzero/model/gobigger/network/activation.py +++ /dev/null @@ -1,96 +0,0 @@ -""" -Copyright 2020 Sensetime X-lab. All Rights Reserved - -Main Function: - 1. build activation: you can use build_activation to build relu or glu -""" -import torch -import torch.nn as nn - - -class GLU(nn.Module): - r""" - Overview: - Gating Linear Unit. - This class does a thing like this: - - .. code:: python - - # Inputs: input, context, output_size - # The gate value is a learnt function of the input. - gate = sigmoid(linear(input.size)(context)) - # Gate the input and return an output of desired size. - gated_input = gate * input - output = linear(output_size)(gated_input) - return output - Interfaces: - forward - - .. tip:: - - This module also supports 2D convolution, in which case, the input and context must have the same shape. - """ - - def __init__(self, input_dim: int, output_dim: int, context_dim: int, input_type: str = 'fc') -> None: - r""" - Overview: - Init GLU - Arguments: - - input_dim (:obj:`int`): the input dimension - - output_dim (:obj:`int`): the output dimension - - context_dim (:obj:`int`): the context dimension - - input_type (:obj:`str`): the type of input, now support ['fc', 'conv2d'] - """ - super(GLU, self).__init__() - assert (input_type in ['fc', 'conv2d']) - if input_type == 'fc': - self.layer1 = nn.Linear(context_dim, input_dim) - self.layer2 = nn.Linear(input_dim, output_dim) - elif input_type == 'conv2d': - self.layer1 = nn.Conv2d(context_dim, input_dim, 1, 1, 0) - self.layer2 = nn.Conv2d(input_dim, output_dim, 1, 1, 0) - - def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor: - r""" - Overview: - Return GLU computed tensor - Arguments: - - x (:obj:`torch.Tensor`) : the input tensor - - context (:obj:`torch.Tensor`) : the context tensor - Returns: - - x (:obj:`torch.Tensor`): the computed tensor - """ - gate = self.layer1(context) - gate = torch.sigmoid(gate) - x = gate * x - x = self.layer2(x) - return x - -class Swish(nn.Module): - - def __init__(self): - super(Swish, self).__init__() - - def forward(self, x): - x = x * torch.sigmoid(x) - return x - -def build_activation(activation: str, inplace: bool = None) -> nn.Module: - r""" - Overview: - Return the activation module according to the given type. - Arguments: - - actvation (:obj:`str`): the type of activation module, now supports ['relu', 'glu', 'prelu'] - - inplace (:obj:`bool`): can optionally do the operation in-place in relu. Default ``None`` - Returns: - - act_func (:obj:`nn.module`): the corresponding activation module - """ - if inplace is not None: - assert activation == 'relu', 'inplace argument is not compatible with {}'.format(activation) - else: - inplace = True - act_func = {'relu': nn.ReLU(inplace=inplace), 'glu': GLU, 'prelu': nn.PReLU(),'swish': Swish()} - if activation in act_func.keys(): - return act_func[activation] - else: - raise KeyError("invalid key for activation: {}".format(activation)) diff --git a/lzero/model/gobigger/gobigger_encoder.py b/lzero/model/gobigger/network/gobigger_encoder.py similarity index 85% rename from lzero/model/gobigger/gobigger_encoder.py rename to lzero/model/gobigger/network/gobigger_encoder.py index 5ba0d80f7..a7916d8a4 100644 --- a/lzero/model/gobigger/gobigger_encoder.py +++ b/lzero/model/gobigger/network/gobigger_encoder.py @@ -1,17 +1,14 @@ -from typing import Dict - +from typing import Dict, Optional import torch import torch.nn as nn from torch import Tensor - -from .network import sequence_mask, ScatterConnection -from .network.encoder import SignBinaryEncoder, BinaryEncoder, OnehotEncoder, TimeEncoder, UnsqueezeEncoder -from .network.nn_module import fc_block, conv2d_block, MLP -from .network.res_block import ResBlock -from .network.transformer import Transformer -from typing import Any, List, Tuple, Union, Optional, Callable -from easydict import EasyDict from ding.utils.default_helper import deep_merge_dicts +from ding.torch_utils import MLP, fc_block, conv2d_block, ResBlock +from ding.torch_utils import Transformer +from easydict import EasyDict + +from .encoder import SignBinaryEncoder, BinaryEncoder, OnehotEncoder, TimeEncoder, UnsqueezeEncoder +from .scatter_connection import ScatterConnection def sequence_mask(lengths: torch.Tensor, max_len: Optional[int] =None): r""" @@ -57,8 +54,10 @@ def __init__(self, cfg): layer_fn=fc_block, activation=self.cfg.mlp.activation, norm_type=self.cfg.mlp.norm_type, - use_dropout=False - ) + use_dropout=False, + output_activation=True, + output_norm=True, + last_linear_layer_init_zero=False) def forward(self, x: Dict[str, Tensor]): embeddings = [] @@ -96,18 +95,19 @@ def __init__(self, cfg): layer_fn=fc_block, activation=self.cfg.mlp.activation, norm_type=self.cfg.mlp.norm_type, - use_dropout=False) + use_dropout=False, + output_activation=True, + output_norm=True, + last_linear_layer_init_zero=False) self.transformer = Transformer( - n_heads=self.cfg.transformer.head_num, - embedding_size=self.cfg.transformer.embedding_dim, - ffn_size=self.cfg.transformer.ffn_size, - n_layers=self.cfg.transformer.layer_num, - attention_dropout=0.0, - relu_dropout=0.0, - dropout=0.0, + input_dim=self.cfg.transformer.input_dim, + output_dim=self.cfg.transformer.output_dim, + head_num=self.cfg.transformer.head_num, + head_dim=self.cfg.transformer.embedding_dim, + hidden_dim=self.cfg.transformer.ffn_size, + layer_num=self.cfg.transformer.layer_num, activation=self.cfg.transformer.activation, - variant=self.cfg.transformer.variant, ) self.output_fc = fc_block(self.cfg.fc_block.input_dim, self.cfg.fc_block.output_dim, @@ -155,18 +155,19 @@ def __init__(self, cfg): layer_fn=fc_block, activation=self.cfg.mlp.activation, norm_type=self.cfg.mlp.norm_type, - use_dropout=False) + use_dropout=False, + output_activation=True, + output_norm=True, + last_linear_layer_init_zero=False) self.transformer = Transformer( - n_heads=self.cfg.transformer.head_num, - embedding_size=self.cfg.transformer.embedding_dim, - ffn_size=self.cfg.transformer.ffn_size, - n_layers=self.cfg.transformer.layer_num, - attention_dropout=0.0, - relu_dropout=0.0, - dropout=0.0, + input_dim=self.cfg.transformer.input_dim, + output_dim=self.cfg.transformer.output_dim, + head_num=self.cfg.transformer.head_num, + head_dim=self.cfg.transformer.embedding_dim, + hidden_dim=self.cfg.transformer.ffn_size, + layer_num=self.cfg.transformer.layer_num, activation=self.cfg.transformer.activation, - variant=self.cfg.transformer.variant, ) self.output_fc = fc_block(self.cfg.fc_block.input_dim, self.cfg.fc_block.output_dim, @@ -239,7 +240,8 @@ def get_resnet_blocks(self): bias=False, ) layers.append(layer) - layers.append(ResBlock(in_channels=dims[i + 1], + layers.append(ResBlock(res_type='basic', + in_channels=dims[i + 1], activation=self.cfg.resnet.activation, norm_type=self.cfg.resnet.norm_type)) self.resnet = torch.nn.Sequential(*layers) @@ -300,7 +302,7 @@ class GoBiggerEncoder(nn.Module): time=dict(arc='time', embedding_dim=8), last_action_type=dict(arc='one_hot', num_embeddings=27), ), - mlp=dict(input_dim=80, hidden_dim=64, layer_num=2, norm_type='none', output_dim=32, activation='relu'), + mlp=dict(input_dim=80, hidden_dim=64, layer_num=2, norm_type='BN', output_dim=32, activation=nn.ReLU(inplace=True)), ), team_encoder=dict( modules=dict( @@ -308,9 +310,9 @@ class GoBiggerEncoder(nn.Module): view_x=dict(arc='sign_binary', num_embeddings=7), view_y=dict(arc='sign_binary', num_embeddings=7), ), - mlp=dict(input_dim=16, hidden_dim=32, layer_num=2, norm_type='none', output_dim=16, activation='relu'), - transformer=dict(head_num=4, ffn_size=32, layer_num=2, embedding_dim=16, activation='relu', variant='postnorm'), - fc_block=dict(input_dim=16, output_dim=16, activation='relu', norm_type='none'), + mlp=dict(input_dim=16, hidden_dim=32, layer_num=2, norm_type=None, output_dim=16, activation=nn.ReLU(inplace=True)), + transformer=dict(input_dim=16, output_dim=16, head_num=4, ffn_size=32, layer_num=2, embedding_dim=16, activation=nn.ReLU(inplace=True), variant='postnorm'), + fc_block=dict(input_dim=16, output_dim=16, activation=nn.ReLU(inplace=True), norm_type='BN'), ), ball_encoder=dict( modules=dict( @@ -323,14 +325,14 @@ class GoBiggerEncoder(nn.Module): next_x=dict(arc='sign_binary', num_embeddings=8), next_y=dict(arc='sign_binary', num_embeddings=8), ), - mlp=dict(input_dim=92, hidden_dim=128, layer_num=2, norm_type='none', output_dim=64, activation='relu'), - transformer=dict(head_num=4, ffn_size=64, layer_num=3, embedding_dim=64, activation='relu', variant='postnorm'), - fc_block=dict(input_dim=64, output_dim=64, activation='relu', norm_type='none'), + mlp=dict(input_dim=92, hidden_dim=128, layer_num=2, norm_type=None, output_dim=64, activation=nn.ReLU(inplace=True)), + transformer=dict(input_dim=64, output_dim=64, head_num=4, ffn_size=64, layer_num=3, embedding_dim=64, activation=nn.ReLU(inplace=True), variant='postnorm'), + fc_block=dict(input_dim=64, output_dim=64, activation=nn.ReLU(inplace=True), norm_type='BN'), ), spatial_encoder=dict( - scatter=dict(input_dim=64, output_dim=16, scatter_type='add', activation='relu', norm_type='none'), - resnet=dict(project_dim=12, down_channels=[32, 32, 16 ], activation='relu', norm_type='none'), - fc_block=dict(output_dim=64, activation='relu', norm_type='none'), + scatter=dict(input_dim=64, output_dim=16, scatter_type='add', activation=nn.ReLU(inplace=True), norm_type=None), + resnet=dict(project_dim=12, down_channels=[32, 32, 16 ], activation=nn.ReLU(inplace=True), norm_type='BN'), + fc_block=dict(output_dim=64, activation=nn.ReLU(inplace=True), norm_type='BN'), ), ) diff --git a/lzero/model/gobigger/network/nn_module.py b/lzero/model/gobigger/network/nn_module.py deleted file mode 100644 index 976831432..000000000 --- a/lzero/model/gobigger/network/nn_module.py +++ /dev/null @@ -1,235 +0,0 @@ -from typing import Callable - -import torch -import torch.nn as nn - -from .activation import build_activation -from .normalization import build_normalization - - -def fc_block( - in_channels: int, - out_channels: int, - activation: nn.Module = None, - norm_type: str = None, - use_dropout: bool = False, - dropout_probability: float = 0.5 -) -> nn.Sequential: - r""" - Overview: - Create a fully-connected block with activation, normalization and dropout. - Optional normalization can be done to the dim 1 (across the channels) - x -> fc -> norm -> act -> dropout -> out - Arguments: - - in_channels (:obj:`int`): Number of channels in the input tensor - - out_channels (:obj:`int`): Number of channels in the output tensor - - activation (:obj:`nn.Module`): the optional activation function - - norm_type (:obj:`str`): type of the normalization - - use_dropout (:obj:`bool`) : whether to use dropout in the fully-connected block - - dropout_probability (:obj:`float`) : probability of an element to be zeroed in the dropout. Default: 0.5 - Returns: - - block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the fully-connected block - - .. note:: - - you can refer to nn.linear (https://pytorch.org/docs/master/generated/torch.nn.Linear.html) - """ - block = [] - block.append(nn.Linear(in_channels, out_channels)) - if norm_type is not None and norm_type != 'none': - block.append(build_normalization(norm_type, dim=1)(out_channels)) - if isinstance(activation, str) and activation != 'none': - block.append(build_activation(activation)) - elif isinstance(activation, torch.nn.Module): - block.append(activation) - if use_dropout: - block.append(nn.Dropout(dropout_probability)) - return nn.Sequential(*block) - - -def fc_block2( - in_channels, - out_channels, - activation=None, - norm_type=None, - use_dropout=False, - dropout_probability=0.5 -): - r""" - Overview: - create a fully-connected block with activation, normalization and dropout - optional normalization can be done to the dim 1 (across the channels) - x -> fc -> norm -> act -> dropout -> out - Arguments: - - in_channels (:obj:`int`): Number of channels in the input tensor - - out_channels (:obj:`int`): Number of channels in the output tensor - - init_type (:obj:`str`): the type of init to implement - - activation (:obj:`nn.Moduel`): the optional activation function - - norm_type (:obj:`str`): type of the normalization - - use_dropout (:obj:`bool`) : whether to use dropout in the fully-connected block - - dropout_probability (:obj:`float`) : probability of an element to be zeroed in the dropout. Default: 0.5 - Returns: - - block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the fully-connected block - - .. note:: - you can refer to nn.linear (https://pytorch.org/docs/master/generated/torch.nn.Linear.html) - """ - block = [] - if norm_type is not None and norm_type != 'none': - block.append(build_normalization(norm_type, dim=1)(in_channels)) - if isinstance(activation, str) and activation != 'none': - block.append(build_activation(activation)) - elif isinstance(activation, torch.nn.Module): - block.append(activation) - block.append(nn.Linear(in_channels, out_channels)) - if use_dropout: - block.append(nn.Dropout(dropout_probability)) - return nn.Sequential(*block) - - -def conv2d_block( - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int = 1, - padding: int = 0, - dilation: int = 1, - groups: int = 1, - activation: str = None, - norm_type: str = None, - bias: bool = True, -) -> nn.Sequential: - r""" - Overview: - Create a 2-dim convlution layer with activation and normalization. - Arguments: - - in_channels (:obj:`int`): Number of channels in the input tensor - - out_channels (:obj:`int`): Number of channels in the output tensor - - kernel_size (:obj:`int`): Size of the convolving kernel - - stride (:obj:`int`): Stride of the convolution - - padding (:obj:`int`): Zero-padding added to both sides of the input - - dilation (:obj:`int`): Spacing between kernel elements - - groups (:obj:`int`): Number of blocked connections from input channels to output channels - - pad_type (:obj:`str`): the way to add padding, include ['zero', 'reflect', 'replicate'], default: None - - activation (:obj:`nn.Module`): the optional activation function - - norm_type (:obj:`str`): type of the normalization, default set to None, now support ['BN', 'IN', 'SyncBN'] - Returns: - - block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the 2 dim convlution layer - - .. note:: - - Conv2d (https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d) - """ - block = [] - block.append( - nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, - padding=padding, dilation=dilation, groups=groups,bias=bias) - ) - if norm_type is not None: - block.append(nn.GroupNorm(num_groups=1, num_channels=out_channels)) - if isinstance(activation, str) and activation != 'none': - block.append(build_activation(activation)) - elif isinstance(activation, torch.nn.Module): - block.append(activation) - return nn.Sequential(*block) - - -def conv2d_block2( - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - activation: str = None, - norm_type=None, - bias: bool = True, -): - r""" - Overview: - create a 2-dim convlution layer with activation and normalization. - - Note: - Conv2d (https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d) - - Arguments: - - in_channels (:obj:`int`): Number of channels in the input tensor - - out_channels (:obj:`int`): Number of channels in the output tensor - - kernel_size (:obj:`int`): Size of the convolving kernel - - stride (:obj:`int`): Stride of the convolution - - padding (:obj:`int`): Zero-padding added to both sides of the input - - dilation (:obj:`int`): Spacing between kernel elements - - groups (:obj:`int`): Number of blocked connections from input channels to output channels - - init_type (:obj:`str`): the type of init to implement - - pad_type (:obj:`str`): the way to add padding, include ['zero', 'reflect', 'replicate'], default: None - - activation (:obj:`nn.Moduel`): the optional activation function - - norm_type (:obj:`str`): type of the normalization, default set to None, now support ['BN', 'IN', 'SyncBN'] - - Returns: - - block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the 2 dim convlution layer - """ - - block = [] - if norm_type is not None: - block.append(nn.GroupNorm(num_groups=1, num_channels=out_channels)) - if isinstance(activation, str) and activation != 'none': - block.append(build_activation(activation)) - elif isinstance(activation, torch.nn.Module): - block.append(activation) - block.append( - nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, - padding=padding, dilation=dilation, groups=groups,bias=bias) - ) - return nn.Sequential(*block) - - -def MLP( - in_channels: int, - hidden_channels: int, - out_channels: int, - layer_num: int, - layer_fn: Callable = None, - activation: str = None, - norm_type: str = None, - use_dropout: bool = False, - dropout_probability: float = 0.5 -): - r""" - Overview: - create a multi-layer perceptron using fully-connected blocks with activation, normalization and dropout, - optional normalization can be done to the dim 1 (across the channels) - x -> fc -> norm -> act -> dropout -> out - Arguments: - - in_channels (:obj:`int`): Number of channels in the input tensor - - hidden_channels (:obj:`int`): Number of channels in the hidden tensor - - out_channels (:obj:`int`): Number of channels in the output tensor - - layer_num (:obj:`int`): Number of layers - - layer_fn (:obj:`Callable`): layer function - - activation (:obj:`nn.Module`): the optional activation function - - norm_type (:obj:`str`): type of the normalization - - use_dropout (:obj:`bool`): whether to use dropout in the fully-connected block - - dropout_probability (:obj:`float`): probability of an element to be zeroed in the dropout. Default: 0.5 - Returns: - - block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the fully-connected block - - .. note:: - - you can refer to nn.linear (https://pytorch.org/docs/master/generated/torch.nn.Linear.html) - """ - assert layer_num >= 0, layer_num - if layer_num == 0: - return nn.Sequential(*[nn.Identity()]) - - channels = [in_channels] + [hidden_channels] * (layer_num - 1) + [out_channels] - if layer_fn is None: - layer_fn = fc_block - block = [] - for i, (in_channels, out_channels) in enumerate(zip(channels[:-1], channels[1:])): - block.append(layer_fn(in_channels=in_channels, - out_channels=out_channels, - activation=activation, - norm_type=norm_type, - use_dropout=use_dropout, - dropout_probability=dropout_probability)) - return nn.Sequential(*block) diff --git a/lzero/model/gobigger/network/normalization.py b/lzero/model/gobigger/network/normalization.py deleted file mode 100644 index fd5831c14..000000000 --- a/lzero/model/gobigger/network/normalization.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import Optional -import torch.nn as nn - - -def build_normalization(norm_type: str, dim: Optional[int] = None) -> nn.Module: - r""" - Overview: - Build the corresponding normalization module - Arguments: - - norm_type (:obj:`str`): type of the normaliztion, now support ['BN', 'IN', 'SyncBN', 'AdaptiveIN'] - - dim (:obj:`int`): dimension of the normalization, when norm_type is in [BN, IN] - Returns: - - norm_func (:obj:`nn.Module`): the corresponding batch normalization function - - .. note:: - For beginers, you can refer to to learn more about batch normalization. - """ - if dim is None: - key = norm_type - else: - if norm_type in ['BN', 'IN', 'SyncBN']: - key = norm_type + str(dim) - elif norm_type in ['LN']: - key = norm_type - else: - raise NotImplementedError("not support indicated dim when creates {}".format(norm_type)) - norm_func = { - 'BN1': nn.BatchNorm1d, - 'BN2': nn.BatchNorm2d, - 'LN': nn.LayerNorm, - 'IN2': nn.InstanceNorm2d, - } - if key in norm_func.keys(): - return norm_func[key] - else: - raise KeyError("invalid norm type: {}".format(key)) \ No newline at end of file diff --git a/lzero/model/gobigger/network/res_block.py b/lzero/model/gobigger/network/res_block.py deleted file mode 100644 index f64fae1db..000000000 --- a/lzero/model/gobigger/network/res_block.py +++ /dev/null @@ -1,231 +0,0 @@ -""" -Copyright 2020 Sensetime X-lab. All Rights Reserved - -Main Function: - 1. build ResBlock: you can use this classes to build residual blocks -""" -import torch.nn as nn -from .nn_module import conv2d_block, fc_block,conv2d_block2,fc_block2 -from .activation import build_activation -from .normalization import build_normalization - - -class ResBlock(nn.Module): - r''' - Overview: - Residual Block with 2D convolution layers, including 2 types: - basic block: - input channel: C - x -> 3*3*C -> norm -> act -> 3*3*C -> norm -> act -> out - \__________________________________________/+ - bottleneck block: - x -> 1*1*(1/4*C) -> norm -> act -> 3*3*(1/4*C) -> norm -> act -> 1*1*C -> norm -> act -> out - \_____________________________________________________________________________/+ - - Interface: - __init__, forward - ''' - - def __init__(self, in_channels, out_channels=None,stride=1, downsample=None, activation='relu', norm_type='LN',): - r""" - Overview: - Init the Residual Block - - Arguments: - - in_channels (:obj:`int`): Number of channels in the input tensor - - activation (:obj:`nn.Module`): the optional activation function - - norm_type (:obj:`str`): type of the normalization, defalut set to batch normalization, - support ['BN', 'IN', 'SyncBN', None] - - res_type (:obj:`str`): type of residual block, support ['basic', 'bottleneck'], see overview for details - """ - super(ResBlock, self).__init__() - self.in_channels = in_channels - self.out_channels = self.in_channels if out_channels is None else out_channels - self.activation_type = activation - self.norm_type = norm_type - self.stride = stride - self.downsample = downsample - self.conv1 = conv2d_block(in_channels=self.in_channels, - out_channels=self.out_channels, - kernel_size=3, - stride=self.stride, - padding= 1, - activation=self.activation_type, - norm_type=self.norm_type) - self.conv2 = conv2d_block(in_channels=self.out_channels, - out_channels=self.out_channels, - kernel_size=3, - stride=self.stride, - padding= 1, - activation=None, - norm_type=self.norm_type) - self.activation = build_activation(self.activation_type) - - def forward(self, x): - r""" - Overview: - return the redisual block output - - Arguments: - - x (:obj:`tensor`): the input tensor - - Returns: - - x(:obj:`tensor`): the resblock output tensor - """ - residual = x - out = self.conv1(x) - out = self.conv2(out) - if self.downsample is not None: - residual = self.downsample(x) - out += residual - out = self.activation(out) - return out - - -class ResBlock2(nn.Module): - r''' - Overview: - Residual Block with 2D convolution layers, including 2 types: - basic block: - input channel: C - x -> 3*3*C -> norm -> act -> 3*3*C -> norm -> act -> out - \__________________________________________/+ - bottleneck block: - x -> 1*1*(1/4*C) -> norm -> act -> 3*3*(1/4*C) -> norm -> act -> 1*1*C -> norm -> act -> out - \_____________________________________________________________________________/+ - - Interface: - __init__, forward - ''' - - def __init__(self, in_channels, out_channels=None,stride=1, downsample=None, activation='relu', norm_type='LN',): - r""" - Overview: - Init the Residual Block - - Arguments: - - in_channels (:obj:`int`): Number of channels in the input tensor - - activation (:obj:`nn.Module`): the optional activation function - - norm_type (:obj:`str`): type of the normalization, defalut set to batch normalization, - support ['BN', 'IN', 'SyncBN', None] - - res_type (:obj:`str`): type of residual block, support ['basic', 'bottleneck'], see overview for details - """ - super(ResBlock2, self).__init__() - self.in_channels = in_channels - self.out_channels = self.in_channels if out_channels is None else out_channels - self.activation_type = activation - self.norm_type = norm_type - self.stride = stride - self.downsample = downsample - self.conv1 = conv2d_block2(in_channels=self.in_channels, - out_channels=self.out_channels, - kernel_size=3, - stride=self.stride, - padding= 1, - activation=self.activation_type, - norm_type=self.norm_type) - self.conv2 = conv2d_block2(in_channels=self.out_channels, - out_channels=self.out_channels, - kernel_size=3, - stride=self.stride, - padding= 1, - activation=self.activation_type, - norm_type=self.norm_type) - self.activation = build_activation(self.activation_type) - - - def forward(self, x): - r""" - Overview: - return the redisual block output - - Arguments: - - x (:obj:`tensor`): the input tensor - - Returns: - - x(:obj:`tensor`): the resblock output tensor - """ - residual = x - out = self.conv1(x) - out = self.conv2(out) - if self.downsample is not None: - residual = self.downsample(x) - out += residual - return x - -class ResFCBlock(nn.Module): - def __init__(self, in_channels, activation='relu', norm_type=None): - r""" - Overview: - Init the Residual Block - - Arguments: - - activation (:obj:`nn.Module`): the optional activation function - - norm_type (:obj:`str`): type of the normalization, defalut set to batch normalization - """ - super(ResFCBlock, self).__init__() - self.activation_type = activation - self.norm_type = norm_type - self.fc1 = fc_block(in_channels, in_channels, norm_type=self.norm_type, activation=self.activation_type) - self.fc2 = fc_block(in_channels, in_channels,norm_type=self.norm_type, activation=None) - self.activation = build_activation(self.activation_type) - - - def forward(self, x): - r""" - Overview: - return output of the residual block with 2 fully connected block - - Arguments: - - x (:obj:`tensor`): the input tensor - - Returns: - - x(:obj:`tensor`): the resblock output tensor - """ - residual = x - x = self.fc1(x) - x = self.fc2(x) - x = self.activation(x + residual) - return x - -class ResFCBlock2(nn.Module): - r''' - Overview: - Residual Block with 2 fully connected block - x -> fc1 -> norm -> act -> fc2 -> norm -> act -> out - \_____________________________________/+ - - Interface: - __init__, forward - ''' - - def __init__(self, in_channels, activation='relu', norm_type='LN'): - r""" - Overview: - Init the Residual Block - - Arguments: - - activation (:obj:`nn.Module`): the optional activation function - - norm_type (:obj:`str`): type of the normalization, defalut set to batch normalization - """ - super(ResFCBlock2, self).__init__() - self.activation_type = activation - self.fc1 = fc_block2(in_channels, in_channels, activation=self.activation_type, norm_type=norm_type) - self.fc2 = fc_block2(in_channels, in_channels, activation=self.activation_type, norm_type=norm_type) - - def forward(self, x): - r""" - Overview: - return output of the residual block with 2 fully connected block - - Arguments: - - x (:obj:`tensor`): the input tensor - - Returns: - - x(:obj:`tensor`): the resblock output tensor - """ - residual = x - x = self.fc1(x) - x = self.fc2(x) - x = x + residual - return x \ No newline at end of file diff --git a/lzero/model/gobigger/network/rnn.py b/lzero/model/gobigger/network/rnn.py deleted file mode 100644 index 363107360..000000000 --- a/lzero/model/gobigger/network/rnn.py +++ /dev/null @@ -1,276 +0,0 @@ -""" -Copyright 2020 Sensetime X-lab. All Rights Reserved - -Main Function: - 1. build LSTM: you can use build_LSTM to build the lstm module -""" -import math - -import torch -import torch.nn as nn - -from typing import Optional -from .normalization import build_normalization - - -def is_sequence(data): - return isinstance(data, list) or isinstance(data, tuple) - - -def sequence_mask(lengths: torch.Tensor, max_len: Optional[int] =None): - r""" - Overview: - create a mask for a batch sequences with different lengths - Arguments: - - lengths (:obj:`tensor`): lengths in each different sequences, shape could be (n, 1) or (n) - - max_len (:obj:`int`): the padding size, if max_len is None, the padding size is the - max length of sequences - Returns: - - masks (:obj:`torch.BoolTensor`): mask has the same device as lengths - """ - if len(lengths.shape) == 1: - lengths = lengths.unsqueeze(dim=1) - bz = lengths.numel() - if max_len is None: - max_len = lengths.max() - return torch.arange(0, max_len).type_as(lengths).repeat(bz, 1).lt(lengths).to(lengths.device) - - -class LSTMForwardWrapper(object): - r""" - Overview: - abstract class used to wrap the LSTM forward method - Interface: - _before_forward, _after_forward - """ - - def _before_forward(self, inputs, prev_state): - r""" - Overview: - preprocess the inputs and previous states - Arguments: - - inputs (:obj:`tensor`): input vector of cell, tensor of size [seq_len, batch_size, input_size] - - prev_state (:obj:`tensor` or :obj:`list`): - None or tensor of size [num_directions*num_layers, batch_size, hidden_size], if None then prv_state - will be initialized to all zeros. - Returns: - - prev_state (:obj:`tensor`): batch previous state in lstm - """ - assert hasattr(self, 'num_layers') - assert hasattr(self, 'hidden_size') - seq_len, batch_size = inputs.shape[:2] - if prev_state is None: - num_directions = 1 - zeros = torch.zeros( - num_directions * self.num_layers, - batch_size, - self.hidden_size, - dtype=inputs.dtype, - device=inputs.device - ) - prev_state = (zeros, zeros) - elif is_sequence(prev_state): - if len(prev_state) == 2 and isinstance(prev_state[0], torch.Tensor): - pass - else: - if len(prev_state) != batch_size: - raise RuntimeError( - "prev_state number is not equal to batch_size: {}/{}".format(len(prev_state), batch_size) - ) - num_directions = 1 - zeros = torch.zeros( - num_directions * self.num_layers, 1, self.hidden_size, dtype=inputs.dtype, device=inputs.device - ) - state = [] - for prev in prev_state: - if prev is None: - state.append([zeros, zeros]) - else: - state.append(prev) - state = list(zip(*state)) - prev_state = [torch.cat(t, dim=1) for t in state] - else: - raise TypeError("not support prev_state type: {}".format(type(prev_state))) - return prev_state - - def _after_forward(self, next_state, list_next_state=False): - r""" - Overview: - post process the next_state, return list or tensor type next_states - Arguments: - - next_state (:obj:`list` :obj:`Tuple` of :obj:`tensor`): list of Tuple contains the next (h, c) - - list_next_state (:obj:`bool`): whether return next_state with list format, default set to False - Returns: - - next_state(:obj:`list` of :obj:`tensor` or :obj:`tensor`): the formated next_state - """ - if list_next_state: - h, c = [torch.stack(t, dim=0) for t in zip(*next_state)] - batch_size = h.shape[1] - next_state = [torch.chunk(h, batch_size, dim=1), torch.chunk(c, batch_size, dim=1)] - next_state = list(zip(*next_state)) - else: - next_state = [torch.stack(t, dim=0) for t in zip(*next_state)] - return next_state - - -class LSTM(nn.Module, LSTMForwardWrapper): - r""" - Overview: - Implimentation of LSTM cell - - .. note:: - for begainners, you can reference to learn the basics about lstm - - Interface: - __init__, forward - """ - - def __init__(self, input_size, hidden_size, num_layers, norm_type=None, dropout=0.): - r""" - Overview: - initializate the LSTM cell - - Arguments: - - input_size (:obj:`int`): size of the input vector - - hidden_size (:obj:`int`): size of the hidden state vector - - num_layers (:obj:`int`): number of lstm layers - - norm_type (:obj:`str`): type of the normaliztion, (default: None) - - dropout (:obj:float): dropout rate, default set to .0 - """ - super(LSTM, self).__init__() - self.input_size = input_size - self.hidden_size = hidden_size - self.num_layers = num_layers - - norm_func = build_normalization(norm_type) - self.norm = nn.ModuleList([norm_func(hidden_size * 4) for _ in range(2 * num_layers)]) - self.wx = nn.ParameterList() - self.wh = nn.ParameterList() - dims = [input_size] + [hidden_size] * num_layers - for l in range(num_layers): - self.wx.append(nn.Parameter(torch.zeros(dims[l], dims[l + 1] * 4))) - self.wh.append(nn.Parameter(torch.zeros(hidden_size, hidden_size * 4))) - self.bias = nn.Parameter(torch.zeros(num_layers, hidden_size * 4)) - self.use_dropout = dropout > 0. - if self.use_dropout: - self.dropout = nn.Dropout(dropout) - self._init() - - def _init(self): - gain = math.sqrt(1. / self.hidden_size) - for l in range(self.num_layers): - torch.nn.init.uniform_(self.wx[l], -gain, gain) - torch.nn.init.uniform_(self.wh[l], -gain, gain) - if self.bias is not None: - torch.nn.init.uniform_(self.bias[l], -gain, gain) - - def forward(self, inputs, prev_state, list_next_state=True): - r""" - Overview: - Take the previous state and the input and calculate the output and the nextstate - Arguments: - - inputs (:obj:`tensor`): input vector of cell, tensor of size [seq_len, batch_size, input_size] - - prev_state (:obj:`tensor`): None or tensor of size [num_directions*num_layers, batch_size, hidden_size] - - list_next_state (:obj:`bool`): whether return next_state with list format, default set to False - Returns: - - x (:obj:`tensor`): output from lstm - - next_state (:obj:`tensor` or :obj:`list`): hidden state from lstm - """ - seq_len, batch_size = inputs.shape[:2] - prev_state = self._before_forward(inputs, prev_state) - - H, C = prev_state - x = inputs - next_state = [] - for l in range(self.num_layers): - h, c = H[l], C[l] - new_x = [] - for s in range(seq_len): - gate = self.norm[l * 2](torch.matmul(x[s], self.wx[l]) - ) + self.norm[l * 2 + 1](torch.matmul(h, self.wh[l])) - if self.bias is not None: - gate += self.bias[l] - gate = list(torch.chunk(gate, 4, dim=1)) - i, f, o, u = gate - i = torch.sigmoid(i) - f = torch.sigmoid(f) - o = torch.sigmoid(o) - u = torch.tanh(u) - c = f * c + i * u - h = o * torch.tanh(c) - new_x.append(h) - next_state.append((h, c)) - x = torch.stack(new_x, dim=0) - if self.use_dropout and l != self.num_layers - 1: - x = self.dropout(x) - - next_state = self._after_forward(next_state, list_next_state) - return x, next_state - - -class PytorchLSTM(nn.LSTM, LSTMForwardWrapper): - r""" - Overview: - Wrap the nn.LSTM , format the input and output - Interface: - forward - - .. note:: - you can reference the - """ - - def forward(self, inputs, prev_state, list_next_state=True): - r""" - Overview: - wrapped nn.LSTM.forward - Arguments: - - inputs (:obj:`tensor`): input vector of cell, tensor of size [seq_len, batch_size, input_size] - - prev_state (:obj:`tensor`): None or tensor of size [num_directions*num_layers, batch_size, hidden_size] - - list_next_state (:obj:`bool`): whether return next_state with list format, default set to False - Returns: - - output (:obj:`tensor`): output from lstm - - next_state (:obj:`tensor` or :obj:`list`): hidden state from lstm - """ - prev_state = self._before_forward(inputs, prev_state) - output, next_state = nn.LSTM.forward(self, inputs, prev_state) - next_state = self._after_forward(next_state, list_next_state) - return output, next_state - - def _after_forward(self, next_state, list_next_state=False): - r""" - Overview: - process hidden state after lstm, make it list or remains tensor - Arguments: - - nex_state (:obj:`tensor`): hidden state from lstm - - list_nex_state (:obj:`bool`): whether return next_state with list format, default set to False - Returns: - - next_state (:obj:`tensor` or :obj:`list`): hidden state from lstm - """ - if list_next_state: - h, c = next_state - batch_size = h.shape[1] - next_state = [torch.chunk(h, batch_size, dim=1), torch.chunk(c, batch_size, dim=1)] - return list(zip(*next_state)) - else: - return next_state - - -def get_lstm(lstm_type, input_size, hidden_size, num_layers=1, norm_type='LN', dropout=0.): - r""" - Overview: - build and return the corresponding LSTM cell - Arguments: - - lstm_type (:obj:`str`): version of lstm cell, now support ['normal', 'pytorch'] - - input_size (:obj:`int`): size of the input vector - - hidden_size (:obj:`int`): size of the hidden state vector - - num_layers (:obj:`int`): number of lstm layers - - norm_type (:obj:`str`): type of the normaliztion, (default: None) - - dropout (:obj:float): dropout rate, default set to .0 - Returns: - - lstm (:obj:`LSTM` or :obj:`PytorchLSTM`): the corresponding lstm cell - """ - assert lstm_type in ['normal', 'pytorch'] - if lstm_type == 'normal': - return LSTM(input_size, hidden_size, num_layers, norm_type, dropout=dropout) - elif lstm_type == 'pytorch': - return PytorchLSTM(input_size, hidden_size, num_layers, dropout=dropout) diff --git a/lzero/model/gobigger/network/scatter_connection.py b/lzero/model/gobigger/network/scatter_connection.py index dbb6ab716..263ff0f57 100644 --- a/lzero/model/gobigger/network/scatter_connection.py +++ b/lzero/model/gobigger/network/scatter_connection.py @@ -76,24 +76,6 @@ def forward(self, x: torch.Tensor, spatial_size: Tuple[int, int], location: torc elif self.scatter_type == 'add': output.scatter_add_(dim=2, index=indices, src=x) output = output.view(BatchSize, EmbeddingSize, H, W) - - # device = x.device - # B, M, N = x.shape - # H, W = spatial_size - # index = location.view(-1, 2) - # bias = torch.arange(B).mul_(H * W).unsqueeze(1).repeat(1, M).view(-1).to(device) - # index = index[:, 0] * W + index[:, 1] - # index += bias - # index = index.repeat(N, 1) - # x = x.view(-1, N).permute(1, 0) - # output = torch.zeros(N, B * H * W, device=device) - # if self.scatter_type == 'cover': - # output.scatter_(dim=1, index=index, src=x) - # elif self.scatter_type == 'add': - # output.scatter_add_(dim=1, index=index, src=x) - # output = output.reshape(N, B, H, W) - # output = output.permute(1, 0, 2, 3).contiguous() - return output diff --git a/lzero/model/gobigger/network/soft_argmax.py b/lzero/model/gobigger/network/soft_argmax.py deleted file mode 100644 index a963fd1ad..000000000 --- a/lzero/model/gobigger/network/soft_argmax.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Copyright 2020 Sensetime X-lab. All Rights Reserved - -Main Function: - 1. SoftArgmax: a nn.Module that computes SoftArgmax -""" -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class SoftArgmax(nn.Module): - r""" - Overview: - a nn.Module that computes SoftArgmax - - Note: - for more softargmax info, you can reference the wiki page - or reference the lecture - - - Interface: - __init__, forward - """ - - def __init__(self): - r""" - Overview: - initialize the SoftArgmax module - """ - super(SoftArgmax, self).__init__() - - def forward(self, x): - r""" - Overview: - soft-argmax for location regression - - Arguments: - - x (:obj:`Tensor`): predict heat map - - Returns: - - location (:obj:`Tensor`): predict location - - Shapes: - - x (:obj:`Tensor`): :math:`(B, C, H, W)`, while B is the batch size, - C is number of channels , H and W stands for height and width - - location (:obj:`Tensor`): :math:`(B, 2)`, while B is the batch size - """ - B, C, H, W = x.shape - device, dtype = x.device, x.dtype - # 1 channel - assert (x.shape[1] == 1) - h_kernel = torch.arange(0, H, device=device).to(dtype) - h_kernel = h_kernel.view(1, 1, H, 1).repeat(1, 1, 1, W) - w_kernel = torch.arange(0, W, device=device).to(dtype) - w_kernel = w_kernel.view(1, 1, 1, W).repeat(1, 1, H, 1) - x = F.softmax(x.view(B, C, -1), dim=-1).view(B, C, H, W) - h = (x * h_kernel).sum(dim=[1, 2, 3]) - w = (x * w_kernel).sum(dim=[1, 2, 3]) - return torch.stack([h, w], dim=1) diff --git a/lzero/model/gobigger/network/transformer.py b/lzero/model/gobigger/network/transformer.py deleted file mode 100644 index 67ae4426d..000000000 --- a/lzero/model/gobigger/network/transformer.py +++ /dev/null @@ -1,397 +0,0 @@ -import math -from typing import Dict, Tuple, Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F - -LAYER_NORM_EPS = 1e-5 -NEAR_INF = 1e20 -NEAR_INF_FP16 = 65504 - - -def neginf(dtype: torch.dtype) -> float: - """ - Return a representable finite number near -inf for a dtype. - """ - if dtype is torch.float16: - return -NEAR_INF_FP16 - else: - return -NEAR_INF - - -class MultiHeadAttention(nn.Module): - r""" - Overview: - For each entry embedding, compute individual attention across all entries, add them up to get output attention - """ - - def __init__(self, n_heads: int = None, dim: int = None, dropout: float = 0): - r""" - Overview: - Init attention - Arguments: - - input_dim (:obj:`int`): dimension of input - - head_dim (:obj:`int`): dimension of each head - - output_dim (:obj:`int`): dimension of output - - head_num (:obj:`int`): head num for multihead attention - - dropout (:obj:`nn.Module`): dropout layer - """ - super(MultiHeadAttention, self).__init__() - self.n_heads = n_heads - self.dim = dim - - self.attn_dropout = nn.Dropout(p=dropout) - self.q_lin = nn.Linear(dim, dim) - self.k_lin = nn.Linear(dim, dim) - self.v_lin = nn.Linear(dim, dim) - - # TODO: merge for the initialization step - nn.init.xavier_normal_(self.q_lin.weight) - nn.init.xavier_normal_(self.k_lin.weight) - nn.init.xavier_normal_(self.v_lin.weight) - self.out_lin = nn.Linear(dim, dim) - nn.init.xavier_normal_(self.out_lin.weight) - - # self.attention_pre = fc_block(self.dim, self.dim * 3) # query, key, value - # self.project = fc_block(self.dim,self.dim) - - def split(self, x, T=False): - r""" - Overview: - Split input to get multihead queries, keys, values - Arguments: - - x (:obj:`tensor`): query or key or value - - T (:obj:`bool`): whether to transpose output - Returns: - - x (:obj:`list`): list of output tensors for each head - """ - B, N = x.shape[:2] - x = x.view(B, N, self.head_num, self.head_dim) - x = x.permute(0, 2, 1, 3).contiguous() # B, head_num, N, head_dim - if T: - x = x.permute(0, 1, 3, 2).contiguous() - return x - - def forward(self, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - value: Optional[torch.Tensor] = None, - mask: torch.Tensor = None, - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: - batch_size, query_len, dim = query.size() - assert ( - dim == self.dim - ), 'Dimensions do not match: {} query vs {} configured'.format(dim, self.dim) - assert mask is not None, 'Mask is None, please specify a mask' - n_heads = self.n_heads - dim_per_head = dim // n_heads - scale = math.sqrt(dim_per_head) - - def prepare_head(tensor): - # input is [batch_size, seq_len, n_heads * dim_per_head] - # output is [batch_size * n_heads, seq_len, dim_per_head] - bsz, seq_len, _ = tensor.size() - tensor = tensor.view(batch_size, tensor.size(1), n_heads, dim_per_head) - tensor = ( - tensor.transpose(1, 2) - .contiguous() - .view(batch_size * n_heads, seq_len, dim_per_head) - ) - return tensor - - # q, k, v are the transformed values - if key is None and value is None: - # self attention - key = value = query - _, _key_len, dim = query.size() - elif value is None: - # key and value are the same, but query differs - # self attention - value = key - - assert key is not None # let mypy know we sorted this - _, _key_len, dim = key.size() - - q = prepare_head(self.q_lin(query)) - k = prepare_head(self.k_lin(key)) - v = prepare_head(self.v_lin(value)) - full_key_len = k.size(1) - dot_prod = q.div_(scale).bmm(k.transpose(1, 2)) - # [B * n_heads, query_len, key_len] - attn_mask = ( - (mask == 0) - .view(batch_size, 1, -1, full_key_len) - .repeat(1, n_heads, 1, 1) - .expand(batch_size, n_heads, query_len, full_key_len) - .view(batch_size * n_heads, query_len, full_key_len) - ) - assert attn_mask.shape == dot_prod.shape - dot_prod.masked_fill_(attn_mask, neginf(dot_prod.dtype)) - - attn_weights = F.softmax( - dot_prod, dim=-1, dtype=torch.float # type: ignore - ).type_as(query) - attn_weights = self.attn_dropout(attn_weights) # --attention-dropout - - attentioned = attn_weights.bmm(v) - attentioned = ( - attentioned.type_as(query) - .view(batch_size, n_heads, query_len, dim_per_head) - .transpose(1, 2) - .contiguous() - .view(batch_size, query_len, dim) - ) - - out = self.out_lin(attentioned) - - return out, dot_prod - # - # def forward(self, x, mask=None): - # r""" - # Overview: - # Compute attention - # Arguments: - # - x (:obj:`tensor`): input tensor - # - mask (:obj:`tensor`): mask out invalid entries - # Returns: - # - attention (:obj:`tensor`): attention tensor - # """ - # assert (len(x.shape) == 3) - # B, N = x.shape[:2] - # x = self.attention_pre(x) - # query, key, value = torch.chunk(x, 3, dim=2) - # query, key, value = self.split(query), self.split(key, T=True), self.split(value) - # - # score = torch.matmul(query, key) # B, head_num, N, N - # score /= math.sqrt(self.head_dim) - # if mask is not None: - # score.masked_fill_(~mask, value=-1e9) - # - # score = F.softmax(score, dim=-1) - # score = self.dropout(score) - # attention = torch.matmul(score, value) # B, head_num, N, head_dim - # - # attention = attention.permute(0, 2, 1, 3).contiguous() # B, N, head_num, head_dim - # attention = self.project(attention.view(B, N, -1)) # B, N, output_dim - # return attention - - -class TransformerFFN(nn.Module): - """ - Implements the FFN part of the transformer. - """ - - def __init__( - self, - dim: int = None, - dim_hidden: int = None, - dropout: float = 0, - activation: str = 'relu', - **kwargs, - ): - super(TransformerFFN, self).__init__(**kwargs) - self.dim = dim - self.dim_hidden = dim_hidden - self.dropout_ratio = dropout - self.relu_dropout = nn.Dropout(p=self.dropout_ratio) - if activation == 'relu': - self.nonlinear = F.relu - elif activation == 'gelu': - self.nonlinear = F.gelu - else: - raise ValueError( - "Don't know how to handle --activation {}".format(activation) - ) - self.lin1 = nn.Linear(self.dim, self.dim_hidden) - self.lin2 = nn.Linear(self.dim_hidden, self.dim) - nn.init.xavier_uniform_(self.lin1.weight) - nn.init.xavier_uniform_(self.lin2.weight) - # TODO: initialize biases to 0 - - def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: - """ - Forward pass. - """ - x = self.nonlinear(self.lin1(x)) - x = self.relu_dropout(x) # --relu-dropout - x = self.lin2(x) - return x - - -class TransformerLayer(nn.Module): - r""" - Overview: - In transformer layer, first computes entries's attention and applies a feedforward layer - """ - - def __init__(self, - n_heads: int = None, - embedding_size: int = None, - ffn_size: int = None, - attention_dropout: float = 0.0, - relu_dropout: float = 0.0, - dropout: float = 0.0, - activation: str = 'relu', - variant: Optional[str] = None, - ): - r""" - Overview: - Init transformer layer - Arguments: - - input_dim (:obj:`int`): dimension of input - - head_dim (:obj:`int`): dimension of each head - - hidden_dim (:obj:`int`): dimension of hidden layer in mlp - - output_dim (:obj:`int`): dimension of output - - head_num (:obj:`int`): number of heads for multihead attention - - mlp_num (:obj:`int`): number of mlp layers - - dropout (:obj:`nn.Module`): dropout layer - - activation (:obj:`nn.Module`): activation function - """ - super(TransformerLayer, self).__init__() - self.n_heads = n_heads - self.dim = embedding_size - self.ffn_dim = ffn_size - self.activation = activation - self.variant = variant - self.attention = MultiHeadAttention( - n_heads=self.n_heads, - dim=embedding_size, - dropout=attention_dropout) - self.norm1 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) - self.ffn = TransformerFFN(dim=embedding_size, - dim_hidden=ffn_size, - dropout=relu_dropout, - activation=activation, - ) - self.norm2 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) - self.dropout = nn.Dropout(dropout) - - def forward(self, x, mask): - """ - Overview: - transformer layer forward - Arguments: - - inputs (:obj:`tuple`): x and mask - Returns: - - output (:obj:`tuple`): x and mask - """ - residual = x - - if self.variant == 'prenorm': - x = self.norm1(x) - attended_tensor = self.attention(x, mask=mask)[0] - x = residual + self.dropout(attended_tensor) - if self.variant == 'postnorm': - x = self.norm1(x) - - residual = x - if self.variant == 'prenorm': - x = self.norm2(x) - x = residual + self.dropout(self.ffn(x)) - if self.variant == 'postnorm': - x = self.norm2(x) - - x *= mask.unsqueeze(-1).type_as(x) - return x - - -class Transformer(nn.Module): - ''' - Overview: - Transformer implementation - - Note: - For details refer to Attention is all you need: http://arxiv.org/abs/1706.03762 - ''' - - def __init__( - self, - n_heads=8, - embedding_size: int = 128, - ffn_size: int = 128, - n_layers: int = 3, - attention_dropout: float = 0.0, - relu_dropout: float = 0.0, - dropout: float = 0.0, - activation: Optional[str] = 'relu', - variant: Optional[str] = 'prenorm', - ): - r""" - Overview: - Init transformer - Arguments: - - input_dim (:obj:`int`): dimension of input - - head_dim (:obj:`int`): dimension of each head - - hidden_dim (:obj:`int`): dimension of hidden layer in mlp - - output_dim (:obj:`int`): dimension of output - - head_num (:obj:`int`): number of heads for multihead attention - - mlp_num (:obj:`int`): number of mlp layers - - layer_num (:obj:`int`): number of transformer layers - - dropout_ratio (:obj:`float`): dropout ratio - - activation (:obj:`nn.Module`): activation function - """ - super(Transformer, self).__init__() - self.n_heads = n_heads - self.dim = embedding_size - self.ffn_size = ffn_size - self.n_layers = n_layers - - self.dropout_ratio = dropout - self.attention_dropout = attention_dropout - self.relu_dropout = relu_dropout - self.activation = activation - self.variant = variant - - # build the model - self.layers = self.build_layers() - self.norm_embedding = torch.nn.LayerNorm(self.dim, eps=LAYER_NORM_EPS) - - def build_layers(self) -> nn.ModuleList: - layers = nn.ModuleList() - for _ in range(self.n_layers): - layer = TransformerLayer( - n_heads=self.n_heads, - embedding_size=self.dim, - ffn_size=self.ffn_size, - attention_dropout=self.attention_dropout, - relu_dropout=self.relu_dropout, - dropout=self.dropout_ratio, - variant=self.variant, - activation=self.activation, - ) - layers.append(layer) - return layers - - def forward(self, x, mask=None): - r""" - Overview: - Transformer forward - Arguments: - - x (:obj:`tensor`): input tensor, shape (B, N, C), B is batch size, N is number of entries, - C is feature dimension - - mask (:obj:`tensor` or :obj:`None`): bool tensor, can be used to mask out invalid entries in attention, - shape (B, N), B is batch size, N is number of entries - Returns: - - x (:obj:`tensor`): transformer output - """ - if self.variant == 'postnorm': - x = self.norm_embedding(x) - if mask is not None: - x *= mask.unsqueeze(-1).type_as(x) - else: - mask = torch.ones(size=x.shape[:2],dtype=torch.bool, device=x.device) - if self.variant == 'postnorm': - x = self.norm_embedding(x) - for i in range(self.n_layers): - x = self.layers[i](x, mask) - if self.variant == 'prenorm': - x = self.norm_embedding(x) - return x - -if __name__ == '__main__': - transformer = Transformer(n_heads=8,embedding_size=128) - from bigrl.core.torch_utils.network.rnn import sequence_mask - mask = sequence_mask(lengths=torch.tensor([1,2,3,4,5,6,2,3,0,0]),max_len=20) - y = transformer.forward(x = torch.randn(size=(10,20,128)),mask=mask) - print(y) \ No newline at end of file From e36e7528c469c63378eea0db270b88685186b69f Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Mon, 19 Jun 2023 15:55:25 +0800 Subject: [PATCH 16/54] polish(yzj): polish gobigger entry evaluator --- lzero/entry/train_muzero_gobigger.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/lzero/entry/train_muzero_gobigger.py b/lzero/entry/train_muzero_gobigger.py index b8e33cadc..0a434c01f 100644 --- a/lzero/entry/train_muzero_gobigger.py +++ b/lzero/entry/train_muzero_gobigger.py @@ -138,9 +138,6 @@ def train_muzero_gobigger( policy_config.threshold_training_steps_for_final_temperature, trained_steps=learner.train_iter ) - stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) - stop, reward = vsbot_evaluator.eval_vsbot(learner.save_checkpoint, learner.train_iter, collector.envstep) - # Evaluate policy performance. if evaluator.should_eval(learner.train_iter): stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) From 7098899e178925747a9655b3ce1d811af2b45629 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Mon, 19 Jun 2023 18:39:49 +0800 Subject: [PATCH 17/54] feature(yzj): add eps_greedy and random_collect_episode in gobigger ez --- lzero/entry/train_muzero_gobigger.py | 14 +- lzero/entry/utils.py | 37 ++ .../gobigger/network/gobigger_encoder.py | 10 +- lzero/policy/gobigger_efficientzero.py | 23 +- lzero/policy/gobigger_random_policy.py | 390 ++++++++++++++++++ lzero/worker/gobigger_muzero_collector.py | 3 +- .../config/gobigger_efficientzero_config.py | 15 +- 7 files changed, 477 insertions(+), 15 deletions(-) create mode 100644 lzero/policy/gobigger_random_policy.py diff --git a/lzero/entry/train_muzero_gobigger.py b/lzero/entry/train_muzero_gobigger.py index 0a434c01f..7d7a629ac 100644 --- a/lzero/entry/train_muzero_gobigger.py +++ b/lzero/entry/train_muzero_gobigger.py @@ -12,8 +12,8 @@ from ding.worker import BaseLearner from tensorboardX import SummaryWriter import copy - -from lzero.entry.utils import log_buffer_memory_usage +from ding.rl_utils import get_epsilon_greedy_fn +from lzero.entry.utils import log_buffer_memory_usage, random_collect from lzero.policy import visit_count_temperature from lzero.worker import GoBiggerMuZeroCollector, GoBiggerMuZeroEvaluator @@ -127,6 +127,11 @@ def train_muzero_gobigger( # ============================================================== # Learner's before_run hook. learner.call_hook('before_run') + if cfg.policy.random_collect_episode_num > 0: + random_collect(cfg.policy, policy, collector, collector_env, replay_buffer) + # reset the random_collect_episode_num to 0 + cfg.policy.random_collect_episode_num = 0 + while True: log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) collect_kwargs = {} @@ -138,6 +143,11 @@ def train_muzero_gobigger( policy_config.threshold_training_steps_for_final_temperature, trained_steps=learner.train_iter ) + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn(start=policy_config.eps.start, end=policy_config.eps.end, decay=policy_config.eps.decay, type_=policy_config.eps.type) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + else: + collect_kwargs['epsilon'] = 0.0 # Evaluate policy performance. if evaluator.should_eval(learner.train_iter): stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index 2da99f3fa..f37f1c1cb 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -1,9 +1,46 @@ import os +from typing import Optional, Callable import psutil +from easydict import EasyDict from pympler.asizeof import asizeof from tensorboardX import SummaryWriter +from lzero.policy.gobigger_random_policy import GoBiggerRandomPolicy + + +def random_collect( + policy_cfg: 'EasyDict', # noqa + policy: 'Policy', # noqa + collector: 'ISerialCollector', # noqa + collector_env: 'BaseEnvManager', # noqa + replay_buffer: 'IBuffer', # noqa + postprocess_data_fn: Optional[Callable] = None +) -> None: # noqa + assert policy_cfg.random_collect_episode_num > 0 + + random_policy = GoBiggerRandomPolicy(cfg=policy_cfg) + # set the policy to random policy + collector.reset_policy(random_policy.collect_mode) + + collect_kwargs = {} + # set temperature for visit count distributions according to the train_iter, + # please refer to Appendix D in MuZero paper for details. + collect_kwargs['temperature'] = 1 + + # Collect data by default config n_sample/n_episode. + new_data = collector.collect(train_iter=0, policy_kwargs=collect_kwargs) + + if postprocess_data_fn is not None: + new_data = postprocess_data_fn(new_data) + + # save returned new_data collected by the collector + replay_buffer.push_game_segments(new_data) + # remove the oldest data if the replay buffer is full. + replay_buffer.remove_oldest_data_to_fit() + + # restore the policy + collector.reset_policy(policy.collect_mode) def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None: """ diff --git a/lzero/model/gobigger/network/gobigger_encoder.py b/lzero/model/gobigger/network/gobigger_encoder.py index a7916d8a4..ef72f2314 100644 --- a/lzero/model/gobigger/network/gobigger_encoder.py +++ b/lzero/model/gobigger/network/gobigger_encoder.py @@ -302,7 +302,7 @@ class GoBiggerEncoder(nn.Module): time=dict(arc='time', embedding_dim=8), last_action_type=dict(arc='one_hot', num_embeddings=27), ), - mlp=dict(input_dim=80, hidden_dim=64, layer_num=2, norm_type='BN', output_dim=32, activation=nn.ReLU(inplace=True)), + mlp=dict(input_dim=80, hidden_dim=64, layer_num=2, norm_type=None, output_dim=32, activation=nn.ReLU(inplace=True)), ), team_encoder=dict( modules=dict( @@ -312,7 +312,7 @@ class GoBiggerEncoder(nn.Module): ), mlp=dict(input_dim=16, hidden_dim=32, layer_num=2, norm_type=None, output_dim=16, activation=nn.ReLU(inplace=True)), transformer=dict(input_dim=16, output_dim=16, head_num=4, ffn_size=32, layer_num=2, embedding_dim=16, activation=nn.ReLU(inplace=True), variant='postnorm'), - fc_block=dict(input_dim=16, output_dim=16, activation=nn.ReLU(inplace=True), norm_type='BN'), + fc_block=dict(input_dim=16, output_dim=16, activation=nn.ReLU(inplace=True), norm_type=None), ), ball_encoder=dict( modules=dict( @@ -327,12 +327,12 @@ class GoBiggerEncoder(nn.Module): ), mlp=dict(input_dim=92, hidden_dim=128, layer_num=2, norm_type=None, output_dim=64, activation=nn.ReLU(inplace=True)), transformer=dict(input_dim=64, output_dim=64, head_num=4, ffn_size=64, layer_num=3, embedding_dim=64, activation=nn.ReLU(inplace=True), variant='postnorm'), - fc_block=dict(input_dim=64, output_dim=64, activation=nn.ReLU(inplace=True), norm_type='BN'), + fc_block=dict(input_dim=64, output_dim=64, activation=nn.ReLU(inplace=True), norm_type=None), ), spatial_encoder=dict( scatter=dict(input_dim=64, output_dim=16, scatter_type='add', activation=nn.ReLU(inplace=True), norm_type=None), - resnet=dict(project_dim=12, down_channels=[32, 32, 16 ], activation=nn.ReLU(inplace=True), norm_type='BN'), - fc_block=dict(output_dim=64, activation=nn.ReLU(inplace=True), norm_type='BN'), + resnet=dict(project_dim=12, down_channels=[32, 32, 16 ], activation=nn.ReLU(inplace=True), norm_type=None), + fc_block=dict(output_dim=64, activation=nn.ReLU(inplace=True), norm_type=None), ), ) diff --git a/lzero/policy/gobigger_efficientzero.py b/lzero/policy/gobigger_efficientzero.py index fda441635..bb5e2d0c3 100644 --- a/lzero/policy/gobigger_efficientzero.py +++ b/lzero/policy/gobigger_efficientzero.py @@ -481,6 +481,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: return { 'collect_mcts_temperature': self.collect_mcts_temperature, + 'collect_epsilon': self.collect_epsilon, 'cur_lr': self._optimizer.param_groups[0]['lr'], 'weighted_total_loss': loss_info[0], 'total_loss': loss_info[1], @@ -516,6 +517,7 @@ def _init_collect(self) -> None: else: self._mcts_collect = MCTSPtree(self._cfg) self.collect_mcts_temperature = 1 + self.collect_epsilon = 1 def _forward_collect( self, @@ -523,6 +525,7 @@ def _forward_collect( action_mask: list = None, temperature: float = 1, to_play: List = [-1], + epsilon: float = 0.25, ready_env_id=None ): """ @@ -550,6 +553,7 @@ def _forward_collect( """ self._collect_model.eval() self.collect_mcts_temperature = temperature + self.collect_epsilon = epsilon active_collect_env_num = len(data) data = to_tensor(data) @@ -601,13 +605,21 @@ def _forward_collect( output = {i: defaultdict(list) for i in data_id} if ready_env_id is None: ready_env_id = np.arange(active_collect_env_num) - + for i in range(batch_size): distributions, value = roots_visit_count_distributions[i], roots_values[i] - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self.collect_mcts_temperature, deterministic=False - ) - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if self._cfg.eps.eps_greedy_exploration_in_collect: + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self.collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self.collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self.collect_mcts_temperature, deterministic=False + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] output[i//agent_num]['action'].append(action) output[i//agent_num]['distributions'].append(distributions) output[i//agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) @@ -724,6 +736,7 @@ def _monitor_vars_learn(self) -> List[str]: """ return [ 'collect_mcts_temperature', + 'collect_epsilon', 'cur_lr', 'weighted_total_loss', 'total_loss', diff --git a/lzero/policy/gobigger_random_policy.py b/lzero/policy/gobigger_random_policy.py new file mode 100644 index 000000000..13cfa69a0 --- /dev/null +++ b/lzero/policy/gobigger_random_policy.py @@ -0,0 +1,390 @@ +import copy +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +import torch.optim as optim +from ding.model import model_wrap +from ding.policy.base_policy import Policy +from ding.torch_utils import to_tensor +from ding.utils import POLICY_REGISTRY +from torch.distributions import Categorical +from torch.nn import L1Loss + +from lzero.mcts import EfficientZeroMCTSCtree as MCTSCtree +from lzero.mcts import EfficientZeroMCTSPtree as MCTSPtree +from lzero.model import ImageTransforms +from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ + DiscreteSupport, select_action, to_torch_float_tensor, ez_network_output_unpack, negative_cosine_similarity, prepare_obs, \ + configure_optimizers +from collections import defaultdict +from ding.torch_utils import to_device + + +@POLICY_REGISTRY.register('gobigger_random_policy') +class GoBiggerRandomPolicy(Policy): + """ + Overview: + The policy class for EfficientZero. + """ + + # The default_config for EfficientZero policy. + config = dict( + model=dict( + # (str) The model type. For 1-dimensional vector obs, we use mlp model. For 3-dimensional image obs, we use conv model. + model_type='conv', # options={'mlp', 'conv'} + # (bool) If True, the action space of the environment is continuous, otherwise discrete. + continuous_action_space=False, + # (tuple) The stacked obs shape. + # observation_shape=(1, 96, 96), # if frame_stack_num=1 + observation_shape=(4, 96, 96), # if frame_stack_num=4 + # (bool) Whether to use the self-supervised learning loss. + self_supervised_learning_loss=True, + # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. + categorical_distribution=True, + # (int) The image channel in image observation. + image_channel=1, + # (int) The number of frames to stack together. + frame_stack_num=1, + # (int) The scale of supports used in categorical distribution. + # This variable is only effective when ``categorical_distribution=True``. + support_scale=300, + # (int) The hidden size in LSTM. + lstm_hidden_size=512, + # (bool) whether to learn bias in the last linear layer in value and policy head. + bias=True, + # (str) The type of action encoding. Options are ['one_hot', 'not_one_hot']. Default to 'one_hot'. + discrete_action_encoding_type='one_hot', + # (bool) whether to use res connection in dynamics. + res_connection_in_dynamics=True, + # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'LN'. + norm_type='BN', + ), + # ****** common ****** + # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) + # this variable is used in ``collector``. + sampled_algo=False, + # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. + mcts_ctree=True, + # (bool) Whether to use cuda for network. + cuda=True, + # (int) The number of environments used in collecting data. + collector_env_num=8, + # (int) The number of environments used in evaluating policy. + evaluator_env_num=3, + # (str) The type of environment. The options are ['not_board_games', 'board_games']. + env_type='not_board_games', + # (str) The type of battle mode. The options are ['play_with_bot_mode', 'self_play_mode']. + battle_mode='play_with_bot_mode', + # (bool) Whether to monitor extra statistics in tensorboard. + monitor_extra_statistics=True, + # (int) The transition number of one ``GameSegment``. + game_segment_length=200, + + # ****** observation ****** + # (bool) Whether to transform image to string to save memory. + transform2string=False, + # (bool) Whether to use data augmentation. + use_augmentation=False, + # (list) The style of augmentation. + augmentation=['shift', 'intensity'], + + # ******* learn ****** + # (int) How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + # For different env, we have different episode_length, + # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor + # if we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically. + update_per_collect=None, + # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. + model_update_ratio=0.1, + # (int) Minibatch size for one gradient descent. + batch_size=256, + # (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW'] + optim_type='SGD', + # (float) Learning rate for training policy network. Initial lr for manually decay schedule. + learning_rate=0.2, + # (int) Frequency of target network update. + target_update_freq=100, + # (float) Weight decay for training policy network. + weight_decay=1e-4, + # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). + momentum=0.9, + # (float) The maximum constraint value of gradient norm clipping. + grad_clip_value=10, + # (int) The number of episodes in each collecting stage. + n_episode=8, + # (float) the number of simulations in MCTS. + num_simulations=50, + # (float) Discount factor (gamma) for returns. + discount_factor=0.997, + # (int) The number of steps for calculating target q_value. + td_steps=5, + # (int) The number of unroll steps in dynamics network. + num_unroll_steps=5, + # (int) reset the hidden states in LSTM every ``lstm_horizon_len`` horizon steps. + lstm_horizon_len=5, + # (float) The weight of reward loss. + reward_loss_weight=1, + # (float) The weight of value loss. + value_loss_weight=0.25, + # (float) The weight of policy loss. + policy_loss_weight=1, + # (float) The weight of ssl (self-supervised learning) loss. + ssl_loss_weight=2, + # (bool) Whether to use piecewise constant learning rate decay. + # i.e. lr: 0.2 -> 0.02 -> 0.002 + lr_piecewise_constant_decay=True, + # (int) The number of final training iterations to control lr decay, which is only used for manually decay. + threshold_training_steps_for_final_lr=int(5e4), + # (int) The number of final training iterations to control temperature, which is only used for manually decay. + threshold_training_steps_for_final_temperature=int(1e5), + # (bool) Whether to use manually decayed temperature. + # i.e. temperature: 1 -> 0.5 -> 0.25 + manual_temperature_decay=False, + # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. + # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. + fixed_temperature_value=0.25, + + # ****** Priority ****** + # (bool) Whether to use priority when sampling training data from the buffer. + use_priority=True, + # (bool) Whether to use the maximum priority for new collecting data. + use_max_priority_for_new_data=True, + # (float) The degree of prioritization to use. A value of 0 means no prioritization, + # while a value of 1 means full prioritization. + priority_prob_alpha=0.6, + # (float) The degree of correction to use. A value of 0 means no correction, + # while a value of 1 means full correction. + priority_prob_beta=0.4, + + # ****** UCB ****** + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of the search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default model setting. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. + - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. + - import_names (:obj:`List[str]`): The model class path list used in this algorithm. + .. note:: + The user can define and use customized network model but must obey the same interface definition indicated \ + by import_names path. For EfficientZero, ``lzero.model.efficientzero_model.EfficientZeroModel`` + """ + return 'GoBiggerEfficientZeroModel', ['lzero.model.gobigger.gobigger_efficientzero_model'] + + def _init_learn(self) -> None: + """ + Overview: + Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. + """ + pass + + def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning policy in learn mode, which is the core of the learning process. + The data is sampled from replay buffer. + The loss is calculated by the loss function and the loss is backpropagated to update the model. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. + The first tensor is the current_batch, the second tensor is the target_batch. + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ + current learning loss and learning statistics. + """ + pass + + def _init_collect(self) -> None: + """ + Overview: + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + """ + self._collect_model = self._model + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(self._cfg) + else: + self._mcts_collect = MCTSPtree(self._cfg) + self.collect_mcts_temperature = 1.0 + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + ready_env_id=None + ): + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + self.collect_mcts_temperature = temperature + + active_collect_env_num = len(data) + data = to_tensor(data) + data = sum(sum(data, []), []) + batch_size = len(data) + data = to_device(data, self._cfg.device) + agent_num = batch_size // active_collect_env_num + to_play = np.array(to_play).reshape(-1).tolist() + + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._collect_model.initial_inference(data) + latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( + network_output + ) + + if not self._learn_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + reward_hidden_state_roots = ( + reward_hidden_state_roots[0].detach().cpu().numpy(), + reward_hidden_state_roots[1].detach().cpu().numpy() + ) + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + action_mask = sum(action_mask, []) + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] + # the only difference between collect and eval is the dirichlet noise. + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(batch_size) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(batch_size, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(batch_size, legal_actions) + roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_roots, policy_logits, to_play) + self._mcts_collect.search( + roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play + ) + + roots_visit_count_distributions = roots.get_distributions() # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_values = roots.get_values() # shape: {list: batch_size} + + data_id = [i for i in range(active_collect_env_num)] + output = {i: defaultdict(list) for i in data_id} + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + + for i in range(batch_size): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self.collect_mcts_temperature, deterministic=False + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + # ************* random action ************* + action = int(np.random.choice(legal_actions[i], 1)) + output[i//agent_num]['action'].append(action) + output[i//agent_num]['distributions'].append(distributions) + output[i//agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[i//agent_num]['value'].append(value) + output[i//agent_num]['pred_value'].append(pred_values[i]) + output[i//agent_num]['policy_logits'].append(policy_logits[i]) + + + return output + + def _init_eval(self) -> None: + """ + Overview: + Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. + """ + self._eval_model = self._model + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(self._cfg) + else: + self._mcts_eval = MCTSPtree(self._cfg) + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id=None): + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + pass + + def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Register the variables to be monitored in learn mode. The registered variables will be logged in + tensorboard according to the return value ``_forward_learn``. + """ + pass + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model and optimizer. + Returns: + - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. + """ + pass + + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before. + """ + pass + + def _process_transition(self, obs, policy_output, timestep): + # be compatible with DI-engine Policy class + pass + + def _get_train_sample(self, data): + # be compatible with DI-engine Policy class + pass diff --git a/lzero/worker/gobigger_muzero_collector.py b/lzero/worker/gobigger_muzero_collector.py index a715cce5b..d18329c14 100644 --- a/lzero/worker/gobigger_muzero_collector.py +++ b/lzero/worker/gobigger_muzero_collector.py @@ -287,6 +287,7 @@ def collect(self, if policy_kwargs is None: policy_kwargs = {} temperature = policy_kwargs['temperature'] + epsilon = policy_kwargs['epsilon'] collected_episode = 0 env_nums = self._env_num @@ -376,7 +377,7 @@ def collect(self, # ============================================================== # policy forward # ============================================================== - policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play) + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon) actions_no_env_id=defaultdict(dict) for k,v in policy_output.items(): for agent_id, act in enumerate(v['action']): diff --git a/zoo/gobigger/config/gobigger_efficientzero_config.py b/zoo/gobigger/config/gobigger_efficientzero_config.py index 11ffca0e3..fc51723a7 100644 --- a/zoo/gobigger/config/gobigger_efficientzero_config.py +++ b/zoo/gobigger/config/gobigger_efficientzero_config.py @@ -5,6 +5,7 @@ # ============================================================== # begin of the most frequently changed config specified by the user # ============================================================== +seed=0 collector_env_num = 32 n_episode = 32 evaluator_env_num = 5 @@ -14,13 +15,15 @@ reanalyze_ratio = 0. action_space_size = 27 direction_num=12 +eps_greedy_exploration_in_collect=True +random_collect_episode_num=0 # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== atari_efficientzero_config = dict( exp_name= - f'data_ez_ctree/{env_name}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + f'data_ez_ctree/{env_name}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', env=dict( env_name=env_name, team_num=2, @@ -80,6 +83,14 @@ mcts_ctree=True, env_type='not_board_games', game_segment_length=400, + random_collect_episode_num=random_collect_episode_num, + eps=dict( + eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, + type='exp', + start=1., + end=0.05, + decay=int(1e5), + ), use_augmentation=False, update_per_collect=update_per_collect, batch_size=batch_size, @@ -137,4 +148,4 @@ if __name__ == "__main__": from lzero.entry import train_muzero_gobigger - train_muzero_gobigger([main_config, create_config], seed=0) + train_muzero_gobigger([main_config, create_config], seed=seed) From b94deaed5623a09ad6ca29a3511654c4c1ca6147 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Tue, 20 Jun 2023 14:37:59 +0800 Subject: [PATCH 18/54] fix(yzj): fix key bug in entry utils when random collect --- lzero/entry/utils.py | 1 + lzero/policy/gobigger_random_policy.py | 17 ++++++++--------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index f37f1c1cb..3e4cc180e 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -27,6 +27,7 @@ def random_collect( # set temperature for visit count distributions according to the train_iter, # please refer to Appendix D in MuZero paper for details. collect_kwargs['temperature'] = 1 + collect_kwargs['epsilon'] = 0.0 # Collect data by default config n_sample/n_episode. new_data = collector.collect(train_iter=0, policy_kwargs=collect_kwargs) diff --git a/lzero/policy/gobigger_random_policy.py b/lzero/policy/gobigger_random_policy.py index 13cfa69a0..1138d5d88 100644 --- a/lzero/policy/gobigger_random_policy.py +++ b/lzero/policy/gobigger_random_policy.py @@ -266,15 +266,14 @@ def _forward_collect( network_output ) - if not self._learn_model.training: - # if not in training, obtain the scalars of the value/reward - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() - latent_state_roots = latent_state_roots.detach().cpu().numpy() - reward_hidden_state_roots = ( - reward_hidden_state_roots[0].detach().cpu().numpy(), - reward_hidden_state_roots[1].detach().cpu().numpy() - ) - policy_logits = policy_logits.detach().cpu().numpy().tolist() + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + reward_hidden_state_roots = ( + reward_hidden_state_roots[0].detach().cpu().numpy(), + reward_hidden_state_roots[1].detach().cpu().numpy() + ) + policy_logits = policy_logits.detach().cpu().numpy().tolist() action_mask = sum(action_mask, []) legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] From dfa4671464b59c28e8fb96070a863537305ac514 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Sun, 25 Jun 2023 13:04:25 +0800 Subject: [PATCH 19/54] fix(yzj): fix gobigger encoder bn bug --- lzero/model/gobigger/network/gobigger_encoder.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lzero/model/gobigger/network/gobigger_encoder.py b/lzero/model/gobigger/network/gobigger_encoder.py index ef72f2314..a7916d8a4 100644 --- a/lzero/model/gobigger/network/gobigger_encoder.py +++ b/lzero/model/gobigger/network/gobigger_encoder.py @@ -302,7 +302,7 @@ class GoBiggerEncoder(nn.Module): time=dict(arc='time', embedding_dim=8), last_action_type=dict(arc='one_hot', num_embeddings=27), ), - mlp=dict(input_dim=80, hidden_dim=64, layer_num=2, norm_type=None, output_dim=32, activation=nn.ReLU(inplace=True)), + mlp=dict(input_dim=80, hidden_dim=64, layer_num=2, norm_type='BN', output_dim=32, activation=nn.ReLU(inplace=True)), ), team_encoder=dict( modules=dict( @@ -312,7 +312,7 @@ class GoBiggerEncoder(nn.Module): ), mlp=dict(input_dim=16, hidden_dim=32, layer_num=2, norm_type=None, output_dim=16, activation=nn.ReLU(inplace=True)), transformer=dict(input_dim=16, output_dim=16, head_num=4, ffn_size=32, layer_num=2, embedding_dim=16, activation=nn.ReLU(inplace=True), variant='postnorm'), - fc_block=dict(input_dim=16, output_dim=16, activation=nn.ReLU(inplace=True), norm_type=None), + fc_block=dict(input_dim=16, output_dim=16, activation=nn.ReLU(inplace=True), norm_type='BN'), ), ball_encoder=dict( modules=dict( @@ -327,12 +327,12 @@ class GoBiggerEncoder(nn.Module): ), mlp=dict(input_dim=92, hidden_dim=128, layer_num=2, norm_type=None, output_dim=64, activation=nn.ReLU(inplace=True)), transformer=dict(input_dim=64, output_dim=64, head_num=4, ffn_size=64, layer_num=3, embedding_dim=64, activation=nn.ReLU(inplace=True), variant='postnorm'), - fc_block=dict(input_dim=64, output_dim=64, activation=nn.ReLU(inplace=True), norm_type=None), + fc_block=dict(input_dim=64, output_dim=64, activation=nn.ReLU(inplace=True), norm_type='BN'), ), spatial_encoder=dict( scatter=dict(input_dim=64, output_dim=16, scatter_type='add', activation=nn.ReLU(inplace=True), norm_type=None), - resnet=dict(project_dim=12, down_channels=[32, 32, 16 ], activation=nn.ReLU(inplace=True), norm_type=None), - fc_block=dict(output_dim=64, activation=nn.ReLU(inplace=True), norm_type=None), + resnet=dict(project_dim=12, down_channels=[32, 32, 16 ], activation=nn.ReLU(inplace=True), norm_type='BN'), + fc_block=dict(output_dim=64, activation=nn.ReLU(inplace=True), norm_type='BN'), ), ) From ff1182160c6cc0fc1b8dfdc252e49847301930c2 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Sun, 25 Jun 2023 13:07:10 +0800 Subject: [PATCH 20/54] polish(yzj): polish ez config and set eps as 1.5e4 learner iter --- zoo/gobigger/config/gobigger_efficientzero_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zoo/gobigger/config/gobigger_efficientzero_config.py b/zoo/gobigger/config/gobigger_efficientzero_config.py index fc51723a7..56ef9dc04 100644 --- a/zoo/gobigger/config/gobigger_efficientzero_config.py +++ b/zoo/gobigger/config/gobigger_efficientzero_config.py @@ -89,7 +89,7 @@ type='exp', start=1., end=0.05, - decay=int(1e5), + decay=int(1.5e4), ), use_augmentation=False, update_per_collect=update_per_collect, From a95c19c7f505a04978ed7e5d2483f2eeea655dd5 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Sun, 25 Jun 2023 13:24:44 +0800 Subject: [PATCH 21/54] polish(yzj): polish code style by format.sh --- lzero/entry/eval_muzero_gobigger.py | 3 +- lzero/entry/train_muzero_gobigger.py | 8 +- lzero/entry/utils.py | 1 + .../gobigger/gobigger_efficientzero_model.py | 7 +- .../gobigger_sampled_efficientzero_model.py | 3 +- lzero/model/gobigger/network/encoder.py | 35 ++- .../gobigger/network/gobigger_encoder.py | 271 +++++++++++------- .../gobigger/network/scatter_connection.py | 10 +- lzero/policy/gobigger_efficientzero.py | 30 +- lzero/policy/gobigger_muzero.py | 33 +-- lzero/policy/gobigger_random_policy.py | 21 +- .../policy/gobigger_sampled_efficientzero.py | 48 ++-- lzero/worker/gobigger_muzero_collector.py | 52 ++-- lzero/worker/gobigger_muzero_evaluator.py | 68 +++-- .../config/gobigger_efficientzero_config.py | 68 ++--- zoo/gobigger/config/gobigger_eval_config.py | 10 +- zoo/gobigger/config/gobigger_muzero_config.py | 61 ++-- .../gobigger_sampled_efficientzero_config.py | 60 ++-- zoo/gobigger/env/gobigger_env.py | 150 ++++++---- zoo/gobigger/env/gobigger_rule_bot.py | 33 ++- zoo/gobigger/env/test_gobbiger_env.py | 63 ++-- 21 files changed, 575 insertions(+), 460 deletions(-) diff --git a/lzero/entry/eval_muzero_gobigger.py b/lzero/entry/eval_muzero_gobigger.py index 73b7abb41..dee5a98b5 100644 --- a/lzero/entry/eval_muzero_gobigger.py +++ b/lzero/entry/eval_muzero_gobigger.py @@ -15,6 +15,7 @@ from ding.worker import BaseLearner from lzero.worker import GoBiggerMuZeroEvaluator + def eval_muzero_gobigger( input_cfg: Tuple[dict, dict], seed: int = 0, @@ -111,5 +112,5 @@ def eval_muzero_gobigger( # eval trained model # ============================================================== _, reward_sp = evaluator.eval(learner.save_checkpoint, learner.train_iter) - _, reward_vsbot= vsbot_evaluator.eval_vsbot(learner.save_checkpoint, learner.train_iter) + _, reward_vsbot = vsbot_evaluator.eval_vsbot(learner.save_checkpoint, learner.train_iter) return reward_sp, reward_vsbot diff --git a/lzero/entry/train_muzero_gobigger.py b/lzero/entry/train_muzero_gobigger.py index 7d7a629ac..2abdaf2f5 100644 --- a/lzero/entry/train_muzero_gobigger.py +++ b/lzero/entry/train_muzero_gobigger.py @@ -17,6 +17,7 @@ from lzero.policy import visit_count_temperature from lzero.worker import GoBiggerMuZeroCollector, GoBiggerMuZeroEvaluator + def train_muzero_gobigger( input_cfg: Tuple[dict, dict], seed: int = 0, @@ -144,7 +145,12 @@ def train_muzero_gobigger( trained_steps=learner.train_iter ) if policy_config.eps.eps_greedy_exploration_in_collect: - epsilon_greedy_fn = get_epsilon_greedy_fn(start=policy_config.eps.start, end=policy_config.eps.end, decay=policy_config.eps.decay, type_=policy_config.eps.type) + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) else: collect_kwargs['epsilon'] = 0.0 diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index 3e4cc180e..97c838fdc 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -43,6 +43,7 @@ def random_collect( # restore the policy collector.reset_policy(policy.collect_mode) + def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None: """ Overview: diff --git a/lzero/model/gobigger/gobigger_efficientzero_model.py b/lzero/model/gobigger/gobigger_efficientzero_model.py index 5844729bf..11b4125cc 100644 --- a/lzero/model/gobigger/gobigger_efficientzero_model.py +++ b/lzero/model/gobigger/gobigger_efficientzero_model.py @@ -13,6 +13,7 @@ from easydict import EasyDict from ding.utils.data import default_collate + @MODEL_REGISTRY.register('GoBiggerEfficientZeroModel') class GoBiggerEfficientZeroModel(nn.Module): @@ -179,8 +180,10 @@ def initial_inference(self, obs: torch.Tensor) -> EZNetworkOutput: policy_logits, value = self._prediction(latent_state) # zero initialization for reward hidden states # (hn, cn), each element shape is (layer_num=1, batch_size, lstm_hidden_size) - reward_hidden_state = (torch.zeros(1, batch_size, self.lstm_hidden_size).to(device), - torch.zeros(1, batch_size, self.lstm_hidden_size).to(device),) + reward_hidden_state = ( + torch.zeros(1, batch_size, self.lstm_hidden_size).to(device), + torch.zeros(1, batch_size, self.lstm_hidden_size).to(device), + ) return EZNetworkOutput(value, [0. for _ in range(batch_size)], policy_logits, latent_state, reward_hidden_state) def recurrent_inference( diff --git a/lzero/model/gobigger/gobigger_sampled_efficientzero_model.py b/lzero/model/gobigger/gobigger_sampled_efficientzero_model.py index aa7b79504..38990464b 100644 --- a/lzero/model/gobigger/gobigger_sampled_efficientzero_model.py +++ b/lzero/model/gobigger/gobigger_sampled_efficientzero_model.py @@ -216,8 +216,7 @@ def initial_inference(self, obs: torch.Tensor) -> EZNetworkOutput: # (hn, cn), each element shape is (layer_num=1, batch_size, lstm_hidden_size) reward_hidden_state = ( torch.zeros(1, batch_size, - self.lstm_hidden_size).to(device), torch.zeros(1, batch_size, - self.lstm_hidden_size).to(device) + self.lstm_hidden_size).to(device), torch.zeros(1, batch_size, self.lstm_hidden_size).to(device) ) return EZNetworkOutput(value, [0. for _ in range(batch_size)], policy_logits, latent_state, reward_hidden_state) diff --git a/lzero/model/gobigger/network/encoder.py b/lzero/model/gobigger/network/encoder.py index daa014ec5..adcbc2d89 100644 --- a/lzero/model/gobigger/network/encoder.py +++ b/lzero/model/gobigger/network/encoder.py @@ -4,11 +4,11 @@ class OnehotEncoder(nn.Module): + def __init__(self, num_embeddings: int): super(OnehotEncoder, self).__init__() self.num_embeddings = num_embeddings - self.main = nn.Embedding.from_pretrained(torch.eye(self.num_embeddings), freeze=True, - padding_idx=None) + self.main = nn.Embedding.from_pretrained(torch.eye(self.num_embeddings), freeze=True, padding_idx=None) def forward(self, x: torch.Tensor): x = x.long().clamp_(max=self.num_embeddings - 1) @@ -16,6 +16,7 @@ def forward(self, x: torch.Tensor): class OnehotEmbedding(nn.Module): + def __init__(self, num_embeddings: int, embedding_dim: int): super(OnehotEmbedding, self).__init__() self.num_embeddings = num_embeddings @@ -28,11 +29,13 @@ def forward(self, x: torch.Tensor): class BinaryEncoder(nn.Module): + def __init__(self, num_embeddings: int): super(BinaryEncoder, self).__init__() self.bit_num = num_embeddings - self.main = nn.Embedding.from_pretrained(self.get_binary_embed_matrix(self.bit_num), freeze=True, - padding_idx=None) + self.main = nn.Embedding.from_pretrained( + self.get_binary_embed_matrix(self.bit_num), freeze=True, padding_idx=None + ) @staticmethod def get_binary_embed_matrix(bit_num): @@ -48,11 +51,13 @@ def forward(self, x: torch.Tensor): class SignBinaryEncoder(nn.Module): + def __init__(self, num_embeddings): super(SignBinaryEncoder, self).__init__() self.bit_num = num_embeddings - self.main = nn.Embedding.from_pretrained(self.get_sign_binary_matrix(self.bit_num), freeze=True, - padding_idx=None) + self.main = nn.Embedding.from_pretrained( + self.get_sign_binary_matrix(self.bit_num), freeze=True, padding_idx=None + ) self.max_val = 2 ** (self.bit_num - 1) - 1 @staticmethod @@ -67,27 +72,31 @@ def get_sign_binary_matrix(bit_num): return torch.tensor(embedding_matrix, dtype=torch.float) def forward(self, x: torch.Tensor): - x = x.long().clamp_(max=self.max_val, min=- self.max_val) + x = x.long().clamp_(max=self.max_val, min=-self.max_val) return self.main(x + self.max_val) class PositionEncoder(nn.Module): + def __init__(self, num_embeddings, embedding_dim=None): super(PositionEncoder, self).__init__() self.n_position = num_embeddings self.embedding_dim = self.n_position if embedding_dim is None else embedding_dim self.position_enc = nn.Embedding.from_pretrained( - self.position_encoding_init(self.n_position, self.embedding_dim), - freeze=True, padding_idx=None) + self.position_encoding_init(self.n_position, self.embedding_dim), freeze=True, padding_idx=None + ) @staticmethod def position_encoding_init(n_position, embedding_dim): ''' Init the sinusoid position encoding table ''' # keep dim 0 for padding token position encoding zero vector - position_enc = np.array([ - [pos / np.power(10000, 2 * (j // 2) / embedding_dim) for j in range(embedding_dim)] - for pos in range(n_position)]) + position_enc = np.array( + [ + [pos / np.power(10000, 2 * (j // 2) / embedding_dim) for j in range(embedding_dim)] + for pos in range(n_position) + ] + ) position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # apply sin on 0th,2nd,4th...embedding_dim position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # apply cos on 1st,3rd,5th...embedding_dim return torch.from_numpy(position_enc).type(torch.FloatTensor) @@ -97,6 +106,7 @@ def forward(self, x: torch.Tensor): class TimeEncoder(nn.Module): + def __init__(self, embedding_dim): super(TimeEncoder, self).__init__() self.embedding_dim = embedding_dim @@ -120,6 +130,7 @@ def forward(self, x: torch.Tensor): class UnsqueezeEncoder(nn.Module): + def __init__(self, unsqueeze_dim: int = -1, norm_value: float = 1): super(UnsqueezeEncoder, self).__init__() self.unsqueeze_dim = unsqueeze_dim diff --git a/lzero/model/gobigger/network/gobigger_encoder.py b/lzero/model/gobigger/network/gobigger_encoder.py index a7916d8a4..7feeebd7f 100644 --- a/lzero/model/gobigger/network/gobigger_encoder.py +++ b/lzero/model/gobigger/network/gobigger_encoder.py @@ -10,7 +10,8 @@ from .encoder import SignBinaryEncoder, BinaryEncoder, OnehotEncoder, TimeEncoder, UnsqueezeEncoder from .scatter_connection import ScatterConnection -def sequence_mask(lengths: torch.Tensor, max_len: Optional[int] =None): + +def sequence_mask(lengths: torch.Tensor, max_len: Optional[int] = None): r""" Overview: create a mask for a batch sequences with different lengths @@ -30,6 +31,7 @@ def sequence_mask(lengths: torch.Tensor, max_len: Optional[int] =None): class ScalarEncoder(nn.Module): + def __init__(self, cfg): super(ScalarEncoder, self).__init__() self.whole_cfg = cfg @@ -48,16 +50,19 @@ def __init__(self, cfg): print(f'cant implement {k} for arc {item["arc"]}') raise NotImplementedError - self.layers = MLP(in_channels=self.cfg.mlp.input_dim, hidden_channels=self.cfg.mlp.hidden_dim, - out_channels=self.cfg.mlp.output_dim, - layer_num=self.cfg.mlp.layer_num, - layer_fn=fc_block, - activation=self.cfg.mlp.activation, - norm_type=self.cfg.mlp.norm_type, - use_dropout=False, - output_activation=True, - output_norm=True, - last_linear_layer_init_zero=False) + self.layers = MLP( + in_channels=self.cfg.mlp.input_dim, + hidden_channels=self.cfg.mlp.hidden_dim, + out_channels=self.cfg.mlp.output_dim, + layer_num=self.cfg.mlp.layer_num, + layer_fn=fc_block, + activation=self.cfg.mlp.activation, + norm_type=self.cfg.mlp.norm_type, + use_dropout=False, + output_activation=True, + output_norm=True, + last_linear_layer_init_zero=False + ) def forward(self, x: Dict[str, Tensor]): embeddings = [] @@ -71,6 +76,7 @@ def forward(self, x: Dict[str, Tensor]): class TeamEncoder(nn.Module): + def __init__(self, cfg): super(TeamEncoder, self).__init__() self.whole_cfg = cfg @@ -88,17 +94,19 @@ def __init__(self, cfg): print(f'cant implement {k} for arc {item["arc"]}') raise NotImplementedError - self.encode_layers = MLP(in_channels=self.cfg.mlp.input_dim, - hidden_channels=self.cfg.mlp.hidden_dim, - out_channels=self.cfg.mlp.output_dim, - layer_num=self.cfg.mlp.layer_num, - layer_fn=fc_block, - activation=self.cfg.mlp.activation, - norm_type=self.cfg.mlp.norm_type, - use_dropout=False, - output_activation=True, - output_norm=True, - last_linear_layer_init_zero=False) + self.encode_layers = MLP( + in_channels=self.cfg.mlp.input_dim, + hidden_channels=self.cfg.mlp.hidden_dim, + out_channels=self.cfg.mlp.output_dim, + layer_num=self.cfg.mlp.layer_num, + layer_fn=fc_block, + activation=self.cfg.mlp.activation, + norm_type=self.cfg.mlp.norm_type, + use_dropout=False, + output_activation=True, + output_norm=True, + last_linear_layer_init_zero=False + ) self.transformer = Transformer( input_dim=self.cfg.transformer.input_dim, @@ -109,10 +117,12 @@ def __init__(self, cfg): layer_num=self.cfg.transformer.layer_num, activation=self.cfg.transformer.activation, ) - self.output_fc = fc_block(self.cfg.fc_block.input_dim, - self.cfg.fc_block.output_dim, - norm_type=self.cfg.fc_block.norm_type, - activation=self.cfg.fc_block.activation) + self.output_fc = fc_block( + self.cfg.fc_block.input_dim, + self.cfg.fc_block.output_dim, + norm_type=self.cfg.fc_block.norm_type, + activation=self.cfg.fc_block.activation + ) def forward(self, x): embeddings = [] @@ -131,6 +141,7 @@ def forward(self, x): class BallEncoder(nn.Module): + def __init__(self, cfg): super(BallEncoder, self).__init__() self.whole_cfg = cfg @@ -148,17 +159,19 @@ def __init__(self, cfg): else: print(f'cant implement {k} for arc {item["arc"]}') raise NotImplementedError - self.encode_layers = MLP(in_channels=self.cfg.mlp.input_dim, - hidden_channels=self.cfg.mlp.hidden_dim, - out_channels=self.cfg.mlp.output_dim, - layer_num=self.cfg.mlp.layer_num, - layer_fn=fc_block, - activation=self.cfg.mlp.activation, - norm_type=self.cfg.mlp.norm_type, - use_dropout=False, - output_activation=True, - output_norm=True, - last_linear_layer_init_zero=False) + self.encode_layers = MLP( + in_channels=self.cfg.mlp.input_dim, + hidden_channels=self.cfg.mlp.hidden_dim, + out_channels=self.cfg.mlp.output_dim, + layer_num=self.cfg.mlp.layer_num, + layer_fn=fc_block, + activation=self.cfg.mlp.activation, + norm_type=self.cfg.mlp.norm_type, + use_dropout=False, + output_activation=True, + output_norm=True, + last_linear_layer_init_zero=False + ) self.transformer = Transformer( input_dim=self.cfg.transformer.input_dim, @@ -169,10 +182,12 @@ def __init__(self, cfg): layer_num=self.cfg.transformer.layer_num, activation=self.cfg.transformer.activation, ) - self.output_fc = fc_block(self.cfg.fc_block.input_dim, - self.cfg.fc_block.output_dim, - norm_type=self.cfg.fc_block.norm_type, - activation=self.cfg.fc_block.activation) + self.output_fc = fc_block( + self.cfg.fc_block.input_dim, + self.cfg.fc_block.output_dim, + norm_type=self.cfg.fc_block.norm_type, + activation=self.cfg.fc_block.activation + ) def forward(self, x): ball_num = x['ball_num'] @@ -192,6 +207,7 @@ def forward(self, x): class SpatialEncoder(nn.Module): + def __init__(self, cfg): super(SpatialEncoder, self).__init__() self.whole_cfg = cfg @@ -200,10 +216,12 @@ def __init__(self, cfg): # scatter related self.spatial_x = 64 self.spatial_y = 64 - self.scatter_fc = fc_block(in_channels=self.cfg.scatter.input_dim, - out_channels=self.cfg.scatter.output_dim, - activation=self.cfg.scatter.activation, - norm_type=self.cfg.scatter.norm_type) + self.scatter_fc = fc_block( + in_channels=self.cfg.scatter.input_dim, + out_channels=self.cfg.scatter.output_dim, + activation=self.cfg.scatter.activation, + norm_type=self.cfg.scatter.norm_type + ) self.scatter_connection = ScatterConnection(self.cfg.scatter.scatter_type) # resnet related @@ -213,72 +231,92 @@ def __init__(self, cfg): in_channels=self.spatial_x // 8 * self.spatial_y // 8 * self.cfg.resnet.down_channels[-1], out_channels=self.cfg.fc_block.output_dim, norm_type=self.cfg.fc_block.norm_type, - activation=self.cfg.fc_block.activation) + activation=self.cfg.fc_block.activation + ) def get_resnet_blocks(self): # 2 means food/spore embedding - project = conv2d_block(in_channels=self.cfg.scatter.output_dim + 2, - out_channels=self.cfg.resnet.project_dim, - kernel_size=1, - stride=1, - padding=0, - activation=self.cfg.resnet.activation, - norm_type=self.cfg.resnet.norm_type, - bias=False, - ) + project = conv2d_block( + in_channels=self.cfg.scatter.output_dim + 2, + out_channels=self.cfg.resnet.project_dim, + kernel_size=1, + stride=1, + padding=0, + activation=self.cfg.resnet.activation, + norm_type=self.cfg.resnet.norm_type, + bias=False, + ) layers = [project] dims = [self.cfg.resnet.project_dim] + self.cfg.resnet.down_channels for i in range(len(dims) - 1): - layer = conv2d_block(in_channels=dims[i], - out_channels=dims[i + 1], - kernel_size=4, - stride=2, - padding=1, - activation=self.cfg.resnet.activation, - norm_type=self.cfg.resnet.norm_type, - bias=False, - ) + layer = conv2d_block( + in_channels=dims[i], + out_channels=dims[i + 1], + kernel_size=4, + stride=2, + padding=1, + activation=self.cfg.resnet.activation, + norm_type=self.cfg.resnet.norm_type, + bias=False, + ) layers.append(layer) - layers.append(ResBlock(res_type='basic', - in_channels=dims[i + 1], - activation=self.cfg.resnet.activation, - norm_type=self.cfg.resnet.norm_type)) + layers.append( + ResBlock( + res_type='basic', + in_channels=dims[i + 1], + activation=self.cfg.resnet.activation, + norm_type=self.cfg.resnet.norm_type + ) + ) self.resnet = torch.nn.Sequential(*layers) - - def get_background_embedding(self, coord_x, coord_y, num, ): + def get_background_embedding( + self, + coord_x, + coord_y, + num, + ): background_ones = torch.ones(size=(coord_x.shape[0], coord_x.shape[1]), device=coord_x.device) background_mask = sequence_mask(num, max_len=coord_x.shape[1]) background_ones = (background_ones * background_mask).unsqueeze(-1) - background_embedding = self.scatter_connection.xy_forward(background_ones, - spatial_size=[self.spatial_x, self.spatial_y], - coord_x=coord_x, - coord_y=coord_y) + background_embedding = self.scatter_connection.xy_forward( + background_ones, spatial_size=[self.spatial_x, self.spatial_y], coord_x=coord_x, coord_y=coord_y + ) return background_embedding - def forward(self, inputs, ball_embeddings, ): + def forward( + self, + inputs, + ball_embeddings, + ): spatial_info = inputs['spatial_info'] # food and spore - food_embedding = self.get_background_embedding(coord_x=spatial_info['food_x'], - coord_y=spatial_info['food_y'], - num=spatial_info['food_num'], ) + food_embedding = self.get_background_embedding( + coord_x=spatial_info['food_x'], + coord_y=spatial_info['food_y'], + num=spatial_info['food_num'], + ) - spore_embedding = self.get_background_embedding(coord_x=spatial_info['spore_x'], - coord_y=spatial_info['spore_y'], - num=spatial_info['spore_num'], ) + spore_embedding = self.get_background_embedding( + coord_x=spatial_info['spore_x'], + coord_y=spatial_info['spore_y'], + num=spatial_info['spore_num'], + ) # scatter ball embeddings ball_info = inputs['ball_info'] ball_num = ball_info['ball_num'] ball_mask = sequence_mask(ball_num, max_len=ball_embeddings.shape[1]) ball_embedding = self.scatter_fc(ball_embeddings) * ball_mask.unsqueeze(dim=2) - ball_embedding = self.scatter_connection.xy_forward(ball_embedding, - spatial_size=[self.spatial_x, self.spatial_y], - coord_x=spatial_info['ball_x'], - coord_y=spatial_info['ball_y']) + ball_embedding = self.scatter_connection.xy_forward( + ball_embedding, + spatial_size=[self.spatial_x, self.spatial_y], + coord_x=spatial_info['ball_x'], + coord_y=spatial_info['ball_y'] + ) x = torch.cat([food_embedding, spore_embedding, ball_embedding], dim=1) @@ -290,7 +328,7 @@ def forward(self, inputs, ball_embeddings, ): class GoBiggerEncoder(nn.Module): - config=dict( + config = dict( scalar_encoder=dict( modules=dict( view_x=dict(arc='sign_binary', num_embeddings=7), @@ -301,37 +339,78 @@ class GoBiggerEncoder(nn.Module): rank=dict(arc='one_hot', num_embeddings=4), time=dict(arc='time', embedding_dim=8), last_action_type=dict(arc='one_hot', num_embeddings=27), - ), - mlp=dict(input_dim=80, hidden_dim=64, layer_num=2, norm_type='BN', output_dim=32, activation=nn.ReLU(inplace=True)), + ), + mlp=dict( + input_dim=80, + hidden_dim=64, + layer_num=2, + norm_type='BN', + output_dim=32, + activation=nn.ReLU(inplace=True) + ), ), team_encoder=dict( modules=dict( alliance=dict(arc='one_hot', num_embeddings=2), view_x=dict(arc='sign_binary', num_embeddings=7), view_y=dict(arc='sign_binary', num_embeddings=7), - ), - mlp=dict(input_dim=16, hidden_dim=32, layer_num=2, norm_type=None, output_dim=16, activation=nn.ReLU(inplace=True)), - transformer=dict(input_dim=16, output_dim=16, head_num=4, ffn_size=32, layer_num=2, embedding_dim=16, activation=nn.ReLU(inplace=True), variant='postnorm'), + ), + mlp=dict( + input_dim=16, + hidden_dim=32, + layer_num=2, + norm_type=None, + output_dim=16, + activation=nn.ReLU(inplace=True) + ), + transformer=dict( + input_dim=16, + output_dim=16, + head_num=4, + ffn_size=32, + layer_num=2, + embedding_dim=16, + activation=nn.ReLU(inplace=True), + variant='postnorm' + ), fc_block=dict(input_dim=16, output_dim=16, activation=nn.ReLU(inplace=True), norm_type='BN'), ), ball_encoder=dict( modules=dict( alliance=dict(arc='one_hot', num_embeddings=4), score=dict(arc='one_hot', num_embeddings=50), - radius=dict(arc='unsqueeze',), + radius=dict(arc='unsqueeze', ), rank=dict(arc='one_hot', num_embeddings=5), x=dict(arc='sign_binary', num_embeddings=8), y=dict(arc='sign_binary', num_embeddings=8), next_x=dict(arc='sign_binary', num_embeddings=8), next_y=dict(arc='sign_binary', num_embeddings=8), ), - mlp=dict(input_dim=92, hidden_dim=128, layer_num=2, norm_type=None, output_dim=64, activation=nn.ReLU(inplace=True)), - transformer=dict(input_dim=64, output_dim=64, head_num=4, ffn_size=64, layer_num=3, embedding_dim=64, activation=nn.ReLU(inplace=True), variant='postnorm'), + mlp=dict( + input_dim=92, + hidden_dim=128, + layer_num=2, + norm_type=None, + output_dim=64, + activation=nn.ReLU(inplace=True) + ), + transformer=dict( + input_dim=64, + output_dim=64, + head_num=4, + ffn_size=64, + layer_num=3, + embedding_dim=64, + activation=nn.ReLU(inplace=True), + variant='postnorm' + ), fc_block=dict(input_dim=64, output_dim=64, activation=nn.ReLU(inplace=True), norm_type='BN'), ), spatial_encoder=dict( - scatter=dict(input_dim=64, output_dim=16, scatter_type='add', activation=nn.ReLU(inplace=True), norm_type=None), - resnet=dict(project_dim=12, down_channels=[32, 32, 16 ], activation=nn.ReLU(inplace=True), norm_type='BN'), + scatter=dict( + input_dim=64, output_dim=16, scatter_type='add', activation=nn.ReLU(inplace=True), norm_type=None + ), + resnet=dict(project_dim=12, down_channels=[32, 32, 16], activation=nn.ReLU(inplace=True), norm_type='BN'), fc_block=dict(output_dim=64, activation=nn.ReLU(inplace=True), norm_type='BN'), ), ) diff --git a/lzero/model/gobigger/network/scatter_connection.py b/lzero/model/gobigger/network/scatter_connection.py index 263ff0f57..1005d1719 100644 --- a/lzero/model/gobigger/network/scatter_connection.py +++ b/lzero/model/gobigger/network/scatter_connection.py @@ -24,15 +24,16 @@ def __init__(self, scatter_type='add') -> None: self.scatter_type = scatter_type assert self.scatter_type in ['cover', 'add'] - def xy_forward(self, x: torch.Tensor, spatial_size: Tuple[int, int], coord_x: torch.Tensor,coord_y) -> torch.Tensor: + def xy_forward( + self, x: torch.Tensor, spatial_size: Tuple[int, int], coord_x: torch.Tensor, coord_y + ) -> torch.Tensor: device = x.device BatchSize, Num, EmbeddingSize = x.shape x = x.permute(0, 2, 1) H, W = spatial_size indices = (coord_x * W + coord_y).long() indices = indices.unsqueeze(dim=1).repeat(1, EmbeddingSize, 1) - output = torch.zeros(size=(BatchSize, EmbeddingSize, H, W), device=device).view(BatchSize, EmbeddingSize, - H * W) + output = torch.zeros(size=(BatchSize, EmbeddingSize, H, W), device=device).view(BatchSize, EmbeddingSize, H * W) if self.scatter_type == 'cover': output.scatter_(dim=2, index=indices, src=x) elif self.scatter_type == 'add': @@ -69,8 +70,7 @@ def forward(self, x: torch.Tensor, spatial_size: Tuple[int, int], location: torc H, W = spatial_size indices = location[:, :, 1] + location[:, :, 0] * W indices = indices.unsqueeze(dim=1).repeat(1, EmbeddingSize, 1) - output = torch.zeros(size=(BatchSize, EmbeddingSize, H, W), device=device).view(BatchSize, EmbeddingSize, - H * W) + output = torch.zeros(size=(BatchSize, EmbeddingSize, H, W), device=device).view(BatchSize, EmbeddingSize, H * W) if self.scatter_type == 'cover': output.scatter_(dim=2, index=indices, src=x) elif self.scatter_type == 'add': diff --git a/lzero/policy/gobigger_efficientzero.py b/lzero/policy/gobigger_efficientzero.py index bb5e2d0c3..3234fdfb4 100644 --- a/lzero/policy/gobigger_efficientzero.py +++ b/lzero/policy/gobigger_efficientzero.py @@ -58,7 +58,7 @@ class GoBiggerEfficientZeroPolicy(Policy): # (bool) whether to use res connection in dynamics. res_connection_in_dynamics=True, # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'LN'. - norm_type='BN', + norm_type='BN', ), # ****** common ****** # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) @@ -598,7 +598,8 @@ def _forward_collect( roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play ) - roots_visit_count_distributions = roots.get_distributions() # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions( + ) # shape: ``{list: batch_size} ->{list: action_space_size}`` roots_values = roots.get_values() # shape: {list: batch_size} data_id = [i for i in range(active_collect_env_num)] @@ -620,13 +621,12 @@ def _forward_collect( distributions, temperature=self.collect_mcts_temperature, deterministic=False ) action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - output[i//agent_num]['action'].append(action) - output[i//agent_num]['distributions'].append(distributions) - output[i//agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) - output[i//agent_num]['value'].append(value) - output[i//agent_num]['pred_value'].append(pred_values[i]) - output[i//agent_num]['policy_logits'].append(policy_logits[i]) - + output[i // agent_num]['action'].append(action) + output[i // agent_num]['distributions'].append(distributions) + output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[i // agent_num]['value'].append(value) + output[i // agent_num]['pred_value'].append(pred_values[i]) + output[i // agent_num]['policy_logits'].append(policy_logits[i]) return output @@ -719,12 +719,12 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read ) # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - output[i//agent_num]['action'].append(action) - output[i//agent_num]['distributions'].append(distributions) - output[i//agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) - output[i//agent_num]['value'].append(value) - output[i//agent_num]['pred_value'].append(pred_values[i]) - output[i//agent_num]['policy_logits'].append(policy_logits[i]) + output[i // agent_num]['action'].append(action) + output[i // agent_num]['distributions'].append(distributions) + output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[i // agent_num]['value'].append(value) + output[i // agent_num]['pred_value'].append(pred_values[i]) + output[i // agent_num]['policy_logits'].append(policy_logits[i]) return output diff --git a/lzero/policy/gobigger_muzero.py b/lzero/policy/gobigger_muzero.py index 027027a2e..dbcec0531 100644 --- a/lzero/policy/gobigger_muzero.py +++ b/lzero/policy/gobigger_muzero.py @@ -19,6 +19,7 @@ from collections import defaultdict from ding.torch_utils import to_device + @POLICY_REGISTRY.register('gobigger_muzero') class GoBiggerMuZeroPolicy(Policy): """ @@ -58,7 +59,7 @@ class GoBiggerMuZeroPolicy(Policy): # (bool) whether to use res connection in dynamics. res_connection_in_dynamics=True, # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'LN'. - norm_type='BN', + norm_type='BN', ), # ****** common ****** # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) @@ -173,7 +174,6 @@ def default_model(self) -> Tuple[str, List[str]]: by import_names path. For MuZero, ``lzero.model.muzero_model.MuZeroModel`` """ return 'GoBiggerMuZeroModel', ['lzero.model.gobigger.gobigger_muzero_model'] - def _init_learn(self) -> None: """ @@ -543,7 +543,8 @@ def _forward_collect( roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play) - roots_visit_count_distributions = roots.get_distributions() # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions( + ) # shape: ``{list: batch_size} ->{list: action_space_size}`` roots_values = roots.get_values() # shape: {list: batch_size} data_id = [i for i in range(active_collect_env_num)] @@ -562,12 +563,12 @@ def _forward_collect( # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the # entire action set. action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - output[i//agent_num]['action'].append(action) - output[i//agent_num]['distributions'].append(distributions) - output[i//agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) - output[i//agent_num]['value'].append(value) - output[i//agent_num]['pred_value'].append(pred_values[i]) - output[i//agent_num]['policy_logits'].append(policy_logits[i]) + output[i // agent_num]['action'].append(action) + output[i // agent_num]['distributions'].append(distributions) + output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[i // agent_num]['value'].append(value) + output[i // agent_num]['pred_value'].append(pred_values[i]) + output[i // agent_num]['policy_logits'].append(policy_logits[i]) return output @@ -640,7 +641,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 roots_values = roots.get_values() # shape: {list: batch_size} data_id = [i for i in range(active_eval_env_num)] - output = {i: defaultdict(list) for i in data_id} + output = {i: defaultdict(list) for i in data_id} if ready_env_id is None: ready_env_id = np.arange(active_eval_env_num) @@ -657,12 +658,12 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the # entire action set. action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - output[i//agent_num]['action'].append(action) - output[i//agent_num]['distributions'].append(distributions) - output[i//agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) - output[i//agent_num]['value'].append(value) - output[i//agent_num]['pred_value'].append(pred_values[i]) - output[i//agent_num]['policy_logits'].append(policy_logits[i]) + output[i // agent_num]['action'].append(action) + output[i // agent_num]['distributions'].append(distributions) + output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[i // agent_num]['value'].append(value) + output[i // agent_num]['pred_value'].append(pred_values[i]) + output[i // agent_num]['policy_logits'].append(policy_logits[i]) return output diff --git a/lzero/policy/gobigger_random_policy.py b/lzero/policy/gobigger_random_policy.py index 1138d5d88..15b53efbe 100644 --- a/lzero/policy/gobigger_random_policy.py +++ b/lzero/policy/gobigger_random_policy.py @@ -58,7 +58,7 @@ class GoBiggerRandomPolicy(Policy): # (bool) whether to use res connection in dynamics. res_connection_in_dynamics=True, # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'LN'. - norm_type='BN', + norm_type='BN', ), # ****** common ****** # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) @@ -270,8 +270,7 @@ def _forward_collect( pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() reward_hidden_state_roots = ( - reward_hidden_state_roots[0].detach().cpu().numpy(), - reward_hidden_state_roots[1].detach().cpu().numpy() + reward_hidden_state_roots[0].detach().cpu().numpy(), reward_hidden_state_roots[1].detach().cpu().numpy() ) policy_logits = policy_logits.detach().cpu().numpy().tolist() @@ -293,7 +292,8 @@ def _forward_collect( roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play ) - roots_visit_count_distributions = roots.get_distributions() # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions( + ) # shape: ``{list: batch_size} ->{list: action_space_size}`` roots_values = roots.get_values() # shape: {list: batch_size} data_id = [i for i in range(active_collect_env_num)] @@ -309,13 +309,12 @@ def _forward_collect( action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] # ************* random action ************* action = int(np.random.choice(legal_actions[i], 1)) - output[i//agent_num]['action'].append(action) - output[i//agent_num]['distributions'].append(distributions) - output[i//agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) - output[i//agent_num]['value'].append(value) - output[i//agent_num]['pred_value'].append(pred_values[i]) - output[i//agent_num]['policy_logits'].append(policy_logits[i]) - + output[i // agent_num]['action'].append(action) + output[i // agent_num]['distributions'].append(distributions) + output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[i // agent_num]['value'].append(value) + output[i // agent_num]['pred_value'].append(pred_values[i]) + output[i // agent_num]['policy_logits'].append(policy_logits[i]) return output diff --git a/lzero/policy/gobigger_sampled_efficientzero.py b/lzero/policy/gobigger_sampled_efficientzero.py index 0f16c7e22..8bf5d2b10 100644 --- a/lzero/policy/gobigger_sampled_efficientzero.py +++ b/lzero/policy/gobigger_sampled_efficientzero.py @@ -66,7 +66,7 @@ class GoBiggerSampledEfficientZeroPolicy(Policy): # (bool) whether to use res connection in dynamics. res_connection_in_dynamics=True, # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'LN'. - norm_type='BN', + norm_type='BN', ), # ****** common ****** # (bool) ``sampled_algo=True`` means the policy is sampled-based algorithm (e.g. Sampled EfficientZero), which is used in ``collector``. @@ -410,7 +410,6 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: obs_target_batch_tmp = to_device(obs_target_batch_tmp, self._cfg.device) network_output = self._learn_model.initial_inference(obs_target_batch_tmp) - latent_state = to_tensor(latent_state) representation_state = to_tensor(network_output.latent_state) @@ -851,13 +850,9 @@ def _forward_collect( if self._cfg.model.continuous_action_space is True: # when the action space of the environment is continuous, action_mask[:] is None. # NOTE: in continuous action space env: we set all legal_actions as -1 - legal_actions = [ - [-1 for _ in range(self._cfg.model.num_of_sampled_actions)] for _ in range(batch_size) - ] + legal_actions = [[-1 for _ in range(self._cfg.model.num_of_sampled_actions)] for _ in range(batch_size)] else: - legal_actions = [ - [i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size) - ] + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] if self._cfg.mcts_ctree: # cpp mcts_tree @@ -918,14 +913,13 @@ def _forward_collect( elif len(action.shape) == 1: action = int(action[0]) - output[i//agent_num]['action'].append(action) - output[i//agent_num]['distributions'].append(distributions) - output[i//agent_num]['root_sampled_actions'].append(root_sampled_actions) - output[i//agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) - output[i//agent_num]['value'].append(value) - output[i//agent_num]['pred_value'].append(pred_values[i]) - output[i//agent_num]['policy_logits'].append(policy_logits[i]) - + output[i // agent_num]['action'].append(action) + output[i // agent_num]['distributions'].append(distributions) + output[i // agent_num]['root_sampled_actions'].append(root_sampled_actions) + output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[i // agent_num]['value'].append(value) + output[i // agent_num]['pred_value'].append(pred_values[i]) + output[i // agent_num]['policy_logits'].append(policy_logits[i]) return output @@ -992,13 +986,9 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read if self._cfg.model.continuous_action_space is True: # when the action space of the environment is continuous, action_mask[:] is None. # NOTE: in continuous action space env: we set all legal_actions as -1 - legal_actions = [ - [-1 for _ in range(self._cfg.model.num_of_sampled_actions)] for _ in range(batch_size) - ] + legal_actions = [[-1 for _ in range(self._cfg.model.num_of_sampled_actions)] for _ in range(batch_size)] else: - legal_actions = [ - [i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size) - ] + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] # cpp mcts_tree if self._cfg.mcts_ctree: @@ -1060,13 +1050,13 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read action = int(action) elif len(action.shape) == 1: action = int(action[0]) - output[i//agent_num]['action'].append(action) - output[i//agent_num]['distributions'].append(distributions) - output[i//agent_num]['root_sampled_actions'].append(root_sampled_actions) - output[i//agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) - output[i//agent_num]['value'].append(value) - output[i//agent_num]['pred_value'].append(pred_values[i]) - output[i//agent_num]['policy_logits'].append(policy_logits[i]) + output[i // agent_num]['action'].append(action) + output[i // agent_num]['distributions'].append(distributions) + output[i // agent_num]['root_sampled_actions'].append(root_sampled_actions) + output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[i // agent_num]['value'].append(value) + output[i // agent_num]['pred_value'].append(pred_values[i]) + output[i // agent_num]['policy_logits'].append(policy_logits[i]) return output diff --git a/lzero/worker/gobigger_muzero_collector.py b/lzero/worker/gobigger_muzero_collector.py index d18329c14..840f26753 100644 --- a/lzero/worker/gobigger_muzero_collector.py +++ b/lzero/worker/gobigger_muzero_collector.py @@ -194,9 +194,10 @@ def _compute_priorities(self, i, agent_id, pred_values_lst, search_values_lst): - search_values_lst: The list of value obtained through search. """ if self.policy_config.use_priority and not self.policy_config.use_max_priority_for_new_data: - pred_values = torch.from_numpy(np.array(pred_values_lst[i][agent_id])).to(self.policy_config.device).float().view(-1) + pred_values = torch.from_numpy(np.array(pred_values_lst[i][agent_id])).to(self.policy_config.device + ).float().view(-1) search_values = torch.from_numpy(np.array(search_values_lst[i][agent_id])).to(self.policy_config.device - ).float().view(-1) + ).float().view(-1) priorities = L1Loss(reduction='none' )(pred_values, search_values).detach().cpu().numpy() + self.policy_config.prioritized_replay_eps @@ -206,7 +207,9 @@ def _compute_priorities(self, i, agent_id, pred_values_lst, search_values_lst): return priorities - def pad_and_save_last_trajectory(self, i, agent_id, last_game_segments, last_game_priorities, game_segments, done) -> None: + def pad_and_save_last_trajectory( + self, i, agent_id, last_game_segments, last_game_priorities, game_segments, done + ) -> None: """ Overview: put the last game block into the pool if the current game is finished @@ -312,18 +315,23 @@ def collect(self, to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} agent_num = len(init_obs[0]['action_mask']) game_segments = [ - [GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) for _ in range(agent_num)] for _ in range(env_nums) + [ + GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) for _ in range(agent_num) + ] for _ in range(env_nums) ] # stacked observation windows in reset stage for init game_segments observation_window_stack = [[[] for _ in range(agent_num)] for _ in range(env_nums)] for env_id in range(env_nums): for agent_id in range(agent_num): observation_window_stack[env_id][agent_id] = deque( - [to_ndarray(init_obs[env_id]['observation'][agent_id]) for _ in range(self.policy_config.model.frame_stack_num)], + [ + to_ndarray(init_obs[env_id]['observation'][agent_id]) + for _ in range(self.policy_config.model.frame_stack_num) + ], maxlen=self.policy_config.model.frame_stack_num ) game_segments[env_id][agent_id].reset(observation_window_stack[env_id][agent_id]) @@ -335,7 +343,6 @@ def collect(self, search_values_lst = [[[] for _ in range(agent_num)] for _ in range(env_nums)] pred_values_lst = [[[] for _ in range(agent_num)] for _ in range(env_nums)] - # some logs eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros((env_nums, agent_num)) self_play_moves = 0. @@ -378,8 +385,8 @@ def collect(self, # policy forward # ============================================================== policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon) - actions_no_env_id=defaultdict(dict) - for k,v in policy_output.items(): + actions_no_env_id = defaultdict(dict) + for k, v in policy_output.items(): for agent_id, act in enumerate(v['action']): actions_no_env_id[k][agent_id] = act @@ -435,18 +442,22 @@ def collect(self, if self.policy_config.sampled_algo: for agent_id in range(agent_num): game_segments[env_id][agent_id].store_search_stats( - distributions_dict[env_id][agent_id], value_dict[env_id][agent_id], root_sampled_actions_dict[env_id][agent_id]) + distributions_dict[env_id][agent_id], value_dict[env_id][agent_id], + root_sampled_actions_dict[env_id][agent_id] + ) else: for agent_id in range(agent_num): - if len(distributions_dict[env_id][agent_id])!=27: + if len(distributions_dict[env_id][agent_id]) != 27: print('') - game_segments[env_id][agent_id].store_search_stats(distributions_dict[env_id][agent_id], value_dict[env_id][agent_id]) + game_segments[env_id][agent_id].store_search_stats( + distributions_dict[env_id][agent_id], value_dict[env_id][agent_id] + ) # append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t} # in ``game_segments[env_id].init``, we have append o_{t} in ``self.obs_segment`` for agent_id in range(agent_num): game_segments[env_id][agent_id].append( - actions[env_id][agent_id], to_ndarray(obs['observation'][agent_id]), reward[agent_id], action_mask_dict[env_id][agent_id], - to_play_dict[env_id] + actions[env_id][agent_id], to_ndarray(obs['observation'][agent_id]), reward[agent_id], + action_mask_dict[env_id][agent_id], to_play_dict[env_id] ) # NOTE: the position of code snippet is very important. @@ -575,11 +586,14 @@ def collect(self, config=self.policy_config ) observation_window_stack[env_id][agent_id] = deque( - [init_obs[env_id]['observation'][agent_id] for _ in range(self.policy_config.model.frame_stack_num)], + [ + init_obs[env_id]['observation'][agent_id] + for _ in range(self.policy_config.model.frame_stack_num) + ], maxlen=self.policy_config.model.frame_stack_num ) game_segments[env_id][agent_id].reset(observation_window_stack[env_id][agent_id]) - last_game_segments[env_id] = [None for _ in range(agent_num)] + last_game_segments[env_id] = [None for _ in range(agent_num)] last_game_priorities[env_id] = [None for _ in range(agent_num)] # log diff --git a/lzero/worker/gobigger_muzero_evaluator.py b/lzero/worker/gobigger_muzero_evaluator.py index 40c741414..d4634cda4 100644 --- a/lzero/worker/gobigger_muzero_evaluator.py +++ b/lzero/worker/gobigger_muzero_evaluator.py @@ -239,16 +239,21 @@ def eval( dones = np.array([False for _ in range(env_nums)]) game_segments = [ - [GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) for _ in range(agent_num)] for _ in range(env_nums) + [ + GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) for _ in range(agent_num) + ] for _ in range(env_nums) ] for env_id in range(env_nums): for agent_id in range(agent_num): game_segments[env_id][agent_id].reset( - [to_ndarray(init_obs[env_id]['observation'][agent_id]) for _ in range(self.policy_config.model.frame_stack_num)] + [ + to_ndarray(init_obs[env_id]['observation'][agent_id]) + for _ in range(self.policy_config.model.frame_stack_num) + ] ) ready_env_id = set() @@ -282,8 +287,8 @@ def eval( # policy forward # ============================================================== policy_output = self._policy.forward(stack_obs, action_mask, to_play) - actions_no_env_id=defaultdict(dict) - for k,v in policy_output.items(): + actions_no_env_id = defaultdict(dict) + for k, v in policy_output.items(): for agent_id, act in enumerate(v['action']): actions_no_env_id[k][agent_id] = act # actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} @@ -327,8 +332,8 @@ def eval( for agent_id in range(agent_num): game_segments[env_id][agent_id].append( - actions[env_id][agent_id], to_ndarray(obs['observation'][agent_id]), reward[agent_id], action_mask_dict[env_id][agent_id], - to_play_dict[env_id] + actions[env_id][agent_id], to_ndarray(obs['observation'][agent_id]), reward[agent_id], + action_mask_dict[env_id][agent_id], to_play_dict[env_id] ) # NOTE: in evaluator, we only need save the ``o_{t+1} = obs['observation']`` @@ -421,8 +426,8 @@ def eval( 'reward_min': np.min(episode_return), } # add eat info - for i in range(len(t.info['eats'])//2): - for k,v in t.info['eats'][i].items(): + for i in range(len(t.info['eats']) // 2): + for k, v in t.info['eats'][i].items(): info['agent_{}_{}'.format(i, k)] = v episode_info = eval_monitor.get_episode_info() @@ -479,12 +484,12 @@ def eval_vsbot( self._env.reset() self._policy.reset() - self._bot_policy = GoBiggerBot(env_nums, agent_id=[2,3]) #TODO only support t2p2 + self._bot_policy = GoBiggerBot(env_nums, agent_id=[2, 3]) #TODO only support t2p2 self._bot_policy.reset() # initializations init_obs = self._env.ready_obs - agent_num = len(init_obs[0]['action_mask'])//2 #TODO only support t2p2 + agent_num = len(init_obs[0]['action_mask']) // 2 #TODO only support t2p2 retry_waiting_time = 0.001 while len(init_obs.keys()) != self._env_num: @@ -509,16 +514,21 @@ def eval_vsbot( dones = np.array([False for _ in range(env_nums)]) game_segments = [ - [GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) for _ in range(agent_num)] for _ in range(env_nums) + [ + GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) for _ in range(agent_num) + ] for _ in range(env_nums) ] for env_id in range(env_nums): for agent_id in range(agent_num): game_segments[env_id][agent_id].reset( - [to_ndarray(init_obs[env_id]['observation'][agent_id]) for _ in range(self.policy_config.model.frame_stack_num)] + [ + to_ndarray(init_obs[env_id]['observation'][agent_id]) + for _ in range(self.policy_config.model.frame_stack_num) + ] ) ready_env_id = set() @@ -529,7 +539,7 @@ def eval_vsbot( while not eval_monitor.is_finished(): # Get current ready env obs. obs = self._env.ready_obs - raw_obs = [v['raw_obs'] for k,v in obs.items()] + raw_obs = [v['raw_obs'] for k, v in obs.items()] new_available_env_id = set(obs.keys()).difference(ready_env_id) ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) remain_episode -= min(len(new_available_env_id), remain_episode) @@ -559,8 +569,8 @@ def eval_vsbot( # policy forward # ============================================================== policy_output = self._policy.forward(stack_obs, action_mask, to_play) - actions_no_env_id=defaultdict(dict) - for k,v in policy_output.items(): + actions_no_env_id = defaultdict(dict) + for k, v in policy_output.items(): for agent_id, act in enumerate(v['action']): actions_no_env_id[k][agent_id] = act # actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} @@ -607,8 +617,8 @@ def eval_vsbot( for agent_id in range(agent_num): game_segments[env_id][agent_id].append( - actions[env_id][agent_id], to_ndarray(obs['observation'][agent_id]), reward[agent_id], action_mask_dict[env_id][agent_id], - to_play_dict[env_id] + actions[env_id][agent_id], to_ndarray(obs['observation'][agent_id]), reward[agent_id], + action_mask_dict[env_id][agent_id], to_play_dict[env_id] ) # NOTE: in evaluator, we only need save the ``o_{t+1} = obs['observation']`` @@ -710,12 +720,12 @@ def eval_vsbot( 'bot_reward_min': np.min(bot_episode_return), } # add eat info - for k,v in eat_info.items(): + for k, v in eat_info.items(): for i in range(len(v)): for k1, v1 in v[i].items(): info['agent_{}_{}'.format(i, k1)] = info.get('agent_{}_{}'.format(i, k1), []) + [v1] - - for k,v in info.items(): + + for k, v in info.items(): if 'agent' in k: info[k] = np.mean(v) @@ -770,4 +780,4 @@ def update_bot_reward(self, env_id: int, reward: Any) -> None: """ if isinstance(reward, torch.Tensor): reward = reward.item() - self._bot_reward[env_id].append(reward) \ No newline at end of file + self._bot_reward[env_id].append(reward) diff --git a/zoo/gobigger/config/gobigger_efficientzero_config.py b/zoo/gobigger/config/gobigger_efficientzero_config.py index 56ef9dc04..8fb60fce6 100644 --- a/zoo/gobigger/config/gobigger_efficientzero_config.py +++ b/zoo/gobigger/config/gobigger_efficientzero_config.py @@ -5,7 +5,7 @@ # ============================================================== # begin of the most frequently changed config specified by the user # ============================================================== -seed=0 +seed = 0 collector_env_num = 32 n_episode = 32 evaluator_env_num = 5 @@ -14,9 +14,9 @@ batch_size = 256 reanalyze_ratio = 0. action_space_size = 27 -direction_num=12 -eps_greedy_exploration_in_collect=True -random_collect_episode_num=0 +direction_num = 12 +eps_greedy_exploration_in_collect = True +random_collect_episode_num = 0 # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -37,30 +37,26 @@ use_action_mask=False, reward_div_value=0.1, reward_type='log_reward', - contain_raw_obs=False, # False on collect mode, True on eval vsbot mode, because bot need raw obs + contain_raw_obs=False, # False on collect mode, True on eval vsbot mode, because bot need raw obs start_spirit_progress=0.2, end_spirit_progress=0.8, manager_settings=dict( - food_manager=dict( - num_init=260, - num_min=260, - num_max=300, - ), - thorns_manager=dict( - num_init=3, - num_min=3, - num_max=4, - ), - player_manager=dict( - ball_settings=dict( - score_init=13000, - ), - ), + food_manager=dict( + num_init=260, + num_min=260, + num_max=300, + ), + thorns_manager=dict( + num_init=3, + num_min=3, + num_max=4, + ), + player_manager=dict(ball_settings=dict(score_init=13000, ), ), ), playback_settings=dict( playback_type='by_frame', by_frame=dict( - save_frame=False, # when training should set as False + save_frame=False, # when training should set as False save_dir='./', save_name_prefix='gobigger', ), @@ -77,7 +73,7 @@ action_space_size=action_space_size, downsample=True, discrete_action_encoding_type='one_hot', - norm_type='BN', + norm_type='BN', ), cuda=True, mcts_ctree=True, @@ -105,25 +101,15 @@ collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, ), - learn=dict( - learner=dict( - log_policy=True, - hook=dict( - log_show_after_iter=10, - ), - ), - ), - collect=dict( - collector=dict( - collect_print_freq=10, - ), - ), - eval=dict( - evaluator=dict( - eval_freq=5000, - stop_value=10000000000, - ), - ), + learn=dict(learner=dict( + log_policy=True, + hook=dict(log_show_after_iter=10, ), + ), ), + collect=dict(collector=dict(collect_print_freq=10, ), ), + eval=dict(evaluator=dict( + eval_freq=5000, + stop_value=10000000000, + ), ), ) atari_efficientzero_config = EasyDict(atari_efficientzero_config) main_config = atari_efficientzero_config diff --git a/zoo/gobigger/config/gobigger_eval_config.py b/zoo/gobigger/config/gobigger_eval_config.py index aa94ea2fa..3fc62da27 100644 --- a/zoo/gobigger/config/gobigger_eval_config.py +++ b/zoo/gobigger/config/gobigger_eval_config.py @@ -16,11 +16,11 @@ returns_mean_seeds = [] returns_seeds = [] seeds = [0] - create_config.env_manager.type = 'base' # when visualize must set as base - main_config.env.evaluator_env_num = 1 # when visualize must set as 1 - main_config.env.n_evaluator_episode = 2 # each seed eval episodes num - main_config.env.playback_settings.by_frame.save_frame=True - main_config.env.playback_settings.by_frame.save_name_prefix= 'gobigger' + create_config.env_manager.type = 'base' # when visualize must set as base + main_config.env.evaluator_env_num = 1 # when visualize must set as 1 + main_config.env.n_evaluator_episode = 2 # each seed eval episodes num + main_config.env.playback_settings.by_frame.save_frame = True + main_config.env.playback_settings.by_frame.save_name_prefix = 'gobigger' for seed in seeds: returns_selfplay_mean, returns_vsbot_mean = eval_muzero_gobigger( diff --git a/zoo/gobigger/config/gobigger_muzero_config.py b/zoo/gobigger/config/gobigger_muzero_config.py index 5cfcd9d66..40c447665 100644 --- a/zoo/gobigger/config/gobigger_muzero_config.py +++ b/zoo/gobigger/config/gobigger_muzero_config.py @@ -13,14 +13,13 @@ batch_size = 256 reanalyze_ratio = 0. action_space_size = 27 -direction_num=12 +direction_num = 12 # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== atari_muzero_config = dict( - exp_name= - f'data_mz_ctree/{env_name}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + exp_name=f'data_mz_ctree/{env_name}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( env_name=env_name, team_num=2, @@ -38,21 +37,17 @@ start_spirit_progress=0.2, end_spirit_progress=0.8, manager_settings=dict( - food_manager=dict( - num_init=260, - num_min=260, - num_max=300, - ), - thorns_manager=dict( - num_init=3, - num_min=3, - num_max=4, - ), - player_manager=dict( - ball_settings=dict( - score_init=13000, - ), - ), + food_manager=dict( + num_init=260, + num_min=260, + num_max=300, + ), + thorns_manager=dict( + num_init=3, + num_min=3, + num_max=4, + ), + player_manager=dict(ball_settings=dict(score_init=13000, ), ), ), playback_settings=dict( playback_type='by_frame', @@ -76,7 +71,7 @@ downsample=True, self_supervised_learning_loss=False, # default is False discrete_action_encoding_type='one_hot', - norm_type='BN', + norm_type='BN', ), cuda=True, mcts_ctree=True, @@ -97,25 +92,15 @@ collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, ), - learn=dict( - learner=dict( - log_policy=True, - hook=dict( - log_show_after_iter=10, - ), - ), - ), - collect=dict( - collector=dict( - collect_print_freq=10, - ), - ), - eval=dict( - evaluator=dict( - eval_freq=5000, - stop_value=10000000000, - ), - ), + learn=dict(learner=dict( + log_policy=True, + hook=dict(log_show_after_iter=10, ), + ), ), + collect=dict(collector=dict(collect_print_freq=10, ), ), + eval=dict(evaluator=dict( + eval_freq=5000, + stop_value=10000000000, + ), ), ) atari_muzero_config = EasyDict(atari_muzero_config) main_config = atari_muzero_config diff --git a/zoo/gobigger/config/gobigger_sampled_efficientzero_config.py b/zoo/gobigger/config/gobigger_sampled_efficientzero_config.py index 9db353d2d..97d17159e 100644 --- a/zoo/gobigger/config/gobigger_sampled_efficientzero_config.py +++ b/zoo/gobigger/config/gobigger_sampled_efficientzero_config.py @@ -15,7 +15,7 @@ batch_size = 256 reanalyze_ratio = 0. action_space_size = 27 -direction_num=12 +direction_num = 12 # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -36,25 +36,21 @@ use_action_mask=False, reward_div_value=0.1, reward_type='log_reward', - contain_raw_obs=False, # False on collect mode, True on eval vsbot mode, because bot need raw obs + contain_raw_obs=False, # False on collect mode, True on eval vsbot mode, because bot need raw obs start_spirit_progress=0.2, end_spirit_progress=0.8, manager_settings=dict( - food_manager=dict( - num_init=260, - num_min=260, - num_max=300, - ), - thorns_manager=dict( - num_init=3, - num_min=3, - num_max=4, - ), - player_manager=dict( - ball_settings=dict( - score_init=13000, - ), - ), + food_manager=dict( + num_init=260, + num_min=260, + num_max=300, + ), + thorns_manager=dict( + num_init=3, + num_min=3, + num_max=4, + ), + player_manager=dict(ball_settings=dict(score_init=13000, ), ), ), playback_settings=dict( playback_type='by_frame', @@ -79,7 +75,7 @@ continuous_action_space=continuous_action_space, num_of_sampled_actions=K, discrete_action_encoding_type='one_hot', - norm_type='BN', + norm_type='BN', ), cuda=True, mcts_ctree=True, @@ -100,25 +96,15 @@ collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, ), - learn=dict( - learner=dict( - log_policy=True, - hook=dict( - log_show_after_iter=10, - ), - ), - ), - collect=dict( - collector=dict( - collect_print_freq=10, - ), - ), - eval=dict( - evaluator=dict( - eval_freq=5000, - stop_value=10000000000, - ), - ), + learn=dict(learner=dict( + log_policy=True, + hook=dict(log_show_after_iter=10, ), + ), ), + collect=dict(collector=dict(collect_print_freq=10, ), ), + eval=dict(evaluator=dict( + eval_freq=5000, + stop_value=10000000000, + ), ), ) atari_sampled_efficientzero_config = EasyDict(atari_sampled_efficientzero_config) main_config = atari_sampled_efficientzero_config diff --git a/zoo/gobigger/env/gobigger_env.py b/zoo/gobigger/env/gobigger_env.py index daa7abc3d..46a8b9d98 100644 --- a/zoo/gobigger/env/gobigger_env.py +++ b/zoo/gobigger/env/gobigger_env.py @@ -21,12 +21,12 @@ def __init__(self, cfg: dict) -> None: self.player_num_per_team = self._cfg.player_num_per_team self.direction_num = self._cfg.direction_num self.use_action_mask = self._cfg.use_action_mask - self.action_space_size = self._cfg.action_space_size # discrete action space size + self.action_space_size = self._cfg.action_space_size # discrete action space size self.step_mul = self._cfg.get('step_mul', 8) self.setup_action() self.setup_feature() - self.contain_raw_obs = self._cfg.contain_raw_obs # for save memory - + self.contain_raw_obs = self._cfg.contain_raw_obs # for save memory + def setup_feature(self): self.second_per_frame = 0.05 self.spatial_x = 64 @@ -40,27 +40,38 @@ def setup_feature(self): self.player_init_score = self._cfg.manager_settings.player_manager.ball_settings.score_init self.start_spirit_progress = self._cfg.start_spirit_progress self.end_spirit_progress = self._cfg.end_spirit_progress - + def reset(self) -> np.ndarray: if not self._init_flag: self._env = GoBiggerEnv(self._cfg, step_mul=self.step_mul) self._init_flag = True - self.last_action_types = {player_id: self.direction_num * 2 for player_id in range(self.player_num_per_team * self.team_num)} + self.last_action_types = { + player_id: self.direction_num * 2 + for player_id in range(self.player_num_per_team * self.team_num) + } raw_obs = self._env.reset() obs = self.observation(raw_obs) - self.last_action_types = {player_id: self.direction_num * 2 for player_id in - range(self.player_num_per_team * self.team_num)} - self.last_leaderboard = {team_idx: self.player_init_score * self.player_num_per_team for team_idx in range(self.team_num)} - self.last_player_scores = {player_id: self.player_init_score for player_id in range(self.player_num_per_team * self.team_num)} + self.last_action_types = { + player_id: self.direction_num * 2 + for player_id in range(self.player_num_per_team * self.team_num) + } + self.last_leaderboard = { + team_idx: self.player_init_score * self.player_num_per_team + for team_idx in range(self.team_num) + } + self.last_player_scores = { + player_id: self.player_init_score + for player_id in range(self.player_num_per_team * self.team_num) + } return obs - + def observation(self, raw_obs): obs = self.preprocess_obs(raw_obs) # for alignment with other environments, reverse the action mask - action_mask = [np.logical_not(o['action_mask']) for o in obs] - to_play = [ -1 for _ in range(len(obs))] # Moot, for alignment with other environments + action_mask = [np.logical_not(o['action_mask']) for o in obs] + to_play = [-1 for _ in range(len(obs))] # Moot, for alignment with other environments if self.contain_raw_obs: - obs = {'observation': obs, 'action_mask': action_mask, 'to_play': to_play, 'raw_obs':raw_obs} + obs = {'observation': obs, 'action_mask': action_mask, 'to_play': to_play, 'raw_obs': raw_obs} else: obs = {'observation': obs, 'action_mask': action_mask, 'to_play': to_play} return obs @@ -85,17 +96,17 @@ def step(self, action_dict: dict) -> BaseEnvTimestep: if done: info['eval_episode_return'] = [raw_obs[0]['leaderboard'][i] for i in range(self.team_num)] return BaseEnvTimestep(obs, rew, done, info) - + def seed(self, seed: int, dynamic_seed: bool = True) -> None: self._seed = seed self._dynamic_seed = dynamic_seed np.random.seed(self._seed) - + def close(self) -> None: if self._init_flag: self._env.close() self._init_flag = False - + @property def observation_space(self) -> gym.spaces.Space: return self._observation_space @@ -107,11 +118,17 @@ def action_space(self) -> gym.spaces.Space: @property def reward_space(self) -> gym.spaces.Space: return self._reward_space - + def __repr__(self) -> str: return "LightZero Env({})".format(self.cfg.env_name) - - def transform_obs(self, obs, own_player_id=1, padding=True, last_action_type=None, ): + + def transform_obs( + self, + obs, + own_player_id=1, + padding=True, + last_action_type=None, + ): global_state, player_observations = obs player2team = self.get_player2team() leaderboard = global_state['leaderboard'] @@ -125,8 +142,10 @@ def transform_obs(self, obs, own_player_id=1, padding=True, last_action_type=Non # =========== scene_size = global_state['border'][0] own_left_top_x, own_left_top_y, own_right_bottom_x, own_right_bottom_y = own_player_obs['rectangle'] - own_view_center = [(own_left_top_x + own_right_bottom_x - scene_size) / 2, - (own_left_top_y + own_right_bottom_y - scene_size) / 2] + own_view_center = [ + (own_left_top_x + own_right_bottom_x - scene_size) / 2, + (own_left_top_y + own_right_bottom_y - scene_size) / 2 + ] # own_view_width == own_view_height own_view_width = float(own_right_bottom_x - own_left_top_x) @@ -139,8 +158,10 @@ def transform_obs(self, obs, own_player_id=1, padding=True, last_action_type=Non 'view_y': np.round(np.array(own_view_center[1])).astype(np.int64), 'view_width': np.round(np.array(own_view_width)).astype(np.int64), 'score': np.clip(np.round(np.log(np.array(own_score) / 10)).astype(np.int64), a_min=None, a_max=9), - 'team_score': np.clip(np.round(np.log(np.array(own_team_score / 10))).astype(np.int64), a_min=None, a_max=9), - 'time': np.array(global_state['last_time']//20, dtype=np.int64), + 'team_score': np.clip( + np.round(np.log(np.array(own_team_score / 10))).astype(np.int64), a_min=None, a_max=9 + ), + 'time': np.array(global_state['last_time'] // 20, dtype=np.int64), 'rank': np.array(own_rank, dtype=np.int64), 'last_action_type': np.array(last_action_type, dtype=np.int64) } @@ -166,16 +187,17 @@ def transform_obs(self, obs, own_player_id=1, padding=True, last_action_type=Non game_player_view_x = (game_player_right_bottom_x + game_player_left_top_x - scene_size) / 2 game_player_view_y = (game_player_right_bottom_y + game_player_left_top_y - scene_size) / 2 - all_players.append([alliance, - game_player_view_x, - game_player_view_y, - ]) + all_players.append([ + alliance, + game_player_view_x, + game_player_view_y, + ]) all_players = np.array(all_players) player_padding_num = self.max_player_num - len(all_players) player_num = len(all_players) if player_padding_num < 0: - all_players = all_players[:self.max_player_num,:] + all_players = all_players[:self.max_player_num, :] else: all_players = np.pad(all_players, pad_width=((0, player_padding_num), (0, 0)), mode='constant') team_info = { @@ -198,24 +220,37 @@ def transform_obs(self, obs, own_player_id=1, padding=True, last_action_type=Non neutral_player_id = self.team_num * self.player_num_per_team neutral_team_rank = self.team_num - # clone = [type, score, player_id, team_id, team_rank, x, y, next_x, next_y] - clone = [[ball_type_map['clone'], bl[3], bl[-2], bl[-1], team2rank[bl[-1]], bl[0], bl[1], - *self.next_position(bl[0], bl[1], bl[4], bl[5])] for bl in clone] - + # clone = [type, score, player_id, team_id, team_rank, x, y, next_x, next_y] + clone = [ + [ + ball_type_map['clone'], bl[3], bl[-2], bl[-1], team2rank[bl[-1]], bl[0], bl[1], + *self.next_position(bl[0], bl[1], bl[4], bl[5]) + ] for bl in clone + ] + # thorn = [type, score, player_id, team_id, team_rank, x, y, next_x, next_y] - thorns = [[ball_type_map['thorns'], bl[3], neutral_player_id, neutral_team_id, neutral_team_rank, bl[0], bl[1], - *self.next_position(bl[0], bl[1], bl[4], bl[5])] for bl in thorns] - + thorns = [ + [ + ball_type_map['thorns'], bl[3], neutral_player_id, neutral_team_id, neutral_team_rank, bl[0], bl[1], + *self.next_position(bl[0], bl[1], bl[4], bl[5]) + ] for bl in thorns + ] + # thorn = [type, score, player_id, team_id, team_rank, x, y, next_x, next_y] food = [ - [ball_type_map['food'], bl[3], neutral_player_id, neutral_team_id, neutral_team_rank, bl[0], bl[1], bl[0], - bl[1]] for bl in food] + [ + ball_type_map['food'], bl[3], neutral_player_id, neutral_team_id, neutral_team_rank, bl[0], bl[1], + bl[0], bl[1] + ] for bl in food + ] # spore = [type, score, player_id, team_id, team_rank, x, y, next_x, next_y] spore = [ - [ball_type_map['spore'], bl[3], bl[-1], player2team[bl[-1]], team2rank[player2team[bl[-1]]], bl[0], - bl[1], - *self.next_position(bl[0], bl[1], bl[4], bl[5])] for bl in spore] + [ + ball_type_map['spore'], bl[3], bl[-1], player2team[bl[-1]], team2rank[player2team[bl[-1]]], bl[0], + bl[1], *self.next_position(bl[0], bl[1], bl[4], bl[5]) + ] for bl in spore + ] all_balls = clone + thorns + food + spore @@ -238,7 +273,9 @@ def transform_obs(self, obs, own_player_id=1, padding=True, last_action_type=Non all_balls[:, -1] = ((all_balls[:, -1] - origin_y) / own_view_width * self.spatial_y) # ball - ball_indices = np.logical_and(all_balls[:, 0] != 2, all_balls[:, 0] != 4) # include player balls and thorn balls + ball_indices = np.logical_and( + all_balls[:, 0] != 2, all_balls[:, 0] != 4 + ) # include player balls and thorn balls balls = all_balls[ball_indices] balls_num = len(balls) @@ -286,7 +323,12 @@ def transform_obs(self, obs, own_player_id=1, padding=True, last_action_type=Non ## score&radius scale_score = balls[:, 1] / 100 radius = np.clip(np.sqrt(scale_score * 0.042 + 0.15) / own_view_width, a_max=1, a_min=None) - score = np.clip(np.round(np.clip(np.sqrt(scale_score * 0.042 + 0.15) / own_view_width, a_max=1, a_min=None)*50).astype(int), a_max=49, a_min=None) + score = np.clip( + np.round(np.clip(np.sqrt(scale_score * 0.042 + 0.15) / own_view_width, a_max=1, a_min=None) * 50 + ).astype(int), + a_max=49, + a_min=None + ) ## rank: ball_rank = balls[:, 4] @@ -363,7 +405,7 @@ def transform_obs(self, obs, own_player_id=1, padding=True, last_action_type=Non def preprocess_obs(self, raw_obs): env_player_obs = [] - for game_player_id in range(self.player_num_per_team*self.team_num): + for game_player_id in range(self.player_num_per_team * self.team_num): last_action_type = self.last_action_types[game_player_id] if self.use_action_mask: can_eject = raw_obs[1][game_player_id]['can_eject'] @@ -371,36 +413,38 @@ def preprocess_obs(self, raw_obs): action_mask = self.generate_action_mask(can_eject=can_eject, can_split=can_split) else: action_mask = self.generate_action_mask(can_eject=True, can_split=True) - game_player_obs = self.transform_obs(raw_obs, own_player_id=game_player_id, padding=True, last_action_type=last_action_type) + game_player_obs = self.transform_obs( + raw_obs, own_player_id=game_player_id, padding=True, last_action_type=last_action_type + ) game_player_obs['action_mask'] = action_mask env_player_obs.append(game_player_obs) return env_player_obs - + def generate_action_mask(self, can_eject, can_split): # action mask # 1 represent can not do this action # 0 represent can do this action - action_mask = np.zeros((self.action_space_size,), dtype=np.bool_) + action_mask = np.zeros((self.action_space_size, ), dtype=np.bool_) if not can_eject: action_mask[self.direction_num * 2 + 1] = True if not can_split: action_mask[self.direction_num * 2 + 2] = True return action_mask - + def get_player2team(self, ): player2team = {} for player_id in range(self.player_num_per_team * self.team_num): player2team[player_id] = player_id // self.player_num_per_team return player2team - + def next_position(self, x, y, vel_x, vel_y): next_x = x + self.second_per_frame * vel_x * self.step_mul next_y = y + self.second_per_frame * vel_y * self.step_mul return next_x, next_y - + def transform_action(self, action_idx): return self.x_y_action_List[int(action_idx)] - + def setup_action(self): theta = math.pi * 2 / self.direction_num self.x_y_action_List = [[0.3 * math.cos(theta * i), 0.3 * math.sin(theta * i), 0] for i in range(self.direction_num)] + \ @@ -415,7 +459,7 @@ def get_spirit(self, progress): return spirit else: return 1 - + def transform_reward(self, next_obs): last_time = next_obs[0]['last_time'] total_frame = next_obs[0]['total_frame'] @@ -440,11 +484,11 @@ def transform_reward(self, next_obs): score_rewards_list.append(score_reward) elif self.reward_type == 'sqrt_player': player_reward = player_score - self.last_player_scores[game_player_id] - reward_sign = (player_reward > 0) - (player_reward < 0) # np.sign + reward_sign = (player_reward > 0) - (player_reward < 0) # np.sign score_rewards_list.append(reward_sign * math.sqrt(abs(player_reward)) / 2) elif self.reward_type == 'sqrt_team': team_reward = team_score - self.last_leaderboard[game_team_id] - reward_sign = (team_reward > 0) - (team_reward < 0) # np.sign + reward_sign = (team_reward > 0) - (team_reward < 0) # np.sign score_rewards_list.append(reward_sign * math.sqrt(abs(team_reward)) / 2) else: raise NotImplementedError diff --git a/zoo/gobigger/env/gobigger_rule_bot.py b/zoo/gobigger/env/gobigger_rule_bot.py index d7a16b87f..c8519a9ad 100644 --- a/zoo/gobigger/env/gobigger_rule_bot.py +++ b/zoo/gobigger/env/gobigger_rule_bot.py @@ -18,7 +18,7 @@ def __init__(self, env_num, agent_id: List[int]): self.env_num = env_num self.agent_id = agent_id self.bot = [[BotAgent(i) for i in self.agent_id] for _ in range(self.env_num)] - + def forward(self, raw_obs): action = defaultdict(dict) for env_id in range(self.env_num): @@ -64,6 +64,7 @@ def default_model(self) -> Tuple[str, List[str]]: def _monitor_vars_learn(self) -> List[str]: pass + class BotAgent(): def __init__(self, game_player_id): @@ -122,7 +123,7 @@ def step(self, obs): action_type = 2 else: action_type = 0 - if direction.length()>0: + if direction.length() > 0: direction = direction.normalize() else: direction = Vector2(1, 1).normalize() @@ -169,11 +170,11 @@ def process_food_balls(self, food_balls, my_max_clone_ball): def preprocess(self, overlap): new_overlap = {} for k, v in overlap.items(): - if k =='clone': + if k == 'clone': new_overlap[k] = [] for index, vv in enumerate(v): - tmp={} - tmp['position'] = Vector2(vv[0],vv[1]) + tmp = {} + tmp['position'] = Vector2(vv[0], vv[1]) tmp['radius'] = vv[2] tmp['player'] = int(vv[-2]) tmp['team'] = int(vv[-1]) @@ -181,8 +182,8 @@ def preprocess(self, overlap): else: new_overlap[k] = [] for index, vv in enumerate(v): - tmp={} - tmp['position'] = Vector2(vv[0],vv[1]) + tmp = {} + tmp['position'] = Vector2(vv[0], vv[1]) tmp['radius'] = vv[2] new_overlap[k].append(tmp) return new_overlap @@ -195,17 +196,19 @@ def preprocess_tuple2vector(self, overlap): new_overlap[k].append(vv) new_overlap[k][index]['position'] = Vector2(*vv['position']) return new_overlap - + def add_noise_to_direction(self, direction, noise_ratio=0.1): - direction = direction + Vector2(((random.random() * 2 - 1)*noise_ratio)*direction.x, - ((random.random() * 2 - 1)*noise_ratio)*direction.y) + direction = direction + Vector2( + ((random.random() * 2 - 1) * noise_ratio) * direction.x, + ((random.random() * 2 - 1) * noise_ratio) * direction.y + ) return direction def radius_to_score(self, radius): - return (math.pow(radius,2) - 0.15) / 0.042 * 100 - + return (math.pow(radius, 2) - 0.15) / 0.042 * 100 + def can_eat(self, radius1, radius2): return self.radius_to_score(radius1) > 1.3 * self.radius_to_score(radius2) - - def reset(self,): - self.actions_queue.queue.clear() \ No newline at end of file + + def reset(self, ): + self.actions_queue.queue.clear() diff --git a/zoo/gobigger/env/test_gobbiger_env.py b/zoo/gobigger/env/test_gobbiger_env.py index d933c38c9..416bf4c43 100644 --- a/zoo/gobigger/env/test_gobbiger_env.py +++ b/zoo/gobigger/env/test_gobbiger_env.py @@ -1,27 +1,26 @@ - import pytest from easydict import EasyDict from gobigger_env import GoBiggerLightZeroEnv from gobigger_rule_bot import BotAgent - -env_cfg=EasyDict(dict( - env_name='gobigger', - team_num=2, - player_num_per_team=2, - direction_num=12, - step_mul=8, - map_width=64, - map_height=64, - frame_limit=3600, - action_space_size=27, - use_action_mask=False, - reward_div_value=0.1, - reward_type='log_reward', - contain_raw_obs=True, # False on collect mode, True on eval vsbot mode, because bot need raw obs - start_spirit_progress=0.2, - end_spirit_progress=0.8, - manager_settings=dict( +env_cfg = EasyDict( + dict( + env_name='gobigger', + team_num=2, + player_num_per_team=2, + direction_num=12, + step_mul=8, + map_width=64, + map_height=64, + frame_limit=3600, + action_space_size=27, + use_action_mask=False, + reward_div_value=0.1, + reward_type='log_reward', + contain_raw_obs=True, # False on collect mode, True on eval vsbot mode, because bot need raw obs + start_spirit_progress=0.2, + end_spirit_progress=0.8, + manager_settings=dict( food_manager=dict( num_init=260, num_min=260, @@ -32,22 +31,20 @@ num_min=3, num_max=4, ), - player_manager=dict( - ball_settings=dict( - score_init=13000, - ), + player_manager=dict(ball_settings=dict(score_init=13000, ), ), + ), + playback_settings=dict( + playback_type='by_frame', + by_frame=dict( + # save_frame=False, + save_frame=True, + save_dir='./', + save_name_prefix='test', ), - ), - playback_settings=dict( - playback_type='by_frame', - by_frame=dict( - # save_frame=False, - save_frame=True, - save_dir='./', - save_name_prefix='test', ), - ), -)) + ) +) + @pytest.mark.envtest class TestGoBiggerLightZeroEnv: From 6da29975fd9c815275b852e2b5083647b0c48245 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Sun, 25 Jun 2023 15:34:41 +0800 Subject: [PATCH 22/54] polish(yzj): polish code comments about gobigger in worker/policy/entry --- lzero/entry/eval_muzero_gobigger.py | 2 +- lzero/entry/train_muzero_gobigger.py | 2 +- lzero/policy/gobigger_efficientzero.py | 2 +- lzero/policy/gobigger_muzero.py | 2 +- lzero/policy/gobigger_random_policy.py | 2 +- lzero/policy/gobigger_sampled_efficientzero.py | 2 +- lzero/worker/gobigger_muzero_collector.py | 5 ++--- lzero/worker/gobigger_muzero_evaluator.py | 2 +- zoo/gobigger/env/gobigger_env.py | 3 +++ zoo/gobigger/env/gobigger_rule_bot.py | 1 + 10 files changed, 13 insertions(+), 10 deletions(-) diff --git a/lzero/entry/eval_muzero_gobigger.py b/lzero/entry/eval_muzero_gobigger.py index dee5a98b5..8c463246e 100644 --- a/lzero/entry/eval_muzero_gobigger.py +++ b/lzero/entry/eval_muzero_gobigger.py @@ -24,7 +24,7 @@ def eval_muzero_gobigger( ) -> 'Policy': # noqa """ Overview: - The train entry for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. + The eval entry for GoBigger MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. Arguments: - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type. ``Tuple[dict, dict]`` type means [user_config, create_cfg]. diff --git a/lzero/entry/train_muzero_gobigger.py b/lzero/entry/train_muzero_gobigger.py index 2abdaf2f5..e06a7aab5 100644 --- a/lzero/entry/train_muzero_gobigger.py +++ b/lzero/entry/train_muzero_gobigger.py @@ -28,7 +28,7 @@ def train_muzero_gobigger( ) -> 'Policy': # noqa """ Overview: - The train entry for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. + The train entry for GoBigger MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. Arguments: - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type. ``Tuple[dict, dict]`` type means [user_config, create_cfg]. diff --git a/lzero/policy/gobigger_efficientzero.py b/lzero/policy/gobigger_efficientzero.py index 3234fdfb4..f2f89d96d 100644 --- a/lzero/policy/gobigger_efficientzero.py +++ b/lzero/policy/gobigger_efficientzero.py @@ -25,7 +25,7 @@ class GoBiggerEfficientZeroPolicy(Policy): """ Overview: - The policy class for EfficientZero. + The policy class for GoBiggerEfficientZero. """ # The default_config for EfficientZero policy. diff --git a/lzero/policy/gobigger_muzero.py b/lzero/policy/gobigger_muzero.py index dbcec0531..e721bb291 100644 --- a/lzero/policy/gobigger_muzero.py +++ b/lzero/policy/gobigger_muzero.py @@ -24,7 +24,7 @@ class GoBiggerMuZeroPolicy(Policy): """ Overview: - The policy class for MuZero. + The policy class for GoBiggerMuZero. """ # The default_config for MuZero policy. diff --git a/lzero/policy/gobigger_random_policy.py b/lzero/policy/gobigger_random_policy.py index 15b53efbe..c10172894 100644 --- a/lzero/policy/gobigger_random_policy.py +++ b/lzero/policy/gobigger_random_policy.py @@ -25,7 +25,7 @@ class GoBiggerRandomPolicy(Policy): """ Overview: - The policy class for EfficientZero. + The policy class for GoBiggerRandom. """ # The default_config for EfficientZero policy. diff --git a/lzero/policy/gobigger_sampled_efficientzero.py b/lzero/policy/gobigger_sampled_efficientzero.py index 8bf5d2b10..f402091fc 100644 --- a/lzero/policy/gobigger_sampled_efficientzero.py +++ b/lzero/policy/gobigger_sampled_efficientzero.py @@ -26,7 +26,7 @@ class GoBiggerSampledEfficientZeroPolicy(Policy): """ Overview: - The policy class for Sampled EfficientZero. + The policy class for GoBigger Sampled EfficientZero. """ # The default_config for Sampled fEficientZero policy. diff --git a/lzero/worker/gobigger_muzero_collector.py b/lzero/worker/gobigger_muzero_collector.py index 840f26753..49c6cfbba 100644 --- a/lzero/worker/gobigger_muzero_collector.py +++ b/lzero/worker/gobigger_muzero_collector.py @@ -19,7 +19,8 @@ class GoBiggerMuZeroCollector(ISerialCollector): """ Overview: - The Collector for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. + The Collector for GoBigger MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. + For GoBigger, add agent_num dim in game_segment. Interfaces: __init__, reset, reset_env, reset_policy, collect, close Property: @@ -447,8 +448,6 @@ def collect(self, ) else: for agent_id in range(agent_num): - if len(distributions_dict[env_id][agent_id]) != 27: - print('') game_segments[env_id][agent_id].store_search_stats( distributions_dict[env_id][agent_id], value_dict[env_id][agent_id] ) diff --git a/lzero/worker/gobigger_muzero_evaluator.py b/lzero/worker/gobigger_muzero_evaluator.py index d4634cda4..b57265203 100644 --- a/lzero/worker/gobigger_muzero_evaluator.py +++ b/lzero/worker/gobigger_muzero_evaluator.py @@ -21,7 +21,7 @@ class GoBiggerMuZeroEvaluator(ISerialEvaluator): """ Overview: - The Evaluator for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. + The Evaluator for GoBigger MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. Interfaces: __init__, reset, reset_policy, reset_env, close, should_eval, eval Property: diff --git a/zoo/gobigger/env/gobigger_env.py b/zoo/gobigger/env/gobigger_env.py index 46a8b9d98..d996a7619 100644 --- a/zoo/gobigger/env/gobigger_env.py +++ b/zoo/gobigger/env/gobigger_env.py @@ -109,14 +109,17 @@ def close(self) -> None: @property def observation_space(self) -> gym.spaces.Space: + # The following ensures compatibility with the DI-engine Env class. return self._observation_space @property def action_space(self) -> gym.spaces.Space: + # The following ensures compatibility with the DI-engine Env class. return self._action_space @property def reward_space(self) -> gym.spaces.Space: + # The following ensures compatibility with the DI-engine Env class. return self._reward_space def __repr__(self) -> str: diff --git a/zoo/gobigger/env/gobigger_rule_bot.py b/zoo/gobigger/env/gobigger_rule_bot.py index c8519a9ad..2bc8894f8 100644 --- a/zoo/gobigger/env/gobigger_rule_bot.py +++ b/zoo/gobigger/env/gobigger_rule_bot.py @@ -34,6 +34,7 @@ def reset(self, env_id_lst=None): for agent in self.bot[env_id]: agent.reset() + # The following ensures compatibility with the DI-engine Policy class. def _init_learn(self) -> None: pass From a2ca5ee9b9cde2b83eeb5cde66d899792666e83b Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Sun, 25 Jun 2023 16:12:55 +0800 Subject: [PATCH 23/54] feature(yzj): add eps_greedy and random_collect_episode in gobigger mz --- lzero/policy/gobigger_muzero.py | 24 ++++++++++++------- zoo/gobigger/config/gobigger_muzero_config.py | 15 ++++++++++-- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/lzero/policy/gobigger_muzero.py b/lzero/policy/gobigger_muzero.py index e721bb291..c06737ffd 100644 --- a/lzero/policy/gobigger_muzero.py +++ b/lzero/policy/gobigger_muzero.py @@ -440,6 +440,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in return { 'collect_mcts_temperature': self.collect_mcts_temperature, + 'collect_epsilon': self.collect_epsilon, 'cur_lr': self._optimizer.param_groups[0]['lr'], 'weighted_total_loss': loss_info[0], 'total_loss': loss_info[1], @@ -473,6 +474,7 @@ def _init_collect(self) -> None: else: self._mcts_collect = MCTSPtree(self._cfg) self.collect_mcts_temperature = 1 + self.collect_epsilon = 1 def _forward_collect( self, @@ -480,6 +482,7 @@ def _forward_collect( action_mask: list = None, temperature: float = 1, to_play: List = [-1], + epsilon: float = 0.25, ready_env_id=None ) -> Dict: """ @@ -507,6 +510,7 @@ def _forward_collect( """ self._collect_model.eval() self.collect_mcts_temperature = temperature + self.collect_epsilon = epsilon active_collect_env_num = len(data) data = to_tensor(data) data = sum(sum(data, []), []) @@ -555,14 +559,18 @@ def _forward_collect( for i in range(batch_size): distributions, value = roots_visit_count_distributions[i], roots_values[i] - # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents - # the index within the legal action set, rather than the index in the entire action set. - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self.collect_mcts_temperature, deterministic=False - ) - # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the - # entire action set. - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if self._cfg.eps.eps_greedy_exploration_in_collect: + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self.collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self.collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self.collect_mcts_temperature, deterministic=False + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] output[i // agent_num]['action'].append(action) output[i // agent_num]['distributions'].append(distributions) output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) diff --git a/zoo/gobigger/config/gobigger_muzero_config.py b/zoo/gobigger/config/gobigger_muzero_config.py index 40c447665..b16030f34 100644 --- a/zoo/gobigger/config/gobigger_muzero_config.py +++ b/zoo/gobigger/config/gobigger_muzero_config.py @@ -5,6 +5,7 @@ # ============================================================== # begin of the most frequently changed config specified by the user # ============================================================== +seed = 0 collector_env_num = 32 n_episode = 32 evaluator_env_num = 5 @@ -14,12 +15,14 @@ reanalyze_ratio = 0. action_space_size = 27 direction_num = 12 +eps_greedy_exploration_in_collect = True +random_collect_episode_num = 0 # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== atari_muzero_config = dict( - exp_name=f'data_mz_ctree/{env_name}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + exp_name=f'data_mz_ctree/{env_name}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', env=dict( env_name=env_name, team_num=2, @@ -77,6 +80,14 @@ mcts_ctree=True, env_type='not_board_games', game_segment_length=400, + random_collect_episode_num=random_collect_episode_num, + eps=dict( + eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, + type='exp', + start=1., + end=0.05, + decay=int(1.5e4), + ), use_augmentation=False, update_per_collect=update_per_collect, batch_size=batch_size, @@ -125,4 +136,4 @@ if __name__ == "__main__": from lzero.entry import train_muzero_gobigger - train_muzero_gobigger([main_config, create_config], seed=0) + train_muzero_gobigger([main_config, create_config], seed=seed) From 8c4c5a040c531ca007970b46c42876363d5fefde Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Wed, 28 Jun 2023 15:41:42 +0800 Subject: [PATCH 24/54] polish(yzj): polish entry/buffer/policy/config/model/env comments and codes, fix typo in test_gobigger_env, use numpy instead of pygame in gobigger bot --- lzero/entry/eval_muzero_gobigger.py | 3 +- lzero/entry/train_muzero_gobigger.py | 5 +- lzero/entry/utils.py | 11 + .../gobigger_game_buffer_efficientzero.py | 3 +- .../buffer/gobigger_game_buffer_muzero.py | 5 +- ...igger_game_buffer_sampled_efficientzero.py | 2 +- lzero/model/gobigger/network/encoder.py | 202 ++++++++++++++++-- lzero/policy/gobigger_random_policy.py | 76 +------ .../config/gobigger_efficientzero_config.py | 6 +- zoo/gobigger/config/gobigger_eval_config.py | 9 +- zoo/gobigger/config/gobigger_muzero_config.py | 5 - .../gobigger_sampled_efficientzero_config.py | 5 - zoo/gobigger/env/gobigger_env.py | 8 +- zoo/gobigger/env/gobigger_rule_bot.py | 31 +-- ...t_gobbiger_env.py => test_gobigger_env.py} | 0 15 files changed, 240 insertions(+), 131 deletions(-) rename zoo/gobigger/env/{test_gobbiger_env.py => test_gobigger_env.py} (100%) diff --git a/lzero/entry/eval_muzero_gobigger.py b/lzero/entry/eval_muzero_gobigger.py index 8c463246e..d19a20986 100644 --- a/lzero/entry/eval_muzero_gobigger.py +++ b/lzero/entry/eval_muzero_gobigger.py @@ -34,7 +34,8 @@ def eval_muzero_gobigger( point to the ckpt file of the pretrained model, and an absolute path is recommended. In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. Returns: - - policy (:obj:`Policy`): Converged policy. + - reward_sp (:obj:`List`): reward of self-play mode. + - reward_vsbot (:obj:`List`): reward of vsbot mode. """ cfg, create_cfg = input_cfg assert create_cfg.policy.type in ['gobigger_efficientzero', 'gobigger_muzero', 'gobigger_sampled_efficientzero'], \ diff --git a/lzero/entry/train_muzero_gobigger.py b/lzero/entry/train_muzero_gobigger.py index e06a7aab5..57d2d139d 100644 --- a/lzero/entry/train_muzero_gobigger.py +++ b/lzero/entry/train_muzero_gobigger.py @@ -44,9 +44,7 @@ def train_muzero_gobigger( """ cfg, create_cfg = input_cfg - assert create_cfg.policy.type in ['gobigger_efficientzero', 'gobigger_muzero', 'gobigger_sampled_efficientzero'], \ - "train_muzero entry now only support the following algo.: 'gobigger_efficientzero', 'gobigger_muzero', 'gobigger_sampled_efficientzero'" - + assert create_cfg.policy.type in ['gobigger_efficientzero', 'gobigger_muzero', 'gobigger_sampled_efficientzero'] if create_cfg.policy.type == 'gobigger_efficientzero': from lzero.mcts import GoBiggerEfficientZeroGameBuffer as GameBuffer elif create_cfg.policy.type == 'gobigger_muzero': @@ -93,7 +91,6 @@ def train_muzero_gobigger( # specific game buffer for MCTS+RL algorithms replay_buffer = GameBuffer(policy_config) collector = GoBiggerMuZeroCollector( - collect_print_freq=cfg.collect.collector.collect_print_freq, env=collector_env, policy=policy.collect_mode, tb_logger=tb_logger, diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index 97c838fdc..9c6e777df 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -17,6 +17,17 @@ def random_collect( replay_buffer: 'IBuffer', # noqa postprocess_data_fn: Optional[Callable] = None ) -> None: # noqa + """ + Overview: + Collect data by random policy. + Arguments: + - policy_cfg (:obj:`EasyDict`): The policy config. + - policy (:obj:`Policy`): The policy. + - collector (:obj:`ISerialCollector`): The collector. + - collector_env (:obj:`BaseEnvManager`): The collector env manager. + - replay_buffer (:obj:`IBuffer`): The replay buffer. + - postprocess_data_fn (:obj:`Optional[Callable]`): The postprocess function for the collected data. + """ assert policy_cfg.random_collect_episode_num > 0 random_policy = GoBiggerRandomPolicy(cfg=policy_cfg) diff --git a/lzero/mcts/buffer/gobigger_game_buffer_efficientzero.py b/lzero/mcts/buffer/gobigger_game_buffer_efficientzero.py index b8c76cac7..4f74999aa 100644 --- a/lzero/mcts/buffer/gobigger_game_buffer_efficientzero.py +++ b/lzero/mcts/buffer/gobigger_game_buffer_efficientzero.py @@ -16,7 +16,7 @@ class GoBiggerEfficientZeroGameBuffer(GoBiggerMuZeroGameBuffer): """ Overview: - The specific game buffer for EfficientZero policy. + The specific game buffer for GoBigger EfficientZero policy. """ def __init__(self, cfg: dict): @@ -103,7 +103,6 @@ def _prepare_reward_value_context( - reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, to_play_segment """ - # zero_obs = game_segment_list[0].zero_obs() value_obs_list = [] # the value is valid or not (out of trajectory) value_mask = [] diff --git a/lzero/mcts/buffer/gobigger_game_buffer_muzero.py b/lzero/mcts/buffer/gobigger_game_buffer_muzero.py index 5df565e74..8563789ff 100644 --- a/lzero/mcts/buffer/gobigger_game_buffer_muzero.py +++ b/lzero/mcts/buffer/gobigger_game_buffer_muzero.py @@ -18,7 +18,7 @@ class GoBiggerMuZeroGameBuffer(GameBuffer): """ Overview: - The specific game buffer for MuZero policy. + The specific game buffer for GoBigger MuZero policy. """ def __init__(self, cfg: dict): @@ -48,7 +48,7 @@ def __init__(self, cfg: dict): self.game_pos_priorities = [] self.game_segment_game_pos_look_up = [] - self.tmp_obs = None # for value obs list [46 + 4(td_step)] not < 50(game_segment) + self.tmp_obs = None # a tmp value which records obs when value obs list [current_index + 4(td_step)] > 50(game_segment) def sample( self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] @@ -200,7 +200,6 @@ def _prepare_reward_value_context( - reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, to_play_segment """ - # zero_obs = game_segment_list[0].zero_obs() value_obs_list = [] # the value is valid or not (out of game_segment) value_mask = [] diff --git a/lzero/mcts/buffer/gobigger_game_buffer_sampled_efficientzero.py b/lzero/mcts/buffer/gobigger_game_buffer_sampled_efficientzero.py index 7dd8258e3..877b721a1 100644 --- a/lzero/mcts/buffer/gobigger_game_buffer_sampled_efficientzero.py +++ b/lzero/mcts/buffer/gobigger_game_buffer_sampled_efficientzero.py @@ -16,7 +16,7 @@ class GoBiggerSampledEfficientZeroGameBuffer(GoBiggerMuZeroGameBuffer): """ Overview: - The specific game buffer for Sampled EfficientZero policy. + The specific game buffer for GoBigger Sampled EfficientZero policy. """ def __init__(self, cfg: dict): diff --git a/lzero/model/gobigger/network/encoder.py b/lzero/model/gobigger/network/encoder.py index adcbc2d89..a42de3282 100644 --- a/lzero/model/gobigger/network/encoder.py +++ b/lzero/model/gobigger/network/encoder.py @@ -4,33 +4,82 @@ class OnehotEncoder(nn.Module): - + """ + Overview: + For encoding integers into one-hot vectors using PyTorch's Embedding layer. + """ def __init__(self, num_embeddings: int): + """ + Overview: + The initializer for OnehotEncoder. It initializes an Embedding layer with an identity matrix as its weights. + Arguments: + - num_embeddings (int): The size of the dictionary of embeddings, i.e., the number of rows in the embedding matrix. + """ super(OnehotEncoder, self).__init__() self.num_embeddings = num_embeddings self.main = nn.Embedding.from_pretrained(torch.eye(self.num_embeddings), freeze=True, padding_idx=None) def forward(self, x: torch.Tensor): + """ + Overview: + The common computation graph of the OnehotEncoder. It encodes the input tensor into one-hot vectors. + Arguments: + - x (torch.Tensor): The input tensor should be integers and its maximum value should be less than the 'num_embeddings' specified in the initializer. + Returns: + - (torch.Tensor): Return the x-th row of the embedding layer. + """ x = x.long().clamp_(max=self.num_embeddings - 1) return self.main(x) class OnehotEmbedding(nn.Module): - + """ + Overview: + For encoding integers into higher-dimensional embedding vectors using PyTorch's Embedding layer. + """ def __init__(self, num_embeddings: int, embedding_dim: int): + """ + Overview: + The initializer for OnehotEmbedding. It initializes an Embedding layer with 'num_embeddings' rows and 'embedding_dim' columns. + Arguments: + - num_embeddings (int): The size of the dictionary of embeddings, i.e., the number of rows in the embedding matrix. + - embedding_dim (int): The size of each embedding vector, i.e., the number of columns in the embedding matrix. + """ super(OnehotEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.main = nn.Embedding(num_embeddings=self.num_embeddings, embedding_dim=self.embedding_dim) def forward(self, x: torch.Tensor): + """ + Overview: + The common computation graph of the OnehotEmbedding. + It encodes the input tensor into higher-dimensional embedding vectors. + Arguments: + - x (torch.Tensor): The input tensor should be integers, and its maximum value should be less than the 'num_embeddings' specified in the initializer. + Returns: + - (torch.Tensor): Return the x-th row of the embedding layer. + """ x = x.long().clamp_(max=self.num_embeddings - 1) return self.main(x) class BinaryEncoder(nn.Module): - + """ + Overview: + For encoding integers into binary vectors using PyTorch's Embedding layer. + """ def __init__(self, num_embeddings: int): + """ + Overview: + The initializer for BinaryEncoder. It initializes an Embedding layer with a binary embedding matrix + as its weights. The binary embedding matrix is constructed by representing each integer (from 0 to 2^bit_num-1) + as a binary vector. + Arguments: + - num_embeddings (int): The number of bits in the binary representation. It determines the size of the dictionary of embeddings, + i.e., the number of rows in the embedding matrix (2^bit_num), and the size of each embedding vector, + i.e., the number of columns in the embedding matrix (bit_num). + """ super(BinaryEncoder, self).__init__() self.bit_num = num_embeddings self.main = nn.Embedding.from_pretrained( @@ -39,6 +88,14 @@ def __init__(self, num_embeddings: int): @staticmethod def get_binary_embed_matrix(bit_num): + """ + Overview: + A helper function that generates the binary embedding matrix. + Arguments: + - bit_num (int): The number of bits in the binary representation. + Returns: + - (torch.Tensor): A tensor of shape (2^bit_num, bit_num), where each row is the binary representation of the row index. + """ embedding_matrix = [] for n in range(2 ** bit_num): embedding = [n >> d & 1 for d in range(bit_num)][::-1] @@ -46,13 +103,34 @@ def get_binary_embed_matrix(bit_num): return torch.tensor(embedding_matrix, dtype=torch.float) def forward(self, x: torch.Tensor): + """ + Overview: + The common computation graph of the BinaryEncoder. It encodes the input tensor into binary vectors. + Arguments: + - x (torch.Tensor): The input tensor should be integers, and its maximum value should be less than 2^bit_num. + Returns: + - (torch.Tensor): Return the x-th row of the embedding layer. + """ x = x.long().clamp_(max=2 ** self.bit_num - 1) return self.main(x) class SignBinaryEncoder(nn.Module): - - def __init__(self, num_embeddings): + """ + Overview: + For encoding integers into signed binary vectors using PyTorch's Embedding layer. + """ + def __init__(self, num_embeddings: int): + """ + Overview: + The initializer for SignBinaryEncoder. It initializes an Embedding layer with a signed binary embedding matrix + as its weights. The signed binary embedding matrix is constructed by representing each integer (from -2^(bit_num-1) to 2^(bit_num-1)-1) + as a signed binary vector. The first bit is the sign bit, with 1 representing negative and 0 representing nonnegative. + Arguments: + - num_embeddings (int): The number of bits in the signed binary representation. It determines the size of the dictionary of embeddings, + i.e., the number of rows in the embedding matrix (2^bit_num), and the size of each embedding vector, + i.e., the number of columns in the embedding matrix (bit_num). + """ super(SignBinaryEncoder, self).__init__() self.bit_num = num_embeddings self.main = nn.Embedding.from_pretrained( @@ -62,6 +140,15 @@ def __init__(self, num_embeddings): @staticmethod def get_sign_binary_matrix(bit_num): + """ + Overview: + A helper function that generates the signed binary embedding matrix. + Arguments: + - bit_num (int): The number of bits in the signed binary representation. + Returns: + - (torch.Tensor): A tensor of shape (2^bit_num, bit_num), where each row is the signed binary representation of the row index minus 2^(bit_num-1). + The first column is the sign bit, with 1 representing negative and 0 representing nonnegative. + """ neg_embedding_matrix = [] pos_embedding_matrix = [] for n in range(1, 2 ** (bit_num - 1)): @@ -72,13 +159,35 @@ def get_sign_binary_matrix(bit_num): return torch.tensor(embedding_matrix, dtype=torch.float) def forward(self, x: torch.Tensor): + """ + Overview: + The common computation graph of the SignBinaryEncoder. It encodes the input tensor into signed binary vectors. + Arguments: + - x (torch.Tensor): The input tensor. Its data type should be integers, and its maximum absolute value should be less than 2^(bit_num-1). + Returns: + - (torch.Tensor): Return the x-th row of the embedding layer. + """ x = x.long().clamp_(max=self.max_val, min=-self.max_val) return self.main(x + self.max_val) class PositionEncoder(nn.Module): - - def __init__(self, num_embeddings, embedding_dim=None): + """ + Overview: + For encoding the position of elements into higher-dimensional vectors using PyTorch's Embedding layer. + This is typically used in Transformer models to add positional information to the input tokens. + The position encoding is initialized using a sinusoidal formula, as proposed in the "Attention is All You Need" paper. + """ + def __init__(self, num_embeddings: int, embedding_dim: int = None): + """ + Overview: + The initializer for PositionEncoder. It initializes an Embedding layer with a sinusoidal position encoding matrix + as its weights. + Arguments: + - num_embeddings (int): The maximum number of positions to be encoded, i.e., the number of rows in the position encoding matrix. + - embedding_dim (int, optional): The size of each position encoding vector, i.e., the number of columns in the position encoding matrix. + If not provided, it is set equal to 'num_embeddings'. + """ super(PositionEncoder, self).__init__() self.n_position = num_embeddings self.embedding_dim = self.n_position if embedding_dim is None else embedding_dim @@ -88,9 +197,15 @@ def __init__(self, num_embeddings, embedding_dim=None): @staticmethod def position_encoding_init(n_position, embedding_dim): - ''' Init the sinusoid position encoding table ''' - - # keep dim 0 for padding token position encoding zero vector + """ + Overview: + A helper function that generates the sinusoidal position encoding matrix. + Arguments: + - n_position (int): The maximum number of positions to be encoded. + - embedding_dim (int): The size of each position encoding vector. + Returns: + - (torch.Tensor): A tensor of shape (n_position, embedding_dim), where each row is the sinusoidal position encoding of the row index. + """ position_enc = np.array( [ [pos / np.power(10000, 2 * (j // 2) / embedding_dim) for j in range(embedding_dim)] @@ -102,17 +217,43 @@ def position_encoding_init(n_position, embedding_dim): return torch.from_numpy(position_enc).type(torch.FloatTensor) def forward(self, x: torch.Tensor): + """ + Overview: + The common computation graph of the PositionEncoder. It encodes the input tensor into positional vectors. + Arguments: + - x (torch.Tensor): The input tensor should be integers, and its maximum value should be less than 'num_embeddings' specified in the initializer. Each value in 'x' represents a position to be encoded. + Returns: + - (torch.Tensor): Return the x-th row of the embedding layer. + """ return self.position_enc(x) class TimeEncoder(nn.Module): - - def __init__(self, embedding_dim): + """ + Overview: + For encoding temporal or sequential data into higher-dimensional vectors using the sinusoidal position + encoding mechanism used in Transformer models. This is useful when working with time series data or sequences where + the position of each element (in this case time) is important. + """ + def __init__(self, embedding_dim: int): + """ + Overview: + The initializer for TimeEncoder. It initializes the position array which is used to scale the input data in the sinusoidal encoding function. + Arguments: + - embedding_dim (int): The size of each position encoding vector, i.e., the number of features in the encoded representation. + """ super(TimeEncoder, self).__init__() self.embedding_dim = embedding_dim self.position_array = torch.nn.Parameter(self.get_position_array(), requires_grad=False) def get_position_array(self): + """ + Overview: + A helper function that generates the position array used in the sinusoidal encoding function. Each element in the array is 1 / (10000^(2i/d)), + where i is the position in the array and d is the embedding dimension. This array is used to scale the input data in the encoding function. + Returns: + - (torch.Tensor): A tensor of shape (embedding_dim,) containing the position array. + """ x = torch.arange(0, self.embedding_dim, dtype=torch.float) x = x // 2 * 2 x = torch.div(x, self.embedding_dim) @@ -121,27 +262,50 @@ def get_position_array(self): return x def forward(self, x: torch.Tensor): + """ + Overview: + The common computation graph of the TimeEncoder. It encodes the input tensor into temporal vectors. + Arguments: + - x (torch.Tensor): The input tensor. Its data type should be one-dimensional, and each value in 'x' represents a timestamp or a position in sequence to be encoded. + Returns: + - (torch.Tensor): A tensor containing the temporal encoded vectors of the input tensor 'x'. + """ v = torch.zeros(size=(x.shape[0], self.embedding_dim), dtype=torch.float, device=x.device) assert len(x.shape) == 1 x = x.unsqueeze(dim=1) - v[:, 0::2] = torch.sin(x * self.position_array[0::2]) # even - v[:, 1::2] = torch.cos(x * self.position_array[1::2]) # odd + v[:, 0::2] = torch.sin(x * self.position_array[0::2]) # apply sin on even-indexed positions + v[:, 1::2] = torch.cos(x * self.position_array[1::2]) # apply cos on odd-indexed positions return v class UnsqueezeEncoder(nn.Module): - + """ + Overview: + For unsqueezes a tensor along the specified dimension and then optionally normalizes the tensor. + This is useful when we want to add an extra dimension to the input tensor and potentially scale its values. + """ def __init__(self, unsqueeze_dim: int = -1, norm_value: float = 1): + """ + Overview: + The initializer for UnsqueezeEncoder. + Arguments: + - unsqueeze_dim (int, optional): The dimension to unsqueeze. Default is -1, which unsqueezes at the last dimension. + - norm_value (float, optional): The value to normalize the tensor by. Default is 1, which means no normalization. + """ super(UnsqueezeEncoder, self).__init__() self.unsqueeze_dim = unsqueeze_dim self.norm_value = norm_value def forward(self, x: torch.Tensor): + """ + Overview: + The common computation graph of the UnsqueezeEncoder. It unsqueezes the input tensor along the specified dimension and then normalizes the tensor. + Arguments: + - x (torch.Tensor): The input tensor. + Returns: + - (torch.Tensor): The unsqueezed and normalized tensor. Its shape is the same as the input tensor, but with an extra dimension at the position specified by 'unsqueeze_dim'. Its values are the values of the input tensor divided by 'norm_value'. + """ x = x.float().unsqueeze(dim=self.unsqueeze_dim) if self.norm_value != 1: x = x / self.norm_value return x - - -if __name__ == '__main__': - pass diff --git a/lzero/policy/gobigger_random_policy.py b/lzero/policy/gobigger_random_policy.py index c10172894..b1aef6709 100644 --- a/lzero/policy/gobigger_random_policy.py +++ b/lzero/policy/gobigger_random_policy.py @@ -3,17 +3,12 @@ import numpy as np import torch -import torch.optim as optim -from ding.model import model_wrap from ding.policy.base_policy import Policy from ding.torch_utils import to_tensor from ding.utils import POLICY_REGISTRY -from torch.distributions import Categorical -from torch.nn import L1Loss from lzero.mcts import EfficientZeroMCTSCtree as MCTSCtree from lzero.mcts import EfficientZeroMCTSPtree as MCTSPtree -from lzero.model import ImageTransforms from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ DiscreteSupport, select_action, to_torch_float_tensor, ez_network_output_unpack, negative_cosine_similarity, prepare_obs, \ configure_optimizers @@ -28,7 +23,7 @@ class GoBiggerRandomPolicy(Policy): The policy class for GoBiggerRandom. """ - # The default_config for EfficientZero policy. + # The default_config for GoBiggerRandom policy. config = dict( model=dict( # (str) The model type. For 1-dimensional vector obs, we use mlp model. For 3-dimensional image obs, we use conv model. @@ -180,28 +175,6 @@ def default_model(self) -> Tuple[str, List[str]]: """ return 'GoBiggerEfficientZeroModel', ['lzero.model.gobigger.gobigger_efficientzero_model'] - def _init_learn(self) -> None: - """ - Overview: - Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. - """ - pass - - def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: - """ - Overview: - The forward function for learning policy in learn mode, which is the core of the learning process. - The data is sampled from replay buffer. - The loss is calculated by the loss function and the loss is backpropagated to update the model. - Arguments: - - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. - The first tensor is the current_batch, the second tensor is the target_batch. - Returns: - - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ - current learning loss and learning statistics. - """ - pass - def _init_collect(self) -> None: """ Overview: @@ -329,60 +302,27 @@ def _init_eval(self) -> None: else: self._mcts_eval = MCTSPtree(self._cfg) + # be compatible with DI-engine Policy class + def _init_learn(self) -> None: + pass + + def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: + pass + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id=None): - """ - Overview: - The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. - Choosing the action with the highest value (argmax) rather than sampling during the eval mode. - Arguments: - - data (:obj:`torch.Tensor`): The input data, i.e. the observation. - - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. - - to_play (:obj:`int`): The player to play. - - ready_env_id (:obj:`list`): The id of the env that is ready to collect. - Shape: - - data (:obj:`torch.Tensor`): - - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ - S is the number of stacked frames, H is the height of the image, W is the width of the image. - - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. - - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. - - to_play: :math:`(N, 1)`, where N is the number of collect_env. - - ready_env_id: None - Returns: - - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ - ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. - """ pass def _monitor_vars_learn(self) -> List[str]: - """ - Overview: - Register the variables to be monitored in learn mode. The registered variables will be logged in - tensorboard according to the return value ``_forward_learn``. - """ pass def _state_dict_learn(self) -> Dict[str, Any]: - """ - Overview: - Return the state_dict of learn mode, usually including model and optimizer. - Returns: - - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. - """ pass def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: - """ - Overview: - Load the state_dict variable into policy learn mode. - Arguments: - - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before. - """ pass def _process_transition(self, obs, policy_output, timestep): - # be compatible with DI-engine Policy class pass def _get_train_sample(self, data): - # be compatible with DI-engine Policy class pass diff --git a/zoo/gobigger/config/gobigger_efficientzero_config.py b/zoo/gobigger/config/gobigger_efficientzero_config.py index 8fb60fce6..dfdadc105 100644 --- a/zoo/gobigger/config/gobigger_efficientzero_config.py +++ b/zoo/gobigger/config/gobigger_efficientzero_config.py @@ -77,6 +77,7 @@ ), cuda=True, mcts_ctree=True, + gumbel_algo=False, env_type='not_board_games', game_segment_length=400, random_collect_episode_num=random_collect_episode_num, @@ -105,11 +106,6 @@ log_policy=True, hook=dict(log_show_after_iter=10, ), ), ), - collect=dict(collector=dict(collect_print_freq=10, ), ), - eval=dict(evaluator=dict( - eval_freq=5000, - stop_value=10000000000, - ), ), ) atari_efficientzero_config = EasyDict(atari_efficientzero_config) main_config = atari_efficientzero_config diff --git a/zoo/gobigger/config/gobigger_eval_config.py b/zoo/gobigger/config/gobigger_eval_config.py index 3fc62da27..e32b71151 100644 --- a/zoo/gobigger/config/gobigger_eval_config.py +++ b/zoo/gobigger/config/gobigger_eval_config.py @@ -8,10 +8,15 @@ point to the ckpt file of the pretrained model, and an absolute path is recommended. In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. """ - # sez + # ez # from gobigger_efficientzero_config import main_config, create_config + + # sez + # from gobigger_sampled_efficientzero_config import main_config, create_config + + # mz from gobigger_muzero_config import main_config, create_config - model_path = "/path/ckpt/ckpt_best.pth.tar" + model_path = "exp_name/ckpt/ckpt_best.pth.tar" returns_mean_seeds = [] returns_seeds = [] diff --git a/zoo/gobigger/config/gobigger_muzero_config.py b/zoo/gobigger/config/gobigger_muzero_config.py index b16030f34..6effc4dc5 100644 --- a/zoo/gobigger/config/gobigger_muzero_config.py +++ b/zoo/gobigger/config/gobigger_muzero_config.py @@ -107,11 +107,6 @@ log_policy=True, hook=dict(log_show_after_iter=10, ), ), ), - collect=dict(collector=dict(collect_print_freq=10, ), ), - eval=dict(evaluator=dict( - eval_freq=5000, - stop_value=10000000000, - ), ), ) atari_muzero_config = EasyDict(atari_muzero_config) main_config = atari_muzero_config diff --git a/zoo/gobigger/config/gobigger_sampled_efficientzero_config.py b/zoo/gobigger/config/gobigger_sampled_efficientzero_config.py index 97d17159e..f847077fc 100644 --- a/zoo/gobigger/config/gobigger_sampled_efficientzero_config.py +++ b/zoo/gobigger/config/gobigger_sampled_efficientzero_config.py @@ -100,11 +100,6 @@ log_policy=True, hook=dict(log_show_after_iter=10, ), ), ), - collect=dict(collector=dict(collect_print_freq=10, ), ), - eval=dict(evaluator=dict( - eval_freq=5000, - stop_value=10000000000, - ), ), ) atari_sampled_efficientzero_config = EasyDict(atari_sampled_efficientzero_config) main_config = atari_sampled_efficientzero_config diff --git a/zoo/gobigger/env/gobigger_env.py b/zoo/gobigger/env/gobigger_env.py index d996a7619..29153163b 100644 --- a/zoo/gobigger/env/gobigger_env.py +++ b/zoo/gobigger/env/gobigger_env.py @@ -1,9 +1,15 @@ import gym import numpy as np +from ditk import logging from ding.envs import BaseEnv, BaseEnvTimestep from ding.utils import ENV_REGISTRY -from gobigger.envs import GoBiggerEnv import math +try: + from gobigger.envs import GoBiggerEnv +except ImportError: + import sys + logging.warning("not found gobigger package, please install it through `pip install git+https://github.com/opendilab/GoBigger.git") + sys.exit(1) @ENV_REGISTRY.register('gobigger_lightzero') diff --git a/zoo/gobigger/env/gobigger_rule_bot.py b/zoo/gobigger/env/gobigger_rule_bot.py index 2bc8894f8..916d57bae 100644 --- a/zoo/gobigger/env/gobigger_rule_bot.py +++ b/zoo/gobigger/env/gobigger_rule_bot.py @@ -5,7 +5,7 @@ import math import queue import random -from pygame.math import Vector2 +import numpy as np from typing import List, Dict, Any, Optional, Tuple, Union from collections import namedtuple from collections import defaultdict @@ -116,7 +116,7 @@ def step(self, obs): if min_food_ball is not None: direction = (min_food_ball['position'] - my_clone_balls[0]['position']) else: - direction = (Vector2(0, 0) - my_clone_balls[0]['position']) + direction = (np.array([0, 0]) - my_clone_balls[0]['position']) action_random = random.random() if action_random < 0.02: action_type = 1 @@ -124,12 +124,13 @@ def step(self, obs): action_type = 2 else: action_type = 0 - if direction.length() > 0: - direction = direction.normalize() + if np.linalg.norm(direction) > 0: + direction = direction / np.linalg.norm(direction) else: - direction = Vector2(1, 1).normalize() - direction = self.add_noise_to_direction(direction).normalize() - self.actions_queue.put([direction.x, direction.y, action_type]) + direction = np.array([1,1]) / np.linalg.norm(np.array([1,1])) + direction = self.add_noise_to_direction(direction) + direction = direction / np.linalg.norm(direction) + self.actions_queue.put([direction[0], direction[1], action_type]) action_ret = self.actions_queue.get() return {self.game_player_id: action_ret} @@ -152,7 +153,7 @@ def process_thorns_balls(self, thorns_balls, my_max_clone_ball): min_thorns_ball = None for thorns_ball in thorns_balls: if self.can_eat(my_max_clone_ball['radius'], thorns_ball['radius']): - distance = (thorns_ball['position'] - my_max_clone_ball['position']).length() + distance = np.linalg.norm((thorns_ball['position'] - my_max_clone_ball['position'])) if distance < min_distance: min_distance = distance min_thorns_ball = copy.deepcopy(thorns_ball) @@ -162,7 +163,7 @@ def process_food_balls(self, food_balls, my_max_clone_ball): min_distance = 10000 min_food_ball = None for food_ball in food_balls: - distance = (food_ball['position'] - my_max_clone_ball['position']).length() + distance = np.linalg.norm(food_ball['position'] - my_max_clone_ball['position']) if distance < min_distance: min_distance = distance min_food_ball = copy.deepcopy(food_ball) @@ -175,7 +176,7 @@ def preprocess(self, overlap): new_overlap[k] = [] for index, vv in enumerate(v): tmp = {} - tmp['position'] = Vector2(vv[0], vv[1]) + tmp['position'] = np.array([vv[0], vv[1]]) tmp['radius'] = vv[2] tmp['player'] = int(vv[-2]) tmp['team'] = int(vv[-1]) @@ -184,7 +185,7 @@ def preprocess(self, overlap): new_overlap[k] = [] for index, vv in enumerate(v): tmp = {} - tmp['position'] = Vector2(vv[0], vv[1]) + tmp['position'] = np.array([vv[0], vv[1]]) tmp['radius'] = vv[2] new_overlap[k].append(tmp) return new_overlap @@ -195,13 +196,13 @@ def preprocess_tuple2vector(self, overlap): new_overlap[k] = [] for index, vv in enumerate(v): new_overlap[k].append(vv) - new_overlap[k][index]['position'] = Vector2(*vv['position']) + new_overlap[k][index]['position'] = np.array(*vv['position']) return new_overlap def add_noise_to_direction(self, direction, noise_ratio=0.1): - direction = direction + Vector2( - ((random.random() * 2 - 1) * noise_ratio) * direction.x, - ((random.random() * 2 - 1) * noise_ratio) * direction.y + direction = direction + np.array( + ((random.random() * 2 - 1) * noise_ratio) * direction[0], + ((random.random() * 2 - 1) * noise_ratio) * direction[1] ) return direction diff --git a/zoo/gobigger/env/test_gobbiger_env.py b/zoo/gobigger/env/test_gobigger_env.py similarity index 100% rename from zoo/gobigger/env/test_gobbiger_env.py rename to zoo/gobigger/env/test_gobigger_env.py From 377f66459d62c96de6468a4734a3fe1220782f2e Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Fri, 30 Jun 2023 17:41:53 +0800 Subject: [PATCH 25/54] polish(yzj): use ding scatter_model, muzero_collector add multi_agent option --- .../gobigger/network/gobigger_encoder.py | 3 +- .../gobigger/network/scatter_connection.py | 89 --- lzero/worker/gobigger_muzero_collector.py | 546 +----------------- lzero/worker/gobigger_muzero_evaluator.py | 6 +- lzero/worker/muzero_collector.py | 342 +++++++---- .../config/gobigger_efficientzero_config.py | 3 + zoo/gobigger/env/gobigger_env.py | 3 +- 7 files changed, 257 insertions(+), 735 deletions(-) delete mode 100644 lzero/model/gobigger/network/scatter_connection.py diff --git a/lzero/model/gobigger/network/gobigger_encoder.py b/lzero/model/gobigger/network/gobigger_encoder.py index 7feeebd7f..6fda88c7e 100644 --- a/lzero/model/gobigger/network/gobigger_encoder.py +++ b/lzero/model/gobigger/network/gobigger_encoder.py @@ -4,11 +4,10 @@ from torch import Tensor from ding.utils.default_helper import deep_merge_dicts from ding.torch_utils import MLP, fc_block, conv2d_block, ResBlock -from ding.torch_utils import Transformer +from ding.torch_utils import Transformer, ScatterConnection from easydict import EasyDict from .encoder import SignBinaryEncoder, BinaryEncoder, OnehotEncoder, TimeEncoder, UnsqueezeEncoder -from .scatter_connection import ScatterConnection def sequence_mask(lengths: torch.Tensor, max_len: Optional[int] = None): diff --git a/lzero/model/gobigger/network/scatter_connection.py b/lzero/model/gobigger/network/scatter_connection.py deleted file mode 100644 index 1005d1719..000000000 --- a/lzero/model/gobigger/network/scatter_connection.py +++ /dev/null @@ -1,89 +0,0 @@ -from typing import Tuple - -import torch -import torch.nn as nn - - -class ScatterConnection(nn.Module): - r""" - Overview: - Scatter feature to its corresponding location - In alphastar, each entity is embedded into a tensor, these tensors are scattered into a feature map - with map size - """ - - def __init__(self, scatter_type='add') -> None: - r""" - Overview: - Init class - Arguments: - - scatter_type (:obj:`str`): add or cover, if two entities have same location, scatter type decides the - first one should be covered or added to second one - """ - super(ScatterConnection, self).__init__() - self.scatter_type = scatter_type - assert self.scatter_type in ['cover', 'add'] - - def xy_forward( - self, x: torch.Tensor, spatial_size: Tuple[int, int], coord_x: torch.Tensor, coord_y - ) -> torch.Tensor: - device = x.device - BatchSize, Num, EmbeddingSize = x.shape - x = x.permute(0, 2, 1) - H, W = spatial_size - indices = (coord_x * W + coord_y).long() - indices = indices.unsqueeze(dim=1).repeat(1, EmbeddingSize, 1) - output = torch.zeros(size=(BatchSize, EmbeddingSize, H, W), device=device).view(BatchSize, EmbeddingSize, H * W) - if self.scatter_type == 'cover': - output.scatter_(dim=2, index=indices, src=x) - elif self.scatter_type == 'add': - output.scatter_add_(dim=2, index=indices, src=x) - output = output.view(BatchSize, EmbeddingSize, H, W) - return output - - def forward(self, x: torch.Tensor, spatial_size: Tuple[int, int], location: torch.Tensor) -> torch.Tensor: - """ - Overview: - scatter x into a spatial feature map - Arguments: - - x (:obj:`tensor`): input tensor :math: `(B, M, N)` where `M` means the number of entity, `N` means\ - the dimension of entity attributes - - spatial_size (:obj:`tuple`): Tuple[H, W], the size of spatial feature x will be scattered into - - location (:obj:`tensor`): :math: `(B, M, 2)` torch.LongTensor, each location should be (y, x) - Returns: - - output (:obj:`tensor`): :math: `(B, N, H, W)` where `H` and `W` are spatial_size, return the\ - scattered feature map - Shapes: - - Input: :math: `(B, M, N)` where `M` means the number of entity, `N` means\ - the dimension of entity attributes - - Size: Tuple[H, W] - - Location: :math: `(B, M, 2)` torch.LongTensor, each location should be (y, x) - - Output: :math: `(B, N, H, W)` where `H` and `W` are spatial_size - - .. note:: - when there are some overlapping in locations, ``cover`` mode will result in the loss of information, we - use the addition as temporal substitute. - """ - device = x.device - BatchSize, Num, EmbeddingSize = x.shape - x = x.permute(0, 2, 1) - H, W = spatial_size - indices = location[:, :, 1] + location[:, :, 0] * W - indices = indices.unsqueeze(dim=1).repeat(1, EmbeddingSize, 1) - output = torch.zeros(size=(BatchSize, EmbeddingSize, H, W), device=device).view(BatchSize, EmbeddingSize, H * W) - if self.scatter_type == 'cover': - output.scatter_(dim=2, index=indices, src=x) - elif self.scatter_type == 'add': - output.scatter_add_(dim=2, index=indices, src=x) - output = output.view(BatchSize, EmbeddingSize, H, W) - return output - - -if __name__ == '__main__': - scatter_conn = ScatterConnection() - BatchSize, Num, EmbeddingSize = 10, 20, 3 - SpatialSize = (13, 17) - for _ in range(10): - x = torch.randn(size=(BatchSize, Num, EmbeddingSize)) - locations = torch.randint(low=0, high=12, size=(BatchSize, Num, 2)) - scatter_conn.forward(x, SpatialSize, location=locations) diff --git a/lzero/worker/gobigger_muzero_collector.py b/lzero/worker/gobigger_muzero_collector.py index 49c6cfbba..835feb1e8 100644 --- a/lzero/worker/gobigger_muzero_collector.py +++ b/lzero/worker/gobigger_muzero_collector.py @@ -7,7 +7,7 @@ from ding.envs import BaseEnvManager from ding.torch_utils import to_ndarray from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY -from ding.worker.collector.base_serial_collector import ISerialCollector +from .muzero_collector import MuZeroCollector from torch.nn import L1Loss from lzero.mcts.buffer.game_segment import GameSegment @@ -16,7 +16,7 @@ @SERIAL_COLLECTOR_REGISTRY.register('gobigger_episode_muzero') -class GoBiggerMuZeroCollector(ISerialCollector): +class GoBiggerMuZeroCollector(MuZeroCollector): """ Overview: The Collector for GoBigger MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. @@ -52,139 +52,8 @@ def __init__( - exp_name (:obj:`str`): Experiment name, which is used to indicate output directory. - policy_config: Config of game. """ - self._exp_name = exp_name - self._instance_name = instance_name - self._collect_print_freq = collect_print_freq - self._timer = EasyTimer() - self._end_flag = False + super().__init__(collect_print_freq, env, policy, tb_logger, exp_name, instance_name, policy_config) - if tb_logger is not None: - self._logger, _ = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False - ) - self._tb_logger = tb_logger - else: - self._logger, self._tb_logger = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name - ) - - self.policy_config = policy_config - - self.reset(policy, env) - - def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: - """ - Overview: - Reset the environment. - If _env is None, reset the old environment. - If _env is not None, replace the old environment in the collector with the new passed \ - in environment and launch. - Arguments: - - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ - env_manager(BaseEnvManager) - """ - if _env is not None: - self._env = _env - self._env.launch() - self._env_num = self._env.env_num - else: - self._env.reset() - - def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: - """ - Overview: - Reset the policy. - If _policy is None, reset the old policy. - If _policy is not None, replace the old policy in the collector with the new passed in policy. - Arguments: - - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy - """ - assert hasattr(self, '_env'), "please set env first" - if _policy is not None: - self._policy = _policy - self._default_n_episode = _policy.get_attribute('cfg').get('n_episode', None) - self._logger.debug( - 'Set default n_episode mode(n_episode({}), env_num({}))'.format(self._default_n_episode, self._env_num) - ) - self._policy.reset() - - def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: - """ - Overview: - Reset the environment and policy. - If _env is None, reset the old environment. - If _env is not None, replace the old environment in the collector with the new passed \ - in environment and launch. - If _policy is None, reset the old policy. - If _policy is not None, replace the old policy in the collector with the new passed in policy. - Arguments: - - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy - - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ - env_manager(BaseEnvManager) - """ - if _env is not None: - self.reset_env(_env) - if _policy is not None: - self.reset_policy(_policy) - - self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)} - - self._episode_info = [] - self._total_envstep_count = 0 - self._total_episode_count = 0 - self._total_duration = 0 - self._last_train_iter = 0 - self._end_flag = False - - # A game_segment_pool implementation based on the deque structure. - self.game_segment_pool = deque(maxlen=int(1e6)) - self.unroll_plus_td_steps = self.policy_config.num_unroll_steps + self.policy_config.td_steps - - def _reset_stat(self, env_id: int) -> None: - """ - Overview: - Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool\ - and env_info. Reset these states according to env_id. You can refer to base_serial_collector\ - to get more messages. - Arguments: - - env_id (:obj:`int`): the id where we need to reset the collector's state - """ - self._env_info[env_id] = {'time': 0., 'step': 0} - - @property - def envstep(self) -> int: - """ - Overview: - Print the total envstep count. - Return: - - envstep (:obj:`int`): the total envstep count - """ - return self._total_envstep_count - - def close(self) -> None: - """ - Overview: - Close the collector. If end_flag is False, close the environment, flush the tb_logger\ - and close the tb_logger. - """ - if self._end_flag: - return - self._end_flag = True - self._env.close() - self._tb_logger.flush() - self._tb_logger.close() - - def __del__(self) -> None: - """ - Overview: - Execute the close command and close the collector. __del__ is automatically called to \ - destroy the collector instance when the collector finishes its work - """ - self.close() - - # ============================================================== - # MCTS+RL related core code - # ============================================================== def _compute_priorities(self, i, agent_id, pred_values_lst, search_values_lst): """ Overview: @@ -252,7 +121,7 @@ def pad_and_save_last_trajectory( rew: game_segment_length + stack + num_unroll_steps + td_steps -1 20 +5+3-1 action: game_segment_length -> 20 root_values: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 - child_visits: game_segment_length + num_unroll_steps -> 20 +5 + child_visits: game_segment_length + num_unroll_steps -> 20 +5 to_play: game_segment_length -> 20 action_mask: game_segment_length -> 20 """ @@ -267,410 +136,3 @@ def pad_and_save_last_trajectory( last_game_priorities[i][agent_id] = None return None - - def collect(self, - n_episode: Optional[int] = None, - train_iter: int = 0, - policy_kwargs: Optional[dict] = None) -> List[Any]: - """ - Overview: - Collect `n_episode` data with policy_kwargs, which is already trained `train_iter` iterations. - Arguments: - - n_episode (:obj:`int`): the number of collecting data episode. - - train_iter (:obj:`int`): the number of training iteration. - - policy_kwargs (:obj:`dict`): the keyword args for policy forward. - Returns: - - return_data (:obj:`List`): A list containing collected game_segments - """ - if n_episode is None: - if self._default_n_episode is None: - raise RuntimeError("Please specify collect n_episode") - else: - n_episode = self._default_n_episode - assert n_episode >= self._env_num, "Please make sure n_episode >= env_num{}/{}".format(n_episode, self._env_num) - if policy_kwargs is None: - policy_kwargs = {} - temperature = policy_kwargs['temperature'] - epsilon = policy_kwargs['epsilon'] - - collected_episode = 0 - env_nums = self._env_num - - # initializations - init_obs = self._env.ready_obs - - retry_waiting_time = 0.001 - while len(init_obs.keys()) != self._env_num: - # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # len(self._env.ready_obs), especially in tictactoe env. - self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) - self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) - time.sleep(retry_waiting_time) - self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) - self._logger.info( - 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states) - ) - init_obs = self._env.ready_obs - - action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} - to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} - agent_num = len(init_obs[0]['action_mask']) - game_segments = [ - [ - GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) for _ in range(agent_num) - ] for _ in range(env_nums) - ] - # stacked observation windows in reset stage for init game_segments - observation_window_stack = [[[] for _ in range(agent_num)] for _ in range(env_nums)] - for env_id in range(env_nums): - for agent_id in range(agent_num): - observation_window_stack[env_id][agent_id] = deque( - [ - to_ndarray(init_obs[env_id]['observation'][agent_id]) - for _ in range(self.policy_config.model.frame_stack_num) - ], - maxlen=self.policy_config.model.frame_stack_num - ) - game_segments[env_id][agent_id].reset(observation_window_stack[env_id][agent_id]) - - dones = np.array([False for _ in range(env_nums)]) - last_game_segments = [[None for _ in range(agent_num)] for _ in range(env_nums)] - last_game_priorities = [[None for _ in range(agent_num)] for _ in range(env_nums)] - # for priorities in self-play - search_values_lst = [[[] for _ in range(agent_num)] for _ in range(env_nums)] - pred_values_lst = [[[] for _ in range(agent_num)] for _ in range(env_nums)] - - # some logs - eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros((env_nums, agent_num)) - self_play_moves = 0. - self_play_episodes = 0. - self_play_moves_max = 0 - self_play_visit_entropy = [] - total_transitions = 0 - - ready_env_id = set() - remain_episode = n_episode - - while True: - with self._timer: - # Get current ready env obs. - obs = self._env.ready_obs - new_available_env_id = set(obs.keys()).difference(ready_env_id) - ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) - remain_episode -= min(len(new_available_env_id), remain_episode) - - stack_obs = defaultdict(list) - for env_id in ready_env_id: - for agent_id in range(agent_num): - stack_obs[env_id].append(game_segments[env_id][agent_id].get_obs()) - - # stack_obs = {env_id: [game_segments[env_id][agent_id].get_obs() for agent_id in agent_num] for env_id in ready_env_id} - stack_obs = list(stack_obs.values()) - - action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} - to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} - action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] - to_play = [to_play_dict[env_id] for env_id in ready_env_id] - - stack_obs = to_ndarray(stack_obs) - - # stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) - - # stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() - - # ============================================================== - # policy forward - # ============================================================== - policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon) - actions_no_env_id = defaultdict(dict) - for k, v in policy_output.items(): - for agent_id, act in enumerate(v['action']): - actions_no_env_id[k][agent_id] = act - - # actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} - distributions_dict_no_env_id = {k: v['distributions'] for k, v in policy_output.items()} - if self.policy_config.sampled_algo: - root_sampled_actions_dict_no_env_id = { - k: v['root_sampled_actions'] - for k, v in policy_output.items() - } - value_dict_no_env_id = {k: v['value'] for k, v in policy_output.items()} - pred_value_dict_no_env_id = {k: v['pred_value'] for k, v in policy_output.items()} - visit_entropy_dict_no_env_id = { - k: v['visit_count_distribution_entropy'] - for k, v in policy_output.items() - } - - # TODO(pu): subprocess - actions = {} - distributions_dict = {} - if self.policy_config.sampled_algo: - root_sampled_actions_dict = {} - value_dict = {} - pred_value_dict = {} - visit_entropy_dict = {} - for index, env_id in enumerate(ready_env_id): - actions[env_id] = actions_no_env_id.pop(index) - distributions_dict[env_id] = distributions_dict_no_env_id.pop(index) - if self.policy_config.sampled_algo: - root_sampled_actions_dict[env_id] = root_sampled_actions_dict_no_env_id.pop(index) - value_dict[env_id] = value_dict_no_env_id.pop(index) - pred_value_dict[env_id] = pred_value_dict_no_env_id.pop(index) - visit_entropy_dict[env_id] = visit_entropy_dict_no_env_id.pop(index) - - # ============================================================== - # Interact with env. - # ============================================================== - timesteps = self._env.step(actions) - - interaction_duration = self._timer.value / len(timesteps) - for env_id, timestep in timesteps.items(): - with self._timer: - if timestep.info.get('abnormal', False): - # If there is an abnormal timestep, reset all the related variables(including this env). - # suppose there is no reset param, just reset this env - self._env.reset({env_id: None}) - self._policy.reset([env_id]) - self._reset_stat(env_id) - self._logger.info('Env{} returns a abnormal step, its info is {}'.format(env_id, timestep.info)) - continue - obs, reward, done, info = timestep.obs, timestep.reward, timestep.done, timestep.info - - if self.policy_config.sampled_algo: - for agent_id in range(agent_num): - game_segments[env_id][agent_id].store_search_stats( - distributions_dict[env_id][agent_id], value_dict[env_id][agent_id], - root_sampled_actions_dict[env_id][agent_id] - ) - else: - for agent_id in range(agent_num): - game_segments[env_id][agent_id].store_search_stats( - distributions_dict[env_id][agent_id], value_dict[env_id][agent_id] - ) - # append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t} - # in ``game_segments[env_id].init``, we have append o_{t} in ``self.obs_segment`` - for agent_id in range(agent_num): - game_segments[env_id][agent_id].append( - actions[env_id][agent_id], to_ndarray(obs['observation'][agent_id]), reward[agent_id], - action_mask_dict[env_id][agent_id], to_play_dict[env_id] - ) - - # NOTE: the position of code snippet is very important. - # the obs['action_mask'] and obs['to_play'] is corresponding to next action - action_mask_dict[env_id] = to_ndarray(obs['action_mask']) - to_play_dict[env_id] = to_ndarray(obs['to_play']) - dones[env_id] = done - for agent_id in range(agent_num): - visit_entropies_lst[env_id][agent_id] += visit_entropy_dict[env_id][agent_id] - eps_steps_lst[env_id] += 1 - total_transitions += 1 - - if self.policy_config.use_priority and not self.policy_config.use_max_priority_for_new_data: - for agent_id in range(agent_num): - pred_values_lst[env_id][agent_id].append(pred_value_dict[env_id][agent_id]) - search_values_lst[env_id][agent_id].append(value_dict[env_id][agent_id]) - - # append the newest obs - for agent_id in range(agent_num): - observation_window_stack[env_id][agent_id].append(to_ndarray(obs['observation'][agent_id])) - - # ============================================================== - # we will save a game block if it is the end of the game or the next game block is finished. - # ============================================================== - - # if game block is full, we will save the last game block - for agent_id in range(agent_num): - if game_segments[env_id][agent_id].is_full(): - # pad over last block trajectory - if last_game_segments[env_id][agent_id] is not None: - # TODO(pu): return the one game block - self.pad_and_save_last_trajectory( - env_id, agent_id, last_game_segments, last_game_priorities, game_segments, dones - ) - - # calculate priority - priorities = self._compute_priorities(env_id, agent_id, pred_values_lst, search_values_lst) - # pred_values_lst[env_id] = [] - # search_values_lst[env_id] = [] - search_values_lst = [[[] for _ in range(agent_num)] for _ in range(env_nums)] - pred_values_lst = [[[] for _ in range(agent_num)] for _ in range(env_nums)] - - # the current game_segments become last_game_segment - last_game_segments[env_id][agent_id] = game_segments[env_id][agent_id] - last_game_priorities[env_id][agent_id] = priorities - - # create new GameSegment - game_segments[env_id][agent_id] = GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) - game_segments[env_id][agent_id].reset(observation_window_stack[env_id][agent_id]) - - self._env_info[env_id]['step'] += 1 - self._total_envstep_count += 1 - self._env_info[env_id]['time'] += self._timer.value + interaction_duration - if timestep.done: - self._total_episode_count += 1 - reward = timestep.info['eval_episode_return'][0] - info = { - 'reward': reward, - 'time': self._env_info[env_id]['time'], - 'step': self._env_info[env_id]['step'], - 'visit_entropy': visit_entropies_lst[env_id] / eps_steps_lst[env_id], - } - collected_episode += 1 - self._episode_info.append(info) - - # ============================================================== - # if it is the end of the game, we will save the game block - # ============================================================== - - # NOTE: put the penultimate game block in one episode into the trajectory_pool - # pad over 2th last game_segment using the last game_segment - for agent_id in range(agent_num): - if last_game_segments[env_id] is not None: - self.pad_and_save_last_trajectory( - env_id, agent_id, last_game_segments, last_game_priorities, game_segments, dones - ) - - # store current block trajectory - priorities = self._compute_priorities(env_id, agent_id, pred_values_lst, search_values_lst) - - # NOTE: put the last game block in one episode into the trajectory_pool - game_segments[env_id][agent_id].game_segment_to_array() - - # assert len(game_segments[env_id]) == len(priorities) - # NOTE: save the last game block in one episode into the trajectory_pool if it's not null - if len(game_segments[env_id][agent_id].reward_segment) != 0: - self.game_segment_pool.append((game_segments[env_id][agent_id], priorities, dones[env_id])) - - # print(game_segments[env_id].reward_segment) - # reset the finished env and init game_segments - if n_episode > self._env_num: - # Get current ready env obs. - init_obs = self._env.ready_obs - retry_waiting_time = 0.001 - while len(init_obs.keys()) != self._env_num: - # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # len(self._env.ready_obs), especially in tictactoe env. - self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) - self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) - time.sleep(retry_waiting_time) - self._logger.info( - '=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10 - ) - self._logger.info( - 'After sleeping {}s, the current _env_states is {}'.format( - retry_waiting_time, self._env._env_states - ) - ) - init_obs = self._env.ready_obs - - new_available_env_id = set(init_obs.keys()).difference(ready_env_id) - ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) - remain_episode -= min(len(new_available_env_id), remain_episode) - - action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) - to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) - - for agent_id in range(agent_num): - game_segments[env_id][agent_id] = GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) - observation_window_stack[env_id][agent_id] = deque( - [ - init_obs[env_id]['observation'][agent_id] - for _ in range(self.policy_config.model.frame_stack_num) - ], - maxlen=self.policy_config.model.frame_stack_num - ) - game_segments[env_id][agent_id].reset(observation_window_stack[env_id][agent_id]) - last_game_segments[env_id] = [None for _ in range(agent_num)] - last_game_priorities[env_id] = [None for _ in range(agent_num)] - - # log - self_play_moves_max = max(self_play_moves_max, eps_steps_lst[env_id]) - self_play_visit_entropy.append(visit_entropies_lst[env_id] / eps_steps_lst[env_id]) - self_play_moves += eps_steps_lst[env_id] - self_play_episodes += 1 - - pred_values_lst[env_id] = [] - search_values_lst[env_id] = [] - eps_steps_lst[env_id] = 0 - visit_entropies_lst[env_id] = 0 - - # Env reset is done by env_manager automatically - self._policy.reset([env_id]) - self._reset_stat(env_id) - # TODO(pu): subprocess mode, when n_episode > self._env_num, occasionally the ready_env_id=() - # and the stack_obs is np.array(None, dtype=object) - ready_env_id.remove(env_id) - - if collected_episode >= n_episode: - # [data, meta_data] - return_data = [self.game_segment_pool[i][0] for i in range(len(self.game_segment_pool))], [ - { - 'priorities': self.game_segment_pool[i][1], - 'done': self.game_segment_pool[i][2], - 'unroll_plus_td_steps': self.unroll_plus_td_steps - } for i in range(len(self.game_segment_pool)) - ] - # for i in range(len(self.game_segment_pool)): - # print(self.game_segment_pool[i][0].obs_segment.__len__()) - # print(self.game_segment_pool[i][0].reward_segment) - # for i in range(len(return_data[0])): - # print(return_data[0][i].reward_segment) - break - # log - self._output_log(train_iter) - return return_data - - def _output_log(self, train_iter: int) -> None: - """ - Overview: - Print the output log information. You can refer to Docs/Best Practice/How to understand\ - training generated folders/Serial mode/log/collector for more details. - Arguments: - - train_iter (:obj:`int`): the number of training iteration. - """ - if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: - self._last_train_iter = train_iter - episode_count = len(self._episode_info) - envstep_count = sum([d['step'] for d in self._episode_info]) - duration = sum([d['time'] for d in self._episode_info]) - episode_reward = [d['reward'] for d in self._episode_info] - visit_entropy = [d['visit_entropy'][0] for d in self._episode_info] - self._total_duration += duration - info = { - 'episode_count': episode_count, - 'envstep_count': envstep_count, - 'avg_envstep_per_episode': envstep_count / episode_count, - 'avg_envstep_per_sec': envstep_count / duration, - 'avg_episode_per_sec': episode_count / duration, - 'collect_time': duration, - 'reward_mean': np.mean(episode_reward), - 'reward_std': np.std(episode_reward), - 'reward_max': np.max(episode_reward), - 'reward_min': np.min(episode_reward), - 'total_envstep_count': self._total_envstep_count, - 'total_episode_count': self._total_episode_count, - 'total_duration': self._total_duration, - 'visit_entropy': np.mean(visit_entropy), - # 'each_reward': episode_reward, - } - self._episode_info.clear() - self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) - for k, v in info.items(): - if k in ['each_reward']: - continue - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) - if k in ['total_envstep_count']: - continue - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) diff --git a/lzero/worker/gobigger_muzero_evaluator.py b/lzero/worker/gobigger_muzero_evaluator.py index b57265203..ad9665f36 100644 --- a/lzero/worker/gobigger_muzero_evaluator.py +++ b/lzero/worker/gobigger_muzero_evaluator.py @@ -348,7 +348,7 @@ def eval( if t.done: # Env reset is done by env_manager automatically. self._policy.reset([env_id]) - reward = t.info['eval_episode_return'][0] + reward = t.info['eval_episode_return'] if 'episode_info' in t.info: eval_monitor.update_info(env_id, t.info['episode_info']) eval_monitor.update_reward(env_id, reward) @@ -633,8 +633,8 @@ def eval_vsbot( if t.done: # Env reset is done by env_manager automatically. self._policy.reset([env_id]) - reward = t.info['eval_episode_return'][0] - bot_reward = t.info['eval_episode_return'][1] + reward = t.info['eval_episode_return'] + bot_reward = t.info['eval_bot_episode_return'] eat_info[env_id] = t.info['eats'] if 'episode_info' in t.info: eval_monitor.update_info(env_id, t.info['episode_info']) diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index fffe6e84e..a269b1ad2 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -1,5 +1,5 @@ import time -from collections import deque, namedtuple +from collections import deque, namedtuple, defaultdict from typing import Optional, Any, List import numpy as np @@ -67,6 +67,10 @@ def __init__( ) self.policy_config = policy_config + if 'multi_agent' in self.policy_config.keys() and self.policy_config.multi_agent: + self._multi_agent = self.policy_config.multi_agent + else: + self._multi_agent = False self.reset(policy, env) @@ -292,6 +296,7 @@ def collect(self, if policy_kwargs is None: policy_kwargs = {} temperature = policy_kwargs['temperature'] + epsilon = policy_kwargs['epsilon'] collected_episode = 0 env_nums = self._env_num @@ -315,34 +320,68 @@ def collect(self, action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} - game_segments = [ - GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) for _ in range(env_nums) - ] - # stacked observation windows in reset stage for init game_segments - observation_window_stack = [[] for _ in range(env_nums)] - for env_id in range(env_nums): - observation_window_stack[env_id] = deque( - [to_ndarray(init_obs[env_id]['observation']) for _ in range(self.policy_config.model.frame_stack_num)], - maxlen=self.policy_config.model.frame_stack_num - ) - - game_segments[env_id].reset(observation_window_stack[env_id]) - - dones = np.array([False for _ in range(env_nums)]) - last_game_segments = [None for _ in range(env_nums)] - last_game_priorities = [None for _ in range(env_nums)] - # for priorities in self-play - search_values_lst = [[] for _ in range(env_nums)] - pred_values_lst = [[] for _ in range(env_nums)] + if self._multi_agent: + agent_num = len(init_obs[0]['action_mask']) + game_segments = [ + [ + GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) for _ in range(agent_num) + ] for _ in range(env_nums) + ] + # stacked observation windows in reset stage for init game_segments + observation_window_stack = [[[] for _ in range(agent_num)] for _ in range(env_nums)] + for env_id in range(env_nums): + for agent_id in range(agent_num): + observation_window_stack[env_id][agent_id] = deque( + [ + to_ndarray(init_obs[env_id]['observation'][agent_id]) + for _ in range(self.policy_config.model.frame_stack_num) + ], + maxlen=self.policy_config.model.frame_stack_num + ) + game_segments[env_id][agent_id].reset(observation_window_stack[env_id][agent_id]) + + dones = np.array([False for _ in range(env_nums)]) + last_game_segments = [[None for _ in range(agent_num)] for _ in range(env_nums)] + last_game_priorities = [[None for _ in range(agent_num)] for _ in range(env_nums)] + # for priorities in self-play + search_values_lst = [[[] for _ in range(agent_num)] for _ in range(env_nums)] + pred_values_lst = [[[] for _ in range(agent_num)] for _ in range(env_nums)] + else: + # some logs + game_segments = [ + GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) for _ in range(env_nums) + ] + # stacked observation windows in reset stage for init game_segments + observation_window_stack = [[] for _ in range(env_nums)] + for env_id in range(env_nums): + observation_window_stack[env_id] = deque( + [to_ndarray(init_obs[env_id]['observation']) for _ in range(self.policy_config.model.frame_stack_num)], + maxlen=self.policy_config.model.frame_stack_num + ) + game_segments[env_id].reset(observation_window_stack[env_id]) + + dones = np.array([False for _ in range(env_nums)]) + last_game_segments = [None for _ in range(env_nums)] + last_game_priorities = [None for _ in range(env_nums)] + # for priorities in self-play + search_values_lst = [[] for _ in range(env_nums)] + pred_values_lst = [[] for _ in range(env_nums)] if self.policy_config.gumbel_algo: improved_policy_lst = [[] for _ in range(env_nums)] # some logs - eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros(env_nums) + if self._multi_agent: + eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros((env_nums, agent_num)) + else: + eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros(env_nums) if self.policy_config.gumbel_algo: completed_value_lst = np.zeros(env_nums) self_play_moves = 0. @@ -361,8 +400,13 @@ def collect(self, new_available_env_id = set(obs.keys()).difference(ready_env_id) ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) remain_episode -= min(len(new_available_env_id), remain_episode) - - stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + if self._multi_agent: + stack_obs = defaultdict(list) + for env_id in ready_env_id: + for agent_id in range(agent_num): + stack_obs[env_id].append(game_segments[env_id][agent_id].get_obs()) + else: + stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} stack_obs = list(stack_obs.values()) action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} @@ -372,16 +416,21 @@ def collect(self, stack_obs = to_ndarray(stack_obs) - stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) - - stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() + if self.policy_config.model.model_type and self.policy_config.model.model_type in ['conv', 'mlp']: + stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) + stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() # ============================================================== # policy forward # ============================================================== - policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play) - - actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon) + if self._multi_agent: + actions_no_env_id = defaultdict(dict) + for k, v in policy_output.items(): + for agent_id, act in enumerate(v['action']): + actions_no_env_id[k][agent_id] = act + else: + actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} distributions_dict_no_env_id = {k: v['distributions'] for k, v in policy_output.items()} if self.policy_config.sampled_algo: root_sampled_actions_dict_no_env_id = { @@ -444,19 +493,39 @@ def collect(self, obs, reward, done, info = timestep.obs, timestep.reward, timestep.done, timestep.info if self.policy_config.sampled_algo: - game_segments[env_id].store_search_stats( - distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id] - ) + if self._multi_agent: + for agent_id in range(agent_num): + game_segments[env_id][agent_id].store_search_stats( + distributions_dict[env_id][agent_id], value_dict[env_id][agent_id], + root_sampled_actions_dict[env_id][agent_id] + ) + else: + game_segments[env_id].store_search_stats( + distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id] + ) elif self.policy_config.gumbel_algo: game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], improved_policy = improved_policy_dict[env_id]) else: - game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id]) + if self._multi_agent: + for agent_id in range(agent_num): + game_segments[env_id][agent_id].store_search_stats( + distributions_dict[env_id][agent_id], value_dict[env_id][agent_id] + ) + else: + game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id]) # append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t} # in ``game_segments[env_id].init``, we have append o_{t} in ``self.obs_segment`` - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id] - ) + if self._multi_agent: + for agent_id in range(agent_num): + game_segments[env_id][agent_id].append( + actions[env_id][agent_id], to_ndarray(obs['observation'][agent_id]), reward[agent_id], + action_mask_dict[env_id][agent_id], to_play_dict[env_id] + ) + else: + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], + to_play_dict[env_id] + ) # NOTE: the position of code snippet is very important. # the obs['action_mask'] and obs['to_play'] is corresponding to next action @@ -464,53 +533,94 @@ def collect(self, to_play_dict[env_id] = to_ndarray(obs['to_play']) dones[env_id] = done - visit_entropies_lst[env_id] += visit_entropy_dict[env_id] - if self.policy_config.gumbel_algo: - completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id])) + if self._multi_agent: + for agent_id in range(agent_num): + visit_entropies_lst[env_id][agent_id] += visit_entropy_dict[env_id][agent_id] + else: + visit_entropies_lst[env_id] += visit_entropy_dict[env_id] + if self.policy_config.gumbel_algo: + completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id])) eps_steps_lst[env_id] += 1 total_transitions += 1 if self.policy_config.use_priority and not self.policy_config.use_max_priority_for_new_data: - pred_values_lst[env_id].append(pred_value_dict[env_id]) - search_values_lst[env_id].append(value_dict[env_id]) - if self.policy_config.gumbel_algo: - improved_policy_lst[env_id].append(improved_policy_dict[env_id]) + if self._multi_agent: + for agent_id in range(agent_num): + pred_values_lst[env_id][agent_id].append(pred_value_dict[env_id][agent_id]) + search_values_lst[env_id][agent_id].append(value_dict[env_id][agent_id]) + else: + pred_values_lst[env_id].append(pred_value_dict[env_id]) + search_values_lst[env_id].append(value_dict[env_id]) + if self.policy_config.gumbel_algo: + improved_policy_lst[env_id].append(improved_policy_dict[env_id]) # append the newest obs - observation_window_stack[env_id].append(to_ndarray(obs['observation'])) + if self._multi_agent: + observation_window_stack[env_id][agent_id].append(to_ndarray(obs['observation'][agent_id])) + else: + observation_window_stack[env_id].append(to_ndarray(obs['observation'])) # ============================================================== # we will save a game block if it is the end of the game or the next game block is finished. # ============================================================== # if game block is full, we will save the last game block - if game_segments[env_id].is_full(): - # pad over last block trajectory - if last_game_segments[env_id] is not None: - # TODO(pu): return the one game block - self.pad_and_save_last_trajectory( - env_id, last_game_segments, last_game_priorities, game_segments, dones - ) - - # calculate priority - priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) - pred_values_lst[env_id] = [] - search_values_lst[env_id] = [] - if self.policy_config.gumbel_algo: - improved_policy_lst[env_id] = [] - - # the current game_segments become last_game_segment - last_game_segments[env_id] = game_segments[env_id] - last_game_priorities[env_id] = priorities + if self._multi_agent: + for agent_id in range(agent_num): + if game_segments[env_id][agent_id].is_full(): + # pad over last block trajectory + if last_game_segments[env_id][agent_id] is not None: + # TODO(pu): return the one game block + self.pad_and_save_last_trajectory( + env_id, agent_id, last_game_segments, last_game_priorities, game_segments, dones + ) + + # calculate priority + priorities = self._compute_priorities(env_id, agent_id, pred_values_lst, search_values_lst) + # pred_values_lst[env_id] = [] + # search_values_lst[env_id] = [] + search_values_lst = [[[] for _ in range(agent_num)] for _ in range(env_nums)] + pred_values_lst = [[[] for _ in range(agent_num)] for _ in range(env_nums)] + + # the current game_segments become last_game_segment + last_game_segments[env_id][agent_id] = game_segments[env_id][agent_id] + last_game_priorities[env_id][agent_id] = priorities + + # create new GameSegment + game_segments[env_id][agent_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + game_segments[env_id][agent_id].reset(observation_window_stack[env_id][agent_id]) + else: + if game_segments[env_id].is_full(): + # pad over last block trajectory + if last_game_segments[env_id] is not None: + # TODO(pu): return the one game block + self.pad_and_save_last_trajectory( + env_id, last_game_segments, last_game_priorities, game_segments, dones + ) - # create new GameSegment - game_segments[env_id] = GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) - game_segments[env_id].reset(observation_window_stack[env_id]) + # calculate priority + priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) + pred_values_lst[env_id] = [] + search_values_lst[env_id] = [] + if self.policy_config.gumbel_algo: + improved_policy_lst[env_id] = [] + + # the current game_segments become last_game_segment + last_game_segments[env_id] = game_segments[env_id] + last_game_priorities[env_id] = priorities + + # create new GameSegment + game_segments[env_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + game_segments[env_id].reset(observation_window_stack[env_id]) self._env_info[env_id]['step'] += 1 self._total_envstep_count += 1 @@ -537,21 +647,38 @@ def collect(self, # NOTE: put the penultimate game block in one episode into the trajectory_pool # pad over 2th last game_segment using the last game_segment - if last_game_segments[env_id] is not None: - self.pad_and_save_last_trajectory( - env_id, last_game_segments, last_game_priorities, game_segments, dones - ) + if self._multi_agent: + for agent_id in range(agent_num): + if last_game_segments[env_id] is not None: + self.pad_and_save_last_trajectory( + env_id, agent_id, last_game_segments, last_game_priorities, game_segments, dones + ) + + # store current block trajectory + priorities = self._compute_priorities(env_id, agent_id, pred_values_lst, search_values_lst) - # store current block trajectory - priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) + # NOTE: put the last game block in one episode into the trajectory_pool + game_segments[env_id][agent_id].game_segment_to_array() - # NOTE: put the last game block in one episode into the trajectory_pool - game_segments[env_id].game_segment_to_array() + # assert len(game_segments[env_id]) == len(priorities) + # NOTE: save the last game block in one episode into the trajectory_pool if it's not null + if len(game_segments[env_id][agent_id].reward_segment) != 0: + self.game_segment_pool.append((game_segments[env_id][agent_id], priorities, dones[env_id])) + else: + if last_game_segments[env_id] is not None: + self.pad_and_save_last_trajectory( + env_id, last_game_segments, last_game_priorities, game_segments, dones + ) + # store current block trajectory + priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) + + # NOTE: put the last game block in one episode into the trajectory_pool + game_segments[env_id].game_segment_to_array() - # assert len(game_segments[env_id]) == len(priorities) - # NOTE: save the last game block in one episode into the trajectory_pool if it's not null - if len(game_segments[env_id].reward_segment) != 0: - self.game_segment_pool.append((game_segments[env_id], priorities, dones[env_id])) + # assert len(game_segments[env_id]) == len(priorities) + # NOTE: save the last game block in one episode into the trajectory_pool if it's not null + if len(game_segments[env_id].reward_segment) != 0: + self.game_segment_pool.append((game_segments[env_id], priorities, dones[env_id])) # print(game_segments[env_id].reward_segment) # reset the finished env and init game_segments @@ -582,18 +709,37 @@ def collect(self, action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) - game_segments[env_id] = GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) - observation_window_stack[env_id] = deque( - [init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num)], - maxlen=self.policy_config.model.frame_stack_num - ) - game_segments[env_id].reset(observation_window_stack[env_id]) - last_game_segments[env_id] = None - last_game_priorities[env_id] = None + if self._multi_agent: + for agent_id in range(agent_num): + game_segments[env_id][agent_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + observation_window_stack[env_id][agent_id] = deque( + [ + init_obs[env_id]['observation'][agent_id] + for _ in range(self.policy_config.model.frame_stack_num) + ], + maxlen=self.policy_config.model.frame_stack_num + ) + game_segments[env_id][agent_id].reset(observation_window_stack[env_id][agent_id]) + last_game_segments[env_id] = [None for _ in range(agent_num)] + last_game_priorities[env_id] = [None for _ in range(agent_num)] + + else: + game_segments[env_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + observation_window_stack[env_id] = deque( + [init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num)], + maxlen=self.policy_config.model.frame_stack_num + ) + game_segments[env_id].reset(observation_window_stack[env_id]) + last_game_segments[env_id] = None + last_game_priorities[env_id] = None # log self_play_moves_max = max(self_play_moves_max, eps_steps_lst[env_id]) diff --git a/zoo/gobigger/config/gobigger_efficientzero_config.py b/zoo/gobigger/config/gobigger_efficientzero_config.py index dfdadc105..4d67c6cc8 100644 --- a/zoo/gobigger/config/gobigger_efficientzero_config.py +++ b/zoo/gobigger/config/gobigger_efficientzero_config.py @@ -1,6 +1,7 @@ from easydict import EasyDict env_name = 'GoBigger' +multi_agent = True # ============================================================== # begin of the most frequently changed config specified by the user @@ -67,7 +68,9 @@ manager=dict(shared_memory=False, ), ), policy=dict( + multi_agent=multi_agent, model=dict( + model_type='structured', latent_state_dim=176, frame_stack_num=1, action_space_size=action_space_size, diff --git a/zoo/gobigger/env/gobigger_env.py b/zoo/gobigger/env/gobigger_env.py index 29153163b..e9197296a 100644 --- a/zoo/gobigger/env/gobigger_env.py +++ b/zoo/gobigger/env/gobigger_env.py @@ -100,7 +100,8 @@ def step(self, action_dict: dict) -> BaseEnvTimestep: # postprocess self.postproecess(action_dict) if done: - info['eval_episode_return'] = [raw_obs[0]['leaderboard'][i] for i in range(self.team_num)] + info['eval_episode_return'] = raw_obs[0]['leaderboard'][0] + info['eval_bot_episode_return'] = raw_obs[0]['leaderboard'][1] #TODO only support t2p2 return BaseEnvTimestep(obs, rew, done, info) def seed(self, seed: int, dynamic_seed: bool = True) -> None: From 1ed22b2ef54342cb8a5929957997b31a84b60f05 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Mon, 3 Jul 2023 21:29:41 +0800 Subject: [PATCH 26/54] fix(yzj): fix collector bug that observation_window_stack no for_loop on agent_id --- lzero/worker/muzero_collector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index a269b1ad2..34500f924 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -557,7 +557,8 @@ def collect(self, # append the newest obs if self._multi_agent: - observation_window_stack[env_id][agent_id].append(to_ndarray(obs['observation'][agent_id])) + for agent_id in range(agent_num): + observation_window_stack[env_id][agent_id].append(to_ndarray(obs['observation'][agent_id])) else: observation_window_stack[env_id].append(to_ndarray(obs['observation'])) From 35e77149fdc071cfec4176481dfa916497c55e54 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Wed, 5 Jul 2023 22:51:52 +0800 Subject: [PATCH 27/54] fix(yzj): fix ignore done in collector --- lzero/worker/muzero_collector.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 34500f924..f3dc907a8 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -532,7 +532,11 @@ def collect(self, action_mask_dict[env_id] = to_ndarray(obs['action_mask']) to_play_dict[env_id] = to_ndarray(obs['to_play']) - dones[env_id] = done + if self.policy_config.ignore_done: + dones[env_id] = False + else: + dones[env_id] = done + if self._multi_agent: for agent_id in range(agent_num): visit_entropies_lst[env_id][agent_id] += visit_entropy_dict[env_id][agent_id] @@ -670,14 +674,15 @@ def collect(self, self.pad_and_save_last_trajectory( env_id, last_game_segments, last_game_priorities, game_segments, dones ) - # store current block trajectory + + # store current segment trajectory priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) - # NOTE: put the last game block in one episode into the trajectory_pool + # NOTE: put the last game segment in one episode into the trajectory_pool game_segments[env_id].game_segment_to_array() # assert len(game_segments[env_id]) == len(priorities) - # NOTE: save the last game block in one episode into the trajectory_pool if it's not null + # NOTE: save the last game segment in one episode into the trajectory_pool if it's not null if len(game_segments[env_id].reward_segment) != 0: self.game_segment_pool.append((game_segments[env_id], priorities, dones[env_id])) From 4df3ada691964887fe62263b76aa3971d607a377 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Wed, 5 Jul 2023 22:59:35 +0800 Subject: [PATCH 28/54] polish(yzj): polish ez config ignore done --- zoo/gobigger/config/gobigger_efficientzero_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/zoo/gobigger/config/gobigger_efficientzero_config.py b/zoo/gobigger/config/gobigger_efficientzero_config.py index 4d67c6cc8..dade662ac 100644 --- a/zoo/gobigger/config/gobigger_efficientzero_config.py +++ b/zoo/gobigger/config/gobigger_efficientzero_config.py @@ -69,6 +69,7 @@ ), policy=dict( multi_agent=multi_agent, + ignore_done=True, model=dict( model_type='structured', latent_state_dim=176, From 272611fa5c43cb2f222cc5614e6876ee00db3038 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Wed, 5 Jul 2023 23:02:24 +0800 Subject: [PATCH 29/54] fix(yzj): add game_segment_pool clear() --- lzero/worker/muzero_collector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index f3dc907a8..c6efb8a50 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -774,6 +774,7 @@ def collect(self, 'unroll_plus_td_steps': self.unroll_plus_td_steps } for i in range(len(self.game_segment_pool)) ] + self.game_segment_pool.clear() # for i in range(len(self.game_segment_pool)): # print(self.game_segment_pool[i][0].obs_segment.__len__()) # print(self.game_segment_pool[i][0].reward_segment) From cc5499678aeabb3d52dc7f1626d0945958292318 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Wed, 12 Jul 2023 11:38:25 +0800 Subject: [PATCH 30/54] polish(yzj): add gobigger/entry , polish gobigger config and add default env config(t2p2) --- lzero/entry/__init__.py | 2 - .../config/gobigger_efficientzero_config.py | 43 ++------------ zoo/gobigger/config/gobigger_eval_config.py | 2 +- zoo/gobigger/config/gobigger_muzero_config.py | 49 +++------------- .../gobigger_sampled_efficientzero_config.py | 58 ++++++------------- zoo/gobigger/entry/__init__.py | 2 + .../gobigger}/entry/eval_muzero_gobigger.py | 0 .../gobigger}/entry/train_muzero_gobigger.py | 4 ++ zoo/gobigger/env/gobigger_env.py | 45 +++++++++++++- 9 files changed, 82 insertions(+), 123 deletions(-) create mode 100644 zoo/gobigger/entry/__init__.py rename {lzero => zoo/gobigger}/entry/eval_muzero_gobigger.py (100%) rename {lzero => zoo/gobigger}/entry/train_muzero_gobigger.py (97%) diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py index c6db99f30..352d29ddf 100644 --- a/lzero/entry/__init__.py +++ b/lzero/entry/__init__.py @@ -4,5 +4,3 @@ from .eval_muzero import eval_muzero from .eval_muzero_with_gym_env import eval_muzero_with_gym_env from .train_muzero_with_gym_env import train_muzero_with_gym_env -from .train_muzero_gobigger import train_muzero_gobigger -from .eval_muzero_gobigger import eval_muzero_gobigger diff --git a/zoo/gobigger/config/gobigger_efficientzero_config.py b/zoo/gobigger/config/gobigger_efficientzero_config.py index dade662ac..e823713d2 100644 --- a/zoo/gobigger/config/gobigger_efficientzero_config.py +++ b/zoo/gobigger/config/gobigger_efficientzero_config.py @@ -26,42 +26,7 @@ exp_name= f'data_ez_ctree/{env_name}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', env=dict( - env_name=env_name, - team_num=2, - player_num_per_team=2, - direction_num=direction_num, - step_mul=8, - map_width=64, - map_height=64, - frame_limit=3600, - action_space_size=action_space_size, - use_action_mask=False, - reward_div_value=0.1, - reward_type='log_reward', - contain_raw_obs=False, # False on collect mode, True on eval vsbot mode, because bot need raw obs - start_spirit_progress=0.2, - end_spirit_progress=0.8, - manager_settings=dict( - food_manager=dict( - num_init=260, - num_min=260, - num_max=300, - ), - thorns_manager=dict( - num_init=3, - num_min=3, - num_max=4, - ), - player_manager=dict(ball_settings=dict(score_init=13000, ), ), - ), - playback_settings=dict( - playback_type='by_frame', - by_frame=dict( - save_frame=False, # when training should set as False - save_dir='./', - save_name_prefix='gobigger', - ), - ), + env_name=env_name, # default is 'GoBigger T2P2' collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, @@ -87,10 +52,10 @@ random_collect_episode_num=random_collect_episode_num, eps=dict( eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, - type='exp', + type='linear', start=1., end=0.05, - decay=int(1.5e4), + decay=int(1e5), ), use_augmentation=False, update_per_collect=update_per_collect, @@ -133,5 +98,5 @@ create_config = atari_efficientzero_create_config if __name__ == "__main__": - from lzero.entry import train_muzero_gobigger + from zoo.gobigger.entry import train_muzero_gobigger train_muzero_gobigger([main_config, create_config], seed=seed) diff --git a/zoo/gobigger/config/gobigger_eval_config.py b/zoo/gobigger/config/gobigger_eval_config.py index e32b71151..8b98c1a1f 100644 --- a/zoo/gobigger/config/gobigger_eval_config.py +++ b/zoo/gobigger/config/gobigger_eval_config.py @@ -1,5 +1,5 @@ # According to the model you want to evaluate, import the corresponding config. -from lzero.entry import eval_muzero_gobigger +from zoo.gobigger.entry import eval_muzero_gobigger import numpy as np if __name__ == "__main__": diff --git a/zoo/gobigger/config/gobigger_muzero_config.py b/zoo/gobigger/config/gobigger_muzero_config.py index 6effc4dc5..05e91a8a1 100644 --- a/zoo/gobigger/config/gobigger_muzero_config.py +++ b/zoo/gobigger/config/gobigger_muzero_config.py @@ -1,6 +1,7 @@ from easydict import EasyDict env_name = 'GoBigger' +multi_agent = True # ============================================================== # begin of the most frequently changed config specified by the user @@ -24,50 +25,17 @@ atari_muzero_config = dict( exp_name=f'data_mz_ctree/{env_name}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', env=dict( - env_name=env_name, - team_num=2, - player_num_per_team=2, - direction_num=direction_num, - step_mul=8, - map_width=64, - map_height=64, - frame_limit=3600, - action_space_size=action_space_size, - use_action_mask=False, - reward_div_value=0.1, - reward_type='log_reward', - contain_raw_obs=False, # False on collect mode, True on eval vsbot mode, because bot need raw obs - start_spirit_progress=0.2, - end_spirit_progress=0.8, - manager_settings=dict( - food_manager=dict( - num_init=260, - num_min=260, - num_max=300, - ), - thorns_manager=dict( - num_init=3, - num_min=3, - num_max=4, - ), - player_manager=dict(ball_settings=dict(score_init=13000, ), ), - ), - playback_settings=dict( - playback_type='by_frame', - by_frame=dict( - save_frame=False, - # save_frame=True, - save_dir='./', - save_name_prefix='gobigger', - ), - ), + env_name=env_name, # default is 'GoBigger T2P2' collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, manager=dict(shared_memory=False, ), ), policy=dict( + multi_agent=multi_agent, + ignore_done=True, model=dict( + model_type='structured', latent_state_dim=176, frame_stack_num=1, action_space_size=action_space_size, @@ -78,15 +46,16 @@ ), cuda=True, mcts_ctree=True, + gumbel_algo=False, env_type='not_board_games', game_segment_length=400, random_collect_episode_num=random_collect_episode_num, eps=dict( eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, - type='exp', + type='linear', start=1., end=0.05, - decay=int(1.5e4), + decay=int(1e5), ), use_augmentation=False, update_per_collect=update_per_collect, @@ -130,5 +99,5 @@ create_config = atari_muzero_create_config if __name__ == "__main__": - from lzero.entry import train_muzero_gobigger + from zoo.gobigger.entry import train_muzero_gobigger train_muzero_gobigger([main_config, create_config], seed=seed) diff --git a/zoo/gobigger/config/gobigger_sampled_efficientzero_config.py b/zoo/gobigger/config/gobigger_sampled_efficientzero_config.py index f847077fc..a9e9306e6 100644 --- a/zoo/gobigger/config/gobigger_sampled_efficientzero_config.py +++ b/zoo/gobigger/config/gobigger_sampled_efficientzero_config.py @@ -1,10 +1,12 @@ from easydict import EasyDict env_name = 'GoBigger' +multi_agent = True # ============================================================== # begin of the most frequently changed config specified by the user # ============================================================== +seed = 0 continuous_action_space = False K = 20 # num_of_sampled_actions collector_env_num = 32 @@ -16,58 +18,27 @@ reanalyze_ratio = 0. action_space_size = 27 direction_num = 12 +eps_greedy_exploration_in_collect = True +random_collect_episode_num = 0 # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== atari_sampled_efficientzero_config = dict( exp_name= - f'data_sez_ctree/{env_name[:-14]}_sampled_efficientzero_k{K}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + f'data_sez_ctree/{env_name}_sampled_efficientzero_k{K}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', env=dict( - env_name=env_name, - team_num=2, - player_num_per_team=2, - direction_num=direction_num, - step_mul=8, - map_width=64, - map_height=64, - frame_limit=3600, - action_space_size=action_space_size, - use_action_mask=False, - reward_div_value=0.1, - reward_type='log_reward', - contain_raw_obs=False, # False on collect mode, True on eval vsbot mode, because bot need raw obs - start_spirit_progress=0.2, - end_spirit_progress=0.8, - manager_settings=dict( - food_manager=dict( - num_init=260, - num_min=260, - num_max=300, - ), - thorns_manager=dict( - num_init=3, - num_min=3, - num_max=4, - ), - player_manager=dict(ball_settings=dict(score_init=13000, ), ), - ), - playback_settings=dict( - playback_type='by_frame', - by_frame=dict( - save_frame=False, - # save_frame=True, - save_dir='./', - save_name_prefix='gobigger', - ), - ), + env_name=env_name, # default is 'GoBigger T2P2' collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, manager=dict(shared_memory=False, ), ), policy=dict( + multi_agent=multi_agent, + ignore_done=True, model=dict( + model_type='structured', latent_state_dim=176, frame_stack_num=1, action_space_size=action_space_size, @@ -79,8 +50,17 @@ ), cuda=True, mcts_ctree=True, + gumbel_algo=False, env_type='not_board_games', game_segment_length=400, + random_collect_episode_num=random_collect_episode_num, + eps=dict( + eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, + type='linear', + start=1., + end=0.05, + decay=int(1e5), + ), use_augmentation=False, update_per_collect=update_per_collect, batch_size=batch_size, @@ -123,5 +103,5 @@ create_config = atari_sampled_efficientzero_create_config if __name__ == "__main__": - from lzero.entry import train_muzero_gobigger + from zoo.gobigger.entry import train_muzero_gobigger train_muzero_gobigger([main_config, create_config], seed=0) diff --git a/zoo/gobigger/entry/__init__.py b/zoo/gobigger/entry/__init__.py new file mode 100644 index 000000000..641457eff --- /dev/null +++ b/zoo/gobigger/entry/__init__.py @@ -0,0 +1,2 @@ +from .train_muzero_gobigger import train_muzero_gobigger +from .eval_muzero_gobigger import eval_muzero_gobigger diff --git a/lzero/entry/eval_muzero_gobigger.py b/zoo/gobigger/entry/eval_muzero_gobigger.py similarity index 100% rename from lzero/entry/eval_muzero_gobigger.py rename to zoo/gobigger/entry/eval_muzero_gobigger.py diff --git a/lzero/entry/train_muzero_gobigger.py b/zoo/gobigger/entry/train_muzero_gobigger.py similarity index 97% rename from lzero/entry/train_muzero_gobigger.py rename to zoo/gobigger/entry/train_muzero_gobigger.py index 57d2d139d..9f83ceb0e 100644 --- a/lzero/entry/train_muzero_gobigger.py +++ b/zoo/gobigger/entry/train_muzero_gobigger.py @@ -151,6 +151,10 @@ def train_muzero_gobigger( collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) else: collect_kwargs['epsilon'] = 0.0 + + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + stop, reward = vsbot_evaluator.eval_vsbot(learner.save_checkpoint, learner.train_iter, collector.envstep) + # Evaluate policy performance. if evaluator.should_eval(learner.train_iter): stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) diff --git a/zoo/gobigger/env/gobigger_env.py b/zoo/gobigger/env/gobigger_env.py index e9197296a..b9a2b0f63 100644 --- a/zoo/gobigger/env/gobigger_env.py +++ b/zoo/gobigger/env/gobigger_env.py @@ -2,8 +2,9 @@ import numpy as np from ditk import logging from ding.envs import BaseEnv, BaseEnvTimestep -from ding.utils import ENV_REGISTRY +from ding.utils import ENV_REGISTRY, deep_merge_dicts import math +from easydict import EasyDict try: from gobigger.envs import GoBiggerEnv except ImportError: @@ -12,11 +13,51 @@ sys.exit(1) +default_t2p2_config = dict( + team_num=2, + player_num_per_team=2, + direction_num=12, + step_mul=8, + map_width=64, + map_height=64, + frame_limit=3600, + action_space_size=27, + use_action_mask=False, + reward_div_value=0.1, + reward_type='log_reward', + contain_raw_obs=False, # False on collect mode, True on eval vsbot mode, because bot need raw obs + start_spirit_progress=0.2, + end_spirit_progress=0.8, + manager_settings=dict( + food_manager=dict( + num_init=260, + num_min=260, + num_max=300, + ), + thorns_manager=dict( + num_init=3, + num_min=3, + num_max=4, + ), + player_manager=dict(ball_settings=dict(score_init=13000, ), ), + ), + playback_settings=dict( + playback_type='by_frame', + by_frame=dict( + save_frame=False, # when training should set as False + save_dir='./', + save_name_prefix='gobigger', + ), + ), + ) + + @ENV_REGISTRY.register('gobigger_lightzero') class GoBiggerLightZeroEnv(BaseEnv): def __init__(self, cfg: dict) -> None: - self._cfg = cfg + self._cfg = deep_merge_dicts(default_t2p2_config, cfg) + self._cfg = EasyDict(self._cfg) # ding env info self._init_flag = False self._observation_space = None From 39802f55616cb376039da66ce7444fae6c1c6d39 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Mon, 17 Jul 2023 16:51:44 +0800 Subject: [PATCH 31/54] polish(yzj): polish eps greedy and random policy --- lzero/entry/train_muzero.py | 12 + lzero/entry/utils.py | 49 -- lzero/policy/efficientzero.py | 4 + lzero/policy/gobigger_efficientzero.py | 307 +-------- lzero/policy/gobigger_muzero.py | 328 +-------- lzero/policy/gobigger_random_policy.py | 328 --------- .../policy/gobigger_sampled_efficientzero.py | 630 +----------------- lzero/policy/muzero.py | 4 + lzero/policy/sampled_efficientzero.py | 5 +- lzero/worker/gobigger_muzero_evaluator.py | 499 ++------------ lzero/worker/muzero_collector.py | 5 +- lzero/worker/muzero_evaluator.py | 130 +++- .../config/atari_efficientzero_config.py | 12 + zoo/gobigger/entry/train_muzero_gobigger.py | 14 +- 14 files changed, 279 insertions(+), 2048 deletions(-) delete mode 100644 lzero/policy/gobigger_random_policy.py diff --git a/lzero/entry/train_muzero.py b/lzero/entry/train_muzero.py index 06ae053cc..5650988d5 100644 --- a/lzero/entry/train_muzero.py +++ b/lzero/entry/train_muzero.py @@ -9,6 +9,7 @@ from ding.envs import get_vec_env_setting from ding.policy import create_policy from ding.utils import set_pkg_seed +from ding.rl_utils import get_epsilon_greedy_fn from ding.worker import BaseLearner from tensorboardX import SummaryWriter @@ -126,6 +127,17 @@ def train_muzero( trained_steps=learner.train_iter ) + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + else: + collect_kwargs['epsilon'] = 0.0 + # Evaluate policy performance. if evaluator.should_eval(learner.train_iter): stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index 9c6e777df..9963af566 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -6,55 +6,6 @@ from pympler.asizeof import asizeof from tensorboardX import SummaryWriter -from lzero.policy.gobigger_random_policy import GoBiggerRandomPolicy - - -def random_collect( - policy_cfg: 'EasyDict', # noqa - policy: 'Policy', # noqa - collector: 'ISerialCollector', # noqa - collector_env: 'BaseEnvManager', # noqa - replay_buffer: 'IBuffer', # noqa - postprocess_data_fn: Optional[Callable] = None -) -> None: # noqa - """ - Overview: - Collect data by random policy. - Arguments: - - policy_cfg (:obj:`EasyDict`): The policy config. - - policy (:obj:`Policy`): The policy. - - collector (:obj:`ISerialCollector`): The collector. - - collector_env (:obj:`BaseEnvManager`): The collector env manager. - - replay_buffer (:obj:`IBuffer`): The replay buffer. - - postprocess_data_fn (:obj:`Optional[Callable]`): The postprocess function for the collected data. - """ - assert policy_cfg.random_collect_episode_num > 0 - - random_policy = GoBiggerRandomPolicy(cfg=policy_cfg) - # set the policy to random policy - collector.reset_policy(random_policy.collect_mode) - - collect_kwargs = {} - # set temperature for visit count distributions according to the train_iter, - # please refer to Appendix D in MuZero paper for details. - collect_kwargs['temperature'] = 1 - collect_kwargs['epsilon'] = 0.0 - - # Collect data by default config n_sample/n_episode. - new_data = collector.collect(train_iter=0, policy_kwargs=collect_kwargs) - - if postprocess_data_fn is not None: - new_data = postprocess_data_fn(new_data) - - # save returned new_data collected by the collector - replay_buffer.push_game_segments(new_data) - # remove the oldest data if the replay buffer is full. - replay_buffer.remove_oldest_data_to_fit() - - # restore the policy - collector.reset_policy(policy.collect_mode) - - def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None: """ Overview: diff --git a/lzero/policy/efficientzero.py b/lzero/policy/efficientzero.py index 6b59ab752..84dd6f1d9 100644 --- a/lzero/policy/efficientzero.py +++ b/lzero/policy/efficientzero.py @@ -501,6 +501,7 @@ def _init_collect(self) -> None: else: self._mcts_collect = MCTSPtree(self._cfg) self.collect_mcts_temperature = 1 + self.collect_epsilon = 1 def _forward_collect( self, @@ -508,6 +509,8 @@ def _forward_collect( action_mask: list = None, temperature: float = 1, to_play: List = [-1], + random_collect_episode_num: int = 0, + epsilon: float = 0.25, ready_env_id=None ): """ @@ -697,6 +700,7 @@ def _monitor_vars_learn(self) -> List[str]: """ return [ 'collect_mcts_temperature', + 'collect_epsilon', 'cur_lr', 'weighted_total_loss', 'total_loss', diff --git a/lzero/policy/gobigger_efficientzero.py b/lzero/policy/gobigger_efficientzero.py index f2f89d96d..0b8885fbc 100644 --- a/lzero/policy/gobigger_efficientzero.py +++ b/lzero/policy/gobigger_efficientzero.py @@ -1,11 +1,8 @@ -import copy from typing import List, Dict, Any, Tuple, Union import numpy as np import torch -import torch.optim as optim -from ding.model import model_wrap -from ding.policy.base_policy import Policy +from .efficientzero import EfficientZeroPolicy from ding.torch_utils import to_tensor from ding.utils import POLICY_REGISTRY from torch.distributions import Categorical @@ -13,7 +10,6 @@ from lzero.mcts import EfficientZeroMCTSCtree as MCTSCtree from lzero.mcts import EfficientZeroMCTSPtree as MCTSPtree -from lzero.model import ImageTransforms from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ DiscreteSupport, select_action, to_torch_float_tensor, ez_network_output_unpack, negative_cosine_similarity, prepare_obs, \ configure_optimizers @@ -22,147 +18,12 @@ @POLICY_REGISTRY.register('gobigger_efficientzero') -class GoBiggerEfficientZeroPolicy(Policy): +class GoBiggerEfficientZeroPolicy(EfficientZeroPolicy): """ Overview: The policy class for GoBiggerEfficientZero. """ - # The default_config for EfficientZero policy. - config = dict( - model=dict( - # (str) The model type. For 1-dimensional vector obs, we use mlp model. For 3-dimensional image obs, we use conv model. - model_type='conv', # options={'mlp', 'conv'} - # (bool) If True, the action space of the environment is continuous, otherwise discrete. - continuous_action_space=False, - # (tuple) The stacked obs shape. - # observation_shape=(1, 96, 96), # if frame_stack_num=1 - observation_shape=(4, 96, 96), # if frame_stack_num=4 - # (bool) Whether to use the self-supervised learning loss. - self_supervised_learning_loss=True, - # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. - categorical_distribution=True, - # (int) The image channel in image observation. - image_channel=1, - # (int) The number of frames to stack together. - frame_stack_num=1, - # (int) The scale of supports used in categorical distribution. - # This variable is only effective when ``categorical_distribution=True``. - support_scale=300, - # (int) The hidden size in LSTM. - lstm_hidden_size=512, - # (bool) whether to learn bias in the last linear layer in value and policy head. - bias=True, - # (str) The type of action encoding. Options are ['one_hot', 'not_one_hot']. Default to 'one_hot'. - discrete_action_encoding_type='one_hot', - # (bool) whether to use res connection in dynamics. - res_connection_in_dynamics=True, - # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'LN'. - norm_type='BN', - ), - # ****** common ****** - # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) - # this variable is used in ``collector``. - sampled_algo=False, - # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. - mcts_ctree=True, - # (bool) Whether to use cuda for network. - cuda=True, - # (int) The number of environments used in collecting data. - collector_env_num=8, - # (int) The number of environments used in evaluating policy. - evaluator_env_num=3, - # (str) The type of environment. The options are ['not_board_games', 'board_games']. - env_type='not_board_games', - # (str) The type of battle mode. The options are ['play_with_bot_mode', 'self_play_mode']. - battle_mode='play_with_bot_mode', - # (bool) Whether to monitor extra statistics in tensorboard. - monitor_extra_statistics=True, - # (int) The transition number of one ``GameSegment``. - game_segment_length=200, - - # ****** observation ****** - # (bool) Whether to transform image to string to save memory. - transform2string=False, - # (bool) Whether to use data augmentation. - use_augmentation=False, - # (list) The style of augmentation. - augmentation=['shift', 'intensity'], - - # ******* learn ****** - # (int) How many updates(iterations) to train after collector's one collection. - # Bigger "update_per_collect" means bigger off-policy. - # collect data -> update policy-> collect data -> ... - # For different env, we have different episode_length, - # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor - update_per_collect=100, - # (int) Minibatch size for one gradient descent. - batch_size=256, - # (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW'] - optim_type='SGD', - # (float) Learning rate for training policy network. Initial lr for manually decay schedule. - learning_rate=0.2, - # (int) Frequency of target network update. - target_update_freq=100, - # (float) Weight decay for training policy network. - weight_decay=1e-4, - # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). - momentum=0.9, - # (float) The maximum constraint value of gradient norm clipping. - grad_clip_value=10, - # (int) The number of episode in each collecting stage. - n_episode=8, - # (float) the number of simulations in MCTS. - num_simulations=50, - # (float) Discount factor (gamma) for returns. - discount_factor=0.997, - # (int) The number of step for calculating target q_value. - td_steps=5, - # (int) The number of unroll steps in dynamics network. - num_unroll_steps=5, - # (int) reset the hidden states in LSTM every ``lstm_horizon_len`` horizon steps. - lstm_horizon_len=5, - # (float) The weight of reward loss. - reward_loss_weight=1, - # (float) The weight of value loss. - value_loss_weight=0.25, - # (float) The weight of policy loss. - policy_loss_weight=1, - # (float) The weight of ssl (self-supervised learning) loss. - ssl_loss_weight=2, - # (bool) Whether to use piecewise constant learning rate decay. - # i.e. lr: 0.2 -> 0.02 -> 0.002 - lr_piecewise_constant_decay=True, - # (int) The number of final training iterations to control lr decay, which is only used for manually decay. - threshold_training_steps_for_final_lr=int(5e4), - # (int) The number of final training iterations to control temperature, which is only used for manually decay. - threshold_training_steps_for_final_temperature=int(1e5), - # (bool) Whether to use manually decayed temperature. - # i.e. temperature: 1 -> 0.5 -> 0.25 - manual_temperature_decay=False, - # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. - # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. - fixed_temperature_value=0.25, - - # ****** Priority ****** - # (bool) Whether to use priority when sampling training data from the buffer. - use_priority=True, - # (bool) Whether to use the maximum priority for new collecting data. - use_max_priority_for_new_data=True, - # (float) The degree of prioritization to use. A value of 0 means no prioritization, - # while a value of 1 means full prioritization. - priority_prob_alpha=0.6, - # (float) The degree of correction to use. A value of 0 means no correction, - # while a value of 1 means full correction. - priority_prob_beta=0.4, - - # ****** UCB ****** - # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of the search tree. - root_dirichlet_alpha=0.3, - # (float) The noise weight at the root node of the search tree. - root_noise_weight=0.25, - ) - def default_model(self) -> Tuple[str, List[str]]: """ Overview: @@ -177,63 +38,6 @@ def default_model(self) -> Tuple[str, List[str]]: """ return 'GoBiggerEfficientZeroModel', ['lzero.model.gobigger.gobigger_efficientzero_model'] - def _init_learn(self) -> None: - """ - Overview: - Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. - """ - assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type - if self._cfg.optim_type == 'SGD': - self._optimizer = optim.SGD( - self._model.parameters(), - lr=self._cfg.learning_rate, - momentum=self._cfg.momentum, - weight_decay=self._cfg.weight_decay, - ) - - elif self._cfg.optim_type == 'Adam': - self._optimizer = optim.Adam( - self._model.parameters(), - lr=self._cfg.learning_rate, - weight_decay=self._cfg.weight_decay, - ) - elif self._cfg.optim_type == 'AdamW': - self._optimizer = configure_optimizers( - model=self._model, - weight_decay=self._cfg.weight_decay, - learning_rate=self._cfg.learning_rate, - device_type=self._cfg.device - ) - - if self._cfg.lr_piecewise_constant_decay: - from torch.optim.lr_scheduler import LambdaLR - max_step = self._cfg.threshold_training_steps_for_final_lr - # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. - lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa - self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda) - - # use model_wrapper for specialized demands of different modes - self._target_model = copy.deepcopy(self._model) - self._target_model = model_wrap( - self._target_model, - wrapper_name='target', - update_type='assign', - update_kwargs={'freq': self._cfg.target_update_freq} - ) - self._learn_model = self._model - - if self._cfg.use_augmentation: - self.image_transforms = ImageTransforms( - self._cfg.augmentation, - image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) - ) - self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) - def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: """ Overview: @@ -506,25 +310,13 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: 'total_grad_norm_before_clip': total_grad_norm_before_clip } - def _init_collect(self) -> None: - """ - Overview: - Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. - """ - self._collect_model = self._model - if self._cfg.mcts_ctree: - self._mcts_collect = MCTSCtree(self._cfg) - else: - self._mcts_collect = MCTSPtree(self._cfg) - self.collect_mcts_temperature = 1 - self.collect_epsilon = 1 - def _forward_collect( self, data: torch.Tensor, action_mask: list = None, temperature: float = 1, to_play: List = [-1], + random_collect_episode_num: int = 0, epsilon: float = 0.25, ready_env_id=None ): @@ -609,18 +401,25 @@ def _forward_collect( for i in range(batch_size): distributions, value = roots_visit_count_distributions[i], roots_values[i] - if self._cfg.eps.eps_greedy_exploration_in_collect: - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self.collect_mcts_temperature, deterministic=True - ) - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - if np.random.rand() < self.collect_epsilon: - action = np.random.choice(legal_actions[i]) - else: + if random_collect_episode_num>0: # random collect + distributions, value = roots_visit_count_distributions[i], roots_values[i] action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( distributions, temperature=self.collect_mcts_temperature, deterministic=False ) action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + else: + if self._cfg.eps.eps_greedy_exploration_in_collect: # eps greedy collect + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self.collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self.collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: # collect + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self.collect_mcts_temperature, deterministic=False + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] output[i // agent_num]['action'].append(action) output[i // agent_num]['distributions'].append(distributions) output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) @@ -630,17 +429,6 @@ def _forward_collect( return output - def _init_eval(self) -> None: - """ - Overview: - Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. - """ - self._eval_model = self._model - if self._cfg.mcts_ctree: - self._mcts_eval = MCTSCtree(self._cfg) - else: - self._mcts_eval = MCTSPtree(self._cfg) - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id=None): """ Overview: @@ -728,62 +516,3 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read return output - def _monitor_vars_learn(self) -> List[str]: - """ - Overview: - Register the variables to be monitored in learn mode. The registered variables will be logged in - tensorboard according to the return value ``_forward_learn``. - """ - return [ - 'collect_mcts_temperature', - 'collect_epsilon', - 'cur_lr', - 'weighted_total_loss', - 'total_loss', - 'policy_loss', - 'policy_entropy', - 'target_policy_entropy', - 'value_prefix_loss', - 'value_loss', - 'consistency_loss', - 'value_priority', - 'target_value_prefix', - 'target_value', - 'predicted_value_prefixs', - 'predicted_values', - 'transformed_target_value_prefix', - 'transformed_target_value', - 'total_grad_norm_before_clip', - ] - - def _state_dict_learn(self) -> Dict[str, Any]: - """ - Overview: - Return the state_dict of learn mode, usually including model and optimizer. - Returns: - - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. - """ - return { - 'model': self._learn_model.state_dict(), - 'target_model': self._target_model.state_dict(), - 'optimizer': self._optimizer.state_dict(), - } - - def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: - """ - Overview: - Load the state_dict variable into policy learn mode. - Arguments: - - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before. - """ - self._learn_model.load_state_dict(state_dict['model']) - self._target_model.load_state_dict(state_dict['target_model']) - self._optimizer.load_state_dict(state_dict['optimizer']) - - def _process_transition(self, obs, policy_output, timestep): - # be compatible with DI-engine Policy class - pass - - def _get_train_sample(self, data): - # be compatible with DI-engine Policy class - pass diff --git a/lzero/policy/gobigger_muzero.py b/lzero/policy/gobigger_muzero.py index c06737ffd..d5dab230a 100644 --- a/lzero/policy/gobigger_muzero.py +++ b/lzero/policy/gobigger_muzero.py @@ -1,18 +1,14 @@ -import copy from typing import List, Dict, Any, Tuple, Union import numpy as np import torch -import torch.optim as optim -from ding.model import model_wrap -from ding.policy.base_policy import Policy +from .muzero import MuZeroPolicy from ding.torch_utils import to_tensor from ding.utils import POLICY_REGISTRY from torch.nn import L1Loss from lzero.mcts import MuZeroMCTSCtree as MCTSCtree from lzero.mcts import MuZeroMCTSPtree as MCTSPtree -from lzero.model import ImageTransforms from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, prepare_obs, \ configure_optimizers @@ -21,146 +17,12 @@ @POLICY_REGISTRY.register('gobigger_muzero') -class GoBiggerMuZeroPolicy(Policy): +class GoBiggerMuZeroPolicy(MuZeroPolicy): """ Overview: The policy class for GoBiggerMuZero. """ - - # The default_config for MuZero policy. - config = dict( - model=dict( - # (str) The model type. For 1-dimensional vector obs, we use mlp model. For the image obs, we use conv model. - model_type='conv', # options={'mlp', 'conv'} - # (bool) If True, the action space of the environment is continuous, otherwise discrete. - continuous_action_space=False, - # (tuple) The stacked obs shape. - # observation_shape=(1, 96, 96), # if frame_stack_num=1 - observation_shape=(4, 96, 96), # if frame_stack_num=4 - # (bool) Whether to use the self-supervised learning loss. - self_supervised_learning_loss=False, - # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. - categorical_distribution=True, - # (int) The image channel in image observation. - image_channel=1, - # (int) The number of frames to stack together. - frame_stack_num=1, - # (int) The number of res blocks in MuZero model. - num_res_blocks=1, - # (int) The number of channels of hidden states in MuZero model. - num_channels=64, - # (int) The scale of supports used in categorical distribution. - # This variable is only effective when ``categorical_distribution=True``. - support_scale=300, - # (bool) whether to learn bias in the last linear layer in value and policy head. - bias=True, - # (str) The type of action encoding. Options are ['one_hot', 'not_one_hot']. Default to 'one_hot'. - discrete_action_encoding_type='one_hot', - # (bool) whether to use res connection in dynamics. - res_connection_in_dynamics=True, - # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'LN'. - norm_type='BN', - ), - # ****** common ****** - # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) - # this variable is used in ``collector``. - sampled_algo=False, - # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. - mcts_ctree=True, - # (bool) Whether to use cuda for network. - cuda=True, - # (int) The number of environments used in collecting data. - collector_env_num=8, - # (int) The number of environments used in evaluating policy. - evaluator_env_num=3, - # (str) The type of environment. Options is ['not_board_games', 'board_games']. - env_type='not_board_games', - # (str) The type of battle mode. Options is ['play_with_bot_mode', 'self_play_mode']. - battle_mode='play_with_bot_mode', - # (bool) Whether to monitor extra statistics in tensorboard. - monitor_extra_statistics=True, - # (int) The transition number of one ``GameSegment``. - game_segment_length=200, - - # ****** observation ****** - # (bool) Whether to transform image to string to save memory. - transform2string=False, - # (bool) Whether to use data augmentation. - use_augmentation=False, - # (list) The style of augmentation. - augmentation=['shift', 'intensity'], - - # ******* learn ****** - # (int) How many updates(iterations) to train after collector's one collection. - # Bigger "update_per_collect" means bigger off-policy. - # collect data -> update policy-> collect data -> ... - # For different env, we have different episode_length, - # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor - update_per_collect=100, - # (int) Minibatch size for one gradient descent. - batch_size=256, - # (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW'] - optim_type='SGD', - # (float) Learning rate for training policy network. Initial lr for manually decay schedule. - learning_rate=0.2, - # (int) Frequency of target network update. - target_update_freq=100, - # (float) Weight decay for training policy network. - weight_decay=1e-4, - # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). - momentum=0.9, - # (float) The maximum constraint value of gradient norm clipping. - grad_clip_value=10, - # (int) The number of episodes in each collecting stage. - n_episode=8, - # (int) the number of simulations in MCTS. - num_simulations=50, - # (float) Discount factor (gamma) for returns. - discount_factor=0.997, - # (int) The number of steps for calculating target q_value. - td_steps=5, - # (int) The number of unroll steps in dynamics network. - num_unroll_steps=5, - # (float) The weight of reward loss. - reward_loss_weight=1, - # (float) The weight of value loss. - value_loss_weight=0.25, - # (float) The weight of policy loss. - policy_loss_weight=1, - # (float) The weight of ssl (self-supervised learning) loss. - ssl_loss_weight=0, - # (bool) Whether to use piecewise constant learning rate decay. - # i.e. lr: 0.2 -> 0.02 -> 0.002 - lr_piecewise_constant_decay=True, - # (int) The number of final training iterations to control lr decay, which is only used for manually decay. - threshold_training_steps_for_final_lr=int(5e4), - # (bool) Whether to use manually decayed temperature. - manual_temperature_decay=False, - # (int) The number of final training iterations to control temperature, which is only used for manually decay. - threshold_training_steps_for_final_temperature=int(1e5), - # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. - # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. - fixed_temperature_value=0.25, - - # ****** Priority ****** - # (bool) Whether to use priority when sampling training data from the buffer. - use_priority=True, - # (bool) Whether to use the maximum priority for new collecting data. - use_max_priority_for_new_data=True, - # (float) The degree of prioritization to use. A value of 0 means no prioritization, - # while a value of 1 means full prioritization. - priority_prob_alpha=0.6, - # (float) The degree of correction to use. A value of 0 means no correction, - # while a value of 1 means full correction. - priority_prob_beta=0.4, - - # ****** UCB ****** - # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of search tree. - root_dirichlet_alpha=0.3, - # (float) The noise weight at the root node of the search tree. - root_noise_weight=0.25, - ) - + def default_model(self) -> Tuple[str, List[str]]: """ Overview: @@ -175,60 +37,6 @@ def default_model(self) -> Tuple[str, List[str]]: """ return 'GoBiggerMuZeroModel', ['lzero.model.gobigger.gobigger_muzero_model'] - def _init_learn(self) -> None: - """ - Overview: - Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. - """ - assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type - # NOTE: in board_gmaes, for fixed lr 0.003, 'Adam' is better than 'SGD'. - if self._cfg.optim_type == 'SGD': - self._optimizer = optim.SGD( - self._model.parameters(), - lr=self._cfg.learning_rate, - momentum=self._cfg.momentum, - weight_decay=self._cfg.weight_decay, - ) - elif self._cfg.optim_type == 'Adam': - self._optimizer = optim.Adam( - self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay - ) - elif self._cfg.optim_type == 'AdamW': - self._optimizer = configure_optimizers( - model=self._model, - weight_decay=self._cfg.weight_decay, - learning_rate=self._cfg.learning_rate, - device_type=self._cfg.device - ) - - if self._cfg.lr_piecewise_constant_decay: - from torch.optim.lr_scheduler import LambdaLR - max_step = self._cfg.threshold_training_steps_for_final_lr - # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. - lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa - self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda) - - # use model_wrapper for specialized demands of different modes - self._target_model = copy.deepcopy(self._model) - self._target_model = model_wrap( - self._target_model, - wrapper_name='target', - update_type='assign', - update_kwargs={'freq': self._cfg.target_update_freq} - ) - self._learn_model = self._model - - if self._cfg.use_augmentation: - self.image_transforms = ImageTransforms( - self._cfg.augmentation, - image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) - ) - self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) - def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: """ Overview: @@ -406,7 +214,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in self._learn_model.parameters(), self._cfg.grad_clip_value ) self._optimizer.step() - if self._cfg.lr_piecewise_constant_decay is True: + if self._cfg.lr_piecewise_constant_decay: self.lr_scheduler.step() # ============================================================== @@ -414,68 +222,34 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # ============================================================== self._target_model.update(self._learn_model.state_dict()) - # packing loss info for tensorboard logging - loss_info = ( - weighted_total_loss.item(), loss.mean().item(), policy_loss.mean().item(), reward_loss.mean().item(), - value_loss.mean().item(), consistency_loss.mean() - ) if self._cfg.monitor_extra_statistics: predicted_rewards = torch.stack(predicted_rewards).transpose(1, 0).squeeze(-1) predicted_rewards = predicted_rewards.reshape(-1).unsqueeze(-1) - - td_data = ( - value_priority, - target_reward.detach().cpu().numpy(), - target_value.detach().cpu().numpy(), - transformed_target_reward.detach().cpu().numpy(), - transformed_target_value.detach().cpu().numpy(), - target_reward_categorical.detach().cpu().numpy(), - target_value_categorical.detach().cpu().numpy(), - predicted_rewards.detach().cpu().numpy(), - predicted_values.detach().cpu().numpy(), - target_policy.detach().cpu().numpy(), - predicted_policies.detach().cpu().numpy(), - latent_state_list, - ) - return { 'collect_mcts_temperature': self.collect_mcts_temperature, 'collect_epsilon': self.collect_epsilon, 'cur_lr': self._optimizer.param_groups[0]['lr'], - 'weighted_total_loss': loss_info[0], - 'total_loss': loss_info[1], - 'policy_loss': loss_info[2], - 'reward_loss': loss_info[3], - 'value_loss': loss_info[4], - 'consistency_loss': loss_info[5] / self._cfg.num_unroll_steps, + 'weighted_total_loss': weighted_total_loss.item(), + 'total_loss': loss.mean().item(), + 'policy_loss': policy_loss.mean().item(), + 'reward_loss': reward_loss.mean().item(), + 'value_loss': value_loss.mean().item(), + 'consistency_loss': consistency_loss.mean() / self._cfg.num_unroll_steps, # ============================================================== # priority related # ============================================================== 'value_priority_orig': value_priority, - 'value_priority': td_data[0].flatten().mean().item(), - 'target_reward': td_data[1].flatten().mean().item(), - 'target_value': td_data[2].flatten().mean().item(), - 'transformed_target_reward': td_data[3].flatten().mean().item(), - 'transformed_target_value': td_data[4].flatten().mean().item(), - 'predicted_rewards': td_data[7].flatten().mean().item(), - 'predicted_values': td_data[8].flatten().mean().item(), + 'value_priority': value_priority.mean().item(), + 'target_reward': target_reward.detach().cpu().numpy().mean().item(), + 'target_value': target_value.detach().cpu().numpy().mean().item(), + 'transformed_target_reward': transformed_target_reward.detach().cpu().numpy().mean().item(), + 'transformed_target_value': transformed_target_value.detach().cpu().numpy().mean().item(), + 'predicted_rewards': predicted_rewards.detach().cpu().numpy().mean().item(), + 'predicted_values': predicted_values.detach().cpu().numpy().mean().item(), 'total_grad_norm_before_clip': total_grad_norm_before_clip } - def _init_collect(self) -> None: - """ - Overview: - Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. - """ - self._collect_model = self._model - if self._cfg.mcts_ctree: - self._mcts_collect = MCTSCtree(self._cfg) - else: - self._mcts_collect = MCTSPtree(self._cfg) - self.collect_mcts_temperature = 1 - self.collect_epsilon = 1 - def _forward_collect( self, data: torch.Tensor, @@ -580,17 +354,6 @@ def _forward_collect( return output - def _init_eval(self) -> None: - """ - Overview: - Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. - """ - self._eval_model = self._model - if self._cfg.mcts_ctree: - self._mcts_eval = MCTSCtree(self._cfg) - else: - self._mcts_eval = MCTSPtree(self._cfg) - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id=None) -> Dict: """ Overview: @@ -674,60 +437,3 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 output[i // agent_num]['policy_logits'].append(policy_logits[i]) return output - - def _monitor_vars_learn(self) -> List[str]: - """ - Overview: - Register the variables to be monitored in learn mode. The registered variables will be logged in - tensorboard according to the return value ``_forward_learn``. - """ - return [ - 'collect_mcts_temperature', - 'cur_lr', - 'weighted_total_loss', - 'total_loss', - 'policy_loss', - 'reward_loss', - 'value_loss', - 'consistency_loss', - 'value_priority', - 'target_reward', - 'target_value', - 'predicted_rewards', - 'predicted_values', - 'transformed_target_reward', - 'transformed_target_value', - 'total_grad_norm_before_clip', - ] - - def _state_dict_learn(self) -> Dict[str, Any]: - """ - Overview: - Return the state_dict of learn mode, usually including model, target_model and optimizer. - Returns: - - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. - """ - return { - 'model': self._learn_model.state_dict(), - 'target_model': self._target_model.state_dict(), - 'optimizer': self._optimizer.state_dict(), - } - - def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: - """ - Overview: - Load the state_dict variable into policy learn mode. - Arguments: - - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. - """ - self._learn_model.load_state_dict(state_dict['model']) - self._target_model.load_state_dict(state_dict['target_model']) - self._optimizer.load_state_dict(state_dict['optimizer']) - - def _process_transition(self, obs, policy_output, timestep): - # be compatible with DI-engine Policy class - pass - - def _get_train_sample(self, data): - # be compatible with DI-engine Policy class - pass diff --git a/lzero/policy/gobigger_random_policy.py b/lzero/policy/gobigger_random_policy.py deleted file mode 100644 index b1aef6709..000000000 --- a/lzero/policy/gobigger_random_policy.py +++ /dev/null @@ -1,328 +0,0 @@ -import copy -from typing import List, Dict, Any, Tuple, Union - -import numpy as np -import torch -from ding.policy.base_policy import Policy -from ding.torch_utils import to_tensor -from ding.utils import POLICY_REGISTRY - -from lzero.mcts import EfficientZeroMCTSCtree as MCTSCtree -from lzero.mcts import EfficientZeroMCTSPtree as MCTSPtree -from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ - DiscreteSupport, select_action, to_torch_float_tensor, ez_network_output_unpack, negative_cosine_similarity, prepare_obs, \ - configure_optimizers -from collections import defaultdict -from ding.torch_utils import to_device - - -@POLICY_REGISTRY.register('gobigger_random_policy') -class GoBiggerRandomPolicy(Policy): - """ - Overview: - The policy class for GoBiggerRandom. - """ - - # The default_config for GoBiggerRandom policy. - config = dict( - model=dict( - # (str) The model type. For 1-dimensional vector obs, we use mlp model. For 3-dimensional image obs, we use conv model. - model_type='conv', # options={'mlp', 'conv'} - # (bool) If True, the action space of the environment is continuous, otherwise discrete. - continuous_action_space=False, - # (tuple) The stacked obs shape. - # observation_shape=(1, 96, 96), # if frame_stack_num=1 - observation_shape=(4, 96, 96), # if frame_stack_num=4 - # (bool) Whether to use the self-supervised learning loss. - self_supervised_learning_loss=True, - # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. - categorical_distribution=True, - # (int) The image channel in image observation. - image_channel=1, - # (int) The number of frames to stack together. - frame_stack_num=1, - # (int) The scale of supports used in categorical distribution. - # This variable is only effective when ``categorical_distribution=True``. - support_scale=300, - # (int) The hidden size in LSTM. - lstm_hidden_size=512, - # (bool) whether to learn bias in the last linear layer in value and policy head. - bias=True, - # (str) The type of action encoding. Options are ['one_hot', 'not_one_hot']. Default to 'one_hot'. - discrete_action_encoding_type='one_hot', - # (bool) whether to use res connection in dynamics. - res_connection_in_dynamics=True, - # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'LN'. - norm_type='BN', - ), - # ****** common ****** - # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) - # this variable is used in ``collector``. - sampled_algo=False, - # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. - mcts_ctree=True, - # (bool) Whether to use cuda for network. - cuda=True, - # (int) The number of environments used in collecting data. - collector_env_num=8, - # (int) The number of environments used in evaluating policy. - evaluator_env_num=3, - # (str) The type of environment. The options are ['not_board_games', 'board_games']. - env_type='not_board_games', - # (str) The type of battle mode. The options are ['play_with_bot_mode', 'self_play_mode']. - battle_mode='play_with_bot_mode', - # (bool) Whether to monitor extra statistics in tensorboard. - monitor_extra_statistics=True, - # (int) The transition number of one ``GameSegment``. - game_segment_length=200, - - # ****** observation ****** - # (bool) Whether to transform image to string to save memory. - transform2string=False, - # (bool) Whether to use data augmentation. - use_augmentation=False, - # (list) The style of augmentation. - augmentation=['shift', 'intensity'], - - # ******* learn ****** - # (int) How many updates(iterations) to train after collector's one collection. - # Bigger "update_per_collect" means bigger off-policy. - # collect data -> update policy-> collect data -> ... - # For different env, we have different episode_length, - # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor - # if we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically. - update_per_collect=None, - # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. - model_update_ratio=0.1, - # (int) Minibatch size for one gradient descent. - batch_size=256, - # (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW'] - optim_type='SGD', - # (float) Learning rate for training policy network. Initial lr for manually decay schedule. - learning_rate=0.2, - # (int) Frequency of target network update. - target_update_freq=100, - # (float) Weight decay for training policy network. - weight_decay=1e-4, - # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). - momentum=0.9, - # (float) The maximum constraint value of gradient norm clipping. - grad_clip_value=10, - # (int) The number of episodes in each collecting stage. - n_episode=8, - # (float) the number of simulations in MCTS. - num_simulations=50, - # (float) Discount factor (gamma) for returns. - discount_factor=0.997, - # (int) The number of steps for calculating target q_value. - td_steps=5, - # (int) The number of unroll steps in dynamics network. - num_unroll_steps=5, - # (int) reset the hidden states in LSTM every ``lstm_horizon_len`` horizon steps. - lstm_horizon_len=5, - # (float) The weight of reward loss. - reward_loss_weight=1, - # (float) The weight of value loss. - value_loss_weight=0.25, - # (float) The weight of policy loss. - policy_loss_weight=1, - # (float) The weight of ssl (self-supervised learning) loss. - ssl_loss_weight=2, - # (bool) Whether to use piecewise constant learning rate decay. - # i.e. lr: 0.2 -> 0.02 -> 0.002 - lr_piecewise_constant_decay=True, - # (int) The number of final training iterations to control lr decay, which is only used for manually decay. - threshold_training_steps_for_final_lr=int(5e4), - # (int) The number of final training iterations to control temperature, which is only used for manually decay. - threshold_training_steps_for_final_temperature=int(1e5), - # (bool) Whether to use manually decayed temperature. - # i.e. temperature: 1 -> 0.5 -> 0.25 - manual_temperature_decay=False, - # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. - # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. - fixed_temperature_value=0.25, - - # ****** Priority ****** - # (bool) Whether to use priority when sampling training data from the buffer. - use_priority=True, - # (bool) Whether to use the maximum priority for new collecting data. - use_max_priority_for_new_data=True, - # (float) The degree of prioritization to use. A value of 0 means no prioritization, - # while a value of 1 means full prioritization. - priority_prob_alpha=0.6, - # (float) The degree of correction to use. A value of 0 means no correction, - # while a value of 1 means full correction. - priority_prob_beta=0.4, - - # ****** UCB ****** - # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of the search tree. - root_dirichlet_alpha=0.3, - # (float) The noise weight at the root node of the search tree. - root_noise_weight=0.25, - ) - - def default_model(self) -> Tuple[str, List[str]]: - """ - Overview: - Return this algorithm default model setting. - Returns: - - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. - - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. - - import_names (:obj:`List[str]`): The model class path list used in this algorithm. - .. note:: - The user can define and use customized network model but must obey the same interface definition indicated \ - by import_names path. For EfficientZero, ``lzero.model.efficientzero_model.EfficientZeroModel`` - """ - return 'GoBiggerEfficientZeroModel', ['lzero.model.gobigger.gobigger_efficientzero_model'] - - def _init_collect(self) -> None: - """ - Overview: - Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. - """ - self._collect_model = self._model - if self._cfg.mcts_ctree: - self._mcts_collect = MCTSCtree(self._cfg) - else: - self._mcts_collect = MCTSPtree(self._cfg) - self.collect_mcts_temperature = 1.0 - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) - - def _forward_collect( - self, - data: torch.Tensor, - action_mask: list = None, - temperature: float = 1, - to_play: List = [-1], - ready_env_id=None - ): - """ - Overview: - The forward function for collecting data in collect mode. Use model to execute MCTS search. - Choosing the action through sampling during the collect mode. - Arguments: - - data (:obj:`torch.Tensor`): The input data, i.e. the observation. - - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. - - temperature (:obj:`float`): The temperature of the policy. - - to_play (:obj:`int`): The player to play. - - ready_env_id (:obj:`list`): The id of the env that is ready to collect. - Shape: - - data (:obj:`torch.Tensor`): - - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ - S is the number of stacked frames, H is the height of the image, W is the width of the image. - - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. - - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. - - temperature: :math:`(1, )`. - - to_play: :math:`(N, 1)`, where N is the number of collect_env. - - ready_env_id: None - Returns: - - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ - ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. - """ - self._collect_model.eval() - self.collect_mcts_temperature = temperature - - active_collect_env_num = len(data) - data = to_tensor(data) - data = sum(sum(data, []), []) - batch_size = len(data) - data = to_device(data, self._cfg.device) - agent_num = batch_size // active_collect_env_num - to_play = np.array(to_play).reshape(-1).tolist() - - with torch.no_grad(): - # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} - network_output = self._collect_model.initial_inference(data) - latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( - network_output - ) - - # if not in training, obtain the scalars of the value/reward - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() - latent_state_roots = latent_state_roots.detach().cpu().numpy() - reward_hidden_state_roots = ( - reward_hidden_state_roots[0].detach().cpu().numpy(), reward_hidden_state_roots[1].detach().cpu().numpy() - ) - policy_logits = policy_logits.detach().cpu().numpy().tolist() - - action_mask = sum(action_mask, []) - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] - # the only difference between collect and eval is the dirichlet noise. - noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) - ).astype(np.float32).tolist() for j in range(batch_size) - ] - if self._cfg.mcts_ctree: - # cpp mcts_tree - roots = MCTSCtree.roots(batch_size, legal_actions) - else: - # python mcts_tree - roots = MCTSPtree.roots(batch_size, legal_actions) - roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_roots, policy_logits, to_play) - self._mcts_collect.search( - roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play - ) - - roots_visit_count_distributions = roots.get_distributions( - ) # shape: ``{list: batch_size} ->{list: action_space_size}`` - roots_values = roots.get_values() # shape: {list: batch_size} - - data_id = [i for i in range(active_collect_env_num)] - output = {i: defaultdict(list) for i in data_id} - if ready_env_id is None: - ready_env_id = np.arange(active_collect_env_num) - - for i in range(batch_size): - distributions, value = roots_visit_count_distributions[i], roots_values[i] - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self.collect_mcts_temperature, deterministic=False - ) - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - # ************* random action ************* - action = int(np.random.choice(legal_actions[i], 1)) - output[i // agent_num]['action'].append(action) - output[i // agent_num]['distributions'].append(distributions) - output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) - output[i // agent_num]['value'].append(value) - output[i // agent_num]['pred_value'].append(pred_values[i]) - output[i // agent_num]['policy_logits'].append(policy_logits[i]) - - return output - - def _init_eval(self) -> None: - """ - Overview: - Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. - """ - self._eval_model = self._model - if self._cfg.mcts_ctree: - self._mcts_eval = MCTSCtree(self._cfg) - else: - self._mcts_eval = MCTSPtree(self._cfg) - - # be compatible with DI-engine Policy class - def _init_learn(self) -> None: - pass - - def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: - pass - - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id=None): - pass - - def _monitor_vars_learn(self) -> List[str]: - pass - - def _state_dict_learn(self) -> Dict[str, Any]: - pass - - def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: - pass - - def _process_transition(self, obs, policy_output, timestep): - pass - - def _get_train_sample(self, data): - pass diff --git a/lzero/policy/gobigger_sampled_efficientzero.py b/lzero/policy/gobigger_sampled_efficientzero.py index f402091fc..58991bda3 100644 --- a/lzero/policy/gobigger_sampled_efficientzero.py +++ b/lzero/policy/gobigger_sampled_efficientzero.py @@ -1,11 +1,8 @@ -import copy from typing import List, Dict, Any, Tuple, Union import numpy as np import torch -import torch.optim as optim -from ding.model import model_wrap -from ding.policy.base_policy import Policy +from .sampled_efficientzero import SampledEfficientZeroPolicy from ding.torch_utils import to_tensor from ding.utils import POLICY_REGISTRY from ditk import logging @@ -14,7 +11,6 @@ from lzero.mcts import SampledEfficientZeroMCTSCtree as MCTSCtree from lzero.mcts import SampledEfficientZeroMCTSPtree as MCTSPtree -from lzero.model import ImageTransforms from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ DiscreteSupport, to_torch_float_tensor, ez_network_output_unpack, select_action, negative_cosine_similarity, prepare_obs, \ configure_optimizers @@ -23,161 +19,12 @@ @POLICY_REGISTRY.register('gobigger_sampled_efficientzero') -class GoBiggerSampledEfficientZeroPolicy(Policy): +class GoBiggerSampledEfficientZeroPolicy(SampledEfficientZeroPolicy): """ Overview: The policy class for GoBigger Sampled EfficientZero. """ - # The default_config for Sampled fEficientZero policy. - config = dict( - model=dict( - # (str) The model type. For 1-dimensional vector obs, we use mlp model. For 3-dimensional image obs, we use conv model. - model_type='conv', # options={'mlp', 'conv'} - # (bool) If True, the action space of the environment is continuous, otherwise discrete. - continuous_action_space=False, - # (tuple) the stacked obs shape. - # observation_shape=(1, 96, 96), # if frame_stack_num=1 - observation_shape=(4, 96, 96), # if frame_stack_num=4 - # (bool) Whether to use the self-supervised learning loss. - self_supervised_learning_loss=True, - # (int) The size of action space. For discrete action space, it is the number of actions. - # For continuous action space, it is the dimension of action. - action_space_size=6, - # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. - categorical_distribution=True, - # (int) the image channel in image observation. - image_channel=1, - # (int) The number of frames to stack together. - frame_stack_num=1, - # (int) The scale of supports used in categorical distribution. - # This variable is only effective when ``categorical_distribution=True``. - support_scale=300, - # (int) The hidden size in LSTM. - lstm_hidden_size=512, - # (str) The type of sigma. options={'conditioned', 'fixed'} - sigma_type='conditioned', - # (float) The fixed sigma value. Only effective when ``sigma_type='fixed'``. - fixed_sigma_value=0.3, - # (bool) whether to learn bias in the last linear layer in value and policy head. - bias=True, - # (str) The type of action encoding. Options are ['one_hot', 'not_one_hot']. Default to 'one_hot'. - discrete_action_encoding_type='one_hot', - # (bool) whether to use res connection in dynamics. - res_connection_in_dynamics=True, - # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'LN'. - norm_type='BN', - ), - # ****** common ****** - # (bool) ``sampled_algo=True`` means the policy is sampled-based algorithm (e.g. Sampled EfficientZero), which is used in ``collector``. - sampled_algo=True, - # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. - mcts_ctree=True, - # (bool) Whether to use cuda in policy. - cuda=True, - # (int) The number of environments used in collecting data. - collector_env_num=8, - # (int) The number of environments used in evaluating policy. - evaluator_env_num=3, - # (str) The type of environment. The options are ['not_board_games', 'board_games']. - env_type='not_board_games', - # (str) The type of battle mode. The options are ['play_with_bot_mode', 'self_play_mode']. - battle_mode='play_with_bot_mode', - # (bool) Whether to monitor extra statistics in tensorboard. - monitor_extra_statistics=True, - # (int) The transition number of one ``GameSegment``. - game_segment_length=200, - - # ****** observation ****** - # (bool) Whether to transform image to string to save memory. - transform2string=False, - # (bool) Whether to use data augmentation. - use_augmentation=False, - # (list) The style of augmentation. - augmentation=['shift', 'intensity'], - - # ******* learn ****** - # (int) How many updates(iterations) to train after collector's one collection. - # Bigger "update_per_collect" means bigger off-policy. - # collect data -> update policy-> collect data -> ... - # For different env, we have different episode_length, - # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor - update_per_collect=100, - # (int) Minibatch size for one gradient descent. - batch_size=256, - # (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW'] - optim_type='SGD', - learning_rate=0.2, # init lr for manually decay schedule - # optim_type='Adam', - # lr_piecewise_constant_decay=False, - # learning_rate=0.003, # lr for Adam optimizer - # (float) Weight uniform initialization range in the last output layer - init_w=3e-3, - normalize_prob_of_sampled_actions=False, - policy_loss_type='cross_entropy', # options={'cross_entropy', 'KL'} - # (int) Frequency of target network update. - target_update_freq=100, - weight_decay=1e-4, - momentum=0.9, - grad_clip_value=10, - # You can use either "n_sample" or "n_episode" in collector.collect. - # Get "n_episode" episodes per collect. - n_episode=8, - # (float) the number of simulations in MCTS. - num_simulations=50, - # (float) Discount factor (gamma) for returns. - discount_factor=0.997, - # (int) The number of step for calculating target q_value. - td_steps=5, - # (int) The number of unroll steps in dynamics network. - num_unroll_steps=5, - # (int) reset the hidden states in LSTM every ``lstm_horizon_len`` horizon steps. - lstm_horizon_len=5, - # (float) The weight of reward loss. - reward_loss_weight=1, - # (float) The weight of value loss. - value_loss_weight=0.25, - # (float) The weight of policy loss. - policy_loss_weight=1, - # (float) The weight of policy entropy loss. - policy_entropy_loss_weight=0, - # (float) The weight of ssl (self-supervised learning) loss. - ssl_loss_weight=2, - # (bool) Whether to use the cosine learning rate decay. - cos_lr_scheduler=False, - # (bool) Whether to use piecewise constant learning rate decay. - # i.e. lr: 0.2 -> 0.02 -> 0.002 - lr_piecewise_constant_decay=True, - # (int) The number of final training iterations to control lr decay, which is only used for manually decay. - threshold_training_steps_for_final_lr=int(5e4), - # (int) The number of final training iterations to control temperature, which is only used for manually decay. - threshold_training_steps_for_final_temperature=int(1e5), - # (bool) Whether to use manually decayed temperature. - # i.e. temperature: 1 -> 0.5 -> 0.25 - manual_temperature_decay=False, - # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. - # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. - fixed_temperature_value=0.25, - - # ****** Priority ****** - # (bool) Whether to use priority when sampling training data from the buffer. - use_priority=True, - # (bool) Whether to use the maximum priority for new collecting data. - use_max_priority_for_new_data=True, - # (float) The degree of prioritization to use. A value of 0 means no prioritization, - # while a value of 1 means full prioritization. - priority_prob_alpha=0.6, - # (float) The degree of correction to use. A value of 0 means no correction, - # while a value of 1 means full correction. - priority_prob_beta=0.4, - - # ****** UCB ****** - # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of the search tree. - root_dirichlet_alpha=0.3, - # (float) The noise weight at the root node of the search tree. - root_noise_weight=0.25, - ) - def default_model(self) -> Tuple[str, List[str]]: """ Overview: @@ -193,75 +40,6 @@ def default_model(self) -> Tuple[str, List[str]]: """ return 'GoBiggerSampledEfficientZeroModel', ['lzero.model.gobigger.gobigger_sampled_efficientzero_model'] - def _init_learn(self) -> None: - """ - Overview: - Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. - """ - assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type - if self._cfg.model.continuous_action_space: - # Weight Init for the last output layer of gaussian policy head in prediction network. - init_w = self._cfg.init_w - self._model.prediction_network.fc_policy_head.mu.weight.data.uniform_(-init_w, init_w) - self._model.prediction_network.fc_policy_head.mu.bias.data.uniform_(-init_w, init_w) - self._model.prediction_network.fc_policy_head.log_sigma_layer.weight.data.uniform_(-init_w, init_w) - try: - self._model.prediction_network.fc_policy_head.log_sigma_layer.bias.data.uniform_(-init_w, init_w) - except Exception as exception: - logging.warning(exception) - - if self._cfg.optim_type == 'SGD': - self._optimizer = optim.SGD( - self._model.parameters(), - lr=self._cfg.learning_rate, - momentum=self._cfg.momentum, - weight_decay=self._cfg.weight_decay, - ) - - elif self._cfg.optim_type == 'Adam': - self._optimizer = optim.Adam( - self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay - ) - elif self._cfg.optim_type == 'AdamW': - self._optimizer = configure_optimizers( - model=self._model, - weight_decay=self._cfg.weight_decay, - learning_rate=self._cfg.learning_rate, - device_type=self._cfg.device - ) - - if self._cfg.cos_lr_scheduler is True: - from torch.optim.lr_scheduler import CosineAnnealingLR - self.lr_scheduler = CosineAnnealingLR(self._optimizer, 1e6, eta_min=0, last_epoch=-1) - - if self._cfg.lr_piecewise_constant_decay: - from torch.optim.lr_scheduler import LambdaLR - max_step = self._cfg.threshold_training_steps_for_final_lr - # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. - lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa - self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda) - - # use model_wrapper for specialized demands of different modes - self._target_model = copy.deepcopy(self._model) - self._target_model = model_wrap( - self._target_model, - wrapper_name='target', - update_type='assign', - update_kwargs={'freq': self._cfg.target_update_freq} - ) - self._learn_model = self._model - - if self._cfg.use_augmentation: - self.image_transforms = ImageTransforms( - self._cfg.augmentation, - image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) - ) - self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) - def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: """ Overview: @@ -486,7 +264,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: self._learn_model.parameters(), self._cfg.grad_clip_value ) self._optimizer.step() - if self._cfg.cos_lr_scheduler is True or self._cfg.lr_piecewise_constant_decay is True: + if self._cfg.cos_lr_scheduler or self._cfg.lr_piecewise_constant_decay: self.lr_scheduler.step() # ============================================================== @@ -494,47 +272,38 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # ============================================================== self._target_model.update(self._learn_model.state_dict()) - loss_data = ( - weighted_total_loss.item(), loss.mean().item(), policy_loss.mean().item(), value_prefix_loss.mean().item(), - value_loss.mean().item(), consistency_loss.mean() - ) if self._cfg.monitor_extra_statistics: predicted_value_prefixs = torch.stack(predicted_value_prefixs).transpose(1, 0).squeeze(-1) predicted_value_prefixs = predicted_value_prefixs.reshape(-1).unsqueeze(-1) - td_data = ( - value_priority, target_value_prefix.detach().cpu().numpy(), target_value.detach().cpu().numpy(), - transformed_target_value_prefix.detach().cpu().numpy(), transformed_target_value.detach().cpu().numpy(), - target_value_prefix_categorical.detach().cpu().numpy(), target_value_categorical.detach().cpu().numpy(), - predicted_value_prefixs.detach().cpu().numpy(), predicted_values.detach().cpu().numpy(), - target_policy.detach().cpu().numpy(), predicted_policies.detach().cpu().numpy(), latent_state_list - ) - - if self._cfg.model.continuous_action_space: - return { + return_data = { 'cur_lr': self._optimizer.param_groups[0]['lr'], + 'collect_epsilon': self.collect_epsilon, 'collect_mcts_temperature': self.collect_mcts_temperature, - 'weighted_total_loss': loss_data[0], - 'total_loss': loss_data[1], - 'policy_loss': loss_data[2], + 'weighted_total_loss': weighted_total_loss.item(), + 'total_loss': loss.mean().item(), + 'policy_loss': policy_loss.mean().item(), 'policy_entropy': policy_entropy.item() / (self._cfg.num_unroll_steps + 1), 'target_policy_entropy': target_policy_entropy.item() / (self._cfg.num_unroll_steps + 1), - 'value_prefix_loss': loss_data[3], - 'value_loss': loss_data[4], - 'consistency_loss': loss_data[5] / self._cfg.num_unroll_steps, + 'value_prefix_loss': value_prefix_loss.mean().item(), + 'value_loss': value_loss.mean().item(), + 'consistency_loss': consistency_loss.mean() / self._cfg.num_unroll_steps, # ============================================================== # priority related # ============================================================== - 'value_priority': td_data[0].flatten().mean().item(), + 'value_priority': value_priority.flatten().mean().item(), 'value_priority_orig': value_priority, - 'target_value_prefix': td_data[1].flatten().mean().item(), - 'target_value': td_data[2].flatten().mean().item(), - 'transformed_target_value_prefix': td_data[3].flatten().mean().item(), - 'transformed_target_value': td_data[4].flatten().mean().item(), - 'predicted_value_prefixs': td_data[7].flatten().mean().item(), - 'predicted_values': td_data[8].flatten().mean().item(), + 'target_value_prefix': target_value_prefix.detach().cpu().numpy().mean().item(), + 'target_value': target_value.detach().cpu().numpy().mean().item(), + 'transformed_target_value_prefix': transformed_target_value_prefix.detach().cpu().numpy().mean().item(), + 'transformed_target_value': transformed_target_value.detach().cpu().numpy().mean().item(), + 'predicted_value_prefixs': predicted_value_prefixs.detach().cpu().numpy().mean().item(), + 'predicted_values': predicted_values.detach().cpu().numpy().mean().item() + } + if self._cfg.model.continuous_action_space: + return_data.update({ # ============================================================== # sampled related core code # ============================================================== @@ -549,32 +318,9 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: 'target_sampled_actions_min': target_sampled_actions[:, :, 0].min().item(), 'target_sampled_actions_mean': target_sampled_actions[:, :, 0].mean().item(), 'total_grad_norm_before_clip': total_grad_norm_before_clip - } + }) else: - return { - 'cur_lr': self._optimizer.param_groups[0]['lr'], - 'collect_mcts_temperature': self.collect_mcts_temperature, - 'weighted_total_loss': loss_data[0], - 'total_loss': loss_data[1], - 'policy_loss': loss_data[2], - 'policy_entropy': policy_entropy.item() / (self._cfg.num_unroll_steps + 1), - 'target_policy_entropy': target_policy_entropy.item() / (self._cfg.num_unroll_steps + 1), - 'value_prefix_loss': loss_data[3], - 'value_loss': loss_data[4], - 'consistency_loss': loss_data[5] / self._cfg.num_unroll_steps, - - # ============================================================== - # priority related - # ============================================================== - 'value_priority': td_data[0].flatten().mean().item(), - 'value_priority_orig': value_priority, - 'target_value_prefix': td_data[1].flatten().mean().item(), - 'target_value': td_data[2].flatten().mean().item(), - 'transformed_target_value_prefix': td_data[3].flatten().mean().item(), - 'transformed_target_value': td_data[4].flatten().mean().item(), - 'predicted_value_prefixs': td_data[7].flatten().mean().item(), - 'predicted_values': td_data[8].flatten().mean().item(), - + return_data.update({ # ============================================================== # sampled related core code # ============================================================== @@ -583,217 +329,12 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: 'target_sampled_actions_min': target_sampled_actions[:, :].float().min().item(), 'target_sampled_actions_mean': target_sampled_actions[:, :].float().mean().item(), 'total_grad_norm_before_clip': total_grad_norm_before_clip - } - - def _calculate_policy_loss_cont( - self, policy_loss: torch.Tensor, policy_logits: torch.Tensor, target_policy: torch.Tensor, - mask_batch: torch.Tensor, child_sampled_actions_batch: torch.Tensor, unroll_step: int - ) -> Tuple[torch.Tensor]: - """ - Overview: - Calculate the policy loss for continuous action space. - Arguments: - - policy_loss (:obj:`torch.Tensor`): The policy loss tensor. - - policy_logits (:obj:`torch.Tensor`): The policy logits tensor. - - target_policy (:obj:`torch.Tensor`): The target policy tensor. - - mask_batch (:obj:`torch.Tensor`): The mask tensor. - - child_sampled_actions_batch (:obj:`torch.Tensor`): The child sampled actions tensor. - - unroll_step (:obj:`int`): The unroll step. - Returns: - - policy_loss (:obj:`torch.Tensor`): The policy loss tensor. - - policy_entropy (:obj:`torch.Tensor`): The policy entropy tensor. - - policy_entropy_loss (:obj:`torch.Tensor`): The policy entropy loss tensor. - - target_policy_entropy (:obj:`torch.Tensor`): The target policy entropy tensor. - - target_sampled_actions (:obj:`torch.Tensor`): The target sampled actions tensor. - - mu (:obj:`torch.Tensor`): The mu tensor. - - sigma (:obj:`torch.Tensor`): The sigma tensor. - """ - (mu, sigma - ) = policy_logits[:, :self._cfg.model.action_space_size], policy_logits[:, -self._cfg.model.action_space_size:] - - dist = Independent(Normal(mu, sigma), 1) - - # take the init hypothetical step k=unroll_step - target_normalized_visit_count = target_policy[:, unroll_step] - - # ******* NOTE: target_policy_entropy is only for debug. ****** - non_masked_indices = torch.nonzero(mask_batch[:, unroll_step]).squeeze(-1) - # Check if there are any unmasked rows - if len(non_masked_indices) > 0: - target_normalized_visit_count_masked = torch.index_select( - target_normalized_visit_count, 0, non_masked_indices - ) - target_dist = Categorical(target_normalized_visit_count_masked) - target_policy_entropy = target_dist.entropy().mean() - else: - # Set target_policy_entropy to 0 if all rows are masked - target_policy_entropy = 0 - - # shape: (batch_size, num_unroll_steps, num_of_sampled_actions, action_dim, 1) -> (batch_size, - # num_of_sampled_actions, action_dim) e.g. (4, 6, 20, 2, 1) -> (4, 20, 2) - target_sampled_actions = child_sampled_actions_batch[:, unroll_step].squeeze(-1) - - policy_entropy = dist.entropy().mean() - policy_entropy_loss = -dist.entropy() - - # Project the sampled-based improved policy back onto the space of representable policies. calculate KL - # loss (batch_size, num_of_sampled_actions) -> (4,20) target_normalized_visit_count is - # categorical distribution, the range of target_log_prob_sampled_actions is (-inf, 0), add 1e-6 for - # numerical stability. - target_log_prob_sampled_actions = torch.log(target_normalized_visit_count + 1e-6) - log_prob_sampled_actions = [] - for k in range(self._cfg.model.num_of_sampled_actions): - # target_sampled_actions[:,i,:].shape: batch_size, action_dim -> 4,2 - # dist.log_prob(target_sampled_actions[:,i,:]).shape: batch_size -> 4 - # dist is normal distribution, the range of log_prob_sampled_actions is (-inf, inf) - - # way 1: - # log_prob = dist.log_prob(target_sampled_actions[:, k, :]) - - # way 2: SAC-like - y = 1 - target_sampled_actions[:, k, :].pow(2) - - # NOTE: for numerical stability. - target_sampled_actions_clamped = torch.clamp( - target_sampled_actions[:, k, :], torch.tensor(-1 + 1e-6), torch.tensor(1 - 1e-6) - ) - target_sampled_actions_before_tanh = torch.arctanh(target_sampled_actions_clamped) - - # keep dimension for loss computation (usually for action space is 1 env. e.g. pendulum) - log_prob = dist.log_prob(target_sampled_actions_before_tanh).unsqueeze(-1) - log_prob = log_prob - torch.log(y + 1e-6).sum(-1, keepdim=True) - log_prob = log_prob.squeeze(-1) - - log_prob_sampled_actions.append(log_prob) - - # shape: (batch_size, num_of_sampled_actions) e.g. (4,20) - log_prob_sampled_actions = torch.stack(log_prob_sampled_actions, dim=-1) - - if self._cfg.normalize_prob_of_sampled_actions: - # normalize the prob of sampled actions - prob_sampled_actions_norm = torch.exp(log_prob_sampled_actions) / torch.exp(log_prob_sampled_actions).sum( - -1 - ).unsqueeze(-1).repeat(1, log_prob_sampled_actions.shape[-1]).detach() - # the above line is equal to the following line. - # prob_sampled_actions_norm = F.normalize(torch.exp(log_prob_sampled_actions), p=1., dim=-1, eps=1e-6) - log_prob_sampled_actions = torch.log(prob_sampled_actions_norm + 1e-6) - - # NOTE: the +=. - if self._cfg.policy_loss_type == 'KL': - # KL divergence loss: sum( p* log(p/q) ) = sum( p*log(p) - p*log(q) )= sum( p*log(p)) - sum( p*log(q) ) - policy_loss += ( - torch.exp(target_log_prob_sampled_actions.detach()) * - (target_log_prob_sampled_actions.detach() - log_prob_sampled_actions) - ).sum(-1) * mask_batch[:, unroll_step] - elif self._cfg.policy_loss_type == 'cross_entropy': - # cross_entropy loss: - sum(p * log (q) ) - policy_loss += -torch.sum( - torch.exp(target_log_prob_sampled_actions.detach()) * log_prob_sampled_actions, 1 - ) * mask_batch[:, unroll_step] - - return policy_loss, policy_entropy, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma - - def _calculate_policy_loss_disc( - self, policy_loss: torch.Tensor, policy_logits: torch.Tensor, target_policy: torch.Tensor, - mask_batch: torch.Tensor, child_sampled_actions_batch: torch.Tensor, unroll_step: int - ) -> Tuple[torch.Tensor]: - """ - Overview: - Calculate the policy loss for discrete action space. - Arguments: - - policy_loss (:obj:`torch.Tensor`): The policy loss tensor. - - policy_logits (:obj:`torch.Tensor`): The policy logits tensor. - - target_policy (:obj:`torch.Tensor`): The target policy tensor. - - mask_batch (:obj:`torch.Tensor`): The mask tensor. - - child_sampled_actions_batch (:obj:`torch.Tensor`): The child sampled actions tensor. - - unroll_step (:obj:`int`): The unroll step. - Returns: - - policy_loss (:obj:`torch.Tensor`): The policy loss tensor. - - policy_entropy (:obj:`torch.Tensor`): The policy entropy tensor. - - policy_entropy_loss (:obj:`torch.Tensor`): The policy entropy loss tensor. - - target_policy_entropy (:obj:`torch.Tensor`): The target policy entropy tensor. - - target_sampled_actions (:obj:`torch.Tensor`): The target sampled actions tensor. - """ - prob = torch.softmax(policy_logits, dim=-1) - dist = Categorical(prob) - - # take the init hypothetical step k=unroll_step - target_normalized_visit_count = target_policy[:, unroll_step] - - # Note: The target_policy_entropy is just for debugging. - target_normalized_visit_count_masked = torch.index_select( - target_normalized_visit_count, 0, - torch.nonzero(mask_batch[:, unroll_step]).squeeze(-1) - ) - target_dist = Categorical(target_normalized_visit_count_masked) - target_policy_entropy = target_dist.entropy().mean() - - # shape: (batch_size, num_unroll_steps, num_of_sampled_actions, action_dim, 1) -> (batch_size, - # num_of_sampled_actions, action_dim) e.g. (4, 6, 20, 2, 1) -> (4, 20, 2) - target_sampled_actions = child_sampled_actions_batch[:, unroll_step].squeeze(-1) - - policy_entropy = dist.entropy().mean() - policy_entropy_loss = -dist.entropy() - - # Project the sampled-based improved policy back onto the space of representable policies. calculate KL - # loss (batch_size, num_of_sampled_actions) -> (4,20) target_normalized_visit_count is - # categorical distribution, the range of target_log_prob_sampled_actions is (-inf, 0), add 1e-6 for - # numerical stability. - target_log_prob_sampled_actions = torch.log(target_normalized_visit_count + 1e-6) - - log_prob_sampled_actions = [] - for k in range(self._cfg.model.num_of_sampled_actions): - # target_sampled_actions[:,i,:] shape: (batch_size, action_dim) e.g. (4,2) - # dist.log_prob(target_sampled_actions[:,i,:]) shape: batch_size e.g. 4 - # dist is normal distribution, the range of log_prob_sampled_actions is (-inf, inf) - - if len(target_sampled_actions.shape) == 2: - target_sampled_actions = target_sampled_actions.unsqueeze(-1) - - log_prob = torch.log(prob.gather(-1, target_sampled_actions[:, k].long()).squeeze(-1) + 1e-6) - log_prob_sampled_actions.append(log_prob) - - # (batch_size, num_of_sampled_actions) e.g. (4,20) - log_prob_sampled_actions = torch.stack(log_prob_sampled_actions, dim=-1) - - if self._cfg.normalize_prob_of_sampled_actions: - # normalize the prob of sampled actions - prob_sampled_actions_norm = torch.exp(log_prob_sampled_actions) / torch.exp(log_prob_sampled_actions).sum( - -1 - ).unsqueeze(-1).repeat(1, log_prob_sampled_actions.shape[-1]).detach() - # the above line is equal to the following line. - # prob_sampled_actions_norm = F.normalize(torch.exp(log_prob_sampled_actions), p=1., dim=-1, eps=1e-6) - log_prob_sampled_actions = torch.log(prob_sampled_actions_norm + 1e-6) - - # NOTE: the +=. - if self._cfg.policy_loss_type == 'KL': - # KL divergence loss: sum( p* log(p/q) ) = sum( p*log(p) - p*log(q) )= sum( p*log(p)) - sum( p*log(q) ) - policy_loss += ( - torch.exp(target_log_prob_sampled_actions.detach()) * - (target_log_prob_sampled_actions.detach() - log_prob_sampled_actions) - ).sum(-1) * mask_batch[:, unroll_step] - elif self._cfg.policy_loss_type == 'cross_entropy': - # cross_entropy loss: - sum(p * log (q) ) - policy_loss += -torch.sum( - torch.exp(target_log_prob_sampled_actions.detach()) * log_prob_sampled_actions, 1 - ) * mask_batch[:, unroll_step] - - return policy_loss, policy_entropy, policy_entropy_loss, target_policy_entropy, target_sampled_actions - - def _init_collect(self) -> None: - """ - Overview: - Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. - """ - self._collect_model = self._model - if self._cfg.mcts_ctree: - self._mcts_collect = MCTSCtree(self._cfg) - else: - self._mcts_collect = MCTSPtree(self._cfg) - self.collect_mcts_temperature = 1 + }) + + return return_data def _forward_collect( - self, data: torch.Tensor, action_mask: list = None, temperature: np.ndarray = 1, to_play=-1, ready_env_id=None + self, data: torch.Tensor, action_mask: list = None, temperature: np.ndarray = 1, to_play=-1, epsilon: float = 0.25, ready_env_id=None ): """ Overview: @@ -820,6 +361,7 @@ def _forward_collect( """ self._collect_model.eval() self.collect_mcts_temperature = temperature + self.collect_epsilon = epsilon active_collect_env_num = len(data) data = to_tensor(data) @@ -923,17 +465,6 @@ def _forward_collect( return output - def _init_eval(self) -> None: - """ - Overview: - Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. - """ - self._eval_model = self._model - if self._cfg.mcts_ctree: - self._mcts_eval = MCTSCtree(self._cfg) - else: - self._mcts_eval = MCTSPtree(self._cfg) - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id=None): """ Overview: @@ -1059,108 +590,3 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read output[i // agent_num]['policy_logits'].append(policy_logits[i]) return output - - def _monitor_vars_learn(self) -> List[str]: - """ - Overview: - Register the variables to be monitored in learn mode. The registered variables will be logged in - tensorboard according to the return value ``_forward_learn``. - """ - if self._cfg.model.continuous_action_space: - return [ - 'collect_mcts_temperature', - 'cur_lr', - 'total_loss', - 'weighted_total_loss', - 'policy_loss', - 'value_prefix_loss', - 'value_loss', - 'consistency_loss', - 'value_priority', - 'target_value_prefix', - 'target_value', - 'predicted_value_prefixs', - 'predicted_values', - 'transformed_target_value_prefix', - 'transformed_target_value', - - # ============================================================== - # sampled related core code - # ============================================================== - 'policy_entropy', - 'target_policy_entropy', - 'policy_mu_max', - 'policy_mu_min', - 'policy_mu_mean', - 'policy_sigma_max', - 'policy_sigma_min', - 'policy_sigma_mean', - # take the fist dim in action space - 'target_sampled_actions_max', - 'target_sampled_actions_min', - 'target_sampled_actions_mean', - 'total_grad_norm_before_clip', - ] - else: - return [ - 'collect_mcts_temperature', - 'cur_lr', - 'total_loss', - 'weighted_total_loss', - 'loss_mean', - 'policy_loss', - 'value_prefix_loss', - 'value_loss', - 'consistency_loss', - 'value_priority', - 'target_value_prefix', - 'target_value', - 'predicted_value_prefixs', - 'predicted_values', - 'transformed_target_value_prefix', - 'transformed_target_value', - - # ============================================================== - # sampled related core code - # ============================================================== - 'policy_entropy', - 'target_policy_entropy', - - # take the fist dim in action space - 'target_sampled_actions_max', - 'target_sampled_actions_min', - 'target_sampled_actions_mean', - 'total_grad_norm_before_clip', - ] - - def _state_dict_learn(self) -> Dict[str, Any]: - """ - Overview: - Return the state_dict of learn mode, usually including model and optimizer. - Returns: - - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. - """ - return { - 'model': self._learn_model.state_dict(), - 'target_model': self._target_model.state_dict(), - 'optimizer': self._optimizer.state_dict(), - } - - def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: - """ - Overview: - Load the state_dict variable into policy learn mode. - Arguments: - - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before. - """ - self._learn_model.load_state_dict(state_dict['model']) - self._target_model.load_state_dict(state_dict['target_model']) - self._optimizer.load_state_dict(state_dict['optimizer']) - - def _process_transition(self, obs, policy_output, timestep): - # be compatible with DI-engine Policy class - pass - - def _get_train_sample(self, data): - # be compatible with DI-engine Policy class - pass diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index 65958924b..ed46fad7a 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -454,6 +454,7 @@ def _init_collect(self) -> None: else: self._mcts_collect = MCTSPtree(self._cfg) self.collect_mcts_temperature = 1 + self.collect_epsilon = 1 def _forward_collect( self, @@ -461,6 +462,8 @@ def _forward_collect( action_mask: list = None, temperature: float = 1, to_play: List = [-1], + random_collect_episode_num: int = 0, + epsilon: float = 0.25, ready_env_id=None ) -> Dict: """ @@ -643,6 +646,7 @@ def _monitor_vars_learn(self) -> List[str]: """ return [ 'collect_mcts_temperature', + 'collect_epsilon', 'cur_lr', 'weighted_total_loss', 'total_loss', diff --git a/lzero/policy/sampled_efficientzero.py b/lzero/policy/sampled_efficientzero.py index 83ac3aa38..a8b32724b 100644 --- a/lzero/policy/sampled_efficientzero.py +++ b/lzero/policy/sampled_efficientzero.py @@ -762,9 +762,10 @@ def _init_collect(self) -> None: else: self._mcts_collect = MCTSPtree(self._cfg) self.collect_mcts_temperature = 1 + self.collect_epsilon = 1 def _forward_collect( - self, data: torch.Tensor, action_mask: list = None, temperature: np.ndarray = 1, to_play=-1, ready_env_id=None + self, data: torch.Tensor, action_mask: list = None, temperature: np.ndarray = 1, to_play=-1, random_collect_episode_num: int = 0, epsilon: float = 0.25, ready_env_id=None ): """ Overview: @@ -1034,6 +1035,7 @@ def _monitor_vars_learn(self) -> List[str]: if self._cfg.model.continuous_action_space: return [ 'collect_mcts_temperature', + 'collect_epsilon', 'cur_lr', 'total_loss', 'weighted_total_loss', @@ -1069,6 +1071,7 @@ def _monitor_vars_learn(self) -> List[str]: else: return [ 'collect_mcts_temperature', + 'collect_epsilon', 'cur_lr', 'total_loss', 'weighted_total_loss', diff --git a/lzero/worker/gobigger_muzero_evaluator.py b/lzero/worker/gobigger_muzero_evaluator.py index ad9665f36..440827936 100644 --- a/lzero/worker/gobigger_muzero_evaluator.py +++ b/lzero/worker/gobigger_muzero_evaluator.py @@ -1,24 +1,23 @@ import time -import copy from collections import namedtuple from typing import Any, Optional, Callable, Tuple import numpy as np import torch -from easydict import EasyDict from ding.envs import BaseEnvManager from ding.torch_utils import to_ndarray -from ding.utils import build_logger, EasyTimer -from ding.worker.collector.base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor +from .muzero_evaluator import MuZeroEvaluator +from ding.worker.collector.base_serial_evaluator import VectorEvalMonitor from lzero.mcts.buffer.game_segment import GameSegment from lzero.mcts.utils import prepare_observation from collections import defaultdict + from zoo.gobigger.env.gobigger_rule_bot import GoBiggerBot from collections import namedtuple, deque -class GoBiggerMuZeroEvaluator(ISerialEvaluator): +class GoBiggerMuZeroEvaluator(MuZeroEvaluator): """ Overview: The Evaluator for GoBigger MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. @@ -27,25 +26,6 @@ class GoBiggerMuZeroEvaluator(ISerialEvaluator): Property: env, policy """ - - @classmethod - def default_config(cls: type) -> EasyDict: - """ - Overview: - Get evaluator's default config. We merge evaluator's default config with other default configs\ - and user's config to get the final config. - Return: - cfg (:obj:`EasyDict`): evaluator's default config - """ - cfg = EasyDict(copy.deepcopy(cls.config)) - cfg.cfg_type = cls.__name__ + 'Dict' - return cfg - - config = dict( - # Evaluate every "eval_freq" training iterations. - eval_freq=50, - ) - def __init__( self, eval_freq: int = 1000, @@ -72,390 +52,15 @@ def __init__( - instance_name (:obj:`Optional[str]`): Name of this instance. - policy_config: Config of game. """ - self._eval_freq = eval_freq - self._exp_name = exp_name - self._instance_name = instance_name - if tb_logger is not None: - self._logger, _ = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False - ) - self._tb_logger = tb_logger - else: - self._logger, self._tb_logger = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name - ) - self.reset(policy, env) - - self._timer = EasyTimer() - self._default_n_episode = n_evaluator_episode - self._stop_value = stop_value - - # ============================================================== - # MCTS+RL related core code - # ============================================================== - self.policy_config = policy_config - - def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: - """ - Overview: - Reset evaluator's environment. In some case, we need evaluator use the same policy in different \ - environments. We can use reset_env to reset the environment. - If _env is None, reset the old environment. - If _env is not None, replace the old environment in the evaluator with the \ - new passed in environment and launch. - Arguments: - - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ - env_manager(BaseEnvManager) - """ - if _env is not None: - self._env = _env - self._env.launch() - self._env_num = self._env.env_num - else: - self._env.reset() - - def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: - """ - Overview: - Reset evaluator's policy. In some case, we need evaluator work in this same environment but use\ - different policy. We can use reset_policy to reset the policy. - If _policy is None, reset the old policy. - If _policy is not None, replace the old policy in the evaluator with the new passed in policy. - Arguments: - - policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy - """ - assert hasattr(self, '_env'), "please set env first" - if _policy is not None: - self._policy = _policy - self._policy.reset() - - def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: - """ - Overview: - Reset evaluator's policy and environment. Use new policy and environment to collect data. - If _env is None, reset the old environment. - If _env is not None, replace the old environment in the evaluator with the new passed in \ - environment and launch. - If _policy is None, reset the old policy. - If _policy is not None, replace the old policy in the evaluator with the new passed in policy. - Arguments: - - policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy - - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ - env_manager(BaseEnvManager) - """ - if _env is not None: - self.reset_env(_env) - if _policy is not None: - self.reset_policy(_policy) - self._max_eval_reward = float("-inf") - self._last_eval_iter = 0 - self._end_flag = False - - def close(self) -> None: - """ - Overview: - Close the evaluator. If end_flag is False, close the environment, flush the tb_logger\ - and close the tb_logger. - """ - if self._end_flag: - return - self._end_flag = True - self._env.close() - self._tb_logger.flush() - self._tb_logger.close() - - def __del__(self): - """ - Overview: - Execute the close command and close the evaluator. __del__ is automatically called \ - to destroy the evaluator instance when the evaluator finishes its work - """ - self.close() - - def should_eval(self, train_iter: int) -> bool: - """ - Overview: - Determine whether you need to start the evaluation mode, if the number of training has reached\ - the maximum number of times to start the evaluator, return True - Arguments: - - train_iter (:obj:`int`): Current training iteration. - """ - if train_iter == self._last_eval_iter: - return False - if (train_iter - self._last_eval_iter) < self._eval_freq and train_iter != 0: - return False - self._last_eval_iter = train_iter - return True - - def eval( - self, - save_ckpt_fn: Callable = None, - train_iter: int = -1, - envstep: int = -1, - n_episode: Optional[int] = None, - ) -> Tuple[bool, float]: - """ - Overview: - Evaluate policy and store the best policy based on whether it reaches the highest historical reward. - Arguments: - - save_ckpt_fn (:obj:`Callable`): Saving ckpt function, which will be triggered by getting the best reward. - - train_iter (:obj:`int`): Current training iteration. - - envstep (:obj:`int`): Current env interaction step. - - n_episode (:obj:`int`): Number of evaluation episodes. - Returns: - - stop_flag (:obj:`bool`): Whether this training program can be ended. - - eval_reward (:obj:`float`): Current eval_reward. - """ - if n_episode is None: - n_episode = self._default_n_episode - assert n_episode is not None, "please indicate eval n_episode" - envstep_count = 0 - eval_monitor = VectorEvalMonitor(self._env.env_num, n_episode) - env_nums = self._env.env_num - - self._env.reset() - self._policy.reset() - - # initializations - init_obs = self._env.ready_obs - - retry_waiting_time = 0.001 - while len(init_obs.keys()) != self._env_num: - # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # len(self._env.ready_obs), especially in tictactoe env. - self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) - self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) - time.sleep(retry_waiting_time) - self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) - self._logger.info( - 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states) - ) - init_obs = self._env.ready_obs - - action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} - - to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} - agent_num = len(init_obs[0]['action_mask']) - dones = np.array([False for _ in range(env_nums)]) - - game_segments = [ - [ - GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) for _ in range(agent_num) - ] for _ in range(env_nums) - ] - for env_id in range(env_nums): - for agent_id in range(agent_num): - game_segments[env_id][agent_id].reset( - [ - to_ndarray(init_obs[env_id]['observation'][agent_id]) - for _ in range(self.policy_config.model.frame_stack_num) - ] - ) - - ready_env_id = set() - remain_episode = n_episode - - with self._timer: - while not eval_monitor.is_finished(): - # Get current ready env obs. - obs = self._env.ready_obs - new_available_env_id = set(obs.keys()).difference(ready_env_id) - ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) - remain_episode -= min(len(new_available_env_id), remain_episode) - - stack_obs = defaultdict(list) - for env_id in ready_env_id: - for agent_id in range(agent_num): - stack_obs[env_id].append(game_segments[env_id][agent_id].get_obs()) - # stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} - stack_obs = list(stack_obs.values()) - - action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} - to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} - action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] - to_play = [to_play_dict[env_id] for env_id in ready_env_id] - - stack_obs = to_ndarray(stack_obs) - # stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) - # stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() - - # ============================================================== - # policy forward - # ============================================================== - policy_output = self._policy.forward(stack_obs, action_mask, to_play) - actions_no_env_id = defaultdict(dict) - for k, v in policy_output.items(): - for agent_id, act in enumerate(v['action']): - actions_no_env_id[k][agent_id] = act - # actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} - distributions_dict_no_env_id = {k: v['distributions'] for k, v in policy_output.items()} - if self.policy_config.sampled_algo: - root_sampled_actions_dict_no_env_id = { - k: v['root_sampled_actions'] - for k, v in policy_output.items() - } - - value_dict_no_env_id = {k: v['value'] for k, v in policy_output.items()} - pred_value_dict_no_env_id = {k: v['pred_value'] for k, v in policy_output.items()} - visit_entropy_dict_no_env_id = { - k: v['visit_count_distribution_entropy'] - for k, v in policy_output.items() - } - - actions = {} - distributions_dict = {} - if self.policy_config.sampled_algo: - root_sampled_actions_dict = {} - value_dict = {} - pred_value_dict = {} - visit_entropy_dict = {} - for index, env_id in enumerate(ready_env_id): - actions[env_id] = actions_no_env_id.pop(index) - distributions_dict[env_id] = distributions_dict_no_env_id.pop(index) - if self.policy_config.sampled_algo: - root_sampled_actions_dict[env_id] = root_sampled_actions_dict_no_env_id.pop(index) - value_dict[env_id] = value_dict_no_env_id.pop(index) - pred_value_dict[env_id] = pred_value_dict_no_env_id.pop(index) - visit_entropy_dict[env_id] = visit_entropy_dict_no_env_id.pop(index) - - # ============================================================== - # Interact with env. - # ============================================================== - timesteps = self._env.step(actions) - - for env_id, t in timesteps.items(): - obs, reward, done, info = t.obs, t.reward, t.done, t.info - - for agent_id in range(agent_num): - game_segments[env_id][agent_id].append( - actions[env_id][agent_id], to_ndarray(obs['observation'][agent_id]), reward[agent_id], - action_mask_dict[env_id][agent_id], to_play_dict[env_id] - ) - - # NOTE: in evaluator, we only need save the ``o_{t+1} = obs['observation']`` - # game_segments[env_id].obs_segment.append(to_ndarray(obs['observation'])) - - # NOTE: the position of code snippet is very important. - # the obs['action_mask'] and obs['to_play'] is corresponding to next action - action_mask_dict[env_id] = to_ndarray(obs['action_mask']) - to_play_dict[env_id] = to_ndarray(obs['to_play']) - - dones[env_id] = done - if t.done: - # Env reset is done by env_manager automatically. - self._policy.reset([env_id]) - reward = t.info['eval_episode_return'] - if 'episode_info' in t.info: - eval_monitor.update_info(env_id, t.info['episode_info']) - eval_monitor.update_reward(env_id, reward) - self._logger.info( - "[EVALUATOR selfplay]env {} finish episode, final reward: {}, current episode: {}".format( - env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode() - ) - ) - - # reset the finished env and init game_segments - if n_episode > self._env_num: - # Get current ready env obs. - init_obs = self._env.ready_obs - retry_waiting_time = 0.001 - while len(init_obs.keys()) != self._env_num: - # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # len(self._env.ready_obs), especially in tictactoe env. - self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) - self._logger.info( - 'Before sleeping, the _env_states is {}'.format(self._env._env_states) - ) - time.sleep(retry_waiting_time) - self._logger.info( - '=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10 - ) - self._logger.info( - 'After sleeping {}s, the current _env_states is {}'.format( - retry_waiting_time, self._env._env_states - ) - ) - init_obs = self._env.ready_obs - - new_available_env_id = set(init_obs.keys()).difference(ready_env_id) - ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) - remain_episode -= min(len(new_available_env_id), remain_episode) - - action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) - to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) - - for agent_id in range(agent_num): - game_segments[env_id][agent_id] = GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) - - game_segments[env_id][agent_id].reset( - [ - init_obs[env_id]['observation'][agent_id] - for _ in range(self.policy_config.model.frame_stack_num) - ] - ) - - # Env reset is done by env_manager automatically. - self._policy.reset([env_id]) - # TODO(pu): subprocess mode, when n_episode > self._env_num, occasionally the ready_env_id=() - # and the stack_obs is np.array(None, dtype=object) - ready_env_id.remove(env_id) - - envstep_count += 1 - duration = self._timer.value - episode_return = eval_monitor.get_episode_return() - info = { - 'train_iter': train_iter, - 'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter), - 'episode_count': n_episode, - 'envstep_count': envstep_count, - 'avg_envstep_per_episode': envstep_count / n_episode, - 'evaluate_time': duration, - 'avg_envstep_per_sec': envstep_count / duration, - 'avg_time_per_episode': n_episode / duration, - 'reward_mean': np.mean(episode_return), - 'reward_std': np.std(episode_return), - 'reward_max': np.max(episode_return), - 'reward_min': np.min(episode_return), - } + super().__init__(eval_freq, n_evaluator_episode, stop_value, env, policy, tb_logger, exp_name, instance_name, policy_config) + + def _add_info(self, last_timestep, info): # add eat info - for i in range(len(t.info['eats']) // 2): - for k, v in t.info['eats'][i].items(): + for i in range(len(last_timestep.info['eats']) // 2): + for k, v in last_timestep.info['eats'][i].items(): info['agent_{}_{}'.format(i, k)] = v - - episode_info = eval_monitor.get_episode_info() - if episode_info is not None: - info.update(episode_info) - self._logger.info(self._logger.get_tabulate_vars_hor(info)) - # self._logger.info(self._logger.get_tabulate_vars(info)) - for k, v in info.items(): - if k in ['train_iter', 'ckpt_name', 'each_reward']: - continue - if not np.isscalar(v): - continue - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) - eval_reward = np.mean(episode_return) - # if eval_reward > self._max_eval_reward: - # if save_ckpt_fn: - # save_ckpt_fn('ckpt_best.pth.tar') - # self._max_eval_reward = eval_reward - stop_flag = eval_reward >= self._stop_value and train_iter > 0 - if stop_flag: - self._logger.info( - "[LightZero serial pipeline] " + - "Current eval_reward: {} is greater than stop_value: {}".format(eval_reward, self._stop_value) + - ", so your MCTS/RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details." - ) - return stop_flag, eval_reward - + return info + def eval_vsbot( self, save_ckpt_fn: Callable = None, @@ -479,17 +84,19 @@ def eval_vsbot( n_episode = self._default_n_episode assert n_episode is not None, "please indicate eval n_episode" envstep_count = 0 + # specifically for vs bot eval_monitor = GoBiggerVectorEvalMonitor(self._env.env_num, n_episode) env_nums = self._env.env_num self._env.reset() self._policy.reset() + + # specifically for vs bot self._bot_policy = GoBiggerBot(env_nums, agent_id=[2, 3]) #TODO only support t2p2 self._bot_policy.reset() # initializations init_obs = self._env.ready_obs - agent_num = len(init_obs[0]['action_mask']) // 2 #TODO only support t2p2 retry_waiting_time = 0.001 while len(init_obs.keys()) != self._env_num: @@ -504,6 +111,8 @@ def eval_vsbot( ) init_obs = self._env.ready_obs + # specifically for vs bot + agent_num = len(init_obs[0]['action_mask']) // 2 #TODO only support t2p2 for i in range(env_nums): for k, v in init_obs[i].items(): if k != 'raw_obs': @@ -533,22 +142,26 @@ def eval_vsbot( ready_env_id = set() remain_episode = n_episode + # specifically for vs bot eat_info = defaultdict() with self._timer: while not eval_monitor.is_finished(): # Get current ready env obs. obs = self._env.ready_obs + # specifically for vs bot raw_obs = [v['raw_obs'] for k, v in obs.items()] new_available_env_id = set(obs.keys()).difference(ready_env_id) ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) remain_episode -= min(len(new_available_env_id), remain_episode) - stack_obs = defaultdict(list) - for env_id in ready_env_id: - for agent_id in range(agent_num): - stack_obs[env_id].append(game_segments[env_id][agent_id].get_obs()) - # stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + if self._multi_agent: + stack_obs = defaultdict(list) + for env_id in ready_env_id: + for agent_id in range(agent_num): + stack_obs[env_id].append(game_segments[env_id][agent_id].get_obs()) + else: + stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} stack_obs = list(stack_obs.values()) action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} @@ -557,8 +170,9 @@ def eval_vsbot( to_play = [to_play_dict[env_id] for env_id in ready_env_id] stack_obs = to_ndarray(stack_obs) - # stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) - # stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() + if self.policy_config.model.model_type and self.policy_config.model.model_type in ['conv', 'mlp']: + stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) + stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() # ============================================================== # bot forward @@ -569,11 +183,13 @@ def eval_vsbot( # policy forward # ============================================================== policy_output = self._policy.forward(stack_obs, action_mask, to_play) - actions_no_env_id = defaultdict(dict) - for k, v in policy_output.items(): - for agent_id, act in enumerate(v['action']): - actions_no_env_id[k][agent_id] = act - # actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} + if self._multi_agent: + actions_no_env_id = defaultdict(dict) + for k, v in policy_output.items(): + for agent_id, act in enumerate(v['action']): + actions_no_env_id[k][agent_id] = act + else: + actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} distributions_dict_no_env_id = {k: v['distributions'] for k, v in policy_output.items()} if self.policy_config.sampled_algo: root_sampled_actions_dict_no_env_id = { @@ -607,6 +223,7 @@ def eval_vsbot( # ============================================================== # Interact with env. # ============================================================== + # specifically for vs bot for env_id, v in bot_actions.items(): actions[env_id].update(v) @@ -614,11 +231,16 @@ def eval_vsbot( for env_id, t in timesteps.items(): obs, reward, done, info = t.obs, t.reward, t.done, t.info - - for agent_id in range(agent_num): - game_segments[env_id][agent_id].append( - actions[env_id][agent_id], to_ndarray(obs['observation'][agent_id]), reward[agent_id], - action_mask_dict[env_id][agent_id], to_play_dict[env_id] + if self._multi_agent: + for agent_id in range(agent_num): + game_segments[env_id][agent_id].append( + actions[env_id][agent_id], to_ndarray(obs['observation'][agent_id]), reward[agent_id], + action_mask_dict[env_id][agent_id], to_play_dict[env_id] + ) + else: + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], + to_play_dict[env_id] ) # NOTE: in evaluator, we only need save the ``o_{t+1} = obs['observation']`` @@ -634,11 +256,13 @@ def eval_vsbot( # Env reset is done by env_manager automatically. self._policy.reset([env_id]) reward = t.info['eval_episode_return'] + # specifically for vs bot bot_reward = t.info['eval_bot_episode_return'] eat_info[env_id] = t.info['eats'] if 'episode_info' in t.info: eval_monitor.update_info(env_id, t.info['episode_info']) eval_monitor.update_reward(env_id, reward) + # specifically for vs bot eval_monitor.update_bot_reward(env_id, bot_reward) self._logger.info( "[EVALUATOR vsbot]env {} finish episode, final reward: {}, current episode: {}".format( @@ -676,22 +300,37 @@ def eval_vsbot( action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) - for agent_id in range(agent_num): - game_segments[env_id][agent_id] = GameSegment( + if self._multi_agent: + for agent_id in range(agent_num): + game_segments[env_id][agent_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + + game_segments[env_id][agent_id].reset( + [ + init_obs[env_id]['observation'][agent_id] + for _ in range(self.policy_config.model.frame_stack_num) + ] + ) + else: + game_segments[env_id] = GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config ) - game_segments[env_id][agent_id].reset( + game_segments[env_id].reset( [ - init_obs[env_id]['observation'][agent_id] + init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num) ] ) # Env reset is done by env_manager automatically. self._policy.reset([env_id]) + # specifically for vs bot self._bot_policy.reset([env_id]) # TODO(pu): subprocess mode, when n_episode > self._env_num, occasionally the ready_env_id=() # and the stack_obs is np.array(None, dtype=object) @@ -700,6 +339,7 @@ def eval_vsbot( envstep_count += 1 duration = self._timer.value episode_return = eval_monitor.get_episode_return() + # specifically for vs bot bot_episode_return = eval_monitor.get_bot_episode_return() info = { 'train_iter': train_iter, @@ -714,11 +354,13 @@ def eval_vsbot( 'reward_std': np.std(episode_return), 'reward_max': np.max(episode_return), 'reward_min': np.min(episode_return), + # specifically for vs bot 'bot_reward_mean': np.mean(bot_episode_return), 'bot_reward_std': np.std(bot_episode_return), 'bot_reward_max': np.max(bot_episode_return), 'bot_reward_min': np.min(bot_episode_return), } + # specifically for vs bot # add eat info for k, v in eat_info.items(): for i in range(len(v)): @@ -755,7 +397,6 @@ def eval_vsbot( ) return stop_flag, eval_reward - class GoBiggerVectorEvalMonitor(VectorEvalMonitor): def __init__(self, env_num: int, n_episode: int) -> None: diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index c6efb8a50..2c3a451ec 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -292,6 +292,9 @@ def collect(self, raise RuntimeError("Please specify collect n_episode") else: n_episode = self._default_n_episode + random_collect_episode_num = 0 + else: + random_collect_episode_num = n_episode assert n_episode >= self._env_num, "Please make sure n_episode >= env_num{}/{}".format(n_episode, self._env_num) if policy_kwargs is None: policy_kwargs = {} @@ -423,7 +426,7 @@ def collect(self, # ============================================================== # policy forward # ============================================================== - policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon) + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, random_collect_episode_num, epsilon) if self._multi_agent: actions_no_env_id = defaultdict(dict) for k, v in policy_output.items(): diff --git a/lzero/worker/muzero_evaluator.py b/lzero/worker/muzero_evaluator.py index 04d6fece9..f2e585a03 100644 --- a/lzero/worker/muzero_evaluator.py +++ b/lzero/worker/muzero_evaluator.py @@ -13,6 +13,7 @@ from ding.worker.collector.base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor from lzero.mcts.buffer.game_segment import GameSegment from lzero.mcts.utils import prepare_observation +from collections import defaultdict class MuZeroEvaluator(ISerialEvaluator): @@ -92,6 +93,11 @@ def __init__( # ============================================================== self.policy_config = policy_config + if 'multi_agent' in self.policy_config.keys() and self.policy_config.multi_agent: + self._multi_agent = self.policy_config.multi_agent + else: + self._multi_agent = False + def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: @@ -183,6 +189,9 @@ def should_eval(self, train_iter: int) -> bool: return False self._last_eval_iter = train_iter return True + + def _add_info(self, last_timestep, info): + pass def eval( self, @@ -234,17 +243,37 @@ def eval( to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} dones = np.array([False for _ in range(env_nums)]) - game_segments = [ - GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) for _ in range(env_nums) - ] - for i in range(env_nums): - game_segments[i].reset( - [to_ndarray(init_obs[i]['observation']) for _ in range(self.policy_config.model.frame_stack_num)] - ) + if self._multi_agent: + agent_num = len(init_obs[0]['action_mask']) + game_segments = [ + [ + GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) for _ in range(agent_num) + ] for _ in range(env_nums) + ] + for env_id in range(env_nums): + for agent_id in range(agent_num): + game_segments[env_id][agent_id].reset( + [ + to_ndarray(init_obs[env_id]['observation'][agent_id]) + for _ in range(self.policy_config.model.frame_stack_num) + ] + ) + else: + game_segments = [ + GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) for _ in range(env_nums) + ] + for i in range(env_nums): + game_segments[i].reset( + [to_ndarray(init_obs[i]['observation']) for _ in range(self.policy_config.model.frame_stack_num)] + ) ready_env_id = set() remain_episode = n_episode @@ -257,7 +286,13 @@ def eval( ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) remain_episode -= min(len(new_available_env_id), remain_episode) - stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + if self._multi_agent: + stack_obs = defaultdict(list) + for env_id in ready_env_id: + for agent_id in range(agent_num): + stack_obs[env_id].append(game_segments[env_id][agent_id].get_obs()) + else: + stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} stack_obs = list(stack_obs.values()) action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} @@ -266,15 +301,21 @@ def eval( to_play = [to_play_dict[env_id] for env_id in ready_env_id] stack_obs = to_ndarray(stack_obs) - stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) - stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() + if self.policy_config.model.model_type and self.policy_config.model.model_type in ['conv', 'mlp']: + stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) + stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() # ============================================================== # policy forward # ============================================================== policy_output = self._policy.forward(stack_obs, action_mask, to_play) - - actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} + if self._multi_agent: + actions_no_env_id = defaultdict(dict) + for k, v in policy_output.items(): + for agent_id, act in enumerate(v['action']): + actions_no_env_id[k][agent_id] = act + else: + actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} distributions_dict_no_env_id = {k: v['distributions'] for k, v in policy_output.items()} if self.policy_config.sampled_algo: root_sampled_actions_dict_no_env_id = { @@ -312,11 +353,17 @@ def eval( for env_id, t in timesteps.items(): obs, reward, done, info = t.obs, t.reward, t.done, t.info - - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id] - ) + if self._multi_agent: + for agent_id in range(agent_num): + game_segments[env_id][agent_id].append( + actions[env_id][agent_id], to_ndarray(obs['observation'][agent_id]), reward[agent_id], + action_mask_dict[env_id][agent_id], to_play_dict[env_id] + ) + else: + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], + to_play_dict[env_id] + ) # NOTE: in evaluator, we only need save the ``o_{t+1} = obs['observation']`` # game_segments[env_id].obs_segment.append(to_ndarray(obs['observation'])) @@ -370,18 +417,33 @@ def eval( action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) - game_segments[env_id] = GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) + if self._multi_agent: + for agent_id in range(agent_num): + game_segments[env_id][agent_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) - game_segments[env_id].reset( - [ - init_obs[env_id]['observation'] - for _ in range(self.policy_config.model.frame_stack_num) - ] - ) + game_segments[env_id][agent_id].reset( + [ + init_obs[env_id]['observation'][agent_id] + for _ in range(self.policy_config.model.frame_stack_num) + ] + ) + else: + game_segments[env_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + + game_segments[env_id].reset( + [ + init_obs[env_id]['observation'] + for _ in range(self.policy_config.model.frame_stack_num) + ] + ) # Env reset is done by env_manager automatically. self._policy.reset([env_id]) @@ -405,8 +467,10 @@ def eval( 'reward_std': np.std(episode_return), 'reward_max': np.max(episode_return), 'reward_min': np.min(episode_return), - # 'each_reward': episode_return, } + + # add other info + info = self._add_info(t, info) episode_info = eval_monitor.get_episode_info() if episode_info is not None: info.update(episode_info) diff --git a/zoo/atari/config/atari_efficientzero_config.py b/zoo/atari/config/atari_efficientzero_config.py index 6d6af6a1f..8a6b3e4d2 100644 --- a/zoo/atari/config/atari_efficientzero_config.py +++ b/zoo/atari/config/atari_efficientzero_config.py @@ -25,6 +25,9 @@ batch_size = 256 max_env_step = int(1e6) reanalyze_ratio = 0. + +random_collect_episode_num=0 +eps_greedy_exploration_in_collect = False # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -50,8 +53,17 @@ norm_type='BN', ), cuda=True, + ignore_done=True, env_type='not_board_games', game_segment_length=400, + random_collect_episode_num=random_collect_episode_num, + eps=dict( + eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, + type='linear', + start=1., + end=0.05, + decay=int(1e5), + ), use_augmentation=True, update_per_collect=update_per_collect, batch_size=batch_size, diff --git a/zoo/gobigger/entry/train_muzero_gobigger.py b/zoo/gobigger/entry/train_muzero_gobigger.py index 9f83ceb0e..141a650f1 100644 --- a/zoo/gobigger/entry/train_muzero_gobigger.py +++ b/zoo/gobigger/entry/train_muzero_gobigger.py @@ -13,7 +13,7 @@ from tensorboardX import SummaryWriter import copy from ding.rl_utils import get_epsilon_greedy_fn -from lzero.entry.utils import log_buffer_memory_usage, random_collect +from lzero.entry.utils import log_buffer_memory_usage from lzero.policy import visit_count_temperature from lzero.worker import GoBiggerMuZeroCollector, GoBiggerMuZeroEvaluator @@ -126,7 +126,14 @@ def train_muzero_gobigger( # Learner's before_run hook. learner.call_hook('before_run') if cfg.policy.random_collect_episode_num > 0: - random_collect(cfg.policy, policy, collector, collector_env, replay_buffer) + collect_kwargs = {} + collect_kwargs['temperature'] = 1 + collect_kwargs['epsilon'] = 0.0 + new_data = collector.collect(n_episode=cfg.policy.random_collect_episode_num, train_iter=0, policy_kwargs=collect_kwargs) + # save returned new_data collected by the collector + replay_buffer.push_game_segments(new_data) + # remove the oldest data if the replay buffer is full. + replay_buffer.remove_oldest_data_to_fit() # reset the random_collect_episode_num to 0 cfg.policy.random_collect_episode_num = 0 @@ -151,9 +158,6 @@ def train_muzero_gobigger( collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) else: collect_kwargs['epsilon'] = 0.0 - - stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) - stop, reward = vsbot_evaluator.eval_vsbot(learner.save_checkpoint, learner.train_iter, collector.envstep) # Evaluate policy performance. if evaluator.should_eval(learner.train_iter): From 58281d6350d30bc34e7b35db97a0b854dacabcf8 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Mon, 17 Jul 2023 17:45:32 +0800 Subject: [PATCH 32/54] fix(yzj): fix random collect in gobigger ez policy --- lzero/policy/gobigger_efficientzero.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lzero/policy/gobigger_efficientzero.py b/lzero/policy/gobigger_efficientzero.py index 0b8885fbc..02e450739 100644 --- a/lzero/policy/gobigger_efficientzero.py +++ b/lzero/policy/gobigger_efficientzero.py @@ -402,11 +402,11 @@ def _forward_collect( for i in range(batch_size): distributions, value = roots_visit_count_distributions[i], roots_values[i] if random_collect_episode_num>0: # random collect - distributions, value = roots_visit_count_distributions[i], roots_values[i] action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( distributions, temperature=self.collect_mcts_temperature, deterministic=False ) - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + # action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + action = np.random.choice(legal_actions[i]) else: if self._cfg.eps.eps_greedy_exploration_in_collect: # eps greedy collect action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( From c29abafc07e903517f5f07ed46f82fe752a78b0b Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Fri, 4 Aug 2023 12:38:15 +0800 Subject: [PATCH 33/54] feature(yzj): add peetingzoo mz/ez algo, add multi agent buffer/policy/model, and polish gobigger code --- lzero/entry/train_muzero.py | 25 +- lzero/mcts/buffer/__init__.py | 5 +- .../mcts/buffer/game_buffer_efficientzero.py | 8 +- lzero/mcts/buffer/game_buffer_muzero.py | 8 +- .../buffer/gobigger_game_buffer_muzero.py | 703 ------------------ ...igger_game_buffer_sampled_efficientzero.py | 593 --------------- ... multi_agent_game_buffer_efficientzero.py} | 206 +---- .../buffer/multi_agent_game_buffer_muzero.py | 261 +++++++ lzero/mcts/utils.py | 5 +- lzero/model/efficientzero_model_structure.py | 187 +++++ .../model/gobigger/{network => }/__init__.py | 0 lzero/model/gobigger/{network => }/encoder.py | 0 .../gobigger/gobigger_efficientzero_model.py | 477 ------------ .../{network => }/gobigger_encoder.py | 0 lzero/model/gobigger/gobigger_muzero_model.py | 449 ----------- .../gobigger_sampled_efficientzero_model.py | 524 ------------- lzero/model/muzero_model_structure.py | 183 +++++ lzero/model/petting_zoo/__init__.py | 0 lzero/model/petting_zoo/encoder.py | 12 + lzero/policy/efficientzero.py | 16 +- lzero/policy/gobigger_efficientzero.py | 504 ------------- lzero/policy/gobigger_muzero.py | 445 ----------- lzero/policy/multi_agent_efficientzero.py | 229 ++++++ lzero/policy/multi_agent_muzero.py | 225 ++++++ ...policy.py => multi_agent_random_policy.py} | 78 +- lzero/policy/muzero.py | 16 +- lzero/policy/random_policy.py | 7 + lzero/policy/utils.py | 14 + lzero/worker/__init__.py | 4 +- ...tor.py => multi_agent_muzero_collector.py} | 8 +- lzero/worker/muzero_collector.py | 9 +- lzero/worker/muzero_evaluator.py | 4 +- .../config/gobigger_efficientzero_config.py | 27 +- zoo/gobigger/config/gobigger_muzero_config.py | 27 +- zoo/gobigger/entry/train_muzero_gobigger.py | 46 +- zoo/petting_zoo/__init__.py | 0 zoo/petting_zoo/config/__init__.py | 1 + .../config/ptz_simple_spread_ez_config.py | 114 +++ .../config/ptz_simple_spread_mz_config.py | 116 +++ zoo/petting_zoo/envs/__init__.py | 0 .../envs/petting_zoo_simple_spread_env.py | 368 +++++++++ .../test_petting_zoo_simple_spread_env.py | 133 ++++ 42 files changed, 2036 insertions(+), 4001 deletions(-) delete mode 100644 lzero/mcts/buffer/gobigger_game_buffer_muzero.py delete mode 100644 lzero/mcts/buffer/gobigger_game_buffer_sampled_efficientzero.py rename lzero/mcts/buffer/{gobigger_game_buffer_efficientzero.py => multi_agent_game_buffer_efficientzero.py} (54%) create mode 100644 lzero/mcts/buffer/multi_agent_game_buffer_muzero.py create mode 100644 lzero/model/efficientzero_model_structure.py rename lzero/model/gobigger/{network => }/__init__.py (100%) rename lzero/model/gobigger/{network => }/encoder.py (100%) delete mode 100644 lzero/model/gobigger/gobigger_efficientzero_model.py rename lzero/model/gobigger/{network => }/gobigger_encoder.py (100%) delete mode 100644 lzero/model/gobigger/gobigger_muzero_model.py delete mode 100644 lzero/model/gobigger/gobigger_sampled_efficientzero_model.py create mode 100644 lzero/model/muzero_model_structure.py create mode 100644 lzero/model/petting_zoo/__init__.py create mode 100644 lzero/model/petting_zoo/encoder.py delete mode 100644 lzero/policy/gobigger_efficientzero.py delete mode 100644 lzero/policy/gobigger_muzero.py create mode 100644 lzero/policy/multi_agent_efficientzero.py create mode 100644 lzero/policy/multi_agent_muzero.py rename lzero/policy/{gobigger_random_policy.py => multi_agent_random_policy.py} (65%) rename lzero/worker/{gobigger_muzero_collector.py => multi_agent_muzero_collector.py} (95%) create mode 100644 zoo/petting_zoo/__init__.py create mode 100644 zoo/petting_zoo/config/__init__.py create mode 100644 zoo/petting_zoo/config/ptz_simple_spread_ez_config.py create mode 100644 zoo/petting_zoo/config/ptz_simple_spread_mz_config.py create mode 100644 zoo/petting_zoo/envs/__init__.py create mode 100644 zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py create mode 100644 zoo/petting_zoo/envs/test_petting_zoo_simple_spread_env.py diff --git a/lzero/entry/train_muzero.py b/lzero/entry/train_muzero.py index 6960619ad..3ffd4202c 100644 --- a/lzero/entry/train_muzero.py +++ b/lzero/entry/train_muzero.py @@ -15,9 +15,6 @@ from lzero.entry.utils import log_buffer_memory_usage from lzero.policy import visit_count_temperature -from lzero.policy.random_policy import LightZeroRandomPolicy -from lzero.worker import MuZeroCollector as Collector -from lzero.worker import MuZeroEvaluator as Evaluator from .utils import random_collect @@ -47,8 +44,8 @@ def train_muzero( """ cfg, create_cfg = input_cfg - assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero'], \ - "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero'" + assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'multi_agent_efficientzero', 'multi_agent_muzero'], \ + "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'multi_agent_efficientzero', 'multi_agent_muzero'" if create_cfg.policy.type == 'muzero': from lzero.mcts import MuZeroGameBuffer as GameBuffer @@ -58,6 +55,10 @@ def train_muzero( from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer elif create_cfg.policy.type == 'gumbel_muzero': from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'multi_agent_muzero': + from lzero.mcts import MultiAgentMuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'multi_agent_efficientzero': + from lzero.mcts import MultiAgentSampledEfficientZeroGameBuffer as GameBuffer if cfg.policy.cuda and torch.cuda.is_available(): cfg.policy.device = 'cuda' @@ -92,6 +93,14 @@ def train_muzero( batch_size = policy_config.batch_size # specific game buffer for MCTS+RL algorithms replay_buffer = GameBuffer(policy_config) + + if policy_config.multi_agent: + from lzero.worker import MultiAgentMuZeroCollector as Collector + from lzero.worker import MuZeroEvaluator as Evaluator + else: + from lzero.worker import MuZeroCollector as Collector + from lzero.worker import MuZeroEvaluator as Evaluator + collector = Collector( env=collector_env, policy=policy.collect_mode, @@ -123,7 +132,11 @@ def train_muzero( # Exploration: The collection of random data aids the agent in exploring the environment and prevents premature convergence to a suboptimal policy. # Comparation: The agent's performance during random action-taking can be used as a reference point to evaluate the efficacy of reinforcement learning algorithms. if cfg.policy.random_collect_episode_num > 0: - random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer) + if policy_config.multi_agent: + from lzero.policy.multi_agent_random_policy import MultiAgentLightZeroRandomPolicy as RandomPolicy + else: + from lzero.policy.random_policy import LightZeroRandomPolicy as RandomPolicy + random_collect(cfg.policy, policy, RandomPolicy, collector, collector_env, replay_buffer) while True: log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) diff --git a/lzero/mcts/buffer/__init__.py b/lzero/mcts/buffer/__init__.py index 1ccee0471..78864d59e 100644 --- a/lzero/mcts/buffer/__init__.py +++ b/lzero/mcts/buffer/__init__.py @@ -1,7 +1,6 @@ from .game_buffer_muzero import MuZeroGameBuffer from .game_buffer_efficientzero import EfficientZeroGameBuffer from .game_buffer_sampled_efficientzero import SampledEfficientZeroGameBuffer -from .gobigger_game_buffer_muzero import GoBiggerMuZeroGameBuffer -from .gobigger_game_buffer_efficientzero import GoBiggerEfficientZeroGameBuffer -from .gobigger_game_buffer_sampled_efficientzero import GoBiggerSampledEfficientZeroGameBuffer from .game_buffer_gumbel_muzero import GumbelMuZeroGameBuffer +from .multi_agent_game_buffer_muzero import MultiAgentMuZeroGameBuffer +from .multi_agent_game_buffer_efficientzero import MultiAgentSampledEfficientZeroGameBuffer \ No newline at end of file diff --git a/lzero/mcts/buffer/game_buffer_efficientzero.py b/lzero/mcts/buffer/game_buffer_efficientzero.py index 4ab12259e..75e4649d9 100644 --- a/lzero/mcts/buffer/game_buffer_efficientzero.py +++ b/lzero/mcts/buffer/game_buffer_efficientzero.py @@ -44,6 +44,8 @@ def __init__(self, cfg: dict): self.base_idx = 0 self.clear_time = 0 + self.tmp_obs = None # for value obs list [46 + 4(td_step)] not < 50(game_segment) + def sample(self, batch_size: int, policy: Any) -> List[Any]: """ Overview: @@ -100,7 +102,6 @@ def _prepare_reward_value_context( - reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, to_play_segment """ - zero_obs = game_segment_list[0].zero_obs() value_obs_list = [] # the value is valid or not (out of trajectory) value_mask = [] @@ -148,11 +149,12 @@ def _prepare_reward_value_context( end_index = beg_index + self._cfg.model.frame_stack_num # the stacked obs in time t obs = game_obs[beg_index:end_index] + self.tmp_obs = obs # will be masked else: value_mask.append(0) - obs = zero_obs + obs = self.tmp_obs # will be masked - value_obs_list.append(obs) + value_obs_list.append(obs.tolist()) reward_value_context = [ value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index daddf6f9f..6aaf04dad 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -48,6 +48,8 @@ def __init__(self, cfg: dict): self.game_pos_priorities = [] self.game_segment_game_pos_look_up = [] + self.tmp_obs = None # a tmp value which records obs when value obs list [current_index + 4(td_step)] > 50(game_segment) + def sample( self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] ) -> List[Any]: @@ -198,7 +200,6 @@ def _prepare_reward_value_context( - reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, to_play_segment """ - zero_obs = game_segment_list[0].zero_obs() value_obs_list = [] # the value is valid or not (out of game_segment) value_mask = [] @@ -238,11 +239,12 @@ def _prepare_reward_value_context( end_index = beg_index + self._cfg.model.frame_stack_num # the stacked obs in time t obs = game_obs[beg_index:end_index] + self.tmp_obs = obs # will be masked else: value_mask.append(0) - obs = zero_obs + obs = self.tmp_obs # will be masked - value_obs_list.append(obs) + value_obs_list.append(obs.tolist()) reward_value_context = [ value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, diff --git a/lzero/mcts/buffer/gobigger_game_buffer_muzero.py b/lzero/mcts/buffer/gobigger_game_buffer_muzero.py deleted file mode 100644 index 8563789ff..000000000 --- a/lzero/mcts/buffer/gobigger_game_buffer_muzero.py +++ /dev/null @@ -1,703 +0,0 @@ -from typing import Any, List, Tuple, Union, TYPE_CHECKING, Optional - -import numpy as np -import torch -from ding.utils import BUFFER_REGISTRY - -from lzero.mcts.tree_search.mcts_ctree import MuZeroMCTSCtree as MCTSCtree -from lzero.mcts.tree_search.mcts_ptree import MuZeroMCTSPtree as MCTSPtree -from lzero.mcts.utils import prepare_observation -from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform -from .game_buffer import GameBuffer -from ding.torch_utils import to_device, to_tensor - -if TYPE_CHECKING: - from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy - -@BUFFER_REGISTRY.register('gobigger_game_buffer_muzero') -class GoBiggerMuZeroGameBuffer(GameBuffer): - """ - Overview: - The specific game buffer for GoBigger MuZero policy. - """ - - def __init__(self, cfg: dict): - super().__init__(cfg) - """ - Overview: - Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key - in the default configuration, the user-provided value will override the default configuration. Otherwise, - the default configuration will be used. - """ - default_config = self.default_config() - default_config.update(cfg) - self._cfg = default_config - assert self._cfg.env_type in ['not_board_games', 'board_games'] - self.replay_buffer_size = self._cfg.replay_buffer_size - self.batch_size = self._cfg.batch_size - self._alpha = self._cfg.priority_prob_alpha - self._beta = self._cfg.priority_prob_beta - - self.keep_ratio = 1 - self.model_update_interval = 10 - self.num_of_collected_episodes = 0 - self.base_idx = 0 - self.clear_time = 0 - - self.game_segment_buffer = [] - self.game_pos_priorities = [] - self.game_segment_game_pos_look_up = [] - - self.tmp_obs = None # a tmp value which records obs when value obs list [current_index + 4(td_step)] > 50(game_segment) - - def sample( - self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] - ) -> List[Any]: - """ - Overview: - sample data from ``GameBuffer`` and prepare the current and target batch for training. - Arguments: - - batch_size (:obj:`int`): batch size. - - policy (:obj:`Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]`): policy. - Returns: - - train_data (:obj:`List`): List of train data, including current_batch and target_batch. - """ - policy._target_model.to(self._cfg.device) - policy._target_model.eval() - - # obtain the current_batch and prepare target context - reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch( - batch_size, self._cfg.reanalyze_ratio - ) - # target reward, target value - batch_rewards, batch_target_values = self._compute_target_reward_value( - reward_value_context, policy._target_model - ) - # target policy - batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model) - batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( - policy_non_re_context, self._cfg.model.action_space_size - ) - - # fusion of batch_target_policies_re and batch_target_policies_non_re to batch_target_policies - if 0 < self._cfg.reanalyze_ratio < 1: - batch_target_policies = np.concatenate([batch_target_policies_re, batch_target_policies_non_re]) - elif self._cfg.reanalyze_ratio == 1: - batch_target_policies = batch_target_policies_re - elif self._cfg.reanalyze_ratio == 0: - batch_target_policies = batch_target_policies_non_re - - target_batch = [batch_rewards, batch_target_values, batch_target_policies] - - # a batch contains the current_batch and the target_batch - train_data = [current_batch, target_batch] - return train_data - - def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: - """ - Overview: - first sample orig_data through ``_sample_orig_data()``, - then prepare the context of a batch: - reward_value_context: the context of reanalyzed value targets - policy_re_context: the context of reanalyzed policy targets - policy_non_re_context: the context of non-reanalyzed policy targets - current_batch: the inputs of batch - Arguments: - - batch_size (:obj:`int`): the batch size of orig_data from replay buffer. - - reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed) - Returns: - - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch - """ - # obtain the batch context from replay buffer - orig_data = self._sample_orig_data(batch_size) - game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data - batch_size = len(batch_index_list) - obs_list, action_list, mask_list = [], [], [] - # prepare the inputs of a batch - for i in range(batch_size): - game = game_segment_list[i] - pos_in_game_segment = pos_in_game_segment_list[i] - - actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment + - self._cfg.num_unroll_steps].tolist() - # add mask for invalid actions (out of trajectory), 1 for valid, 0 for invalid - mask_tmp = [1. for i in range(len(actions_tmp))] - mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] - - # pad random action - actions_tmp += [ - np.random.randint(0, game.action_space_size) - for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) - ] - - # obtain the input observations - # pad if length of obs in game_segment is less than stack+num_unroll_steps - # e.g. stack+num_unroll_steps = 4+5 - obs_list.append( - game_segment_list[i].get_unroll_obs( - pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True - ) - ) - action_list.append(actions_tmp) - mask_list.append(mask_tmp) - - # formalize the input observations - # obs_list = prepare_observation(obs_list, self._cfg.model.model_type) - - # formalize the inputs of a batch - current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list] - for i in range(len(current_batch)): - current_batch[i] = np.asarray(current_batch[i]) - - total_transitions = self.get_num_of_transitions() - - # obtain the context of value targets - reward_value_context = self._prepare_reward_value_context( - batch_index_list, game_segment_list, pos_in_game_segment_list, total_transitions - ) - """ - only reanalyze recent reanalyze_ratio (e.g. 50%) data - if self._cfg.reanalyze_outdated is True, batch_index_list is sorted according to its generated env_steps - 0: reanalyze_num -> reanalyzed policy, reanalyze_num:end -> non reanalyzed policy - """ - reanalyze_num = int(batch_size * reanalyze_ratio) - # reanalyzed policy - if reanalyze_num > 0: - # obtain the context of reanalyzed policy targets - policy_re_context = self._prepare_policy_reanalyzed_context( - batch_index_list[:reanalyze_num], game_segment_list[:reanalyze_num], - pos_in_game_segment_list[:reanalyze_num] - ) - else: - policy_re_context = None - - # non reanalyzed policy - if reanalyze_num < batch_size: - # obtain the context of non-reanalyzed policy targets - policy_non_re_context = self._prepare_policy_non_reanalyzed_context( - batch_index_list[reanalyze_num:], game_segment_list[reanalyze_num:], - pos_in_game_segment_list[reanalyze_num:] - ) - else: - policy_non_re_context = None - - context = reward_value_context, policy_re_context, policy_non_re_context, current_batch - return context - - def _prepare_reward_value_context( - self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[Any], - total_transitions: int - ) -> List[Any]: - """ - Overview: - prepare the context of rewards and values for calculating TD value target in reanalyzing part. - Arguments: - - batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer - - game_segment_list (:obj:`list`): list of game segments - - pos_in_game_segment_list (:obj:`list`): list of transition index in game_segment - - total_transitions (:obj:`int`): number of collected transitions - Returns: - - reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, - td_steps_list, action_mask_segment, to_play_segment - """ - value_obs_list = [] - # the value is valid or not (out of game_segment) - value_mask = [] - rewards_list = [] - game_segment_lens = [] - # for board games - action_mask_segment, to_play_segment = [], [] - - td_steps_list = [] - for game_segment, state_index, idx in zip(game_segment_list, pos_in_game_segment_list, batch_index_list): - game_segment_len = len(game_segment) - game_segment_lens.append(game_segment_len) - - td_steps = np.clip(self._cfg.td_steps, 1, max(1, game_segment_len - state_index)).astype(np.int32) - - # prepare the corresponding observations for bootstrapped values o_{t+k} - # o[t+ td_steps, t + td_steps + stack frames + num_unroll_steps] - # t=2+3 -> o[2+3, 2+3+4+5] -> o[5, 14] - game_obs = game_segment.get_unroll_obs(state_index + td_steps, self._cfg.num_unroll_steps) - - rewards_list.append(game_segment.reward_segment) - - # for board games - action_mask_segment.append(game_segment.action_mask_segment) - to_play_segment.append(game_segment.to_play_segment) - - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): - # get the bootstrapped target obs - td_steps_list.append(td_steps) - # index of bootstrapped obs o_{t+td_steps} - bootstrap_index = current_index + td_steps - - if bootstrap_index < game_segment_len: - value_mask.append(1) - # beg_index = bootstrap_index - (state_index + td_steps), max of beg_index is num_unroll_steps - beg_index = current_index - state_index - end_index = beg_index + self._cfg.model.frame_stack_num - # the stacked obs in time t - obs = game_obs[beg_index:end_index] - self.tmp_obs = obs # will be masked - else: - value_mask.append(0) - obs = self.tmp_obs # will be masked - value_obs_list.append(obs.tolist()) - - reward_value_context = [ - value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, - action_mask_segment, to_play_segment - ] - return reward_value_context - - def _prepare_policy_non_reanalyzed_context( - self, batch_index_list: List[int], game_segment_list: List[Any], pos_in_game_segment_list: List[int] - ) -> List[Any]: - """ - Overview: - prepare the context of policies for calculating policy target in non-reanalyzing part, just return the policy in self-play - Arguments: - - batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer - - game_segment_list (:obj:`list`): list of game segments - - pos_in_game_segment_list (:obj:`list`): list transition index in game - Returns: - - policy_non_re_context (:obj:`list`): pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment - """ - child_visits = [] - game_segment_lens = [] - # for board games - action_mask_segment, to_play_segment = [], [] - - for game_segment, state_index, idx in zip(game_segment_list, pos_in_game_segment_list, batch_index_list): - game_segment_len = len(game_segment) - game_segment_lens.append(game_segment_len) - # for board games - action_mask_segment.append(game_segment.action_mask_segment) - to_play_segment.append(game_segment.to_play_segment) - - child_visits.append(game_segment.child_visit_segment) - - policy_non_re_context = [ - pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment - ] - return policy_non_re_context - - def _prepare_policy_reanalyzed_context( - self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str] - ) -> List[Any]: - """ - Overview: - prepare the context of policies for calculating policy target in reanalyzing part. - Arguments: - - batch_index_list (:obj:'list'): start transition index in the replay buffer - - game_segment_list (:obj:'list'): list of game segments - - pos_in_game_segment_list (:obj:'list'): position of transition index in one game history - Returns: - - policy_re_context (:obj:`list`): policy_obs_list, policy_mask, pos_in_game_segment_list, indices, - child_visits, game_segment_lens, action_mask_segment, to_play_segment - """ - zero_obs = game_segment_list[0].zero_obs() - with torch.no_grad(): - # for policy - policy_obs_list = [] - policy_mask = [] - # 0 -> Invalid target policy for padding outside of game segments, - # 1 -> Previous target policy for game segments. - rewards, child_visits, game_segment_lens = [], [], [] - # for board games - action_mask_segment, to_play_segment = [], [] - for game_segment, state_index in zip(game_segment_list, pos_in_game_segment_list): - game_segment_len = len(game_segment) - game_segment_lens.append(game_segment_len) - rewards.append(game_segment.reward_segment) - # for board games - action_mask_segment.append(game_segment.action_mask_segment) - to_play_segment.append(game_segment.to_play_segment) - - child_visits.append(game_segment.child_visit_segment) - # prepare the corresponding observations - game_obs = game_segment.get_unroll_obs(state_index, self._cfg.num_unroll_steps) - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): - - if current_index < game_segment_len: - policy_mask.append(1) - beg_index = current_index - state_index - end_index = beg_index + self._cfg.model.frame_stack_num - obs = game_obs[beg_index:end_index] - else: - policy_mask.append(0) - obs = zero_obs - policy_obs_list.append(obs) - - policy_re_context = [ - policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, - action_mask_segment, to_play_segment - ] - return policy_re_context - - def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any) -> Tuple[Any, Any]: - """ - Overview: - prepare reward and value targets from the context of rewards and values. - Arguments: - - reward_value_context (:obj:'list'): the reward value context - - model (:obj:'torch.tensor'):model of the target model - Returns: - - batch_value_prefixs (:obj:'np.ndarray): batch of value prefix - - batch_target_values (:obj:'np.ndarray): batch of value estimation - """ - value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, \ - to_play_segment = reward_value_context # noqa - # transition_batch_size = game_segment_batch_size * (num_unroll_steps+1) - transition_batch_size = len(value_obs_list) - game_segment_batch_size = len(pos_in_game_segment_list) - - to_play, action_mask = self._preprocess_to_play_and_action_mask( - game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list - ) - if self._cfg.model.continuous_action_space is True: - # when the action space of the environment is continuous, action_mask[:] is None. - action_mask = [ - list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) - ] - # NOTE: in continuous action space env: we set all legal_actions as -1 - legal_actions = [ - [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) - ] - else: - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] - - batch_target_values, batch_rewards = [], [] - with torch.no_grad(): - # value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) - # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors - slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) - network_output = [] - for i in range(slices): - beg_index = self._cfg.mini_infer_size * i - end_index = self._cfg.mini_infer_size * (i + 1) - - # m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() - m_obs = value_obs_list[beg_index:end_index] - m_obs = to_tensor(m_obs) - m_obs = sum(m_obs, []) - m_obs = to_device(m_obs, self._cfg.device) - - # calculate the target value - m_output = model.initial_inference(m_obs) - - if not model.training: - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), - m_output.policy_logits - ] - ) - - network_output.append(m_output) - - # concat the output slices after model inference - if self._cfg.use_root_value: - # use the root values from MCTS, as in EfficiientZero - # the root values have limited improvement but require much more GPU actors; - _, reward_pool, policy_logits_pool, latent_state_roots = concat_output( - network_output, data_type='muzero' - ) - reward_pool = reward_pool.squeeze().tolist() - policy_logits_pool = policy_logits_pool.tolist() - noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) - ).astype(np.float32).tolist() for j in range(transition_batch_size) - ] - if self._cfg.mcts_ctree: - # cpp mcts_tree - roots = MCTSCtree.roots(transition_batch_size, legal_actions) - roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) - # do MCTS for a new policy with the recent target model - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) - else: - # python mcts_tree - roots = MCTSPtree.roots(transition_batch_size, legal_actions) - roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) - # do MCTS for a new policy with the recent target model - MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) - - roots_values = roots.get_values() - value_list = np.array(roots_values) - else: - # use the predicted values - value_list = concat_output_value(network_output) - - # get last state value - if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: - # TODO(pu): for board_games, very important, to check - value_list = value_list.reshape(-1) * np.array( - [ - self._cfg.discount_factor ** td_steps_list[i] if int(td_steps_list[i]) % - 2 == 0 else -self._cfg.discount_factor ** td_steps_list[i] - for i in range(transition_batch_size) - ] - ) - else: - value_list = value_list.reshape(-1) * ( - np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list - ) - - value_list = value_list * np.array(value_mask) - value_list = value_list.tolist() - horizon_id, value_index = 0, 0 - - for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list, - pos_in_game_segment_list, - to_play_segment): - target_values = [] - target_rewards = [] - base_index = state_index - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): - bootstrap_index = current_index + td_steps_list[value_index] - # for i, reward in enumerate(game.rewards[current_index:bootstrap_index]): - for i, reward in enumerate(reward_list[current_index:bootstrap_index]): - if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: - # TODO(pu): for board_games, very important, to check - if to_play_list[base_index] == to_play_list[i]: - value_list[value_index] += reward * self._cfg.discount_factor ** i - else: - value_list[value_index] += -reward * self._cfg.discount_factor ** i - else: - value_list[value_index] += reward * self._cfg.discount_factor ** i - horizon_id += 1 - - if current_index < game_segment_len_non_re: - target_values.append(value_list[value_index]) - target_rewards.append(reward_list[current_index]) - else: - target_values.append(0) - target_rewards.append(0.0) - # TODO: check - # target_rewards.append(reward) - value_index += 1 - - batch_rewards.append(target_rewards) - batch_target_values.append(target_values) - - batch_rewards = np.asarray(batch_rewards, dtype=object) - batch_target_values = np.asarray(batch_target_values, dtype=object) - return batch_rewards, batch_target_values - - def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any) -> np.ndarray: - """ - Overview: - prepare policy targets from the reanalyzed context of policies - Arguments: - - policy_re_context (:obj:`List`): List of policy context to reanalyzed - Returns: - - batch_target_policies_re - """ - if policy_re_context is None: - return [] - batch_target_policies_re = [] - - # for board games - policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, action_mask_segment, \ - to_play_segment = policy_re_context # noqa - # transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1) - transition_batch_size = len(policy_obs_list) - game_segment_batch_size = len(pos_in_game_segment_list) - - to_play, action_mask = self._preprocess_to_play_and_action_mask( - game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list - ) - - if self._cfg.model.continuous_action_space is True: - # when the action space of the environment is continuous, action_mask[:] is None. - action_mask = [ - list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) - ] - # NOTE: in continuous action space env: we set all legal_actions as -1 - legal_actions = [ - [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) - ] - else: - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] - - with torch.no_grad(): - policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type) - # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors - slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) - network_output = [] - for i in range(slices): - beg_index = self._cfg.mini_infer_size * i - end_index = self._cfg.mini_infer_size * (i + 1) - m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device).float() - m_output = model.initial_inference(m_obs) - if not model.training: - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), - m_output.policy_logits - ] - ) - - network_output.append(m_output) - - _, reward_pool, policy_logits_pool, latent_state_roots = concat_output(network_output, data_type='muzero') - reward_pool = reward_pool.squeeze().tolist() - policy_logits_pool = policy_logits_pool.tolist() - noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size - ).astype(np.float32).tolist() for _ in range(transition_batch_size) - ] - if self._cfg.mcts_ctree: - # cpp mcts_tree - roots = MCTSCtree.roots(transition_batch_size, legal_actions) - roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) - # do MCTS for a new policy with the recent target model - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) - else: - # python mcts_tree - roots = MCTSPtree.roots(transition_batch_size, legal_actions) - roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) - # do MCTS for a new policy with the recent target model - MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) - - roots_legal_actions_list = legal_actions - roots_distributions = roots.get_distributions() - policy_index = 0 - for state_index, game_index in zip(pos_in_game_segment_list, batch_index_list): - target_policies = [] - - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): - distributions = roots_distributions[policy_index] - - if policy_mask[policy_index] == 0: - # NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0 - target_policies.append([0 for _ in range(self._cfg.model.action_space_size)]) - else: - if distributions is None: - # if at some obs, the legal_action is None, add the fake target_policy - target_policies.append( - list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size) - ) - else: - if self._cfg.env_type == 'not_board_games': - # for atari/classic_control/box2d environments that only have one player. - sum_visits = sum(distributions) - policy = [visit_count / sum_visits for visit_count in distributions] - target_policies.append(policy) - else: - # for board games that have two players and legal_actions is dy - policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)] - # to make sure target_policies have the same dimension - sum_visits = sum(distributions) - policy = [visit_count / sum_visits for visit_count in distributions] - for index, legal_action in enumerate(roots_legal_actions_list[policy_index]): - policy_tmp[legal_action] = policy[index] - target_policies.append(policy_tmp) - - policy_index += 1 - - batch_target_policies_re.append(target_policies) - - batch_target_policies_re = np.array(batch_target_policies_re) - - return batch_target_policies_re - - def _compute_target_policy_non_reanalyzed( - self, policy_non_re_context: List[Any], policy_shape: Optional[int] - ) -> np.ndarray: - """ - Overview: - prepare policy targets from the non-reanalyzed context of policies - Arguments: - - policy_non_re_context (:obj:`List`): List containing: - - pos_in_game_segment_list - - child_visits - - game_segment_lens - - action_mask_segment - - to_play_segment - - policy_shape: self._cfg.model.action_space_size - Returns: - - batch_target_policies_non_re - """ - batch_target_policies_non_re = [] - if policy_non_re_context is None: - return batch_target_policies_non_re - - pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment = policy_non_re_context - game_segment_batch_size = len(pos_in_game_segment_list) - transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1) - - to_play, action_mask = self._preprocess_to_play_and_action_mask( - game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list - ) - - if self._cfg.model.continuous_action_space is True: - # when the action space of the environment is continuous, action_mask[:] is None. - action_mask = [ - list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) - ] - # NOTE: in continuous action space env: we set all legal_actions as -1 - legal_actions = [ - [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) - ] - else: - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] - - with torch.no_grad(): - policy_index = 0 - # 0 -> Invalid target policy for padding outside of game segments, - # 1 -> Previous target policy for game segments. - policy_mask = [] - for game_segment_len, child_visit, state_index in zip(game_segment_lens, child_visits, - pos_in_game_segment_list): - target_policies = [] - - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): - if current_index < game_segment_len: - policy_mask.append(1) - # NOTE: child_visit is already a distribution - distributions = child_visit[current_index] - if self._cfg.env_type == 'not_board_games': - # for atari/classic_control/box2d environments that only have one player. - target_policies.append(distributions) - else: - # for board games that have two players. - policy_tmp = [0 for _ in range(policy_shape)] - for index, legal_action in enumerate(legal_actions[policy_index]): - # only the action in ``legal_action`` the policy logits is nonzero - policy_tmp[legal_action] = distributions[index] - target_policies.append(policy_tmp) - else: - # NOTE: the invalid padding target policy, O is to make sure the correspoding cross_entropy_loss=0 - policy_mask.append(0) - target_policies.append([0 for _ in range(policy_shape)]) - - policy_index += 1 - - batch_target_policies_non_re.append(target_policies) - batch_target_policies_non_re = np.asarray(batch_target_policies_non_re) - return batch_target_policies_non_re - - def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) -> None: - """ - Overview: - Update the priority of training data. - Arguments: - - train_data (:obj:`Optional[List[Optional[np.ndarray]]]`): training data to be updated priority. - - batch_priorities (:obj:`batch_priorities`): priorities to update to. - NOTE: - train_data = [current_batch, target_batch] - current_batch = [obs_list, action_list, mask_list, batch_index_list, weights, make_time_list] - """ - indices = train_data[0][3] - metas = {'make_time': train_data[0][5], 'batch_priorities': batch_priorities} - # only update the priorities for data still in replay buffer - for i in range(len(indices)): - if metas['make_time'][i] > self.clear_time: - idx, prio = indices[i], metas['batch_priorities'][i] - self.game_pos_priorities[idx] = prio diff --git a/lzero/mcts/buffer/gobigger_game_buffer_sampled_efficientzero.py b/lzero/mcts/buffer/gobigger_game_buffer_sampled_efficientzero.py deleted file mode 100644 index 877b721a1..000000000 --- a/lzero/mcts/buffer/gobigger_game_buffer_sampled_efficientzero.py +++ /dev/null @@ -1,593 +0,0 @@ -from typing import Any, List, Tuple - -import numpy as np -import torch -from ding.utils import BUFFER_REGISTRY - -from lzero.mcts.tree_search.mcts_ctree_sampled import SampledEfficientZeroMCTSCtree as MCTSCtree -from lzero.mcts.tree_search.mcts_ptree_sampled import SampledEfficientZeroMCTSPtree as MCTSPtree -from lzero.mcts.utils import prepare_observation -from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform -from .gobigger_game_buffer_muzero import GoBiggerMuZeroGameBuffer -from ding.torch_utils import to_device, to_tensor, to_ndarray - - -@BUFFER_REGISTRY.register('gobigger_game_buffer_sampled_efficientzero') -class GoBiggerSampledEfficientZeroGameBuffer(GoBiggerMuZeroGameBuffer): - """ - Overview: - The specific game buffer for GoBigger Sampled EfficientZero policy. - """ - - def __init__(self, cfg: dict): - super().__init__(cfg) - """ - Overview: - Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key - in the default configuration, the user-provided value will override the default configuration. Otherwise, - the default configuration will be used. - """ - default_config = self.default_config() - default_config.update(cfg) - self._cfg = default_config - assert self._cfg.env_type in ['not_board_games', 'board_games'] - self.replay_buffer_size = self._cfg.replay_buffer_size - self.batch_size = self._cfg.batch_size - self._alpha = self._cfg.priority_prob_alpha - self._beta = self._cfg.priority_prob_beta - - self.game_segment_buffer = [] - self.game_pos_priorities = [] - self.game_segment_game_pos_look_up = [] - - self.keep_ratio = 1 - self.num_of_collected_episodes = 0 - self.base_idx = 0 - self.clear_time = 0 - - self.tmp_obs = None # for value obs list [46 + 4(td_step)] not < 50(game_segment) - - def sample(self, batch_size: int, policy: Any) -> List[Any]: - """ - Overview: - sample data from ``GameBuffer`` and prepare the current and target batch for training - Arguments: - - batch_size (:obj:`int`): batch size - - policy (:obj:`torch.tensor`): model of policy - Returns: - - train_data (:obj:`List`): List of train data - """ - - policy._target_model.to(self._cfg.device) - policy._target_model.eval() - - reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch( - batch_size, self._cfg.reanalyze_ratio - ) - - # target reward, target value - batch_value_prefixs, batch_target_values = self._compute_target_reward_value( - reward_value_context, policy._target_model - ) - - batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( - policy_non_re_context, self._cfg.model.num_of_sampled_actions - ) - - if self._cfg.reanalyze_ratio > 0: - # target policy - batch_target_policies_re, root_sampled_actions = self._compute_target_policy_reanalyzed( - policy_re_context, policy._target_model - ) - # ============================================================== - # fix reanalyze in sez: - # use the latest root_sampled_actions after the reanalyze process, - # because the batch_target_policies_re is corresponding to the latest root_sampled_actions - # ============================================================== - - assert (self._cfg.reanalyze_ratio > 0 and self._cfg.reanalyze_outdated is True), \ - "in sampled effiicientzero, if self._cfg.reanalyze_ratio>0, you must set self._cfg.reanalyze_outdated=True" - # current_batch = [obs_list, action_list, root_sampled_actions_list, mask_list, batch_index_list, weights_list, make_time_list] - if self._cfg.model.continuous_action_space: - current_batch[2][:int(batch_size * self._cfg.reanalyze_ratio)] = root_sampled_actions.reshape( - int(batch_size * self._cfg.reanalyze_ratio), self._cfg.num_unroll_steps + 1, - self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size - ) - else: - current_batch[2][:int(batch_size * self._cfg.reanalyze_ratio)] = root_sampled_actions.reshape( - int(batch_size * self._cfg.reanalyze_ratio), self._cfg.num_unroll_steps + 1, - self._cfg.model.num_of_sampled_actions, 1 - ) - - if 0 < self._cfg.reanalyze_ratio < 1: - try: - batch_target_policies = np.concatenate([batch_target_policies_re, batch_target_policies_non_re]) - except Exception as error: - print(error) - elif self._cfg.reanalyze_ratio == 1: - batch_target_policies = batch_target_policies_re - elif self._cfg.reanalyze_ratio == 0: - batch_target_policies = batch_target_policies_non_re - - target_batch = [batch_value_prefixs, batch_target_values, batch_target_policies] - # a batch contains the current_batch and the target_batch - train_data = [current_batch, target_batch] - return train_data - - def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: - """ - Overview: - first sample orig_data through ``_sample_orig_data()``, - then prepare the context of a batch: - reward_value_context: the context of reanalyzed value targets - policy_re_context: the context of reanalyzed policy targets - policy_non_re_context: the context of non-reanalyzed policy targets - current_batch: the inputs of batch - Arguments: - - batch_size (:obj:`int`): the batch size of orig_data from replay buffer. - - reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed) - Returns: - - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch - """ - # obtain the batch context from replay buffer - orig_data = self._sample_orig_data(batch_size) - game_lst, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data - batch_size = len(batch_index_list) - obs_list, action_list, mask_list = [], [], [] - root_sampled_actions_list = [] - # prepare the inputs of a batch - for i in range(batch_size): - game = game_lst[i] - pos_in_game_segment = pos_in_game_segment_list[i] - # ============================================================== - # sampled related core code - # ============================================================== - actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment + - self._cfg.num_unroll_steps].tolist() - - # NOTE: self._cfg.num_unroll_steps + 1 - root_sampled_actions_tmp = game.root_sampled_actions[pos_in_game_segment:pos_in_game_segment + - self._cfg.num_unroll_steps + 1] - - # add mask for invalid actions (out of trajectory), 1 for valid, 0 for invalid - mask_tmp = [1. for i in range(len(root_sampled_actions_tmp))] - mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] - - # pad random action - if self._cfg.model.continuous_action_space: - actions_tmp += [ - np.random.randn(self._cfg.model.action_space_size) - for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) - ] - root_sampled_actions_tmp += [ - np.random.rand(self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size) - for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp)) - ] - else: - actions_tmp += [ - np.random.randint(0, self._cfg.model.action_space_size, 1).item() - for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) - ] - if len(root_sampled_actions_tmp[0].shape) == 1: - root_sampled_actions_tmp += [ - np.random.randint(0, self._cfg.model.action_space_size, - self._cfg.model.num_of_sampled_actions) - # NOTE: self._cfg.num_unroll_steps + 1 - for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp)) - ] - else: - root_sampled_actions_tmp += [ - np.random.randint(0, self._cfg.model.action_space_size, - self._cfg.model.num_of_sampled_actions).reshape( - self._cfg.model.num_of_sampled_actions, 1 - ) # NOTE: self._cfg.num_unroll_steps + 1 - for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp)) - ] - - # obtain the input observations - # stack+num_unroll_steps 4+5 - # pad if length of obs in game_segment is less than stack+num_unroll_steps - obs_list.append( - game_lst[i].get_unroll_obs( - pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True - ) - ) - action_list.append(actions_tmp) - root_sampled_actions_list.append(root_sampled_actions_tmp) - - mask_list.append(mask_tmp) - - # formalize the input observations - #obs_list = prepare_observation(obs_list, self._cfg.model.model_type) - # ============================================================== - # sampled related core code - # ============================================================== - # formalize the inputs of a batch - current_batch = [ - obs_list, action_list, root_sampled_actions_list, mask_list, batch_index_list, weights_list, make_time_list - ] - - for i in range(len(current_batch)): - current_batch[i] = np.asarray(current_batch[i]) - - total_transitions = self.get_num_of_transitions() - - # obtain the context of value targets - reward_value_context = self._prepare_reward_value_context( - batch_index_list, game_lst, pos_in_game_segment_list, total_transitions - ) - """ - only reanalyze recent reanalyze_ratio (e.g. 50%) data - if self._cfg.reanalyze_outdated is True, batch_index_list is sorted according to its generated env_steps - 0: reanalyze_num -> reanalyzed policy, reanalyze_num:end -> non reanalyzed policy - """ - reanalyze_num = int(batch_size * reanalyze_ratio) - # reanalyzed policy - if reanalyze_num > 0: - # obtain the context of reanalyzed policy targets - policy_re_context = self._prepare_policy_reanalyzed_context( - batch_index_list[:reanalyze_num], game_lst[:reanalyze_num], pos_in_game_segment_list[:reanalyze_num] - ) - else: - policy_re_context = None - - # non reanalyzed policy - if reanalyze_num < batch_size: - # obtain the context of non-reanalyzed policy targets - policy_non_re_context = self._prepare_policy_non_reanalyzed_context( - batch_index_list[reanalyze_num:], game_lst[reanalyze_num:], pos_in_game_segment_list[reanalyze_num:] - ) - else: - policy_non_re_context = None - - context = reward_value_context, policy_re_context, policy_non_re_context, current_batch - return context - - def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any) -> List[np.ndarray]: - """ - Overview: - prepare reward and value targets from the context of rewards and values. - Arguments: - - reward_value_context (:obj:'list'): the reward value context - - model (:obj:'torch.tensor'):model of the target model - Returns: - - batch_value_prefixs (:obj:'np.ndarray): batch of value prefix - - batch_target_values (:obj:'np.ndarray): batch of value estimation - """ - value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, \ - to_play_segment = reward_value_context # noqa - - # transition_batch_size = game_segment_batch_size * (num_unroll_steps+1) - transition_batch_size = len(value_obs_list) - game_segment_batch_size = len(pos_in_game_segment_list) - - to_play, action_mask = self._preprocess_to_play_and_action_mask( - game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list - ) - if self._cfg.model.continuous_action_space is True: - # when the action space of the environment is continuous, action_mask[:] is None. - action_mask = [ - list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) - ] - # NOTE: in continuous action space env: we set all legal_actions as -1 - legal_actions = [ - [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) - ] - else: - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] - - batch_target_values, batch_value_prefixs = [], [] - with torch.no_grad(): - # value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) - # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors - slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) - network_output = [] - for i in range(slices): - beg_index = self._cfg.mini_infer_size * i - end_index = self._cfg.mini_infer_size * (i + 1) - # m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() - m_obs = value_obs_list[beg_index:end_index] - m_obs = to_tensor(m_obs) - m_obs = sum(m_obs, []) - m_obs = to_device(m_obs, self._cfg.device) - - # calculate the target value - m_output = model.initial_inference(m_obs) - - # TODO(pu) - if not model.training: - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), - m_output.policy_logits - ] - ) - m_output.reward_hidden_state = ( - m_output.reward_hidden_state[0].detach().cpu().numpy(), - m_output.reward_hidden_state[1].detach().cpu().numpy() - ) - - network_output.append(m_output) - - # concat the output slices after model inference - if self._cfg.use_root_value: - # use the root values from MCTS - # the root values have limited improvement but require much more GPU actors; - _, value_prefix_pool, policy_logits_pool, latent_state_roots, reward_hidden_state_roots = concat_output( - network_output, data_type='efficientzero' - ) - value_prefix_pool = value_prefix_pool.squeeze().tolist() - policy_logits_pool = policy_logits_pool.tolist() - # generate the noises for the root nodes - noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.num_of_sampled_actions - ).astype(np.float32).tolist() for _ in range(transition_batch_size) - ] - - if self._cfg.mcts_ctree: - # cpp mcts_tree - # prepare the root nodes for MCTS - roots = MCTSCtree.roots( - transition_batch_size, legal_actions, self._cfg.model.action_space_size, - self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space - ) - - roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play) - # do MCTS for a new policy with the recent target model - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play) - else: - # python mcts_tree - roots = MCTSPtree.roots( - transition_batch_size, legal_actions, self._cfg.model.action_space_size, - self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space - ) - roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play) - # do MCTS for a new policy with the recent target model - MCTSPtree.roots(self._cfg - ).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play) - - roots_values = roots.get_values() - value_list = np.array(roots_values) - else: - # use the predicted values - value_list = concat_output_value(network_output) - - # get last state value - if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: - # TODO(pu): for board_games, very important, to check - value_list = value_list.reshape(-1) * np.array( - [ - self._cfg.discount_factor ** td_steps_list[i] if int(td_steps_list[i]) % - 2 == 0 else -self._cfg.discount_factor ** td_steps_list[i] - for i in range(transition_batch_size) - ] - ) - else: - value_list = value_list.reshape(-1) * ( - np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list - ) - - value_list = value_list * np.array(value_mask) - value_list = value_list.tolist() - - horizon_id, value_index = 0, 0 - for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list, - pos_in_game_segment_list, - to_play_segment): - target_values = [] - target_value_prefixs = [] - - value_prefix = 0.0 - base_index = state_index - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): - bootstrap_index = current_index + td_steps_list[value_index] - # for i, reward in enumerate(game.rewards[current_index:bootstrap_index]): - for i, reward in enumerate(reward_list[current_index:bootstrap_index]): - if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: - # TODO(pu): for board_games, very important, to check - if to_play_list[base_index] == to_play_list[i]: - value_list[value_index] += reward * self._cfg.discount_factor ** i - else: - value_list[value_index] += -reward * self._cfg.discount_factor ** i - else: - value_list[value_index] += reward * self._cfg.discount_factor ** i - # TODO(pu): why value don't use discount_factor factor - - # reset every lstm_horizon_len - if horizon_id % self._cfg.lstm_horizon_len == 0: - value_prefix = 0.0 - base_index = current_index - horizon_id += 1 - - if current_index < game_segment_len_non_re: - target_values.append(value_list[value_index]) - # Since the horizon is small and the discount_factor is close to 1. - # Compute the reward sum to approximate the value prefix for simplification - value_prefix += reward_list[current_index - ] # * config.discount_factor ** (current_index - base_index) - target_value_prefixs.append(value_prefix) - else: - target_values.append(0) - target_value_prefixs.append(value_prefix) - - value_index += 1 - - batch_value_prefixs.append(target_value_prefixs) - batch_target_values.append(target_values) - - batch_value_prefixs = np.asarray(batch_value_prefixs, dtype=object) - batch_target_values = np.asarray(batch_target_values, dtype=object) - - return batch_value_prefixs, batch_target_values - - def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any) -> np.ndarray: - """ - Overview: - prepare policy targets from the reanalyzed context of policies - Arguments: - - policy_re_context (:obj:`List`): List of policy context to reanalyzed - Returns: - - batch_target_policies_re - """ - if policy_re_context is None: - return [] - batch_target_policies_re = [] - - policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, action_mask_segment, \ - to_play_segment = policy_re_context # noqa - # transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1) - transition_batch_size = len(policy_obs_list) - game_segment_batch_size = len(pos_in_game_segment_list) - - to_play, action_mask = self._preprocess_to_play_and_action_mask( - game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list - ) - if self._cfg.model.continuous_action_space is True: - # when the action space of the environment is continuous, action_mask[:] is None. - action_mask = [ - list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) - ] - # NOTE: in continuous action space env, we set all legal_actions as -1 - legal_actions = [ - [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) - ] - else: - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] - - with torch.no_grad(): - policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type) - # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors - self._cfg.mini_infer_size = self._cfg.mini_infer_size - slices = np.ceil(transition_batch_size / self._cfg.mini_infer_size).astype(np.int_) - network_output = [] - for i in range(slices): - beg_index = self._cfg.mini_infer_size * i - end_index = self._cfg.mini_infer_size * (i + 1) - m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device).float() - - m_output = model.initial_inference(m_obs) - - if not model.training: - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), - m_output.policy_logits - ] - ) - m_output.reward_hidden_state = ( - m_output.reward_hidden_state[0].detach().cpu().numpy(), - m_output.reward_hidden_state[1].detach().cpu().numpy() - ) - - network_output.append(m_output) - - _, value_prefix_pool, policy_logits_pool, latent_state_roots, reward_hidden_state_roots = concat_output( - network_output, data_type='efficientzero' - ) - - value_prefix_pool = value_prefix_pool.squeeze().tolist() - policy_logits_pool = policy_logits_pool.tolist() - noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.num_of_sampled_actions - ).astype(np.float32).tolist() for _ in range(transition_batch_size) - ] - if self._cfg.mcts_ctree: - # ============================================================== - # sampled related core code - # ============================================================== - # cpp mcts_tree - roots = MCTSCtree.roots( - transition_batch_size, legal_actions, self._cfg.model.action_space_size, - self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space - ) - roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play) - # do MCTS for a new policy with the recent target model - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play) - else: - # python mcts_tree - roots = MCTSPtree.roots( - transition_batch_size, legal_actions, self._cfg.model.action_space_size, - self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space - ) - roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play) - # do MCTS for a new policy with the recent target model - MCTSPtree.roots(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play) - - roots_legal_actions_list = legal_actions - roots_distributions = roots.get_distributions() - - # ============================================================== - # fix reanalyze in sez - # ============================================================== - roots_sampled_actions = roots.get_sampled_actions() - try: - root_sampled_actions = np.array([action.value for action in roots_sampled_actions]) - except Exception: - root_sampled_actions = np.array([action for action in roots_sampled_actions]) - - policy_index = 0 - for state_index, game_idx in zip(pos_in_game_segment_list, batch_index_list): - target_policies = [] - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): - distributions = roots_distributions[policy_index] - # ============================================================== - # sampled related core code - # ============================================================== - if policy_mask[policy_index] == 0: - # NOTE: the invalid padding target policy, O is to make sure the correspoding cross_entropy_loss=0 - target_policies.append([0 for _ in range(self._cfg.model.num_of_sampled_actions)]) - else: - if distributions is None: - # if at some obs, the legal_action is None, then add the fake target_policy - target_policies.append( - list( - np.ones(self._cfg.model.num_of_sampled_actions) / - self._cfg.model.num_of_sampled_actions - ) - ) - else: - if self._cfg.env_type == 'not_board_games': - sum_visits = sum(distributions) - policy = [visit_count / sum_visits for visit_count in distributions] - target_policies.append(policy) - else: - # for two_player board games - policy_tmp = [0 for _ in range(self._cfg.model.num_of_sampled_actions)] - # to make sure target_policies have the same dimension - sum_visits = sum(distributions) - policy = [visit_count / sum_visits for visit_count in distributions] - for index, legal_action in enumerate(roots_legal_actions_list[policy_index]): - policy_tmp[legal_action] = policy[index] - target_policies.append(policy_tmp) - - policy_index += 1 - - batch_target_policies_re.append(target_policies) - - batch_target_policies_re = np.array(batch_target_policies_re) - - return batch_target_policies_re, root_sampled_actions - - def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) -> None: - """ - Overview: - Update the priority of training data. - Arguments: - - train_data (:obj:`Optional[List[Optional[np.ndarray]]]`): training data to be updated priority. - - batch_priorities (:obj:`batch_priorities`): priorities to update to. - NOTE: - train_data = [current_batch, target_batch] - current_batch = [obs_list, action_list, root_sampled_actions_list, mask_list, batch_index_list, weights_list, make_time_list] - """ - - batch_index_list = train_data[0][4] - metas = {'make_time': train_data[0][6], 'batch_priorities': batch_priorities} - # only update the priorities for data still in replay buffer - for i in range(len(batch_index_list)): - if metas['make_time'][i] > self.clear_time: - idx, prio = batch_index_list[i], metas['batch_priorities'][i] - self.game_pos_priorities[idx] = prio diff --git a/lzero/mcts/buffer/gobigger_game_buffer_efficientzero.py b/lzero/mcts/buffer/multi_agent_game_buffer_efficientzero.py similarity index 54% rename from lzero/mcts/buffer/gobigger_game_buffer_efficientzero.py rename to lzero/mcts/buffer/multi_agent_game_buffer_efficientzero.py index 4f74999aa..84382895b 100644 --- a/lzero/mcts/buffer/gobigger_game_buffer_efficientzero.py +++ b/lzero/mcts/buffer/multi_agent_game_buffer_efficientzero.py @@ -8,84 +8,18 @@ from lzero.mcts.tree_search.mcts_ptree import EfficientZeroMCTSPtree as MCTSPtree from lzero.mcts.utils import prepare_observation from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform -from .gobigger_game_buffer_muzero import GoBiggerMuZeroGameBuffer +from .multi_agent_game_buffer_muzero import MultiAgentMuZeroGameBuffer from ding.torch_utils import to_device, to_tensor, to_ndarray +from ding.utils.data import default_collate -@BUFFER_REGISTRY.register('gobigger_game_buffer_efficientzero') -class GoBiggerEfficientZeroGameBuffer(GoBiggerMuZeroGameBuffer): +@BUFFER_REGISTRY.register('multi_agent_game_buffer_efficientzero') +class MultiAgentSampledEfficientZeroGameBuffer(MultiAgentMuZeroGameBuffer): """ Overview: - The specific game buffer for GoBigger EfficientZero policy. + The specific game buffer for Multi Agent EfficientZero policy. """ - def __init__(self, cfg: dict): - super().__init__(cfg) - """ - Overview: - Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key - in the default configuration, the user-provided value will override the default configuration. Otherwise, - the default configuration will be used. - """ - default_config = self.default_config() - default_config.update(cfg) - self._cfg = default_config - assert self._cfg.env_type in ['not_board_games', 'board_games'] - self.replay_buffer_size = self._cfg.replay_buffer_size - self.batch_size = self._cfg.batch_size - self._alpha = self._cfg.priority_prob_alpha - self._beta = self._cfg.priority_prob_beta - - self.game_segment_buffer = [] - self.game_pos_priorities = [] - self.game_segment_game_pos_look_up = [] - - self.keep_ratio = 1 - self.num_of_collected_episodes = 0 - self.base_idx = 0 - self.clear_time = 0 - - self.tmp_obs = None # for value obs list [46 + 4(td_step)] not < 50(game_segment) - - def sample(self, batch_size: int, policy: Any) -> List[Any]: - """ - Overview: - sample data from ``GameBuffer`` and prepare the current and target batch for training - Arguments: - - batch_size (:obj:`int`): batch size - - policy (:obj:`torch.tensor`): model of policy - Returns: - - train_data (:obj:`List`): List of train data - """ - policy._target_model.to(self._cfg.device) - policy._target_model.eval() - - # obtain the current_batch and prepare target context - reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch( - batch_size, self._cfg.reanalyze_ratio - ) - - # target value_prefixs, target value - batch_value_prefixs, batch_target_values = self._compute_target_reward_value( - reward_value_context, policy._target_model - ) - # target policy - batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model) - batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( - policy_non_re_context, self._cfg.model.action_space_size - ) - - if 0 < self._cfg.reanalyze_ratio < 1: - batch_target_policies = np.concatenate([batch_target_policies_re, batch_target_policies_non_re]) - elif self._cfg.reanalyze_ratio == 1: - batch_target_policies = batch_target_policies_re - elif self._cfg.reanalyze_ratio == 0: - batch_target_policies = batch_target_policies_non_re - - target_batch = [batch_value_prefixs, batch_target_values, batch_target_policies] - # a batch contains the current_batch and the target_batch - train_data = [current_batch, target_batch] - return train_data def _prepare_reward_value_context( self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[Any], @@ -154,6 +88,7 @@ def _prepare_reward_value_context( else: value_mask.append(0) obs = self.tmp_obs # will be masked + value_obs_list.append(obs.tolist()) reward_value_context = [ @@ -190,7 +125,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A # ============================================================== batch_target_values, batch_value_prefixs = [], [] with torch.no_grad(): - #value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) + value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) network_output = [] @@ -203,6 +138,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A m_obs = to_tensor(m_obs) m_obs = sum(m_obs, []) m_obs = to_device(m_obs, self._cfg.device) + m_obs = default_collate(m_obs) # calculate the target value m_output = model.initial_inference(m_obs) @@ -318,129 +254,3 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A return batch_value_prefixs, batch_target_values - def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any) -> np.ndarray: - """ - Overview: - prepare policy targets from the reanalyzed context of policies - Arguments: - - policy_re_context (:obj:`List`): List of policy context to reanalyzed - Returns: - - batch_target_policies_re - """ - if policy_re_context is None: - return [] - batch_target_policies_re = [] - - policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, action_mask_segment, \ - to_play_segment = policy_re_context # noqa - # transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1) - transition_batch_size = len(policy_obs_list) - game_segment_batch_size = len(pos_in_game_segment_list) - to_play, action_mask = self._preprocess_to_play_and_action_mask( - game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list - ) - - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] - with torch.no_grad(): - policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type) - # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors - slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) - network_output = [] - for i in range(slices): - beg_index = self._cfg.mini_infer_size * i - end_index = self._cfg.mini_infer_size * (i + 1) - m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device).float() - - m_output = model.initial_inference(m_obs) - - if not model.training: - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), - m_output.policy_logits - ] - ) - m_output.reward_hidden_state = ( - m_output.reward_hidden_state[0].detach().cpu().numpy(), - m_output.reward_hidden_state[1].detach().cpu().numpy() - ) - - network_output.append(m_output) - - _, value_prefix_pool, policy_logits_pool, latent_state_roots, reward_hidden_state_roots = concat_output( - network_output, data_type='efficientzero' - ) - value_prefix_pool = value_prefix_pool.squeeze().tolist() - policy_logits_pool = policy_logits_pool.tolist() - noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size - ).astype(np.float32).tolist() for _ in range(transition_batch_size) - ] - if self._cfg.mcts_ctree: - # cpp mcts_tree - roots = MCTSCtree.roots(transition_batch_size, legal_actions) - roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play) - # do MCTS for a new policy with the recent target model - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play) - else: - # python mcts_tree - roots = MCTSPtree.roots(transition_batch_size, legal_actions) - roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play) - # do MCTS for a new policy with the recent target model - MCTSPtree(self._cfg).search( - roots, model, latent_state_roots, reward_hidden_state_roots, to_play=to_play - ) - - roots_legal_actions_list = legal_actions - roots_distributions = roots.get_distributions() - policy_index = 0 - for state_index, game_index in zip(pos_in_game_segment_list, batch_index_list): - target_policies = [] - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): - distributions = roots_distributions[policy_index] - if policy_mask[policy_index] == 0: - # NOTE: the invalid padding target policy, O is to make sure the correspoding cross_entropy_loss=0 - target_policies.append([0 for _ in range(self._cfg.model.action_space_size)]) - else: - if distributions is None: - # if at some obs, the legal_action is None, add the fake target_policy - target_policies.append( - list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size) - ) - else: - if self._cfg.mcts_ctree: - # cpp mcts_tree - if self._cfg.env_type == 'not_board_games': - sum_visits = sum(distributions) - policy = [visit_count / sum_visits for visit_count in distributions] - target_policies.append(policy) - else: - # for two_player board games - policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)] - # to make sure target_policies have the same dimension - sum_visits = sum(distributions) - policy = [visit_count / sum_visits for visit_count in distributions] - for index, legal_action in enumerate(roots_legal_actions_list[policy_index]): - policy_tmp[legal_action] = policy[index] - target_policies.append(policy_tmp) - else: - # python mcts_tree - if self._cfg.env_type == 'not_board_games': - sum_visits = sum(distributions) - policy = [visit_count / sum_visits for visit_count in distributions] - target_policies.append(policy) - else: - # for two_player board games - policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)] - # to make sure target_policies have the same dimension - sum_visits = sum(distributions) - policy = [visit_count / sum_visits for visit_count in distributions] - for index, legal_action in enumerate(roots_legal_actions_list[policy_index]): - policy_tmp[legal_action] = policy[index] - target_policies.append(policy_tmp) - policy_index += 1 - batch_target_policies_re.append(target_policies) - batch_target_policies_re = np.array(batch_target_policies_re) - return batch_target_policies_re diff --git a/lzero/mcts/buffer/multi_agent_game_buffer_muzero.py b/lzero/mcts/buffer/multi_agent_game_buffer_muzero.py new file mode 100644 index 000000000..eadb473d9 --- /dev/null +++ b/lzero/mcts/buffer/multi_agent_game_buffer_muzero.py @@ -0,0 +1,261 @@ +from typing import Any, List, Tuple, Union, TYPE_CHECKING, Optional + +import numpy as np +import torch +from ding.utils import BUFFER_REGISTRY + +from lzero.mcts.tree_search.mcts_ctree import MuZeroMCTSCtree as MCTSCtree +from lzero.mcts.tree_search.mcts_ptree import MuZeroMCTSPtree as MCTSPtree +from lzero.mcts.utils import prepare_observation +from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform +from .game_buffer import GameBuffer +from ding.torch_utils import to_device, to_tensor +from ding.utils.data import default_collate +from .game_buffer_muzero import MuZeroGameBuffer + +if TYPE_CHECKING: + from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy + +@BUFFER_REGISTRY.register('multi_agent_game_buffer_muzero') +class MultiAgentMuZeroGameBuffer(MuZeroGameBuffer): + """ + Overview: + The specific game buffer for Multi Agent MuZero policy. + """ + + def _prepare_policy_non_reanalyzed_context( + self, batch_index_list: List[int], game_segment_list: List[Any], pos_in_game_segment_list: List[int] + ) -> List[Any]: + """ + Overview: + prepare the context of policies for calculating policy target in non-reanalyzing part, just return the policy in self-play + Arguments: + - batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer + - game_segment_list (:obj:`list`): list of game segments + - pos_in_game_segment_list (:obj:`list`): list transition index in game + Returns: + - policy_non_re_context (:obj:`list`): pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment + """ + child_visits = [] + game_segment_lens = [] + # for board games + action_mask_segment, to_play_segment = [], [] + + for game_segment, state_index, idx in zip(game_segment_list, pos_in_game_segment_list, batch_index_list): + game_segment_len = len(game_segment) + game_segment_lens.append(game_segment_len) + # for board games + action_mask_segment.append(game_segment.action_mask_segment) + to_play_segment.append(game_segment.to_play_segment) + + child_visits.append(game_segment.child_visit_segment) + + policy_non_re_context = [ + pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment + ] + return policy_non_re_context + + def _prepare_policy_reanalyzed_context( + self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str] + ) -> List[Any]: + """ + Overview: + prepare the context of policies for calculating policy target in reanalyzing part. + Arguments: + - batch_index_list (:obj:'list'): start transition index in the replay buffer + - game_segment_list (:obj:'list'): list of game segments + - pos_in_game_segment_list (:obj:'list'): position of transition index in one game history + Returns: + - policy_re_context (:obj:`list`): policy_obs_list, policy_mask, pos_in_game_segment_list, indices, + child_visits, game_segment_lens, action_mask_segment, to_play_segment + """ + zero_obs = game_segment_list[0].zero_obs() + with torch.no_grad(): + # for policy + policy_obs_list = [] + policy_mask = [] + # 0 -> Invalid target policy for padding outside of game segments, + # 1 -> Previous target policy for game segments. + rewards, child_visits, game_segment_lens = [], [], [] + # for board games + action_mask_segment, to_play_segment = [], [] + for game_segment, state_index in zip(game_segment_list, pos_in_game_segment_list): + game_segment_len = len(game_segment) + game_segment_lens.append(game_segment_len) + rewards.append(game_segment.reward_segment) + # for board games + action_mask_segment.append(game_segment.action_mask_segment) + to_play_segment.append(game_segment.to_play_segment) + + child_visits.append(game_segment.child_visit_segment) + # prepare the corresponding observations + game_obs = game_segment.get_unroll_obs(state_index, self._cfg.num_unroll_steps) + for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + + if current_index < game_segment_len: + policy_mask.append(1) + beg_index = current_index - state_index + end_index = beg_index + self._cfg.model.frame_stack_num + obs = game_obs[beg_index:end_index] + else: + policy_mask.append(0) + obs = zero_obs + policy_obs_list.append(obs) + + policy_re_context = [ + policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, + action_mask_segment, to_play_segment + ] + return policy_re_context + + def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any) -> Tuple[Any, Any]: + """ + Overview: + prepare reward and value targets from the context of rewards and values. + Arguments: + - reward_value_context (:obj:'list'): the reward value context + - model (:obj:'torch.tensor'):model of the target model + Returns: + - batch_value_prefixs (:obj:'np.ndarray): batch of value prefix + - batch_target_values (:obj:'np.ndarray): batch of value estimation + """ + value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, \ + to_play_segment = reward_value_context # noqa + # transition_batch_size = game_segment_batch_size * (num_unroll_steps+1) + transition_batch_size = len(value_obs_list) + game_segment_batch_size = len(pos_in_game_segment_list) + + to_play, action_mask = self._preprocess_to_play_and_action_mask( + game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list + ) + if self._cfg.model.continuous_action_space is True: + # when the action space of the environment is continuous, action_mask[:] is None. + action_mask = [ + list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) + ] + # NOTE: in continuous action space env: we set all legal_actions as -1 + legal_actions = [ + [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) + ] + else: + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] + + batch_target_values, batch_rewards = [], [] + with torch.no_grad(): + value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) + # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors + slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) + network_output = [] + for i in range(slices): + beg_index = self._cfg.mini_infer_size * i + end_index = self._cfg.mini_infer_size * (i + 1) + + # m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() + m_obs = value_obs_list[beg_index:end_index] + m_obs = to_tensor(m_obs) + m_obs = sum(m_obs, []) + m_obs = to_device(m_obs, self._cfg.device) + m_obs = default_collate(m_obs) + + # calculate the target value + m_output = model.initial_inference(m_obs) + + if not model.training: + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + m_output.policy_logits + ] + ) + + network_output.append(m_output) + + # concat the output slices after model inference + if self._cfg.use_root_value: + # use the root values from MCTS, as in EfficiientZero + # the root values have limited improvement but require much more GPU actors; + _, reward_pool, policy_logits_pool, latent_state_roots = concat_output( + network_output, data_type='muzero' + ) + reward_pool = reward_pool.squeeze().tolist() + policy_logits_pool = policy_logits_pool.tolist() + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(transition_batch_size) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(transition_batch_size, legal_actions) + roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) + # do MCTS for a new policy with the recent target model + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) + else: + # python mcts_tree + roots = MCTSPtree.roots(transition_batch_size, legal_actions) + roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) + # do MCTS for a new policy with the recent target model + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) + + roots_values = roots.get_values() + value_list = np.array(roots_values) + else: + # use the predicted values + value_list = concat_output_value(network_output) + + # get last state value + if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: + # TODO(pu): for board_games, very important, to check + value_list = value_list.reshape(-1) * np.array( + [ + self._cfg.discount_factor ** td_steps_list[i] if int(td_steps_list[i]) % + 2 == 0 else -self._cfg.discount_factor ** td_steps_list[i] + for i in range(transition_batch_size) + ] + ) + else: + value_list = value_list.reshape(-1) * ( + np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list + ) + + value_list = value_list * np.array(value_mask) + value_list = value_list.tolist() + horizon_id, value_index = 0, 0 + + for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list, + pos_in_game_segment_list, + to_play_segment): + target_values = [] + target_rewards = [] + base_index = state_index + for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + bootstrap_index = current_index + td_steps_list[value_index] + # for i, reward in enumerate(game.rewards[current_index:bootstrap_index]): + for i, reward in enumerate(reward_list[current_index:bootstrap_index]): + if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: + # TODO(pu): for board_games, very important, to check + if to_play_list[base_index] == to_play_list[i]: + value_list[value_index] += reward * self._cfg.discount_factor ** i + else: + value_list[value_index] += -reward * self._cfg.discount_factor ** i + else: + value_list[value_index] += reward * self._cfg.discount_factor ** i + horizon_id += 1 + + if current_index < game_segment_len_non_re: + target_values.append(value_list[value_index]) + target_rewards.append(reward_list[current_index]) + else: + target_values.append(0) + target_rewards.append(0.0) + # TODO: check + # target_rewards.append(reward) + value_index += 1 + + batch_rewards.append(target_rewards) + batch_target_values.append(target_values) + + batch_rewards = np.asarray(batch_rewards, dtype=object) + batch_target_values = np.asarray(batch_target_values, dtype=object) + return batch_rewards, batch_target_values diff --git a/lzero/mcts/utils.py b/lzero/mcts/utils.py index 6f5737acb..57915ee8c 100644 --- a/lzero/mcts/utils.py +++ b/lzero/mcts/utils.py @@ -63,7 +63,7 @@ def prepare_observation(observation_list, model_type='conv'): - observation_list (:obj:`List`): list of observations. - model_type (:obj:`str`): type of the model. (default is 'conv') """ - assert model_type in ['conv', 'mlp'] + assert model_type in ['conv', 'mlp', 'structure'] observation_array = np.array(observation_list) if model_type == 'conv': @@ -98,6 +98,9 @@ def prepare_observation(observation_list, model_type='conv'): observation_array = observation_array.reshape(observation_array.shape[0], -1) # print(observation_array.shape) + elif model_type == 'structure': + return observation_list + return observation_array diff --git a/lzero/model/efficientzero_model_structure.py b/lzero/model/efficientzero_model_structure.py new file mode 100644 index 000000000..93afeab2b --- /dev/null +++ b/lzero/model/efficientzero_model_structure.py @@ -0,0 +1,187 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from ding.torch_utils import MLP, ResBlock +from ding.utils import MODEL_REGISTRY, SequenceType +from numpy import ndarray + +from .common import EZNetworkOutput, RepresentationNetwork, PredictionNetwork +from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean +from lzero.model.efficientzero_model_mlp import EfficientZeroModelMLP, DynamicsNetworkMLP, PredictionNetworkMLP + +@MODEL_REGISTRY.register('EfficientZeroModelStructure') +class EfficientZeroModelStructure(EfficientZeroModelMLP): + def __init__( + self, + env_name: str, + action_space_size: int = 6, + lstm_hidden_size: int = 512, + latent_state_dim: int = 256, + fc_reward_layers: SequenceType = [32], + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + reward_support_size: int = 601, + value_support_size: int = 601, + proj_hid: int = 1024, + proj_out: int = 1024, + pred_hid: int = 512, + pred_out: int = 1024, + self_supervised_learning_loss: bool = True, + categorical_distribution: bool = True, + last_linear_layer_init_zero: bool = True, + state_norm: bool = False, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + discrete_action_encoding_type: str = 'one_hot', + res_connection_in_dynamics: bool = False, + *args, + **kwargs, + ): + """ + Overview: + The definition of the network model of EfficientZero, which is a generalization version for 1D vector obs. + The networks are mainly built on fully connected layers. + Sampled EfficientZero model consists of a representation network, a dynamics network and a prediction network. + The representation network is an MLP network which maps the raw observation to a latent state. + The dynamics network is an MLP+LSTM network which predicts the next latent state, reward_hidden_state and value_prefix given the current latent state and action. + The prediction network is an MLP network which predicts the value and policy given the current latent state. + Arguments: + - env_name (:obj:`str`): Env name, e.g. ptz_simple_spread, gobigger etc. + - action_space_size: (:obj:`int`): Action space size, e.g. 4 for Lunarlander. + - lstm_hidden_size (:obj:`int`): The hidden size of LSTM in dynamics network to predict value_prefix. + - latent_state_dim (:obj:`int`): The dimension of latent state, such as 256. + - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - reward_support_size (:obj:`int`): The size of categorical reward output + - value_support_size (:obj:`int`): The size of categorical value output. + - proj_hid (:obj:`int`): The size of projection hidden layer. + - proj_out (:obj:`int`): The size of projection output layer. + - pred_hid (:obj:`int`): The size of prediction hidden layer. + - pred_out (:obj:`int`): The size of prediction output layer. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks in Sampled EfficientZero model, default set it to False. + - categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical distribution for value, reward/value_prefix. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. + - state_norm (:obj:`bool`): Whether to use normalization for latent states, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - discrete_action_encoding_type (:obj:`str`): The type of encoding for discrete action. Default sets it to 'one_hot'. options = {'one_hot', 'not_one_hot'} + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. + """ + super(EfficientZeroModelStructure, self).__init__() + if not categorical_distribution: + self.reward_support_size = 1 + self.value_support_size = 1 + else: + self.reward_support_size = reward_support_size + self.value_support_size = value_support_size + + self.action_space_size = action_space_size + self.continuous_action_space = False + # The dim of action space. For discrete action space, it is 1. + # For continuous action space, it is the dimension of continuous action. + self.action_space_dim = action_space_size if self.continuous_action_space else 1 + assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type + self.discrete_action_encoding_type = discrete_action_encoding_type + if self.continuous_action_space: + self.action_encoding_dim = action_space_size + else: + if self.discrete_action_encoding_type == 'one_hot': + self.action_encoding_dim = action_space_size + elif self.discrete_action_encoding_type == 'not_one_hot': + self.action_encoding_dim = 1 + + self.lstm_hidden_size = lstm_hidden_size + self.proj_hid = proj_hid + self.proj_out = proj_out + self.pred_hid = pred_hid + self.pred_out = pred_out + self.self_supervised_learning_loss = self_supervised_learning_loss + self.last_linear_layer_init_zero = last_linear_layer_init_zero + self.state_norm = state_norm + self.res_connection_in_dynamics = res_connection_in_dynamics + + if env_name == 'gobigger': + from lzero.model.gobigger.gobigger_encoder import GoBiggerEncoder as Encoder + elif env_name == 'ptz_simple_spread': + from lzero.model.petting_zoo.encoder import PettingZooEncoder as Encoder + else: + raise NotImplementedError + self.representation_network = Encoder() + + self.dynamics_network = DynamicsNetworkMLP( + action_encoding_dim=self.action_encoding_dim, + num_channels=latent_state_dim + self.action_encoding_dim, + common_layer_num=2, + lstm_hidden_size=lstm_hidden_size, + fc_reward_layers=fc_reward_layers, + output_support_size=self.reward_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type, + res_connection_in_dynamics=self.res_connection_in_dynamics, + ) + + self.prediction_network = PredictionNetworkMLP( + action_space_size=action_space_size, + num_channels=latent_state_dim, + fc_value_layers=fc_value_layers, + fc_policy_layers=fc_policy_layers, + output_support_size=self.value_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type + ) + + if self.self_supervised_learning_loss: + # self_supervised_learning_loss related network proposed in EfficientZero + self.projection_input_dim = latent_state_dim + + self.projection = nn.Sequential( + nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) + ) + self.prediction_head = nn.Sequential( + nn.Linear(self.proj_out, self.pred_hid), + nn.BatchNorm1d(self.pred_hid), + activation, + nn.Linear(self.pred_hid, self.pred_out), + ) + + def initial_inference(self, obs: torch.Tensor) -> EZNetworkOutput: + """ + Overview: + Initial inference of EfficientZero model, which is the first step of the EfficientZero model. + To perform the initial inference, we first use the representation network to obtain the "latent_state" of the observation. + Then we use the prediction network to predict the "value" and "policy_logits" of the "latent_state", and + also prepare the zeros-like ``reward_hidden_state`` for the next step of the EfficientZero model. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns (EZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. \ + In initial inference, we set it to zero vector. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The hidden state of LSTM about reward. In initial inference, \ + we set it to the zeros-like hidden state (H and C). + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size. + """ + batch_size = obs['action_mask'].shape[0] + latent_state = self._representation(obs) + device = latent_state.device + policy_logits, value = self._prediction(latent_state) + # zero initialization for reward hidden states + # (hn, cn), each element shape is (layer_num=1, batch_size, lstm_hidden_size) + reward_hidden_state = ( + torch.zeros(1, batch_size, self.lstm_hidden_size).to(device), + torch.zeros(1, batch_size, self.lstm_hidden_size).to(device), + ) + return EZNetworkOutput(value, [0. for _ in range(batch_size)], policy_logits, latent_state, reward_hidden_state) diff --git a/lzero/model/gobigger/network/__init__.py b/lzero/model/gobigger/__init__.py similarity index 100% rename from lzero/model/gobigger/network/__init__.py rename to lzero/model/gobigger/__init__.py diff --git a/lzero/model/gobigger/network/encoder.py b/lzero/model/gobigger/encoder.py similarity index 100% rename from lzero/model/gobigger/network/encoder.py rename to lzero/model/gobigger/encoder.py diff --git a/lzero/model/gobigger/gobigger_efficientzero_model.py b/lzero/model/gobigger/gobigger_efficientzero_model.py deleted file mode 100644 index 11b4125cc..000000000 --- a/lzero/model/gobigger/gobigger_efficientzero_model.py +++ /dev/null @@ -1,477 +0,0 @@ -from typing import Optional, Tuple - -import torch -import torch.nn as nn -from ding.torch_utils import MLP -from ding.utils import MODEL_REGISTRY, SequenceType -from numpy import ndarray - -from ..common import EZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP -from ..utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean -from .network.gobigger_encoder import GoBiggerEncoder -import yaml -from easydict import EasyDict -from ding.utils.data import default_collate - - -@MODEL_REGISTRY.register('GoBiggerEfficientZeroModel') -class GoBiggerEfficientZeroModel(nn.Module): - - def __init__( - self, - observation_shape: int = 2, - action_space_size: int = 6, - lstm_hidden_size: int = 512, - latent_state_dim: int = 256, - fc_reward_layers: SequenceType = [32], - fc_value_layers: SequenceType = [32], - fc_policy_layers: SequenceType = [32], - reward_support_size: int = 601, - value_support_size: int = 601, - proj_hid: int = 1024, - proj_out: int = 1024, - pred_hid: int = 512, - pred_out: int = 1024, - self_supervised_learning_loss: bool = True, - categorical_distribution: bool = True, - last_linear_layer_init_zero: bool = True, - state_norm: bool = False, - activation: Optional[nn.Module] = nn.ReLU(inplace=True), - norm_type: Optional[str] = 'BN', - discrete_action_encoding_type: str = 'one_hot', - res_connection_in_dynamics: bool = False, - *args, - **kwargs, - ): - """ - Overview: - The definition of the network model of EfficientZero, which is a generalization version for 1D vector obs. - The networks are mainly built on fully connected layers. - Sampled EfficientZero model consists of a representation network, a dynamics network and a prediction network. - The representation network is an MLP network which maps the raw observation to a latent state. - The dynamics network is an MLP+LSTM network which predicts the next latent state, reward_hidden_state and value_prefix given the current latent state and action. - The prediction network is an MLP network which predicts the value and policy given the current latent state. - Arguments: - - observation_shape (:obj:`int`): Observation space shape, e.g. 8 for Lunarlander. - - action_space_size: (:obj:`int`): Action space size, e.g. 4 for Lunarlander. - - lstm_hidden_size (:obj:`int`): The hidden size of LSTM in dynamics network to predict value_prefix. - - latent_state_dim (:obj:`int`): The dimension of latent state, such as 256. - - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). - - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - reward_support_size (:obj:`int`): The size of categorical reward output - - value_support_size (:obj:`int`): The size of categorical value output. - - proj_hid (:obj:`int`): The size of projection hidden layer. - - proj_out (:obj:`int`): The size of projection output layer. - - pred_hid (:obj:`int`): The size of prediction hidden layer. - - pred_out (:obj:`int`): The size of prediction output layer. - - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks in Sampled EfficientZero model, default set it to False. - - categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical distribution for value, reward/value_prefix. - - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. - - state_norm (:obj:`bool`): Whether to use normalization for latent states, default sets it to True. - - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ - operation to speedup, e.g. ReLU(inplace=True). - - discrete_action_encoding_type (:obj:`str`): The type of encoding for discrete action. Default sets it to 'one_hot'. options = {'one_hot', 'not_one_hot'} - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. - """ - super(GoBiggerEfficientZeroModel, self).__init__() - if not categorical_distribution: - self.reward_support_size = 1 - self.value_support_size = 1 - else: - self.reward_support_size = reward_support_size - self.value_support_size = value_support_size - - self.action_space_size = action_space_size - self.continuous_action_space = False - # The dim of action space. For discrete action space, it is 1. - # For continuous action space, it is the dimension of continuous action. - self.action_space_dim = action_space_size if self.continuous_action_space else 1 - assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type - self.discrete_action_encoding_type = discrete_action_encoding_type - if self.continuous_action_space: - self.action_encoding_dim = action_space_size - else: - if self.discrete_action_encoding_type == 'one_hot': - self.action_encoding_dim = action_space_size - elif self.discrete_action_encoding_type == 'not_one_hot': - self.action_encoding_dim = 1 - - self.lstm_hidden_size = lstm_hidden_size - self.proj_hid = proj_hid - self.proj_out = proj_out - self.pred_hid = pred_hid - self.pred_out = pred_out - self.self_supervised_learning_loss = self_supervised_learning_loss - self.last_linear_layer_init_zero = last_linear_layer_init_zero - self.state_norm = state_norm - self.res_connection_in_dynamics = res_connection_in_dynamics - - self.representation_network = GoBiggerEncoder() - - self.dynamics_network = DynamicsNetworkMLP( - action_encoding_dim=self.action_encoding_dim, - num_channels=latent_state_dim + self.action_encoding_dim, - common_layer_num=2, - lstm_hidden_size=lstm_hidden_size, - fc_reward_layers=fc_reward_layers, - output_support_size=self.reward_support_size, - last_linear_layer_init_zero=self.last_linear_layer_init_zero, - norm_type=norm_type, - res_connection_in_dynamics=self.res_connection_in_dynamics, - ) - - self.prediction_network = PredictionNetworkMLP( - action_space_size=action_space_size, - num_channels=latent_state_dim, - fc_value_layers=fc_value_layers, - fc_policy_layers=fc_policy_layers, - output_support_size=self.value_support_size, - last_linear_layer_init_zero=self.last_linear_layer_init_zero, - norm_type=norm_type - ) - - if self.self_supervised_learning_loss: - # self_supervised_learning_loss related network proposed in EfficientZero - self.projection_input_dim = latent_state_dim - - self.projection = nn.Sequential( - nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) - ) - self.prediction_head = nn.Sequential( - nn.Linear(self.proj_out, self.pred_hid), - nn.BatchNorm1d(self.pred_hid), - activation, - nn.Linear(self.pred_hid, self.pred_out), - ) - - def initial_inference(self, obs: torch.Tensor) -> EZNetworkOutput: - """ - Overview: - Initial inference of EfficientZero model, which is the first step of the EfficientZero model. - To perform the initial inference, we first use the representation network to obtain the "latent_state" of the observation. - Then we use the prediction network to predict the "value" and "policy_logits" of the "latent_state", and - also prepare the zeros-like ``reward_hidden_state`` for the next step of the EfficientZero model. - Arguments: - - obs (:obj:`torch.Tensor`): The 1D vector observation data. - Returns (EZNetworkOutput): - - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. - - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. \ - In initial inference, we set it to zero vector. - - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The hidden state of LSTM about reward. In initial inference, \ - we set it to the zeros-like hidden state (H and C). - Shapes: - - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. - - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. - - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. - - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. - - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size. - """ - batch_size = len(obs) - obs = default_collate(obs) - latent_state = self._representation(obs) - device = latent_state.device - policy_logits, value = self._prediction(latent_state) - # zero initialization for reward hidden states - # (hn, cn), each element shape is (layer_num=1, batch_size, lstm_hidden_size) - reward_hidden_state = ( - torch.zeros(1, batch_size, self.lstm_hidden_size).to(device), - torch.zeros(1, batch_size, self.lstm_hidden_size).to(device), - ) - return EZNetworkOutput(value, [0. for _ in range(batch_size)], policy_logits, latent_state, reward_hidden_state) - - def recurrent_inference( - self, latent_state: torch.Tensor, reward_hidden_state: torch.Tensor, action: torch.Tensor - ) -> EZNetworkOutput: - """ - Overview: - Recurrent inference of EfficientZero model, which is the rollout step of the EfficientZero model. - To perform the recurrent inference, we first use the dynamics network to predict ``next_latent_state``, - ``reward_hidden_state``, ``value_prefix`` by the given current ``latent_state`` and ``action``. - We then use the prediction network to predict the ``value`` and ``policy_logits``. - Arguments: - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. - - action (:obj:`torch.Tensor`): The predicted action to rollout. - Returns (EZNetworkOutput): - - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. - - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. - - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. - - next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. - - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. - Shapes: - - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. - - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. - - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. - - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. - - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size. - """ - next_latent_state, reward_hidden_state, value_prefix = self._dynamics(latent_state, reward_hidden_state, action) - policy_logits, value = self._prediction(next_latent_state) - return EZNetworkOutput(value, value_prefix, policy_logits, next_latent_state, reward_hidden_state) - - def _representation(self, observation: torch.Tensor) -> Tuple[torch.Tensor]: - """ - Overview: - Use the representation network to encode the observations into latent state. - Arguments: - - obs (:obj:`torch.Tensor`): The 1D vector observation data. - Returns: - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - Shapes: - - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. - - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - """ - latent_state = self.representation_network(observation) - if self.state_norm: - latent_state = renormalize(latent_state) - return latent_state - - def _prediction(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor]: - """ - Overview: - Use the representation network to encode the observations into latent state. - Arguments: - - obs (:obj:`torch.Tensor`): The 1D vector observation data. - Returns: - - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. - - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. - Shapes: - - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. - - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. - """ - policy_logits, value = self.prediction_network(latent_state) - return policy_logits, value - - def _dynamics(self, latent_state: torch.Tensor, reward_hidden_state: Tuple, - action: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor], torch.Tensor]: - """ - Overview: - Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` - ``value_prefix`` and ``next_reward_hidden_state``. - Arguments: - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. - - action (:obj:`torch.Tensor`): The predicted action to rollout. - Returns: - - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. - - next_reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. - - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. - Shapes: - - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. - - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. - """ - # NOTE: the discrete action encoding type is important for some environments - - # discrete action space - if self.discrete_action_encoding_type == 'one_hot': - # Stack latent_state with the one hot encoded action - if len(action.shape) == 1: - # (batch_size, ) -> (batch_size, 1) - # e.g., torch.Size([8]) -> torch.Size([8, 1]) - action = action.unsqueeze(-1) - - # transform action to one-hot encoding. - # action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4) - action_one_hot = torch.zeros(action.shape[0], self.action_space_size, device=action.device) - # transform action to torch.int64 - action = action.long() - action_one_hot.scatter_(1, action, 1) - action_encoding = action_one_hot - elif self.discrete_action_encoding_type == 'not_one_hot': - action_encoding = action / self.action_space_size - if len(action_encoding.shape) == 1: - # (batch_size, ) -> (batch_size, 1) - # e.g., torch.Size([8]) -> torch.Size([8, 1]) - action_encoding = action_encoding.unsqueeze(-1) - - action_encoding = action_encoding.to(latent_state.device).float() - # state_action_encoding shape: (batch_size, latent_state[1] + action_dim]) or - # (batch_size, latent_state[1] + action_space_size]) depending on the discrete_action_encoding_type. - state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) - - # NOTE: the key difference with MuZero - next_latent_state, next_reward_hidden_state, value_prefix = self.dynamics_network( - state_action_encoding, reward_hidden_state - ) - - if self.state_norm: - next_latent_state = renormalize(next_latent_state) - return next_latent_state, next_reward_hidden_state, value_prefix - - def project(self, latent_state: torch.Tensor, with_grad=True): - """ - Overview: - Project the latent state to a lower dimension to calculate the self-supervised loss, which is proposed in EfficientZero. - For more details, please refer to the paper ``Exploring Simple Siamese Representation Learning``. - Arguments: - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - - with_grad (:obj:`bool`): Whether to calculate gradient for the projection result. - Returns: - - proj (:obj:`torch.Tensor`): The result embedding vector of projection operation. - Shapes: - - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - - proj (:obj:`torch.Tensor`): :math:`(B, projection_output_dim)`, where B is batch_size. - - Examples: - >>> latent_state = torch.randn(256, 64) - >>> output = self.project(latent_state) - >>> output.shape # (256, 1024) - """ - proj = self.projection(latent_state) - - if with_grad: - # with grad, use prediction_head - return self.prediction_head(proj) - else: - return proj.detach() - - def get_params_mean(self) -> float: - return get_params_mean(self) - - -class DynamicsNetworkMLP(nn.Module): - - def __init__( - self, - action_encoding_dim: int = 2, - num_channels: int = 64, - common_layer_num: int = 2, - fc_reward_layers: SequenceType = [32], - output_support_size: int = 601, - lstm_hidden_size: int = 512, - last_linear_layer_init_zero: bool = True, - activation: Optional[nn.Module] = nn.ReLU(inplace=True), - norm_type: Optional[str] = 'BN', - res_connection_in_dynamics: bool = False, - ): - """ - Overview: - The definition of dynamics network in EfficientZero algorithm, which is used to predict next latent state - value_prefix and reward_hidden_state by the given current latent state and action. - The networks are mainly built on fully connected layers. - Arguments: - - action_encoding_dim (:obj:`int`): The dimension of action encoding. - - num_channels (:obj:`int`): The num of channels in latent states. - - common_layer_num (:obj:`int`): The number of common layers in dynamics network. - - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). - - output_support_size (:obj:`int`): The size of categorical reward output. - - lstm_hidden_size (:obj:`int`): The hidden size of lstm in dynamics network. - - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializationss for the last layer of value/policy head, default sets it to True. - - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ - operation to speedup, e.g. ReLU(inplace=True). - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection in dynamics network. - """ - super().__init__() - assert num_channels > action_encoding_dim, f'num_channels:{num_channels} <= action_encoding_dim:{action_encoding_dim}' - - self.num_channels = num_channels - self.action_encoding_dim = action_encoding_dim - self.latent_state_dim = self.num_channels - self.action_encoding_dim - self.lstm_hidden_size = lstm_hidden_size - self.activation = activation - self.res_connection_in_dynamics = res_connection_in_dynamics - - if self.res_connection_in_dynamics: - self.fc_dynamics_1 = MLP( - in_channels=self.num_channels, - hidden_channels=self.latent_state_dim, - layer_num=common_layer_num, - out_channels=self.latent_state_dim, - activation=activation, - norm_type=norm_type, - output_activation=True, - output_norm=True, - # last_linear_layer_init_zero=False is important for convergence - last_linear_layer_init_zero=False, - ) - self.fc_dynamics_2 = MLP( - in_channels=self.latent_state_dim, - hidden_channels=self.latent_state_dim, - layer_num=common_layer_num, - out_channels=self.latent_state_dim, - activation=activation, - norm_type=norm_type, - output_activation=True, - output_norm=True, - # last_linear_layer_init_zero=False is important for convergence - last_linear_layer_init_zero=False, - ) - else: - self.fc_dynamics = MLP( - in_channels=self.num_channels, - hidden_channels=self.latent_state_dim, - layer_num=common_layer_num, - out_channels=self.latent_state_dim, - activation=activation, - norm_type=norm_type, - output_activation=True, - output_norm=True, - # last_linear_layer_init_zero=False is important for convergence - last_linear_layer_init_zero=False, - ) - - # input_shape: (sequence_length,batch_size,input_size) - # output_shape: (sequence_length, batch_size, hidden_size) - self.lstm = nn.LSTM(input_size=self.latent_state_dim, hidden_size=self.lstm_hidden_size) - - self.fc_reward_head = MLP( - in_channels=self.lstm_hidden_size, - hidden_channels=fc_reward_layers[0], - layer_num=2, - out_channels=output_support_size, - activation=self.activation, - norm_type=norm_type, - output_activation=False, - output_norm=False, - last_linear_layer_init_zero=last_linear_layer_init_zero - ) - - def forward(self, state_action_encoding: torch.Tensor, reward_hidden_state): - """ - Overview: - Forward computation of the dynamics network. Predict next latent state given current state_action_encoding and reward hidden state. - Arguments: - - state_action_encoding (:obj:`torch.Tensor`): The state-action encoding, which is the concatenation of \ - latent state and action encoding, with shape (batch_size, num_channels, height, width). - - reward_hidden_state (:obj:`Tuple[torch.Tensor, torch.Tensor]`): The input hidden state of LSTM about reward. - Returns: - - next_latent_state (:obj:`torch.Tensor`): The next latent state, with shape (batch_size, latent_state_dim). - - next_reward_hidden_state (:obj:`torch.Tensor`): The input hidden state of LSTM about reward. - - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. - """ - if self.res_connection_in_dynamics: - # take the state encoding (latent_state), state_action_encoding[:, -self.action_encoding_dim] - # is action encoding - latent_state = state_action_encoding[:, :-self.action_encoding_dim] - x = self.fc_dynamics_1(state_action_encoding) - # the residual link: add state encoding to the state_action encoding - next_latent_state = x + latent_state - next_latent_state_ = self.fc_dynamics_2(next_latent_state) - else: - next_latent_state = self.fc_dynamics(state_action_encoding) - next_latent_state_ = next_latent_state - - next_latent_state_unsqueeze = next_latent_state_.unsqueeze(0) - value_prefix, next_reward_hidden_state = self.lstm(next_latent_state_unsqueeze, reward_hidden_state) - value_prefix = self.fc_reward_head(value_prefix.squeeze(0)) - - return next_latent_state, next_reward_hidden_state, value_prefix - - def get_dynamic_mean(self) -> float: - return get_dynamic_mean(self) - - def get_reward_mean(self) -> Tuple[ndarray, float]: - return get_reward_mean(self) diff --git a/lzero/model/gobigger/network/gobigger_encoder.py b/lzero/model/gobigger/gobigger_encoder.py similarity index 100% rename from lzero/model/gobigger/network/gobigger_encoder.py rename to lzero/model/gobigger/gobigger_encoder.py diff --git a/lzero/model/gobigger/gobigger_muzero_model.py b/lzero/model/gobigger/gobigger_muzero_model.py deleted file mode 100644 index 36585b4ef..000000000 --- a/lzero/model/gobigger/gobigger_muzero_model.py +++ /dev/null @@ -1,449 +0,0 @@ -from typing import Optional, Tuple - -import torch -import torch.nn as nn -from ding.torch_utils import MLP -from ding.utils import MODEL_REGISTRY, SequenceType - -from ..common import MZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP -from ..utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean -from .network.gobigger_encoder import GoBiggerEncoder -import yaml -from easydict import EasyDict -from ding.utils.data import default_collate - - -@MODEL_REGISTRY.register('GoBiggerMuZeroModel') -class GoBiggerMuZeroModel(nn.Module): - - def __init__( - self, - observation_shape: int = 2, - action_space_size: int = 6, - latent_state_dim: int = 256, - fc_reward_layers: SequenceType = [32], - fc_value_layers: SequenceType = [32], - fc_policy_layers: SequenceType = [32], - reward_support_size: int = 601, - value_support_size: int = 601, - proj_hid: int = 1024, - proj_out: int = 1024, - pred_hid: int = 512, - pred_out: int = 1024, - self_supervised_learning_loss: bool = False, - categorical_distribution: bool = True, - activation: Optional[nn.Module] = nn.ReLU(inplace=True), - last_linear_layer_init_zero: bool = True, - state_norm: bool = False, - discrete_action_encoding_type: str = 'one_hot', - norm_type: Optional[str] = 'BN', - res_connection_in_dynamics: bool = False, - *args, - **kwargs - ): - """ - Overview: - The definition of the network model of MuZero, which is a generalization version for 1D vector obs. - The networks are mainly built on fully connected layers. - The representation network is an MLP network which maps the raw observation to a latent state. - The dynamics network is an MLP network which predicts the next latent state, and reward given the current latent state and action. - The prediction network is an MLP network which predicts the value and policy given the current latent state. - Arguments: - - observation_shape (:obj:`int`): Observation space shape, e.g. 8 for Lunarlander. - - action_space_size: (:obj:`int`): Action space size, e.g. 4 for Lunarlander. - - latent_state_dim (:obj:`int`): The dimension of latent state, such as 256. - - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). - - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - reward_support_size (:obj:`int`): The size of categorical reward output - - value_support_size (:obj:`int`): The size of categorical value output. - - proj_hid (:obj:`int`): The size of projection hidden layer. - - proj_out (:obj:`int`): The size of projection output layer. - - pred_hid (:obj:`int`): The size of prediction hidden layer. - - pred_out (:obj:`int`): The size of prediction output layer. - - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks in MuZero model, default set it to False. - - categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical distribution for value, reward/value_prefix. - - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ - operation to speedup, e.g. ReLU(inplace=True). - - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. - - state_norm (:obj:`bool`): Whether to use normalization for latent states, default sets it to True. - - discrete_action_encoding_type (:obj:`str`): The encoding type of discrete action, which can be 'one_hot' or 'not_one_hot'. - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. - """ - super(GoBiggerMuZeroModel, self).__init__() - self.categorical_distribution = categorical_distribution - if not self.categorical_distribution: - self.reward_support_size = 1 - self.value_support_size = 1 - else: - self.reward_support_size = reward_support_size - self.value_support_size = value_support_size - - self.action_space_size = action_space_size - self.continuous_action_space = False - # The dim of action space. For discrete action space, it is 1. - # For continuous action space, it is the dimension of continuous action. - self.action_space_dim = action_space_size if self.continuous_action_space else 1 - assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type - self.discrete_action_encoding_type = discrete_action_encoding_type - if self.continuous_action_space: - self.action_encoding_dim = action_space_size - else: - if self.discrete_action_encoding_type == 'one_hot': - self.action_encoding_dim = action_space_size - elif self.discrete_action_encoding_type == 'not_one_hot': - self.action_encoding_dim = 1 - - self.latent_state_dim = latent_state_dim - self.proj_hid = proj_hid - self.proj_out = proj_out - self.pred_hid = pred_hid - self.pred_out = pred_out - self.self_supervised_learning_loss = self_supervised_learning_loss - self.last_linear_layer_init_zero = last_linear_layer_init_zero - self.state_norm = state_norm - self.res_connection_in_dynamics = res_connection_in_dynamics - - self.representation_network = GoBiggerEncoder() - - self.dynamics_network = DynamicsNetwork( - action_encoding_dim=self.action_encoding_dim, - num_channels=self.latent_state_dim + self.action_encoding_dim, - common_layer_num=2, - fc_reward_layers=fc_reward_layers, - output_support_size=self.reward_support_size, - last_linear_layer_init_zero=self.last_linear_layer_init_zero, - norm_type=norm_type, - res_connection_in_dynamics=self.res_connection_in_dynamics, - ) - - self.prediction_network = PredictionNetworkMLP( - action_space_size=action_space_size, - num_channels=latent_state_dim, - fc_value_layers=fc_value_layers, - fc_policy_layers=fc_policy_layers, - output_support_size=self.value_support_size, - last_linear_layer_init_zero=self.last_linear_layer_init_zero, - norm_type=norm_type - ) - - if self.self_supervised_learning_loss: - # self_supervised_learning_loss related network proposed in EfficientZero - self.projection_input_dim = latent_state_dim - - self.projection = nn.Sequential( - nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) - ) - self.prediction_head = nn.Sequential( - nn.Linear(self.proj_out, self.pred_hid), - nn.BatchNorm1d(self.pred_hid), - activation, - nn.Linear(self.pred_hid, self.pred_out), - ) - - def initial_inference(self, obs: torch.Tensor) -> MZNetworkOutput: - """ - Overview: - Initial inference of MuZero model, which is the first step of the MuZero model. - To perform the initial inference, we first use the representation network to obtain the "latent_state" of the observation. - Then we use the prediction network to predict the "value" and "policy_logits" of the "latent_state", and - also prepare the zeros-like ``reward`` for the next step of the MuZero model. - Arguments: - - obs (:obj:`torch.Tensor`): The 1D vector observation data. - Returns (MZNetworkOutput): - - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. - - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. \ - In initial inference, we set it to zero vector. - - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The hidden state of LSTM about reward. In initial inference, \ - we set it to the zeros-like hidden state (H and C). - Shapes: - - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. - - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. - - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. - - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. - - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - """ - batch_size = len(obs) - obs = default_collate(obs) - latent_state = self._representation(obs) - policy_logits, value = self._prediction(latent_state) - return MZNetworkOutput( - value, - [0. for _ in range(batch_size)], - policy_logits, - latent_state, - ) - - def recurrent_inference(self, latent_state: torch.Tensor, action: torch.Tensor) -> MZNetworkOutput: - """ - Overview: - Recurrent inference of MuZero model, which is the rollout step of the MuZero model. - To perform the recurrent inference, we first use the dynamics network to predict ``next_latent_state``, - ``reward`` by the given current ``latent_state`` and ``action``. - We then use the prediction network to predict the ``value`` and ``policy_logits``. - Arguments: - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input obs. - - action (:obj:`torch.Tensor`): The predicted action to rollout. - Returns (MZNetworkOutput): - - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. - - reward (:obj:`torch.Tensor`): The predicted reward for input state. - - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. - - next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. - Shapes: - - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. - - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. - - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. - - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. - - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - """ - next_latent_state, reward = self._dynamics(latent_state, action) - policy_logits, value = self._prediction(next_latent_state) - return MZNetworkOutput(value, reward, policy_logits, next_latent_state) - - def _representation(self, observation: torch.Tensor) -> Tuple[torch.Tensor]: - """ - Overview: - Use the representation network to encode the observations into latent state. - Arguments: - - obs (:obj:`torch.Tensor`): The 1D vector observation data. - Returns: - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - Shapes: - - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. - - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - """ - latent_state = self.representation_network(observation) - if self.state_norm: - latent_state = renormalize(latent_state) - return latent_state - - def _prediction(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Overview: - Use the representation network to encode the observations into latent state. - Arguments: - - obs (:obj:`torch.Tensor`): The 1D vector observation data. - Returns: - - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. - - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. - Shapes: - - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. - - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. - """ - policy_logits, value = self.prediction_network(latent_state) - return policy_logits, value - - def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Overview: - Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` - ``reward`` and ``next_reward_hidden_state``. - Arguments: - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. - - action (:obj:`torch.Tensor`): The predicted action to rollout. - Returns: - - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. - - next_reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. - - reward (:obj:`torch.Tensor`): The predicted reward for input state. - Shapes: - - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. - - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. - """ - # NOTE: the discrete action encoding type is important for some environments - - # discrete action space - if self.discrete_action_encoding_type == 'one_hot': - # Stack latent_state with the one hot encoded action - if len(action.shape) == 1: - # (batch_size, ) -> (batch_size, 1) - # e.g., torch.Size([8]) -> torch.Size([8, 1]) - action = action.unsqueeze(-1) - - # transform action to one-hot encoding. - # action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4) - action_one_hot = torch.zeros(action.shape[0], self.action_space_size, device=action.device) - # transform action to torch.int64 - action = action.long() - action_one_hot.scatter_(1, action, 1) - action_encoding = action_one_hot - elif self.discrete_action_encoding_type == 'not_one_hot': - action_encoding = action / self.action_space_size - if len(action_encoding.shape) == 1: - # (batch_size, ) -> (batch_size, 1) - # e.g., torch.Size([8]) -> torch.Size([8, 1]) - action_encoding = action_encoding.unsqueeze(-1) - - action_encoding = action_encoding.to(latent_state.device).float() - # state_action_encoding shape: (batch_size, latent_state[1] + action_dim]) or - # (batch_size, latent_state[1] + action_space_size]) depending on the discrete_action_encoding_type. - state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) - - next_latent_state, reward = self.dynamics_network(state_action_encoding) - - if not self.state_norm: - return next_latent_state, reward - else: - next_latent_state_normalized = renormalize(next_latent_state) - return next_latent_state_normalized, reward - - def project(self, latent_state: torch.Tensor, with_grad=True) -> torch.Tensor: - """ - Overview: - Project the latent state to a lower dimension to calculate the self-supervised loss, which is proposed in EfficientZero. - For more details, please refer to the paper ``Exploring Simple Siamese Representation Learning``. - Arguments: - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - - with_grad (:obj:`bool`): Whether to calculate gradient for the projection result. - Returns: - - proj (:obj:`torch.Tensor`): The result embedding vector of projection operation. - Shapes: - - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - - proj (:obj:`torch.Tensor`): :math:`(B, projection_output_dim)`, where B is batch_size. - - Examples: - >>> latent_state = torch.randn(256, 64) - >>> output = self.project(latent_state) - >>> output.shape # (256, 1024) - """ - proj = self.projection(latent_state) - - if with_grad: - # with grad, use prediction_head - return self.prediction_head(proj) - else: - return proj.detach() - - def get_params_mean(self) -> float: - return get_params_mean(self) - - -class DynamicsNetwork(nn.Module): - - def __init__( - self, - action_encoding_dim: int = 2, - num_channels: int = 64, - common_layer_num: int = 2, - fc_reward_layers: SequenceType = [32], - output_support_size: int = 601, - last_linear_layer_init_zero: bool = True, - activation: Optional[nn.Module] = nn.ReLU(inplace=True), - norm_type: Optional[str] = 'BN', - res_connection_in_dynamics: bool = False, - ): - """ - Overview: - The definition of dynamics network in MuZero algorithm, which is used to predict next latent state - reward by the given current latent state and action. - The networks are mainly built on fully connected layers. - Arguments: - - action_encoding_dim (:obj:`int`): The dimension of action encoding. - - num_channels (:obj:`int`): The num of channels in latent states. - - common_layer_num (:obj:`int`): The number of common layers in dynamics network. - - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). - - output_support_size (:obj:`int`): The size of categorical reward output. - - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. - - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ - operation to speedup, e.g. ReLU(inplace=True). - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection in dynamics network. - """ - super().__init__() - self.num_channels = num_channels - self.action_encoding_dim = action_encoding_dim - self.latent_state_dim = self.num_channels - self.action_encoding_dim - - self.res_connection_in_dynamics = res_connection_in_dynamics - if self.res_connection_in_dynamics: - self.fc_dynamics_1 = MLP( - in_channels=self.num_channels, - hidden_channels=self.latent_state_dim, - layer_num=common_layer_num, - out_channels=self.latent_state_dim, - activation=activation, - norm_type=norm_type, - output_activation=True, - output_norm=True, - # last_linear_layer_init_zero=False is important for convergence - last_linear_layer_init_zero=False, - ) - self.fc_dynamics_2 = MLP( - in_channels=self.latent_state_dim, - hidden_channels=self.latent_state_dim, - layer_num=common_layer_num, - out_channels=self.latent_state_dim, - activation=activation, - norm_type=norm_type, - output_activation=True, - output_norm=True, - # last_linear_layer_init_zero=False is important for convergence - last_linear_layer_init_zero=False, - ) - else: - self.fc_dynamics = MLP( - in_channels=self.num_channels, - hidden_channels=self.latent_state_dim, - layer_num=common_layer_num, - out_channels=self.latent_state_dim, - activation=activation, - norm_type=norm_type, - output_activation=True, - output_norm=True, - # last_linear_layer_init_zero=False is important for convergence - last_linear_layer_init_zero=False, - ) - - self.fc_reward_head = MLP( - in_channels=self.latent_state_dim, - hidden_channels=fc_reward_layers[0], - layer_num=2, - out_channels=output_support_size, - activation=activation, - norm_type=norm_type, - output_activation=False, - output_norm=False, - last_linear_layer_init_zero=last_linear_layer_init_zero - ) - - def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Overview: - Forward computation of the dynamics network. Predict the next latent state given current latent state and action. - Arguments: - - state_action_encoding (:obj:`torch.Tensor`): The state-action encoding, which is the concatenation of \ - latent state and action encoding, with shape (batch_size, num_channels, height, width). - Returns: - - next_latent_state (:obj:`torch.Tensor`): The next latent state, with shape (batch_size, latent_state_dim). - - reward (:obj:`torch.Tensor`): The predicted reward for input state. - """ - if self.res_connection_in_dynamics: - # take the state encoding (e.g. latent_state), - # state_action_encoding[:, -self.action_encoding_dim:] is action encoding - latent_state = state_action_encoding[:, :-self.action_encoding_dim] - x = self.fc_dynamics_1(state_action_encoding) - # the residual link: add the latent_state to the state_action encoding - next_latent_state = x + latent_state - next_latent_state_encoding = self.fc_dynamics_2(next_latent_state) - else: - next_latent_state = self.fc_dynamics(state_action_encoding) - next_latent_state_encoding = next_latent_state - - reward = self.fc_reward_head(next_latent_state_encoding) - - return next_latent_state, reward - - def get_dynamic_mean(self) -> float: - return get_dynamic_mean(self) - - def get_reward_mean(self) -> float: - return get_reward_mean(self) diff --git a/lzero/model/gobigger/gobigger_sampled_efficientzero_model.py b/lzero/model/gobigger/gobigger_sampled_efficientzero_model.py deleted file mode 100644 index 38990464b..000000000 --- a/lzero/model/gobigger/gobigger_sampled_efficientzero_model.py +++ /dev/null @@ -1,524 +0,0 @@ -from typing import Optional, Tuple - -import torch -import torch.nn as nn -from ding.model.common import ReparameterizationHead -from ding.torch_utils import MLP -from ding.utils import MODEL_REGISTRY, SequenceType - -from ..common import EZNetworkOutput, RepresentationNetworkMLP -from ..efficientzero_model_mlp import DynamicsNetworkMLP -from ..utils import renormalize, get_params_mean -from .network.gobigger_encoder import GoBiggerEncoder -import yaml -from easydict import EasyDict -from ding.utils.data import default_collate - - -@MODEL_REGISTRY.register('GoBiggerSampledEfficientZeroModel') -class GoBiggerSampledEfficientZeroModel(nn.Module): - - def __init__( - self, - observation_shape: int = 2, - action_space_size: int = 6, - latent_state_dim: int = 256, - lstm_hidden_size: int = 512, - fc_reward_layers: SequenceType = [32], - fc_value_layers: SequenceType = [32], - fc_policy_layers: SequenceType = [32], - reward_support_size: int = 601, - value_support_size: int = 601, - proj_hid: int = 1024, - proj_out: int = 1024, - pred_hid: int = 512, - pred_out: int = 1024, - self_supervised_learning_loss: bool = True, - categorical_distribution: bool = True, - activation: Optional[nn.Module] = nn.ReLU(inplace=True), - last_linear_layer_init_zero: bool = True, - state_norm: bool = False, - # ============================================================== - # specific sampled related config - # ============================================================== - continuous_action_space: bool = False, - num_of_sampled_actions: int = 6, - sigma_type='conditioned', - fixed_sigma_value: float = 0.3, - bound_type: str = None, - norm_type: str = 'BN', - discrete_action_encoding_type: str = 'one_hot', - res_connection_in_dynamics: bool = False, - *args, - **kwargs, - ): - """ - Overview: - The definition of the network model of Sampled EfficientZero, which is a generalization version for 1D vector obs. - The networks are mainly built on fully connected layers. - Sampled EfficientZero model consists of a representation network, a dynamics network and a prediction network. - The representation network is an MLP network which maps the raw observation to a latent state. - The dynamics network is an MLP+LSTM network which predicts the next latent state, reward_hidden_state and value_prefix given the current latent state and action. - The prediction network is an MLP network which predicts the value and policy given the current latent state. - Arguments: - - observation_shape (:obj:`int`): Observation space shape, e.g. 8 for Lunarlander. - - action_space_size: (:obj:`int`): Action space size, which is an integer number. For discrete action space, it is the num of discrete actions, \ - e.g. 4 for Lunarlander. For continuous action space, it is the dimension of the continuous action, e.g. 4 for bipedalwalker. - - latent_state_dim (:obj:`int`): The dimension of latent state, such as 256. - - lstm_hidden_size (:obj:`int`): The hidden size of LSTM in dynamics network to predict value_prefix. - - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). - - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - reward_support_size (:obj:`int`): The size of categorical reward output - - value_support_size (:obj:`int`): The size of categorical value output. - - proj_hid (:obj:`int`): The size of projection hidden layer. - - proj_out (:obj:`int`): The size of projection output layer. - - pred_hid (:obj:`int`): The size of prediction hidden layer. - - pred_out (:obj:`int`): The size of prediction output layer. - - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks in Sampled EfficientZero model, default set it to False. - - categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical distribution for value, reward/value_prefix. - - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ - operation to speedup, e.g. ReLU(inplace=True). - - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. - - state_norm (:obj:`bool`): Whether to use normalization for latent states, default sets it to True. - # ============================================================== - # specific sampled related config - # ============================================================== - - continuous_action_space (:obj:`bool`): The type of action space. default set it to False. - - num_of_sampled_actions (:obj:`int`): the number of sampled actions, i.e. the K in original Sampled MuZero paper. - # see ``ReparameterizationHead`` in ``ding.model.common.head`` for more details about the following arguments. - - sigma_type (:obj:`str`): the type of sigma in policy head of prediction network, options={'conditioned', 'fixed'}. - - fixed_sigma_value (:obj:`float`): the fixed sigma value in policy head of prediction network, - - bound_type (:obj:`str`): The type of bound in networks. Default sets it to None. - - norm_type (:obj:`str`): The type of normalization in networks. default set it to 'BN'. - - discrete_action_encoding_type (:obj:`str`): The type of encoding for discrete action. Default sets it to 'one_hot'. options = {'one_hot', 'not_one_hot'} - - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. - """ - super(GoBiggerSampledEfficientZeroModel, self).__init__() - if not categorical_distribution: - self.reward_support_size = 1 - self.value_support_size = 1 - else: - self.reward_support_size = reward_support_size - self.value_support_size = value_support_size - - self.continuous_action_space = continuous_action_space - self.observation_shape = observation_shape - self.action_space_size = action_space_size - # The dim of action space. For discrete action space, it is 1. - # For continuous action space, it is the dimension of continuous action. - self.action_space_dim = action_space_size if self.continuous_action_space else 1 - assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type - self.discrete_action_encoding_type = discrete_action_encoding_type - if self.continuous_action_space: - self.action_encoding_dim = action_space_size - else: - if self.discrete_action_encoding_type == 'one_hot': - self.action_encoding_dim = action_space_size - elif self.discrete_action_encoding_type == 'not_one_hot': - self.action_encoding_dim = 1 - - self.lstm_hidden_size = lstm_hidden_size - self.latent_state_dim = latent_state_dim - self.fc_reward_layers = fc_reward_layers - self.fc_value_layers = fc_value_layers - self.fc_policy_layers = fc_policy_layers - self.proj_hid = proj_hid - self.proj_out = proj_out - self.pred_hid = pred_hid - self.pred_out = pred_out - - self.last_linear_layer_init_zero = last_linear_layer_init_zero - self.state_norm = state_norm - self.self_supervised_learning_loss = self_supervised_learning_loss - - self.sigma_type = sigma_type - self.fixed_sigma_value = fixed_sigma_value - self.bound_type = bound_type - self.norm_type = norm_type - self.num_of_sampled_actions = num_of_sampled_actions - self.res_connection_in_dynamics = res_connection_in_dynamics - - self.representation_network = GoBiggerEncoder() - - self.dynamics_network = DynamicsNetworkMLP( - action_encoding_dim=self.action_encoding_dim, - num_channels=self.latent_state_dim + self.action_encoding_dim, - common_layer_num=2, - lstm_hidden_size=self.lstm_hidden_size, - fc_reward_layers=self.fc_reward_layers, - output_support_size=self.reward_support_size, - last_linear_layer_init_zero=self.last_linear_layer_init_zero, - norm_type=norm_type, - res_connection_in_dynamics=self.res_connection_in_dynamics, - ) - - self.prediction_network = PredictionNetworkMLP( - continuous_action_space=self.continuous_action_space, - action_space_size=self.action_space_size, - num_channels=self.latent_state_dim, - fc_value_layers=self.fc_value_layers, - fc_policy_layers=self.fc_policy_layers, - output_support_size=self.value_support_size, - last_linear_layer_init_zero=self.last_linear_layer_init_zero, - sigma_type=self.sigma_type, - fixed_sigma_value=self.fixed_sigma_value, - bound_type=self.bound_type, - norm_type=self.norm_type, - ) - - if self.self_supervised_learning_loss: - # self_supervised_learning_loss related network proposed in EfficientZero - self.projection_input_dim = latent_state_dim - self.projection = nn.Sequential( - nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) - ) - self.prediction_head = nn.Sequential( - nn.Linear(self.proj_out, self.pred_hid), - nn.BatchNorm1d(self.pred_hid), - activation, - nn.Linear(self.pred_hid, self.pred_out), - ) - - def initial_inference(self, obs: torch.Tensor) -> EZNetworkOutput: - """ - Overview: - Initial inference of SampledEfficientZero model, which is the first step of the SampledEfficientZero model. - To perform the initial inference, we first use the representation network to obtain the "latent_state" of the observation. - Then we use the prediction network to predict the "value" and "policy_logits" of the "latent_state", and - also prepare the zeros-like ``reward_hidden_state`` for the next step of the Sampled EfficientZero model. - Arguments: - - obs (:obj:`torch.Tensor`): The 1D vector observation data. - Returns (EZNetworkOutput): - - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. - - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. \ - In initial inference, we set it to zero vector. - - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The hidden state of LSTM about reward. In initial inference, \ - we set it to the zeros-like hidden state (H and C). - Shapes: - - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. - - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. - - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. - - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. - - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size. - """ - batch_size = len(obs) - obs = default_collate(obs) - latent_state = self._representation(obs) - device = latent_state.device - policy_logits, value = self._prediction(latent_state) - # zero initialization for reward hidden states - # (hn, cn), each element shape is (layer_num=1, batch_size, lstm_hidden_size) - reward_hidden_state = ( - torch.zeros(1, batch_size, - self.lstm_hidden_size).to(device), torch.zeros(1, batch_size, self.lstm_hidden_size).to(device) - ) - return EZNetworkOutput(value, [0. for _ in range(batch_size)], policy_logits, latent_state, reward_hidden_state) - - def recurrent_inference( - self, latent_state: torch.Tensor, reward_hidden_state: torch.Tensor, action: torch.Tensor - ) -> EZNetworkOutput: - """ - Overview: - Recurrent inference of Sampled EfficientZero model, which is the rollout step of the Sampled EfficientZero model. - To perform the recurrent inference, we first use the dynamics network to predict ``next_latent_state``, - ``reward_hidden_state``, ``value_prefix`` by the given current ``latent_state`` and ``action``. - We then use the prediction network to predict the ``value`` and ``policy_logits``. - Arguments: - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. - - action (:obj:`torch.Tensor`): The predicted action to rollout. - Returns (EZNetworkOutput): - - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. - - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. - - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. - - next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. - - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. - Shapes: - - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. - - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. - - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. - - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. - - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size. - """ - next_latent_state, reward_hidden_state, value_prefix = self._dynamics(latent_state, reward_hidden_state, action) - policy_logits, value = self._prediction(next_latent_state) - return EZNetworkOutput(value, value_prefix, policy_logits, next_latent_state, reward_hidden_state) - - def _representation(self, observation: torch.Tensor) -> Tuple[torch.Tensor]: - """ - Overview: - Use the representation network to encode the observations into latent state. - Arguments: - - obs (:obj:`torch.Tensor`): The 1D vector observation data. - Returns: - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - Shapes: - - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. - - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - """ - latent_state = self.representation_network(observation) - if self.state_norm: - latent_state = renormalize(latent_state) - return latent_state - - def _prediction(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Overview: - Use the representation network to encode the observations into latent state. - Arguments: - - obs (:obj:`torch.Tensor`): The 1D vector observation data. - Returns: - - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. - - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. - Shapes: - - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. - - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. - """ - policy, value = self.prediction_network(latent_state) - return policy, value - - def _dynamics(self, latent_state: torch.Tensor, reward_hidden_state: Tuple, - action: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor], torch.Tensor]: - """ - Overview: - Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` - ``value_prefix`` and ``next_reward_hidden_state``. - Arguments: - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. - - action (:obj:`torch.Tensor`): The predicted action to rollout. - Returns: - - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. - - next_reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. - - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. - Shapes: - - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. - - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. - """ - # NOTE: the discrete action encoding type is important for some environments - - if not self.continuous_action_space: - # discrete action space - if self.discrete_action_encoding_type == 'one_hot': - # Stack latent_state with the one hot encoded action - if len(action.shape) == 1: - # (batch_size, ) -> (batch_size, 1) - # e.g., torch.Size([8]) -> torch.Size([8, 1]) - action = action.unsqueeze(-1) - - # transform action to one-hot encoding. - # action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4) - action_one_hot = torch.zeros(action.shape[0], self.action_space_size, device=action.device) - # transform action to torch.int64 - action = action.long() - action_one_hot.scatter_(1, action, 1) - action_encoding = action_one_hot - elif self.discrete_action_encoding_type == 'not_one_hot': - action_encoding = action / self.action_space_size - if len(action_encoding.shape) == 1: - # (batch_size, ) -> (batch_size, 1) - # e.g., torch.Size([8]) -> torch.Size([8, 1]) - action_encoding = action_encoding.unsqueeze(-1) - else: - # continuous action space - if len(action.shape) == 1: - # (batch_size, ) -> (batch_size, 1) - # e.g., torch.Size([8]) -> torch.Size([8, 1]) - action = action.unsqueeze(-1) - elif len(action.shape) == 3: - # (batch_size, action_dim, 1) -> (batch_size, action_dim) - # e.g., torch.Size([8, 2, 1]) -> torch.Size([8, 2]) - action = action.squeeze(-1) - - action_encoding = action - - action_encoding = action_encoding.to(latent_state.device).float() - # state_action_encoding shape: (batch_size, latent_state[1] + action_dim]) or - # (batch_size, latent_state[1] + action_space_size]) depending on the discrete_action_encoding_type. - state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) - - next_latent_state, next_reward_hidden_state, value_prefix = self.dynamics_network( - state_action_encoding, reward_hidden_state - ) - - if not self.state_norm: - return next_latent_state, next_reward_hidden_state, value_prefix - else: - next_latent_state_normalized = renormalize(next_latent_state) - return next_latent_state_normalized, next_reward_hidden_state, value_prefix - - def project(self, latent_state: torch.Tensor, with_grad=True) -> torch.Tensor: - """ - Overview: - Project the latent state to a lower dimension to calculate the self-supervised loss, which is proposed in EfficientZero. - For more details, please refer to the paper ``Exploring Simple Siamese Representation Learning``. - Arguments: - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - - with_grad (:obj:`bool`): Whether to calculate gradient for the projection result. - Returns: - - proj (:obj:`torch.Tensor`): The result embedding vector of projection operation. - Shapes: - - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - - proj (:obj:`torch.Tensor`): :math:`(B, projection_output_dim)`, where B is batch_size. - - Examples: - >>> latent_state = torch.randn(256, 64) - >>> output = self.project(latent_state) - >>> output.shape # (256, 1024) - """ - proj = self.projection(latent_state) - - if with_grad: - # with grad, use prediction_head - return self.prediction_head(proj) - else: - return proj.detach() - - def get_params_mean(self): - return get_params_mean(self) - - -class PredictionNetworkMLP(nn.Module): - - def __init__( - self, - continuous_action_space, - action_space_size, - num_channels, - common_layer_num: int = 2, - fc_value_layers: SequenceType = [32], - fc_policy_layers: SequenceType = [32], - output_support_size: int = 601, - last_linear_layer_init_zero: bool = True, - activation: Optional[nn.Module] = nn.ReLU(inplace=True), - # ============================================================== - # specific sampled related config - # ============================================================== - sigma_type='conditioned', - fixed_sigma_value: float = 0.3, - bound_type: str = None, - norm_type: str = 'BN', - ): - """ - Overview: - The definition of policy and value prediction network, which is used to predict value and policy by the - given latent state. - The networks are mainly built on fully connected layers. - Arguments: - - continuous_action_space (:obj:`bool`): The type of action space. default set it to False. - - action_space_size: (:obj:`int`): Action space size, usually an integer number. For discrete action \ - space, it is the number of discrete actions. For continuous action space, it is the dimension of \ - continuous action. - - num_channels (:obj:`int`): The num of channels in latent states. - - num_res_blocks (:obj:`int`): The number of res blocks. - - fc_value_layers (:obj:`SequenceType`): hidden layers of the value prediction head (MLP head). - - fc_policy_layers (:obj:`SequenceType`): hidden layers of the policy prediction head (MLP head). - - output_support_size (:obj:`int`): dim of value output. - - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. - # ============================================================== - # specific sampled related config - # ============================================================== - # see ``ReparameterizationHead`` in ``ding.model.common.head`` for more details about thee following arguments. - - sigma_type (:obj:`str`): the type of sigma in policy head of prediction network, options={'conditioned', 'fixed'}. - - fixed_sigma_value (:obj:`float`): the fixed sigma value in policy head of prediction network, - - bound_type (:obj:`str`): The type of bound in networks. default set it to None. - - norm_type (:obj:`str`): The type of normalization in networks. default set it to 'BN'. - """ - super().__init__() - self.num_channels = num_channels - self.continuous_action_space = continuous_action_space - self.norm_type = norm_type - self.sigma_type = sigma_type - self.fixed_sigma_value = fixed_sigma_value - self.bound_type = bound_type - self.action_space_size = action_space_size - if self.continuous_action_space: - self.action_encoding_dim = self.action_space_size - else: - self.action_encoding_dim = 1 - - # ******* common backbone ****** - self.fc_prediction_common = MLP( - in_channels=self.num_channels, - hidden_channels=self.num_channels, - out_channels=self.num_channels, - layer_num=common_layer_num, - activation=activation, - norm_type=norm_type, - output_activation=True, - output_norm=True, - # last_linear_layer_init_zero=False is important for convergence - last_linear_layer_init_zero=False, - ) - - # ******* value and policy head ****** - self.fc_value_head = MLP( - in_channels=self.num_channels, - hidden_channels=fc_value_layers[0], - out_channels=output_support_size, - layer_num=2, - activation=activation, - norm_type=norm_type, - output_activation=False, - output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. - last_linear_layer_init_zero=last_linear_layer_init_zero - ) - - # sampled related core code - if self.continuous_action_space: - self.fc_policy_head = ReparameterizationHead( - input_size=self.num_channels, - output_size=action_space_size, - layer_num=2, - sigma_type=self.sigma_type, - fixed_sigma_value=self.fixed_sigma_value, - activation=nn.ReLU(), - norm_type=None, - bound_type=self.bound_type - ) - else: - self.fc_policy_head = MLP( - in_channels=self.num_channels, - hidden_channels=fc_policy_layers[0], - out_channels=action_space_size, - layer_num=2, - activation=activation, - norm_type=self.norm_type, - output_activation=False, - output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. - last_linear_layer_init_zero=last_linear_layer_init_zero - ) - - def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Overview: - Forward computation of the prediction network. - Arguments: - - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, in_channels). - Returns: - - policy (:obj:`torch.Tensor`): policy tensor. If action space is discrete, shape is (B, action_space_size). - If action space is continuous, shape is (B, action_space_size * 2). - - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). - """ - x_prediction_common = self.fc_prediction_common(latent_state) - value = self.fc_value_head(x_prediction_common) - - # sampled related core code - policy = self.fc_policy_head(x_prediction_common) - if self.continuous_action_space: - policy = torch.cat([policy['mu'], policy['sigma']], dim=-1) - - return policy, value diff --git a/lzero/model/muzero_model_structure.py b/lzero/model/muzero_model_structure.py new file mode 100644 index 000000000..af5e0419e --- /dev/null +++ b/lzero/model/muzero_model_structure.py @@ -0,0 +1,183 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from ding.torch_utils import MLP +from ding.utils import MODEL_REGISTRY, SequenceType + +from .common import MZNetworkOutput, PredictionNetworkMLP +from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean +from lzero.model.muzero_model_mlp import MuZeroModelMLP, DynamicsNetwork + + + +@MODEL_REGISTRY.register('MuZeroModelStructure') +class MuZeroModelMLPStructure(MuZeroModelMLP): + + def __init__( + self, + env_name: str, + action_space_size: int = 6, + latent_state_dim: int = 256, + fc_reward_layers: SequenceType = [32], + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + reward_support_size: int = 601, + value_support_size: int = 601, + proj_hid: int = 1024, + proj_out: int = 1024, + pred_hid: int = 512, + pred_out: int = 1024, + self_supervised_learning_loss: bool = False, + categorical_distribution: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + last_linear_layer_init_zero: bool = True, + state_norm: bool = False, + discrete_action_encoding_type: str = 'one_hot', + norm_type: Optional[str] = 'BN', + res_connection_in_dynamics: bool = False, + *args, + **kwargs + ): + """ + Overview: + The definition of the network model of MuZero, which is a generalization version for 1D vector obs. + The networks are mainly built on fully connected layers. + The representation network is an MLP network which maps the raw observation to a latent state. + The dynamics network is an MLP network which predicts the next latent state, and reward given the current latent state and action. + The prediction network is an MLP network which predicts the value and policy given the current latent state. + Arguments: + - observation_shape (:obj:`int`): Observation space shape, e.g. 8 for Lunarlander. + - action_space_size: (:obj:`int`): Action space size, e.g. 4 for Lunarlander. + - latent_state_dim (:obj:`int`): The dimension of latent state, such as 256. + - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - reward_support_size (:obj:`int`): The size of categorical reward output + - value_support_size (:obj:`int`): The size of categorical value output. + - proj_hid (:obj:`int`): The size of projection hidden layer. + - proj_out (:obj:`int`): The size of projection output layer. + - pred_hid (:obj:`int`): The size of prediction hidden layer. + - pred_out (:obj:`int`): The size of prediction output layer. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks in MuZero model, default set it to False. + - categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical distribution for value, reward/value_prefix. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. + - state_norm (:obj:`bool`): Whether to use normalization for latent states, default sets it to True. + - discrete_action_encoding_type (:obj:`str`): The encoding type of discrete action, which can be 'one_hot' or 'not_one_hot'. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. + """ + super(MuZeroModelMLP, self).__init__() + self.categorical_distribution = categorical_distribution + if not self.categorical_distribution: + self.reward_support_size = 1 + self.value_support_size = 1 + else: + self.reward_support_size = reward_support_size + self.value_support_size = value_support_size + + self.action_space_size = action_space_size + self.continuous_action_space = False + # The dim of action space. For discrete action space, it is 1. + # For continuous action space, it is the dimension of continuous action. + self.action_space_dim = action_space_size if self.continuous_action_space else 1 + assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type + self.discrete_action_encoding_type = discrete_action_encoding_type + if self.continuous_action_space: + self.action_encoding_dim = action_space_size + else: + if self.discrete_action_encoding_type == 'one_hot': + self.action_encoding_dim = action_space_size + elif self.discrete_action_encoding_type == 'not_one_hot': + self.action_encoding_dim = 1 + + self.latent_state_dim = latent_state_dim + self.proj_hid = proj_hid + self.proj_out = proj_out + self.pred_hid = pred_hid + self.pred_out = pred_out + self.self_supervised_learning_loss = self_supervised_learning_loss + self.last_linear_layer_init_zero = last_linear_layer_init_zero + self.state_norm = state_norm + self.res_connection_in_dynamics = res_connection_in_dynamics + + if env_name == 'gobigger': + from lzero.model.gobigger.gobigger_encoder import GoBiggerEncoder as Encoder + elif env_name == 'ptz_simple_spread': + from lzero.model.petting_zoo.encoder import PettingZooEncoder as Encoder + else: + raise NotImplementedError + self.representation_network = Encoder() + + self.dynamics_network = DynamicsNetwork( + action_encoding_dim=self.action_encoding_dim, + num_channels=self.latent_state_dim + self.action_encoding_dim, + common_layer_num=2, + fc_reward_layers=fc_reward_layers, + output_support_size=self.reward_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type, + res_connection_in_dynamics=self.res_connection_in_dynamics, + ) + + self.prediction_network = PredictionNetworkMLP( + action_space_size=action_space_size, + num_channels=latent_state_dim, + fc_value_layers=fc_value_layers, + fc_policy_layers=fc_policy_layers, + output_support_size=self.value_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type + ) + + if self.self_supervised_learning_loss: + # self_supervised_learning_loss related network proposed in EfficientZero + self.projection_input_dim = latent_state_dim + + self.projection = nn.Sequential( + nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) + ) + self.prediction_head = nn.Sequential( + nn.Linear(self.proj_out, self.pred_hid), + nn.BatchNorm1d(self.pred_hid), + activation, + nn.Linear(self.pred_hid, self.pred_out), + ) + + def initial_inference(self, obs: torch.Tensor) -> MZNetworkOutput: + """ + Overview: + Initial inference of MuZero model, which is the first step of the MuZero model. + To perform the initial inference, we first use the representation network to obtain the "latent_state" of the observation. + Then we use the prediction network to predict the "value" and "policy_logits" of the "latent_state", and + also prepare the zeros-like ``reward`` for the next step of the MuZero model. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. \ + In initial inference, we set it to zero vector. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The hidden state of LSTM about reward. In initial inference, \ + we set it to the zeros-like hidden state (H and C). + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + """ + batch_size = obs['action_mask'].shape[0] + latent_state = self._representation(obs) + policy_logits, value = self._prediction(latent_state) + return MZNetworkOutput( + value, + [0. for _ in range(batch_size)], + policy_logits, + latent_state, + ) diff --git a/lzero/model/petting_zoo/__init__.py b/lzero/model/petting_zoo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lzero/model/petting_zoo/encoder.py b/lzero/model/petting_zoo/encoder.py new file mode 100644 index 000000000..015ca6066 --- /dev/null +++ b/lzero/model/petting_zoo/encoder.py @@ -0,0 +1,12 @@ +import torch.nn as nn + +class PettingZooEncoder(nn.Module): + + def __init__(self): + super().__init__() + self.encoder = nn.Identity() + + def forward(self, x): + x = x['agent_state'] + x = self.encoder(x) + return x \ No newline at end of file diff --git a/lzero/policy/efficientzero.py b/lzero/policy/efficientzero.py index 953abcc11..1071c3475 100644 --- a/lzero/policy/efficientzero.py +++ b/lzero/policy/efficientzero.py @@ -17,6 +17,9 @@ from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ DiscreteSupport, select_action, to_torch_float_tensor, ez_network_output_unpack, negative_cosine_similarity, prepare_obs, \ configure_optimizers +from collections import defaultdict +from ding.torch_utils import to_device, to_tensor +from ding.utils.data import default_collate @POLICY_REGISTRY.register('efficientzero') @@ -204,6 +207,8 @@ def default_model(self) -> Tuple[str, List[str]]: return 'EfficientZeroModel', ['lzero.model.efficientzero_model'] elif self._cfg.model.model_type == "mlp": return 'EfficientZeroModelMLP', ['lzero.model.efficientzero_model_mlp'] + elif self._cfg.model.model_type == "structure": + return 'EfficientZeroModelStructure', ['lzero.model.efficientzero_model_structure'] def _init_learn(self) -> None: """ @@ -302,7 +307,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: target_value_prefix = target_value_prefix.view(self._cfg.batch_size, -1) target_value = target_value.view(self._cfg.batch_size, -1) - assert obs_batch.size(0) == self._cfg.batch_size == target_value_prefix.size(0) + assert self._cfg.batch_size == target_value_prefix.size(0) # ``scalar_transform`` to transform the original value to the scaled value, # i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. @@ -397,6 +402,15 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: beg_index = self._cfg.model.observation_shape * step_i end_index = self._cfg.model.observation_shape * (step_i + self._cfg.model.frame_stack_num) network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) + elif self._cfg.model.model_type == 'structure': + beg_index = step_i + end_index = step_i + self._cfg.model.frame_stack_num + obs_target_batch_tmp = obs_target_batch[:, beg_index:end_index].tolist() + obs_target_batch_tmp = sum(obs_target_batch_tmp, []) + obs_target_batch_tmp = to_tensor(obs_target_batch_tmp) + obs_target_batch_tmp = to_device(obs_target_batch_tmp, self._cfg.device) + obs_target_batch_tmp = default_collate(obs_target_batch_tmp) + network_output = self._learn_model.initial_inference(obs_target_batch_tmp) latent_state = to_tensor(latent_state) representation_state = to_tensor(network_output.latent_state) diff --git a/lzero/policy/gobigger_efficientzero.py b/lzero/policy/gobigger_efficientzero.py deleted file mode 100644 index ff54a45dd..000000000 --- a/lzero/policy/gobigger_efficientzero.py +++ /dev/null @@ -1,504 +0,0 @@ -from typing import List, Dict, Any, Tuple, Union - -import numpy as np -import torch -from .efficientzero import EfficientZeroPolicy -from ding.torch_utils import to_tensor -from ding.utils import POLICY_REGISTRY -from torch.distributions import Categorical -from torch.nn import L1Loss - -from lzero.mcts import EfficientZeroMCTSCtree as MCTSCtree -from lzero.mcts import EfficientZeroMCTSPtree as MCTSPtree -from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ - DiscreteSupport, select_action, to_torch_float_tensor, ez_network_output_unpack, negative_cosine_similarity, prepare_obs, \ - configure_optimizers -from collections import defaultdict -from ding.torch_utils import to_device - - -@POLICY_REGISTRY.register('gobigger_efficientzero') -class GoBiggerEfficientZeroPolicy(EfficientZeroPolicy): - """ - Overview: - The policy class for GoBiggerEfficientZero. - """ - - def default_model(self) -> Tuple[str, List[str]]: - """ - Overview: - Return this algorithm default model setting. - Returns: - - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. - - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. - - import_names (:obj:`List[str]`): The model class path list used in this algorithm. - .. note:: - The user can define and use customized network model but must obey the same interface definition indicated \ - by import_names path. For EfficientZero, ``lzero.model.efficientzero_model.EfficientZeroModel`` - """ - return 'GoBiggerEfficientZeroModel', ['lzero.model.gobigger.gobigger_efficientzero_model'] - - def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: - """ - Overview: - The forward function for learning policy in learn mode, which is the core of the learning process. - The data is sampled from replay buffer. - The loss is calculated by the loss function and the loss is backpropagated to update the model. - Arguments: - - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. - The first tensor is the current_batch, the second tensor is the target_batch. - Returns: - - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ - current learning loss and learning statistics. - """ - self._learn_model.train() - self._target_model.train() - - current_batch, target_batch = data - obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch - target_value_prefix, target_value, target_policy = target_batch - - obs_batch_ori = obs_batch_ori.tolist() - obs_batch_ori = np.array(obs_batch_ori) - obs_batch = obs_batch_ori[:, 0:self._cfg.model.frame_stack_num] - if self._cfg.model.self_supervised_learning_loss: - obs_target_batch = obs_batch_ori[:, self._cfg.model.frame_stack_num:] - # obs_batch, obs_target_batch = obs_batch_ori.tolist() - - # # do augmentations - # if self._cfg.use_augmentation: - # obs_batch = self.image_transforms.transform(obs_batch) - # obs_target_batch = self.image_transforms.transform(obs_target_batch) - - # shape: (batch_size, num_unroll_steps, action_dim) - # NOTE: .long(), in discrete action space. - action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() - data_list = [ - mask_batch, - target_value_prefix.astype('float64'), - target_value.astype('float64'), target_policy, weights - ] - [mask_batch, target_value_prefix, target_value, target_policy, - weights] = to_torch_float_tensor(data_list, self._cfg.device) - - target_value_prefix = target_value_prefix.view(self._cfg.batch_size, -1) - target_value = target_value.view(self._cfg.batch_size, -1) - assert obs_batch.size == self._cfg.batch_size == target_value_prefix.size(0) - - # ``scalar_transform`` to transform the original value to the scaled value, - # i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. - transformed_target_value_prefix = scalar_transform(target_value_prefix) - transformed_target_value = scalar_transform(target_value) - # transform a scalar to its categorical_distribution. After this transformation, each scalar is - # represented as the linear combination of its two adjacent supports. - target_value_prefix_categorical = phi_transform(self.reward_support, transformed_target_value_prefix) - target_value_categorical = phi_transform(self.value_support, transformed_target_value) - - # ============================================================== - # the core initial_inference in EfficientZero policy. - # ============================================================== - obs_batch = obs_batch.tolist() - obs_batch = sum(obs_batch, []) - obs_batch = to_tensor(obs_batch) - obs_batch = to_device(obs_batch, self._cfg.device) - network_output = self._learn_model.initial_inference(obs_batch) - # value_prefix shape: (batch_size, 10), the ``value_prefix`` at the first step is zero padding. - latent_state, value_prefix, reward_hidden_state, value, policy_logits = ez_network_output_unpack(network_output) - - # transform the scaled value or its categorical representation to its original value, - # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. - original_value = self.inverse_scalar_transform_handle(value) - - # Note: The following lines are just for debugging. - predicted_value_prefixs = [] - if self._cfg.monitor_extra_statistics: - latent_state_list = latent_state.detach().cpu().numpy() - predicted_values, predicted_policies = original_value.detach().cpu(), torch.softmax( - policy_logits, dim=1 - ).detach().cpu() - - # calculate the new priorities for each transition. - value_priority = L1Loss(reduction='none')(original_value.squeeze(-1), target_value[:, 0]) - value_priority = value_priority.data.cpu().numpy() + 1e-6 - - prob = torch.softmax(policy_logits, dim=-1) - dist = Categorical(prob) - policy_entropy = dist.entropy().mean() - - # ============================================================== - # calculate policy and value loss for the first step. - # ============================================================== - policy_loss = cross_entropy_loss(policy_logits, target_policy[:, 0]) - - # Here we take the init hypothetical step k=0. - target_normalized_visit_count_init_step = target_policy[:, 0] - - # ******* NOTE: target_policy_entropy is only for debug. ****** - non_masked_indices = torch.nonzero(mask_batch[:, 0]).squeeze(-1) - # Check if there are any unmasked rows - if len(non_masked_indices) > 0: - target_normalized_visit_count_masked = torch.index_select( - target_normalized_visit_count_init_step, 0, non_masked_indices - ) - target_dist = Categorical(target_normalized_visit_count_masked) - target_policy_entropy = target_dist.entropy().mean() - else: - # Set target_policy_entropy to 0 if all rows are masked - target_policy_entropy = 0 - - value_loss = cross_entropy_loss(value, target_value_categorical[:, 0]) - - value_prefix_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) - consistency_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) - - # ============================================================== - # the core recurrent_inference in EfficientZero policy. - # ============================================================== - for step_i in range(self._cfg.num_unroll_steps): - # unroll with the dynamics function: predict the next ``latent_state``, ``reward_hidden_state``, - # `` value_prefix`` given current ``latent_state`` ``reward_hidden_state`` and ``action``. - # And then predict policy_logits and value with the prediction function. - network_output = self._learn_model.recurrent_inference( - latent_state, reward_hidden_state, action_batch[:, step_i] - ) - latent_state, value_prefix, reward_hidden_state, value, policy_logits = ez_network_output_unpack( - network_output - ) - - # transform the scaled value or its categorical representation to its original value, - # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. - original_value = self.inverse_scalar_transform_handle(value) - - # ============================================================== - # calculate consistency loss for the next ``num_unroll_steps`` unroll steps. - # ============================================================== - if self._cfg.ssl_loss_weight > 0: - beg_index = step_i - end_index = step_i + self._cfg.model.frame_stack_num - obs_target_batch_tmp = obs_target_batch[:, beg_index:end_index].tolist() - obs_target_batch_tmp = sum(obs_target_batch_tmp, []) - obs_target_batch_tmp = to_tensor(obs_target_batch_tmp) - obs_target_batch_tmp = to_device(obs_target_batch_tmp, self._cfg.device) - network_output = self._learn_model.initial_inference(obs_target_batch_tmp) - - latent_state = to_tensor(latent_state) - representation_state = to_tensor(network_output.latent_state) - - # NOTE: no grad for the representation_state branch. - dynamic_proj = self._learn_model.project(latent_state, with_grad=True) - observation_proj = self._learn_model.project(representation_state, with_grad=False) - temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_i] - - consistency_loss += temp_loss - - # NOTE: the target policy, target_value_categorical, target_value_prefix_categorical is calculated in - # game buffer now. - # ============================================================== - # calculate policy loss for the next ``num_unroll_steps`` unroll steps. - # NOTE: the +=. - # ============================================================== - policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_i + 1]) - - # Here we take the hypothetical step k = step_i + 1 - prob = torch.softmax(policy_logits, dim=-1) - dist = Categorical(prob) - policy_entropy += dist.entropy().mean() - target_normalized_visit_count = target_policy[:, step_i + 1] - - # ******* NOTE: target_policy_entropy is only for debug. ****** - non_masked_indices = torch.nonzero(mask_batch[:, step_i + 1]).squeeze(-1) - # Check if there are any unmasked rows - if len(non_masked_indices) > 0: - target_normalized_visit_count_masked = torch.index_select( - target_normalized_visit_count, 0, non_masked_indices - ) - target_dist = Categorical(target_normalized_visit_count_masked) - target_policy_entropy += target_dist.entropy().mean() - else: - # Set target_policy_entropy to 0 if all rows are masked - target_policy_entropy += 0 - - value_loss += cross_entropy_loss(value, target_value_categorical[:, step_i + 1]) - value_prefix_loss += cross_entropy_loss(value_prefix, target_value_prefix_categorical[:, step_i]) - - # reset hidden states every ``lstm_horizon_len`` unroll steps. - if (step_i + 1) % self._cfg.lstm_horizon_len == 0: - reward_hidden_state = ( - torch.zeros(1, self._cfg.batch_size, self._cfg.model.lstm_hidden_size).to(self._cfg.device), - torch.zeros(1, self._cfg.batch_size, self._cfg.model.lstm_hidden_size).to(self._cfg.device) - ) - - if self._cfg.monitor_extra_statistics: - original_value_prefixs = self.inverse_scalar_transform_handle(value_prefix) - original_value_prefixs_cpu = original_value_prefixs.detach().cpu() - - predicted_values = torch.cat( - (predicted_values, self.inverse_scalar_transform_handle(value).detach().cpu()) - ) - predicted_value_prefixs.append(original_value_prefixs_cpu) - predicted_policies = torch.cat((predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu())) - latent_state_list = np.concatenate((latent_state_list, latent_state.detach().cpu().numpy())) - - # ============================================================== - # the core learn model update step. - # ============================================================== - # weighted loss with masks (some invalid states which are out of trajectory.) - loss = ( - self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss + - self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * value_prefix_loss - ) - weighted_total_loss = (weights * loss).mean() - # TODO(pu): test the effect of gradient scale. - gradient_scale = 1 / self._cfg.num_unroll_steps - weighted_total_loss.register_hook(lambda grad: grad * gradient_scale) - self._optimizer.zero_grad() - weighted_total_loss.backward() - total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_( - self._learn_model.parameters(), self._cfg.grad_clip_value - ) - self._optimizer.step() - if self._cfg.lr_piecewise_constant_decay: - self.lr_scheduler.step() - - # ============================================================== - # the core target model update step. - # ============================================================== - self._target_model.update(self._learn_model.state_dict()) - - # packing loss info for tensorboard logging - loss_info = ( - weighted_total_loss.item(), loss.mean().item(), policy_loss.mean().item(), value_prefix_loss.mean().item(), - value_loss.mean().item(), consistency_loss.mean() - ) - - if self._cfg.monitor_extra_statistics: - predicted_value_prefixs = torch.stack(predicted_value_prefixs).transpose(1, 0).squeeze(-1) - predicted_value_prefixs = predicted_value_prefixs.reshape(-1).unsqueeze(-1) - - return { - 'collect_mcts_temperature': self.collect_mcts_temperature, - 'collect_epsilon': self.collect_epsilon, - 'cur_lr': self._optimizer.param_groups[0]['lr'], - 'weighted_total_loss': weighted_total_loss.item(), - 'total_loss': loss.mean().item(), - 'policy_loss': policy_loss.mean().item(), - 'policy_entropy': policy_entropy.item() / (self._cfg.num_unroll_steps + 1), - 'target_policy_entropy': target_policy_entropy.item() / (self._cfg.num_unroll_steps + 1), - 'value_prefix_loss': value_prefix_loss.mean().item(), - 'value_loss': value_loss.mean().item(), - 'consistency_loss': consistency_loss.mean() / self._cfg.num_unroll_steps, - - # ============================================================== - # priority related - # ============================================================== - 'value_priority': value_priority.mean().item(), - 'value_priority_orig': value_priority, - 'target_value_prefix': target_value_prefix.detach().cpu().numpy().mean().item(), - 'target_value': target_value.detach().cpu().numpy().mean().item(), - 'transformed_target_value_prefix': transformed_target_value_prefix.detach().cpu().numpy().mean().item(), - 'transformed_target_value': transformed_target_value.detach().cpu().numpy().mean().item(), - 'predicted_value_prefixs': predicted_value_prefixs.detach().cpu().numpy().mean().item(), - 'predicted_values': predicted_values.detach().cpu().numpy().mean().item(), - 'total_grad_norm_before_clip': total_grad_norm_before_clip - } - - def _forward_collect( - self, - data: torch.Tensor, - action_mask: list = None, - temperature: float = 1, - to_play: List = [-1], - epsilon: float = 0.25, - ready_env_id = None - ): - """ - Overview: - The forward function for collecting data in collect mode. Use model to execute MCTS search. - Choosing the action through sampling during the collect mode. - Arguments: - - data (:obj:`torch.Tensor`): The input data, i.e. the observation. - - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. - - temperature (:obj:`float`): The temperature of the policy. - - to_play (:obj:`int`): The player to play. - - ready_env_id (:obj:`list`): The id of the env that is ready to collect. - Shape: - - data (:obj:`torch.Tensor`): - - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ - S is the number of stacked frames, H is the height of the image, W is the width of the image. - - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. - - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. - - temperature: :math:`(1, )`. - - to_play: :math:`(N, 1)`, where N is the number of collect_env. - - ready_env_id: None - Returns: - - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ - ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. - """ - self._collect_model.eval() - self.collect_mcts_temperature = temperature - self.collect_epsilon = epsilon - - active_collect_env_num = len(data) - data = to_tensor(data) - data = sum(sum(data, []), []) - batch_size = len(data) - data = to_device(data, self._cfg.device) - agent_num = batch_size // active_collect_env_num - to_play = np.array(to_play).reshape(-1).tolist() - - with torch.no_grad(): - # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} - network_output = self._collect_model.initial_inference(data) - latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( - network_output - ) - - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() - latent_state_roots = latent_state_roots.detach().cpu().numpy() - reward_hidden_state_roots = ( - reward_hidden_state_roots[0].detach().cpu().numpy(), - reward_hidden_state_roots[1].detach().cpu().numpy() - ) - policy_logits = policy_logits.detach().cpu().numpy().tolist() - - action_mask = sum(action_mask, []) - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] - # the only difference between collect and eval is the dirichlet noise. - noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) - ).astype(np.float32).tolist() for j in range(batch_size) - ] - if self._cfg.mcts_ctree: - # cpp mcts_tree - roots = MCTSCtree.roots(batch_size, legal_actions) - else: - # python mcts_tree - roots = MCTSPtree.roots(batch_size, legal_actions) - roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_roots, policy_logits, to_play) - self._mcts_collect.search( - roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play - ) - - roots_visit_count_distributions = roots.get_distributions( - ) # shape: ``{list: batch_size} ->{list: action_space_size}`` - roots_values = roots.get_values() # shape: {list: batch_size} - - data_id = [i for i in range(active_collect_env_num)] - output = {i: defaultdict(list) for i in data_id} - if ready_env_id is None: - ready_env_id = np.arange(active_collect_env_num) - - for i in range(batch_size): - distributions, value = roots_visit_count_distributions[i], roots_values[i] - if self._cfg.eps.eps_greedy_exploration_in_collect: - # eps-greedy collect - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self.collect_mcts_temperature, deterministic=True - ) - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - if np.random.rand() < self.collect_epsilon: - action = np.random.choice(legal_actions[i]) - else: - # normal collect - # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents - # the index within the legal action set, rather than the index in the entire action set. - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self.collect_mcts_temperature, deterministic=False - ) - # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - output[i // agent_num]['action'].append(action) - output[i // agent_num]['distributions'].append(distributions) - output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) - output[i // agent_num]['value'].append(value) - output[i // agent_num]['pred_value'].append(pred_values[i]) - output[i // agent_num]['policy_logits'].append(policy_logits[i]) - - return output - - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id=None): - """ - Overview: - The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. - Choosing the action with the highest value (argmax) rather than sampling during the eval mode. - Arguments: - - data (:obj:`torch.Tensor`): The input data, i.e. the observation. - - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. - - to_play (:obj:`int`): The player to play. - - ready_env_id (:obj:`list`): The id of the env that is ready to collect. - Shape: - - data (:obj:`torch.Tensor`): - - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ - S is the number of stacked frames, H is the height of the image, W is the width of the image. - - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. - - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. - - to_play: :math:`(N, 1)`, where N is the number of collect_env. - - ready_env_id: None - Returns: - - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ - ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. - """ - self._eval_model.eval() - active_eval_env_num = len(data) - data = to_tensor(data) - data = sum(sum(data, []), []) - batch_size = len(data) - data = to_device(data, self._cfg.device) - agent_num = batch_size // active_eval_env_num - to_play = np.array(to_play).reshape(-1).tolist() - - with torch.no_grad(): - # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} - network_output = self._eval_model.initial_inference(data) - latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( - network_output - ) - - if not self._eval_model.training: - # if not in training, obtain the scalars of the value/reward - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) - latent_state_roots = latent_state_roots.detach().cpu().numpy() - reward_hidden_state_roots = ( - reward_hidden_state_roots[0].detach().cpu().numpy(), - reward_hidden_state_roots[1].detach().cpu().numpy() - ) - policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) - - action_mask = sum(action_mask, []) - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] - if self._cfg.mcts_ctree: - # cpp mcts_tree - roots = MCTSCtree.roots(batch_size, legal_actions) - else: - # python mcts_tree - roots = MCTSPtree.roots(batch_size, legal_actions) - roots.prepare_no_noise(value_prefix_roots, policy_logits, to_play) - self._mcts_eval.search(roots, self._eval_model, latent_state_roots, reward_hidden_state_roots, to_play) - - roots_visit_count_distributions = roots.get_distributions( - ) # shape: ``{list: batch_size} ->{list: action_space_size}`` - roots_values = roots.get_values() # shape: {list: batch_size} - data_id = [i for i in range(active_eval_env_num)] - output = {i: defaultdict(list) for i in data_id} - - if ready_env_id is None: - ready_env_id = np.arange(active_eval_env_num) - - for i in range(batch_size): - distributions, value = roots_visit_count_distributions[i], roots_values[i] - # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents - # the index within the legal action set, rather than the index in the entire action set. - # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than sampling during the evaluation phase. - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=1, deterministic=True - ) - # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - output[i // agent_num]['action'].append(action) - output[i // agent_num]['distributions'].append(distributions) - output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) - output[i // agent_num]['value'].append(value) - output[i // agent_num]['pred_value'].append(pred_values[i]) - output[i // agent_num]['policy_logits'].append(policy_logits[i]) - - return output \ No newline at end of file diff --git a/lzero/policy/gobigger_muzero.py b/lzero/policy/gobigger_muzero.py deleted file mode 100644 index 3ae716c91..000000000 --- a/lzero/policy/gobigger_muzero.py +++ /dev/null @@ -1,445 +0,0 @@ -from typing import List, Dict, Any, Tuple, Union - -import numpy as np -import torch -from .muzero import MuZeroPolicy -from ding.torch_utils import to_tensor -from ding.utils import POLICY_REGISTRY -from torch.nn import L1Loss - -from lzero.mcts import MuZeroMCTSCtree as MCTSCtree -from lzero.mcts import MuZeroMCTSPtree as MCTSPtree -from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ - DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, prepare_obs, \ - configure_optimizers -from collections import defaultdict -from ding.torch_utils import to_device - - -@POLICY_REGISTRY.register('gobigger_muzero') -class GoBiggerMuZeroPolicy(MuZeroPolicy): - """ - Overview: - The policy class for GoBiggerMuZero. - """ - - def default_model(self) -> Tuple[str, List[str]]: - """ - Overview: - Return this algorithm default model setting. - Returns: - - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. - - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. - - import_names (:obj:`List[str]`): The model class path list used in this algorithm. - .. note:: - The user can define and use customized network model but must obey the same interface definition indicated \ - by import_names path. For MuZero, ``lzero.model.muzero_model.MuZeroModel`` - """ - return 'GoBiggerMuZeroModel', ['lzero.model.gobigger.gobigger_muzero_model'] - - def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: - """ - Overview: - The forward function for learning policy in learn mode, which is the core of the learning process. - The data is sampled from replay buffer. - The loss is calculated by the loss function and the loss is backpropagated to update the model. - Arguments: - - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. - The first tensor is the current_batch, the second tensor is the target_batch. - Returns: - - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ - current learning loss and learning statistics. - """ - self._learn_model.train() - self._target_model.train() - - current_batch, target_batch = data - obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch - target_reward, target_value, target_policy = target_batch - - obs_batch_ori = obs_batch_ori.tolist() - obs_batch_ori = np.array(obs_batch_ori) - obs_batch = obs_batch_ori[:, 0:self._cfg.model.frame_stack_num] - if self._cfg.model.self_supervised_learning_loss: - obs_target_batch = obs_batch_ori[:, self._cfg.model.frame_stack_num:] - # obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) - - # # do augmentations - # if self._cfg.use_augmentation: - # obs_batch = self.image_transforms.transform(obs_batch) - # if self._cfg.model.self_supervised_learning_loss: - # obs_target_batch = self.image_transforms.transform(obs_target_batch) - - # shape: (batch_size, num_unroll_steps, action_dim) - # NOTE: .long(), in discrete action space. - action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() - data_list = [ - mask_batch, - target_reward.astype('float32'), - target_value.astype('float32'), target_policy, weights - ] - [mask_batch, target_reward, target_value, target_policy, - weights] = to_torch_float_tensor(data_list, self._cfg.device) - - target_reward = target_reward.view(self._cfg.batch_size, -1) - target_value = target_value.view(self._cfg.batch_size, -1) - - assert obs_batch.size == self._cfg.batch_size == target_reward.size(0) - - # ``scalar_transform`` to transform the original value to the scaled value, - # i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. - transformed_target_reward = scalar_transform(target_reward) - transformed_target_value = scalar_transform(target_value) - - # transform a scalar to its categorical_distribution. After this transformation, each scalar is - # represented as the linear combination of its two adjacent supports. - target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) - target_value_categorical = phi_transform(self.value_support, transformed_target_value) - - # ============================================================== - # the core initial_inference in MuZero policy. - # ============================================================== - obs_batch = obs_batch.tolist() - obs_batch = sum(obs_batch, []) - obs_batch = to_tensor(obs_batch) - obs_batch = to_device(obs_batch, self._cfg.device) - network_output = self._learn_model.initial_inference(obs_batch) - - # value_prefix shape: (batch_size, 10), the ``value_prefix`` at the first step is zero padding. - latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) - - # transform the scaled value or its categorical representation to its original value, - # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. - original_value = self.inverse_scalar_transform_handle(value) - - # Note: The following lines are just for debugging. - predicted_rewards = [] - if self._cfg.monitor_extra_statistics: - latent_state_list = latent_state.detach().cpu().numpy() - predicted_values, predicted_policies = original_value.detach().cpu(), torch.softmax( - policy_logits, dim=1 - ).detach().cpu() - - # calculate the new priorities for each transition. - value_priority = L1Loss(reduction='none')(original_value.squeeze(-1), target_value[:, 0]) - value_priority = value_priority.data.cpu().numpy() + 1e-6 - - # ============================================================== - # calculate policy and value loss for the first step. - # ============================================================== - policy_loss = cross_entropy_loss(policy_logits, target_policy[:, 0]) - value_loss = cross_entropy_loss(value, target_value_categorical[:, 0]) - - reward_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) - consistency_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) - - gradient_scale = 1 / self._cfg.num_unroll_steps - - # ============================================================== - # the core recurrent_inference in MuZero policy. - # ============================================================== - for step_i in range(self._cfg.num_unroll_steps): - # unroll with the dynamics function: predict the next ``latent_state``, ``reward``, - # given current ``latent_state`` and ``action``. - # And then predict policy_logits and value with the prediction function. - network_output = self._learn_model.recurrent_inference(latent_state, action_batch[:, step_i]) - latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) - - # transform the scaled value or its categorical representation to its original value, - # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. - original_value = self.inverse_scalar_transform_handle(value) - - if self._cfg.model.self_supervised_learning_loss: - # ============================================================== - # calculate consistency loss for the next ``num_unroll_steps`` unroll steps. - # ============================================================== - if self._cfg.ssl_loss_weight > 0: - beg_index = step_i - end_index = step_i + self._cfg.model.frame_stack_num - obs_target_batch_tmp = obs_target_batch[:, beg_index:end_index].tolist() - obs_target_batch_tmp = sum(obs_target_batch_tmp, []) - obs_target_batch_tmp = to_tensor(obs_target_batch_tmp) - obs_target_batch_tmp = to_device(obs_target_batch_tmp, self._cfg.device) - network_output = self._learn_model.initial_inference(obs_target_batch_tmp) - - latent_state = to_tensor(latent_state) - representation_state = to_tensor(network_output.latent_state) - - # NOTE: no grad for the representation_state branch - dynamic_proj = self._learn_model.project(latent_state, with_grad=True) - observation_proj = self._learn_model.project(representation_state, with_grad=False) - temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_i] - consistency_loss += temp_loss - - # NOTE: the target policy, target_value_categorical, target_reward_categorical is calculated in - # game buffer now. - # ============================================================== - # calculate policy loss for the next ``num_unroll_steps`` unroll steps. - # NOTE: the +=. - # ============================================================== - policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_i + 1]) - - value_loss += cross_entropy_loss(value, target_value_categorical[:, step_i + 1]) - reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_i]) - - # Follow MuZero, set half gradient - # latent_state.register_hook(lambda grad: grad * 0.5) - - if self._cfg.monitor_extra_statistics: - original_rewards = self.inverse_scalar_transform_handle(reward) - original_rewards_cpu = original_rewards.detach().cpu() - - predicted_values = torch.cat( - (predicted_values, self.inverse_scalar_transform_handle(value).detach().cpu()) - ) - predicted_rewards.append(original_rewards_cpu) - predicted_policies = torch.cat((predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu())) - latent_state_list = np.concatenate((latent_state_list, latent_state.detach().cpu().numpy())) - - # ============================================================== - # the core learn model update step. - # ============================================================== - # weighted loss with masks (some invalid states which are out of trajectory.) - loss = ( - self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss + - self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * reward_loss - ) - weighted_total_loss = (weights * loss).mean() - - gradient_scale = 1 / self._cfg.num_unroll_steps - weighted_total_loss.register_hook(lambda grad: grad * gradient_scale) - self._optimizer.zero_grad() - weighted_total_loss.backward() - total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_( - self._learn_model.parameters(), self._cfg.grad_clip_value - ) - self._optimizer.step() - if self._cfg.lr_piecewise_constant_decay: - self.lr_scheduler.step() - - # ============================================================== - # the core target model update step. - # ============================================================== - self._target_model.update(self._learn_model.state_dict()) - - if self._cfg.monitor_extra_statistics: - predicted_rewards = torch.stack(predicted_rewards).transpose(1, 0).squeeze(-1) - predicted_rewards = predicted_rewards.reshape(-1).unsqueeze(-1) - - return { - 'collect_mcts_temperature': self.collect_mcts_temperature, - 'collect_epsilon': self.collect_epsilon, - 'cur_lr': self._optimizer.param_groups[0]['lr'], - 'weighted_total_loss': weighted_total_loss.item(), - 'total_loss': loss.mean().item(), - 'policy_loss': policy_loss.mean().item(), - 'reward_loss': reward_loss.mean().item(), - 'value_loss': value_loss.mean().item(), - 'consistency_loss': consistency_loss.mean() / self._cfg.num_unroll_steps, - - # ============================================================== - # priority related - # ============================================================== - 'value_priority_orig': value_priority, - 'value_priority': value_priority.mean().item(), - 'target_reward': target_reward.detach().cpu().numpy().mean().item(), - 'target_value': target_value.detach().cpu().numpy().mean().item(), - 'transformed_target_reward': transformed_target_reward.detach().cpu().numpy().mean().item(), - 'transformed_target_value': transformed_target_value.detach().cpu().numpy().mean().item(), - 'predicted_rewards': predicted_rewards.detach().cpu().numpy().mean().item(), - 'predicted_values': predicted_values.detach().cpu().numpy().mean().item(), - 'total_grad_norm_before_clip': total_grad_norm_before_clip - } - - def _forward_collect( - self, - data: torch.Tensor, - action_mask: list = None, - temperature: float = 1, - to_play: List = [-1], - epsilon: float = 0.25, - ready_env_id = None - ) -> Dict: - """ - Overview: - The forward function for collecting data in collect mode. Use model to execute MCTS search. - Choosing the action through sampling during the collect mode. - Arguments: - - data (:obj:`torch.Tensor`): The input data, i.e. the observation. - - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. - - temperature (:obj:`float`): The temperature of the policy. - - to_play (:obj:`int`): The player to play. - - ready_env_id (:obj:`list`): The id of the env that is ready to collect. - Shape: - - data (:obj:`torch.Tensor`): - - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ - S is the number of stacked frames, H is the height of the image, W is the width of the image. - - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. - - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. - - temperature: :math:`(1, )`. - - to_play: :math:`(N, 1)`, where N is the number of collect_env. - - ready_env_id: None - Returns: - - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ - ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. - """ - self._collect_model.eval() - self.collect_mcts_temperature = temperature - self.collect_epsilon = epsilon - active_collect_env_num = len(data) - data = to_tensor(data) - data = sum(sum(data, []), []) - batch_size = len(data) - data = to_device(data, self._cfg.device) - agent_num = batch_size // active_collect_env_num - to_play = np.array(to_play).reshape(-1).tolist() - - with torch.no_grad(): - # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} - network_output = self._collect_model.initial_inference(data) - latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) - - if not self._learn_model.training: - # if not in training, obtain the scalars of the value/reward - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() - latent_state_roots = latent_state_roots.detach().cpu().numpy() - policy_logits = policy_logits.detach().cpu().numpy().tolist() - - action_mask = sum(action_mask, []) - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] - # the only difference between collect and eval is the dirichlet noise - noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) - ).astype(np.float32).tolist() for j in range(batch_size) - ] - if self._cfg.mcts_ctree: - # cpp mcts_tree - roots = MCTSCtree.roots(batch_size, legal_actions) - else: - # python mcts_tree - roots = MCTSPtree.roots(batch_size, legal_actions) - - roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) - self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play) - - roots_visit_count_distributions = roots.get_distributions( - ) # shape: ``{list: batch_size} ->{list: action_space_size}`` - roots_values = roots.get_values() # shape: {list: batch_size} - - data_id = [i for i in range(active_collect_env_num)] - output = {i: defaultdict(list) for i in data_id} - - if ready_env_id is None: - ready_env_id = np.arange(active_collect_env_num) - - for i in range(batch_size): - distributions, value = roots_visit_count_distributions[i], roots_values[i] - if self._cfg.eps.eps_greedy_exploration_in_collect: - # eps greedy collect - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self.collect_mcts_temperature, deterministic=True - ) - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - if np.random.rand() < self.collect_epsilon: - action = np.random.choice(legal_actions[i]) - else: - # normal collect - # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents - # the index within the legal action set, rather than the index in the entire action set. - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self.collect_mcts_temperature, deterministic=False - ) - # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - output[i // agent_num]['action'].append(action) - output[i // agent_num]['distributions'].append(distributions) - output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) - output[i // agent_num]['value'].append(value) - output[i // agent_num]['pred_value'].append(pred_values[i]) - output[i // agent_num]['policy_logits'].append(policy_logits[i]) - - return output - - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id=None) -> Dict: - """ - Overview: - The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. - Choosing the action with the highest value (argmax) rather than sampling during the eval mode. - Arguments: - - data (:obj:`torch.Tensor`): The input data, i.e. the observation. - - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. - - to_play (:obj:`int`): The player to play. - - ready_env_id (:obj:`list`): The id of the env that is ready to collect. - Shape: - - data (:obj:`torch.Tensor`): - - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ - S is the number of stacked frames, H is the height of the image, W is the width of the image. - - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. - - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. - - to_play: :math:`(N, 1)`, where N is the number of collect_env. - - ready_env_id: None - Returns: - - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ - ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. - """ - self._eval_model.eval() - active_eval_env_num = len(data) - data = to_tensor(data) - data = sum(sum(data, []), []) - batch_size = len(data) - data = to_device(data, self._cfg.device) - agent_num = batch_size // active_eval_env_num - to_play = np.array(to_play).reshape(-1).tolist() - - with torch.no_grad(): - # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} - network_output = self._collect_model.initial_inference(data) - latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) - - if not self._eval_model.training: - # if not in training, obtain the scalars of the value/reward - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) - latent_state_roots = latent_state_roots.detach().cpu().numpy() - policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) - - action_mask = sum(action_mask, []) - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] - if self._cfg.mcts_ctree: - # cpp mcts_tree - roots = MCTSCtree.roots(batch_size, legal_actions) - else: - # python mcts_tree - roots = MCTSPtree.roots(batch_size, legal_actions) - roots.prepare_no_noise(reward_roots, policy_logits, to_play) - self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play) - - roots_visit_count_distributions = roots.get_distributions( - ) # shape: ``{list: batch_size} ->{list: action_space_size}`` - roots_values = roots.get_values() # shape: {list: batch_size} - - data_id = [i for i in range(active_eval_env_num)] - output = {i: defaultdict(list) for i in data_id} - - if ready_env_id is None: - ready_env_id = np.arange(active_eval_env_num) - - for i in range(batch_size): - distributions, value = roots_visit_count_distributions[i], roots_values[i] - # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents - # the index within the legal action set, rather than the index in the entire action set. - # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than - # sampling during the evaluation phase. - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=1, deterministic=True - ) - # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the - # entire action set. - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - output[i // agent_num]['action'].append(action) - output[i // agent_num]['distributions'].append(distributions) - output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) - output[i // agent_num]['value'].append(value) - output[i // agent_num]['pred_value'].append(pred_values[i]) - output[i // agent_num]['policy_logits'].append(policy_logits[i]) - - return output diff --git a/lzero/policy/multi_agent_efficientzero.py b/lzero/policy/multi_agent_efficientzero.py new file mode 100644 index 000000000..d01c66f58 --- /dev/null +++ b/lzero/policy/multi_agent_efficientzero.py @@ -0,0 +1,229 @@ +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +from .efficientzero import EfficientZeroPolicy +from ding.torch_utils import to_tensor +from ding.utils import POLICY_REGISTRY +from torch.distributions import Categorical +from torch.nn import L1Loss + +from lzero.mcts import EfficientZeroMCTSCtree as MCTSCtree +from lzero.mcts import EfficientZeroMCTSPtree as MCTSPtree +from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ + DiscreteSupport, select_action, to_torch_float_tensor, ez_network_output_unpack, negative_cosine_similarity, prepare_obs, \ + configure_optimizers +from collections import defaultdict +from ding.torch_utils import to_device, to_tensor +from ding.utils.data import default_collate + + +@POLICY_REGISTRY.register('multi_agent_efficientzero') +class MultiAgentEfficientZeroPolicy(EfficientZeroPolicy): + """ + Overview: + The policy class for Multi Agent EfficientZero. + """ + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id = None + ): + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + self.collect_mcts_temperature = temperature + self.collect_epsilon = epsilon + + active_collect_env_num = len(data) + data = to_tensor(data) + data = sum(sum(data, []), []) + batch_size = len(data) + data = to_device(data, self._cfg.device) + data = default_collate(data) + agent_num = batch_size // active_collect_env_num + to_play = np.array(to_play).reshape(-1).tolist() + + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._collect_model.initial_inference(data) + latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( + network_output + ) + + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + reward_hidden_state_roots = ( + reward_hidden_state_roots[0].detach().cpu().numpy(), + reward_hidden_state_roots[1].detach().cpu().numpy() + ) + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + action_mask = sum(action_mask, []) + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] + # the only difference between collect and eval is the dirichlet noise. + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(batch_size) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(batch_size, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(batch_size, legal_actions) + roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_roots, policy_logits, to_play) + self._mcts_collect.search( + roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play + ) + + roots_visit_count_distributions = roots.get_distributions( + ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_values = roots.get_values() # shape: {list: batch_size} + + data_id = [i for i in range(active_collect_env_num)] + output = {i: defaultdict(list) for i in data_id} + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + + for i in range(batch_size): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + if self._cfg.eps.eps_greedy_exploration_in_collect: + # eps-greedy collect + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self.collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self.collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # normal collect + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self.collect_mcts_temperature, deterministic=False + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[i // agent_num]['action'].append(action) + output[i // agent_num]['distributions'].append(distributions) + output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[i // agent_num]['value'].append(value) + output[i // agent_num]['pred_value'].append(pred_values[i]) + output[i // agent_num]['policy_logits'].append(policy_logits[i]) + + return output + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id=None): + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + active_eval_env_num = len(data) + data = to_tensor(data) + data = sum(sum(data, []), []) + batch_size = len(data) + data = to_device(data, self._cfg.device) + data = default_collate(data) + agent_num = batch_size // active_eval_env_num + to_play = np.array(to_play).reshape(-1).tolist() + + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._eval_model.initial_inference(data) + latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( + network_output + ) + + if not self._eval_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + reward_hidden_state_roots = ( + reward_hidden_state_roots[0].detach().cpu().numpy(), + reward_hidden_state_roots[1].detach().cpu().numpy() + ) + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + + action_mask = sum(action_mask, []) + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(batch_size, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(batch_size, legal_actions) + roots.prepare_no_noise(value_prefix_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, reward_hidden_state_roots, to_play) + + roots_visit_count_distributions = roots.get_distributions( + ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_values = roots.get_values() # shape: {list: batch_size} + data_id = [i for i in range(active_eval_env_num)] + output = {i: defaultdict(list) for i in data_id} + + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + + for i in range(batch_size): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than sampling during the evaluation phase. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[i // agent_num]['action'].append(action) + output[i // agent_num]['distributions'].append(distributions) + output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[i // agent_num]['value'].append(value) + output[i // agent_num]['pred_value'].append(pred_values[i]) + output[i // agent_num]['policy_logits'].append(policy_logits[i]) + + return output \ No newline at end of file diff --git a/lzero/policy/multi_agent_muzero.py b/lzero/policy/multi_agent_muzero.py new file mode 100644 index 000000000..6e5476a40 --- /dev/null +++ b/lzero/policy/multi_agent_muzero.py @@ -0,0 +1,225 @@ +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +import torch.optim as optim +from ding.model import model_wrap +from ding.policy.base_policy import Policy +from ding.torch_utils import to_tensor +from ding.utils import POLICY_REGISTRY +from torch.nn import L1Loss + +from lzero.mcts import MuZeroMCTSCtree as MCTSCtree +from lzero.mcts import MuZeroMCTSPtree as MCTSPtree +from lzero.model import ImageTransforms +from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ + DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, prepare_obs, \ + configure_optimizers + +from collections import defaultdict +from ding.torch_utils import to_device +from ding.utils.data import default_collate +from .muzero import MuZeroPolicy + + +@POLICY_REGISTRY.register('multi_agent_muzero') +class MultiAgentMuZeroPolicy(MuZeroPolicy): + """ + Overview: + The policy class for Multi Agent MuZero. + """ + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id = None + ) -> Dict: + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + self.collect_mcts_temperature = temperature + self.collect_epsilon = epsilon + active_collect_env_num = len(data) + data = to_tensor(data) + data = sum(sum(data, []), []) + batch_size = len(data) + data = to_device(data, self._cfg.device) + data = default_collate(data) + agent_num = batch_size // active_collect_env_num + to_play = np.array(to_play).reshape(-1).tolist() + + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._collect_model.initial_inference(data) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + if not self._learn_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + action_mask = sum(action_mask, []) + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] + # the only difference between collect and eval is the dirichlet noise + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(batch_size) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(batch_size, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(batch_size, legal_actions) + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play) + + roots_visit_count_distributions = roots.get_distributions( + ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_values = roots.get_values() # shape: {list: batch_size} + + data_id = [i for i in range(active_collect_env_num)] + output = {i: defaultdict(list) for i in data_id} + + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + + for i in range(batch_size): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + if self._cfg.eps.eps_greedy_exploration_in_collect: + # eps greedy collect + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self.collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self.collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # normal collect + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self.collect_mcts_temperature, deterministic=False + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[i // agent_num]['action'].append(action) + output[i // agent_num]['distributions'].append(distributions) + output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[i // agent_num]['value'].append(value) + output[i // agent_num]['pred_value'].append(pred_values[i]) + output[i // agent_num]['policy_logits'].append(policy_logits[i]) + + return output + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id=None) -> Dict: + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + active_eval_env_num = len(data) + data = to_tensor(data) + data = sum(sum(data, []), []) + batch_size = len(data) + data = to_device(data, self._cfg.device) + data = default_collate(data) + agent_num = batch_size // active_eval_env_num + to_play = np.array(to_play).reshape(-1).tolist() + + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._collect_model.initial_inference(data) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + if not self._eval_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + + action_mask = sum(action_mask, []) + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(batch_size, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(batch_size, legal_actions) + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play) + + roots_visit_count_distributions = roots.get_distributions( + ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_values = roots.get_values() # shape: {list: batch_size} + + data_id = [i for i in range(active_eval_env_num)] + output = {i: defaultdict(list) for i in data_id} + + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + + for i in range(batch_size): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than + # sampling during the evaluation phase. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the + # entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[i // agent_num]['action'].append(action) + output[i // agent_num]['distributions'].append(distributions) + output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[i // agent_num]['value'].append(value) + output[i // agent_num]['pred_value'].append(pred_values[i]) + output[i // agent_num]['policy_logits'].append(policy_logits[i]) + + return output diff --git a/lzero/policy/gobigger_random_policy.py b/lzero/policy/multi_agent_random_policy.py similarity index 65% rename from lzero/policy/gobigger_random_policy.py rename to lzero/policy/multi_agent_random_policy.py index be001ed2d..1fc110b54 100644 --- a/lzero/policy/gobigger_random_policy.py +++ b/lzero/policy/multi_agent_random_policy.py @@ -5,35 +5,20 @@ from ding.policy.base_policy import Policy from ding.utils import POLICY_REGISTRY -from lzero.mcts import EfficientZeroMCTSCtree as MCTSCtree -from lzero.mcts import EfficientZeroMCTSPtree as MCTSPtree -from lzero.policy import InverseScalarTransform, select_action, ez_network_output_unpack +from lzero.policy import InverseScalarTransform, select_action, ez_network_output_unpack, mz_network_output_unpack from .random_policy import LightZeroRandomPolicy from collections import defaultdict from ding.torch_utils import to_device, to_tensor +from ding.utils.data import default_collate -@POLICY_REGISTRY.register('gobigger_lightzero_random_policy') -class GoBiggerLightZeroRandomPolicy(LightZeroRandomPolicy): +@POLICY_REGISTRY.register('multi_agent_lightzero_random_policy') +class MultiAgentLightZeroRandomPolicy(LightZeroRandomPolicy): """ Overview: - The policy class for GoBiggerRandom. + The policy class for Multi Agent LightZero Random Policy. """ - def default_model(self) -> Tuple[str, List[str]]: - """ - Overview: - Return this algorithm default model setting. - Returns: - - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. - - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. - - import_names (:obj:`List[str]`): The model class path list used in this algorithm. - .. note:: - The user can define and use customized network model but must obey the same interface definition indicated \ - by import_names path. For EfficientZero, ``lzero.model.efficientzero_model.EfficientZeroModel`` - """ - return 'GoBiggerEfficientZeroModel', ['lzero.model.gobigger.gobigger_efficientzero_model'] - def _forward_collect( self, data: torch.Tensor, @@ -41,7 +26,7 @@ def _forward_collect( temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, - ready_env_id=None + ready_env_id = None ): """ Overview: @@ -68,28 +53,36 @@ def _forward_collect( """ self._collect_model.eval() self.collect_mcts_temperature = temperature + self.collect_epsilon = epsilon active_collect_env_num = len(data) data = to_tensor(data) data = sum(sum(data, []), []) batch_size = len(data) data = to_device(data, self._cfg.device) + data = default_collate(data) agent_num = batch_size // active_collect_env_num to_play = np.array(to_play).reshape(-1).tolist() with torch.no_grad(): # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} network_output = self._collect_model.initial_inference(data) - latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( - network_output - ) + if 'efficientzero' in self._cfg.type: # efficientzero or multi_agent_efficientzero + latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( + network_output + ) + elif 'muzero' in self._cfg.type: # muzero or multi_agent_muzero + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + else: + raise NotImplementedError("need to implement pipeline: {}".format(self._cfg.type)) - # if not in training, obtain the scalars of the value/reward pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() - reward_hidden_state_roots = ( - reward_hidden_state_roots[0].detach().cpu().numpy(), reward_hidden_state_roots[1].detach().cpu().numpy() - ) + if 'efficientzero' in self._cfg.type: + reward_hidden_state_roots = ( + reward_hidden_state_roots[0].detach().cpu().numpy(), + reward_hidden_state_roots[1].detach().cpu().numpy() + ) policy_logits = policy_logits.detach().cpu().numpy().tolist() action_mask = sum(action_mask, []) @@ -101,14 +94,20 @@ def _forward_collect( ] if self._cfg.mcts_ctree: # cpp mcts_tree - roots = MCTSCtree.roots(batch_size, legal_actions) + roots = self.MCTSCtree.roots(batch_size, legal_actions) else: # python mcts_tree - roots = MCTSPtree.roots(batch_size, legal_actions) - roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_roots, policy_logits, to_play) - self._mcts_collect.search( - roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play - ) + roots = self.MCTSPtree.roots(batch_size, legal_actions) + if 'efficientzero' in self._cfg.type: # efficientzero or multi_agent_efficientzero + roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_roots, policy_logits, to_play) + self._mcts_collect.search( + roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play + ) + elif 'muzero' in self._cfg.type: # muzero or multi_agent_muzero + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play) + else: + raise NotImplementedError("need to implement pipeline: {}".format(self._cfg.type)) roots_visit_count_distributions = roots.get_distributions( ) # shape: ``{list: batch_size} ->{list: action_space_size}`` @@ -124,10 +123,13 @@ def _forward_collect( action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( distributions, temperature=self.collect_mcts_temperature, deterministic=False ) - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - # ************* random action ************* - action = int(np.random.choice(legal_actions[i], 1)) - output[i // agent_num]['action'].append(action) + + # ****** sample a random action from the legal action set ******** + # all items except action are formally obtained from MCTS + random_action = int(np.random.choice(legal_actions[i], 1)) + # **************************************************************** + + output[i // agent_num]['action'].append(random_action) output[i // agent_num]['distributions'].append(distributions) output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) output[i // agent_num]['value'].append(value) diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index 3536b6967..7d3297f1e 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -16,6 +16,9 @@ from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, prepare_obs, \ configure_optimizers +from collections import defaultdict +from ding.torch_utils import to_device +from ding.utils.data import default_collate @POLICY_REGISTRY.register('muzero') @@ -202,6 +205,8 @@ def default_model(self) -> Tuple[str, List[str]]: return 'MuZeroModel', ['lzero.model.muzero_model'] elif self._cfg.model.model_type == "mlp": return 'MuZeroModelMLP', ['lzero.model.muzero_model_mlp'] + elif self._cfg.model.model_type == "structure": + return 'MuZeroModelStructure', ['lzero.model.muzero_model_structure'] def _init_learn(self) -> None: """ @@ -299,7 +304,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in target_reward = target_reward.view(self._cfg.batch_size, -1) target_value = target_value.view(self._cfg.batch_size, -1) - assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0) + assert self._cfg.batch_size == target_reward.size(0) # ``scalar_transform`` to transform the original value to the scaled value, # i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. @@ -376,6 +381,15 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in beg_index = self._cfg.model.observation_shape * step_i end_index = self._cfg.model.observation_shape * (step_i + self._cfg.model.frame_stack_num) network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) + elif self._cfg.model.model_type == 'structure': + beg_index = step_i + end_index = step_i + self._cfg.model.frame_stack_num + obs_target_batch_tmp = obs_target_batch[:, beg_index:end_index].tolist() + obs_target_batch_tmp = sum(obs_target_batch_tmp, []) + obs_target_batch_tmp = to_tensor(obs_target_batch_tmp) + obs_target_batch_tmp = to_device(obs_target_batch_tmp, self._cfg.device) + obs_target_batch_tmp = default_collate(obs_target_batch_tmp) + network_output = self._learn_model.initial_inference(obs_target_batch_tmp) latent_state = to_tensor(latent_state) representation_state = to_tensor(network_output.latent_state) diff --git a/lzero/policy/random_policy.py b/lzero/policy/random_policy.py index 384b83d24..b2d1cb0d7 100644 --- a/lzero/policy/random_policy.py +++ b/lzero/policy/random_policy.py @@ -59,6 +59,13 @@ def default_model(self) -> Tuple[str, List[str]]: return 'MuZeroModelMLP', ['lzero.model.muzero_model_mlp'] else: raise NotImplementedError("need to implement pipeline: {}".format(self._cfg.type)) + elif self._cfg.model.model_type == 'structure': + if 'efficientzero' in self._cfg.type: # efficientzero or multi_agent_efficientzero + return 'EfficientZeroModelStructure', ['lzero.model.efficientzero_model_structure'] + elif 'muzero' in self._cfg.type: # muzero or multi_agent_muzero + return 'MuZeroModelStructure', ['lzero.model.muzero_model_structure'] + else: + raise NotImplementedError("need to implement pipeline: {}".format(self._cfg.type)) def _init_collect(self) -> None: """ diff --git a/lzero/policy/utils.py b/lzero/policy/utils.py index cd220bd0e..f65db53fa 100644 --- a/lzero/policy/utils.py +++ b/lzero/policy/utils.py @@ -11,6 +11,8 @@ import torch import torch.nn as nn from torch.nn import functional as F +from ding.torch_utils import to_device, to_tensor +from ding.utils.data import default_collate class LayerNorm(nn.Module): @@ -176,6 +178,18 @@ def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, # ``obs_target_batch`` is only used for calculate consistency loss, which take the all obs other than # timestep t1, and is only performed in the last 8 timesteps in the second dim in ``obs_batch_ori``. obs_target_batch = obs_batch_ori[:, cfg.model.observation_shape:] + elif cfg.model.model_type == 'structure': + obs_batch_ori = obs_batch_ori.tolist() + obs_batch_ori = np.array(obs_batch_ori) + obs_batch = obs_batch_ori[:, 0:cfg.model.frame_stack_num] + if cfg.model.self_supervised_learning_loss: + obs_target_batch = obs_batch_ori[:, cfg.model.frame_stack_num:] + + obs_batch = obs_batch.tolist() + obs_batch = sum(obs_batch, []) + obs_batch = to_tensor(obs_batch) + obs_batch = to_device(obs_batch, cfg.device) + obs_batch = default_collate(obs_batch) return obs_batch, obs_target_batch diff --git a/lzero/worker/__init__.py b/lzero/worker/__init__.py index 62a97b743..341112f60 100644 --- a/lzero/worker/__init__.py +++ b/lzero/worker/__init__.py @@ -2,5 +2,5 @@ from .alphazero_evaluator import AlphaZeroEvaluator from .muzero_collector import MuZeroCollector from .muzero_evaluator import MuZeroEvaluator -from .gobigger_muzero_collector import GoBiggerMuZeroCollector -from .gobigger_muzero_evaluator import GoBiggerMuZeroEvaluator +from .multi_agent_muzero_collector import MultiAgentMuZeroCollector +from .gobigger_muzero_evaluator import GoBiggerMuZeroEvaluator \ No newline at end of file diff --git a/lzero/worker/gobigger_muzero_collector.py b/lzero/worker/multi_agent_muzero_collector.py similarity index 95% rename from lzero/worker/gobigger_muzero_collector.py rename to lzero/worker/multi_agent_muzero_collector.py index 835feb1e8..ad0690d0b 100644 --- a/lzero/worker/gobigger_muzero_collector.py +++ b/lzero/worker/multi_agent_muzero_collector.py @@ -15,12 +15,12 @@ from collections import defaultdict -@SERIAL_COLLECTOR_REGISTRY.register('gobigger_episode_muzero') -class GoBiggerMuZeroCollector(MuZeroCollector): +@SERIAL_COLLECTOR_REGISTRY.register('multi_agent_episode_muzero') +class MultiAgentMuZeroCollector(MuZeroCollector): """ Overview: - The Collector for GoBigger MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. - For GoBigger, add agent_num dim in game_segment. + The Collector for Multi Agent MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. + For Multi Agent, add agent_num dim in game_segment. Interfaces: __init__, reset, reset_env, reset_policy, collect, close Property: diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index bc8d90daf..63ac5100b 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -292,9 +292,6 @@ def collect(self, raise RuntimeError("Please specify collect n_episode") else: n_episode = self._default_n_episode - random_collect_episode_num = 0 - else: - random_collect_episode_num = n_episode assert n_episode >= self._env_num, "Please make sure n_episode >= env_num{}/{}".format(n_episode, self._env_num) if policy_kwargs is None: policy_kwargs = {} @@ -325,6 +322,7 @@ def collect(self, if self._multi_agent: agent_num = len(init_obs[0]['action_mask']) + assert agent_num == self.policy_config.model.agent_num, "Please make sure agent_num == env.agent_num" game_segments = [ [ GameSegment( @@ -522,7 +520,8 @@ def collect(self, if self._multi_agent: for agent_id in range(agent_num): game_segments[env_id][agent_id].append( - actions[env_id][agent_id], to_ndarray(obs['observation'][agent_id]), reward[agent_id], + actions[env_id][agent_id], + to_ndarray(obs['observation'][agent_id]), reward[agent_id] if isinstance(reward, list) else reward, action_mask_dict[env_id][agent_id], to_play_dict[env_id] ) else: @@ -658,7 +657,7 @@ def collect(self, # pad over 2th last game_segment using the last game_segment if self._multi_agent: for agent_id in range(agent_num): - if last_game_segments[env_id] is not None: + if last_game_segments[env_id][agent_id] is not None: self.pad_and_save_last_trajectory( env_id, agent_id, last_game_segments, last_game_priorities, game_segments, dones ) diff --git a/lzero/worker/muzero_evaluator.py b/lzero/worker/muzero_evaluator.py index f2e585a03..d21ae9559 100644 --- a/lzero/worker/muzero_evaluator.py +++ b/lzero/worker/muzero_evaluator.py @@ -191,7 +191,7 @@ def should_eval(self, train_iter: int) -> bool: return True def _add_info(self, last_timestep, info): - pass + return info def eval( self, @@ -356,7 +356,7 @@ def eval( if self._multi_agent: for agent_id in range(agent_num): game_segments[env_id][agent_id].append( - actions[env_id][agent_id], to_ndarray(obs['observation'][agent_id]), reward[agent_id], + actions[env_id][agent_id], to_ndarray(obs['observation'][agent_id]), reward[agent_id] if isinstance(reward, list) else reward, action_mask_dict[env_id][agent_id], to_play_dict[env_id] ) else: diff --git a/zoo/gobigger/config/gobigger_efficientzero_config.py b/zoo/gobigger/config/gobigger_efficientzero_config.py index cea9fe087..23c8e5a11 100644 --- a/zoo/gobigger/config/gobigger_efficientzero_config.py +++ b/zoo/gobigger/config/gobigger_efficientzero_config.py @@ -1,6 +1,6 @@ from easydict import EasyDict -env_name = 'GoBigger' +env_name = 'gobigger' multi_agent = True # ============================================================== @@ -21,7 +21,7 @@ # end of the most frequently changed config specified by the user # ============================================================== -atari_efficientzero_config = dict( +gobigger_efficientzero_config = dict( exp_name= f'data_ez_ctree/{env_name}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', env=dict( @@ -35,11 +35,12 @@ multi_agent=multi_agent, ignore_done=True, model=dict( - model_type='structured', + model_type='structure', + env_name=env_name, + agent_num=4, # default is t2p2 latent_state_dim=176, frame_stack_num=1, action_space_size=action_space_size, - downsample=True, discrete_action_encoding_type='one_hot', norm_type='BN', ), @@ -75,26 +76,26 @@ hook=dict(log_show_after_iter=10, ), ), ), ) -atari_efficientzero_config = EasyDict(atari_efficientzero_config) -main_config = atari_efficientzero_config +gobigger_efficientzero_config = EasyDict(gobigger_efficientzero_config) +main_config = gobigger_efficientzero_config -atari_efficientzero_create_config = dict( +gobigger_efficientzero_create_config = dict( env=dict( type='gobigger_lightzero', import_names=['zoo.gobigger.env.gobigger_env'], ), env_manager=dict(type='subprocess'), policy=dict( - type='gobigger_efficientzero', - import_names=['lzero.policy.gobigger_efficientzero'], + type='multi_agent_efficientzero', + import_names=['lzero.policy.multi_agent_efficientzero'], ), collector=dict( - type='gobigger_episode_muzero', - import_names=['lzero.worker.gobigger_muzero_collector'], + type='multi_agent_episode_muzero', + import_names=['lzero.worker.multi_agent_muzero_collector'], ) ) -atari_efficientzero_create_config = EasyDict(atari_efficientzero_create_config) -create_config = atari_efficientzero_create_config +gobigger_efficientzero_create_config = EasyDict(gobigger_efficientzero_create_config) +create_config = gobigger_efficientzero_create_config if __name__ == "__main__": from zoo.gobigger.entry import train_muzero_gobigger diff --git a/zoo/gobigger/config/gobigger_muzero_config.py b/zoo/gobigger/config/gobigger_muzero_config.py index 3ac5eab3b..df2eacf73 100644 --- a/zoo/gobigger/config/gobigger_muzero_config.py +++ b/zoo/gobigger/config/gobigger_muzero_config.py @@ -1,6 +1,6 @@ from easydict import EasyDict -env_name = 'GoBigger' +env_name = 'gobigger' multi_agent = True # ============================================================== @@ -21,7 +21,7 @@ # end of the most frequently changed config specified by the user # ============================================================== -atari_muzero_config = dict( +gobigger_muzero_config = dict( exp_name=f'data_mz_ctree/{env_name}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', env=dict( env_name=env_name, # default is 'GoBigger T2P2' @@ -34,11 +34,12 @@ multi_agent=multi_agent, ignore_done=True, model=dict( - model_type='structured', + model_type='structure', + env_name=env_name, + agent_num=4, # default is t2p2 latent_state_dim=176, frame_stack_num=1, action_space_size=action_space_size, - downsample=True, self_supervised_learning_loss=False, # default is False discrete_action_encoding_type='one_hot', norm_type='BN', @@ -76,26 +77,26 @@ hook=dict(log_show_after_iter=10, ), ), ), ) -atari_muzero_config = EasyDict(atari_muzero_config) -main_config = atari_muzero_config +gobigger_muzero_config = EasyDict(gobigger_muzero_config) +main_config = gobigger_muzero_config -atari_muzero_create_config = dict( +gobigger_muzero_create_config = dict( env=dict( type='gobigger_lightzero', import_names=['zoo.gobigger.env.gobigger_env'], ), env_manager=dict(type='subprocess'), policy=dict( - type='gobigger_muzero', - import_names=['lzero.policy.gobigger_muzero'], + type='multi_agent_muzero', + import_names=['lzero.policy.multi_agent_muzero'], ), collector=dict( - type='gobigger_episode_muzero', - import_names=['lzero.worker.gobigger_muzero_collector'], + type='multi_agent_episode_muzero', + import_names=['lzero.worker.multi_agent_muzero_collector'], ) ) -atari_muzero_create_config = EasyDict(atari_muzero_create_config) -create_config = atari_muzero_create_config +gobigger_muzero_create_config = EasyDict(gobigger_muzero_create_config) +create_config = gobigger_muzero_create_config if __name__ == "__main__": from zoo.gobigger.entry import train_muzero_gobigger diff --git a/zoo/gobigger/entry/train_muzero_gobigger.py b/zoo/gobigger/entry/train_muzero_gobigger.py index 25f17096f..068c76ee1 100644 --- a/zoo/gobigger/entry/train_muzero_gobigger.py +++ b/zoo/gobigger/entry/train_muzero_gobigger.py @@ -15,9 +15,9 @@ from ding.rl_utils import get_epsilon_greedy_fn from lzero.entry.utils import log_buffer_memory_usage from lzero.policy import visit_count_temperature -from lzero.worker import GoBiggerMuZeroCollector, GoBiggerMuZeroEvaluator +from lzero.worker import GoBiggerMuZeroEvaluator from lzero.entry.utils import random_collect -from lzero.policy.gobigger_random_policy import GoBiggerLightZeroRandomPolicy +from lzero.policy.multi_agent_random_policy import MultiAgentLightZeroRandomPolicy def train_muzero_gobigger( @@ -46,13 +46,21 @@ def train_muzero_gobigger( """ cfg, create_cfg = input_cfg - assert create_cfg.policy.type in ['gobigger_efficientzero', 'gobigger_muzero', 'gobigger_sampled_efficientzero'] - if create_cfg.policy.type == 'gobigger_efficientzero': - from lzero.mcts import GoBiggerEfficientZeroGameBuffer as GameBuffer - elif create_cfg.policy.type == 'gobigger_muzero': - from lzero.mcts import GoBiggerMuZeroGameBuffer as GameBuffer - elif create_cfg.policy.type == 'gobigger_sampled_efficientzero': - from lzero.mcts import GoBiggerSampledEfficientZeroGameBuffer as GameBuffer + assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'multi_agent_efficientzero', 'multi_agent_muzero'], \ + "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'multi_agent_efficientzero', 'multi_agent_muzero'" + + if create_cfg.policy.type == 'muzero': + from lzero.mcts import MuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'efficientzero': + from lzero.mcts import EfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'sampled_efficientzero': + from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'gumbel_muzero': + from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'multi_agent_efficientzero': + from lzero.mcts import MultiAgentSampledEfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'multi_agent_muzero': + from lzero.mcts import MultiAgentMuZeroGameBuffer as GameBuffer if cfg.policy.cuda and torch.cuda.is_available(): cfg.policy.device = 'cuda' @@ -92,7 +100,13 @@ def train_muzero_gobigger( batch_size = policy_config.batch_size # specific game buffer for MCTS+RL algorithms replay_buffer = GameBuffer(policy_config) - collector = GoBiggerMuZeroCollector( + if policy_config.multi_agent: + from lzero.worker import MultiAgentMuZeroCollector as Collector + from lzero.worker import MuZeroEvaluator as Evaluator + else: + from lzero.worker import MuZeroCollector as Collector + from lzero.worker import MuZeroEvaluator as Evaluator + collector = Collector( env=collector_env, policy=policy.collect_mode, tb_logger=tb_logger, @@ -127,6 +141,7 @@ def train_muzero_gobigger( # ============================================================== # Learner's before_run hook. learner.call_hook('before_run') + if cfg.policy.update_per_collect is not None: update_per_collect = cfg.policy.update_per_collect @@ -134,7 +149,7 @@ def train_muzero_gobigger( # Exploration: The collection of random data aids the agent in exploring the environment and prevents premature convergence to a suboptimal policy. # Comparation: The agent's performance during random action-taking can be used as a reference point to evaluate the efficacy of reinforcement learning algorithms. if cfg.policy.random_collect_episode_num > 0: - random_collect(cfg.policy, policy, GoBiggerLightZeroRandomPolicy, collector, collector_env, replay_buffer) + random_collect(cfg.policy, policy, MultiAgentLightZeroRandomPolicy, collector, collector_env, replay_buffer) while True: log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) @@ -147,6 +162,7 @@ def train_muzero_gobigger( policy_config.threshold_training_steps_for_final_temperature, trained_steps=learner.train_iter ) + if policy_config.eps.eps_greedy_exploration_in_collect: epsilon_greedy_fn = get_epsilon_greedy_fn( start=policy_config.eps.start, @@ -160,20 +176,24 @@ def train_muzero_gobigger( # Evaluate policy performance. if evaluator.should_eval(learner.train_iter): - stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + stop, reward = evaluator.eval(None, learner.train_iter, collector.envstep) # save_ckpt_fn = None stop, reward = vsbot_evaluator.eval_vsbot(learner.save_checkpoint, learner.train_iter, collector.envstep) if stop: break # Collect data by default config n_sample/n_episode. new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + if cfg.policy.update_per_collect is None: + # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio. + collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]]) + update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio) # save returned new_data collected by the collector replay_buffer.push_game_segments(new_data) # remove the oldest data if the replay buffer is full. replay_buffer.remove_oldest_data_to_fit() # Learn policy from collected data. - for i in range(cfg.policy.update_per_collect): + for i in range(update_per_collect): # Learner will train ``update_per_collect`` times in one iteration. if replay_buffer.get_num_of_transitions() > batch_size: train_data = replay_buffer.sample(batch_size, policy) diff --git a/zoo/petting_zoo/__init__.py b/zoo/petting_zoo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/petting_zoo/config/__init__.py b/zoo/petting_zoo/config/__init__.py new file mode 100644 index 000000000..5348554ae --- /dev/null +++ b/zoo/petting_zoo/config/__init__.py @@ -0,0 +1 @@ +from .ptz_simple_spread_ez_config import main_config, create_config diff --git a/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py b/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py new file mode 100644 index 000000000..109b472af --- /dev/null +++ b/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py @@ -0,0 +1,114 @@ +from easydict import EasyDict + +env_name = 'ptz_simple_spread' +multi_agent = True + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +n_agent = 3 +n_landmark = n_agent +collector_env_num = 8 +evaluator_env_num = 8 +n_episode = 8 +batch_size = 256 +num_simulations = 50 +update_per_collect = 1000 +reanalyze_ratio = 0. +action_space_size = 5 +seed = 0 +eps_greedy_exploration_in_collect = True +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +main_config = dict( + exp_name= + f'data_ez_ctree/{env_name}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', + env=dict( + env_family='mpe', + env_id='simple_spread_v2', + n_agent=n_agent, + n_landmark=n_landmark, + max_cycles=25, + agent_obs_only=False, + agent_specific_global_state=True, + continuous_actions=False, + stop_value=0, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + multi_agent=multi_agent, + ignore_done=False, + model=dict( + model_type='structure', + env_name=env_name, + latent_state_dim=18, + frame_stack_num=1, + action_space='discrete', + action_space_size=action_space_size, + agent_num=n_agent, + agent_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2, + global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + + n_landmark * 2 + n_agent * (n_agent - 1) * 2, + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + mcts_ctree=True, + gumbel_algo=False, + env_type='not_board_games', + game_segment_length=400, + random_collect_episode_num=0, + eps=dict( + eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, + type='linear', + start=1., + end=0.05, + decay=int(1e5), + ), + use_augmentation=False, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + learn=dict(learner=dict( + log_policy=True, + hook=dict(log_show_after_iter=10, ), + ), ), +) +main_config = EasyDict(main_config) +create_config = dict( + env=dict( + import_names=['zoo.petting_zoo.envs.petting_zoo_simple_spread_env'], + type='petting_zoo', + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='multi_agent_efficientzero', + import_names=['lzero.policy.multi_agent_efficientzero'], + ), + collector=dict( + type='multi_agent_episode_muzero', + import_names=['lzero.worker.multi_agent_muzero_collector'], + ) +) +create_config = EasyDict(create_config) +ptz_simple_spread_efficientzero_config = main_config +ptz_simple_spread_efficientzero_create_config = create_config + +if __name__ == '__main__': + from lzero.entry import train_muzero + train_muzero([main_config, create_config], seed=seed) diff --git a/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py b/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py new file mode 100644 index 000000000..a17f1fa6d --- /dev/null +++ b/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py @@ -0,0 +1,116 @@ +from easydict import EasyDict + +env_name = 'ptz_simple_spread' +multi_agent = True + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +seed = 0 +n_agent = 3 +n_landmark = n_agent +collector_env_num = 8 +evaluator_env_num = 8 +n_episode = 8 +batch_size = 256 +num_simulations = 50 +update_per_collect = 1000 +reanalyze_ratio = 0. +action_space_size = 5 +eps_greedy_exploration_in_collect = True +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +main_config = dict( + exp_name= + f'data_mz_ctree/{env_name}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', + env=dict( + env_family='mpe', + env_id='simple_spread_v2', + n_agent=n_agent, + n_landmark=n_landmark, + max_cycles=25, + agent_obs_only=False, + agent_specific_global_state=True, + continuous_actions=False, + stop_value=0, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + multi_agent=multi_agent, + ignore_done=False, + model=dict( + model_type='structure', + env_name=env_name, + latent_state_dim=18, + frame_stack_num=1, + action_space='discrete', + action_space_size=action_space_size, + agent_num=n_agent, + self_supervised_learning_loss=False, # default is False + agent_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2, + global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + + n_landmark * 2 + n_agent * (n_agent - 1) * 2, + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + mcts_ctree=True, + gumbel_algo=False, + env_type='not_board_games', + game_segment_length=400, + random_collect_episode_num=0, + eps=dict( + eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, + type='linear', + start=1., + end=0.05, + decay=int(1e5), + ), + use_augmentation=False, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, + ssl_loss_weight=0, # default is 0 + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + learn=dict(learner=dict( + log_policy=True, + hook=dict(log_show_after_iter=10, ), + ), ), +) +main_config = EasyDict(main_config) +create_config = dict( + env=dict( + import_names=['zoo.petting_zoo.envs.petting_zoo_simple_spread_env'], + type='petting_zoo', + ), + env_manager=dict(type='base'), + policy=dict( + type='multi_agent_muzero', + import_names=['lzero.policy.multi_agent_muzero'], + ), + collector=dict( + type='multi_agent_episode_muzero', + import_names=['lzero.worker.multi_agent_muzero_collector'], + ) +) +create_config = EasyDict(create_config) +ptz_simple_spread_muzero_config = main_config +ptz_simple_spread_muzero_create_config = create_config + +if __name__ == '__main__': + from lzero.entry import train_muzero + train_muzero([main_config, create_config], seed=seed) diff --git a/zoo/petting_zoo/envs/__init__.py b/zoo/petting_zoo/envs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py new file mode 100644 index 000000000..9944d8180 --- /dev/null +++ b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py @@ -0,0 +1,368 @@ +from typing import Any, List, Union, Optional, Dict +import gymnasium as gym +import numpy as np +import pettingzoo +from functools import reduce + +from ding.envs import BaseEnv, BaseEnvTimestep, FrameStackWrapper +from ding.torch_utils import to_ndarray, to_list +from ding.envs.common.common_function import affine_transform +from ding.utils import ENV_REGISTRY, import_module +from pettingzoo.utils.conversions import parallel_wrapper_fn +from pettingzoo.mpe._mpe_utils.simple_env import SimpleEnv, make_env +from pettingzoo.mpe.simple_spread.simple_spread import Scenario + + +@ENV_REGISTRY.register('petting_zoo') +class PettingZooEnv(BaseEnv): + # Now only supports simple_spread_v2. + # All agents' observations should have the same shape. + + def __init__(self, cfg: dict) -> None: + self._cfg = cfg + self._init_flag = False + self._replay_path = None + self._env_family = self._cfg.env_family + self._env_id = self._cfg.env_id + self._num_agents = self._cfg.n_agent + self._num_landmarks = self._cfg.n_landmark + self._continuous_actions = self._cfg.get('continuous_actions', False) + self._max_cycles = self._cfg.get('max_cycles', 25) + self._act_scale = self._cfg.get('act_scale', False) + self._agent_specific_global_state = self._cfg.get('agent_specific_global_state', False) + if self._act_scale: + assert self._continuous_actions, 'Only continuous action space env needs act_scale' + + def reset(self) -> np.ndarray: + if not self._init_flag: + # In order to align with the simple spread in Multiagent Particle Env (MPE), + # instead of adopting the pettingzoo interface directly, + # we have redefined the way rewards are calculated + + # import_module(['pettingzoo.{}.{}'.format(self._env_family, self._env_id)]) + # self._env = pettingzoo.__dict__[self._env_family].__dict__[self._env_id].parallel_env( + # N=self._cfg.n_agent, continuous_actions=self._continuous_actions, max_cycles=self._max_cycles + # ) + + # init parallel_env wrapper + _env = make_env(simple_spread_raw_env) + parallel_env = parallel_wrapper_fn(_env) + # init env + self._env = parallel_env( + N=self._cfg.n_agent, continuous_actions=self._continuous_actions, max_cycles=self._max_cycles + ) + # dynamic seed reduces training speed greatly + # if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: + # np_seed = 100 * np.random.randint(1, 1000) + # self._env.seed(self._seed + np_seed) + if self._replay_path is not None: + self._env = gym.wrappers.Monitor( + self._env, self._replay_path, video_callable=lambda episode_id: True, force=True + ) + if hasattr(self, '_seed'): + obs = self._env.reset(seed=self._seed) + else: + obs = self._env.reset() + if not self._init_flag: + self._agents = self._env.agents + + self._action_space = gym.spaces.Dict({agent: self._env.action_space(agent) for agent in self._agents}) + single_agent_obs_space = self._env.action_space(self._agents[0]) + if isinstance(single_agent_obs_space, gym.spaces.Box): + self._action_dim = single_agent_obs_space.shape + elif isinstance(single_agent_obs_space, gym.spaces.Discrete): + self._action_dim = (single_agent_obs_space.n, ) + else: + raise Exception('Only support `Box` or `Discrete` obs space for single agent.') + + # only for env 'simple_spread_v2', n_agent = 5 + # now only for the case that each agent in the team have the same obs structure and corresponding shape. + if not self._cfg.agent_obs_only: + self._observation_space = gym.spaces.Dict( + { + 'agent_state': gym.spaces.Box( + low=float("-inf"), + high=float("inf"), + shape=(self._num_agents, + self._env.observation_space('agent_0').shape[0]), # (self._num_agents, 30) + dtype=np.float32 + ), + 'global_state': gym.spaces.Box( + low=float("-inf"), + high=float("inf"), + shape=( + 4 * self._num_agents + 2 * self._num_landmarks + 2 * self._num_agents * + (self._num_agents - 1), + ), + dtype=np.float32 + ), + 'agent_alone_state': gym.spaces.Box( + low=float("-inf"), + high=float("inf"), + shape=(self._num_agents, 4 + 2 * self._num_landmarks + 2 * (self._num_agents - 1)), + dtype=np.float32 + ), + 'agent_alone_padding_state': gym.spaces.Box( + low=float("-inf"), + high=float("inf"), + shape=(self._num_agents, + self._env.observation_space('agent_0').shape[0]), # (self._num_agents, 30) + dtype=np.float32 + ), + 'action_mask': gym.spaces.Box( + low=float("-inf"), + high=float("inf"), + shape=(self._num_agents, self._action_dim[0]), # (self._num_agents, 5) + dtype=np.float32 + ) + } + ) + # whether use agent_specific_global_state. It is usually used in AC multiagent algos, e.g., mappo, masac, etc. + if self._agent_specific_global_state: + agent_specifig_global_state = gym.spaces.Box( + low=float("-inf"), + high=float("inf"), + shape=( + self._num_agents, self._env.observation_space('agent_0').shape[0] + 4 * self._num_agents + + 2 * self._num_landmarks + 2 * self._num_agents * (self._num_agents - 1) + ), + dtype=np.float32 + ) + self._observation_space['global_state'] = agent_specifig_global_state + else: + # for case when env.agent_obs_only=True + self._observation_space = gym.spaces.Box( + low=float("-inf"), + high=float("inf"), + shape=(self._num_agents, self._env.observation_space('agent_0').shape[0]), + dtype=np.float32 + ) + + self._reward_space = gym.spaces.Dict( + { + agent: gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(1, ), dtype=np.float32) + for agent in self._agents + } + ) + self._init_flag = True + # self._eval_episode_return = {agent: 0. for agent in self._agents} + self._eval_episode_return = 0. + self._step_count = 0 + obs_n = self._process_obs(obs) + return obs_n + + def close(self) -> None: + if self._init_flag: + self._env.close() + self._init_flag = False + + def render(self) -> None: + self._env.render() + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + self._seed = seed + self._dynamic_seed = dynamic_seed + np.random.seed(self._seed) + + def step(self, action: dict) -> BaseEnvTimestep: + self._step_count += 1 + action = np.array(list(action.values())) + action = self._process_action(action) + if self._act_scale: + for agent in self._agents: + # print(action[agent]) + # print(self.action_space[agent]) + # print(self.action_space[agent].low, self.action_space[agent].high) + action[agent] = affine_transform( + action[agent], min_val=self.action_space[agent].low, max_val=self.action_space[agent].high + ) + + obs, rew, done, trunc, info = self._env.step(action) + obs_n = self._process_obs(obs) + rew_n = np.array([sum([rew[agent] for agent in self._agents])]) + rew_n = rew_n.astype(np.float32) + # collide_sum = 0 + # for i in range(self._num_agents): + # collide_sum += info['n'][i][1] + # collide_penalty = self._cfg.get('collide_penal', self._num_agent) + # rew_n += collide_sum * (1.0 - collide_penalty) + # rew_n = rew_n / (self._cfg.get('max_cycles', 25) * self._num_agent) + self._eval_episode_return += rew_n.item() + + # occupied_landmarks = info['n'][0][3] + # if self._step_count >= self._max_step or occupied_landmarks >= self._n_agent \ + # or occupied_landmarks >= self._num_landmarks: + # done_n = True + # else: + # done_n = False + done_n = reduce(lambda x, y: x and y, done.values()) or self._step_count >= self._max_cycles + + # for agent in self._agents: + # self._eval_episode_return[agent] += rew[agent] + if done_n: # or reduce(lambda x, y: x and y, done.values()) + info['eval_episode_return'] = self._eval_episode_return + # for agent in rew: + # rew[agent] = to_ndarray([rew[agent]]) + return BaseEnvTimestep(obs_n, rew_n, done_n, info) + + def enable_save_replay(self, replay_path: Optional[str] = None) -> None: + if replay_path is None: + replay_path = './video' + self._replay_path = replay_path + + def _process_obs(self, obs: 'torch.Tensor') -> np.ndarray: # noqa + obs = np.array([obs[agent] for agent in self._agents]).astype(np.float32) + if self._cfg.get('agent_obs_only', False): + return obs + ret = {} + # Raw agent observation structure is -- + # [self_vel, self_pos, landmark_rel_positions, other_agent_rel_positions, communication] + # where `communication` are signals from other agents (two for each agent in `simple_spread_v2`` env) + + # agent_state: Shape (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2). + # Stacked observation. Contains + # - agent itself's state(velocity + position) + # - position of items that the agent can observe(e.g. other agents, landmarks) + # - communication + ret['agent_state'] = obs + # global_state: Shape (n_agent * (2 + 2) + n_landmark * 2 + n_agent * (n_agent - 1) * 2, ). + # 1-dim vector. Contains + # - all agents' state(velocity + position) + + # - all landmarks' position + + # - all agents' communication + ret['global_state'] = np.concatenate( + [ + obs[0, 2:-(self._num_agents - 1) * 2], # all agents' position + all landmarks' position + obs[:, 0:2].flatten(), # all agents' velocity + obs[:, -(self._num_agents - 1) * 2:].flatten() # all agents' communication + ] + ) + # agent_specific_global_state: Shape (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + n_landmark * 2 + n_agent * (n_agent - 1) * 2). + # 2-dim vector. contains + # - agent_state info + # - global_state info + if self._agent_specific_global_state: + ret['global_state'] = np.concatenate( + [ret['agent_state'], + np.expand_dims(ret['global_state'], axis=0).repeat(self._num_agents, axis=0)], + axis=1 + ) + # agent_alone_state: Shape (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2). + # Stacked observation. Exclude other agents' positions from agent_state. Contains + # - agent itself's state(velocity + position) + + # - landmarks' positions (do not include other agents' positions) + # - communication + ret['agent_alone_state'] = np.concatenate( + [ + obs[:, 0:(4 + self._num_agents * 2)], # agent itself's state + landmarks' position + obs[:, -(self._num_agents - 1) * 2:], # communication + ], + 1 + ) + # agent_alone_padding_state: Shape (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2). + # Contains the same information as agent_alone_state; + # But 0-padding other agents' positions. + ret['agent_alone_padding_state'] = np.concatenate( + [ + obs[:, 0:(4 + self._num_agents * 2)], # agent itself's state + landmarks' position + np.zeros((self._num_agents, + (self._num_agents - 1) * 2), np.float32), # Other agents' position(0-padding) + obs[:, -(self._num_agents - 1) * 2:] # communication + ], + 1 + ) + # action_mask: All actions are of use(either 1 for discrete or 5 for continuous). Thus all 1. + # action_mask = np.ones((self._num_agents, *self._action_dim)).astype(np.float32) + action_mask = [[1 for _ in range(*self._action_dim)] for _ in range(self._num_agents)] + to_play = [-1 for _ in self._agents] # Moot, for alignment with other environments + + ret_transform = [] + for i in range(len(self.agents)): + tmp = {} + for k,v in ret.items(): + tmp[k] = v[i] + tmp['action_mask'] = [1 for _ in range(*self._action_dim)] + ret_transform.append(tmp) + + return {'observation': ret_transform, 'action_mask': action_mask, 'to_play': to_play} + + def _process_action(self, action: 'torch.Tensor') -> Dict[str, np.ndarray]: # noqa + dict_action = {} + for i, agent in enumerate(self._agents): + agent_action = action[i] + if agent_action.shape == (1, ): + agent_action = agent_action.squeeze() # 0-dim array + dict_action[agent] = agent_action + return dict_action + + def random_action(self) -> np.ndarray: + random_action = self.action_space.sample() + for k in random_action: + if isinstance(random_action[k], np.ndarray): + pass + elif isinstance(random_action[k], int): + random_action[k] = to_ndarray([random_action[k]], dtype=np.int64) + return random_action + + def __repr__(self) -> str: + return "DI-engine PettingZoo Env" + + @property + def agents(self) -> List[str]: + return self._agents + + @property + def observation_space(self) -> gym.spaces.Space: + return self._observation_space + + @property + def action_space(self) -> gym.spaces.Space: + return self._action_space + + @property + def reward_space(self) -> gym.spaces.Space: + return self._reward_space + + +class simple_spread_raw_env(SimpleEnv): + + def __init__(self, N=3, local_ratio=0.5, max_cycles=25, continuous_actions=False): + assert 0. <= local_ratio <= 1., "local_ratio is a proportion. Must be between 0 and 1." + scenario = Scenario() + world = scenario.make_world(N) + super().__init__(scenario, world, max_cycles, continuous_actions=continuous_actions, local_ratio=local_ratio) + self.metadata['name'] = "simple_spread_v2" + + def _execute_world_step(self): + # set action for each agent + for i, agent in enumerate(self.world.agents): + action = self.current_actions[i] + scenario_action = [] + if agent.movable: + mdim = self.world.dim_p * 2 + 1 + if self.continuous_actions: + scenario_action.append(action[0:mdim]) + action = action[mdim:] + else: + scenario_action.append(action % mdim) + action //= mdim + if not agent.silent: + scenario_action.append(action) + self._set_action(scenario_action, agent, self.action_spaces[agent.name]) + + self.world.step() + + global_reward = 0. + if self.local_ratio is not None: + global_reward = float(self.scenario.global_reward(self.world)) + + for agent in self.world.agents: + agent_reward = float(self.scenario.reward(agent, self.world)) + if self.local_ratio is not None: + # we changed reward calc way to keep same with mpe + # reward = global_reward * (1 - self.local_ratio) + agent_reward * self.local_ratio + reward = global_reward + agent_reward + else: + reward = agent_reward + + self.rewards[agent.name] = reward diff --git a/zoo/petting_zoo/envs/test_petting_zoo_simple_spread_env.py b/zoo/petting_zoo/envs/test_petting_zoo_simple_spread_env.py new file mode 100644 index 000000000..22117cf85 --- /dev/null +++ b/zoo/petting_zoo/envs/test_petting_zoo_simple_spread_env.py @@ -0,0 +1,133 @@ +from easydict import EasyDict +import pytest +import numpy as np +import pettingzoo +from ding.utils import import_module + +from dizoo.petting_zoo.envs.petting_zoo_simple_spread_env import PettingZooEnv + + +@pytest.mark.envtest +class TestPettingZooEnv: + + def test_agent_obs_only(self): + n_agent = 5 + n_landmark = n_agent + env = PettingZooEnv( + EasyDict( + dict( + env_family='mpe', + env_id='simple_spread_v2', + n_agent=n_agent, + n_landmark=n_landmark, + max_step=100, + agent_obs_only=True, + continuous_actions=True, + ) + ) + ) + env.seed(123) + assert env._seed == 123 + obs = env.reset() + assert obs.shape == (n_agent, 2 + 2 + (n_agent - 1) * 2 + n_agent * 2 + (n_agent - 1) * 2) + for i in range(10): + random_action = env.random_action() + random_action = np.array([random_action[agent] for agent in random_action]) + timestep = env.step(random_action) + print(timestep) + assert isinstance(timestep.obs, np.ndarray), timestep.obs + assert timestep.obs.shape == (n_agent, 2 + 2 + (n_agent - 1) * 2 + n_agent * 2 + (n_agent - 1) * 2) + assert isinstance(timestep.done, bool), timestep.done + assert isinstance(timestep.reward, np.ndarray), timestep.reward + assert timestep.reward.dtype == np.float32 + print(env.observation_space, env.action_space, env.reward_space) + env.close() + + def test_dict_obs(self): + n_agent = 5 + n_landmark = n_agent + env = PettingZooEnv( + EasyDict( + dict( + env_family='mpe', + env_id='simple_spread_v2', + n_agent=n_agent, + n_landmark=n_landmark, + max_step=100, + agent_obs_only=False, + continuous_actions=True, + ) + ) + ) + env.seed(123) + assert env._seed == 123 + obs = env.reset() + for k, v in obs.items(): + print(k, v.shape) + for i in range(10): + random_action = env.random_action() + random_action = np.array([random_action[agent] for agent in random_action]) + timestep = env.step(random_action) + print(timestep) + assert isinstance(timestep.obs, dict), timestep.obs + assert isinstance(timestep.obs['agent_state'], np.ndarray), timestep.obs + assert timestep.obs['agent_state'].shape == ( + n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + ) + assert timestep.obs['global_state'].shape == ( + n_agent * (2 + 2) + n_landmark * 2 + n_agent * (n_agent - 1) * 2, + ) + assert timestep.obs['agent_alone_state'].shape == (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2) + assert timestep.obs['agent_alone_padding_state'].shape == ( + n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + ) + assert timestep.obs['action_mask'].dtype == np.float32 + assert isinstance(timestep.done, bool), timestep.done + assert isinstance(timestep.reward, np.ndarray), timestep.reward + print(env.observation_space, env.action_space, env.reward_space) + env.close() + + def test_agent_specific_global_state(self): + n_agent = 5 + n_landmark = n_agent + env = PettingZooEnv( + EasyDict( + dict( + env_family='mpe', + env_id='simple_spread_v2', + n_agent=n_agent, + n_landmark=n_landmark, + max_step=100, + agent_obs_only=False, + agent_specific_global_state=True, + continuous_actions=True, + ) + ) + ) + env.seed(123) + assert env._seed == 123 + obs = env.reset() + for k, v in obs.items(): + print(k, v.shape) + for i in range(10): + random_action = env.random_action() + random_action = np.array([random_action[agent] for agent in random_action]) + timestep = env.step(random_action) + print(timestep) + assert isinstance(timestep.obs, dict), timestep.obs + assert isinstance(timestep.obs['agent_state'], np.ndarray), timestep.obs + assert timestep.obs['agent_state'].shape == ( + n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + ) + assert timestep.obs['global_state'].shape == ( + n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + + n_landmark * 2 + n_agent * (n_agent - 1) * 2 + ) + assert timestep.obs['agent_alone_state'].shape == (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2) + assert timestep.obs['agent_alone_padding_state'].shape == ( + n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + ) + assert isinstance(timestep.done, bool), timestep.done + assert isinstance(timestep.reward, np.ndarray), timestep.reward + print(env.observation_space, env.action_space, env.reward_space) + env.close() From e4667df1d7a895976dd8ec9fc1c9008f2793c5c6 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Fri, 4 Aug 2023 13:31:41 +0800 Subject: [PATCH 34/54] polish(yzj): polish multi agent muzero collector --- .../multi_agent_game_buffer_efficientzero.py | 1 - lzero/worker/multi_agent_muzero_collector.py | 27 ------------------- 2 files changed, 28 deletions(-) diff --git a/lzero/mcts/buffer/multi_agent_game_buffer_efficientzero.py b/lzero/mcts/buffer/multi_agent_game_buffer_efficientzero.py index 84382895b..5debf8db8 100644 --- a/lzero/mcts/buffer/multi_agent_game_buffer_efficientzero.py +++ b/lzero/mcts/buffer/multi_agent_game_buffer_efficientzero.py @@ -20,7 +20,6 @@ class MultiAgentSampledEfficientZeroGameBuffer(MultiAgentMuZeroGameBuffer): The specific game buffer for Multi Agent EfficientZero policy. """ - def _prepare_reward_value_context( self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[Any], total_transitions: int diff --git a/lzero/worker/multi_agent_muzero_collector.py b/lzero/worker/multi_agent_muzero_collector.py index ad0690d0b..bd5dc7544 100644 --- a/lzero/worker/multi_agent_muzero_collector.py +++ b/lzero/worker/multi_agent_muzero_collector.py @@ -27,33 +27,6 @@ class MultiAgentMuZeroCollector(MuZeroCollector): envstep """ - # TO be compatible with ISerialCollector - config = dict() - - def __init__( - self, - collect_print_freq: int = 100, - env: BaseEnvManager = None, - policy: namedtuple = None, - tb_logger: 'SummaryWriter' = None, # noqa - exp_name: Optional[str] = 'default_experiment', - instance_name: Optional[str] = 'collector', - policy_config: 'policy_config' = None, # noqa - ) -> None: - """ - Overview: - Init the collector according to input arguments. - Arguments: - - collect_print_freq (:obj:`int`): collect_print_frequency in terms of training_steps. - - env (:obj:`BaseEnvManager`): the subclass of vectorized env_manager(BaseEnvManager) - - policy (:obj:`namedtuple`): the api namedtuple of collect_mode policy - - tb_logger (:obj:`SummaryWriter`): tensorboard handle - - instance_name (:obj:`Optional[str]`): Name of this instance. - - exp_name (:obj:`str`): Experiment name, which is used to indicate output directory. - - policy_config: Config of game. - """ - super().__init__(collect_print_freq, env, policy, tb_logger, exp_name, instance_name, policy_config) - def _compute_priorities(self, i, agent_id, pred_values_lst, search_values_lst): """ Overview: From b6dca699cd8d1163cb49315be305715a3a71bfcd Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Tue, 8 Aug 2023 19:30:48 +0800 Subject: [PATCH 35/54] polish(yzj): polish gobigger collector and config to support t2p3 --- lzero/worker/gobigger_muzero_evaluator.py | 40 ++----------------- .../config/gobigger_efficientzero_config.py | 10 ++++- zoo/gobigger/config/gobigger_muzero_config.py | 10 ++++- 3 files changed, 19 insertions(+), 41 deletions(-) diff --git a/lzero/worker/gobigger_muzero_evaluator.py b/lzero/worker/gobigger_muzero_evaluator.py index 440827936..0bf6bdc05 100644 --- a/lzero/worker/gobigger_muzero_evaluator.py +++ b/lzero/worker/gobigger_muzero_evaluator.py @@ -18,41 +18,6 @@ class GoBiggerMuZeroEvaluator(MuZeroEvaluator): - """ - Overview: - The Evaluator for GoBigger MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. - Interfaces: - __init__, reset, reset_policy, reset_env, close, should_eval, eval - Property: - env, policy - """ - def __init__( - self, - eval_freq: int = 1000, - n_evaluator_episode: int = 3, - stop_value: int = 1e6, - env: BaseEnvManager = None, - policy: namedtuple = None, - tb_logger: 'SummaryWriter' = None, # noqa - exp_name: Optional[str] = 'default_experiment', - instance_name: Optional[str] = 'evaluator', - policy_config: 'policy_config' = None, # noqa - ) -> None: - """ - Overview: - Init method. Load config and use ``self._cfg`` setting to build common serial evaluator components, - e.g. logger helper, timer. - Arguments: - - eval_freq (:obj:`int`): evaluation frequency in terms of training steps. - - n_evaluator_episode (:obj:`int`): the number of episodes to eval in total. - - env (:obj:`BaseEnvManager`): the subclass of vectorized env_manager(BaseEnvManager) - - policy (:obj:`namedtuple`): the api namedtuple of collect_mode policy - - tb_logger (:obj:`SummaryWriter`): tensorboard handle - - exp_name (:obj:`str`): Experiment name, which is used to indicate output directory. - - instance_name (:obj:`Optional[str]`): Name of this instance. - - policy_config: Config of game. - """ - super().__init__(eval_freq, n_evaluator_episode, stop_value, env, policy, tb_logger, exp_name, instance_name, policy_config) def _add_info(self, last_timestep, info): # add eat info @@ -92,7 +57,9 @@ def eval_vsbot( self._policy.reset() # specifically for vs bot - self._bot_policy = GoBiggerBot(env_nums, agent_id=[2, 3]) #TODO only support t2p2 + agent_num = self.policy_config['model']['agent_num'] + team_num = self.policy_config['model']['team_num'] + self._bot_policy = GoBiggerBot(env_nums, agent_id=[i for i in range(agent_num//team_num, agent_num)]) #TODO only support t2p2 self._bot_policy.reset() # initializations @@ -112,7 +79,6 @@ def eval_vsbot( init_obs = self._env.ready_obs # specifically for vs bot - agent_num = len(init_obs[0]['action_mask']) // 2 #TODO only support t2p2 for i in range(env_nums): for k, v in init_obs[i].items(): if k != 'raw_obs': diff --git a/zoo/gobigger/config/gobigger_efficientzero_config.py b/zoo/gobigger/config/gobigger_efficientzero_config.py index 23c8e5a11..b6fb47086 100644 --- a/zoo/gobigger/config/gobigger_efficientzero_config.py +++ b/zoo/gobigger/config/gobigger_efficientzero_config.py @@ -17,6 +17,9 @@ action_space_size = 27 direction_num = 12 eps_greedy_exploration_in_collect = True +player_num_per_team = 3 +team_num = 2 +agent_num = player_num_per_team*team_num # default is GoBigger T2P2 # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -25,7 +28,9 @@ exp_name= f'data_ez_ctree/{env_name}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', env=dict( - env_name=env_name, # default is 'GoBigger T2P2' + env_name=env_name, + player_num_per_team=player_num_per_team, + team_num=team_num, collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, @@ -37,7 +42,8 @@ model=dict( model_type='structure', env_name=env_name, - agent_num=4, # default is t2p2 + agent_num=agent_num, + team_num=team_num, latent_state_dim=176, frame_stack_num=1, action_space_size=action_space_size, diff --git a/zoo/gobigger/config/gobigger_muzero_config.py b/zoo/gobigger/config/gobigger_muzero_config.py index df2eacf73..a5ad0b80e 100644 --- a/zoo/gobigger/config/gobigger_muzero_config.py +++ b/zoo/gobigger/config/gobigger_muzero_config.py @@ -17,6 +17,9 @@ action_space_size = 27 direction_num = 12 eps_greedy_exploration_in_collect = True +player_num_per_team = 2 +team_num = 2 +agent_num = player_num_per_team*team_num # default is GoBigger T2P2 # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -24,7 +27,9 @@ gobigger_muzero_config = dict( exp_name=f'data_mz_ctree/{env_name}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', env=dict( - env_name=env_name, # default is 'GoBigger T2P2' + env_name=env_name, + player_num_per_team=player_num_per_team, + team_num=team_num, collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, @@ -36,7 +41,8 @@ model=dict( model_type='structure', env_name=env_name, - agent_num=4, # default is t2p2 + agent_num=agent_num, + team_num=team_num, latent_state_dim=176, frame_stack_num=1, action_space_size=action_space_size, From 09a4440a9cfde30ff503a0de9ca2d5d5f96807d8 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Tue, 8 Aug 2023 20:17:41 +0800 Subject: [PATCH 36/54] feature(yzj): add fc encoder on ptz env instead of identity --- lzero/model/petting_zoo/encoder.py | 3 ++- zoo/petting_zoo/config/ptz_simple_spread_ez_config.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/lzero/model/petting_zoo/encoder.py b/lzero/model/petting_zoo/encoder.py index 015ca6066..9f3313731 100644 --- a/lzero/model/petting_zoo/encoder.py +++ b/lzero/model/petting_zoo/encoder.py @@ -1,10 +1,11 @@ import torch.nn as nn +from ding.model.common import FCEncoder class PettingZooEncoder(nn.Module): def __init__(self): super().__init__() - self.encoder = nn.Identity() + self.encoder = FCEncoder(obs_shape=18, hidden_size_list=[256, 256], activation=nn.ReLU(), norm_type=None) def forward(self, x): x = x['agent_state'] diff --git a/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py b/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py index 109b472af..5a874d393 100644 --- a/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py +++ b/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py @@ -46,7 +46,7 @@ model=dict( model_type='structure', env_name=env_name, - latent_state_dim=18, + latent_state_dim=256, frame_stack_num=1, action_space='discrete', action_space_size=action_space_size, @@ -55,6 +55,8 @@ global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + n_landmark * 2 + n_agent * (n_agent - 1) * 2, discrete_action_encoding_type='one_hot', + global_cooperation=True, # TODO: doesn't work now + hidden_size_list=[256, 256], norm_type='BN', ), cuda=True, From 407329a1e359be8d9e1125b846b4c6b5dbe25dee Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Thu, 10 Aug 2023 17:56:55 +0800 Subject: [PATCH 37/54] polish(yzj): polish buffer name and remove ignore done in atari config --- lzero/entry/train_muzero.py | 2 +- lzero/mcts/buffer/__init__.py | 2 +- lzero/mcts/buffer/multi_agent_game_buffer_efficientzero.py | 2 +- zoo/atari/config/atari_efficientzero_config.py | 1 - 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/lzero/entry/train_muzero.py b/lzero/entry/train_muzero.py index 3ffd4202c..f463b266f 100644 --- a/lzero/entry/train_muzero.py +++ b/lzero/entry/train_muzero.py @@ -58,7 +58,7 @@ def train_muzero( elif create_cfg.policy.type == 'multi_agent_muzero': from lzero.mcts import MultiAgentMuZeroGameBuffer as GameBuffer elif create_cfg.policy.type == 'multi_agent_efficientzero': - from lzero.mcts import MultiAgentSampledEfficientZeroGameBuffer as GameBuffer + from lzero.mcts import MultiAgentEfficientZeroGameBuffer as GameBuffer if cfg.policy.cuda and torch.cuda.is_available(): cfg.policy.device = 'cuda' diff --git a/lzero/mcts/buffer/__init__.py b/lzero/mcts/buffer/__init__.py index 78864d59e..25cea5ccb 100644 --- a/lzero/mcts/buffer/__init__.py +++ b/lzero/mcts/buffer/__init__.py @@ -3,4 +3,4 @@ from .game_buffer_sampled_efficientzero import SampledEfficientZeroGameBuffer from .game_buffer_gumbel_muzero import GumbelMuZeroGameBuffer from .multi_agent_game_buffer_muzero import MultiAgentMuZeroGameBuffer -from .multi_agent_game_buffer_efficientzero import MultiAgentSampledEfficientZeroGameBuffer \ No newline at end of file +from .multi_agent_game_buffer_efficientzero import MultiAgentEfficientZeroGameBuffer \ No newline at end of file diff --git a/lzero/mcts/buffer/multi_agent_game_buffer_efficientzero.py b/lzero/mcts/buffer/multi_agent_game_buffer_efficientzero.py index 5debf8db8..0f04fa8b0 100644 --- a/lzero/mcts/buffer/multi_agent_game_buffer_efficientzero.py +++ b/lzero/mcts/buffer/multi_agent_game_buffer_efficientzero.py @@ -14,7 +14,7 @@ @BUFFER_REGISTRY.register('multi_agent_game_buffer_efficientzero') -class MultiAgentSampledEfficientZeroGameBuffer(MultiAgentMuZeroGameBuffer): +class MultiAgentEfficientZeroGameBuffer(MultiAgentMuZeroGameBuffer): """ Overview: The specific game buffer for Multi Agent EfficientZero policy. diff --git a/zoo/atari/config/atari_efficientzero_config.py b/zoo/atari/config/atari_efficientzero_config.py index 5861deb50..1de1800e2 100644 --- a/zoo/atari/config/atari_efficientzero_config.py +++ b/zoo/atari/config/atari_efficientzero_config.py @@ -52,7 +52,6 @@ norm_type='BN', ), cuda=True, - ignore_done=True, env_type='not_board_games', game_segment_length=400, random_collect_episode_num=0, From 592fab188bf37128953c4ba342c8913a4dce9a41 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Mon, 14 Aug 2023 21:04:18 +0800 Subject: [PATCH 38/54] fix(yzj): fix ssl data bug and polish to_device code --- .../multi_agent_game_buffer_efficientzero.py | 3 +- .../buffer/multi_agent_game_buffer_muzero.py | 3 +- lzero/mcts/utils.py | 2 +- lzero/policy/efficientzero.py | 32 +++++++++++---- lzero/policy/multi_agent_efficientzero.py | 12 +----- lzero/policy/multi_agent_muzero.py | 11 +----- lzero/policy/multi_agent_random_policy.py | 1 - lzero/policy/muzero.py | 18 +++++---- lzero/policy/utils.py | 39 +++++++++++++++---- lzero/worker/muzero_collector.py | 14 ++++--- .../config/gobigger_efficientzero_config.py | 2 +- zoo/gobigger/config/gobigger_muzero_config.py | 2 +- zoo/gobigger/entry/train_muzero_gobigger.py | 2 +- .../config/ptz_simple_spread_ez_config.py | 2 +- .../config/ptz_simple_spread_mz_config.py | 8 ++-- .../envs/petting_zoo_simple_spread_env.py | 4 +- 16 files changed, 93 insertions(+), 62 deletions(-) diff --git a/lzero/mcts/buffer/multi_agent_game_buffer_efficientzero.py b/lzero/mcts/buffer/multi_agent_game_buffer_efficientzero.py index 0f04fa8b0..a592df4b0 100644 --- a/lzero/mcts/buffer/multi_agent_game_buffer_efficientzero.py +++ b/lzero/mcts/buffer/multi_agent_game_buffer_efficientzero.py @@ -134,10 +134,9 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A #m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() m_obs = value_obs_list[beg_index:end_index] - m_obs = to_tensor(m_obs) m_obs = sum(m_obs, []) - m_obs = to_device(m_obs, self._cfg.device) m_obs = default_collate(m_obs) + m_obs = to_device(m_obs, self._cfg.device) # calculate the target value m_output = model.initial_inference(m_obs) diff --git a/lzero/mcts/buffer/multi_agent_game_buffer_muzero.py b/lzero/mcts/buffer/multi_agent_game_buffer_muzero.py index eadb473d9..53518ac82 100644 --- a/lzero/mcts/buffer/multi_agent_game_buffer_muzero.py +++ b/lzero/mcts/buffer/multi_agent_game_buffer_muzero.py @@ -152,10 +152,9 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A # m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() m_obs = value_obs_list[beg_index:end_index] - m_obs = to_tensor(m_obs) m_obs = sum(m_obs, []) - m_obs = to_device(m_obs, self._cfg.device) m_obs = default_collate(m_obs) + m_obs = to_device(m_obs, self._cfg.device) # calculate the target value m_output = model.initial_inference(m_obs) diff --git a/lzero/mcts/utils.py b/lzero/mcts/utils.py index 57915ee8c..c46b95578 100644 --- a/lzero/mcts/utils.py +++ b/lzero/mcts/utils.py @@ -99,7 +99,7 @@ def prepare_observation(observation_list, model_type='conv'): # print(observation_array.shape) elif model_type == 'structure': - return observation_list + return observation_list return observation_array diff --git a/lzero/policy/efficientzero.py b/lzero/policy/efficientzero.py index 1071c3475..bfd4681d2 100644 --- a/lzero/policy/efficientzero.py +++ b/lzero/policy/efficientzero.py @@ -403,14 +403,30 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: end_index = self._cfg.model.observation_shape * (step_i + self._cfg.model.frame_stack_num) network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) elif self._cfg.model.model_type == 'structure': - beg_index = step_i - end_index = step_i + self._cfg.model.frame_stack_num - obs_target_batch_tmp = obs_target_batch[:, beg_index:end_index].tolist() - obs_target_batch_tmp = sum(obs_target_batch_tmp, []) - obs_target_batch_tmp = to_tensor(obs_target_batch_tmp) - obs_target_batch_tmp = to_device(obs_target_batch_tmp, self._cfg.device) - obs_target_batch_tmp = default_collate(obs_target_batch_tmp) - network_output = self._learn_model.initial_inference(obs_target_batch_tmp) + obs_target_batch_new = {} + for k, v in obs_target_batch.items(): + if k == 'action_mask': + obs_target_batch_new[k] = v + continue + if isinstance(v, dict): + obs_target_batch_new[k] = {} + for k1, v1 in v.items(): + if len(v1.shape) == 1: + observation_shape = v1.shape[0]//self._cfg.num_unroll_steps + beg_index = observation_shape * step_i + end_index = observation_shape * (step_i + self._cfg.model.frame_stack_num) + obs_target_batch_new[k][k1] = v1[beg_index:end_index] + else: + observation_shape = v1.shape[1]//self._cfg.num_unroll_steps + beg_index = observation_shape * step_i + end_index = observation_shape * (step_i + self._cfg.model.frame_stack_num) + obs_target_batch_new[k][k1] = v1[:, beg_index:end_index] + else: + observation_shape = v.shape[1]//self._cfg.num_unroll_steps + beg_index = observation_shape * step_i + end_index = observation_shape * (step_i + self._cfg.model.frame_stack_num) + obs_target_batch_new[k] = v[:, beg_index:end_index] + network_output = self._learn_model.initial_inference(obs_target_batch_new) latent_state = to_tensor(latent_state) representation_state = to_tensor(network_output.latent_state) diff --git a/lzero/policy/multi_agent_efficientzero.py b/lzero/policy/multi_agent_efficientzero.py index d01c66f58..ce5f2eb54 100644 --- a/lzero/policy/multi_agent_efficientzero.py +++ b/lzero/policy/multi_agent_efficientzero.py @@ -3,10 +3,7 @@ import numpy as np import torch from .efficientzero import EfficientZeroPolicy -from ding.torch_utils import to_tensor from ding.utils import POLICY_REGISTRY -from torch.distributions import Categorical -from torch.nn import L1Loss from lzero.mcts import EfficientZeroMCTSCtree as MCTSCtree from lzero.mcts import EfficientZeroMCTSPtree as MCTSPtree @@ -14,7 +11,6 @@ DiscreteSupport, select_action, to_torch_float_tensor, ez_network_output_unpack, negative_cosine_similarity, prepare_obs, \ configure_optimizers from collections import defaultdict -from ding.torch_utils import to_device, to_tensor from ding.utils.data import default_collate @@ -62,12 +58,10 @@ def _forward_collect( self.collect_epsilon = epsilon active_collect_env_num = len(data) - data = to_tensor(data) data = sum(sum(data, []), []) batch_size = len(data) - data = to_device(data, self._cfg.device) data = default_collate(data) - agent_num = batch_size // active_collect_env_num + agent_num = self._cfg['model']['agent_num'] to_play = np.array(to_play).reshape(-1).tolist() with torch.no_grad(): @@ -164,12 +158,10 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read """ self._eval_model.eval() active_eval_env_num = len(data) - data = to_tensor(data) data = sum(sum(data, []), []) batch_size = len(data) - data = to_device(data, self._cfg.device) data = default_collate(data) - agent_num = batch_size // active_eval_env_num + agent_num = self._cfg['model']['agent_num'] to_play = np.array(to_play).reshape(-1).tolist() with torch.no_grad(): diff --git a/lzero/policy/multi_agent_muzero.py b/lzero/policy/multi_agent_muzero.py index 6e5476a40..b77b5356b 100644 --- a/lzero/policy/multi_agent_muzero.py +++ b/lzero/policy/multi_agent_muzero.py @@ -4,20 +4,17 @@ import torch import torch.optim as optim from ding.model import model_wrap -from ding.policy.base_policy import Policy from ding.torch_utils import to_tensor from ding.utils import POLICY_REGISTRY from torch.nn import L1Loss from lzero.mcts import MuZeroMCTSCtree as MCTSCtree from lzero.mcts import MuZeroMCTSPtree as MCTSPtree -from lzero.model import ImageTransforms from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, prepare_obs, \ configure_optimizers from collections import defaultdict -from ding.torch_utils import to_device from ding.utils.data import default_collate from .muzero import MuZeroPolicy @@ -65,12 +62,10 @@ def _forward_collect( self.collect_mcts_temperature = temperature self.collect_epsilon = epsilon active_collect_env_num = len(data) - data = to_tensor(data) data = sum(sum(data, []), []) batch_size = len(data) - data = to_device(data, self._cfg.device) data = default_collate(data) - agent_num = batch_size // active_collect_env_num + agent_num = self._cfg['model']['agent_num'] to_play = np.array(to_play).reshape(-1).tolist() with torch.no_grad(): @@ -163,12 +158,10 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 """ self._eval_model.eval() active_eval_env_num = len(data) - data = to_tensor(data) data = sum(sum(data, []), []) batch_size = len(data) - data = to_device(data, self._cfg.device) data = default_collate(data) - agent_num = batch_size // active_eval_env_num + agent_num = self._cfg['model']['agent_num'] to_play = np.array(to_play).reshape(-1).tolist() with torch.no_grad(): diff --git a/lzero/policy/multi_agent_random_policy.py b/lzero/policy/multi_agent_random_policy.py index 1fc110b54..8d3e7083f 100644 --- a/lzero/policy/multi_agent_random_policy.py +++ b/lzero/policy/multi_agent_random_policy.py @@ -59,7 +59,6 @@ def _forward_collect( data = to_tensor(data) data = sum(sum(data, []), []) batch_size = len(data) - data = to_device(data, self._cfg.device) data = default_collate(data) agent_num = batch_size // active_collect_env_num to_play = np.array(to_play).reshape(-1).tolist() diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index 7d3297f1e..e2a8c8c97 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -382,14 +382,16 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in end_index = self._cfg.model.observation_shape * (step_i + self._cfg.model.frame_stack_num) network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) elif self._cfg.model.model_type == 'structure': - beg_index = step_i - end_index = step_i + self._cfg.model.frame_stack_num - obs_target_batch_tmp = obs_target_batch[:, beg_index:end_index].tolist() - obs_target_batch_tmp = sum(obs_target_batch_tmp, []) - obs_target_batch_tmp = to_tensor(obs_target_batch_tmp) - obs_target_batch_tmp = to_device(obs_target_batch_tmp, self._cfg.device) - obs_target_batch_tmp = default_collate(obs_target_batch_tmp) - network_output = self._learn_model.initial_inference(obs_target_batch_tmp) + obs_target_batch_new = {} + for k, v in obs_target_batch.items(): + if k == 'action_mask': + obs_target_batch_new[k] = v + continue + observation_shape = v.shape[1]//self._cfg.num_unroll_steps + beg_index = observation_shape * step_i + end_index = observation_shape * (step_i + self._cfg.model.frame_stack_num) + obs_target_batch_new[k] = v[:, beg_index:end_index] + network_output = self._learn_model.initial_inference(obs_target_batch_new) latent_state = to_tensor(latent_state) representation_state = to_tensor(network_output.latent_state) diff --git a/lzero/policy/utils.py b/lzero/policy/utils.py index f65db53fa..cea065232 100644 --- a/lzero/policy/utils.py +++ b/lzero/policy/utils.py @@ -179,19 +179,44 @@ def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, # timestep t1, and is only performed in the last 8 timesteps in the second dim in ``obs_batch_ori``. obs_target_batch = obs_batch_ori[:, cfg.model.observation_shape:] elif cfg.model.model_type == 'structure': - obs_batch_ori = obs_batch_ori.tolist() - obs_batch_ori = np.array(obs_batch_ori) + # dict obs_shape = 1 + batch_size = obs_batch_ori.shape[0] obs_batch = obs_batch_ori[:, 0:cfg.model.frame_stack_num] if cfg.model.self_supervised_learning_loss: obs_target_batch = obs_batch_ori[:, cfg.model.frame_stack_num:] - + + # obs_batch obs_batch = obs_batch.tolist() obs_batch = sum(obs_batch, []) - obs_batch = to_tensor(obs_batch) - obs_batch = to_device(obs_batch, cfg.device) obs_batch = default_collate(obs_batch) - - return obs_batch, obs_target_batch + obs_batch_new = {} + for k, v in obs_batch.items(): + if isinstance(v, dict): # espaecially for gobigger obs, { {'k':{'k1':[], 'k2':[]},} + obs_batch_new[k] = {} + for k1, v1 in v.items(): + if len(v1.shape) == 1: + obs_batch_new[k][k1] = v1 + else: + obs_batch_new[k][k1] = v1.reshape(batch_size, -1) + else: # espaecially for ptz obs, {'k1':[], 'k2':[]} + obs_batch_new[k] = v.reshape(batch_size, -1) + + # obs_target_batch + obs_target_batch = obs_target_batch.tolist() + obs_target_batch = sum(obs_target_batch, []) + obs_target_batch = default_collate(obs_target_batch) + obs_target_batch_new = {} + for k, v in obs_target_batch.items(): + if isinstance(v, dict): # espaecially for gobigger obs, { {'k':{'k1':[], 'k2':[]},} + obs_target_batch_new[k] = {} + for k1, v1 in v.items(): + if len(v1.shape) == 1: + obs_target_batch_new[k][k1] = v1 + else: + obs_target_batch_new[k][k1] = v1.reshape(batch_size, -1) + else: # espaecially for ptz obs, {'k1':[], 'k2':[]} + obs_target_batch_new[k] = v.reshape(batch_size, -1) + return obs_batch_new, obs_target_batch_new def negative_cosine_similarity(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 63ac5100b..dc1aafbb5 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -415,11 +415,15 @@ def collect(self, action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] to_play = [to_play_dict[env_id] for env_id in ready_env_id] - stack_obs = to_ndarray(stack_obs) - - if self.policy_config.model.model_type and self.policy_config.model.model_type in ['conv', 'mlp']: - stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) - stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() + if self.policy_config.model.model_type: + if self.policy_config.model.model_type in ['conv', 'mlp']: + stack_obs = to_ndarray(stack_obs) + stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) + stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() + elif self.policy_config.model.model_type == 'structure': + stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) + else: + raise ValueError('model_type must be one of [conv, mlp, structure]') # ============================================================== # policy forward diff --git a/zoo/gobigger/config/gobigger_efficientzero_config.py b/zoo/gobigger/config/gobigger_efficientzero_config.py index b6fb47086..a5951aa32 100644 --- a/zoo/gobigger/config/gobigger_efficientzero_config.py +++ b/zoo/gobigger/config/gobigger_efficientzero_config.py @@ -17,7 +17,7 @@ action_space_size = 27 direction_num = 12 eps_greedy_exploration_in_collect = True -player_num_per_team = 3 +player_num_per_team = 2 team_num = 2 agent_num = player_num_per_team*team_num # default is GoBigger T2P2 # ============================================================== diff --git a/zoo/gobigger/config/gobigger_muzero_config.py b/zoo/gobigger/config/gobigger_muzero_config.py index a5ad0b80e..6a3da6632 100644 --- a/zoo/gobigger/config/gobigger_muzero_config.py +++ b/zoo/gobigger/config/gobigger_muzero_config.py @@ -46,7 +46,7 @@ latent_state_dim=176, frame_stack_num=1, action_space_size=action_space_size, - self_supervised_learning_loss=False, # default is False + self_supervised_learning_loss=True, # default is False discrete_action_encoding_type='one_hot', norm_type='BN', ), diff --git a/zoo/gobigger/entry/train_muzero_gobigger.py b/zoo/gobigger/entry/train_muzero_gobigger.py index 068c76ee1..73e3d231a 100644 --- a/zoo/gobigger/entry/train_muzero_gobigger.py +++ b/zoo/gobigger/entry/train_muzero_gobigger.py @@ -58,7 +58,7 @@ def train_muzero_gobigger( elif create_cfg.policy.type == 'gumbel_muzero': from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer elif create_cfg.policy.type == 'multi_agent_efficientzero': - from lzero.mcts import MultiAgentSampledEfficientZeroGameBuffer as GameBuffer + from lzero.mcts import MultiAgentEfficientZeroGameBuffer as GameBuffer elif create_cfg.policy.type == 'multi_agent_muzero': from lzero.mcts import MultiAgentMuZeroGameBuffer as GameBuffer diff --git a/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py b/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py index 5a874d393..51990d255 100644 --- a/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py +++ b/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py @@ -6,6 +6,7 @@ # ============================================================== # begin of the most frequently changed config specified by the user # ============================================================== +seed = 0 n_agent = 3 n_landmark = n_agent collector_env_num = 8 @@ -16,7 +17,6 @@ update_per_collect = 1000 reanalyze_ratio = 0. action_space_size = 5 -seed = 0 eps_greedy_exploration_in_collect = True # ============================================================== # end of the most frequently changed config specified by the user diff --git a/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py b/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py index a17f1fa6d..4550388c5 100644 --- a/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py +++ b/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py @@ -46,16 +46,18 @@ model=dict( model_type='structure', env_name=env_name, - latent_state_dim=18, + latent_state_dim=256, frame_stack_num=1, action_space='discrete', action_space_size=action_space_size, agent_num=n_agent, - self_supervised_learning_loss=False, # default is False + self_supervised_learning_loss=True, # default is False agent_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2, global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + n_landmark * 2 + n_agent * (n_agent - 1) * 2, discrete_action_encoding_type='one_hot', + global_cooperation=True, # TODO: doesn't work now + hidden_size_list=[256, 256], norm_type='BN', ), cuda=True, @@ -97,7 +99,7 @@ import_names=['zoo.petting_zoo.envs.petting_zoo_simple_spread_env'], type='petting_zoo', ), - env_manager=dict(type='base'), + env_manager=dict(type='subprocess'), policy=dict( type='multi_agent_muzero', import_names=['lzero.policy.multi_agent_muzero'], diff --git a/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py index 9944d8180..87df2dc34 100644 --- a/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py +++ b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py @@ -274,10 +274,10 @@ def _process_obs(self, obs: 'torch.Tensor') -> np.ndarray: # noqa # action_mask: All actions are of use(either 1 for discrete or 5 for continuous). Thus all 1. # action_mask = np.ones((self._num_agents, *self._action_dim)).astype(np.float32) action_mask = [[1 for _ in range(*self._action_dim)] for _ in range(self._num_agents)] - to_play = [-1 for _ in self._agents] # Moot, for alignment with other environments + to_play = [-1 for _ in range(self._num_agents)] # Moot, for alignment with other environments ret_transform = [] - for i in range(len(self.agents)): + for i in range(self._num_agents): tmp = {} for k,v in ret.items(): tmp[k] = v[i] From 3392d61b95471ddccdfa95b479f2fc652fa1ba6e Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Mon, 14 Aug 2023 21:37:27 +0800 Subject: [PATCH 39/54] fix(yzj): fix policy utils obs batch --- lzero/policy/efficientzero.py | 3 +++ lzero/policy/muzero.py | 3 +++ lzero/policy/utils.py | 4 +++- 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/lzero/policy/efficientzero.py b/lzero/policy/efficientzero.py index bfd4681d2..8820fc160 100644 --- a/lzero/policy/efficientzero.py +++ b/lzero/policy/efficientzero.py @@ -189,6 +189,9 @@ class EfficientZeroPolicy(Policy): # (int) The decay steps from start to end eps. decay=int(1e5), ), + + # (bool) Whether it is a multi-agent environment. + multi_agent=False, ) def default_model(self) -> Tuple[str, List[str]]: diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index e2a8c8c97..98c258df1 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -187,6 +187,9 @@ class MuZeroPolicy(Policy): # (int) The decay steps from start to end eps. decay=int(1e5), ), + + # (bool) Whether it is a multi-agent environment. + multi_agent=False, ) def default_model(self) -> Tuple[str, List[str]]: diff --git a/lzero/policy/utils.py b/lzero/policy/utils.py index cea065232..ec2377f58 100644 --- a/lzero/policy/utils.py +++ b/lzero/policy/utils.py @@ -155,6 +155,7 @@ def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, # ``obs_target_batch`` is only used for calculate consistency loss, which take the all obs other than # timestep t1, and is only performed in the last 8 timesteps in the second dim in ``obs_batch_ori``. obs_target_batch = obs_batch_ori[:, cfg.model.image_channel:, :, :] + return obs_batch, obs_target_batch elif cfg.model.model_type == 'mlp': # for 1-dimensional vector obs """ @@ -178,6 +179,7 @@ def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, # ``obs_target_batch`` is only used for calculate consistency loss, which take the all obs other than # timestep t1, and is only performed in the last 8 timesteps in the second dim in ``obs_batch_ori``. obs_target_batch = obs_batch_ori[:, cfg.model.observation_shape:] + return obs_batch, obs_target_batch elif cfg.model.model_type == 'structure': # dict obs_shape = 1 batch_size = obs_batch_ori.shape[0] @@ -216,7 +218,7 @@ def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, obs_target_batch_new[k][k1] = v1.reshape(batch_size, -1) else: # espaecially for ptz obs, {'k1':[], 'k2':[]} obs_target_batch_new[k] = v.reshape(batch_size, -1) - return obs_batch_new, obs_target_batch_new + return obs_batch_new, obs_target_batch_new def negative_cosine_similarity(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: From 9337ce3c98817e700975d03834644fe8231c9c2c Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Mon, 14 Aug 2023 21:48:36 +0800 Subject: [PATCH 40/54] fix(yzj): fix collect mode and eval mode to device --- lzero/policy/multi_agent_efficientzero.py | 3 +++ lzero/policy/multi_agent_muzero.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/lzero/policy/multi_agent_efficientzero.py b/lzero/policy/multi_agent_efficientzero.py index ce5f2eb54..e2e966d14 100644 --- a/lzero/policy/multi_agent_efficientzero.py +++ b/lzero/policy/multi_agent_efficientzero.py @@ -12,6 +12,7 @@ configure_optimizers from collections import defaultdict from ding.utils.data import default_collate +from ding.torch_utils import to_device, to_tensor @POLICY_REGISTRY.register('multi_agent_efficientzero') @@ -61,6 +62,7 @@ def _forward_collect( data = sum(sum(data, []), []) batch_size = len(data) data = default_collate(data) + data = to_device(data, self._device) agent_num = self._cfg['model']['agent_num'] to_play = np.array(to_play).reshape(-1).tolist() @@ -161,6 +163,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read data = sum(sum(data, []), []) batch_size = len(data) data = default_collate(data) + data = to_device(data, self._device) agent_num = self._cfg['model']['agent_num'] to_play = np.array(to_play).reshape(-1).tolist() diff --git a/lzero/policy/multi_agent_muzero.py b/lzero/policy/multi_agent_muzero.py index b77b5356b..2e29c29c4 100644 --- a/lzero/policy/multi_agent_muzero.py +++ b/lzero/policy/multi_agent_muzero.py @@ -17,6 +17,7 @@ from collections import defaultdict from ding.utils.data import default_collate from .muzero import MuZeroPolicy +from ding.torch_utils import to_device, to_tensor @POLICY_REGISTRY.register('multi_agent_muzero') @@ -65,6 +66,7 @@ def _forward_collect( data = sum(sum(data, []), []) batch_size = len(data) data = default_collate(data) + data = to_device(data, self._device) agent_num = self._cfg['model']['agent_num'] to_play = np.array(to_play).reshape(-1).tolist() @@ -161,6 +163,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 data = sum(sum(data, []), []) batch_size = len(data) data = default_collate(data) + data = to_device(data, self._device) agent_num = self._cfg['model']['agent_num'] to_play = np.array(to_play).reshape(-1).tolist() From deab81172d0fcef1f4d2a4f32cdb8b7f2773ab9a Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Tue, 15 Aug 2023 17:06:24 +0800 Subject: [PATCH 41/54] fix(yzj): fix to device bug on policy utils --- lzero/policy/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lzero/policy/utils.py b/lzero/policy/utils.py index ec2377f58..8d0738d3e 100644 --- a/lzero/policy/utils.py +++ b/lzero/policy/utils.py @@ -218,6 +218,8 @@ def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, obs_target_batch_new[k][k1] = v1.reshape(batch_size, -1) else: # espaecially for ptz obs, {'k1':[], 'k2':[]} obs_target_batch_new[k] = v.reshape(batch_size, -1) + obs_batch_new = to_device(obs_batch_new, device=cfg.device) + obs_target_batch_new = to_device(obs_target_batch_new, device=cfg.device) return obs_batch_new, obs_target_batch_new From 705b5f9c4a6ca1c782a1e680eabe9e9127ec376e Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Tue, 15 Aug 2023 17:46:40 +0800 Subject: [PATCH 42/54] polish(yzj): polish multi agent game buffer code --- lzero/entry/train_muzero.py | 10 +- lzero/mcts/buffer/__init__.py | 4 +- .../mcts/buffer/game_buffer_efficientzero.py | 10 +- lzero/mcts/buffer/game_buffer_muzero.py | 12 +- .../multi_agent_game_buffer_efficientzero.py | 254 ----------------- .../buffer/multi_agent_game_buffer_muzero.py | 260 ------------------ zoo/gobigger/entry/train_muzero_gobigger.py | 23 +- 7 files changed, 36 insertions(+), 537 deletions(-) delete mode 100644 lzero/mcts/buffer/multi_agent_game_buffer_efficientzero.py delete mode 100644 lzero/mcts/buffer/multi_agent_game_buffer_muzero.py diff --git a/lzero/entry/train_muzero.py b/lzero/entry/train_muzero.py index f463b266f..f8d1c5358 100644 --- a/lzero/entry/train_muzero.py +++ b/lzero/entry/train_muzero.py @@ -49,16 +49,12 @@ def train_muzero( if create_cfg.policy.type == 'muzero': from lzero.mcts import MuZeroGameBuffer as GameBuffer - elif create_cfg.policy.type == 'efficientzero': + elif create_cfg.policy.type == 'efficientzero' or create_cfg.policy.type == 'multi_agent_efficientzero': from lzero.mcts import EfficientZeroGameBuffer as GameBuffer elif create_cfg.policy.type == 'sampled_efficientzero': from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer - elif create_cfg.policy.type == 'gumbel_muzero': + elif create_cfg.policy.type == 'gumbel_muzero' or create_cfg.policy.type == 'multi_agent_muzero': from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer - elif create_cfg.policy.type == 'multi_agent_muzero': - from lzero.mcts import MultiAgentMuZeroGameBuffer as GameBuffer - elif create_cfg.policy.type == 'multi_agent_efficientzero': - from lzero.mcts import MultiAgentEfficientZeroGameBuffer as GameBuffer if cfg.policy.cuda and torch.cuda.is_available(): cfg.policy.device = 'cuda' @@ -124,7 +120,7 @@ def train_muzero( # ============================================================== # Learner's before_run hook. learner.call_hook('before_run') - + if cfg.policy.update_per_collect is not None: update_per_collect = cfg.policy.update_per_collect diff --git a/lzero/mcts/buffer/__init__.py b/lzero/mcts/buffer/__init__.py index 25cea5ccb..26794ffa0 100644 --- a/lzero/mcts/buffer/__init__.py +++ b/lzero/mcts/buffer/__init__.py @@ -1,6 +1,4 @@ from .game_buffer_muzero import MuZeroGameBuffer from .game_buffer_efficientzero import EfficientZeroGameBuffer from .game_buffer_sampled_efficientzero import SampledEfficientZeroGameBuffer -from .game_buffer_gumbel_muzero import GumbelMuZeroGameBuffer -from .multi_agent_game_buffer_muzero import MultiAgentMuZeroGameBuffer -from .multi_agent_game_buffer_efficientzero import MultiAgentEfficientZeroGameBuffer \ No newline at end of file +from .game_buffer_gumbel_muzero import GumbelMuZeroGameBuffer \ No newline at end of file diff --git a/lzero/mcts/buffer/game_buffer_efficientzero.py b/lzero/mcts/buffer/game_buffer_efficientzero.py index 75e4649d9..154034305 100644 --- a/lzero/mcts/buffer/game_buffer_efficientzero.py +++ b/lzero/mcts/buffer/game_buffer_efficientzero.py @@ -9,6 +9,8 @@ from lzero.mcts.utils import prepare_observation from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform from .game_buffer_muzero import MuZeroGameBuffer +from ding.torch_utils import to_device, to_tensor, to_ndarray +from ding.utils.data import default_collate @BUFFER_REGISTRY.register('game_buffer_efficientzero') @@ -198,7 +200,13 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A beg_index = self._cfg.mini_infer_size * i end_index = self._cfg.mini_infer_size * (i + 1) - m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() + if self._cfg.model.model_type and self._cfg.model.model_type in ['conv', 'mlp']: + m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() + elif self._cfg.model.model_type and self._cfg.model.model_type == 'structure': + m_obs = value_obs_list[beg_index:end_index] + m_obs = sum(m_obs, []) + m_obs = default_collate(m_obs) + m_obs = to_device(m_obs, self._cfg.device) # calculate the target value m_output = model.initial_inference(m_obs) diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index 6aaf04dad..9243b0958 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -9,6 +9,8 @@ from lzero.mcts.utils import prepare_observation from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform from .game_buffer import GameBuffer +from ding.torch_utils import to_device, to_tensor +from ding.utils.data import default_collate if TYPE_CHECKING: from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy @@ -378,8 +380,14 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A for i in range(slices): beg_index = self._cfg.mini_infer_size * i end_index = self._cfg.mini_infer_size * (i + 1) - - m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() + + if self._cfg.model.model_type and self._cfg.model.model_type in ['conv', 'mlp']: + m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() + elif self._cfg.model.model_type and self._cfg.model.model_type == 'structure': + m_obs = value_obs_list[beg_index:end_index] + m_obs = sum(m_obs, []) + m_obs = default_collate(m_obs) + m_obs = to_device(m_obs, self._cfg.device) # calculate the target value m_output = model.initial_inference(m_obs) diff --git a/lzero/mcts/buffer/multi_agent_game_buffer_efficientzero.py b/lzero/mcts/buffer/multi_agent_game_buffer_efficientzero.py deleted file mode 100644 index a592df4b0..000000000 --- a/lzero/mcts/buffer/multi_agent_game_buffer_efficientzero.py +++ /dev/null @@ -1,254 +0,0 @@ -from typing import Any, List - -import numpy as np -import torch -from ding.utils import BUFFER_REGISTRY - -from lzero.mcts.tree_search.mcts_ctree import EfficientZeroMCTSCtree as MCTSCtree -from lzero.mcts.tree_search.mcts_ptree import EfficientZeroMCTSPtree as MCTSPtree -from lzero.mcts.utils import prepare_observation -from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform -from .multi_agent_game_buffer_muzero import MultiAgentMuZeroGameBuffer -from ding.torch_utils import to_device, to_tensor, to_ndarray -from ding.utils.data import default_collate - - -@BUFFER_REGISTRY.register('multi_agent_game_buffer_efficientzero') -class MultiAgentEfficientZeroGameBuffer(MultiAgentMuZeroGameBuffer): - """ - Overview: - The specific game buffer for Multi Agent EfficientZero policy. - """ - - def _prepare_reward_value_context( - self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[Any], - total_transitions: int - ) -> List[Any]: - """ - Overview: - prepare the context of rewards and values for calculating TD value target in reanalyzing part. - Arguments: - - batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer - - game_segment_list (:obj:`list`): list of game segments - - pos_in_game_segment_list (:obj:`list`): list of transition index in game_segment - - total_transitions (:obj:`int`): number of collected transitions - Returns: - - reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, - td_steps_list, action_mask_segment, to_play_segment - """ - value_obs_list = [] - # the value is valid or not (out of trajectory) - value_mask = [] - rewards_list = [] - game_segment_lens = [] - # for two_player board games - action_mask_segment, to_play_segment = [], [] - - td_steps_list = [] - for game_segment, state_index, idx in zip(game_segment_list, pos_in_game_segment_list, batch_index_list): - game_segment_len = len(game_segment) - game_segment_lens.append(game_segment_len) - - # ============================================================== - # EfficientZero related core code - # ============================================================== - # TODO(pu): - # for atari, off-policy correction: shorter horizon of td steps - # delta_td = (total_transitions - idx) // config.auto_td_steps - # td_steps = config.td_steps - delta_td - # td_steps = np.clip(td_steps, 1, 5).astype(np.int) - td_steps = np.clip(self._cfg.td_steps, 1, max(1, game_segment_len - state_index)).astype(np.int32) - - # prepare the corresponding observations for bootstrapped values o_{t+k} - # o[t+ td_steps, t + td_steps + stack frames + num_unroll_steps] - # t=2+3 -> o[2+3, 2+3+4+5] -> o[5, 14] - game_obs = game_segment.get_unroll_obs(state_index + td_steps, self._cfg.num_unroll_steps) - - rewards_list.append(game_segment.reward_segment) - - # for two_player board games - action_mask_segment.append(game_segment.action_mask_segment) - to_play_segment.append(game_segment.to_play_segment) - - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): - # get the bootstrapped target obs - td_steps_list.append(td_steps) - # index of bootstrapped obs o_{t+td_steps} - bootstrap_index = current_index + td_steps - - if bootstrap_index < game_segment_len: - value_mask.append(1) - # beg_index = bootstrap_index - (state_index + td_steps), max of beg_index is num_unroll_steps - beg_index = current_index - state_index - end_index = beg_index + self._cfg.model.frame_stack_num - # the stacked obs in time t - obs = game_obs[beg_index:end_index] - self.tmp_obs = obs # will be masked - else: - value_mask.append(0) - obs = self.tmp_obs # will be masked - - value_obs_list.append(obs.tolist()) - - reward_value_context = [ - value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, - action_mask_segment, to_play_segment - ] - return reward_value_context - - def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any) -> List[np.ndarray]: - """ - Overview: - prepare reward and value targets from the context of rewards and values. - Arguments: - - reward_value_context (:obj:'list'): the reward value context - - model (:obj:'torch.tensor'):model of the target model - Returns: - - batch_value_prefixs (:obj:'np.ndarray): batch of value prefix - - batch_target_values (:obj:'np.ndarray): batch of value estimation - """ - value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, \ - to_play_segment = reward_value_context # noqa - # transition_batch_size = game_segment_batch_size * (num_unroll_steps+1) - transition_batch_size = len(value_obs_list) - game_segment_batch_size = len(pos_in_game_segment_list) - - to_play, action_mask = self._preprocess_to_play_and_action_mask( - game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list - ) - - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] - - # ============================================================== - # EfficientZero related core code - # ============================================================== - batch_target_values, batch_value_prefixs = [], [] - with torch.no_grad(): - value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) - # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors - slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) - network_output = [] - for i in range(slices): - beg_index = self._cfg.mini_infer_size * i - end_index = self._cfg.mini_infer_size * (i + 1) - - #m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() - m_obs = value_obs_list[beg_index:end_index] - m_obs = sum(m_obs, []) - m_obs = default_collate(m_obs) - m_obs = to_device(m_obs, self._cfg.device) - - # calculate the target value - m_output = model.initial_inference(m_obs) - if not model.training: - # ============================================================== - # EfficientZero related core code - # ============================================================== - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), - m_output.policy_logits - ] - ) - m_output.reward_hidden_state = ( - m_output.reward_hidden_state[0].detach().cpu().numpy(), - m_output.reward_hidden_state[1].detach().cpu().numpy() - ) - network_output.append(m_output) - - # concat the output slices after model inference - if self._cfg.use_root_value: - # use the root values from MCTS, as in EfficiientZero - # the root values have limited improvement but require much more GPU actors; - _, value_prefix_pool, policy_logits_pool, latent_state_roots, reward_hidden_state_roots = concat_output( - network_output, data_type='efficientzero' - ) - value_prefix_pool = value_prefix_pool.squeeze().tolist() - policy_logits_pool = policy_logits_pool.tolist() - noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size - ).astype(np.float32).tolist() for _ in range(transition_batch_size) - ] - if self._cfg.mcts_ctree: - # cpp mcts_tree - roots = MCTSCtree.roots(transition_batch_size, legal_actions) - roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play) - # do MCTS for a new policy with the recent target model - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play) - else: - # python mcts_tree - roots = MCTSPtree.roots(transition_batch_size, legal_actions) - roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play) - # do MCTS for a new policy with the recent target model - MCTSPtree(self._cfg).search( - roots, model, latent_state_roots, reward_hidden_state_roots, to_play=to_play - ) - roots_values = roots.get_values() - value_list = np.array(roots_values) - else: - # use the predicted values - value_list = concat_output_value(network_output) - - # get last state value - if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: - # TODO(pu): for board_games, very important, to check - value_list = value_list.reshape(-1) * np.array( - [ - self._cfg.discount_factor ** td_steps_list[i] if int(td_steps_list[i]) % - 2 == 0 else -self._cfg.discount_factor ** td_steps_list[i] - for i in range(transition_batch_size) - ] - ) - else: - value_list = value_list.reshape(-1) * ( - np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list - ) - - value_list = value_list * np.array(value_mask) - value_list = value_list.tolist() - horizon_id, value_index = 0, 0 - for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list, - pos_in_game_segment_list, - to_play_segment): - target_values = [] - target_value_prefixs = [] - value_prefix = 0.0 - base_index = state_index - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): - bootstrap_index = current_index + td_steps_list[value_index] - for i, reward in enumerate(reward_list[current_index:bootstrap_index]): - if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: - # TODO(pu): for board_games, very important, to check - if to_play_list[base_index] == to_play_list[i]: - value_list[value_index] += reward * self._cfg.discount_factor ** i - else: - value_list[value_index] += -reward * self._cfg.discount_factor ** i - else: - value_list[value_index] += reward * self._cfg.discount_factor ** i - - # reset every lstm_horizon_len - if horizon_id % self._cfg.lstm_horizon_len == 0: - value_prefix = 0.0 - base_index = current_index - horizon_id += 1 - - if current_index < game_segment_len_non_re: - target_values.append(value_list[value_index]) - # TODO: Since the horizon is small and the discount_factor is close to 1. - # Compute the reward sum to approximate the value prefix for simplification - value_prefix += reward_list[current_index - ] # * self._cfg.discount_factor ** (current_index - base_index) - target_value_prefixs.append(value_prefix) - else: - target_values.append(0) - target_value_prefixs.append(value_prefix) - value_index += 1 - batch_value_prefixs.append(target_value_prefixs) - batch_target_values.append(target_values) - batch_value_prefixs = np.asarray(batch_value_prefixs, dtype=object) - batch_target_values = np.asarray(batch_target_values, dtype=object) - - return batch_value_prefixs, batch_target_values - diff --git a/lzero/mcts/buffer/multi_agent_game_buffer_muzero.py b/lzero/mcts/buffer/multi_agent_game_buffer_muzero.py deleted file mode 100644 index 53518ac82..000000000 --- a/lzero/mcts/buffer/multi_agent_game_buffer_muzero.py +++ /dev/null @@ -1,260 +0,0 @@ -from typing import Any, List, Tuple, Union, TYPE_CHECKING, Optional - -import numpy as np -import torch -from ding.utils import BUFFER_REGISTRY - -from lzero.mcts.tree_search.mcts_ctree import MuZeroMCTSCtree as MCTSCtree -from lzero.mcts.tree_search.mcts_ptree import MuZeroMCTSPtree as MCTSPtree -from lzero.mcts.utils import prepare_observation -from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform -from .game_buffer import GameBuffer -from ding.torch_utils import to_device, to_tensor -from ding.utils.data import default_collate -from .game_buffer_muzero import MuZeroGameBuffer - -if TYPE_CHECKING: - from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy - -@BUFFER_REGISTRY.register('multi_agent_game_buffer_muzero') -class MultiAgentMuZeroGameBuffer(MuZeroGameBuffer): - """ - Overview: - The specific game buffer for Multi Agent MuZero policy. - """ - - def _prepare_policy_non_reanalyzed_context( - self, batch_index_list: List[int], game_segment_list: List[Any], pos_in_game_segment_list: List[int] - ) -> List[Any]: - """ - Overview: - prepare the context of policies for calculating policy target in non-reanalyzing part, just return the policy in self-play - Arguments: - - batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer - - game_segment_list (:obj:`list`): list of game segments - - pos_in_game_segment_list (:obj:`list`): list transition index in game - Returns: - - policy_non_re_context (:obj:`list`): pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment - """ - child_visits = [] - game_segment_lens = [] - # for board games - action_mask_segment, to_play_segment = [], [] - - for game_segment, state_index, idx in zip(game_segment_list, pos_in_game_segment_list, batch_index_list): - game_segment_len = len(game_segment) - game_segment_lens.append(game_segment_len) - # for board games - action_mask_segment.append(game_segment.action_mask_segment) - to_play_segment.append(game_segment.to_play_segment) - - child_visits.append(game_segment.child_visit_segment) - - policy_non_re_context = [ - pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment - ] - return policy_non_re_context - - def _prepare_policy_reanalyzed_context( - self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str] - ) -> List[Any]: - """ - Overview: - prepare the context of policies for calculating policy target in reanalyzing part. - Arguments: - - batch_index_list (:obj:'list'): start transition index in the replay buffer - - game_segment_list (:obj:'list'): list of game segments - - pos_in_game_segment_list (:obj:'list'): position of transition index in one game history - Returns: - - policy_re_context (:obj:`list`): policy_obs_list, policy_mask, pos_in_game_segment_list, indices, - child_visits, game_segment_lens, action_mask_segment, to_play_segment - """ - zero_obs = game_segment_list[0].zero_obs() - with torch.no_grad(): - # for policy - policy_obs_list = [] - policy_mask = [] - # 0 -> Invalid target policy for padding outside of game segments, - # 1 -> Previous target policy for game segments. - rewards, child_visits, game_segment_lens = [], [], [] - # for board games - action_mask_segment, to_play_segment = [], [] - for game_segment, state_index in zip(game_segment_list, pos_in_game_segment_list): - game_segment_len = len(game_segment) - game_segment_lens.append(game_segment_len) - rewards.append(game_segment.reward_segment) - # for board games - action_mask_segment.append(game_segment.action_mask_segment) - to_play_segment.append(game_segment.to_play_segment) - - child_visits.append(game_segment.child_visit_segment) - # prepare the corresponding observations - game_obs = game_segment.get_unroll_obs(state_index, self._cfg.num_unroll_steps) - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): - - if current_index < game_segment_len: - policy_mask.append(1) - beg_index = current_index - state_index - end_index = beg_index + self._cfg.model.frame_stack_num - obs = game_obs[beg_index:end_index] - else: - policy_mask.append(0) - obs = zero_obs - policy_obs_list.append(obs) - - policy_re_context = [ - policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, - action_mask_segment, to_play_segment - ] - return policy_re_context - - def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any) -> Tuple[Any, Any]: - """ - Overview: - prepare reward and value targets from the context of rewards and values. - Arguments: - - reward_value_context (:obj:'list'): the reward value context - - model (:obj:'torch.tensor'):model of the target model - Returns: - - batch_value_prefixs (:obj:'np.ndarray): batch of value prefix - - batch_target_values (:obj:'np.ndarray): batch of value estimation - """ - value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, \ - to_play_segment = reward_value_context # noqa - # transition_batch_size = game_segment_batch_size * (num_unroll_steps+1) - transition_batch_size = len(value_obs_list) - game_segment_batch_size = len(pos_in_game_segment_list) - - to_play, action_mask = self._preprocess_to_play_and_action_mask( - game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list - ) - if self._cfg.model.continuous_action_space is True: - # when the action space of the environment is continuous, action_mask[:] is None. - action_mask = [ - list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) - ] - # NOTE: in continuous action space env: we set all legal_actions as -1 - legal_actions = [ - [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) - ] - else: - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] - - batch_target_values, batch_rewards = [], [] - with torch.no_grad(): - value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) - # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors - slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) - network_output = [] - for i in range(slices): - beg_index = self._cfg.mini_infer_size * i - end_index = self._cfg.mini_infer_size * (i + 1) - - # m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() - m_obs = value_obs_list[beg_index:end_index] - m_obs = sum(m_obs, []) - m_obs = default_collate(m_obs) - m_obs = to_device(m_obs, self._cfg.device) - - # calculate the target value - m_output = model.initial_inference(m_obs) - - if not model.training: - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), - m_output.policy_logits - ] - ) - - network_output.append(m_output) - - # concat the output slices after model inference - if self._cfg.use_root_value: - # use the root values from MCTS, as in EfficiientZero - # the root values have limited improvement but require much more GPU actors; - _, reward_pool, policy_logits_pool, latent_state_roots = concat_output( - network_output, data_type='muzero' - ) - reward_pool = reward_pool.squeeze().tolist() - policy_logits_pool = policy_logits_pool.tolist() - noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) - ).astype(np.float32).tolist() for j in range(transition_batch_size) - ] - if self._cfg.mcts_ctree: - # cpp mcts_tree - roots = MCTSCtree.roots(transition_batch_size, legal_actions) - roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) - # do MCTS for a new policy with the recent target model - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) - else: - # python mcts_tree - roots = MCTSPtree.roots(transition_batch_size, legal_actions) - roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) - # do MCTS for a new policy with the recent target model - MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) - - roots_values = roots.get_values() - value_list = np.array(roots_values) - else: - # use the predicted values - value_list = concat_output_value(network_output) - - # get last state value - if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: - # TODO(pu): for board_games, very important, to check - value_list = value_list.reshape(-1) * np.array( - [ - self._cfg.discount_factor ** td_steps_list[i] if int(td_steps_list[i]) % - 2 == 0 else -self._cfg.discount_factor ** td_steps_list[i] - for i in range(transition_batch_size) - ] - ) - else: - value_list = value_list.reshape(-1) * ( - np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list - ) - - value_list = value_list * np.array(value_mask) - value_list = value_list.tolist() - horizon_id, value_index = 0, 0 - - for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list, - pos_in_game_segment_list, - to_play_segment): - target_values = [] - target_rewards = [] - base_index = state_index - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): - bootstrap_index = current_index + td_steps_list[value_index] - # for i, reward in enumerate(game.rewards[current_index:bootstrap_index]): - for i, reward in enumerate(reward_list[current_index:bootstrap_index]): - if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: - # TODO(pu): for board_games, very important, to check - if to_play_list[base_index] == to_play_list[i]: - value_list[value_index] += reward * self._cfg.discount_factor ** i - else: - value_list[value_index] += -reward * self._cfg.discount_factor ** i - else: - value_list[value_index] += reward * self._cfg.discount_factor ** i - horizon_id += 1 - - if current_index < game_segment_len_non_re: - target_values.append(value_list[value_index]) - target_rewards.append(reward_list[current_index]) - else: - target_values.append(0) - target_rewards.append(0.0) - # TODO: check - # target_rewards.append(reward) - value_index += 1 - - batch_rewards.append(target_rewards) - batch_target_values.append(target_values) - - batch_rewards = np.asarray(batch_rewards, dtype=object) - batch_target_values = np.asarray(batch_target_values, dtype=object) - return batch_rewards, batch_target_values diff --git a/zoo/gobigger/entry/train_muzero_gobigger.py b/zoo/gobigger/entry/train_muzero_gobigger.py index 73e3d231a..44dbe08f6 100644 --- a/zoo/gobigger/entry/train_muzero_gobigger.py +++ b/zoo/gobigger/entry/train_muzero_gobigger.py @@ -9,15 +9,16 @@ from ding.envs import get_vec_env_setting from ding.policy import create_policy from ding.utils import set_pkg_seed +from ding.rl_utils import get_epsilon_greedy_fn from ding.worker import BaseLearner from tensorboardX import SummaryWriter -import copy -from ding.rl_utils import get_epsilon_greedy_fn + from lzero.entry.utils import log_buffer_memory_usage from lzero.policy import visit_count_temperature from lzero.worker import GoBiggerMuZeroEvaluator from lzero.entry.utils import random_collect -from lzero.policy.multi_agent_random_policy import MultiAgentLightZeroRandomPolicy +import copy + def train_muzero_gobigger( @@ -51,16 +52,12 @@ def train_muzero_gobigger( if create_cfg.policy.type == 'muzero': from lzero.mcts import MuZeroGameBuffer as GameBuffer - elif create_cfg.policy.type == 'efficientzero': + elif create_cfg.policy.type == 'efficientzero' or create_cfg.policy.type == 'multi_agent_efficientzero': from lzero.mcts import EfficientZeroGameBuffer as GameBuffer elif create_cfg.policy.type == 'sampled_efficientzero': from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer - elif create_cfg.policy.type == 'gumbel_muzero': + elif create_cfg.policy.type == 'gumbel_muzero' or create_cfg.policy.type == 'multi_agent_muzero': from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer - elif create_cfg.policy.type == 'multi_agent_efficientzero': - from lzero.mcts import MultiAgentEfficientZeroGameBuffer as GameBuffer - elif create_cfg.policy.type == 'multi_agent_muzero': - from lzero.mcts import MultiAgentMuZeroGameBuffer as GameBuffer if cfg.policy.cuda and torch.cuda.is_available(): cfg.policy.device = 'cuda' @@ -70,6 +67,7 @@ def train_muzero_gobigger( cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) # Create main components: env, policy env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) @@ -100,6 +98,7 @@ def train_muzero_gobigger( batch_size = policy_config.batch_size # specific game buffer for MCTS+RL algorithms replay_buffer = GameBuffer(policy_config) + if policy_config.multi_agent: from lzero.worker import MultiAgentMuZeroCollector as Collector from lzero.worker import MuZeroEvaluator as Evaluator @@ -149,7 +148,11 @@ def train_muzero_gobigger( # Exploration: The collection of random data aids the agent in exploring the environment and prevents premature convergence to a suboptimal policy. # Comparation: The agent's performance during random action-taking can be used as a reference point to evaluate the efficacy of reinforcement learning algorithms. if cfg.policy.random_collect_episode_num > 0: - random_collect(cfg.policy, policy, MultiAgentLightZeroRandomPolicy, collector, collector_env, replay_buffer) + if policy_config.multi_agent: + from lzero.policy.multi_agent_random_policy import MultiAgentLightZeroRandomPolicy as RandomPolicy + else: + from lzero.policy.random_policy import LightZeroRandomPolicy as RandomPolicy + random_collect(cfg.policy, policy, RandomPolicy, collector, collector_env, replay_buffer) while True: log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) From 43b2bb5d99c44dc53b74a1b9e7b012d597c13b1b Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Tue, 15 Aug 2023 17:54:42 +0800 Subject: [PATCH 43/54] polish(yzj): polish code --- lzero/entry/__init__.py | 2 +- lzero/entry/utils.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py index 352d29ddf..f68d876a6 100644 --- a/lzero/entry/__init__.py +++ b/lzero/entry/__init__.py @@ -3,4 +3,4 @@ from .train_muzero import train_muzero from .eval_muzero import eval_muzero from .eval_muzero_with_gym_env import eval_muzero_with_gym_env -from .train_muzero_with_gym_env import train_muzero_with_gym_env +from .train_muzero_with_gym_env import train_muzero_with_gym_env \ No newline at end of file diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index 2f8e2aae1..b11e37d0a 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -1,8 +1,6 @@ import os -from typing import Optional, Callable import psutil -from easydict import EasyDict from pympler.asizeof import asizeof from tensorboardX import SummaryWriter from typing import Optional, Callable @@ -41,6 +39,7 @@ def random_collect( # restore the policy collector.reset_policy(policy.collect_mode) + def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None: """ Overview: From 3d88a17cde078653fd19c2ecfa61dcf83053602e Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Tue, 15 Aug 2023 20:30:54 +0800 Subject: [PATCH 44/54] fix(yzj): fix priority bug, polish priority related config, add all agent obs to ptz --- lzero/model/petting_zoo/encoder.py | 2 +- lzero/policy/efficientzero.py | 2 -- lzero/policy/gumbel_muzero.py | 2 -- lzero/policy/multi_agent_efficientzero.py | 4 ++-- lzero/policy/multi_agent_muzero.py | 4 ++-- lzero/policy/muzero.py | 2 -- lzero/policy/sampled_efficientzero.py | 2 -- lzero/worker/multi_agent_muzero_collector.py | 4 ++-- lzero/worker/muzero_collector.py | 6 +++--- zoo/gobigger/config/gobigger_efficientzero_config.py | 2 +- zoo/gobigger/config/gobigger_muzero_config.py | 2 +- zoo/petting_zoo/config/ptz_simple_spread_ez_config.py | 2 +- zoo/petting_zoo/config/ptz_simple_spread_mz_config.py | 2 +- zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py | 7 +++++++ 14 files changed, 21 insertions(+), 22 deletions(-) diff --git a/lzero/model/petting_zoo/encoder.py b/lzero/model/petting_zoo/encoder.py index 9f3313731..8c9e92c59 100644 --- a/lzero/model/petting_zoo/encoder.py +++ b/lzero/model/petting_zoo/encoder.py @@ -5,7 +5,7 @@ class PettingZooEncoder(nn.Module): def __init__(self): super().__init__() - self.encoder = FCEncoder(obs_shape=18, hidden_size_list=[256, 256], activation=nn.ReLU(), norm_type=None) + self.encoder = FCEncoder(obs_shape=54, hidden_size_list=[256, 256], activation=nn.ReLU(), norm_type=None) def forward(self, x): x = x['agent_state'] diff --git a/lzero/policy/efficientzero.py b/lzero/policy/efficientzero.py index 8820fc160..55fe57cb2 100644 --- a/lzero/policy/efficientzero.py +++ b/lzero/policy/efficientzero.py @@ -157,8 +157,6 @@ class EfficientZeroPolicy(Policy): # ****** Priority ****** # (bool) Whether to use priority when sampling training data from the buffer. use_priority=True, - # (bool) Whether to use the maximum priority for new collecting data. - use_max_priority_for_new_data=True, # (float) The degree of prioritization to use. A value of 0 means no prioritization, # while a value of 1 means full prioritization. priority_prob_alpha=0.6, diff --git a/lzero/policy/gumbel_muzero.py b/lzero/policy/gumbel_muzero.py index f7db4d148..707623f9a 100644 --- a/lzero/policy/gumbel_muzero.py +++ b/lzero/policy/gumbel_muzero.py @@ -153,8 +153,6 @@ class GumeblMuZeroPolicy(Policy): # ****** Priority ****** # (bool) Whether to use priority when sampling training data from the buffer. use_priority=True, - # (bool) Whether to use the maximum priority for new collecting data. - use_max_priority_for_new_data=True, # (float) The degree of prioritization to use. A value of 0 means no prioritization, # while a value of 1 means full prioritization. priority_prob_alpha=0.6, diff --git a/lzero/policy/multi_agent_efficientzero.py b/lzero/policy/multi_agent_efficientzero.py index e2e966d14..7ed1ab414 100644 --- a/lzero/policy/multi_agent_efficientzero.py +++ b/lzero/policy/multi_agent_efficientzero.py @@ -64,6 +64,7 @@ def _forward_collect( data = default_collate(data) data = to_device(data, self._device) agent_num = self._cfg['model']['agent_num'] + action_mask = sum(action_mask, []) to_play = np.array(to_play).reshape(-1).tolist() with torch.no_grad(): @@ -81,7 +82,6 @@ def _forward_collect( ) policy_logits = policy_logits.detach().cpu().numpy().tolist() - action_mask = sum(action_mask, []) legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] # the only difference between collect and eval is the dirichlet noise. noises = [ @@ -165,6 +165,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read data = default_collate(data) data = to_device(data, self._device) agent_num = self._cfg['model']['agent_num'] + action_mask = sum(action_mask, []) to_play = np.array(to_play).reshape(-1).tolist() with torch.no_grad(): @@ -184,7 +185,6 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read ) policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) - action_mask = sum(action_mask, []) legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] if self._cfg.mcts_ctree: # cpp mcts_tree diff --git a/lzero/policy/multi_agent_muzero.py b/lzero/policy/multi_agent_muzero.py index 2e29c29c4..97ebe2eec 100644 --- a/lzero/policy/multi_agent_muzero.py +++ b/lzero/policy/multi_agent_muzero.py @@ -68,6 +68,7 @@ def _forward_collect( data = default_collate(data) data = to_device(data, self._device) agent_num = self._cfg['model']['agent_num'] + action_mask = sum(action_mask, []) to_play = np.array(to_play).reshape(-1).tolist() with torch.no_grad(): @@ -81,7 +82,6 @@ def _forward_collect( latent_state_roots = latent_state_roots.detach().cpu().numpy() policy_logits = policy_logits.detach().cpu().numpy().tolist() - action_mask = sum(action_mask, []) legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] # the only difference between collect and eval is the dirichlet noise noises = [ @@ -165,6 +165,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 data = default_collate(data) data = to_device(data, self._device) agent_num = self._cfg['model']['agent_num'] + action_mask = sum(action_mask, []) to_play = np.array(to_play).reshape(-1).tolist() with torch.no_grad(): @@ -178,7 +179,6 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 latent_state_roots = latent_state_roots.detach().cpu().numpy() policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) - action_mask = sum(action_mask, []) legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] if self._cfg.mcts_ctree: # cpp mcts_tree diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index 98c258df1..a7e9041e1 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -155,8 +155,6 @@ class MuZeroPolicy(Policy): # ****** Priority ****** # (bool) Whether to use priority when sampling training data from the buffer. use_priority=True, - # (bool) Whether to use the maximum priority for new collecting data. - use_max_priority_for_new_data=True, # (float) The degree of prioritization to use. A value of 0 means no prioritization, # while a value of 1 means full prioritization. priority_prob_alpha=0.6, diff --git a/lzero/policy/sampled_efficientzero.py b/lzero/policy/sampled_efficientzero.py index 665cf011b..b39720bee 100644 --- a/lzero/policy/sampled_efficientzero.py +++ b/lzero/policy/sampled_efficientzero.py @@ -169,8 +169,6 @@ class SampledEfficientZeroPolicy(Policy): # ****** Priority ****** # (bool) Whether to use priority when sampling training data from the buffer. use_priority=True, - # (bool) Whether to use the maximum priority for new collecting data. - use_max_priority_for_new_data=True, # (float) The degree of prioritization to use. A value of 0 means no prioritization, # while a value of 1 means full prioritization. priority_prob_alpha=0.6, diff --git a/lzero/worker/multi_agent_muzero_collector.py b/lzero/worker/multi_agent_muzero_collector.py index bd5dc7544..76916b44d 100644 --- a/lzero/worker/multi_agent_muzero_collector.py +++ b/lzero/worker/multi_agent_muzero_collector.py @@ -36,14 +36,14 @@ def _compute_priorities(self, i, agent_id, pred_values_lst, search_values_lst): - pred_values_lst: The list of value being predicted. - search_values_lst: The list of value obtained through search. """ - if self.policy_config.use_priority and not self.policy_config.use_max_priority_for_new_data: + if self.policy_config.use_priority: pred_values = torch.from_numpy(np.array(pred_values_lst[i][agent_id])).to(self.policy_config.device ).float().view(-1) search_values = torch.from_numpy(np.array(search_values_lst[i][agent_id])).to(self.policy_config.device ).float().view(-1) priorities = L1Loss(reduction='none' )(pred_values, - search_values).detach().cpu().numpy() + self.policy_config.prioritized_replay_eps + search_values).detach().cpu().numpy() + 1e-6 # avoid zero priority else: # priorities is None -> use the max priority for all newly collected data priorities = None diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index dc1aafbb5..67747a321 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -196,13 +196,13 @@ def _compute_priorities(self, i, pred_values_lst, search_values_lst): - pred_values_lst: The list of value being predicted. - search_values_lst: The list of value obtained through search. """ - if self.policy_config.use_priority and not self.policy_config.use_max_priority_for_new_data: + if self.policy_config.use_priority: pred_values = torch.from_numpy(np.array(pred_values_lst[i])).to(self.policy_config.device).float().view(-1) search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self.policy_config.device ).float().view(-1) priorities = L1Loss(reduction='none' )(pred_values, - search_values).detach().cpu().numpy() + self.policy_config.prioritized_replay_eps + search_values).detach().cpu().numpy() + 1e-6 # avoid zero priority else: # priorities is None -> use the max priority for all newly collected data priorities = None @@ -555,7 +555,7 @@ def collect(self, eps_steps_lst[env_id] += 1 total_transitions += 1 - if self.policy_config.use_priority and not self.policy_config.use_max_priority_for_new_data: + if self.policy_config.use_priority: if self._multi_agent: for agent_id in range(agent_num): pred_values_lst[env_id][agent_id].append(pred_value_dict[env_id][agent_id]) diff --git a/zoo/gobigger/config/gobigger_efficientzero_config.py b/zoo/gobigger/config/gobigger_efficientzero_config.py index a5951aa32..d9780281b 100644 --- a/zoo/gobigger/config/gobigger_efficientzero_config.py +++ b/zoo/gobigger/config/gobigger_efficientzero_config.py @@ -54,7 +54,7 @@ mcts_ctree=True, gumbel_algo=False, env_type='not_board_games', - game_segment_length=400, + game_segment_length=500, random_collect_episode_num=0, eps=dict( eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, diff --git a/zoo/gobigger/config/gobigger_muzero_config.py b/zoo/gobigger/config/gobigger_muzero_config.py index 6a3da6632..81fd0397a 100644 --- a/zoo/gobigger/config/gobigger_muzero_config.py +++ b/zoo/gobigger/config/gobigger_muzero_config.py @@ -54,7 +54,7 @@ mcts_ctree=True, gumbel_algo=False, env_type='not_board_games', - game_segment_length=400, + game_segment_length=500, random_collect_episode_num=0, eps=dict( eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, diff --git a/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py b/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py index 51990d255..995b90988 100644 --- a/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py +++ b/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py @@ -51,7 +51,7 @@ action_space='discrete', action_space_size=action_space_size, agent_num=n_agent, - agent_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2, + agent_obs_shape=(2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2)*3, global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + n_landmark * 2 + n_agent * (n_agent - 1) * 2, discrete_action_encoding_type='one_hot', diff --git a/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py b/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py index 4550388c5..3facc5e3e 100644 --- a/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py +++ b/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py @@ -52,7 +52,7 @@ action_space_size=action_space_size, agent_num=n_agent, self_supervised_learning_loss=True, # default is False - agent_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2, + agent_obs_shape=(2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2)*3, global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + n_landmark * 2 + n_agent * (n_agent - 1) * 2, discrete_action_encoding_type='one_hot', diff --git a/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py index 87df2dc34..4e8638fd8 100644 --- a/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py +++ b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py @@ -284,6 +284,13 @@ def _process_obs(self, obs: 'torch.Tensor') -> np.ndarray: # noqa tmp['action_mask'] = [1 for _ in range(*self._action_dim)] ret_transform.append(tmp) + concat_state_0 = np.concatenate((ret_transform[0]['agent_state'], ret_transform[1]['agent_state'], ret_transform[2]['agent_state']), axis=0) + concat_state_1 = np.concatenate((ret_transform[1]['agent_state'], ret_transform[0]['agent_state'], ret_transform[2]['agent_state']), axis=0) + concat_state_2 = np.concatenate((ret_transform[2]['agent_state'], ret_transform[0]['agent_state'], ret_transform[1]['agent_state']), axis=0) + + ret_transform[0]['agent_state'] = concat_state_0 + ret_transform[1]['agent_state'] = concat_state_1 + ret_transform[2]['agent_state'] = concat_state_2 return {'observation': ret_transform, 'action_mask': action_mask, 'to_play': to_play} def _process_action(self, action: 'torch.Tensor') -> Dict[str, np.ndarray]: # noqa From a09517abd65bae3678b36eed6c95b72eef1c32a3 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Tue, 15 Aug 2023 23:42:42 +0800 Subject: [PATCH 45/54] polish(yzj): polish train entry --- lzero/entry/train_muzero.py | 4 ++-- zoo/gobigger/entry/train_muzero_gobigger.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lzero/entry/train_muzero.py b/lzero/entry/train_muzero.py index f8d1c5358..e91f19209 100644 --- a/lzero/entry/train_muzero.py +++ b/lzero/entry/train_muzero.py @@ -47,13 +47,13 @@ def train_muzero( assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'multi_agent_efficientzero', 'multi_agent_muzero'], \ "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'multi_agent_efficientzero', 'multi_agent_muzero'" - if create_cfg.policy.type == 'muzero': + if create_cfg.policy.type == 'muzero' or create_cfg.policy.type == 'multi_agent_muzero': from lzero.mcts import MuZeroGameBuffer as GameBuffer elif create_cfg.policy.type == 'efficientzero' or create_cfg.policy.type == 'multi_agent_efficientzero': from lzero.mcts import EfficientZeroGameBuffer as GameBuffer elif create_cfg.policy.type == 'sampled_efficientzero': from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer - elif create_cfg.policy.type == 'gumbel_muzero' or create_cfg.policy.type == 'multi_agent_muzero': + elif create_cfg.policy.type == 'gumbel_muzero': from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer if cfg.policy.cuda and torch.cuda.is_available(): diff --git a/zoo/gobigger/entry/train_muzero_gobigger.py b/zoo/gobigger/entry/train_muzero_gobigger.py index 44dbe08f6..b261c4786 100644 --- a/zoo/gobigger/entry/train_muzero_gobigger.py +++ b/zoo/gobigger/entry/train_muzero_gobigger.py @@ -50,13 +50,13 @@ def train_muzero_gobigger( assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'multi_agent_efficientzero', 'multi_agent_muzero'], \ "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'multi_agent_efficientzero', 'multi_agent_muzero'" - if create_cfg.policy.type == 'muzero': + if create_cfg.policy.type == 'muzero' or create_cfg.policy.type == 'multi_agent_muzero': from lzero.mcts import MuZeroGameBuffer as GameBuffer elif create_cfg.policy.type == 'efficientzero' or create_cfg.policy.type == 'multi_agent_efficientzero': from lzero.mcts import EfficientZeroGameBuffer as GameBuffer elif create_cfg.policy.type == 'sampled_efficientzero': from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer - elif create_cfg.policy.type == 'gumbel_muzero' or create_cfg.policy.type == 'multi_agent_muzero': + elif create_cfg.policy.type == 'gumbel_muzero': from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer if cfg.policy.cuda and torch.cuda.is_available(): From 714ba4b7b0f161a18dc597aa4feaf5854e896f1f Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Wed, 16 Aug 2023 22:57:26 +0800 Subject: [PATCH 46/54] polish(yzj): polish gobigger config --- zoo/gobigger/config/gobigger_efficientzero_config.py | 3 ++- zoo/gobigger/config/gobigger_muzero_config.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/zoo/gobigger/config/gobigger_efficientzero_config.py b/zoo/gobigger/config/gobigger_efficientzero_config.py index d9780281b..0383501fc 100644 --- a/zoo/gobigger/config/gobigger_efficientzero_config.py +++ b/zoo/gobigger/config/gobigger_efficientzero_config.py @@ -11,7 +11,7 @@ n_episode = 32 evaluator_env_num = 5 num_simulations = 50 -update_per_collect = 2000 +update_per_collect = 1000 batch_size = 256 reanalyze_ratio = 0. action_space_size = 27 @@ -69,6 +69,7 @@ optim_type='SGD', lr_piecewise_constant_decay=True, learning_rate=0.2, + ssl_loss_weight=0, num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, n_episode=n_episode, diff --git a/zoo/gobigger/config/gobigger_muzero_config.py b/zoo/gobigger/config/gobigger_muzero_config.py index 81fd0397a..b09bd8cb5 100644 --- a/zoo/gobigger/config/gobigger_muzero_config.py +++ b/zoo/gobigger/config/gobigger_muzero_config.py @@ -11,7 +11,7 @@ n_episode = 32 evaluator_env_num = 5 num_simulations = 50 -update_per_collect = 2000 +update_per_collect = 1000 batch_size = 256 reanalyze_ratio = 0. action_space_size = 27 From 0ee0122244030ae7522efce77dab45914fdbb6d2 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Fri, 18 Aug 2023 15:37:08 +0800 Subject: [PATCH 47/54] polish(yzj): polish best gobigger config on ez/mz --- zoo/gobigger/config/gobigger_muzero_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/zoo/gobigger/config/gobigger_muzero_config.py b/zoo/gobigger/config/gobigger_muzero_config.py index b09bd8cb5..1378ab1ba 100644 --- a/zoo/gobigger/config/gobigger_muzero_config.py +++ b/zoo/gobigger/config/gobigger_muzero_config.py @@ -46,7 +46,7 @@ latent_state_dim=176, frame_stack_num=1, action_space_size=action_space_size, - self_supervised_learning_loss=True, # default is False + self_supervised_learning_loss=False, discrete_action_encoding_type='one_hot', norm_type='BN', ), @@ -69,9 +69,9 @@ optim_type='SGD', lr_piecewise_constant_decay=True, learning_rate=0.2, + ssl_loss_weight=0, num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, - ssl_loss_weight=0, # default is 0 n_episode=n_episode, eval_freq=int(2e3), replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. From 71ce58e7443b31440708d63ccd281ada3a1853e1 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Fri, 18 Aug 2023 16:30:53 +0800 Subject: [PATCH 48/54] polish(yzj): polish collector to adapt multi-agent mode --- lzero/entry/train_muzero.py | 9 +- lzero/worker/__init__.py | 1 - lzero/worker/multi_agent_muzero_collector.py | 111 ------------------ lzero/worker/muzero_collector.py | 91 +++++++++++++- .../config/gobigger_efficientzero_config.py | 4 +- zoo/gobigger/config/gobigger_muzero_config.py | 4 +- zoo/gobigger/entry/train_muzero_gobigger.py | 9 +- .../config/ptz_simple_spread_ez_config.py | 4 +- .../config/ptz_simple_spread_mz_config.py | 4 +- 9 files changed, 99 insertions(+), 138 deletions(-) delete mode 100644 lzero/worker/multi_agent_muzero_collector.py diff --git a/lzero/entry/train_muzero.py b/lzero/entry/train_muzero.py index e91f19209..74e454038 100644 --- a/lzero/entry/train_muzero.py +++ b/lzero/entry/train_muzero.py @@ -13,6 +13,8 @@ from ding.worker import BaseLearner from tensorboardX import SummaryWriter +from lzero.worker import MuZeroCollector as Collector +from lzero.worker import MuZeroEvaluator as Evaluator from lzero.entry.utils import log_buffer_memory_usage from lzero.policy import visit_count_temperature from .utils import random_collect @@ -90,13 +92,6 @@ def train_muzero( # specific game buffer for MCTS+RL algorithms replay_buffer = GameBuffer(policy_config) - if policy_config.multi_agent: - from lzero.worker import MultiAgentMuZeroCollector as Collector - from lzero.worker import MuZeroEvaluator as Evaluator - else: - from lzero.worker import MuZeroCollector as Collector - from lzero.worker import MuZeroEvaluator as Evaluator - collector = Collector( env=collector_env, policy=policy.collect_mode, diff --git a/lzero/worker/__init__.py b/lzero/worker/__init__.py index 341112f60..7e81f3e72 100644 --- a/lzero/worker/__init__.py +++ b/lzero/worker/__init__.py @@ -2,5 +2,4 @@ from .alphazero_evaluator import AlphaZeroEvaluator from .muzero_collector import MuZeroCollector from .muzero_evaluator import MuZeroEvaluator -from .multi_agent_muzero_collector import MultiAgentMuZeroCollector from .gobigger_muzero_evaluator import GoBiggerMuZeroEvaluator \ No newline at end of file diff --git a/lzero/worker/multi_agent_muzero_collector.py b/lzero/worker/multi_agent_muzero_collector.py deleted file mode 100644 index 76916b44d..000000000 --- a/lzero/worker/multi_agent_muzero_collector.py +++ /dev/null @@ -1,111 +0,0 @@ -import time -from collections import deque, namedtuple -from typing import Optional, Any, List - -import numpy as np -import torch -from ding.envs import BaseEnvManager -from ding.torch_utils import to_ndarray -from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY -from .muzero_collector import MuZeroCollector -from torch.nn import L1Loss - -from lzero.mcts.buffer.game_segment import GameSegment -from lzero.mcts.utils import prepare_observation -from collections import defaultdict - - -@SERIAL_COLLECTOR_REGISTRY.register('multi_agent_episode_muzero') -class MultiAgentMuZeroCollector(MuZeroCollector): - """ - Overview: - The Collector for Multi Agent MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. - For Multi Agent, add agent_num dim in game_segment. - Interfaces: - __init__, reset, reset_env, reset_policy, collect, close - Property: - envstep - """ - - def _compute_priorities(self, i, agent_id, pred_values_lst, search_values_lst): - """ - Overview: - obtain the priorities at index i. - Arguments: - - i: index. - - pred_values_lst: The list of value being predicted. - - search_values_lst: The list of value obtained through search. - """ - if self.policy_config.use_priority: - pred_values = torch.from_numpy(np.array(pred_values_lst[i][agent_id])).to(self.policy_config.device - ).float().view(-1) - search_values = torch.from_numpy(np.array(search_values_lst[i][agent_id])).to(self.policy_config.device - ).float().view(-1) - priorities = L1Loss(reduction='none' - )(pred_values, - search_values).detach().cpu().numpy() + 1e-6 # avoid zero priority - else: - # priorities is None -> use the max priority for all newly collected data - priorities = None - - return priorities - - def pad_and_save_last_trajectory( - self, i, agent_id, last_game_segments, last_game_priorities, game_segments, done - ) -> None: - """ - Overview: - put the last game block into the pool if the current game is finished - Arguments: - - last_game_segments (:obj:`list`): list of the last game segments - - last_game_priorities (:obj:`list`): list of the last game priorities - - game_segments (:obj:`list`): list of the current game segments - Note: - (last_game_segments[i].obs_segment[-4:][j] == game_segments[i].obs_segment[:4][j]).all() is True - """ - # pad over last block trajectory - beg_index = self.policy_config.model.frame_stack_num - end_index = beg_index + self.policy_config.num_unroll_steps - - # the start obs is init zero obs, so we take the [ : +] obs as the pad obs - # e.g. the start 4 obs is init zero obs, the num_unroll_steps is 5, so we take the [4:9] obs as the pad obs - pad_obs_lst = game_segments[i][agent_id].obs_segment[beg_index:end_index] - pad_child_visits_lst = game_segments[i][agent_id].child_visit_segment[:self.policy_config.num_unroll_steps] - # EfficientZero original repo bug: - # pad_child_visits_lst = game_segments[i].child_visit_segment[beg_index:end_index] - - beg_index = 0 - # self.unroll_plus_td_steps = self.policy_config.num_unroll_steps + self.policy_config.td_steps - end_index = beg_index + self.unroll_plus_td_steps - 1 - - pad_reward_lst = game_segments[i][agent_id].reward_segment[beg_index:end_index] - - beg_index = 0 - end_index = beg_index + self.unroll_plus_td_steps - - pad_root_values_lst = game_segments[i][agent_id].root_value_segment[beg_index:end_index] - - # pad over and save - last_game_segments[i][agent_id].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst) - """ - Note: - game_segment element shape: - obs: game_segment_length + stack + num_unroll_steps, 20+4 +5 - rew: game_segment_length + stack + num_unroll_steps + td_steps -1 20 +5+3-1 - action: game_segment_length -> 20 - root_values: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 - child_visits: game_segment_length + num_unroll_steps -> 20 +5 - to_play: game_segment_length -> 20 - action_mask: game_segment_length -> 20 - """ - - last_game_segments[i][agent_id].game_segment_to_array() - - # put the game block into the pool - self.game_segment_pool.append((last_game_segments[i][agent_id], last_game_priorities[i][agent_id], done[i])) - - # reset last game_segments - last_game_segments[i][agent_id] = None - last_game_priorities[i][agent_id] = None - - return None diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 67747a321..9f759a6ae 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -272,6 +272,89 @@ def pad_and_save_last_trajectory(self, i, last_game_segments, last_game_prioriti last_game_priorities[i] = None return None + + def _compute_priorities_for_agent(self, i, agent_id, pred_values_lst, search_values_lst): + """ + Overview: + obtain the priorities at index i. + Arguments: + - i: index. + - pred_values_lst: The list of value being predicted. + - search_values_lst: The list of value obtained through search. + """ + if self.policy_config.use_priority: + pred_values = torch.from_numpy(np.array(pred_values_lst[i][agent_id])).to(self.policy_config.device + ).float().view(-1) + search_values = torch.from_numpy(np.array(search_values_lst[i][agent_id])).to(self.policy_config.device + ).float().view(-1) + priorities = L1Loss(reduction='none' + )(pred_values, + search_values).detach().cpu().numpy() + 1e-6 # avoid zero priority + else: + # priorities is None -> use the max priority for all newly collected data + priorities = None + + return priorities + + def pad_and_save_last_trajectory_for_agent( + self, i, agent_id, last_game_segments, last_game_priorities, game_segments, done + ) -> None: + """ + Overview: + put the last game block into the pool if the current game is finished + Arguments: + - last_game_segments (:obj:`list`): list of the last game segments + - last_game_priorities (:obj:`list`): list of the last game priorities + - game_segments (:obj:`list`): list of the current game segments + Note: + (last_game_segments[i].obs_segment[-4:][j] == game_segments[i].obs_segment[:4][j]).all() is True + """ + # pad over last block trajectory + beg_index = self.policy_config.model.frame_stack_num + end_index = beg_index + self.policy_config.num_unroll_steps + + # the start obs is init zero obs, so we take the [ : +] obs as the pad obs + # e.g. the start 4 obs is init zero obs, the num_unroll_steps is 5, so we take the [4:9] obs as the pad obs + pad_obs_lst = game_segments[i][agent_id].obs_segment[beg_index:end_index] + pad_child_visits_lst = game_segments[i][agent_id].child_visit_segment[:self.policy_config.num_unroll_steps] + # EfficientZero original repo bug: + # pad_child_visits_lst = game_segments[i].child_visit_segment[beg_index:end_index] + + beg_index = 0 + # self.unroll_plus_td_steps = self.policy_config.num_unroll_steps + self.policy_config.td_steps + end_index = beg_index + self.unroll_plus_td_steps - 1 + + pad_reward_lst = game_segments[i][agent_id].reward_segment[beg_index:end_index] + + beg_index = 0 + end_index = beg_index + self.unroll_plus_td_steps + + pad_root_values_lst = game_segments[i][agent_id].root_value_segment[beg_index:end_index] + + # pad over and save + last_game_segments[i][agent_id].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst) + """ + Note: + game_segment element shape: + obs: game_segment_length + stack + num_unroll_steps, 20+4 +5 + rew: game_segment_length + stack + num_unroll_steps + td_steps -1 20 +5+3-1 + action: game_segment_length -> 20 + root_values: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 + child_visits: game_segment_length + num_unroll_steps -> 20 +5 + to_play: game_segment_length -> 20 + action_mask: game_segment_length -> 20 + """ + + last_game_segments[i][agent_id].game_segment_to_array() + + # put the game block into the pool + self.game_segment_pool.append((last_game_segments[i][agent_id], last_game_priorities[i][agent_id], done[i])) + + # reset last game_segments + last_game_segments[i][agent_id] = None + last_game_priorities[i][agent_id] = None + + return None def collect(self, n_episode: Optional[int] = None, @@ -584,12 +667,12 @@ def collect(self, # pad over last block trajectory if last_game_segments[env_id][agent_id] is not None: # TODO(pu): return the one game block - self.pad_and_save_last_trajectory( + self.pad_and_save_last_trajectory_for_agent( env_id, agent_id, last_game_segments, last_game_priorities, game_segments, dones ) # calculate priority - priorities = self._compute_priorities(env_id, agent_id, pred_values_lst, search_values_lst) + priorities = self._compute_priorities_for_agent(env_id, agent_id, pred_values_lst, search_values_lst) # pred_values_lst[env_id] = [] # search_values_lst[env_id] = [] search_values_lst = [[[] for _ in range(agent_num)] for _ in range(env_nums)] @@ -662,12 +745,12 @@ def collect(self, if self._multi_agent: for agent_id in range(agent_num): if last_game_segments[env_id][agent_id] is not None: - self.pad_and_save_last_trajectory( + self.pad_and_save_last_trajectory_for_agent( env_id, agent_id, last_game_segments, last_game_priorities, game_segments, dones ) # store current block trajectory - priorities = self._compute_priorities(env_id, agent_id, pred_values_lst, search_values_lst) + priorities = self._compute_priorities_for_agent(env_id, agent_id, pred_values_lst, search_values_lst) # NOTE: put the last game block in one episode into the trajectory_pool game_segments[env_id][agent_id].game_segment_to_array() diff --git a/zoo/gobigger/config/gobigger_efficientzero_config.py b/zoo/gobigger/config/gobigger_efficientzero_config.py index 0383501fc..57e358422 100644 --- a/zoo/gobigger/config/gobigger_efficientzero_config.py +++ b/zoo/gobigger/config/gobigger_efficientzero_config.py @@ -97,8 +97,8 @@ import_names=['lzero.policy.multi_agent_efficientzero'], ), collector=dict( - type='multi_agent_episode_muzero', - import_names=['lzero.worker.multi_agent_muzero_collector'], + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], ) ) gobigger_efficientzero_create_config = EasyDict(gobigger_efficientzero_create_config) diff --git a/zoo/gobigger/config/gobigger_muzero_config.py b/zoo/gobigger/config/gobigger_muzero_config.py index 1378ab1ba..9638b527e 100644 --- a/zoo/gobigger/config/gobigger_muzero_config.py +++ b/zoo/gobigger/config/gobigger_muzero_config.py @@ -97,8 +97,8 @@ import_names=['lzero.policy.multi_agent_muzero'], ), collector=dict( - type='multi_agent_episode_muzero', - import_names=['lzero.worker.multi_agent_muzero_collector'], + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], ) ) gobigger_muzero_create_config = EasyDict(gobigger_muzero_create_config) diff --git a/zoo/gobigger/entry/train_muzero_gobigger.py b/zoo/gobigger/entry/train_muzero_gobigger.py index b261c4786..991e64575 100644 --- a/zoo/gobigger/entry/train_muzero_gobigger.py +++ b/zoo/gobigger/entry/train_muzero_gobigger.py @@ -15,12 +15,13 @@ from lzero.entry.utils import log_buffer_memory_usage from lzero.policy import visit_count_temperature +from lzero.worker import MuZeroCollector as Collector +from lzero.worker import MuZeroEvaluator as Evaluator from lzero.worker import GoBiggerMuZeroEvaluator from lzero.entry.utils import random_collect import copy - def train_muzero_gobigger( input_cfg: Tuple[dict, dict], seed: int = 0, @@ -99,12 +100,6 @@ def train_muzero_gobigger( # specific game buffer for MCTS+RL algorithms replay_buffer = GameBuffer(policy_config) - if policy_config.multi_agent: - from lzero.worker import MultiAgentMuZeroCollector as Collector - from lzero.worker import MuZeroEvaluator as Evaluator - else: - from lzero.worker import MuZeroCollector as Collector - from lzero.worker import MuZeroEvaluator as Evaluator collector = Collector( env=collector_env, policy=policy.collect_mode, diff --git a/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py b/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py index 995b90988..8e1cb9984 100644 --- a/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py +++ b/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py @@ -103,8 +103,8 @@ import_names=['lzero.policy.multi_agent_efficientzero'], ), collector=dict( - type='multi_agent_episode_muzero', - import_names=['lzero.worker.multi_agent_muzero_collector'], + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], ) ) create_config = EasyDict(create_config) diff --git a/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py b/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py index 3facc5e3e..dac1cfc39 100644 --- a/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py +++ b/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py @@ -105,8 +105,8 @@ import_names=['lzero.policy.multi_agent_muzero'], ), collector=dict( - type='multi_agent_episode_muzero', - import_names=['lzero.worker.multi_agent_muzero_collector'], + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], ) ) create_config = EasyDict(create_config) From 5bec18b8a53a19624609a27746b8eb41b2426717 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Fri, 18 Aug 2023 19:51:39 +0800 Subject: [PATCH 49/54] polish(yzj): polish multi agent model --- lzero/model/efficientzero_model_mlp.py | 19 +- lzero/model/efficientzero_model_structure.py | 187 ---------------- lzero/model/gobigger/__init__.py | 0 lzero/model/muzero_model_mlp.py | 14 +- lzero/model/muzero_model_structure.py | 183 ---------------- lzero/model/petting_zoo/__init__.py | 0 lzero/policy/efficientzero.py | 2 - lzero/policy/muzero.py | 2 - lzero/policy/utils.py | 34 +-- lzero/worker/gobigger_muzero_evaluator.py | 10 +- .../config/gobigger_efficientzero_config.py | 1 - zoo/gobigger/config/gobigger_muzero_config.py | 1 - zoo/gobigger/entry/train_muzero_gobigger.py | 14 +- zoo/gobigger/model/__init__.py | 1 + .../gobigger/model}/encoder.py | 0 .../gobigger/model/model.py | 0 .../config/ptz_simple_spread_ez_config.py | 4 +- .../config/ptz_simple_spread_mz_config.py | 5 +- zoo/petting_zoo/entry/__init__.py | 1 + zoo/petting_zoo/entry/train_muzero.py | 200 ++++++++++++++++++ .../envs/petting_zoo_simple_spread_env.py | 8 - zoo/petting_zoo/model/__init__.py | 1 + .../petting_zoo/model/model.py | 4 +- 23 files changed, 264 insertions(+), 427 deletions(-) delete mode 100644 lzero/model/efficientzero_model_structure.py delete mode 100644 lzero/model/gobigger/__init__.py delete mode 100644 lzero/model/muzero_model_structure.py delete mode 100644 lzero/model/petting_zoo/__init__.py create mode 100644 zoo/gobigger/model/__init__.py rename {lzero/model/gobigger => zoo/gobigger/model}/encoder.py (100%) rename lzero/model/gobigger/gobigger_encoder.py => zoo/gobigger/model/model.py (100%) create mode 100644 zoo/petting_zoo/entry/__init__.py create mode 100644 zoo/petting_zoo/entry/train_muzero.py create mode 100644 zoo/petting_zoo/model/__init__.py rename lzero/model/petting_zoo/encoder.py => zoo/petting_zoo/model/model.py (74%) diff --git a/lzero/model/efficientzero_model_mlp.py b/lzero/model/efficientzero_model_mlp.py index a491cdb75..de852350e 100644 --- a/lzero/model/efficientzero_model_mlp.py +++ b/lzero/model/efficientzero_model_mlp.py @@ -4,6 +4,8 @@ import torch.nn as nn from ding.torch_utils import MLP from ding.utils import MODEL_REGISTRY, SequenceType +from ding.utils.default_helper import get_shape0 + from numpy import ndarray from .common import EZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP @@ -36,6 +38,7 @@ def __init__( norm_type: Optional[str] = 'BN', discrete_action_encoding_type: str = 'one_hot', res_connection_in_dynamics: bool = False, + state_encoder=None, *args, **kwargs, ): @@ -104,9 +107,12 @@ def __init__( self.state_norm = state_norm self.res_connection_in_dynamics = res_connection_in_dynamics - self.representation_network = RepresentationNetworkMLP( - observation_shape=observation_shape, hidden_channels=latent_state_dim, norm_type=norm_type - ) + if state_encoder == None: + self.representation_network = RepresentationNetworkMLP( + observation_shape=observation_shape, hidden_channels=latent_state_dim, norm_type=norm_type + ) + else: + self.representation_network = state_encoder self.dynamics_network = DynamicsNetworkMLP( action_encoding_dim=self.action_encoding_dim, @@ -171,15 +177,16 @@ def initial_inference(self, obs: torch.Tensor) -> EZNetworkOutput: - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size. """ - batch_size = obs.size(0) + batch_size = get_shape0(obs) latent_state = self._representation(obs) + device = latent_state.device policy_logits, value = self._prediction(latent_state) # zero initialization for reward hidden states # (hn, cn), each element shape is (layer_num=1, batch_size, lstm_hidden_size) reward_hidden_state = ( torch.zeros(1, batch_size, - self.lstm_hidden_size).to(obs.device), torch.zeros(1, batch_size, - self.lstm_hidden_size).to(obs.device) + self.lstm_hidden_size).to(device), torch.zeros(1, batch_size, + self.lstm_hidden_size).to(device) ) return EZNetworkOutput(value, [0. for _ in range(batch_size)], policy_logits, latent_state, reward_hidden_state) diff --git a/lzero/model/efficientzero_model_structure.py b/lzero/model/efficientzero_model_structure.py deleted file mode 100644 index 93afeab2b..000000000 --- a/lzero/model/efficientzero_model_structure.py +++ /dev/null @@ -1,187 +0,0 @@ -from typing import Optional, Tuple - -import torch -import torch.nn as nn -from ding.torch_utils import MLP, ResBlock -from ding.utils import MODEL_REGISTRY, SequenceType -from numpy import ndarray - -from .common import EZNetworkOutput, RepresentationNetwork, PredictionNetwork -from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean -from lzero.model.efficientzero_model_mlp import EfficientZeroModelMLP, DynamicsNetworkMLP, PredictionNetworkMLP - -@MODEL_REGISTRY.register('EfficientZeroModelStructure') -class EfficientZeroModelStructure(EfficientZeroModelMLP): - def __init__( - self, - env_name: str, - action_space_size: int = 6, - lstm_hidden_size: int = 512, - latent_state_dim: int = 256, - fc_reward_layers: SequenceType = [32], - fc_value_layers: SequenceType = [32], - fc_policy_layers: SequenceType = [32], - reward_support_size: int = 601, - value_support_size: int = 601, - proj_hid: int = 1024, - proj_out: int = 1024, - pred_hid: int = 512, - pred_out: int = 1024, - self_supervised_learning_loss: bool = True, - categorical_distribution: bool = True, - last_linear_layer_init_zero: bool = True, - state_norm: bool = False, - activation: Optional[nn.Module] = nn.ReLU(inplace=True), - norm_type: Optional[str] = 'BN', - discrete_action_encoding_type: str = 'one_hot', - res_connection_in_dynamics: bool = False, - *args, - **kwargs, - ): - """ - Overview: - The definition of the network model of EfficientZero, which is a generalization version for 1D vector obs. - The networks are mainly built on fully connected layers. - Sampled EfficientZero model consists of a representation network, a dynamics network and a prediction network. - The representation network is an MLP network which maps the raw observation to a latent state. - The dynamics network is an MLP+LSTM network which predicts the next latent state, reward_hidden_state and value_prefix given the current latent state and action. - The prediction network is an MLP network which predicts the value and policy given the current latent state. - Arguments: - - env_name (:obj:`str`): Env name, e.g. ptz_simple_spread, gobigger etc. - - action_space_size: (:obj:`int`): Action space size, e.g. 4 for Lunarlander. - - lstm_hidden_size (:obj:`int`): The hidden size of LSTM in dynamics network to predict value_prefix. - - latent_state_dim (:obj:`int`): The dimension of latent state, such as 256. - - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). - - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - reward_support_size (:obj:`int`): The size of categorical reward output - - value_support_size (:obj:`int`): The size of categorical value output. - - proj_hid (:obj:`int`): The size of projection hidden layer. - - proj_out (:obj:`int`): The size of projection output layer. - - pred_hid (:obj:`int`): The size of prediction hidden layer. - - pred_out (:obj:`int`): The size of prediction output layer. - - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks in Sampled EfficientZero model, default set it to False. - - categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical distribution for value, reward/value_prefix. - - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. - - state_norm (:obj:`bool`): Whether to use normalization for latent states, default sets it to True. - - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ - operation to speedup, e.g. ReLU(inplace=True). - - discrete_action_encoding_type (:obj:`str`): The type of encoding for discrete action. Default sets it to 'one_hot'. options = {'one_hot', 'not_one_hot'} - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. - """ - super(EfficientZeroModelStructure, self).__init__() - if not categorical_distribution: - self.reward_support_size = 1 - self.value_support_size = 1 - else: - self.reward_support_size = reward_support_size - self.value_support_size = value_support_size - - self.action_space_size = action_space_size - self.continuous_action_space = False - # The dim of action space. For discrete action space, it is 1. - # For continuous action space, it is the dimension of continuous action. - self.action_space_dim = action_space_size if self.continuous_action_space else 1 - assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type - self.discrete_action_encoding_type = discrete_action_encoding_type - if self.continuous_action_space: - self.action_encoding_dim = action_space_size - else: - if self.discrete_action_encoding_type == 'one_hot': - self.action_encoding_dim = action_space_size - elif self.discrete_action_encoding_type == 'not_one_hot': - self.action_encoding_dim = 1 - - self.lstm_hidden_size = lstm_hidden_size - self.proj_hid = proj_hid - self.proj_out = proj_out - self.pred_hid = pred_hid - self.pred_out = pred_out - self.self_supervised_learning_loss = self_supervised_learning_loss - self.last_linear_layer_init_zero = last_linear_layer_init_zero - self.state_norm = state_norm - self.res_connection_in_dynamics = res_connection_in_dynamics - - if env_name == 'gobigger': - from lzero.model.gobigger.gobigger_encoder import GoBiggerEncoder as Encoder - elif env_name == 'ptz_simple_spread': - from lzero.model.petting_zoo.encoder import PettingZooEncoder as Encoder - else: - raise NotImplementedError - self.representation_network = Encoder() - - self.dynamics_network = DynamicsNetworkMLP( - action_encoding_dim=self.action_encoding_dim, - num_channels=latent_state_dim + self.action_encoding_dim, - common_layer_num=2, - lstm_hidden_size=lstm_hidden_size, - fc_reward_layers=fc_reward_layers, - output_support_size=self.reward_support_size, - last_linear_layer_init_zero=self.last_linear_layer_init_zero, - norm_type=norm_type, - res_connection_in_dynamics=self.res_connection_in_dynamics, - ) - - self.prediction_network = PredictionNetworkMLP( - action_space_size=action_space_size, - num_channels=latent_state_dim, - fc_value_layers=fc_value_layers, - fc_policy_layers=fc_policy_layers, - output_support_size=self.value_support_size, - last_linear_layer_init_zero=self.last_linear_layer_init_zero, - norm_type=norm_type - ) - - if self.self_supervised_learning_loss: - # self_supervised_learning_loss related network proposed in EfficientZero - self.projection_input_dim = latent_state_dim - - self.projection = nn.Sequential( - nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) - ) - self.prediction_head = nn.Sequential( - nn.Linear(self.proj_out, self.pred_hid), - nn.BatchNorm1d(self.pred_hid), - activation, - nn.Linear(self.pred_hid, self.pred_out), - ) - - def initial_inference(self, obs: torch.Tensor) -> EZNetworkOutput: - """ - Overview: - Initial inference of EfficientZero model, which is the first step of the EfficientZero model. - To perform the initial inference, we first use the representation network to obtain the "latent_state" of the observation. - Then we use the prediction network to predict the "value" and "policy_logits" of the "latent_state", and - also prepare the zeros-like ``reward_hidden_state`` for the next step of the EfficientZero model. - Arguments: - - obs (:obj:`torch.Tensor`): The 1D vector observation data. - Returns (EZNetworkOutput): - - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. - - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. \ - In initial inference, we set it to zero vector. - - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The hidden state of LSTM about reward. In initial inference, \ - we set it to the zeros-like hidden state (H and C). - Shapes: - - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. - - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. - - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. - - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. - - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size. - """ - batch_size = obs['action_mask'].shape[0] - latent_state = self._representation(obs) - device = latent_state.device - policy_logits, value = self._prediction(latent_state) - # zero initialization for reward hidden states - # (hn, cn), each element shape is (layer_num=1, batch_size, lstm_hidden_size) - reward_hidden_state = ( - torch.zeros(1, batch_size, self.lstm_hidden_size).to(device), - torch.zeros(1, batch_size, self.lstm_hidden_size).to(device), - ) - return EZNetworkOutput(value, [0. for _ in range(batch_size)], policy_logits, latent_state, reward_hidden_state) diff --git a/lzero/model/gobigger/__init__.py b/lzero/model/gobigger/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/lzero/model/muzero_model_mlp.py b/lzero/model/muzero_model_mlp.py index caf1df15d..25dfff530 100644 --- a/lzero/model/muzero_model_mlp.py +++ b/lzero/model/muzero_model_mlp.py @@ -4,6 +4,8 @@ import torch.nn as nn from ding.torch_utils import MLP from ding.utils import MODEL_REGISTRY, SequenceType +from ding.utils.default_helper import get_shape0 + from .common import MZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean @@ -34,6 +36,7 @@ def __init__( discrete_action_encoding_type: str = 'one_hot', norm_type: Optional[str] = 'BN', res_connection_in_dynamics: bool = False, + state_encoder=None, *args, **kwargs ): @@ -101,9 +104,12 @@ def __init__( self.state_norm = state_norm self.res_connection_in_dynamics = res_connection_in_dynamics - self.representation_network = RepresentationNetworkMLP( - observation_shape=observation_shape, hidden_channels=self.latent_state_dim, norm_type=norm_type - ) + if state_encoder == None: + self.representation_network = RepresentationNetworkMLP( + observation_shape=observation_shape, hidden_channels=latent_state_dim, norm_type=norm_type + ) + else: + self.representation_network = state_encoder self.dynamics_network = DynamicsNetwork( action_encoding_dim=self.action_encoding_dim, @@ -166,7 +172,7 @@ def initial_inference(self, obs: torch.Tensor) -> MZNetworkOutput: - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. """ - batch_size = obs.size(0) + batch_size = get_shape0(obs) latent_state = self._representation(obs) policy_logits, value = self._prediction(latent_state) return MZNetworkOutput( diff --git a/lzero/model/muzero_model_structure.py b/lzero/model/muzero_model_structure.py deleted file mode 100644 index af5e0419e..000000000 --- a/lzero/model/muzero_model_structure.py +++ /dev/null @@ -1,183 +0,0 @@ -from typing import Optional, Tuple - -import torch -import torch.nn as nn -from ding.torch_utils import MLP -from ding.utils import MODEL_REGISTRY, SequenceType - -from .common import MZNetworkOutput, PredictionNetworkMLP -from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean -from lzero.model.muzero_model_mlp import MuZeroModelMLP, DynamicsNetwork - - - -@MODEL_REGISTRY.register('MuZeroModelStructure') -class MuZeroModelMLPStructure(MuZeroModelMLP): - - def __init__( - self, - env_name: str, - action_space_size: int = 6, - latent_state_dim: int = 256, - fc_reward_layers: SequenceType = [32], - fc_value_layers: SequenceType = [32], - fc_policy_layers: SequenceType = [32], - reward_support_size: int = 601, - value_support_size: int = 601, - proj_hid: int = 1024, - proj_out: int = 1024, - pred_hid: int = 512, - pred_out: int = 1024, - self_supervised_learning_loss: bool = False, - categorical_distribution: bool = True, - activation: Optional[nn.Module] = nn.ReLU(inplace=True), - last_linear_layer_init_zero: bool = True, - state_norm: bool = False, - discrete_action_encoding_type: str = 'one_hot', - norm_type: Optional[str] = 'BN', - res_connection_in_dynamics: bool = False, - *args, - **kwargs - ): - """ - Overview: - The definition of the network model of MuZero, which is a generalization version for 1D vector obs. - The networks are mainly built on fully connected layers. - The representation network is an MLP network which maps the raw observation to a latent state. - The dynamics network is an MLP network which predicts the next latent state, and reward given the current latent state and action. - The prediction network is an MLP network which predicts the value and policy given the current latent state. - Arguments: - - observation_shape (:obj:`int`): Observation space shape, e.g. 8 for Lunarlander. - - action_space_size: (:obj:`int`): Action space size, e.g. 4 for Lunarlander. - - latent_state_dim (:obj:`int`): The dimension of latent state, such as 256. - - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). - - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - reward_support_size (:obj:`int`): The size of categorical reward output - - value_support_size (:obj:`int`): The size of categorical value output. - - proj_hid (:obj:`int`): The size of projection hidden layer. - - proj_out (:obj:`int`): The size of projection output layer. - - pred_hid (:obj:`int`): The size of prediction hidden layer. - - pred_out (:obj:`int`): The size of prediction output layer. - - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks in MuZero model, default set it to False. - - categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical distribution for value, reward/value_prefix. - - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ - operation to speedup, e.g. ReLU(inplace=True). - - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. - - state_norm (:obj:`bool`): Whether to use normalization for latent states, default sets it to True. - - discrete_action_encoding_type (:obj:`str`): The encoding type of discrete action, which can be 'one_hot' or 'not_one_hot'. - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. - """ - super(MuZeroModelMLP, self).__init__() - self.categorical_distribution = categorical_distribution - if not self.categorical_distribution: - self.reward_support_size = 1 - self.value_support_size = 1 - else: - self.reward_support_size = reward_support_size - self.value_support_size = value_support_size - - self.action_space_size = action_space_size - self.continuous_action_space = False - # The dim of action space. For discrete action space, it is 1. - # For continuous action space, it is the dimension of continuous action. - self.action_space_dim = action_space_size if self.continuous_action_space else 1 - assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type - self.discrete_action_encoding_type = discrete_action_encoding_type - if self.continuous_action_space: - self.action_encoding_dim = action_space_size - else: - if self.discrete_action_encoding_type == 'one_hot': - self.action_encoding_dim = action_space_size - elif self.discrete_action_encoding_type == 'not_one_hot': - self.action_encoding_dim = 1 - - self.latent_state_dim = latent_state_dim - self.proj_hid = proj_hid - self.proj_out = proj_out - self.pred_hid = pred_hid - self.pred_out = pred_out - self.self_supervised_learning_loss = self_supervised_learning_loss - self.last_linear_layer_init_zero = last_linear_layer_init_zero - self.state_norm = state_norm - self.res_connection_in_dynamics = res_connection_in_dynamics - - if env_name == 'gobigger': - from lzero.model.gobigger.gobigger_encoder import GoBiggerEncoder as Encoder - elif env_name == 'ptz_simple_spread': - from lzero.model.petting_zoo.encoder import PettingZooEncoder as Encoder - else: - raise NotImplementedError - self.representation_network = Encoder() - - self.dynamics_network = DynamicsNetwork( - action_encoding_dim=self.action_encoding_dim, - num_channels=self.latent_state_dim + self.action_encoding_dim, - common_layer_num=2, - fc_reward_layers=fc_reward_layers, - output_support_size=self.reward_support_size, - last_linear_layer_init_zero=self.last_linear_layer_init_zero, - norm_type=norm_type, - res_connection_in_dynamics=self.res_connection_in_dynamics, - ) - - self.prediction_network = PredictionNetworkMLP( - action_space_size=action_space_size, - num_channels=latent_state_dim, - fc_value_layers=fc_value_layers, - fc_policy_layers=fc_policy_layers, - output_support_size=self.value_support_size, - last_linear_layer_init_zero=self.last_linear_layer_init_zero, - norm_type=norm_type - ) - - if self.self_supervised_learning_loss: - # self_supervised_learning_loss related network proposed in EfficientZero - self.projection_input_dim = latent_state_dim - - self.projection = nn.Sequential( - nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) - ) - self.prediction_head = nn.Sequential( - nn.Linear(self.proj_out, self.pred_hid), - nn.BatchNorm1d(self.pred_hid), - activation, - nn.Linear(self.pred_hid, self.pred_out), - ) - - def initial_inference(self, obs: torch.Tensor) -> MZNetworkOutput: - """ - Overview: - Initial inference of MuZero model, which is the first step of the MuZero model. - To perform the initial inference, we first use the representation network to obtain the "latent_state" of the observation. - Then we use the prediction network to predict the "value" and "policy_logits" of the "latent_state", and - also prepare the zeros-like ``reward`` for the next step of the MuZero model. - Arguments: - - obs (:obj:`torch.Tensor`): The 1D vector observation data. - Returns (MZNetworkOutput): - - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. - - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. \ - In initial inference, we set it to zero vector. - - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The hidden state of LSTM about reward. In initial inference, \ - we set it to the zeros-like hidden state (H and C). - Shapes: - - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. - - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. - - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. - - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. - - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - """ - batch_size = obs['action_mask'].shape[0] - latent_state = self._representation(obs) - policy_logits, value = self._prediction(latent_state) - return MZNetworkOutput( - value, - [0. for _ in range(batch_size)], - policy_logits, - latent_state, - ) diff --git a/lzero/model/petting_zoo/__init__.py b/lzero/model/petting_zoo/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/lzero/policy/efficientzero.py b/lzero/policy/efficientzero.py index 5cb58c1c1..1b3d43ac5 100644 --- a/lzero/policy/efficientzero.py +++ b/lzero/policy/efficientzero.py @@ -210,8 +210,6 @@ def default_model(self) -> Tuple[str, List[str]]: return 'EfficientZeroModel', ['lzero.model.efficientzero_model'] elif self._cfg.model.model_type == "mlp": return 'EfficientZeroModelMLP', ['lzero.model.efficientzero_model_mlp'] - elif self._cfg.model.model_type == "structure": - return 'EfficientZeroModelStructure', ['lzero.model.efficientzero_model_structure'] def _init_learn(self) -> None: """ diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index 9cccbc833..fc90e76a4 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -208,8 +208,6 @@ def default_model(self) -> Tuple[str, List[str]]: return 'MuZeroModel', ['lzero.model.muzero_model'] elif self._cfg.model.model_type == "mlp": return 'MuZeroModelMLP', ['lzero.model.muzero_model_mlp'] - elif self._cfg.model.model_type == "structure": - return 'MuZeroModelStructure', ['lzero.model.muzero_model_structure'] def _init_learn(self) -> None: """ diff --git a/lzero/policy/utils.py b/lzero/policy/utils.py index 8d0738d3e..22258a738 100644 --- a/lzero/policy/utils.py +++ b/lzero/policy/utils.py @@ -202,24 +202,26 @@ def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, obs_batch_new[k][k1] = v1.reshape(batch_size, -1) else: # espaecially for ptz obs, {'k1':[], 'k2':[]} obs_batch_new[k] = v.reshape(batch_size, -1) + obs_batch_new = to_device(obs_batch_new, device=cfg.device) # obs_target_batch - obs_target_batch = obs_target_batch.tolist() - obs_target_batch = sum(obs_target_batch, []) - obs_target_batch = default_collate(obs_target_batch) - obs_target_batch_new = {} - for k, v in obs_target_batch.items(): - if isinstance(v, dict): # espaecially for gobigger obs, { {'k':{'k1':[], 'k2':[]},} - obs_target_batch_new[k] = {} - for k1, v1 in v.items(): - if len(v1.shape) == 1: - obs_target_batch_new[k][k1] = v1 - else: - obs_target_batch_new[k][k1] = v1.reshape(batch_size, -1) - else: # espaecially for ptz obs, {'k1':[], 'k2':[]} - obs_target_batch_new[k] = v.reshape(batch_size, -1) - obs_batch_new = to_device(obs_batch_new, device=cfg.device) - obs_target_batch_new = to_device(obs_target_batch_new, device=cfg.device) + obs_target_batch_new = None + if cfg.model.self_supervised_learning_loss: + obs_target_batch = obs_target_batch.tolist() + obs_target_batch = sum(obs_target_batch, []) + obs_target_batch = default_collate(obs_target_batch) + obs_target_batch_new = {} + for k, v in obs_target_batch.items(): + if isinstance(v, dict): # espaecially for gobigger obs, { {'k':{'k1':[], 'k2':[]},} + obs_target_batch_new[k] = {} + for k1, v1 in v.items(): + if len(v1.shape) == 1: + obs_target_batch_new[k][k1] = v1 + else: + obs_target_batch_new[k][k1] = v1.reshape(batch_size, -1) + else: # espaecially for ptz obs, {'k1':[], 'k2':[]} + obs_target_batch_new[k] = v.reshape(batch_size, -1) + obs_target_batch_new = to_device(obs_target_batch_new, device=cfg.device) return obs_batch_new, obs_target_batch_new diff --git a/lzero/worker/gobigger_muzero_evaluator.py b/lzero/worker/gobigger_muzero_evaluator.py index 0bf6bdc05..a2e3ad5a1 100644 --- a/lzero/worker/gobigger_muzero_evaluator.py +++ b/lzero/worker/gobigger_muzero_evaluator.py @@ -349,19 +349,19 @@ def eval_vsbot( continue self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) - eval_reward = np.mean(episode_return) - if eval_reward > self._max_eval_reward: + episode_return = np.mean(episode_return) + if episode_return > self._max_episode_return: if save_ckpt_fn: save_ckpt_fn('ckpt_best.pth.tar') - self._max_eval_reward = eval_reward - stop_flag = eval_reward >= self._stop_value and train_iter > 0 + self._max_episode_return = episode_return + stop_flag = episode_return >= self._stop_value and train_iter > 0 if stop_flag: self._logger.info( "[LightZero serial pipeline] " + "Current eval_reward: {} is greater than stop_value: {}".format(eval_reward, self._stop_value) + ", so your MCTS/RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details." ) - return stop_flag, eval_reward + return stop_flag, episode_return class GoBiggerVectorEvalMonitor(VectorEvalMonitor): diff --git a/zoo/gobigger/config/gobigger_efficientzero_config.py b/zoo/gobigger/config/gobigger_efficientzero_config.py index 57e358422..ce58c26d2 100644 --- a/zoo/gobigger/config/gobigger_efficientzero_config.py +++ b/zoo/gobigger/config/gobigger_efficientzero_config.py @@ -41,7 +41,6 @@ ignore_done=True, model=dict( model_type='structure', - env_name=env_name, agent_num=agent_num, team_num=team_num, latent_state_dim=176, diff --git a/zoo/gobigger/config/gobigger_muzero_config.py b/zoo/gobigger/config/gobigger_muzero_config.py index 9638b527e..90d1680d4 100644 --- a/zoo/gobigger/config/gobigger_muzero_config.py +++ b/zoo/gobigger/config/gobigger_muzero_config.py @@ -40,7 +40,6 @@ ignore_done=True, model=dict( model_type='structure', - env_name=env_name, agent_num=agent_num, team_num=team_num, latent_state_dim=176, diff --git a/zoo/gobigger/entry/train_muzero_gobigger.py b/zoo/gobigger/entry/train_muzero_gobigger.py index 991e64575..03ce590e9 100644 --- a/zoo/gobigger/entry/train_muzero_gobigger.py +++ b/zoo/gobigger/entry/train_muzero_gobigger.py @@ -8,19 +8,20 @@ from ding.envs import create_env_manager from ding.envs import get_vec_env_setting from ding.policy import create_policy -from ding.utils import set_pkg_seed +from ding.utils import set_pkg_seed, get_rank from ding.rl_utils import get_epsilon_greedy_fn from ding.worker import BaseLearner from tensorboardX import SummaryWriter -from lzero.entry.utils import log_buffer_memory_usage -from lzero.policy import visit_count_temperature from lzero.worker import MuZeroCollector as Collector from lzero.worker import MuZeroEvaluator as Evaluator -from lzero.worker import GoBiggerMuZeroEvaluator +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature from lzero.entry.utils import random_collect -import copy +import copy +from lzero.worker import GoBiggerMuZeroEvaluator +from zoo.gobigger.model import GoBiggerEncoder def train_muzero_gobigger( input_cfg: Tuple[dict, dict], @@ -53,8 +54,10 @@ def train_muzero_gobigger( if create_cfg.policy.type == 'muzero' or create_cfg.policy.type == 'multi_agent_muzero': from lzero.mcts import MuZeroGameBuffer as GameBuffer + from lzero.model.muzero_model_mlp import MuZeroModelMLP as Encoder elif create_cfg.policy.type == 'efficientzero' or create_cfg.policy.type == 'multi_agent_efficientzero': from lzero.mcts import EfficientZeroGameBuffer as GameBuffer + from lzero.model.efficientzero_model_mlp import EfficientZeroModelMLP as Encoder elif create_cfg.policy.type == 'sampled_efficientzero': from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer elif create_cfg.policy.type == 'gumbel_muzero': @@ -82,6 +85,7 @@ def train_muzero_gobigger( vsbot_evaluator_env.seed(cfg.seed, dynamic_seed=False) set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + model = Encoder(**cfg.policy.model, state_encoder=GoBiggerEncoder()) policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) # load pretrained model diff --git a/zoo/gobigger/model/__init__.py b/zoo/gobigger/model/__init__.py new file mode 100644 index 000000000..c7ba3e2e0 --- /dev/null +++ b/zoo/gobigger/model/__init__.py @@ -0,0 +1 @@ +from .model import GoBiggerEncoder \ No newline at end of file diff --git a/lzero/model/gobigger/encoder.py b/zoo/gobigger/model/encoder.py similarity index 100% rename from lzero/model/gobigger/encoder.py rename to zoo/gobigger/model/encoder.py diff --git a/lzero/model/gobigger/gobigger_encoder.py b/zoo/gobigger/model/model.py similarity index 100% rename from lzero/model/gobigger/gobigger_encoder.py rename to zoo/gobigger/model/model.py diff --git a/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py b/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py index 8e1cb9984..ef385ee70 100644 --- a/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py +++ b/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py @@ -45,7 +45,6 @@ ignore_done=False, model=dict( model_type='structure', - env_name=env_name, latent_state_dim=256, frame_stack_num=1, action_space='discrete', @@ -78,6 +77,7 @@ optim_type='SGD', lr_piecewise_constant_decay=True, learning_rate=0.2, + ssl_loss_weight=0, num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, n_episode=n_episode, @@ -112,5 +112,5 @@ ptz_simple_spread_efficientzero_create_config = create_config if __name__ == '__main__': - from lzero.entry import train_muzero + from zoo.petting_zoo.entry import train_muzero train_muzero([main_config, create_config], seed=seed) diff --git a/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py b/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py index dac1cfc39..d9d9c653f 100644 --- a/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py +++ b/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py @@ -45,13 +45,12 @@ ignore_done=False, model=dict( model_type='structure', - env_name=env_name, latent_state_dim=256, frame_stack_num=1, action_space='discrete', action_space_size=action_space_size, agent_num=n_agent, - self_supervised_learning_loss=True, # default is False + self_supervised_learning_loss=False, # default is False agent_obs_shape=(2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2)*3, global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + n_landmark * 2 + n_agent * (n_agent - 1) * 2, @@ -114,5 +113,5 @@ ptz_simple_spread_muzero_create_config = create_config if __name__ == '__main__': - from lzero.entry import train_muzero + from zoo.petting_zoo.entry import train_muzero train_muzero([main_config, create_config], seed=seed) diff --git a/zoo/petting_zoo/entry/__init__.py b/zoo/petting_zoo/entry/__init__.py new file mode 100644 index 000000000..cc6e3cbb5 --- /dev/null +++ b/zoo/petting_zoo/entry/__init__.py @@ -0,0 +1 @@ +from .train_muzero import train_muzero \ No newline at end of file diff --git a/zoo/petting_zoo/entry/train_muzero.py b/zoo/petting_zoo/entry/train_muzero.py new file mode 100644 index 000000000..b36dda8f8 --- /dev/null +++ b/zoo/petting_zoo/entry/train_muzero.py @@ -0,0 +1,200 @@ +import logging +import os +from functools import partial +from typing import Optional, Tuple + +import torch +from ding.config import compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.utils import set_pkg_seed, get_rank +from ding.rl_utils import get_epsilon_greedy_fn +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.worker import MuZeroCollector as Collector +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.entry.utils import random_collect +from zoo.petting_zoo.model import PettingZooEncoder + +def train_muzero( + input_cfg: Tuple[dict, dict], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': # noqa + """ + Overview: + The train entry for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. + Arguments: + - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type. + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): The pretrained model path, which should + point to the ckpt file of the pretrained model, and an absolute path is recommended. + In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. + - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + + cfg, create_cfg = input_cfg + assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'multi_agent_efficientzero', 'multi_agent_muzero'], \ + "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'multi_agent_efficientzero', 'multi_agent_muzero'" + + if create_cfg.policy.type == 'muzero' or create_cfg.policy.type == 'multi_agent_muzero': + from lzero.mcts import MuZeroGameBuffer as GameBuffer + from lzero.model.muzero_model_mlp import MuZeroModelMLP as Encoder + elif create_cfg.policy.type == 'efficientzero' or create_cfg.policy.type == 'multi_agent_efficientzero': + from lzero.mcts import EfficientZeroGameBuffer as GameBuffer + from lzero.model.efficientzero_model_mlp import EfficientZeroModelMLP as Encoder + elif create_cfg.policy.type == 'sampled_efficientzero': + from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'gumbel_muzero': + from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer + + if cfg.policy.cuda and torch.cuda.is_available(): + cfg.policy.device = 'cuda' + else: + cfg.policy.device = 'cpu' + + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create main components: env, policy + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + collector_env.seed(cfg.seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder()) + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # load pretrained model + if model_path is not None: + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = cfg.policy + batch_size = policy_config.batch_size + # specific game buffer for MCTS+RL algorithms + replay_buffer = GameBuffer(policy_config) + + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config + ) + + # ============================================================== + # Main loop + # ============================================================== + # Learner's before_run hook. + learner.call_hook('before_run') + + if cfg.policy.update_per_collect is not None: + update_per_collect = cfg.policy.update_per_collect + + # The purpose of collecting random data before training: + # Exploration: The collection of random data aids the agent in exploring the environment and prevents premature convergence to a suboptimal policy. + # Comparation: The agent's performance during random action-taking can be used as a reference point to evaluate the efficacy of reinforcement learning algorithms. + if cfg.policy.random_collect_episode_num > 0: + if policy_config.multi_agent: + from lzero.policy.multi_agent_random_policy import MultiAgentLightZeroRandomPolicy as RandomPolicy + else: + from lzero.policy.random_policy import LightZeroRandomPolicy as RandomPolicy + random_collect(cfg.policy, policy, RandomPolicy, collector, collector_env, replay_buffer) + + while True: + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + collect_kwargs = {} + # set temperature for visit count distributions according to the train_iter, + # please refer to Appendix D in MuZero paper for details. + collect_kwargs['temperature'] = visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ) + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + else: + collect_kwargs['epsilon'] = 0.0 + + # Evaluate policy performance. + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + # Collect data by default config n_sample/n_episode. + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + if cfg.policy.update_per_collect is None: + # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio. + collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]]) + update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio) + # save returned new_data collected by the collector + replay_buffer.push_game_segments(new_data) + # remove the oldest data if the replay buffer is full. + replay_buffer.remove_oldest_data_to_fit() + + # Learn policy from collected data. + for i in range(update_per_collect): + # Learner will train ``update_per_collect`` times in one iteration. + if replay_buffer.get_num_of_transitions() > batch_size: + train_data = replay_buffer.sample(batch_size, policy) + else: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, ' + f'{replay_buffer} ' + f'continue to collect now ....' + ) + break + + # The core train steps for MCTS+RL algorithms. + log_vars = learner.train(train_data, collector.envstep) + + if cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + break + + # Learner's after_run hook. + learner.call_hook('after_run') + return policy diff --git a/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py index 4e8638fd8..87d84455a 100644 --- a/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py +++ b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py @@ -283,14 +283,6 @@ def _process_obs(self, obs: 'torch.Tensor') -> np.ndarray: # noqa tmp[k] = v[i] tmp['action_mask'] = [1 for _ in range(*self._action_dim)] ret_transform.append(tmp) - - concat_state_0 = np.concatenate((ret_transform[0]['agent_state'], ret_transform[1]['agent_state'], ret_transform[2]['agent_state']), axis=0) - concat_state_1 = np.concatenate((ret_transform[1]['agent_state'], ret_transform[0]['agent_state'], ret_transform[2]['agent_state']), axis=0) - concat_state_2 = np.concatenate((ret_transform[2]['agent_state'], ret_transform[0]['agent_state'], ret_transform[1]['agent_state']), axis=0) - - ret_transform[0]['agent_state'] = concat_state_0 - ret_transform[1]['agent_state'] = concat_state_1 - ret_transform[2]['agent_state'] = concat_state_2 return {'observation': ret_transform, 'action_mask': action_mask, 'to_play': to_play} def _process_action(self, action: 'torch.Tensor') -> Dict[str, np.ndarray]: # noqa diff --git a/zoo/petting_zoo/model/__init__.py b/zoo/petting_zoo/model/__init__.py new file mode 100644 index 000000000..821e014a8 --- /dev/null +++ b/zoo/petting_zoo/model/__init__.py @@ -0,0 +1 @@ +from .model import PettingZooEncoder \ No newline at end of file diff --git a/lzero/model/petting_zoo/encoder.py b/zoo/petting_zoo/model/model.py similarity index 74% rename from lzero/model/petting_zoo/encoder.py rename to zoo/petting_zoo/model/model.py index 8c9e92c59..8d2723676 100644 --- a/lzero/model/petting_zoo/encoder.py +++ b/zoo/petting_zoo/model/model.py @@ -5,9 +5,9 @@ class PettingZooEncoder(nn.Module): def __init__(self): super().__init__() - self.encoder = FCEncoder(obs_shape=54, hidden_size_list=[256, 256], activation=nn.ReLU(), norm_type=None) + self.encoder = FCEncoder(obs_shape=48, hidden_size_list=[256, 256], activation=nn.ReLU(), norm_type=None) def forward(self, x): - x = x['agent_state'] + x = x['global_state'] x = self.encoder(x) return x \ No newline at end of file From 920dc383963a27132c011f1b6fad68d7af23987e Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Mon, 21 Aug 2023 19:40:34 +0800 Subject: [PATCH 50/54] polish(yzj): polish gobigger entry and evaluator --- lzero/entry/train_muzero.py | 10 +- lzero/worker/gobigger_muzero_evaluator.py | 623 ++++++++++---------- lzero/worker/muzero_evaluator.py | 1 + zoo/gobigger/entry/train_muzero_gobigger.py | 6 +- 4 files changed, 336 insertions(+), 304 deletions(-) diff --git a/lzero/entry/train_muzero.py b/lzero/entry/train_muzero.py index 20d857276..3f00d91d6 100644 --- a/lzero/entry/train_muzero.py +++ b/lzero/entry/train_muzero.py @@ -13,10 +13,11 @@ from ding.worker import BaseLearner from tensorboardX import SummaryWriter -from lzero.worker import MuZeroCollector as Collector -from lzero.worker import MuZeroEvaluator as Evaluator from lzero.entry.utils import log_buffer_memory_usage from lzero.policy import visit_count_temperature +from lzero.policy.random_policy import LightZeroRandomPolicy +from lzero.worker import MuZeroCollector as Collector +from lzero.worker import MuZeroEvaluator as Evaluator from .utils import random_collect @@ -91,7 +92,6 @@ def train_muzero( batch_size = policy_config.batch_size # specific game buffer for MCTS+RL algorithms replay_buffer = GameBuffer(policy_config) - collector = Collector( env=collector_env, policy=policy.collect_mode, @@ -115,7 +115,7 @@ def train_muzero( # ============================================================== # Learner's before_run hook. learner.call_hook('before_run') - + if cfg.policy.update_per_collect is not None: update_per_collect = cfg.policy.update_per_collect @@ -194,4 +194,4 @@ def train_muzero( # Learner's after_run hook. learner.call_hook('after_run') - return policy + return policy \ No newline at end of file diff --git a/lzero/worker/gobigger_muzero_evaluator.py b/lzero/worker/gobigger_muzero_evaluator.py index a2e3ad5a1..1694c8880 100644 --- a/lzero/worker/gobigger_muzero_evaluator.py +++ b/lzero/worker/gobigger_muzero_evaluator.py @@ -1,20 +1,24 @@ +import copy import time from collections import namedtuple -from typing import Any, Optional, Callable, Tuple +from typing import Optional, Callable, Tuple, Any, List, Dict import numpy as np import torch - from ding.envs import BaseEnvManager -from ding.torch_utils import to_ndarray -from .muzero_evaluator import MuZeroEvaluator -from ding.worker.collector.base_serial_evaluator import VectorEvalMonitor +from ding.torch_utils import to_ndarray, to_item +from ding.utils import build_logger, EasyTimer +from ding.utils import get_world_size, get_rank, broadcast_object_list +from ding.worker.collector.base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor +from easydict import EasyDict + from lzero.mcts.buffer.game_segment import GameSegment from lzero.mcts.utils import prepare_observation from collections import defaultdict from zoo.gobigger.env.gobigger_rule_bot import GoBiggerBot from collections import namedtuple, deque +from .muzero_evaluator import MuZeroEvaluator class GoBiggerMuZeroEvaluator(MuZeroEvaluator): @@ -45,323 +49,350 @@ def eval_vsbot( - stop_flag (:obj:`bool`): Whether this training program can be ended. - eval_reward (:obj:`float`): Current eval_reward. """ - if n_episode is None: - n_episode = self._default_n_episode - assert n_episode is not None, "please indicate eval n_episode" - envstep_count = 0 - # specifically for vs bot - eval_monitor = GoBiggerVectorEvalMonitor(self._env.env_num, n_episode) - env_nums = self._env.env_num - - self._env.reset() - self._policy.reset() - - # specifically for vs bot - agent_num = self.policy_config['model']['agent_num'] - team_num = self.policy_config['model']['team_num'] - self._bot_policy = GoBiggerBot(env_nums, agent_id=[i for i in range(agent_num//team_num, agent_num)]) #TODO only support t2p2 - self._bot_policy.reset() - - # initializations - init_obs = self._env.ready_obs - - retry_waiting_time = 0.001 - while len(init_obs.keys()) != self._env_num: - # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # len(self._env.ready_obs), especially in tictactoe env. - self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) - self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) - time.sleep(retry_waiting_time) - self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) - self._logger.info( - 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states) - ) + episode_info = None + stop_flag = False + if get_rank() == 0: + if n_episode is None: + n_episode = self._default_n_episode + assert n_episode is not None, "please indicate eval n_episode" + envstep_count = 0 + # specifically for vs bot + eval_monitor = GoBiggerVectorEvalMonitor(self._env.env_num, n_episode) + env_nums = self._env.env_num + + self._env.reset() + self._policy.reset() + + # initializations init_obs = self._env.ready_obs - # specifically for vs bot - for i in range(env_nums): - for k, v in init_obs[i].items(): - if k != 'raw_obs': - init_obs[i][k] = v[:agent_num] - - action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} - to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} - dones = np.array([False for _ in range(env_nums)]) - - game_segments = [ - [ - GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) for _ in range(agent_num) - ] for _ in range(env_nums) - ] - for env_id in range(env_nums): - for agent_id in range(agent_num): - game_segments[env_id][agent_id].reset( - [ - to_ndarray(init_obs[env_id]['observation'][agent_id]) - for _ in range(self.policy_config.model.frame_stack_num) - ] + retry_waiting_time = 0.001 + while len(init_obs.keys()) != self._env_num: + # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to + # len(self._env.ready_obs), especially in tictactoe env. + self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) + self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) + time.sleep(retry_waiting_time) + self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) + self._logger.info( + 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states) ) + init_obs = self._env.ready_obs - ready_env_id = set() - remain_episode = n_episode - # specifically for vs bot - eat_info = defaultdict() + # specifically for vs bot + agent_num = self.policy_config['model']['agent_num'] + team_num = self.policy_config['model']['team_num'] + self._bot_policy = GoBiggerBot(env_nums, agent_id=[i for i in range(agent_num//team_num, agent_num)]) #TODO only support t2p2 + self._bot_policy.reset() - with self._timer: - while not eval_monitor.is_finished(): - # Get current ready env obs. - obs = self._env.ready_obs - # specifically for vs bot - raw_obs = [v['raw_obs'] for k, v in obs.items()] - new_available_env_id = set(obs.keys()).difference(ready_env_id) - ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) - remain_episode -= min(len(new_available_env_id), remain_episode) - - if self._multi_agent: - stack_obs = defaultdict(list) - for env_id in ready_env_id: - for agent_id in range(agent_num): - stack_obs[env_id].append(game_segments[env_id][agent_id].get_obs()) - else: - stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} - stack_obs = list(stack_obs.values()) - - action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} - to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} - action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] - to_play = [to_play_dict[env_id] for env_id in ready_env_id] - - stack_obs = to_ndarray(stack_obs) - if self.policy_config.model.model_type and self.policy_config.model.model_type in ['conv', 'mlp']: - stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) - stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() - - # ============================================================== - # bot forward - # ============================================================== - bot_actions = self._bot_policy.forward(raw_obs) - - # ============================================================== - # policy forward - # ============================================================== - policy_output = self._policy.forward(stack_obs, action_mask, to_play) - if self._multi_agent: - actions_no_env_id = defaultdict(dict) - for k, v in policy_output.items(): - for agent_id, act in enumerate(v['action']): - actions_no_env_id[k][agent_id] = act - else: - actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} - distributions_dict_no_env_id = {k: v['distributions'] for k, v in policy_output.items()} - if self.policy_config.sampled_algo: - root_sampled_actions_dict_no_env_id = { - k: v['root_sampled_actions'] - for k, v in policy_output.items() - } + # specifically for vs bot + for i in range(env_nums): + for k, v in init_obs[i].items(): + if k != 'raw_obs': + init_obs[i][k] = v[:agent_num] - value_dict_no_env_id = {k: v['value'] for k, v in policy_output.items()} - pred_value_dict_no_env_id = {k: v['pred_value'] for k, v in policy_output.items()} - visit_entropy_dict_no_env_id = { - k: v['visit_count_distribution_entropy'] - for k, v in policy_output.items() - } - - actions = {} - distributions_dict = {} - if self.policy_config.sampled_algo: - root_sampled_actions_dict = {} - value_dict = {} - pred_value_dict = {} - visit_entropy_dict = {} - for index, env_id in enumerate(ready_env_id): - actions[env_id] = actions_no_env_id.pop(index) - distributions_dict[env_id] = distributions_dict_no_env_id.pop(index) - if self.policy_config.sampled_algo: - root_sampled_actions_dict[env_id] = root_sampled_actions_dict_no_env_id.pop(index) - value_dict[env_id] = value_dict_no_env_id.pop(index) - pred_value_dict[env_id] = pred_value_dict_no_env_id.pop(index) - visit_entropy_dict[env_id] = visit_entropy_dict_no_env_id.pop(index) - - # ============================================================== - # Interact with env. - # ============================================================== - # specifically for vs bot - for env_id, v in bot_actions.items(): - actions[env_id].update(v) + action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} + + to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} + dones = np.array([False for _ in range(env_nums)]) - timesteps = self._env.step(actions) - for env_id, t in timesteps.items(): - obs, reward, done, info = t.obs, t.reward, t.done, t.info + if self._multi_agent: + agent_num = len(init_obs[0]['action_mask']) + assert agent_num == self.policy_config.model.agent_num, "Please make sure agent_num == env.agent_num" + game_segments = [ + [ + GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) for _ in range(agent_num) + ] for _ in range(env_nums) + ] + for env_id in range(env_nums): + for agent_id in range(agent_num): + game_segments[env_id][agent_id].reset( + [ + to_ndarray(init_obs[env_id]['observation'][agent_id]) + for _ in range(self.policy_config.model.frame_stack_num) + ] + ) + else: + game_segments = [ + GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) for _ in range(env_nums) + ] + for i in range(env_nums): + game_segments[i].reset( + [to_ndarray(init_obs[i]['observation']) for _ in range(self.policy_config.model.frame_stack_num)] + ) + + ready_env_id = set() + remain_episode = n_episode + # specifically for vs bot + eat_info = defaultdict() + + with self._timer: + while not eval_monitor.is_finished(): + # Get current ready env obs. + obs = self._env.ready_obs + # specifically for vs bot + raw_obs = [v['raw_obs'] for k, v in obs.items()] + new_available_env_id = set(obs.keys()).difference(ready_env_id) + ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) + remain_episode -= min(len(new_available_env_id), remain_episode) + if self._multi_agent: - for agent_id in range(agent_num): - game_segments[env_id][agent_id].append( - actions[env_id][agent_id], to_ndarray(obs['observation'][agent_id]), reward[agent_id], - action_mask_dict[env_id][agent_id], to_play_dict[env_id] - ) + stack_obs = defaultdict(list) + for env_id in ready_env_id: + for agent_id in range(agent_num): + stack_obs[env_id].append(game_segments[env_id][agent_id].get_obs()) else: - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id] - ) + stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + stack_obs = list(stack_obs.values()) + + action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} + to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} + action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] + to_play = [to_play_dict[env_id] for env_id in ready_env_id] + + stack_obs = to_ndarray(stack_obs) + if self.policy_config.model.model_type and self.policy_config.model.model_type in ['conv', 'mlp']: + stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) + stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() + + # ============================================================== + # bot forward + # ============================================================== + bot_actions = self._bot_policy.forward(raw_obs) + + # ============================================================== + # policy forward + # ============================================================== + policy_output = self._policy.forward(stack_obs, action_mask, to_play) + if self._multi_agent: + actions_no_env_id = defaultdict(dict) + for k, v in policy_output.items(): + for agent_id, act in enumerate(v['action']): + actions_no_env_id[k][agent_id] = act + else: + actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} + distributions_dict_no_env_id = {k: v['distributions'] for k, v in policy_output.items()} + if self.policy_config.sampled_algo: + root_sampled_actions_dict_no_env_id = { + k: v['root_sampled_actions'] + for k, v in policy_output.items() + } + + value_dict_no_env_id = {k: v['value'] for k, v in policy_output.items()} + pred_value_dict_no_env_id = {k: v['pred_value'] for k, v in policy_output.items()} + visit_entropy_dict_no_env_id = { + k: v['visit_count_distribution_entropy'] + for k, v in policy_output.items() + } - # NOTE: in evaluator, we only need save the ``o_{t+1} = obs['observation']`` - # game_segments[env_id].obs_segment.append(to_ndarray(obs['observation'])) - - # NOTE: the position of code snippet is very important. - # the obs['action_mask'] and obs['to_play'] is corresponding to next action - action_mask_dict[env_id] = to_ndarray(obs['action_mask']) - to_play_dict[env_id] = to_ndarray(obs['to_play']) - - dones[env_id] = done - if t.done: - # Env reset is done by env_manager automatically. - self._policy.reset([env_id]) - reward = t.info['eval_episode_return'] - # specifically for vs bot - bot_reward = t.info['eval_bot_episode_return'] - eat_info[env_id] = t.info['eats'] - if 'episode_info' in t.info: - eval_monitor.update_info(env_id, t.info['episode_info']) - eval_monitor.update_reward(env_id, reward) - # specifically for vs bot - eval_monitor.update_bot_reward(env_id, bot_reward) - self._logger.info( - "[EVALUATOR vsbot]env {} finish episode, final reward: {}, current episode: {}".format( - env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode() + actions = {} + distributions_dict = {} + if self.policy_config.sampled_algo: + root_sampled_actions_dict = {} + value_dict = {} + pred_value_dict = {} + visit_entropy_dict = {} + for index, env_id in enumerate(ready_env_id): + actions[env_id] = actions_no_env_id.pop(index) + distributions_dict[env_id] = distributions_dict_no_env_id.pop(index) + if self.policy_config.sampled_algo: + root_sampled_actions_dict[env_id] = root_sampled_actions_dict_no_env_id.pop(index) + value_dict[env_id] = value_dict_no_env_id.pop(index) + pred_value_dict[env_id] = pred_value_dict_no_env_id.pop(index) + visit_entropy_dict[env_id] = visit_entropy_dict_no_env_id.pop(index) + + # ============================================================== + # Interact with env. + # ============================================================== + # specifically for vs bot + for env_id, v in bot_actions.items(): + actions[env_id].update(v) + + timesteps = self._env.step(actions) + + for env_id, t in timesteps.items(): + obs, reward, done, info = t.obs, t.reward, t.done, t.info + if self._multi_agent: + for agent_id in range(agent_num): + game_segments[env_id][agent_id].append( + actions[env_id][agent_id], to_ndarray(obs['observation'][agent_id]), reward[agent_id] if isinstance(reward, list) else reward, + action_mask_dict[env_id][agent_id], to_play_dict[env_id] + ) + else: + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], + to_play_dict[env_id] ) - ) - # reset the finished env and init game_segments - if n_episode > self._env_num: - # Get current ready env obs. - init_obs = self._env.ready_obs - retry_waiting_time = 0.001 - while len(init_obs.keys()) != self._env_num: - # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # len(self._env.ready_obs), especially in tictactoe env. - self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) - self._logger.info( - 'Before sleeping, the _env_states is {}'.format(self._env._env_states) - ) - time.sleep(retry_waiting_time) - self._logger.info( - '=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10 + # NOTE: in evaluator, we only need save the ``o_{t+1} = obs['observation']`` + # game_segments[env_id].obs_segment.append(to_ndarray(obs['observation'])) + + # NOTE: the position of code snippet is very important. + # the obs['action_mask'] and obs['to_play'] is corresponding to next action + action_mask_dict[env_id] = to_ndarray(obs['action_mask']) + to_play_dict[env_id] = to_ndarray(obs['to_play']) + + dones[env_id] = done + if t.done: + # Env reset is done by env_manager automatically. + self._policy.reset([env_id]) + reward = t.info['eval_episode_return'] + # specifically for vs bot + bot_reward = t.info['eval_bot_episode_return'] + eat_info[env_id] = t.info['eats'] + if 'episode_info' in t.info: + eval_monitor.update_info(env_id, t.info['episode_info']) + eval_monitor.update_reward(env_id, reward) + # specifically for vs bot + eval_monitor.update_bot_reward(env_id, bot_reward) + self._logger.info( + "[EVALUATOR vsbot]env {} finish episode, final reward: {}, current episode: {}".format( + env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode() ) - self._logger.info( - 'After sleeping {}s, the current _env_states is {}'.format( - retry_waiting_time, self._env._env_states - ) - ) - init_obs = self._env.ready_obs - - new_available_env_id = set(init_obs.keys()).difference(ready_env_id) - ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) - remain_episode -= min(len(new_available_env_id), remain_episode) - - action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) - to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) + ) - if self._multi_agent: - for agent_id in range(agent_num): - game_segments[env_id][agent_id] = GameSegment( + # reset the finished env and init game_segments + if n_episode > self._env_num: + # Get current ready env obs. + init_obs = self._env.ready_obs + retry_waiting_time = 0.001 + while len(init_obs.keys()) != self._env_num: + # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to + # len(self._env.ready_obs), especially in tictactoe env. + self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) + self._logger.info( + 'Before sleeping, the _env_states is {}'.format(self._env._env_states) + ) + time.sleep(retry_waiting_time) + self._logger.info( + '=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10 + ) + self._logger.info( + 'After sleeping {}s, the current _env_states is {}'.format( + retry_waiting_time, self._env._env_states + ) + ) + init_obs = self._env.ready_obs + + new_available_env_id = set(init_obs.keys()).difference(ready_env_id) + ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) + remain_episode -= min(len(new_available_env_id), remain_episode) + + action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) + to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) + + if self._multi_agent: + for agent_id in range(agent_num): + game_segments[env_id][agent_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + + game_segments[env_id][agent_id].reset( + [ + init_obs[env_id]['observation'][agent_id] + for _ in range(self.policy_config.model.frame_stack_num) + ] + ) + else: + game_segments[env_id] = GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config ) - game_segments[env_id][agent_id].reset( + game_segments[env_id].reset( [ - init_obs[env_id]['observation'][agent_id] + init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num) ] ) - else: - game_segments[env_id] = GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) - - game_segments[env_id].reset( - [ - init_obs[env_id]['observation'] - for _ in range(self.policy_config.model.frame_stack_num) - ] - ) - # Env reset is done by env_manager automatically. - self._policy.reset([env_id]) - # specifically for vs bot - self._bot_policy.reset([env_id]) - # TODO(pu): subprocess mode, when n_episode > self._env_num, occasionally the ready_env_id=() - # and the stack_obs is np.array(None, dtype=object) - ready_env_id.remove(env_id) - - envstep_count += 1 - duration = self._timer.value - episode_return = eval_monitor.get_episode_return() - # specifically for vs bot - bot_episode_return = eval_monitor.get_bot_episode_return() - info = { - 'train_iter': train_iter, - 'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter), - 'episode_count': n_episode, - 'envstep_count': envstep_count, - 'avg_envstep_per_episode': envstep_count / n_episode, - 'evaluate_time': duration, - 'avg_envstep_per_sec': envstep_count / duration, - 'avg_time_per_episode': n_episode / duration, - 'reward_mean': np.mean(episode_return), - 'reward_std': np.std(episode_return), - 'reward_max': np.max(episode_return), - 'reward_min': np.min(episode_return), + # Env reset is done by env_manager automatically. + self._policy.reset([env_id]) + # specifically for vs bot + self._bot_policy.reset([env_id]) + # TODO(pu): subprocess mode, when n_episode > self._env_num, occasionally the ready_env_id=() + # and the stack_obs is np.array(None, dtype=object) + ready_env_id.remove(env_id) + + envstep_count += 1 + duration = self._timer.value + episode_return = eval_monitor.get_episode_return() # specifically for vs bot - 'bot_reward_mean': np.mean(bot_episode_return), - 'bot_reward_std': np.std(bot_episode_return), - 'bot_reward_max': np.max(bot_episode_return), - 'bot_reward_min': np.min(bot_episode_return), - } - # specifically for vs bot - # add eat info - for k, v in eat_info.items(): - for i in range(len(v)): - for k1, v1 in v[i].items(): - info['agent_{}_{}'.format(i, k1)] = info.get('agent_{}_{}'.format(i, k1), []) + [v1] - - for k, v in info.items(): - if 'agent' in k: - info[k] = np.mean(v) - - episode_info = eval_monitor.get_episode_info() - if episode_info is not None: - info.update(episode_info) - self._logger.info(self._logger.get_tabulate_vars_hor(info)) - # self._logger.info(self._logger.get_tabulate_vars(info)) - for k, v in info.items(): - if k in ['train_iter', 'ckpt_name', 'each_reward']: - continue - if not np.isscalar(v): - continue - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) - episode_return = np.mean(episode_return) - if episode_return > self._max_episode_return: - if save_ckpt_fn: - save_ckpt_fn('ckpt_best.pth.tar') - self._max_episode_return = episode_return - stop_flag = episode_return >= self._stop_value and train_iter > 0 - if stop_flag: - self._logger.info( - "[LightZero serial pipeline] " + - "Current eval_reward: {} is greater than stop_value: {}".format(eval_reward, self._stop_value) + - ", so your MCTS/RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details." - ) - return stop_flag, episode_return + bot_episode_return = eval_monitor.get_bot_episode_return() + info = { + 'train_iter': train_iter, + 'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter), + 'episode_count': n_episode, + 'envstep_count': envstep_count, + 'avg_envstep_per_episode': envstep_count / n_episode, + 'evaluate_time': duration, + 'avg_envstep_per_sec': envstep_count / duration, + 'avg_time_per_episode': n_episode / duration, + 'reward_mean': np.mean(episode_return), + 'reward_std': np.std(episode_return), + 'reward_max': np.max(episode_return), + 'reward_min': np.min(episode_return), + # specifically for vs bot + 'bot_reward_mean': np.mean(bot_episode_return), + 'bot_reward_std': np.std(bot_episode_return), + 'bot_reward_max': np.max(bot_episode_return), + 'bot_reward_min': np.min(bot_episode_return), + } + # specifically for vs bot + # add eat info + for k, v in eat_info.items(): + for i in range(len(v)): + for k1, v1 in v[i].items(): + info['agent_{}_{}'.format(i, k1)] = info.get('agent_{}_{}'.format(i, k1), []) + [v1] + + for k, v in info.items(): + if 'agent' in k: + info[k] = np.mean(v) + + episode_info = eval_monitor.get_episode_info() + if episode_info is not None: + info.update(episode_info) + self._logger.info(self._logger.get_tabulate_vars_hor(info)) + # self._logger.info(self._logger.get_tabulate_vars(info)) + for k, v in info.items(): + if k in ['train_iter', 'ckpt_name', 'each_reward']: + continue + if not np.isscalar(v): + continue + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) + episode_return = np.mean(episode_return) + if episode_return > self._max_episode_return: + if save_ckpt_fn: + save_ckpt_fn('ckpt_best.pth.tar') + self._max_episode_return = episode_return + stop_flag = episode_return >= self._stop_value and train_iter > 0 + if stop_flag: + self._logger.info( + "[LightZero serial pipeline] " + + "Current episode_return: {} is greater than stop_value: {}".format(episode_return, self._stop_value) + + ", so your MCTS/RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details." + ) + + if get_world_size() > 1: + objects = [stop_flag, episode_info] + broadcast_object_list(objects, src=0) + stop_flag, episode_info = objects + + episode_info = to_item(episode_info) + return stop_flag, episode_info class GoBiggerVectorEvalMonitor(VectorEvalMonitor): diff --git a/lzero/worker/muzero_evaluator.py b/lzero/worker/muzero_evaluator.py index c8972cc3d..b0611a1f2 100644 --- a/lzero/worker/muzero_evaluator.py +++ b/lzero/worker/muzero_evaluator.py @@ -258,6 +258,7 @@ def eval( if self._multi_agent: agent_num = len(init_obs[0]['action_mask']) + assert agent_num == self.policy_config.model.agent_num, "Please make sure agent_num == env.agent_num" game_segments = [ [ GameSegment( diff --git a/zoo/gobigger/entry/train_muzero_gobigger.py b/zoo/gobigger/entry/train_muzero_gobigger.py index 03ce590e9..31c194bab 100644 --- a/zoo/gobigger/entry/train_muzero_gobigger.py +++ b/zoo/gobigger/entry/train_muzero_gobigger.py @@ -13,10 +13,11 @@ from ding.worker import BaseLearner from tensorboardX import SummaryWriter -from lzero.worker import MuZeroCollector as Collector -from lzero.worker import MuZeroEvaluator as Evaluator from lzero.entry.utils import log_buffer_memory_usage from lzero.policy import visit_count_temperature +from lzero.policy.random_policy import LightZeroRandomPolicy +from lzero.worker import MuZeroCollector as Collector +from lzero.worker import MuZeroEvaluator as Evaluator from lzero.entry.utils import random_collect import copy @@ -103,7 +104,6 @@ def train_muzero_gobigger( batch_size = policy_config.batch_size # specific game buffer for MCTS+RL algorithms replay_buffer = GameBuffer(policy_config) - collector = Collector( env=collector_env, policy=policy.collect_mode, From 1c1fde9e2073ec72350b5c65f0f9c4e35639d575 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Tue, 29 Aug 2023 15:52:02 +0800 Subject: [PATCH 51/54] feature(yzj): add pettingzoo visualization --- lzero/policy/multi_agent_muzero.py | 8 +- zoo/petting_zoo/entry/__init__.py | 3 +- zoo/petting_zoo/entry/eval_muzero.py | 81 +++++++++++++++++++ .../envs/petting_zoo_simple_spread_env.py | 37 +++++++-- 4 files changed, 118 insertions(+), 11 deletions(-) create mode 100644 zoo/petting_zoo/entry/eval_muzero.py diff --git a/lzero/policy/multi_agent_muzero.py b/lzero/policy/multi_agent_muzero.py index 97ebe2eec..701fc7b49 100644 --- a/lzero/policy/multi_agent_muzero.py +++ b/lzero/policy/multi_agent_muzero.py @@ -76,11 +76,9 @@ def _forward_collect( network_output = self._collect_model.initial_inference(data) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) - if not self._learn_model.training: - # if not in training, obtain the scalars of the value/reward - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() - latent_state_roots = latent_state_roots.detach().cpu().numpy() - policy_logits = policy_logits.detach().cpu().numpy().tolist() + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] # the only difference between collect and eval is the dirichlet noise diff --git a/zoo/petting_zoo/entry/__init__.py b/zoo/petting_zoo/entry/__init__.py index cc6e3cbb5..5e8144157 100644 --- a/zoo/petting_zoo/entry/__init__.py +++ b/zoo/petting_zoo/entry/__init__.py @@ -1 +1,2 @@ -from .train_muzero import train_muzero \ No newline at end of file +from .train_muzero import train_muzero +from .eval_muzero import eval_muzero \ No newline at end of file diff --git a/zoo/petting_zoo/entry/eval_muzero.py b/zoo/petting_zoo/entry/eval_muzero.py new file mode 100644 index 000000000..33ef567c3 --- /dev/null +++ b/zoo/petting_zoo/entry/eval_muzero.py @@ -0,0 +1,81 @@ +import logging +import os +from functools import partial +from typing import Optional, Tuple + +import torch +from ding.config import compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.utils import set_pkg_seed, get_rank +from ding.rl_utils import get_epsilon_greedy_fn +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.worker import MuZeroCollector as Collector +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.entry.utils import random_collect +from zoo.petting_zoo.model import PettingZooEncoder + +def eval_muzero(main_cfg, create_cfg, seed=0): + assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'multi_agent_efficientzero', 'multi_agent_muzero'], \ + "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'multi_agent_efficientzero', 'multi_agent_muzero'" + + if create_cfg.policy.type == 'muzero' or create_cfg.policy.type == 'multi_agent_muzero': + from lzero.mcts import MuZeroGameBuffer as GameBuffer + from lzero.model.muzero_model_mlp import MuZeroModelMLP as Encoder + elif create_cfg.policy.type == 'efficientzero' or create_cfg.policy.type == 'multi_agent_efficientzero': + from lzero.mcts import EfficientZeroGameBuffer as GameBuffer + from lzero.model.efficientzero_model_mlp import EfficientZeroModelMLP as Encoder + elif create_cfg.policy.type == 'sampled_efficientzero': + from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'gumbel_muzero': + from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer + + main_cfg.policy.device = 'cpu' + main_cfg.policy.load_path = 'exp_name/ckpt/ckpt_best.pth.tar' + main_cfg.env.replay_path = './' # when visualize must set as base + create_cfg.env_manager.type = 'base' # when visualize must set as base + main_cfg.env.evaluator_env_num = 1 # only 1 env for save replay + main_cfg.env.n_evaluator_episode = 1 + + cfg = compile_config(main_cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create main components: env, policy + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder()) + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + policy.learn_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location=cfg.policy.device)) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = cfg.policy + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config + ) + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter) + return stop, reward + +if __name__ == '__main__': + from zoo.petting_zoo.config.ptz_simple_spread_ez_config import main_config, create_config + eval_muzero(main_config, create_config, seed=0) \ No newline at end of file diff --git a/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py index 87d84455a..f24f66563 100644 --- a/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py +++ b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py @@ -11,6 +11,8 @@ from pettingzoo.utils.conversions import parallel_wrapper_fn from pettingzoo.mpe._mpe_utils.simple_env import SimpleEnv, make_env from pettingzoo.mpe.simple_spread.simple_spread import Scenario +from PIL import Image +import pygame @ENV_REGISTRY.register('petting_zoo') @@ -21,7 +23,8 @@ class PettingZooEnv(BaseEnv): def __init__(self, cfg: dict) -> None: self._cfg = cfg self._init_flag = False - self._replay_path = None + self._replay_path = self._cfg.get('replay_path', None) + self.frame_list = [] self._env_family = self._cfg.env_family self._env_id = self._cfg.env_id self._num_agents = self._cfg.n_agent @@ -55,10 +58,10 @@ def reset(self) -> np.ndarray: # if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: # np_seed = 100 * np.random.randint(1, 1000) # self._env.seed(self._seed + np_seed) - if self._replay_path is not None: - self._env = gym.wrappers.Monitor( - self._env, self._replay_path, video_callable=lambda episode_id: True, force=True - ) + # if self._replay_path is not None: + # self._env = gym.wrappers.Monitor( + # self._env, self._replay_path, video_callable=lambda episode_id: True, force=True + # ) if hasattr(self, '_seed'): obs = self._env.reset(seed=self._seed) else: @@ -199,8 +202,12 @@ def step(self, action: dict) -> BaseEnvTimestep: # for agent in self._agents: # self._eval_episode_return[agent] += rew[agent] + if self._replay_path is not None: + self.frame_list.append(Image.fromarray(self._env.render())) if done_n: # or reduce(lambda x, y: x and y, done.values()) info['eval_episode_return'] = self._eval_episode_return + if self._replay_path is not None: + self.frame_list[0].save('out.gif', save_all=True, append_images=self.frame_list[1:], duration=3, loop=0) # for agent in rew: # rew[agent] = to_ndarray([rew[agent]]) return BaseEnvTimestep(obs_n, rew_n, done_n, info) @@ -330,6 +337,7 @@ def __init__(self, N=3, local_ratio=0.5, max_cycles=25, continuous_actions=False scenario = Scenario() world = scenario.make_world(N) super().__init__(scenario, world, max_cycles, continuous_actions=continuous_actions, local_ratio=local_ratio) + self.render_mode = 'rgb_array' self.metadata['name'] = "simple_spread_v2" def _execute_world_step(self): @@ -365,3 +373,22 @@ def _execute_world_step(self): reward = agent_reward self.rewards[agent.name] = reward + + def render(self): + if self.render_mode is None: + gym.logger.warn( + "You are calling render method without specifying any render mode." + ) + return + + self.enable_render(self.render_mode) + + self.draw() + observation = np.array(pygame.surfarray.pixels3d(self.screen)) + if self.render_mode == "human": + pygame.display.flip() + return ( + np.transpose(observation, axes=(1, 0, 2)) + if self.render_mode == "rgb_array" + else None + ) From 72c669b7ba3b8a69202f400c4163d4ebedf8f6a5 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Tue, 29 Aug 2023 17:06:15 +0800 Subject: [PATCH 52/54] polish(yzj): polish ptz config and model --- lzero/policy/multi_agent_efficientzero.py | 16 +++++++--------- .../config/ptz_simple_spread_ez_config.py | 18 +++++++++--------- zoo/petting_zoo/entry/eval_muzero.py | 2 +- zoo/petting_zoo/model/model.py | 4 ++-- 4 files changed, 19 insertions(+), 21 deletions(-) diff --git a/lzero/policy/multi_agent_efficientzero.py b/lzero/policy/multi_agent_efficientzero.py index 7ed1ab414..149eefa47 100644 --- a/lzero/policy/multi_agent_efficientzero.py +++ b/lzero/policy/multi_agent_efficientzero.py @@ -175,15 +175,13 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read network_output ) - if not self._eval_model.training: - # if not in training, obtain the scalars of the value/reward - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) - latent_state_roots = latent_state_roots.detach().cpu().numpy() - reward_hidden_state_roots = ( - reward_hidden_state_roots[0].detach().cpu().numpy(), - reward_hidden_state_roots[1].detach().cpu().numpy() - ) - policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + reward_hidden_state_roots = ( + reward_hidden_state_roots[0].detach().cpu().numpy(), + reward_hidden_state_roots[1].detach().cpu().numpy() + ) + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] if self._cfg.mcts_ctree: diff --git a/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py b/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py index ef385ee70..c2fa602a5 100644 --- a/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py +++ b/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py @@ -13,11 +13,11 @@ evaluator_env_num = 8 n_episode = 8 batch_size = 256 -num_simulations = 50 -update_per_collect = 1000 +num_simulations = 25 +update_per_collect = 100 reanalyze_ratio = 0. action_space_size = 5 -eps_greedy_exploration_in_collect = True +eps_greedy_exploration_in_collect = False # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -50,7 +50,7 @@ action_space='discrete', action_space_size=action_space_size, agent_num=n_agent, - agent_obs_shape=(2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2)*3, + agent_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2, global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + n_landmark * 2 + n_agent * (n_agent - 1) * 2, discrete_action_encoding_type='one_hot', @@ -62,21 +62,21 @@ mcts_ctree=True, gumbel_algo=False, env_type='not_board_games', - game_segment_length=400, + game_segment_length=50, random_collect_episode_num=0, eps=dict( eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, type='linear', start=1., end=0.05, - decay=int(1e5), + decay=int(2e4), ), use_augmentation=False, update_per_collect=update_per_collect, batch_size=batch_size, - optim_type='SGD', - lr_piecewise_constant_decay=True, - learning_rate=0.2, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, ssl_loss_weight=0, num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, diff --git a/zoo/petting_zoo/entry/eval_muzero.py b/zoo/petting_zoo/entry/eval_muzero.py index 33ef567c3..7eb3e4d17 100644 --- a/zoo/petting_zoo/entry/eval_muzero.py +++ b/zoo/petting_zoo/entry/eval_muzero.py @@ -53,7 +53,7 @@ def eval_muzero(main_cfg, create_cfg, seed=0): model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder()) policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) - policy.learn_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location=cfg.policy.device)) + policy.eval_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location=cfg.policy.device)) # Create worker components: learner, collector, evaluator, replay buffer, commander. tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None diff --git a/zoo/petting_zoo/model/model.py b/zoo/petting_zoo/model/model.py index 8d2723676..9f3313731 100644 --- a/zoo/petting_zoo/model/model.py +++ b/zoo/petting_zoo/model/model.py @@ -5,9 +5,9 @@ class PettingZooEncoder(nn.Module): def __init__(self): super().__init__() - self.encoder = FCEncoder(obs_shape=48, hidden_size_list=[256, 256], activation=nn.ReLU(), norm_type=None) + self.encoder = FCEncoder(obs_shape=18, hidden_size_list=[256, 256], activation=nn.ReLU(), norm_type=None) def forward(self, x): - x = x['global_state'] + x = x['agent_state'] x = self.encoder(x) return x \ No newline at end of file From 11ef08f4bb56dcf213e5bb13ab140570a772c22a Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Mon, 4 Sep 2023 17:46:31 +0800 Subject: [PATCH 53/54] feature(yzj): add ptz simple ez config --- .../config/ptz_simple_ez_config.py | 116 ++++++++++++++++++ .../envs/petting_zoo_simple_spread_env.py | 5 +- zoo/petting_zoo/model/model.py | 2 +- 3 files changed, 121 insertions(+), 2 deletions(-) create mode 100644 zoo/petting_zoo/config/ptz_simple_ez_config.py diff --git a/zoo/petting_zoo/config/ptz_simple_ez_config.py b/zoo/petting_zoo/config/ptz_simple_ez_config.py new file mode 100644 index 000000000..52862344a --- /dev/null +++ b/zoo/petting_zoo/config/ptz_simple_ez_config.py @@ -0,0 +1,116 @@ +from easydict import EasyDict + +env_name = 'ptz_simple_spread' +multi_agent = True + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +seed = 0 +n_agent = 1 +n_landmark = n_agent +collector_env_num = 8 +evaluator_env_num = 8 +n_episode = 8 +batch_size = 256 +num_simulations = 25 +update_per_collect = 100 +reanalyze_ratio = 0. +action_space_size = 5 +eps_greedy_exploration_in_collect = False +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +main_config = dict( + exp_name= + f'data_ez_ctree/{env_name}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', + env=dict( + env_family='mpe', + env_id='simple_v2', + n_agent=n_agent, + n_landmark=n_landmark, + max_cycles=25, + agent_obs_only=False, + agent_specific_global_state=True, + continuous_actions=False, + stop_value=0, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + multi_agent=multi_agent, + ignore_done=False, + model=dict( + model_type='structure', + latent_state_dim=256, + frame_stack_num=1, + action_space='discrete', + action_space_size=action_space_size, + agent_num=n_agent, + agent_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2, + global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + + n_landmark * 2 + n_agent * (n_agent - 1) * 2, + discrete_action_encoding_type='one_hot', + global_cooperation=True, # TODO: doesn't work now + hidden_size_list=[256, 256], + norm_type='BN', + ), + cuda=True, + mcts_ctree=True, + gumbel_algo=False, + env_type='not_board_games', + game_segment_length=50, + random_collect_episode_num=0, + eps=dict( + eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, + type='linear', + start=1., + end=0.05, + decay=int(2e4), + ), + use_augmentation=False, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + ssl_loss_weight=0, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + learn=dict(learner=dict( + log_policy=True, + hook=dict(log_show_after_iter=10, ), + ), ), +) +main_config = EasyDict(main_config) +create_config = dict( + env=dict( + import_names=['zoo.petting_zoo.envs.petting_zoo_simple_spread_env'], + type='petting_zoo', + ), + env_manager=dict(type='base'), + policy=dict( + type='multi_agent_efficientzero', + import_names=['lzero.policy.multi_agent_efficientzero'], + ), + collector=dict( + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], + ) +) +create_config = EasyDict(create_config) +ptz_simple_spread_efficientzero_config = main_config +ptz_simple_spread_efficientzero_create_config = create_config + +if __name__ == '__main__': + from zoo.petting_zoo.entry import train_muzero + train_muzero([main_config, create_config], seed=seed) diff --git a/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py index f24f66563..64cae850a 100644 --- a/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py +++ b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py @@ -169,7 +169,10 @@ def seed(self, seed: int, dynamic_seed: bool = True) -> None: def step(self, action: dict) -> BaseEnvTimestep: self._step_count += 1 - action = np.array(list(action.values())) + if isinstance(action, dict): + action = np.array(list(action.values())) + else: + action = np.array(action) action = self._process_action(action) if self._act_scale: for agent in self._agents: diff --git a/zoo/petting_zoo/model/model.py b/zoo/petting_zoo/model/model.py index 9f3313731..273a423d3 100644 --- a/zoo/petting_zoo/model/model.py +++ b/zoo/petting_zoo/model/model.py @@ -5,7 +5,7 @@ class PettingZooEncoder(nn.Module): def __init__(self): super().__init__() - self.encoder = FCEncoder(obs_shape=18, hidden_size_list=[256, 256], activation=nn.ReLU(), norm_type=None) + self.encoder = FCEncoder(obs_shape=6, hidden_size_list=[256, 256], activation=nn.ReLU(), norm_type=None) def forward(self, x): x = x['agent_state'] From 1e143bce6048a9f3a31c60ecbac59cdb63628cc0 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Thu, 7 Dec 2023 15:54:06 +0800 Subject: [PATCH 54/54] polish(yzj): polish code base --- .../mcts/buffer/game_buffer_efficientzero.py | 2 +- lzero/model/muzero_model_mlp.py | 1 + lzero/policy/multi_agent_efficientzero.py | 5 +- lzero/policy/multi_agent_muzero.py | 2 + zoo/petting_zoo/__init__.py | 0 zoo/petting_zoo/config/__init__.py | 1 - .../config/ptz_simple_ez_config.py | 116 ----- .../config/ptz_simple_spread_ez_config.py | 116 ----- .../config/ptz_simple_spread_mz_config.py | 117 ------ zoo/petting_zoo/entry/__init__.py | 2 - zoo/petting_zoo/entry/eval_muzero.py | 81 ---- zoo/petting_zoo/entry/train_muzero.py | 200 --------- zoo/petting_zoo/envs/__init__.py | 0 .../envs/petting_zoo_simple_spread_env.py | 397 ------------------ .../test_petting_zoo_simple_spread_env.py | 133 ------ zoo/petting_zoo/model/__init__.py | 1 - zoo/petting_zoo/model/model.py | 13 - 17 files changed, 8 insertions(+), 1179 deletions(-) delete mode 100644 zoo/petting_zoo/__init__.py delete mode 100644 zoo/petting_zoo/config/__init__.py delete mode 100644 zoo/petting_zoo/config/ptz_simple_ez_config.py delete mode 100644 zoo/petting_zoo/config/ptz_simple_spread_ez_config.py delete mode 100644 zoo/petting_zoo/config/ptz_simple_spread_mz_config.py delete mode 100644 zoo/petting_zoo/entry/__init__.py delete mode 100644 zoo/petting_zoo/entry/eval_muzero.py delete mode 100644 zoo/petting_zoo/entry/train_muzero.py delete mode 100644 zoo/petting_zoo/envs/__init__.py delete mode 100644 zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py delete mode 100644 zoo/petting_zoo/envs/test_petting_zoo_simple_spread_env.py delete mode 100644 zoo/petting_zoo/model/__init__.py delete mode 100644 zoo/petting_zoo/model/model.py diff --git a/lzero/mcts/buffer/game_buffer_efficientzero.py b/lzero/mcts/buffer/game_buffer_efficientzero.py index 154034305..cad35a658 100644 --- a/lzero/mcts/buffer/game_buffer_efficientzero.py +++ b/lzero/mcts/buffer/game_buffer_efficientzero.py @@ -46,7 +46,7 @@ def __init__(self, cfg: dict): self.base_idx = 0 self.clear_time = 0 - self.tmp_obs = None # for value obs list [46 + 4(td_step)] not < 50(game_segment) + self.tmp_obs = None # since value obs list [46 + 4(td_step)] >= 50(game_segment), need pad def sample(self, batch_size: int, policy: Any) -> List[Any]: """ diff --git a/lzero/model/muzero_model_mlp.py b/lzero/model/muzero_model_mlp.py index 25dfff530..cc1707909 100644 --- a/lzero/model/muzero_model_mlp.py +++ b/lzero/model/muzero_model_mlp.py @@ -69,6 +69,7 @@ def __init__( - discrete_action_encoding_type (:obj:`str`): The encoding type of discrete action, which can be 'one_hot' or 'not_one_hot'. - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. + - state_encoder (:obj:`Optional[nn.Module]`): The state encoder network, which is used to encode the raw observation to latent state. """ super(MuZeroModelMLP, self).__init__() self.categorical_distribution = categorical_distribution diff --git a/lzero/policy/multi_agent_efficientzero.py b/lzero/policy/multi_agent_efficientzero.py index 149eefa47..49426816e 100644 --- a/lzero/policy/multi_agent_efficientzero.py +++ b/lzero/policy/multi_agent_efficientzero.py @@ -19,7 +19,9 @@ class MultiAgentEfficientZeroPolicy(EfficientZeroPolicy): """ Overview: - The policy class for Multi Agent EfficientZero. + The policy class for Multi Agent EfficientZero. + Independent Learning mode is a method in which each agent learns and adapts to the environment independently \ + without directly considering the learning and strategies of other agents. """ def _forward_collect( @@ -212,6 +214,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read ) # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + # Save according to agent dimension output[i // agent_num]['action'].append(action) output[i // agent_num]['distributions'].append(distributions) output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) diff --git a/lzero/policy/multi_agent_muzero.py b/lzero/policy/multi_agent_muzero.py index 701fc7b49..65fd3e2cf 100644 --- a/lzero/policy/multi_agent_muzero.py +++ b/lzero/policy/multi_agent_muzero.py @@ -25,6 +25,8 @@ class MultiAgentMuZeroPolicy(MuZeroPolicy): """ Overview: The policy class for Multi Agent MuZero. + Independent Learning mode is a method in which each agent learns and adapts to the environment independently \ + without directly considering the learning and strategies of other agents. """ def _forward_collect( diff --git a/zoo/petting_zoo/__init__.py b/zoo/petting_zoo/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/zoo/petting_zoo/config/__init__.py b/zoo/petting_zoo/config/__init__.py deleted file mode 100644 index 5348554ae..000000000 --- a/zoo/petting_zoo/config/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .ptz_simple_spread_ez_config import main_config, create_config diff --git a/zoo/petting_zoo/config/ptz_simple_ez_config.py b/zoo/petting_zoo/config/ptz_simple_ez_config.py deleted file mode 100644 index 52862344a..000000000 --- a/zoo/petting_zoo/config/ptz_simple_ez_config.py +++ /dev/null @@ -1,116 +0,0 @@ -from easydict import EasyDict - -env_name = 'ptz_simple_spread' -multi_agent = True - -# ============================================================== -# begin of the most frequently changed config specified by the user -# ============================================================== -seed = 0 -n_agent = 1 -n_landmark = n_agent -collector_env_num = 8 -evaluator_env_num = 8 -n_episode = 8 -batch_size = 256 -num_simulations = 25 -update_per_collect = 100 -reanalyze_ratio = 0. -action_space_size = 5 -eps_greedy_exploration_in_collect = False -# ============================================================== -# end of the most frequently changed config specified by the user -# ============================================================== - -main_config = dict( - exp_name= - f'data_ez_ctree/{env_name}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', - env=dict( - env_family='mpe', - env_id='simple_v2', - n_agent=n_agent, - n_landmark=n_landmark, - max_cycles=25, - agent_obs_only=False, - agent_specific_global_state=True, - continuous_actions=False, - stop_value=0, - collector_env_num=collector_env_num, - evaluator_env_num=evaluator_env_num, - n_evaluator_episode=evaluator_env_num, - manager=dict(shared_memory=False, ), - ), - policy=dict( - multi_agent=multi_agent, - ignore_done=False, - model=dict( - model_type='structure', - latent_state_dim=256, - frame_stack_num=1, - action_space='discrete', - action_space_size=action_space_size, - agent_num=n_agent, - agent_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2, - global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + - n_landmark * 2 + n_agent * (n_agent - 1) * 2, - discrete_action_encoding_type='one_hot', - global_cooperation=True, # TODO: doesn't work now - hidden_size_list=[256, 256], - norm_type='BN', - ), - cuda=True, - mcts_ctree=True, - gumbel_algo=False, - env_type='not_board_games', - game_segment_length=50, - random_collect_episode_num=0, - eps=dict( - eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, - type='linear', - start=1., - end=0.05, - decay=int(2e4), - ), - use_augmentation=False, - update_per_collect=update_per_collect, - batch_size=batch_size, - optim_type='Adam', - lr_piecewise_constant_decay=False, - learning_rate=0.003, - ssl_loss_weight=0, - num_simulations=num_simulations, - reanalyze_ratio=reanalyze_ratio, - n_episode=n_episode, - eval_freq=int(2e3), - replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. - collector_env_num=collector_env_num, - evaluator_env_num=evaluator_env_num, - ), - learn=dict(learner=dict( - log_policy=True, - hook=dict(log_show_after_iter=10, ), - ), ), -) -main_config = EasyDict(main_config) -create_config = dict( - env=dict( - import_names=['zoo.petting_zoo.envs.petting_zoo_simple_spread_env'], - type='petting_zoo', - ), - env_manager=dict(type='base'), - policy=dict( - type='multi_agent_efficientzero', - import_names=['lzero.policy.multi_agent_efficientzero'], - ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) -) -create_config = EasyDict(create_config) -ptz_simple_spread_efficientzero_config = main_config -ptz_simple_spread_efficientzero_create_config = create_config - -if __name__ == '__main__': - from zoo.petting_zoo.entry import train_muzero - train_muzero([main_config, create_config], seed=seed) diff --git a/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py b/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py deleted file mode 100644 index c2fa602a5..000000000 --- a/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py +++ /dev/null @@ -1,116 +0,0 @@ -from easydict import EasyDict - -env_name = 'ptz_simple_spread' -multi_agent = True - -# ============================================================== -# begin of the most frequently changed config specified by the user -# ============================================================== -seed = 0 -n_agent = 3 -n_landmark = n_agent -collector_env_num = 8 -evaluator_env_num = 8 -n_episode = 8 -batch_size = 256 -num_simulations = 25 -update_per_collect = 100 -reanalyze_ratio = 0. -action_space_size = 5 -eps_greedy_exploration_in_collect = False -# ============================================================== -# end of the most frequently changed config specified by the user -# ============================================================== - -main_config = dict( - exp_name= - f'data_ez_ctree/{env_name}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', - env=dict( - env_family='mpe', - env_id='simple_spread_v2', - n_agent=n_agent, - n_landmark=n_landmark, - max_cycles=25, - agent_obs_only=False, - agent_specific_global_state=True, - continuous_actions=False, - stop_value=0, - collector_env_num=collector_env_num, - evaluator_env_num=evaluator_env_num, - n_evaluator_episode=evaluator_env_num, - manager=dict(shared_memory=False, ), - ), - policy=dict( - multi_agent=multi_agent, - ignore_done=False, - model=dict( - model_type='structure', - latent_state_dim=256, - frame_stack_num=1, - action_space='discrete', - action_space_size=action_space_size, - agent_num=n_agent, - agent_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2, - global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + - n_landmark * 2 + n_agent * (n_agent - 1) * 2, - discrete_action_encoding_type='one_hot', - global_cooperation=True, # TODO: doesn't work now - hidden_size_list=[256, 256], - norm_type='BN', - ), - cuda=True, - mcts_ctree=True, - gumbel_algo=False, - env_type='not_board_games', - game_segment_length=50, - random_collect_episode_num=0, - eps=dict( - eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, - type='linear', - start=1., - end=0.05, - decay=int(2e4), - ), - use_augmentation=False, - update_per_collect=update_per_collect, - batch_size=batch_size, - optim_type='Adam', - lr_piecewise_constant_decay=False, - learning_rate=0.003, - ssl_loss_weight=0, - num_simulations=num_simulations, - reanalyze_ratio=reanalyze_ratio, - n_episode=n_episode, - eval_freq=int(2e3), - replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. - collector_env_num=collector_env_num, - evaluator_env_num=evaluator_env_num, - ), - learn=dict(learner=dict( - log_policy=True, - hook=dict(log_show_after_iter=10, ), - ), ), -) -main_config = EasyDict(main_config) -create_config = dict( - env=dict( - import_names=['zoo.petting_zoo.envs.petting_zoo_simple_spread_env'], - type='petting_zoo', - ), - env_manager=dict(type='subprocess'), - policy=dict( - type='multi_agent_efficientzero', - import_names=['lzero.policy.multi_agent_efficientzero'], - ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) -) -create_config = EasyDict(create_config) -ptz_simple_spread_efficientzero_config = main_config -ptz_simple_spread_efficientzero_create_config = create_config - -if __name__ == '__main__': - from zoo.petting_zoo.entry import train_muzero - train_muzero([main_config, create_config], seed=seed) diff --git a/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py b/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py deleted file mode 100644 index d9d9c653f..000000000 --- a/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py +++ /dev/null @@ -1,117 +0,0 @@ -from easydict import EasyDict - -env_name = 'ptz_simple_spread' -multi_agent = True - -# ============================================================== -# begin of the most frequently changed config specified by the user -# ============================================================== -seed = 0 -n_agent = 3 -n_landmark = n_agent -collector_env_num = 8 -evaluator_env_num = 8 -n_episode = 8 -batch_size = 256 -num_simulations = 50 -update_per_collect = 1000 -reanalyze_ratio = 0. -action_space_size = 5 -eps_greedy_exploration_in_collect = True -# ============================================================== -# end of the most frequently changed config specified by the user -# ============================================================== - -main_config = dict( - exp_name= - f'data_mz_ctree/{env_name}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', - env=dict( - env_family='mpe', - env_id='simple_spread_v2', - n_agent=n_agent, - n_landmark=n_landmark, - max_cycles=25, - agent_obs_only=False, - agent_specific_global_state=True, - continuous_actions=False, - stop_value=0, - collector_env_num=collector_env_num, - evaluator_env_num=evaluator_env_num, - n_evaluator_episode=evaluator_env_num, - manager=dict(shared_memory=False, ), - ), - policy=dict( - multi_agent=multi_agent, - ignore_done=False, - model=dict( - model_type='structure', - latent_state_dim=256, - frame_stack_num=1, - action_space='discrete', - action_space_size=action_space_size, - agent_num=n_agent, - self_supervised_learning_loss=False, # default is False - agent_obs_shape=(2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2)*3, - global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + - n_landmark * 2 + n_agent * (n_agent - 1) * 2, - discrete_action_encoding_type='one_hot', - global_cooperation=True, # TODO: doesn't work now - hidden_size_list=[256, 256], - norm_type='BN', - ), - cuda=True, - mcts_ctree=True, - gumbel_algo=False, - env_type='not_board_games', - game_segment_length=400, - random_collect_episode_num=0, - eps=dict( - eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, - type='linear', - start=1., - end=0.05, - decay=int(1e5), - ), - use_augmentation=False, - update_per_collect=update_per_collect, - batch_size=batch_size, - optim_type='SGD', - lr_piecewise_constant_decay=True, - learning_rate=0.2, - ssl_loss_weight=0, # default is 0 - num_simulations=num_simulations, - reanalyze_ratio=reanalyze_ratio, - n_episode=n_episode, - eval_freq=int(2e3), - replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. - collector_env_num=collector_env_num, - evaluator_env_num=evaluator_env_num, - ), - learn=dict(learner=dict( - log_policy=True, - hook=dict(log_show_after_iter=10, ), - ), ), -) -main_config = EasyDict(main_config) -create_config = dict( - env=dict( - import_names=['zoo.petting_zoo.envs.petting_zoo_simple_spread_env'], - type='petting_zoo', - ), - env_manager=dict(type='subprocess'), - policy=dict( - type='multi_agent_muzero', - import_names=['lzero.policy.multi_agent_muzero'], - ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) -) -create_config = EasyDict(create_config) -ptz_simple_spread_muzero_config = main_config -ptz_simple_spread_muzero_create_config = create_config - -if __name__ == '__main__': - from zoo.petting_zoo.entry import train_muzero - train_muzero([main_config, create_config], seed=seed) diff --git a/zoo/petting_zoo/entry/__init__.py b/zoo/petting_zoo/entry/__init__.py deleted file mode 100644 index 5e8144157..000000000 --- a/zoo/petting_zoo/entry/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .train_muzero import train_muzero -from .eval_muzero import eval_muzero \ No newline at end of file diff --git a/zoo/petting_zoo/entry/eval_muzero.py b/zoo/petting_zoo/entry/eval_muzero.py deleted file mode 100644 index 7eb3e4d17..000000000 --- a/zoo/petting_zoo/entry/eval_muzero.py +++ /dev/null @@ -1,81 +0,0 @@ -import logging -import os -from functools import partial -from typing import Optional, Tuple - -import torch -from ding.config import compile_config -from ding.envs import create_env_manager -from ding.envs import get_vec_env_setting -from ding.policy import create_policy -from ding.utils import set_pkg_seed, get_rank -from ding.rl_utils import get_epsilon_greedy_fn -from ding.worker import BaseLearner -from tensorboardX import SummaryWriter - -from lzero.worker import MuZeroCollector as Collector -from lzero.worker import MuZeroEvaluator as Evaluator -from lzero.entry.utils import log_buffer_memory_usage -from lzero.policy import visit_count_temperature -from lzero.entry.utils import random_collect -from zoo.petting_zoo.model import PettingZooEncoder - -def eval_muzero(main_cfg, create_cfg, seed=0): - assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'multi_agent_efficientzero', 'multi_agent_muzero'], \ - "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'multi_agent_efficientzero', 'multi_agent_muzero'" - - if create_cfg.policy.type == 'muzero' or create_cfg.policy.type == 'multi_agent_muzero': - from lzero.mcts import MuZeroGameBuffer as GameBuffer - from lzero.model.muzero_model_mlp import MuZeroModelMLP as Encoder - elif create_cfg.policy.type == 'efficientzero' or create_cfg.policy.type == 'multi_agent_efficientzero': - from lzero.mcts import EfficientZeroGameBuffer as GameBuffer - from lzero.model.efficientzero_model_mlp import EfficientZeroModelMLP as Encoder - elif create_cfg.policy.type == 'sampled_efficientzero': - from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer - elif create_cfg.policy.type == 'gumbel_muzero': - from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer - - main_cfg.policy.device = 'cpu' - main_cfg.policy.load_path = 'exp_name/ckpt/ckpt_best.pth.tar' - main_cfg.env.replay_path = './' # when visualize must set as base - create_cfg.env_manager.type = 'base' # when visualize must set as base - main_cfg.env.evaluator_env_num = 1 # only 1 env for save replay - main_cfg.env.n_evaluator_episode = 1 - - cfg = compile_config(main_cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) - # Create main components: env, policy - env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) - - evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) - - evaluator_env.seed(cfg.seed, dynamic_seed=False) - set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) - - model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder()) - policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) - policy.eval_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location=cfg.policy.device)) - - # Create worker components: learner, collector, evaluator, replay buffer, commander. - tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None - learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) - - # ============================================================== - # MCTS+RL algorithms related core code - # ============================================================== - policy_config = cfg.policy - evaluator = Evaluator( - eval_freq=cfg.policy.eval_freq, - n_evaluator_episode=cfg.env.n_evaluator_episode, - stop_value=cfg.env.stop_value, - env=evaluator_env, - policy=policy.eval_mode, - tb_logger=tb_logger, - exp_name=cfg.exp_name, - policy_config=policy_config - ) - stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter) - return stop, reward - -if __name__ == '__main__': - from zoo.petting_zoo.config.ptz_simple_spread_ez_config import main_config, create_config - eval_muzero(main_config, create_config, seed=0) \ No newline at end of file diff --git a/zoo/petting_zoo/entry/train_muzero.py b/zoo/petting_zoo/entry/train_muzero.py deleted file mode 100644 index b36dda8f8..000000000 --- a/zoo/petting_zoo/entry/train_muzero.py +++ /dev/null @@ -1,200 +0,0 @@ -import logging -import os -from functools import partial -from typing import Optional, Tuple - -import torch -from ding.config import compile_config -from ding.envs import create_env_manager -from ding.envs import get_vec_env_setting -from ding.policy import create_policy -from ding.utils import set_pkg_seed, get_rank -from ding.rl_utils import get_epsilon_greedy_fn -from ding.worker import BaseLearner -from tensorboardX import SummaryWriter - -from lzero.worker import MuZeroCollector as Collector -from lzero.worker import MuZeroEvaluator as Evaluator -from lzero.entry.utils import log_buffer_memory_usage -from lzero.policy import visit_count_temperature -from lzero.entry.utils import random_collect -from zoo.petting_zoo.model import PettingZooEncoder - -def train_muzero( - input_cfg: Tuple[dict, dict], - seed: int = 0, - model: Optional[torch.nn.Module] = None, - model_path: Optional[str] = None, - max_train_iter: Optional[int] = int(1e10), - max_env_step: Optional[int] = int(1e10), -) -> 'Policy': # noqa - """ - Overview: - The train entry for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. - Arguments: - - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type. - ``Tuple[dict, dict]`` type means [user_config, create_cfg]. - - seed (:obj:`int`): Random seed. - - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. - - model_path (:obj:`Optional[str]`): The pretrained model path, which should - point to the ckpt file of the pretrained model, and an absolute path is recommended. - In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. - - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. - - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. - Returns: - - policy (:obj:`Policy`): Converged policy. - """ - - cfg, create_cfg = input_cfg - assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'multi_agent_efficientzero', 'multi_agent_muzero'], \ - "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'multi_agent_efficientzero', 'multi_agent_muzero'" - - if create_cfg.policy.type == 'muzero' or create_cfg.policy.type == 'multi_agent_muzero': - from lzero.mcts import MuZeroGameBuffer as GameBuffer - from lzero.model.muzero_model_mlp import MuZeroModelMLP as Encoder - elif create_cfg.policy.type == 'efficientzero' or create_cfg.policy.type == 'multi_agent_efficientzero': - from lzero.mcts import EfficientZeroGameBuffer as GameBuffer - from lzero.model.efficientzero_model_mlp import EfficientZeroModelMLP as Encoder - elif create_cfg.policy.type == 'sampled_efficientzero': - from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer - elif create_cfg.policy.type == 'gumbel_muzero': - from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer - - if cfg.policy.cuda and torch.cuda.is_available(): - cfg.policy.device = 'cuda' - else: - cfg.policy.device = 'cpu' - - cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) - # Create main components: env, policy - env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) - - collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) - evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) - - collector_env.seed(cfg.seed) - evaluator_env.seed(cfg.seed, dynamic_seed=False) - set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) - - model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder()) - policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) - - # load pretrained model - if model_path is not None: - policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) - - # Create worker components: learner, collector, evaluator, replay buffer, commander. - tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None - learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) - - # ============================================================== - # MCTS+RL algorithms related core code - # ============================================================== - policy_config = cfg.policy - batch_size = policy_config.batch_size - # specific game buffer for MCTS+RL algorithms - replay_buffer = GameBuffer(policy_config) - - collector = Collector( - env=collector_env, - policy=policy.collect_mode, - tb_logger=tb_logger, - exp_name=cfg.exp_name, - policy_config=policy_config - ) - evaluator = Evaluator( - eval_freq=cfg.policy.eval_freq, - n_evaluator_episode=cfg.env.n_evaluator_episode, - stop_value=cfg.env.stop_value, - env=evaluator_env, - policy=policy.eval_mode, - tb_logger=tb_logger, - exp_name=cfg.exp_name, - policy_config=policy_config - ) - - # ============================================================== - # Main loop - # ============================================================== - # Learner's before_run hook. - learner.call_hook('before_run') - - if cfg.policy.update_per_collect is not None: - update_per_collect = cfg.policy.update_per_collect - - # The purpose of collecting random data before training: - # Exploration: The collection of random data aids the agent in exploring the environment and prevents premature convergence to a suboptimal policy. - # Comparation: The agent's performance during random action-taking can be used as a reference point to evaluate the efficacy of reinforcement learning algorithms. - if cfg.policy.random_collect_episode_num > 0: - if policy_config.multi_agent: - from lzero.policy.multi_agent_random_policy import MultiAgentLightZeroRandomPolicy as RandomPolicy - else: - from lzero.policy.random_policy import LightZeroRandomPolicy as RandomPolicy - random_collect(cfg.policy, policy, RandomPolicy, collector, collector_env, replay_buffer) - - while True: - log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) - collect_kwargs = {} - # set temperature for visit count distributions according to the train_iter, - # please refer to Appendix D in MuZero paper for details. - collect_kwargs['temperature'] = visit_count_temperature( - policy_config.manual_temperature_decay, - policy_config.fixed_temperature_value, - policy_config.threshold_training_steps_for_final_temperature, - trained_steps=learner.train_iter - ) - - if policy_config.eps.eps_greedy_exploration_in_collect: - epsilon_greedy_fn = get_epsilon_greedy_fn( - start=policy_config.eps.start, - end=policy_config.eps.end, - decay=policy_config.eps.decay, - type_=policy_config.eps.type - ) - collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) - else: - collect_kwargs['epsilon'] = 0.0 - - # Evaluate policy performance. - if evaluator.should_eval(learner.train_iter): - stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) - if stop: - break - - # Collect data by default config n_sample/n_episode. - new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) - if cfg.policy.update_per_collect is None: - # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio. - collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]]) - update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio) - # save returned new_data collected by the collector - replay_buffer.push_game_segments(new_data) - # remove the oldest data if the replay buffer is full. - replay_buffer.remove_oldest_data_to_fit() - - # Learn policy from collected data. - for i in range(update_per_collect): - # Learner will train ``update_per_collect`` times in one iteration. - if replay_buffer.get_num_of_transitions() > batch_size: - train_data = replay_buffer.sample(batch_size, policy) - else: - logging.warning( - f'The data in replay_buffer is not sufficient to sample a mini-batch: ' - f'batch_size: {batch_size}, ' - f'{replay_buffer} ' - f'continue to collect now ....' - ) - break - - # The core train steps for MCTS+RL algorithms. - log_vars = learner.train(train_data, collector.envstep) - - if cfg.policy.use_priority: - replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) - - if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: - break - - # Learner's after_run hook. - learner.call_hook('after_run') - return policy diff --git a/zoo/petting_zoo/envs/__init__.py b/zoo/petting_zoo/envs/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py deleted file mode 100644 index 64cae850a..000000000 --- a/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py +++ /dev/null @@ -1,397 +0,0 @@ -from typing import Any, List, Union, Optional, Dict -import gymnasium as gym -import numpy as np -import pettingzoo -from functools import reduce - -from ding.envs import BaseEnv, BaseEnvTimestep, FrameStackWrapper -from ding.torch_utils import to_ndarray, to_list -from ding.envs.common.common_function import affine_transform -from ding.utils import ENV_REGISTRY, import_module -from pettingzoo.utils.conversions import parallel_wrapper_fn -from pettingzoo.mpe._mpe_utils.simple_env import SimpleEnv, make_env -from pettingzoo.mpe.simple_spread.simple_spread import Scenario -from PIL import Image -import pygame - - -@ENV_REGISTRY.register('petting_zoo') -class PettingZooEnv(BaseEnv): - # Now only supports simple_spread_v2. - # All agents' observations should have the same shape. - - def __init__(self, cfg: dict) -> None: - self._cfg = cfg - self._init_flag = False - self._replay_path = self._cfg.get('replay_path', None) - self.frame_list = [] - self._env_family = self._cfg.env_family - self._env_id = self._cfg.env_id - self._num_agents = self._cfg.n_agent - self._num_landmarks = self._cfg.n_landmark - self._continuous_actions = self._cfg.get('continuous_actions', False) - self._max_cycles = self._cfg.get('max_cycles', 25) - self._act_scale = self._cfg.get('act_scale', False) - self._agent_specific_global_state = self._cfg.get('agent_specific_global_state', False) - if self._act_scale: - assert self._continuous_actions, 'Only continuous action space env needs act_scale' - - def reset(self) -> np.ndarray: - if not self._init_flag: - # In order to align with the simple spread in Multiagent Particle Env (MPE), - # instead of adopting the pettingzoo interface directly, - # we have redefined the way rewards are calculated - - # import_module(['pettingzoo.{}.{}'.format(self._env_family, self._env_id)]) - # self._env = pettingzoo.__dict__[self._env_family].__dict__[self._env_id].parallel_env( - # N=self._cfg.n_agent, continuous_actions=self._continuous_actions, max_cycles=self._max_cycles - # ) - - # init parallel_env wrapper - _env = make_env(simple_spread_raw_env) - parallel_env = parallel_wrapper_fn(_env) - # init env - self._env = parallel_env( - N=self._cfg.n_agent, continuous_actions=self._continuous_actions, max_cycles=self._max_cycles - ) - # dynamic seed reduces training speed greatly - # if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: - # np_seed = 100 * np.random.randint(1, 1000) - # self._env.seed(self._seed + np_seed) - # if self._replay_path is not None: - # self._env = gym.wrappers.Monitor( - # self._env, self._replay_path, video_callable=lambda episode_id: True, force=True - # ) - if hasattr(self, '_seed'): - obs = self._env.reset(seed=self._seed) - else: - obs = self._env.reset() - if not self._init_flag: - self._agents = self._env.agents - - self._action_space = gym.spaces.Dict({agent: self._env.action_space(agent) for agent in self._agents}) - single_agent_obs_space = self._env.action_space(self._agents[0]) - if isinstance(single_agent_obs_space, gym.spaces.Box): - self._action_dim = single_agent_obs_space.shape - elif isinstance(single_agent_obs_space, gym.spaces.Discrete): - self._action_dim = (single_agent_obs_space.n, ) - else: - raise Exception('Only support `Box` or `Discrete` obs space for single agent.') - - # only for env 'simple_spread_v2', n_agent = 5 - # now only for the case that each agent in the team have the same obs structure and corresponding shape. - if not self._cfg.agent_obs_only: - self._observation_space = gym.spaces.Dict( - { - 'agent_state': gym.spaces.Box( - low=float("-inf"), - high=float("inf"), - shape=(self._num_agents, - self._env.observation_space('agent_0').shape[0]), # (self._num_agents, 30) - dtype=np.float32 - ), - 'global_state': gym.spaces.Box( - low=float("-inf"), - high=float("inf"), - shape=( - 4 * self._num_agents + 2 * self._num_landmarks + 2 * self._num_agents * - (self._num_agents - 1), - ), - dtype=np.float32 - ), - 'agent_alone_state': gym.spaces.Box( - low=float("-inf"), - high=float("inf"), - shape=(self._num_agents, 4 + 2 * self._num_landmarks + 2 * (self._num_agents - 1)), - dtype=np.float32 - ), - 'agent_alone_padding_state': gym.spaces.Box( - low=float("-inf"), - high=float("inf"), - shape=(self._num_agents, - self._env.observation_space('agent_0').shape[0]), # (self._num_agents, 30) - dtype=np.float32 - ), - 'action_mask': gym.spaces.Box( - low=float("-inf"), - high=float("inf"), - shape=(self._num_agents, self._action_dim[0]), # (self._num_agents, 5) - dtype=np.float32 - ) - } - ) - # whether use agent_specific_global_state. It is usually used in AC multiagent algos, e.g., mappo, masac, etc. - if self._agent_specific_global_state: - agent_specifig_global_state = gym.spaces.Box( - low=float("-inf"), - high=float("inf"), - shape=( - self._num_agents, self._env.observation_space('agent_0').shape[0] + 4 * self._num_agents + - 2 * self._num_landmarks + 2 * self._num_agents * (self._num_agents - 1) - ), - dtype=np.float32 - ) - self._observation_space['global_state'] = agent_specifig_global_state - else: - # for case when env.agent_obs_only=True - self._observation_space = gym.spaces.Box( - low=float("-inf"), - high=float("inf"), - shape=(self._num_agents, self._env.observation_space('agent_0').shape[0]), - dtype=np.float32 - ) - - self._reward_space = gym.spaces.Dict( - { - agent: gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(1, ), dtype=np.float32) - for agent in self._agents - } - ) - self._init_flag = True - # self._eval_episode_return = {agent: 0. for agent in self._agents} - self._eval_episode_return = 0. - self._step_count = 0 - obs_n = self._process_obs(obs) - return obs_n - - def close(self) -> None: - if self._init_flag: - self._env.close() - self._init_flag = False - - def render(self) -> None: - self._env.render() - - def seed(self, seed: int, dynamic_seed: bool = True) -> None: - self._seed = seed - self._dynamic_seed = dynamic_seed - np.random.seed(self._seed) - - def step(self, action: dict) -> BaseEnvTimestep: - self._step_count += 1 - if isinstance(action, dict): - action = np.array(list(action.values())) - else: - action = np.array(action) - action = self._process_action(action) - if self._act_scale: - for agent in self._agents: - # print(action[agent]) - # print(self.action_space[agent]) - # print(self.action_space[agent].low, self.action_space[agent].high) - action[agent] = affine_transform( - action[agent], min_val=self.action_space[agent].low, max_val=self.action_space[agent].high - ) - - obs, rew, done, trunc, info = self._env.step(action) - obs_n = self._process_obs(obs) - rew_n = np.array([sum([rew[agent] for agent in self._agents])]) - rew_n = rew_n.astype(np.float32) - # collide_sum = 0 - # for i in range(self._num_agents): - # collide_sum += info['n'][i][1] - # collide_penalty = self._cfg.get('collide_penal', self._num_agent) - # rew_n += collide_sum * (1.0 - collide_penalty) - # rew_n = rew_n / (self._cfg.get('max_cycles', 25) * self._num_agent) - self._eval_episode_return += rew_n.item() - - # occupied_landmarks = info['n'][0][3] - # if self._step_count >= self._max_step or occupied_landmarks >= self._n_agent \ - # or occupied_landmarks >= self._num_landmarks: - # done_n = True - # else: - # done_n = False - done_n = reduce(lambda x, y: x and y, done.values()) or self._step_count >= self._max_cycles - - # for agent in self._agents: - # self._eval_episode_return[agent] += rew[agent] - if self._replay_path is not None: - self.frame_list.append(Image.fromarray(self._env.render())) - if done_n: # or reduce(lambda x, y: x and y, done.values()) - info['eval_episode_return'] = self._eval_episode_return - if self._replay_path is not None: - self.frame_list[0].save('out.gif', save_all=True, append_images=self.frame_list[1:], duration=3, loop=0) - # for agent in rew: - # rew[agent] = to_ndarray([rew[agent]]) - return BaseEnvTimestep(obs_n, rew_n, done_n, info) - - def enable_save_replay(self, replay_path: Optional[str] = None) -> None: - if replay_path is None: - replay_path = './video' - self._replay_path = replay_path - - def _process_obs(self, obs: 'torch.Tensor') -> np.ndarray: # noqa - obs = np.array([obs[agent] for agent in self._agents]).astype(np.float32) - if self._cfg.get('agent_obs_only', False): - return obs - ret = {} - # Raw agent observation structure is -- - # [self_vel, self_pos, landmark_rel_positions, other_agent_rel_positions, communication] - # where `communication` are signals from other agents (two for each agent in `simple_spread_v2`` env) - - # agent_state: Shape (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2). - # Stacked observation. Contains - # - agent itself's state(velocity + position) - # - position of items that the agent can observe(e.g. other agents, landmarks) - # - communication - ret['agent_state'] = obs - # global_state: Shape (n_agent * (2 + 2) + n_landmark * 2 + n_agent * (n_agent - 1) * 2, ). - # 1-dim vector. Contains - # - all agents' state(velocity + position) + - # - all landmarks' position + - # - all agents' communication - ret['global_state'] = np.concatenate( - [ - obs[0, 2:-(self._num_agents - 1) * 2], # all agents' position + all landmarks' position - obs[:, 0:2].flatten(), # all agents' velocity - obs[:, -(self._num_agents - 1) * 2:].flatten() # all agents' communication - ] - ) - # agent_specific_global_state: Shape (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + n_landmark * 2 + n_agent * (n_agent - 1) * 2). - # 2-dim vector. contains - # - agent_state info - # - global_state info - if self._agent_specific_global_state: - ret['global_state'] = np.concatenate( - [ret['agent_state'], - np.expand_dims(ret['global_state'], axis=0).repeat(self._num_agents, axis=0)], - axis=1 - ) - # agent_alone_state: Shape (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2). - # Stacked observation. Exclude other agents' positions from agent_state. Contains - # - agent itself's state(velocity + position) + - # - landmarks' positions (do not include other agents' positions) - # - communication - ret['agent_alone_state'] = np.concatenate( - [ - obs[:, 0:(4 + self._num_agents * 2)], # agent itself's state + landmarks' position - obs[:, -(self._num_agents - 1) * 2:], # communication - ], - 1 - ) - # agent_alone_padding_state: Shape (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2). - # Contains the same information as agent_alone_state; - # But 0-padding other agents' positions. - ret['agent_alone_padding_state'] = np.concatenate( - [ - obs[:, 0:(4 + self._num_agents * 2)], # agent itself's state + landmarks' position - np.zeros((self._num_agents, - (self._num_agents - 1) * 2), np.float32), # Other agents' position(0-padding) - obs[:, -(self._num_agents - 1) * 2:] # communication - ], - 1 - ) - # action_mask: All actions are of use(either 1 for discrete or 5 for continuous). Thus all 1. - # action_mask = np.ones((self._num_agents, *self._action_dim)).astype(np.float32) - action_mask = [[1 for _ in range(*self._action_dim)] for _ in range(self._num_agents)] - to_play = [-1 for _ in range(self._num_agents)] # Moot, for alignment with other environments - - ret_transform = [] - for i in range(self._num_agents): - tmp = {} - for k,v in ret.items(): - tmp[k] = v[i] - tmp['action_mask'] = [1 for _ in range(*self._action_dim)] - ret_transform.append(tmp) - return {'observation': ret_transform, 'action_mask': action_mask, 'to_play': to_play} - - def _process_action(self, action: 'torch.Tensor') -> Dict[str, np.ndarray]: # noqa - dict_action = {} - for i, agent in enumerate(self._agents): - agent_action = action[i] - if agent_action.shape == (1, ): - agent_action = agent_action.squeeze() # 0-dim array - dict_action[agent] = agent_action - return dict_action - - def random_action(self) -> np.ndarray: - random_action = self.action_space.sample() - for k in random_action: - if isinstance(random_action[k], np.ndarray): - pass - elif isinstance(random_action[k], int): - random_action[k] = to_ndarray([random_action[k]], dtype=np.int64) - return random_action - - def __repr__(self) -> str: - return "DI-engine PettingZoo Env" - - @property - def agents(self) -> List[str]: - return self._agents - - @property - def observation_space(self) -> gym.spaces.Space: - return self._observation_space - - @property - def action_space(self) -> gym.spaces.Space: - return self._action_space - - @property - def reward_space(self) -> gym.spaces.Space: - return self._reward_space - - -class simple_spread_raw_env(SimpleEnv): - - def __init__(self, N=3, local_ratio=0.5, max_cycles=25, continuous_actions=False): - assert 0. <= local_ratio <= 1., "local_ratio is a proportion. Must be between 0 and 1." - scenario = Scenario() - world = scenario.make_world(N) - super().__init__(scenario, world, max_cycles, continuous_actions=continuous_actions, local_ratio=local_ratio) - self.render_mode = 'rgb_array' - self.metadata['name'] = "simple_spread_v2" - - def _execute_world_step(self): - # set action for each agent - for i, agent in enumerate(self.world.agents): - action = self.current_actions[i] - scenario_action = [] - if agent.movable: - mdim = self.world.dim_p * 2 + 1 - if self.continuous_actions: - scenario_action.append(action[0:mdim]) - action = action[mdim:] - else: - scenario_action.append(action % mdim) - action //= mdim - if not agent.silent: - scenario_action.append(action) - self._set_action(scenario_action, agent, self.action_spaces[agent.name]) - - self.world.step() - - global_reward = 0. - if self.local_ratio is not None: - global_reward = float(self.scenario.global_reward(self.world)) - - for agent in self.world.agents: - agent_reward = float(self.scenario.reward(agent, self.world)) - if self.local_ratio is not None: - # we changed reward calc way to keep same with mpe - # reward = global_reward * (1 - self.local_ratio) + agent_reward * self.local_ratio - reward = global_reward + agent_reward - else: - reward = agent_reward - - self.rewards[agent.name] = reward - - def render(self): - if self.render_mode is None: - gym.logger.warn( - "You are calling render method without specifying any render mode." - ) - return - - self.enable_render(self.render_mode) - - self.draw() - observation = np.array(pygame.surfarray.pixels3d(self.screen)) - if self.render_mode == "human": - pygame.display.flip() - return ( - np.transpose(observation, axes=(1, 0, 2)) - if self.render_mode == "rgb_array" - else None - ) diff --git a/zoo/petting_zoo/envs/test_petting_zoo_simple_spread_env.py b/zoo/petting_zoo/envs/test_petting_zoo_simple_spread_env.py deleted file mode 100644 index 22117cf85..000000000 --- a/zoo/petting_zoo/envs/test_petting_zoo_simple_spread_env.py +++ /dev/null @@ -1,133 +0,0 @@ -from easydict import EasyDict -import pytest -import numpy as np -import pettingzoo -from ding.utils import import_module - -from dizoo.petting_zoo.envs.petting_zoo_simple_spread_env import PettingZooEnv - - -@pytest.mark.envtest -class TestPettingZooEnv: - - def test_agent_obs_only(self): - n_agent = 5 - n_landmark = n_agent - env = PettingZooEnv( - EasyDict( - dict( - env_family='mpe', - env_id='simple_spread_v2', - n_agent=n_agent, - n_landmark=n_landmark, - max_step=100, - agent_obs_only=True, - continuous_actions=True, - ) - ) - ) - env.seed(123) - assert env._seed == 123 - obs = env.reset() - assert obs.shape == (n_agent, 2 + 2 + (n_agent - 1) * 2 + n_agent * 2 + (n_agent - 1) * 2) - for i in range(10): - random_action = env.random_action() - random_action = np.array([random_action[agent] for agent in random_action]) - timestep = env.step(random_action) - print(timestep) - assert isinstance(timestep.obs, np.ndarray), timestep.obs - assert timestep.obs.shape == (n_agent, 2 + 2 + (n_agent - 1) * 2 + n_agent * 2 + (n_agent - 1) * 2) - assert isinstance(timestep.done, bool), timestep.done - assert isinstance(timestep.reward, np.ndarray), timestep.reward - assert timestep.reward.dtype == np.float32 - print(env.observation_space, env.action_space, env.reward_space) - env.close() - - def test_dict_obs(self): - n_agent = 5 - n_landmark = n_agent - env = PettingZooEnv( - EasyDict( - dict( - env_family='mpe', - env_id='simple_spread_v2', - n_agent=n_agent, - n_landmark=n_landmark, - max_step=100, - agent_obs_only=False, - continuous_actions=True, - ) - ) - ) - env.seed(123) - assert env._seed == 123 - obs = env.reset() - for k, v in obs.items(): - print(k, v.shape) - for i in range(10): - random_action = env.random_action() - random_action = np.array([random_action[agent] for agent in random_action]) - timestep = env.step(random_action) - print(timestep) - assert isinstance(timestep.obs, dict), timestep.obs - assert isinstance(timestep.obs['agent_state'], np.ndarray), timestep.obs - assert timestep.obs['agent_state'].shape == ( - n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 - ) - assert timestep.obs['global_state'].shape == ( - n_agent * (2 + 2) + n_landmark * 2 + n_agent * (n_agent - 1) * 2, - ) - assert timestep.obs['agent_alone_state'].shape == (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2) - assert timestep.obs['agent_alone_padding_state'].shape == ( - n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 - ) - assert timestep.obs['action_mask'].dtype == np.float32 - assert isinstance(timestep.done, bool), timestep.done - assert isinstance(timestep.reward, np.ndarray), timestep.reward - print(env.observation_space, env.action_space, env.reward_space) - env.close() - - def test_agent_specific_global_state(self): - n_agent = 5 - n_landmark = n_agent - env = PettingZooEnv( - EasyDict( - dict( - env_family='mpe', - env_id='simple_spread_v2', - n_agent=n_agent, - n_landmark=n_landmark, - max_step=100, - agent_obs_only=False, - agent_specific_global_state=True, - continuous_actions=True, - ) - ) - ) - env.seed(123) - assert env._seed == 123 - obs = env.reset() - for k, v in obs.items(): - print(k, v.shape) - for i in range(10): - random_action = env.random_action() - random_action = np.array([random_action[agent] for agent in random_action]) - timestep = env.step(random_action) - print(timestep) - assert isinstance(timestep.obs, dict), timestep.obs - assert isinstance(timestep.obs['agent_state'], np.ndarray), timestep.obs - assert timestep.obs['agent_state'].shape == ( - n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 - ) - assert timestep.obs['global_state'].shape == ( - n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + - n_landmark * 2 + n_agent * (n_agent - 1) * 2 - ) - assert timestep.obs['agent_alone_state'].shape == (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2) - assert timestep.obs['agent_alone_padding_state'].shape == ( - n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 - ) - assert isinstance(timestep.done, bool), timestep.done - assert isinstance(timestep.reward, np.ndarray), timestep.reward - print(env.observation_space, env.action_space, env.reward_space) - env.close() diff --git a/zoo/petting_zoo/model/__init__.py b/zoo/petting_zoo/model/__init__.py deleted file mode 100644 index 821e014a8..000000000 --- a/zoo/petting_zoo/model/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .model import PettingZooEncoder \ No newline at end of file diff --git a/zoo/petting_zoo/model/model.py b/zoo/petting_zoo/model/model.py deleted file mode 100644 index 273a423d3..000000000 --- a/zoo/petting_zoo/model/model.py +++ /dev/null @@ -1,13 +0,0 @@ -import torch.nn as nn -from ding.model.common import FCEncoder - -class PettingZooEncoder(nn.Module): - - def __init__(self): - super().__init__() - self.encoder = FCEncoder(obs_shape=6, hidden_size_list=[256, 256], activation=nn.ReLU(), norm_type=None) - - def forward(self, x): - x = x['agent_state'] - x = self.encoder(x) - return x \ No newline at end of file