diff --git a/lzero/entry/train_muzero.py b/lzero/entry/train_muzero.py index 5e3976063..f18681b53 100644 --- a/lzero/entry/train_muzero.py +++ b/lzero/entry/train_muzero.py @@ -50,9 +50,9 @@ def train_muzero( assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'], \ "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero'" - if create_cfg.policy.type == 'muzero': + if create_cfg.policy.type == 'muzero' or create_cfg.policy.type == 'multi_agent_muzero': from lzero.mcts import MuZeroGameBuffer as GameBuffer - elif create_cfg.policy.type == 'efficientzero': + elif create_cfg.policy.type == 'efficientzero' or create_cfg.policy.type == 'multi_agent_efficientzero': from lzero.mcts import EfficientZeroGameBuffer as GameBuffer elif create_cfg.policy.type == 'sampled_efficientzero': from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer @@ -125,7 +125,11 @@ def train_muzero( # Exploration: Collecting random data helps the agent explore the environment and avoid getting stuck in a suboptimal policy prematurely. # Comparison: By observing the agent's performance during random action-taking, we can establish a baseline to evaluate the effectiveness of reinforcement learning algorithms. if cfg.policy.random_collect_episode_num > 0: - random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer) + if policy_config.multi_agent: + from lzero.policy.multi_agent_random_policy import MultiAgentLightZeroRandomPolicy as RandomPolicy + else: + from lzero.policy.random_policy import LightZeroRandomPolicy as RandomPolicy + random_collect(cfg.policy, policy, RandomPolicy, collector, collector_env, replay_buffer) while True: log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) @@ -192,4 +196,4 @@ def train_muzero( # Learner's after_run hook. learner.call_hook('after_run') - return policy + return policy \ No newline at end of file diff --git a/lzero/mcts/buffer/game_buffer_efficientzero.py b/lzero/mcts/buffer/game_buffer_efficientzero.py index 4ab12259e..cad35a658 100644 --- a/lzero/mcts/buffer/game_buffer_efficientzero.py +++ b/lzero/mcts/buffer/game_buffer_efficientzero.py @@ -9,6 +9,8 @@ from lzero.mcts.utils import prepare_observation from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform from .game_buffer_muzero import MuZeroGameBuffer +from ding.torch_utils import to_device, to_tensor, to_ndarray +from ding.utils.data import default_collate @BUFFER_REGISTRY.register('game_buffer_efficientzero') @@ -44,6 +46,8 @@ def __init__(self, cfg: dict): self.base_idx = 0 self.clear_time = 0 + self.tmp_obs = None # since value obs list [46 + 4(td_step)] >= 50(game_segment), need pad + def sample(self, batch_size: int, policy: Any) -> List[Any]: """ Overview: @@ -100,7 +104,6 @@ def _prepare_reward_value_context( - reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, to_play_segment """ - zero_obs = game_segment_list[0].zero_obs() value_obs_list = [] # the value is valid or not (out of trajectory) value_mask = [] @@ -148,11 +151,12 @@ def _prepare_reward_value_context( end_index = beg_index + self._cfg.model.frame_stack_num # the stacked obs in time t obs = game_obs[beg_index:end_index] + self.tmp_obs = obs # will be masked else: value_mask.append(0) - obs = zero_obs + obs = self.tmp_obs # will be masked - value_obs_list.append(obs) + value_obs_list.append(obs.tolist()) reward_value_context = [ value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, @@ -196,7 +200,13 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A beg_index = self._cfg.mini_infer_size * i end_index = self._cfg.mini_infer_size * (i + 1) - m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() + if self._cfg.model.model_type and self._cfg.model.model_type in ['conv', 'mlp']: + m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() + elif self._cfg.model.model_type and self._cfg.model.model_type == 'structure': + m_obs = value_obs_list[beg_index:end_index] + m_obs = sum(m_obs, []) + m_obs = default_collate(m_obs) + m_obs = to_device(m_obs, self._cfg.device) # calculate the target value m_output = model.initial_inference(m_obs) diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index daddf6f9f..9243b0958 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -9,6 +9,8 @@ from lzero.mcts.utils import prepare_observation from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform from .game_buffer import GameBuffer +from ding.torch_utils import to_device, to_tensor +from ding.utils.data import default_collate if TYPE_CHECKING: from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy @@ -48,6 +50,8 @@ def __init__(self, cfg: dict): self.game_pos_priorities = [] self.game_segment_game_pos_look_up = [] + self.tmp_obs = None # a tmp value which records obs when value obs list [current_index + 4(td_step)] > 50(game_segment) + def sample( self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] ) -> List[Any]: @@ -198,7 +202,6 @@ def _prepare_reward_value_context( - reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, to_play_segment """ - zero_obs = game_segment_list[0].zero_obs() value_obs_list = [] # the value is valid or not (out of game_segment) value_mask = [] @@ -238,11 +241,12 @@ def _prepare_reward_value_context( end_index = beg_index + self._cfg.model.frame_stack_num # the stacked obs in time t obs = game_obs[beg_index:end_index] + self.tmp_obs = obs # will be masked else: value_mask.append(0) - obs = zero_obs + obs = self.tmp_obs # will be masked - value_obs_list.append(obs) + value_obs_list.append(obs.tolist()) reward_value_context = [ value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, @@ -376,8 +380,14 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A for i in range(slices): beg_index = self._cfg.mini_infer_size * i end_index = self._cfg.mini_infer_size * (i + 1) - - m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() + + if self._cfg.model.model_type and self._cfg.model.model_type in ['conv', 'mlp']: + m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() + elif self._cfg.model.model_type and self._cfg.model.model_type == 'structure': + m_obs = value_obs_list[beg_index:end_index] + m_obs = sum(m_obs, []) + m_obs = default_collate(m_obs) + m_obs = to_device(m_obs, self._cfg.device) # calculate the target value m_output = model.initial_inference(m_obs) diff --git a/lzero/mcts/utils.py b/lzero/mcts/utils.py index 40811fa08..ede8e57a0 100644 --- a/lzero/mcts/utils.py +++ b/lzero/mcts/utils.py @@ -92,7 +92,7 @@ def prepare_observation(observation_list, model_type='conv'): - observation_list (:obj:`List`): list of observations. - model_type (:obj:`str`): type of the model. (default is 'conv') """ - assert model_type in ['conv', 'mlp'] + assert model_type in ['conv', 'mlp', 'structure'] observation_array = np.array(observation_list) if model_type == 'conv': @@ -127,6 +127,9 @@ def prepare_observation(observation_list, model_type='conv'): observation_array = observation_array.reshape(observation_array.shape[0], -1) # print(observation_array.shape) + elif model_type == 'structure': + return observation_list + return observation_array diff --git a/lzero/model/efficientzero_model_mlp.py b/lzero/model/efficientzero_model_mlp.py index a491cdb75..de852350e 100644 --- a/lzero/model/efficientzero_model_mlp.py +++ b/lzero/model/efficientzero_model_mlp.py @@ -4,6 +4,8 @@ import torch.nn as nn from ding.torch_utils import MLP from ding.utils import MODEL_REGISTRY, SequenceType +from ding.utils.default_helper import get_shape0 + from numpy import ndarray from .common import EZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP @@ -36,6 +38,7 @@ def __init__( norm_type: Optional[str] = 'BN', discrete_action_encoding_type: str = 'one_hot', res_connection_in_dynamics: bool = False, + state_encoder=None, *args, **kwargs, ): @@ -104,9 +107,12 @@ def __init__( self.state_norm = state_norm self.res_connection_in_dynamics = res_connection_in_dynamics - self.representation_network = RepresentationNetworkMLP( - observation_shape=observation_shape, hidden_channels=latent_state_dim, norm_type=norm_type - ) + if state_encoder == None: + self.representation_network = RepresentationNetworkMLP( + observation_shape=observation_shape, hidden_channels=latent_state_dim, norm_type=norm_type + ) + else: + self.representation_network = state_encoder self.dynamics_network = DynamicsNetworkMLP( action_encoding_dim=self.action_encoding_dim, @@ -171,15 +177,16 @@ def initial_inference(self, obs: torch.Tensor) -> EZNetworkOutput: - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size. """ - batch_size = obs.size(0) + batch_size = get_shape0(obs) latent_state = self._representation(obs) + device = latent_state.device policy_logits, value = self._prediction(latent_state) # zero initialization for reward hidden states # (hn, cn), each element shape is (layer_num=1, batch_size, lstm_hidden_size) reward_hidden_state = ( torch.zeros(1, batch_size, - self.lstm_hidden_size).to(obs.device), torch.zeros(1, batch_size, - self.lstm_hidden_size).to(obs.device) + self.lstm_hidden_size).to(device), torch.zeros(1, batch_size, + self.lstm_hidden_size).to(device) ) return EZNetworkOutput(value, [0. for _ in range(batch_size)], policy_logits, latent_state, reward_hidden_state) diff --git a/lzero/model/muzero_model_mlp.py b/lzero/model/muzero_model_mlp.py index caf1df15d..cc1707909 100644 --- a/lzero/model/muzero_model_mlp.py +++ b/lzero/model/muzero_model_mlp.py @@ -4,6 +4,8 @@ import torch.nn as nn from ding.torch_utils import MLP from ding.utils import MODEL_REGISTRY, SequenceType +from ding.utils.default_helper import get_shape0 + from .common import MZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean @@ -34,6 +36,7 @@ def __init__( discrete_action_encoding_type: str = 'one_hot', norm_type: Optional[str] = 'BN', res_connection_in_dynamics: bool = False, + state_encoder=None, *args, **kwargs ): @@ -66,6 +69,7 @@ def __init__( - discrete_action_encoding_type (:obj:`str`): The encoding type of discrete action, which can be 'one_hot' or 'not_one_hot'. - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. + - state_encoder (:obj:`Optional[nn.Module]`): The state encoder network, which is used to encode the raw observation to latent state. """ super(MuZeroModelMLP, self).__init__() self.categorical_distribution = categorical_distribution @@ -101,9 +105,12 @@ def __init__( self.state_norm = state_norm self.res_connection_in_dynamics = res_connection_in_dynamics - self.representation_network = RepresentationNetworkMLP( - observation_shape=observation_shape, hidden_channels=self.latent_state_dim, norm_type=norm_type - ) + if state_encoder == None: + self.representation_network = RepresentationNetworkMLP( + observation_shape=observation_shape, hidden_channels=latent_state_dim, norm_type=norm_type + ) + else: + self.representation_network = state_encoder self.dynamics_network = DynamicsNetwork( action_encoding_dim=self.action_encoding_dim, @@ -166,7 +173,7 @@ def initial_inference(self, obs: torch.Tensor) -> MZNetworkOutput: - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. """ - batch_size = obs.size(0) + batch_size = get_shape0(obs) latent_state = self._representation(obs) policy_logits, value = self._prediction(latent_state) return MZNetworkOutput( diff --git a/lzero/policy/efficientzero.py b/lzero/policy/efficientzero.py index 685e473d7..67a720f77 100644 --- a/lzero/policy/efficientzero.py +++ b/lzero/policy/efficientzero.py @@ -17,7 +17,13 @@ DiscreteSupport, select_action, to_torch_float_tensor, ez_network_output_unpack, negative_cosine_similarity, \ prepare_obs, \ configure_optimizers +<<<<<<< HEAD +from collections import defaultdict +from ding.torch_utils import to_device, to_tensor +from ding.utils.data import default_collate +======= from lzero.policy.muzero import MuZeroPolicy +>>>>>>> origin @POLICY_REGISTRY.register('efficientzero') @@ -191,6 +197,9 @@ class EfficientZeroPolicy(MuZeroPolicy): # (int) The decay steps from start to end eps. decay=int(1e5), ), + + # (bool) Whether it is a multi-agent environment. + multi_agent=False, ) def default_model(self) -> Tuple[str, List[str]]: @@ -309,7 +318,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: target_value_prefix = target_value_prefix.view(self._cfg.batch_size, -1) target_value = target_value.view(self._cfg.batch_size, -1) - assert obs_batch.size(0) == self._cfg.batch_size == target_value_prefix.size(0) + assert self._cfg.batch_size == target_value_prefix.size(0) # ``scalar_transform`` to transform the original value to the scaled value, # i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. @@ -395,9 +404,40 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # calculate consistency loss for the next ``num_unroll_steps`` unroll steps. # ============================================================== if self._cfg.ssl_loss_weight > 0: - # obtain the oracle latent states from representation function. - beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) - network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) + # obtain the oracle hidden states from representation function. + if self._cfg.model.model_type == 'conv': + beg_index = self._cfg.model.image_channel * step_i + end_index = self._cfg.model.image_channel * (step_i + self._cfg.model.frame_stack_num) + network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index, :, :]) + elif self._cfg.model.model_type == 'mlp': + beg_index = self._cfg.model.observation_shape * step_i + end_index = self._cfg.model.observation_shape * (step_i + self._cfg.model.frame_stack_num) + network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) + elif self._cfg.model.model_type == 'structure': + obs_target_batch_new = {} + for k, v in obs_target_batch.items(): + if k == 'action_mask': + obs_target_batch_new[k] = v + continue + if isinstance(v, dict): + obs_target_batch_new[k] = {} + for k1, v1 in v.items(): + if len(v1.shape) == 1: + observation_shape = v1.shape[0]//self._cfg.num_unroll_steps + beg_index = observation_shape * step_i + end_index = observation_shape * (step_i + self._cfg.model.frame_stack_num) + obs_target_batch_new[k][k1] = v1[beg_index:end_index] + else: + observation_shape = v1.shape[1]//self._cfg.num_unroll_steps + beg_index = observation_shape * step_i + end_index = observation_shape * (step_i + self._cfg.model.frame_stack_num) + obs_target_batch_new[k][k1] = v1[:, beg_index:end_index] + else: + observation_shape = v.shape[1]//self._cfg.num_unroll_steps + beg_index = observation_shape * step_i + end_index = observation_shape * (step_i + self._cfg.model.frame_stack_num) + obs_target_batch_new[k] = v[:, beg_index:end_index] + network_output = self._learn_model.initial_inference(obs_target_batch_new) latent_state = to_tensor(latent_state) representation_state = to_tensor(network_output.latent_state) @@ -735,6 +775,7 @@ def _monitor_vars_learn(self) -> List[str]: """ return [ 'collect_mcts_temperature', + 'collect_epsilon', 'cur_lr', 'weighted_total_loss', 'total_loss', diff --git a/lzero/policy/multi_agent_efficientzero.py b/lzero/policy/multi_agent_efficientzero.py new file mode 100644 index 000000000..49426816e --- /dev/null +++ b/lzero/policy/multi_agent_efficientzero.py @@ -0,0 +1,225 @@ +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +from .efficientzero import EfficientZeroPolicy +from ding.utils import POLICY_REGISTRY + +from lzero.mcts import EfficientZeroMCTSCtree as MCTSCtree +from lzero.mcts import EfficientZeroMCTSPtree as MCTSPtree +from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ + DiscreteSupport, select_action, to_torch_float_tensor, ez_network_output_unpack, negative_cosine_similarity, prepare_obs, \ + configure_optimizers +from collections import defaultdict +from ding.utils.data import default_collate +from ding.torch_utils import to_device, to_tensor + + +@POLICY_REGISTRY.register('multi_agent_efficientzero') +class MultiAgentEfficientZeroPolicy(EfficientZeroPolicy): + """ + Overview: + The policy class for Multi Agent EfficientZero. + Independent Learning mode is a method in which each agent learns and adapts to the environment independently \ + without directly considering the learning and strategies of other agents. + """ + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id = None + ): + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + self.collect_mcts_temperature = temperature + self.collect_epsilon = epsilon + + active_collect_env_num = len(data) + data = sum(sum(data, []), []) + batch_size = len(data) + data = default_collate(data) + data = to_device(data, self._device) + agent_num = self._cfg['model']['agent_num'] + action_mask = sum(action_mask, []) + to_play = np.array(to_play).reshape(-1).tolist() + + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._collect_model.initial_inference(data) + latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( + network_output + ) + + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + reward_hidden_state_roots = ( + reward_hidden_state_roots[0].detach().cpu().numpy(), + reward_hidden_state_roots[1].detach().cpu().numpy() + ) + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] + # the only difference between collect and eval is the dirichlet noise. + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(batch_size) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(batch_size, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(batch_size, legal_actions) + roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_roots, policy_logits, to_play) + self._mcts_collect.search( + roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play + ) + + roots_visit_count_distributions = roots.get_distributions( + ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_values = roots.get_values() # shape: {list: batch_size} + + data_id = [i for i in range(active_collect_env_num)] + output = {i: defaultdict(list) for i in data_id} + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + + for i in range(batch_size): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + if self._cfg.eps.eps_greedy_exploration_in_collect: + # eps-greedy collect + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self.collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self.collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # normal collect + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self.collect_mcts_temperature, deterministic=False + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[i // agent_num]['action'].append(action) + output[i // agent_num]['distributions'].append(distributions) + output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[i // agent_num]['value'].append(value) + output[i // agent_num]['pred_value'].append(pred_values[i]) + output[i // agent_num]['policy_logits'].append(policy_logits[i]) + + return output + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id=None): + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + active_eval_env_num = len(data) + data = sum(sum(data, []), []) + batch_size = len(data) + data = default_collate(data) + data = to_device(data, self._device) + agent_num = self._cfg['model']['agent_num'] + action_mask = sum(action_mask, []) + to_play = np.array(to_play).reshape(-1).tolist() + + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._eval_model.initial_inference(data) + latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( + network_output + ) + + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + reward_hidden_state_roots = ( + reward_hidden_state_roots[0].detach().cpu().numpy(), + reward_hidden_state_roots[1].detach().cpu().numpy() + ) + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(batch_size, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(batch_size, legal_actions) + roots.prepare_no_noise(value_prefix_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, reward_hidden_state_roots, to_play) + + roots_visit_count_distributions = roots.get_distributions( + ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_values = roots.get_values() # shape: {list: batch_size} + data_id = [i for i in range(active_eval_env_num)] + output = {i: defaultdict(list) for i in data_id} + + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + + for i in range(batch_size): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than sampling during the evaluation phase. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + # Save according to agent dimension + output[i // agent_num]['action'].append(action) + output[i // agent_num]['distributions'].append(distributions) + output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[i // agent_num]['value'].append(value) + output[i // agent_num]['pred_value'].append(pred_values[i]) + output[i // agent_num]['policy_logits'].append(policy_logits[i]) + + return output \ No newline at end of file diff --git a/lzero/policy/multi_agent_muzero.py b/lzero/policy/multi_agent_muzero.py new file mode 100644 index 000000000..65fd3e2cf --- /dev/null +++ b/lzero/policy/multi_agent_muzero.py @@ -0,0 +1,221 @@ +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +import torch.optim as optim +from ding.model import model_wrap +from ding.torch_utils import to_tensor +from ding.utils import POLICY_REGISTRY +from torch.nn import L1Loss + +from lzero.mcts import MuZeroMCTSCtree as MCTSCtree +from lzero.mcts import MuZeroMCTSPtree as MCTSPtree +from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ + DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, prepare_obs, \ + configure_optimizers + +from collections import defaultdict +from ding.utils.data import default_collate +from .muzero import MuZeroPolicy +from ding.torch_utils import to_device, to_tensor + + +@POLICY_REGISTRY.register('multi_agent_muzero') +class MultiAgentMuZeroPolicy(MuZeroPolicy): + """ + Overview: + The policy class for Multi Agent MuZero. + Independent Learning mode is a method in which each agent learns and adapts to the environment independently \ + without directly considering the learning and strategies of other agents. + """ + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id = None + ) -> Dict: + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + self.collect_mcts_temperature = temperature + self.collect_epsilon = epsilon + active_collect_env_num = len(data) + data = sum(sum(data, []), []) + batch_size = len(data) + data = default_collate(data) + data = to_device(data, self._device) + agent_num = self._cfg['model']['agent_num'] + action_mask = sum(action_mask, []) + to_play = np.array(to_play).reshape(-1).tolist() + + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._collect_model.initial_inference(data) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] + # the only difference between collect and eval is the dirichlet noise + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(batch_size) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(batch_size, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(batch_size, legal_actions) + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play) + + roots_visit_count_distributions = roots.get_distributions( + ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_values = roots.get_values() # shape: {list: batch_size} + + data_id = [i for i in range(active_collect_env_num)] + output = {i: defaultdict(list) for i in data_id} + + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + + for i in range(batch_size): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + if self._cfg.eps.eps_greedy_exploration_in_collect: + # eps greedy collect + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self.collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self.collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # normal collect + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self.collect_mcts_temperature, deterministic=False + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[i // agent_num]['action'].append(action) + output[i // agent_num]['distributions'].append(distributions) + output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[i // agent_num]['value'].append(value) + output[i // agent_num]['pred_value'].append(pred_values[i]) + output[i // agent_num]['policy_logits'].append(policy_logits[i]) + + return output + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id=None) -> Dict: + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + active_eval_env_num = len(data) + data = sum(sum(data, []), []) + batch_size = len(data) + data = default_collate(data) + data = to_device(data, self._device) + agent_num = self._cfg['model']['agent_num'] + action_mask = sum(action_mask, []) + to_play = np.array(to_play).reshape(-1).tolist() + + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._collect_model.initial_inference(data) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + if not self._eval_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(batch_size, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(batch_size, legal_actions) + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play) + + roots_visit_count_distributions = roots.get_distributions( + ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_values = roots.get_values() # shape: {list: batch_size} + + data_id = [i for i in range(active_eval_env_num)] + output = {i: defaultdict(list) for i in data_id} + + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + + for i in range(batch_size): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than + # sampling during the evaluation phase. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the + # entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[i // agent_num]['action'].append(action) + output[i // agent_num]['distributions'].append(distributions) + output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[i // agent_num]['value'].append(value) + output[i // agent_num]['pred_value'].append(pred_values[i]) + output[i // agent_num]['policy_logits'].append(policy_logits[i]) + + return output diff --git a/lzero/policy/multi_agent_random_policy.py b/lzero/policy/multi_agent_random_policy.py new file mode 100644 index 000000000..8d3e7083f --- /dev/null +++ b/lzero/policy/multi_agent_random_policy.py @@ -0,0 +1,138 @@ +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +from ding.policy.base_policy import Policy +from ding.utils import POLICY_REGISTRY + +from lzero.policy import InverseScalarTransform, select_action, ez_network_output_unpack, mz_network_output_unpack +from .random_policy import LightZeroRandomPolicy +from collections import defaultdict +from ding.torch_utils import to_device, to_tensor +from ding.utils.data import default_collate + + +@POLICY_REGISTRY.register('multi_agent_lightzero_random_policy') +class MultiAgentLightZeroRandomPolicy(LightZeroRandomPolicy): + """ + Overview: + The policy class for Multi Agent LightZero Random Policy. + """ + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id = None + ): + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + self.collect_mcts_temperature = temperature + self.collect_epsilon = epsilon + + active_collect_env_num = len(data) + data = to_tensor(data) + data = sum(sum(data, []), []) + batch_size = len(data) + data = default_collate(data) + agent_num = batch_size // active_collect_env_num + to_play = np.array(to_play).reshape(-1).tolist() + + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._collect_model.initial_inference(data) + if 'efficientzero' in self._cfg.type: # efficientzero or multi_agent_efficientzero + latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( + network_output + ) + elif 'muzero' in self._cfg.type: # muzero or multi_agent_muzero + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + else: + raise NotImplementedError("need to implement pipeline: {}".format(self._cfg.type)) + + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + if 'efficientzero' in self._cfg.type: + reward_hidden_state_roots = ( + reward_hidden_state_roots[0].detach().cpu().numpy(), + reward_hidden_state_roots[1].detach().cpu().numpy() + ) + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + action_mask = sum(action_mask, []) + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] + # the only difference between collect and eval is the dirichlet noise. + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(batch_size) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = self.MCTSCtree.roots(batch_size, legal_actions) + else: + # python mcts_tree + roots = self.MCTSPtree.roots(batch_size, legal_actions) + if 'efficientzero' in self._cfg.type: # efficientzero or multi_agent_efficientzero + roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_roots, policy_logits, to_play) + self._mcts_collect.search( + roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play + ) + elif 'muzero' in self._cfg.type: # muzero or multi_agent_muzero + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play) + else: + raise NotImplementedError("need to implement pipeline: {}".format(self._cfg.type)) + + roots_visit_count_distributions = roots.get_distributions( + ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_values = roots.get_values() # shape: {list: batch_size} + + data_id = [i for i in range(active_collect_env_num)] + output = {i: defaultdict(list) for i in data_id} + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + + for i in range(batch_size): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self.collect_mcts_temperature, deterministic=False + ) + + # ****** sample a random action from the legal action set ******** + # all items except action are formally obtained from MCTS + random_action = int(np.random.choice(legal_actions[i], 1)) + # **************************************************************** + + output[i // agent_num]['action'].append(random_action) + output[i // agent_num]['distributions'].append(distributions) + output[i // agent_num]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[i // agent_num]['value'].append(value) + output[i // agent_num]['pred_value'].append(pred_values[i]) + output[i // agent_num]['policy_logits'].append(policy_logits[i]) + + return output diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index b1c0cbaf6..fc8897598 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -15,8 +15,16 @@ from lzero.mcts import MuZeroMCTSPtree as MCTSPtree from lzero.model import ImageTransforms from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ +<<<<<<< HEAD + DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, prepare_obs, \ + configure_optimizers +from collections import defaultdict +from ding.torch_utils import to_device +from ding.utils.data import default_collate +======= DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, \ prepare_obs +>>>>>>> origin @POLICY_REGISTRY.register('muzero') @@ -195,6 +203,9 @@ class MuZeroPolicy(Policy): # (int) The decay steps from start to end eps. decay=int(1e5), ), + + # (bool) Whether it is a multi-agent environment. + multi_agent=False, ) def default_model(self) -> Tuple[str, List[str]]: @@ -323,7 +334,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in target_reward = target_reward.view(self._cfg.batch_size, -1) target_value = target_value.view(self._cfg.batch_size, -1) - assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0) + assert self._cfg.batch_size == target_reward.size(0) # ``scalar_transform`` to transform the original value to the scaled value, # i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. @@ -392,9 +403,28 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # calculate consistency loss for the next ``num_unroll_steps`` unroll steps. # ============================================================== if self._cfg.ssl_loss_weight > 0: - # obtain the oracle latent states from representation function. - beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) - network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) + # obtain the oracle hidden states from representation function. + if self._cfg.model.model_type == 'conv': + beg_index = self._cfg.model.image_channel * step_i + end_index = self._cfg.model.image_channel * (step_i + self._cfg.model.frame_stack_num) + network_output = self._learn_model.initial_inference( + obs_target_batch[:, beg_index:end_index, :, :] + ) + elif self._cfg.model.model_type == 'mlp': + beg_index = self._cfg.model.observation_shape * step_i + end_index = self._cfg.model.observation_shape * (step_i + self._cfg.model.frame_stack_num) + network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) + elif self._cfg.model.model_type == 'structure': + obs_target_batch_new = {} + for k, v in obs_target_batch.items(): + if k == 'action_mask': + obs_target_batch_new[k] = v + continue + observation_shape = v.shape[1]//self._cfg.num_unroll_steps + beg_index = observation_shape * step_i + end_index = observation_shape * (step_i + self._cfg.model.frame_stack_num) + obs_target_batch_new[k] = v[:, beg_index:end_index] + network_output = self._learn_model.initial_inference(obs_target_batch_new) latent_state = to_tensor(latent_state) representation_state = to_tensor(network_output.latent_state) @@ -733,6 +763,7 @@ def _monitor_vars_learn(self) -> List[str]: """ return [ 'collect_mcts_temperature', + 'collect_epsilon', 'cur_lr', 'weighted_total_loss', 'total_loss', diff --git a/lzero/policy/random_policy.py b/lzero/policy/random_policy.py index 735a4122d..4165a60bc 100644 --- a/lzero/policy/random_policy.py +++ b/lzero/policy/random_policy.py @@ -22,10 +22,10 @@ def __init__( enable_field: Optional[List[str]] = None, action_space: Any = None, ): - if cfg.type == 'muzero': + if 'muzero' in cfg.type: from lzero.mcts import MuZeroMCTSCtree as MCTSCtree from lzero.mcts import MuZeroMCTSPtree as MCTSPtree - elif cfg.type == 'efficientzero': + elif 'efficientzero' in cfg.type: from lzero.mcts import EfficientZeroMCTSCtree as MCTSCtree from lzero.mcts import EfficientZeroMCTSPtree as MCTSPtree elif cfg.type == 'sampled_efficientzero': @@ -68,6 +68,13 @@ def default_model(self) -> Tuple[str, List[str]]: return 'SampledEfficientZeroModelMLP', ['lzero.model.sampled_efficientzero_model_mlp'] else: raise NotImplementedError("need to implement pipeline: {}".format(self._cfg.type)) + elif self._cfg.model.model_type == 'structure': + if 'efficientzero' in self._cfg.type: # efficientzero or multi_agent_efficientzero + return 'EfficientZeroModelStructure', ['lzero.model.efficientzero_model_structure'] + elif 'muzero' in self._cfg.type: # muzero or multi_agent_muzero + return 'MuZeroModelStructure', ['lzero.model.muzero_model_structure'] + else: + raise NotImplementedError("need to implement pipeline: {}".format(self._cfg.type)) def _init_collect(self) -> None: """ diff --git a/lzero/policy/sampled_efficientzero.py b/lzero/policy/sampled_efficientzero.py index 184ab8c42..dec81f546 100644 --- a/lzero/policy/sampled_efficientzero.py +++ b/lzero/policy/sampled_efficientzero.py @@ -1155,4 +1155,4 @@ def _process_transition(self, obs, policy_output, timestep): def _get_train_sample(self, data): # be compatible with DI-engine Policy class - pass + pass \ No newline at end of file diff --git a/lzero/policy/utils.py b/lzero/policy/utils.py index 1323fbf89..1e984c29c 100644 --- a/lzero/policy/utils.py +++ b/lzero/policy/utils.py @@ -9,6 +9,8 @@ from easydict import EasyDict from scipy.stats import entropy from torch.nn import functional as F +from ding.torch_utils import to_device, to_tensor +from ding.utils.data import default_collate def pad_and_get_lengths(inputs, num_of_sampled_actions): @@ -325,6 +327,7 @@ def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, # ``obs_target_batch`` is only used for calculate consistency loss, which take the all obs other than # timestep t1, and is only performed in the last 8 timesteps in the second dim in ``obs_batch_ori``. obs_target_batch = obs_batch_ori[:, cfg.model.image_channel:, :, :] + return obs_batch, obs_target_batch elif cfg.model.model_type == 'mlp': # for 1-dimensional vector obs """ @@ -348,8 +351,50 @@ def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, # ``obs_target_batch`` is only used for calculate consistency loss, which take the all obs other than # timestep t1, and is only performed in the last 8 timesteps in the second dim in ``obs_batch_ori``. obs_target_batch = obs_batch_ori[:, cfg.model.observation_shape:] - - return obs_batch, obs_target_batch + return obs_batch, obs_target_batch + elif cfg.model.model_type == 'structure': + # dict obs_shape = 1 + batch_size = obs_batch_ori.shape[0] + obs_batch = obs_batch_ori[:, 0:cfg.model.frame_stack_num] + if cfg.model.self_supervised_learning_loss: + obs_target_batch = obs_batch_ori[:, cfg.model.frame_stack_num:] + + # obs_batch + obs_batch = obs_batch.tolist() + obs_batch = sum(obs_batch, []) + obs_batch = default_collate(obs_batch) + obs_batch_new = {} + for k, v in obs_batch.items(): + if isinstance(v, dict): # espaecially for gobigger obs, { {'k':{'k1':[], 'k2':[]},} + obs_batch_new[k] = {} + for k1, v1 in v.items(): + if len(v1.shape) == 1: + obs_batch_new[k][k1] = v1 + else: + obs_batch_new[k][k1] = v1.reshape(batch_size, -1) + else: # espaecially for ptz obs, {'k1':[], 'k2':[]} + obs_batch_new[k] = v.reshape(batch_size, -1) + obs_batch_new = to_device(obs_batch_new, device=cfg.device) + + # obs_target_batch + obs_target_batch_new = None + if cfg.model.self_supervised_learning_loss: + obs_target_batch = obs_target_batch.tolist() + obs_target_batch = sum(obs_target_batch, []) + obs_target_batch = default_collate(obs_target_batch) + obs_target_batch_new = {} + for k, v in obs_target_batch.items(): + if isinstance(v, dict): # espaecially for gobigger obs, { {'k':{'k1':[], 'k2':[]},} + obs_target_batch_new[k] = {} + for k1, v1 in v.items(): + if len(v1.shape) == 1: + obs_target_batch_new[k][k1] = v1 + else: + obs_target_batch_new[k][k1] = v1.reshape(batch_size, -1) + else: # espaecially for ptz obs, {'k1':[], 'k2':[]} + obs_target_batch_new[k] = v.reshape(batch_size, -1) + obs_target_batch_new = to_device(obs_target_batch_new, device=cfg.device) + return obs_batch_new, obs_target_batch_new def negative_cosine_similarity(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: diff --git a/lzero/worker/__init__.py b/lzero/worker/__init__.py index 3000ee23f..7e81f3e72 100644 --- a/lzero/worker/__init__.py +++ b/lzero/worker/__init__.py @@ -1,4 +1,5 @@ from .alphazero_collector import AlphaZeroCollector from .alphazero_evaluator import AlphaZeroEvaluator from .muzero_collector import MuZeroCollector -from .muzero_evaluator import MuZeroEvaluator \ No newline at end of file +from .muzero_evaluator import MuZeroEvaluator +from .gobigger_muzero_evaluator import GoBiggerMuZeroEvaluator \ No newline at end of file diff --git a/lzero/worker/gobigger_muzero_evaluator.py b/lzero/worker/gobigger_muzero_evaluator.py new file mode 100644 index 000000000..1694c8880 --- /dev/null +++ b/lzero/worker/gobigger_muzero_evaluator.py @@ -0,0 +1,421 @@ +import copy +import time +from collections import namedtuple +from typing import Optional, Callable, Tuple, Any, List, Dict + +import numpy as np +import torch +from ding.envs import BaseEnvManager +from ding.torch_utils import to_ndarray, to_item +from ding.utils import build_logger, EasyTimer +from ding.utils import get_world_size, get_rank, broadcast_object_list +from ding.worker.collector.base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor +from easydict import EasyDict + +from lzero.mcts.buffer.game_segment import GameSegment +from lzero.mcts.utils import prepare_observation +from collections import defaultdict + +from zoo.gobigger.env.gobigger_rule_bot import GoBiggerBot +from collections import namedtuple, deque +from .muzero_evaluator import MuZeroEvaluator + + +class GoBiggerMuZeroEvaluator(MuZeroEvaluator): + + def _add_info(self, last_timestep, info): + # add eat info + for i in range(len(last_timestep.info['eats']) // 2): + for k, v in last_timestep.info['eats'][i].items(): + info['agent_{}_{}'.format(i, k)] = v + return info + + def eval_vsbot( + self, + save_ckpt_fn: Callable = None, + train_iter: int = -1, + envstep: int = -1, + n_episode: Optional[int] = None, + ) -> Tuple[bool, float]: + """ + Overview: + Evaluate policy and store the best policy based on whether it reaches the highest historical reward. + Arguments: + - save_ckpt_fn (:obj:`Callable`): Saving ckpt function, which will be triggered by getting the best reward. + - train_iter (:obj:`int`): Current training iteration. + - envstep (:obj:`int`): Current env interaction step. + - n_episode (:obj:`int`): Number of evaluation episodes. + Returns: + - stop_flag (:obj:`bool`): Whether this training program can be ended. + - eval_reward (:obj:`float`): Current eval_reward. + """ + episode_info = None + stop_flag = False + if get_rank() == 0: + if n_episode is None: + n_episode = self._default_n_episode + assert n_episode is not None, "please indicate eval n_episode" + envstep_count = 0 + # specifically for vs bot + eval_monitor = GoBiggerVectorEvalMonitor(self._env.env_num, n_episode) + env_nums = self._env.env_num + + self._env.reset() + self._policy.reset() + + # initializations + init_obs = self._env.ready_obs + + retry_waiting_time = 0.001 + while len(init_obs.keys()) != self._env_num: + # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to + # len(self._env.ready_obs), especially in tictactoe env. + self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) + self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) + time.sleep(retry_waiting_time) + self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) + self._logger.info( + 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states) + ) + init_obs = self._env.ready_obs + + # specifically for vs bot + agent_num = self.policy_config['model']['agent_num'] + team_num = self.policy_config['model']['team_num'] + self._bot_policy = GoBiggerBot(env_nums, agent_id=[i for i in range(agent_num//team_num, agent_num)]) #TODO only support t2p2 + self._bot_policy.reset() + + # specifically for vs bot + for i in range(env_nums): + for k, v in init_obs[i].items(): + if k != 'raw_obs': + init_obs[i][k] = v[:agent_num] + + action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} + + to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} + dones = np.array([False for _ in range(env_nums)]) + + + if self._multi_agent: + agent_num = len(init_obs[0]['action_mask']) + assert agent_num == self.policy_config.model.agent_num, "Please make sure agent_num == env.agent_num" + game_segments = [ + [ + GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) for _ in range(agent_num) + ] for _ in range(env_nums) + ] + for env_id in range(env_nums): + for agent_id in range(agent_num): + game_segments[env_id][agent_id].reset( + [ + to_ndarray(init_obs[env_id]['observation'][agent_id]) + for _ in range(self.policy_config.model.frame_stack_num) + ] + ) + else: + game_segments = [ + GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) for _ in range(env_nums) + ] + for i in range(env_nums): + game_segments[i].reset( + [to_ndarray(init_obs[i]['observation']) for _ in range(self.policy_config.model.frame_stack_num)] + ) + + ready_env_id = set() + remain_episode = n_episode + # specifically for vs bot + eat_info = defaultdict() + + with self._timer: + while not eval_monitor.is_finished(): + # Get current ready env obs. + obs = self._env.ready_obs + # specifically for vs bot + raw_obs = [v['raw_obs'] for k, v in obs.items()] + new_available_env_id = set(obs.keys()).difference(ready_env_id) + ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) + remain_episode -= min(len(new_available_env_id), remain_episode) + + if self._multi_agent: + stack_obs = defaultdict(list) + for env_id in ready_env_id: + for agent_id in range(agent_num): + stack_obs[env_id].append(game_segments[env_id][agent_id].get_obs()) + else: + stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + stack_obs = list(stack_obs.values()) + + action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} + to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} + action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] + to_play = [to_play_dict[env_id] for env_id in ready_env_id] + + stack_obs = to_ndarray(stack_obs) + if self.policy_config.model.model_type and self.policy_config.model.model_type in ['conv', 'mlp']: + stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) + stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() + + # ============================================================== + # bot forward + # ============================================================== + bot_actions = self._bot_policy.forward(raw_obs) + + # ============================================================== + # policy forward + # ============================================================== + policy_output = self._policy.forward(stack_obs, action_mask, to_play) + if self._multi_agent: + actions_no_env_id = defaultdict(dict) + for k, v in policy_output.items(): + for agent_id, act in enumerate(v['action']): + actions_no_env_id[k][agent_id] = act + else: + actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} + distributions_dict_no_env_id = {k: v['distributions'] for k, v in policy_output.items()} + if self.policy_config.sampled_algo: + root_sampled_actions_dict_no_env_id = { + k: v['root_sampled_actions'] + for k, v in policy_output.items() + } + + value_dict_no_env_id = {k: v['value'] for k, v in policy_output.items()} + pred_value_dict_no_env_id = {k: v['pred_value'] for k, v in policy_output.items()} + visit_entropy_dict_no_env_id = { + k: v['visit_count_distribution_entropy'] + for k, v in policy_output.items() + } + + actions = {} + distributions_dict = {} + if self.policy_config.sampled_algo: + root_sampled_actions_dict = {} + value_dict = {} + pred_value_dict = {} + visit_entropy_dict = {} + for index, env_id in enumerate(ready_env_id): + actions[env_id] = actions_no_env_id.pop(index) + distributions_dict[env_id] = distributions_dict_no_env_id.pop(index) + if self.policy_config.sampled_algo: + root_sampled_actions_dict[env_id] = root_sampled_actions_dict_no_env_id.pop(index) + value_dict[env_id] = value_dict_no_env_id.pop(index) + pred_value_dict[env_id] = pred_value_dict_no_env_id.pop(index) + visit_entropy_dict[env_id] = visit_entropy_dict_no_env_id.pop(index) + + # ============================================================== + # Interact with env. + # ============================================================== + # specifically for vs bot + for env_id, v in bot_actions.items(): + actions[env_id].update(v) + + timesteps = self._env.step(actions) + + for env_id, t in timesteps.items(): + obs, reward, done, info = t.obs, t.reward, t.done, t.info + if self._multi_agent: + for agent_id in range(agent_num): + game_segments[env_id][agent_id].append( + actions[env_id][agent_id], to_ndarray(obs['observation'][agent_id]), reward[agent_id] if isinstance(reward, list) else reward, + action_mask_dict[env_id][agent_id], to_play_dict[env_id] + ) + else: + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], + to_play_dict[env_id] + ) + + # NOTE: in evaluator, we only need save the ``o_{t+1} = obs['observation']`` + # game_segments[env_id].obs_segment.append(to_ndarray(obs['observation'])) + + # NOTE: the position of code snippet is very important. + # the obs['action_mask'] and obs['to_play'] is corresponding to next action + action_mask_dict[env_id] = to_ndarray(obs['action_mask']) + to_play_dict[env_id] = to_ndarray(obs['to_play']) + + dones[env_id] = done + if t.done: + # Env reset is done by env_manager automatically. + self._policy.reset([env_id]) + reward = t.info['eval_episode_return'] + # specifically for vs bot + bot_reward = t.info['eval_bot_episode_return'] + eat_info[env_id] = t.info['eats'] + if 'episode_info' in t.info: + eval_monitor.update_info(env_id, t.info['episode_info']) + eval_monitor.update_reward(env_id, reward) + # specifically for vs bot + eval_monitor.update_bot_reward(env_id, bot_reward) + self._logger.info( + "[EVALUATOR vsbot]env {} finish episode, final reward: {}, current episode: {}".format( + env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode() + ) + ) + + # reset the finished env and init game_segments + if n_episode > self._env_num: + # Get current ready env obs. + init_obs = self._env.ready_obs + retry_waiting_time = 0.001 + while len(init_obs.keys()) != self._env_num: + # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to + # len(self._env.ready_obs), especially in tictactoe env. + self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) + self._logger.info( + 'Before sleeping, the _env_states is {}'.format(self._env._env_states) + ) + time.sleep(retry_waiting_time) + self._logger.info( + '=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10 + ) + self._logger.info( + 'After sleeping {}s, the current _env_states is {}'.format( + retry_waiting_time, self._env._env_states + ) + ) + init_obs = self._env.ready_obs + + new_available_env_id = set(init_obs.keys()).difference(ready_env_id) + ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) + remain_episode -= min(len(new_available_env_id), remain_episode) + + action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) + to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) + + if self._multi_agent: + for agent_id in range(agent_num): + game_segments[env_id][agent_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + + game_segments[env_id][agent_id].reset( + [ + init_obs[env_id]['observation'][agent_id] + for _ in range(self.policy_config.model.frame_stack_num) + ] + ) + else: + game_segments[env_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + + game_segments[env_id].reset( + [ + init_obs[env_id]['observation'] + for _ in range(self.policy_config.model.frame_stack_num) + ] + ) + + # Env reset is done by env_manager automatically. + self._policy.reset([env_id]) + # specifically for vs bot + self._bot_policy.reset([env_id]) + # TODO(pu): subprocess mode, when n_episode > self._env_num, occasionally the ready_env_id=() + # and the stack_obs is np.array(None, dtype=object) + ready_env_id.remove(env_id) + + envstep_count += 1 + duration = self._timer.value + episode_return = eval_monitor.get_episode_return() + # specifically for vs bot + bot_episode_return = eval_monitor.get_bot_episode_return() + info = { + 'train_iter': train_iter, + 'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter), + 'episode_count': n_episode, + 'envstep_count': envstep_count, + 'avg_envstep_per_episode': envstep_count / n_episode, + 'evaluate_time': duration, + 'avg_envstep_per_sec': envstep_count / duration, + 'avg_time_per_episode': n_episode / duration, + 'reward_mean': np.mean(episode_return), + 'reward_std': np.std(episode_return), + 'reward_max': np.max(episode_return), + 'reward_min': np.min(episode_return), + # specifically for vs bot + 'bot_reward_mean': np.mean(bot_episode_return), + 'bot_reward_std': np.std(bot_episode_return), + 'bot_reward_max': np.max(bot_episode_return), + 'bot_reward_min': np.min(bot_episode_return), + } + # specifically for vs bot + # add eat info + for k, v in eat_info.items(): + for i in range(len(v)): + for k1, v1 in v[i].items(): + info['agent_{}_{}'.format(i, k1)] = info.get('agent_{}_{}'.format(i, k1), []) + [v1] + + for k, v in info.items(): + if 'agent' in k: + info[k] = np.mean(v) + + episode_info = eval_monitor.get_episode_info() + if episode_info is not None: + info.update(episode_info) + self._logger.info(self._logger.get_tabulate_vars_hor(info)) + # self._logger.info(self._logger.get_tabulate_vars(info)) + for k, v in info.items(): + if k in ['train_iter', 'ckpt_name', 'each_reward']: + continue + if not np.isscalar(v): + continue + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) + episode_return = np.mean(episode_return) + if episode_return > self._max_episode_return: + if save_ckpt_fn: + save_ckpt_fn('ckpt_best.pth.tar') + self._max_episode_return = episode_return + stop_flag = episode_return >= self._stop_value and train_iter > 0 + if stop_flag: + self._logger.info( + "[LightZero serial pipeline] " + + "Current episode_return: {} is greater than stop_value: {}".format(episode_return, self._stop_value) + + ", so your MCTS/RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details." + ) + + if get_world_size() > 1: + objects = [stop_flag, episode_info] + broadcast_object_list(objects, src=0) + stop_flag, episode_info = objects + + episode_info = to_item(episode_info) + return stop_flag, episode_info + +class GoBiggerVectorEvalMonitor(VectorEvalMonitor): + + def __init__(self, env_num: int, n_episode: int) -> None: + super().__init__(env_num, n_episode) + each_env_episode = [n_episode // env_num for _ in range(env_num)] + self._bot_reward = {env_id: deque(maxlen=maxlen) for env_id, maxlen in enumerate(each_env_episode)} + + def get_bot_episode_return(self) -> list: + """ + Overview: + Sum up all reward and get the total return of one episode. + """ + return sum([list(v) for v in self._bot_reward.values()], []) # sum(iterable, start) + + def update_bot_reward(self, env_id: int, reward: Any) -> None: + """ + Overview: + Update the reward indicated by env_id. + Arguments: + - env_id: (:obj:`int`): the id of the environment we need to update the reward + - reward: (:obj:`Any`): the reward we need to update + """ + if isinstance(reward, torch.Tensor): + reward = reward.item() + self._bot_reward[env_id].append(reward) diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index aca581c47..4788bf29b 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -1,5 +1,5 @@ import time -from collections import deque, namedtuple +from collections import deque, namedtuple, defaultdict from typing import Optional, Any, List import numpy as np @@ -78,6 +78,10 @@ def __init__( self._tb_logger = None self.policy_config = policy_config + if 'multi_agent' in self.policy_config.keys() and self.policy_config.multi_agent: + self._multi_agent = self.policy_config.multi_agent + else: + self._multi_agent = False self.reset(policy, env) @@ -290,6 +294,89 @@ def pad_and_save_last_trajectory(self, i, last_game_segments, last_game_prioriti last_game_priorities[i] = None return None + + def _compute_priorities_for_agent(self, i, agent_id, pred_values_lst, search_values_lst): + """ + Overview: + obtain the priorities at index i. + Arguments: + - i: index. + - pred_values_lst: The list of value being predicted. + - search_values_lst: The list of value obtained through search. + """ + if self.policy_config.use_priority: + pred_values = torch.from_numpy(np.array(pred_values_lst[i][agent_id])).to(self.policy_config.device + ).float().view(-1) + search_values = torch.from_numpy(np.array(search_values_lst[i][agent_id])).to(self.policy_config.device + ).float().view(-1) + priorities = L1Loss(reduction='none' + )(pred_values, + search_values).detach().cpu().numpy() + 1e-6 # avoid zero priority + else: + # priorities is None -> use the max priority for all newly collected data + priorities = None + + return priorities + + def pad_and_save_last_trajectory_for_agent( + self, i, agent_id, last_game_segments, last_game_priorities, game_segments, done + ) -> None: + """ + Overview: + put the last game block into the pool if the current game is finished + Arguments: + - last_game_segments (:obj:`list`): list of the last game segments + - last_game_priorities (:obj:`list`): list of the last game priorities + - game_segments (:obj:`list`): list of the current game segments + Note: + (last_game_segments[i].obs_segment[-4:][j] == game_segments[i].obs_segment[:4][j]).all() is True + """ + # pad over last block trajectory + beg_index = self.policy_config.model.frame_stack_num + end_index = beg_index + self.policy_config.num_unroll_steps + + # the start obs is init zero obs, so we take the [ : +] obs as the pad obs + # e.g. the start 4 obs is init zero obs, the num_unroll_steps is 5, so we take the [4:9] obs as the pad obs + pad_obs_lst = game_segments[i][agent_id].obs_segment[beg_index:end_index] + pad_child_visits_lst = game_segments[i][agent_id].child_visit_segment[:self.policy_config.num_unroll_steps] + # EfficientZero original repo bug: + # pad_child_visits_lst = game_segments[i].child_visit_segment[beg_index:end_index] + + beg_index = 0 + # self.unroll_plus_td_steps = self.policy_config.num_unroll_steps + self.policy_config.td_steps + end_index = beg_index + self.unroll_plus_td_steps - 1 + + pad_reward_lst = game_segments[i][agent_id].reward_segment[beg_index:end_index] + + beg_index = 0 + end_index = beg_index + self.unroll_plus_td_steps + + pad_root_values_lst = game_segments[i][agent_id].root_value_segment[beg_index:end_index] + + # pad over and save + last_game_segments[i][agent_id].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst) + """ + Note: + game_segment element shape: + obs: game_segment_length + stack + num_unroll_steps, 20+4 +5 + rew: game_segment_length + stack + num_unroll_steps + td_steps -1 20 +5+3-1 + action: game_segment_length -> 20 + root_values: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 + child_visits: game_segment_length + num_unroll_steps -> 20 +5 + to_play: game_segment_length -> 20 + action_mask: game_segment_length -> 20 + """ + + last_game_segments[i][agent_id].game_segment_to_array() + + # put the game block into the pool + self.game_segment_pool.append((last_game_segments[i][agent_id], last_game_priorities[i][agent_id], done[i])) + + # reset last game_segments + last_game_segments[i][agent_id] = None + last_game_priorities[i][agent_id] = None + + return None def collect(self, n_episode: Optional[int] = None, @@ -341,34 +428,69 @@ def collect(self, if self.policy_config.use_ture_chance_label_in_chance_encoder: chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)} - game_segments = [ - GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) for _ in range(env_nums) - ] - # stacked observation windows in reset stage for init game_segments - observation_window_stack = [[] for _ in range(env_nums)] - for env_id in range(env_nums): - observation_window_stack[env_id] = deque( - [to_ndarray(init_obs[env_id]['observation']) for _ in range(self.policy_config.model.frame_stack_num)], - maxlen=self.policy_config.model.frame_stack_num - ) - - game_segments[env_id].reset(observation_window_stack[env_id]) - - dones = np.array([False for _ in range(env_nums)]) - last_game_segments = [None for _ in range(env_nums)] - last_game_priorities = [None for _ in range(env_nums)] - # for priorities in self-play - search_values_lst = [[] for _ in range(env_nums)] - pred_values_lst = [[] for _ in range(env_nums)] + if self._multi_agent: + agent_num = len(init_obs[0]['action_mask']) + assert agent_num == self.policy_config.model.agent_num, "Please make sure agent_num == env.agent_num" + game_segments = [ + [ + GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) for _ in range(agent_num) + ] for _ in range(env_nums) + ] + # stacked observation windows in reset stage for init game_segments + observation_window_stack = [[[] for _ in range(agent_num)] for _ in range(env_nums)] + for env_id in range(env_nums): + for agent_id in range(agent_num): + observation_window_stack[env_id][agent_id] = deque( + [ + to_ndarray(init_obs[env_id]['observation'][agent_id]) + for _ in range(self.policy_config.model.frame_stack_num) + ], + maxlen=self.policy_config.model.frame_stack_num + ) + game_segments[env_id][agent_id].reset(observation_window_stack[env_id][agent_id]) + + dones = np.array([False for _ in range(env_nums)]) + last_game_segments = [[None for _ in range(agent_num)] for _ in range(env_nums)] + last_game_priorities = [[None for _ in range(agent_num)] for _ in range(env_nums)] + # for priorities in self-play + search_values_lst = [[[] for _ in range(agent_num)] for _ in range(env_nums)] + pred_values_lst = [[[] for _ in range(agent_num)] for _ in range(env_nums)] + else: + # some logs + game_segments = [ + GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) for _ in range(env_nums) + ] + # stacked observation windows in reset stage for init game_segments + observation_window_stack = [[] for _ in range(env_nums)] + for env_id in range(env_nums): + observation_window_stack[env_id] = deque( + [to_ndarray(init_obs[env_id]['observation']) for _ in range(self.policy_config.model.frame_stack_num)], + maxlen=self.policy_config.model.frame_stack_num + ) + game_segments[env_id].reset(observation_window_stack[env_id]) + + dones = np.array([False for _ in range(env_nums)]) + last_game_segments = [None for _ in range(env_nums)] + last_game_priorities = [None for _ in range(env_nums)] + # for priorities in self-play + search_values_lst = [[] for _ in range(env_nums)] + pred_values_lst = [[] for _ in range(env_nums)] if self.policy_config.gumbel_algo: improved_policy_lst = [[] for _ in range(env_nums)] # some logs - eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros(env_nums) + if self._multi_agent: + eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros((env_nums, agent_num)) + else: + eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros(env_nums) if self.policy_config.gumbel_algo: completed_value_lst = np.zeros(env_nums) self_play_moves = 0. @@ -387,8 +509,13 @@ def collect(self, new_available_env_id = set(obs.keys()).difference(ready_env_id) ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) remain_episode -= min(len(new_available_env_id), remain_episode) - - stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + if self._multi_agent: + stack_obs = defaultdict(list) + for env_id in ready_env_id: + for agent_id in range(agent_num): + stack_obs[env_id].append(game_segments[env_id][agent_id].get_obs()) + else: + stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} stack_obs = list(stack_obs.values()) action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} @@ -399,19 +526,29 @@ def collect(self, chance_dict = {env_id: chance_dict[env_id] for env_id in ready_env_id} chance = [chance_dict[env_id] for env_id in ready_env_id] - stack_obs = to_ndarray(stack_obs) - - stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) - - stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() + if self.policy_config.model.model_type: + if self.policy_config.model.model_type in ['conv', 'mlp']: + stack_obs = to_ndarray(stack_obs) + stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) + stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() + elif self.policy_config.model.model_type == 'structure': + stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) + else: + raise ValueError('model_type must be one of [conv, mlp, structure]') # ============================================================== # policy forward # ============================================================== policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon) - actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} - distributions_dict_no_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()} + if self._multi_agent: + actions_no_env_id = defaultdict(dict) + for k, v in policy_output.items(): + for agent_id, act in enumerate(v['action']): + actions_no_env_id[k][agent_id] = act + else: + actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} + distributions_dict_no_env_id = {k: v['distributions'] for k, v in policy_output.items()} if self.policy_config.sampled_algo: root_sampled_actions_dict_no_env_id = { k: v['root_sampled_actions'] @@ -473,20 +610,35 @@ def collect(self, obs, reward, done, info = timestep.obs, timestep.reward, timestep.done, timestep.info if self.policy_config.sampled_algo: - game_segments[env_id].store_search_stats( - distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id] - ) + if self._multi_agent: + for agent_id in range(agent_num): + game_segments[env_id][agent_id].store_search_stats( + distributions_dict[env_id][agent_id], value_dict[env_id][agent_id], + root_sampled_actions_dict[env_id][agent_id] + ) + else: + game_segments[env_id].store_search_stats( + distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id] + ) elif self.policy_config.gumbel_algo: game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], improved_policy = improved_policy_dict[env_id]) else: - game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id]) + if self._multi_agent: + for agent_id in range(agent_num): + game_segments[env_id][agent_id].store_search_stats( + distributions_dict[env_id][agent_id], value_dict[env_id][agent_id] + ) + else: + game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id]) # append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t} # in ``game_segments[env_id].init``, we have append o_{t} in ``self.obs_segment`` - if self.policy_config.use_ture_chance_label_in_chance_encoder: - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id], chance_dict[env_id] - ) + if self._multi_agent: + for agent_id in range(agent_num): + game_segments[env_id][agent_id].append( + actions[env_id][agent_id], + to_ndarray(obs['observation'][agent_id]), reward[agent_id] if isinstance(reward, list) else reward, + action_mask_dict[env_id][agent_id], to_play_dict[env_id] + ) else: game_segments[env_id].append( actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], @@ -504,54 +656,96 @@ def collect(self, dones[env_id] = False else: dones[env_id] = done - - visit_entropies_lst[env_id] += visit_entropy_dict[env_id] - if self.policy_config.gumbel_algo: - completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id])) + + if self._multi_agent: + for agent_id in range(agent_num): + visit_entropies_lst[env_id][agent_id] += visit_entropy_dict[env_id][agent_id] + else: + visit_entropies_lst[env_id] += visit_entropy_dict[env_id] + if self.policy_config.gumbel_algo: + completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id])) eps_steps_lst[env_id] += 1 total_transitions += 1 if self.policy_config.use_priority: - pred_values_lst[env_id].append(pred_value_dict[env_id]) - search_values_lst[env_id].append(value_dict[env_id]) - if self.policy_config.gumbel_algo: - improved_policy_lst[env_id].append(improved_policy_dict[env_id]) + if self._multi_agent: + for agent_id in range(agent_num): + pred_values_lst[env_id][agent_id].append(pred_value_dict[env_id][agent_id]) + search_values_lst[env_id][agent_id].append(value_dict[env_id][agent_id]) + else: + pred_values_lst[env_id].append(pred_value_dict[env_id]) + search_values_lst[env_id].append(value_dict[env_id]) + if self.policy_config.gumbel_algo: + improved_policy_lst[env_id].append(improved_policy_dict[env_id]) # append the newest obs - observation_window_stack[env_id].append(to_ndarray(obs['observation'])) + if self._multi_agent: + for agent_id in range(agent_num): + observation_window_stack[env_id][agent_id].append(to_ndarray(obs['observation'][agent_id])) + else: + observation_window_stack[env_id].append(to_ndarray(obs['observation'])) # ============================================================== # we will save a game segment if it is the end of the game or the next game segment is finished. # ============================================================== - # if game segment is full, we will save the last game segment - if game_segments[env_id].is_full(): - # pad over last segment trajectory - if last_game_segments[env_id] is not None: - # TODO(pu): return the one game segment - self.pad_and_save_last_trajectory( - env_id, last_game_segments, last_game_priorities, game_segments, dones - ) - - # calculate priority - priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) - pred_values_lst[env_id] = [] - search_values_lst[env_id] = [] - if self.policy_config.gumbel_algo: - improved_policy_lst[env_id] = [] - - # the current game_segments become last_game_segment - last_game_segments[env_id] = game_segments[env_id] - last_game_priorities[env_id] = priorities + # if game block is full, we will save the last game block + if self._multi_agent: + for agent_id in range(agent_num): + if game_segments[env_id][agent_id].is_full(): + # pad over last block trajectory + if last_game_segments[env_id][agent_id] is not None: + # TODO(pu): return the one game block + self.pad_and_save_last_trajectory_for_agent( + env_id, agent_id, last_game_segments, last_game_priorities, game_segments, dones + ) + + # calculate priority + priorities = self._compute_priorities_for_agent(env_id, agent_id, pred_values_lst, search_values_lst) + # pred_values_lst[env_id] = [] + # search_values_lst[env_id] = [] + search_values_lst = [[[] for _ in range(agent_num)] for _ in range(env_nums)] + pred_values_lst = [[[] for _ in range(agent_num)] for _ in range(env_nums)] + + # the current game_segments become last_game_segment + last_game_segments[env_id][agent_id] = game_segments[env_id][agent_id] + last_game_priorities[env_id][agent_id] = priorities + + # create new GameSegment + game_segments[env_id][agent_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + game_segments[env_id][agent_id].reset(observation_window_stack[env_id][agent_id]) + else: + if game_segments[env_id].is_full(): + # pad over last block trajectory + if last_game_segments[env_id] is not None: + # TODO(pu): return the one game block + self.pad_and_save_last_trajectory( + env_id, last_game_segments, last_game_priorities, game_segments, dones + ) - # create new GameSegment - game_segments[env_id] = GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) - game_segments[env_id].reset(observation_window_stack[env_id]) + # calculate priority + priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) + pred_values_lst[env_id] = [] + search_values_lst[env_id] = [] + if self.policy_config.gumbel_algo: + improved_policy_lst[env_id] = [] + + # the current game_segments become last_game_segment + last_game_segments[env_id] = game_segments[env_id] + last_game_priorities[env_id] = priorities + + # create new GameSegment + game_segments[env_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + game_segments[env_id].reset(observation_window_stack[env_id]) self._env_info[env_id]['step'] += 1 collected_step += 1 @@ -577,21 +771,39 @@ def collect(self, # NOTE: put the penultimate game segment in one episode into the trajectory_pool # pad over 2th last game_segment using the last game_segment - if last_game_segments[env_id] is not None: - self.pad_and_save_last_trajectory( - env_id, last_game_segments, last_game_priorities, game_segments, dones - ) + if self._multi_agent: + for agent_id in range(agent_num): + if last_game_segments[env_id][agent_id] is not None: + self.pad_and_save_last_trajectory_for_agent( + env_id, agent_id, last_game_segments, last_game_priorities, game_segments, dones + ) - # store current segment trajectory - priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) + # store current block trajectory + priorities = self._compute_priorities_for_agent(env_id, agent_id, pred_values_lst, search_values_lst) - # NOTE: put the last game segment in one episode into the trajectory_pool - game_segments[env_id].game_segment_to_array() + # NOTE: put the last game block in one episode into the trajectory_pool + game_segments[env_id][agent_id].game_segment_to_array() - # assert len(game_segments[env_id]) == len(priorities) - # NOTE: save the last game segment in one episode into the trajectory_pool if it's not null - if len(game_segments[env_id].reward_segment) != 0: - self.game_segment_pool.append((game_segments[env_id], priorities, dones[env_id])) + # assert len(game_segments[env_id]) == len(priorities) + # NOTE: save the last game block in one episode into the trajectory_pool if it's not null + if len(game_segments[env_id][agent_id].reward_segment) != 0: + self.game_segment_pool.append((game_segments[env_id][agent_id], priorities, dones[env_id])) + else: + if last_game_segments[env_id] is not None: + self.pad_and_save_last_trajectory( + env_id, last_game_segments, last_game_priorities, game_segments, dones + ) + + # store current segment trajectory + priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) + + # NOTE: put the last game segment in one episode into the trajectory_pool + game_segments[env_id].game_segment_to_array() + + # assert len(game_segments[env_id]) == len(priorities) + # NOTE: save the last game segment in one episode into the trajectory_pool if it's not null + if len(game_segments[env_id].reward_segment) != 0: + self.game_segment_pool.append((game_segments[env_id], priorities, dones[env_id])) # print(game_segments[env_id].reward_segment) # reset the finished env and init game_segments @@ -624,18 +836,37 @@ def collect(self, if self.policy_config.use_ture_chance_label_in_chance_encoder: chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) - game_segments[env_id] = GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) - observation_window_stack[env_id] = deque( - [init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num)], - maxlen=self.policy_config.model.frame_stack_num - ) - game_segments[env_id].reset(observation_window_stack[env_id]) - last_game_segments[env_id] = None - last_game_priorities[env_id] = None + if self._multi_agent: + for agent_id in range(agent_num): + game_segments[env_id][agent_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + observation_window_stack[env_id][agent_id] = deque( + [ + init_obs[env_id]['observation'][agent_id] + for _ in range(self.policy_config.model.frame_stack_num) + ], + maxlen=self.policy_config.model.frame_stack_num + ) + game_segments[env_id][agent_id].reset(observation_window_stack[env_id][agent_id]) + last_game_segments[env_id] = [None for _ in range(agent_num)] + last_game_priorities[env_id] = [None for _ in range(agent_num)] + + else: + game_segments[env_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + observation_window_stack[env_id] = deque( + [init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num)], + maxlen=self.policy_config.model.frame_stack_num + ) + game_segments[env_id].reset(observation_window_stack[env_id]) + last_game_segments[env_id] = None + last_game_priorities[env_id] = None # log self_play_moves_max = max(self_play_moves_max, eps_steps_lst[env_id]) diff --git a/lzero/worker/muzero_evaluator.py b/lzero/worker/muzero_evaluator.py index 313a07e07..88d4d211b 100644 --- a/lzero/worker/muzero_evaluator.py +++ b/lzero/worker/muzero_evaluator.py @@ -14,6 +14,7 @@ from lzero.mcts.buffer.game_segment import GameSegment from lzero.mcts.utils import prepare_observation +from collections import defaultdict class MuZeroEvaluator(ISerialEvaluator): @@ -100,6 +101,11 @@ def __init__( # ============================================================== self.policy_config = policy_config + if 'multi_agent' in self.policy_config.keys() and self.policy_config.multi_agent: + self._multi_agent = self.policy_config.multi_agent + else: + self._multi_agent = False + def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: @@ -192,6 +198,9 @@ def should_eval(self, train_iter: int) -> bool: return False self._last_eval_iter = train_iter return True + + def _add_info(self, last_timestep, info): + return info def eval( self, @@ -248,17 +257,38 @@ def eval( to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} dones = np.array([False for _ in range(env_nums)]) - game_segments = [ - GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) for _ in range(env_nums) - ] - for i in range(env_nums): - game_segments[i].reset( - [to_ndarray(init_obs[i]['observation']) for _ in range(self.policy_config.model.frame_stack_num)] - ) + if self._multi_agent: + agent_num = len(init_obs[0]['action_mask']) + assert agent_num == self.policy_config.model.agent_num, "Please make sure agent_num == env.agent_num" + game_segments = [ + [ + GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) for _ in range(agent_num) + ] for _ in range(env_nums) + ] + for env_id in range(env_nums): + for agent_id in range(agent_num): + game_segments[env_id][agent_id].reset( + [ + to_ndarray(init_obs[env_id]['observation'][agent_id]) + for _ in range(self.policy_config.model.frame_stack_num) + ] + ) + else: + game_segments = [ + GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) for _ in range(env_nums) + ] + for i in range(env_nums): + game_segments[i].reset( + [to_ndarray(init_obs[i]['observation']) for _ in range(self.policy_config.model.frame_stack_num)] + ) ready_env_id = set() remain_episode = n_episode @@ -271,7 +301,13 @@ def eval( ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) remain_episode -= min(len(new_available_env_id), remain_episode) - stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + if self._multi_agent: + stack_obs = defaultdict(list) + for env_id in ready_env_id: + for agent_id in range(agent_num): + stack_obs[env_id].append(game_segments[env_id][agent_id].get_obs()) + else: + stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} stack_obs = list(stack_obs.values()) action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} @@ -280,16 +316,22 @@ def eval( to_play = [to_play_dict[env_id] for env_id in ready_env_id] stack_obs = to_ndarray(stack_obs) - stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) - stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() + if self.policy_config.model.model_type and self.policy_config.model.model_type in ['conv', 'mlp']: + stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) + stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() # ============================================================== # policy forward # ============================================================== policy_output = self._policy.forward(stack_obs, action_mask, to_play) - - actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} - distributions_dict_no_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()} + if self._multi_agent: + actions_no_env_id = defaultdict(dict) + for k, v in policy_output.items(): + for agent_id, act in enumerate(v['action']): + actions_no_env_id[k][agent_id] = act + else: + actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} + distributions_dict_no_env_id = {k: v['distributions'] for k, v in policy_output.items()} if self.policy_config.sampled_algo: root_sampled_actions_dict_no_env_id = { k: v['root_sampled_actions'] @@ -326,11 +368,17 @@ def eval( timesteps = to_tensor(timesteps, dtype=torch.float32) for env_id, t in timesteps.items(): obs, reward, done, info = t.obs, t.reward, t.done, t.info - - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id] - ) + if self._multi_agent: + for agent_id in range(agent_num): + game_segments[env_id][agent_id].append( + actions[env_id][agent_id], to_ndarray(obs['observation'][agent_id]), reward[agent_id] if isinstance(reward, list) else reward, + action_mask_dict[env_id][agent_id], to_play_dict[env_id] + ) + else: + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], + to_play_dict[env_id] + ) # NOTE: in evaluator, we only need save the ``o_{t+1} = obs['observation']`` # game_segments[env_id].obs_segment.append(to_ndarray(obs['observation'])) @@ -386,18 +434,33 @@ def eval( action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) - game_segments[env_id] = GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) + if self._multi_agent: + for agent_id in range(agent_num): + game_segments[env_id][agent_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) - game_segments[env_id].reset( - [ - init_obs[env_id]['observation'] - for _ in range(self.policy_config.model.frame_stack_num) - ] - ) + game_segments[env_id][agent_id].reset( + [ + init_obs[env_id]['observation'][agent_id] + for _ in range(self.policy_config.model.frame_stack_num) + ] + ) + else: + game_segments[env_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + + game_segments[env_id].reset( + [ + init_obs[env_id]['observation'] + for _ in range(self.policy_config.model.frame_stack_num) + ] + ) # Env reset is done by env_manager automatically. self._policy.reset([env_id]) @@ -423,6 +486,9 @@ def eval( 'reward_min': np.min(episode_return), # 'each_reward': episode_return, } + + # add other info + info = self._add_info(t, info) episode_info = eval_monitor.get_episode_info() if episode_info is not None: info.update(episode_info) @@ -455,4 +521,4 @@ def eval( stop_flag, episode_info = objects episode_info = to_item(episode_info) - return stop_flag, episode_info + return stop_flag, episode_info \ No newline at end of file diff --git a/zoo/gobigger/__init__.py b/zoo/gobigger/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/gobigger/config/__init__.py b/zoo/gobigger/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/gobigger/config/gobigger_efficientzero_config.py b/zoo/gobigger/config/gobigger_efficientzero_config.py new file mode 100644 index 000000000..ce58c26d2 --- /dev/null +++ b/zoo/gobigger/config/gobigger_efficientzero_config.py @@ -0,0 +1,108 @@ +from easydict import EasyDict + +env_name = 'gobigger' +multi_agent = True + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +seed = 0 +collector_env_num = 32 +n_episode = 32 +evaluator_env_num = 5 +num_simulations = 50 +update_per_collect = 1000 +batch_size = 256 +reanalyze_ratio = 0. +action_space_size = 27 +direction_num = 12 +eps_greedy_exploration_in_collect = True +player_num_per_team = 2 +team_num = 2 +agent_num = player_num_per_team*team_num # default is GoBigger T2P2 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +gobigger_efficientzero_config = dict( + exp_name= + f'data_ez_ctree/{env_name}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', + env=dict( + env_name=env_name, + player_num_per_team=player_num_per_team, + team_num=team_num, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + multi_agent=multi_agent, + ignore_done=True, + model=dict( + model_type='structure', + agent_num=agent_num, + team_num=team_num, + latent_state_dim=176, + frame_stack_num=1, + action_space_size=action_space_size, + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + mcts_ctree=True, + gumbel_algo=False, + env_type='not_board_games', + game_segment_length=500, + random_collect_episode_num=0, + eps=dict( + eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, + type='linear', + start=1., + end=0.05, + decay=int(1e5), + ), + use_augmentation=False, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, + ssl_loss_weight=0, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + learn=dict(learner=dict( + log_policy=True, + hook=dict(log_show_after_iter=10, ), + ), ), +) +gobigger_efficientzero_config = EasyDict(gobigger_efficientzero_config) +main_config = gobigger_efficientzero_config + +gobigger_efficientzero_create_config = dict( + env=dict( + type='gobigger_lightzero', + import_names=['zoo.gobigger.env.gobigger_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='multi_agent_efficientzero', + import_names=['lzero.policy.multi_agent_efficientzero'], + ), + collector=dict( + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], + ) +) +gobigger_efficientzero_create_config = EasyDict(gobigger_efficientzero_create_config) +create_config = gobigger_efficientzero_create_config + +if __name__ == "__main__": + from zoo.gobigger.entry import train_muzero_gobigger + train_muzero_gobigger([main_config, create_config], seed=seed) diff --git a/zoo/gobigger/config/gobigger_eval_config.py b/zoo/gobigger/config/gobigger_eval_config.py new file mode 100644 index 000000000..8b98c1a1f --- /dev/null +++ b/zoo/gobigger/config/gobigger_eval_config.py @@ -0,0 +1,38 @@ +# According to the model you want to evaluate, import the corresponding config. +from zoo.gobigger.entry import eval_muzero_gobigger +import numpy as np + +if __name__ == "__main__": + """ + model_path (:obj:`Optional[str]`): The pretrained model path, which should + point to the ckpt file of the pretrained model, and an absolute path is recommended. + In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + """ + # ez + # from gobigger_efficientzero_config import main_config, create_config + + # sez + # from gobigger_sampled_efficientzero_config import main_config, create_config + + # mz + from gobigger_muzero_config import main_config, create_config + model_path = "exp_name/ckpt/ckpt_best.pth.tar" + + returns_mean_seeds = [] + returns_seeds = [] + seeds = [0] + create_config.env_manager.type = 'base' # when visualize must set as base + main_config.env.evaluator_env_num = 1 # when visualize must set as 1 + main_config.env.n_evaluator_episode = 2 # each seed eval episodes num + main_config.env.playback_settings.by_frame.save_frame = True + main_config.env.playback_settings.by_frame.save_name_prefix = 'gobigger' + + for seed in seeds: + returns_selfplay_mean, returns_vsbot_mean = eval_muzero_gobigger( + [main_config, create_config], + seed=seed, + model_path=model_path, + ) + print('seed: {}'.format(seed)) + print('returns_selfplay_mean: {}'.format(returns_selfplay_mean)) + print('returns_vsbot_mean: {}'.format(returns_vsbot_mean)) diff --git a/zoo/gobigger/config/gobigger_muzero_config.py b/zoo/gobigger/config/gobigger_muzero_config.py new file mode 100644 index 000000000..90d1680d4 --- /dev/null +++ b/zoo/gobigger/config/gobigger_muzero_config.py @@ -0,0 +1,108 @@ +from easydict import EasyDict + +env_name = 'gobigger' +multi_agent = True + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +seed = 0 +collector_env_num = 32 +n_episode = 32 +evaluator_env_num = 5 +num_simulations = 50 +update_per_collect = 1000 +batch_size = 256 +reanalyze_ratio = 0. +action_space_size = 27 +direction_num = 12 +eps_greedy_exploration_in_collect = True +player_num_per_team = 2 +team_num = 2 +agent_num = player_num_per_team*team_num # default is GoBigger T2P2 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +gobigger_muzero_config = dict( + exp_name=f'data_mz_ctree/{env_name}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', + env=dict( + env_name=env_name, + player_num_per_team=player_num_per_team, + team_num=team_num, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + multi_agent=multi_agent, + ignore_done=True, + model=dict( + model_type='structure', + agent_num=agent_num, + team_num=team_num, + latent_state_dim=176, + frame_stack_num=1, + action_space_size=action_space_size, + self_supervised_learning_loss=False, + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + mcts_ctree=True, + gumbel_algo=False, + env_type='not_board_games', + game_segment_length=500, + random_collect_episode_num=0, + eps=dict( + eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, + type='linear', + start=1., + end=0.05, + decay=int(1e5), + ), + use_augmentation=False, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, + ssl_loss_weight=0, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + learn=dict(learner=dict( + log_policy=True, + hook=dict(log_show_after_iter=10, ), + ), ), +) +gobigger_muzero_config = EasyDict(gobigger_muzero_config) +main_config = gobigger_muzero_config + +gobigger_muzero_create_config = dict( + env=dict( + type='gobigger_lightzero', + import_names=['zoo.gobigger.env.gobigger_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='multi_agent_muzero', + import_names=['lzero.policy.multi_agent_muzero'], + ), + collector=dict( + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], + ) +) +gobigger_muzero_create_config = EasyDict(gobigger_muzero_create_config) +create_config = gobigger_muzero_create_config + +if __name__ == "__main__": + from zoo.gobigger.entry import train_muzero_gobigger + train_muzero_gobigger([main_config, create_config], seed=seed) diff --git a/zoo/gobigger/entry/__init__.py b/zoo/gobigger/entry/__init__.py new file mode 100644 index 000000000..641457eff --- /dev/null +++ b/zoo/gobigger/entry/__init__.py @@ -0,0 +1,2 @@ +from .train_muzero_gobigger import train_muzero_gobigger +from .eval_muzero_gobigger import eval_muzero_gobigger diff --git a/zoo/gobigger/entry/eval_muzero_gobigger.py b/zoo/gobigger/entry/eval_muzero_gobigger.py new file mode 100644 index 000000000..d19a20986 --- /dev/null +++ b/zoo/gobigger/entry/eval_muzero_gobigger.py @@ -0,0 +1,117 @@ +import logging +import os +from functools import partial +from typing import Optional, Tuple +import numpy as np +import torch +from tensorboardX import SummaryWriter +import copy + +from ding.config import compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.utils import set_pkg_seed +from ding.worker import BaseLearner +from lzero.worker import GoBiggerMuZeroEvaluator + + +def eval_muzero_gobigger( + input_cfg: Tuple[dict, dict], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, +) -> 'Policy': # noqa + """ + Overview: + The eval entry for GoBigger MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. + Arguments: + - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type. + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): The pretrained model path, which should + point to the ckpt file of the pretrained model, and an absolute path is recommended. + In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + Returns: + - reward_sp (:obj:`List`): reward of self-play mode. + - reward_vsbot (:obj:`List`): reward of vsbot mode. + """ + cfg, create_cfg = input_cfg + assert create_cfg.policy.type in ['gobigger_efficientzero', 'gobigger_muzero', 'gobigger_sampled_efficientzero'], \ + "train_muzero entry now only support the following algo.: 'gobigger_efficientzero', 'gobigger_muzero', 'gobigger_sampled_efficientzero'" + + if create_cfg.policy.type == 'gobigger_efficientzero': + from lzero.mcts import GoBiggerEfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'gobigger_muzero': + from lzero.mcts import GoBiggerMuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'gobigger_sampled_efficientzero': + from lzero.mcts import GoBiggerSampledEfficientZeroGameBuffer as GameBuffer + + if cfg.policy.cuda and torch.cuda.is_available(): + cfg.policy.device = 'cuda' + else: + cfg.policy.device = 'cpu' + + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create main components: env, policy + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + env_cfg = copy.deepcopy(evaluator_env_cfg[0]) + env_cfg.contain_raw_obs = True + vsbot_evaluator_env_cfg = [env_cfg for _ in range(len(evaluator_env_cfg))] + vsbot_evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in vsbot_evaluator_env_cfg]) + + collector_env.seed(cfg.seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + vsbot_evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # load pretrained model + if model_path is not None: + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = cfg.policy + evaluator = GoBiggerMuZeroEvaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config + ) + vsbot_evaluator = GoBiggerMuZeroEvaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=vsbot_evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + instance_name='vsbot_evaluator' + ) + # ============================================================== + # Main loop + # ============================================================== + # Learner's before_run hook. + learner.call_hook('before_run') + # ============================================================== + # eval trained model + # ============================================================== + _, reward_sp = evaluator.eval(learner.save_checkpoint, learner.train_iter) + _, reward_vsbot = vsbot_evaluator.eval_vsbot(learner.save_checkpoint, learner.train_iter) + return reward_sp, reward_vsbot diff --git a/zoo/gobigger/entry/train_muzero_gobigger.py b/zoo/gobigger/entry/train_muzero_gobigger.py new file mode 100644 index 000000000..31c194bab --- /dev/null +++ b/zoo/gobigger/entry/train_muzero_gobigger.py @@ -0,0 +1,222 @@ +import logging +import os +from functools import partial +from typing import Optional, Tuple + +import torch +from ding.config import compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.utils import set_pkg_seed, get_rank +from ding.rl_utils import get_epsilon_greedy_fn +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.policy.random_policy import LightZeroRandomPolicy +from lzero.worker import MuZeroCollector as Collector +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.entry.utils import random_collect + +import copy +from lzero.worker import GoBiggerMuZeroEvaluator +from zoo.gobigger.model import GoBiggerEncoder + +def train_muzero_gobigger( + input_cfg: Tuple[dict, dict], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': # noqa + """ + Overview: + The train entry for GoBigger MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. + Arguments: + - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type. + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): The pretrained model path, which should + point to the ckpt file of the pretrained model, and an absolute path is recommended. + In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. + - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + + cfg, create_cfg = input_cfg + assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'multi_agent_efficientzero', 'multi_agent_muzero'], \ + "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'multi_agent_efficientzero', 'multi_agent_muzero'" + + if create_cfg.policy.type == 'muzero' or create_cfg.policy.type == 'multi_agent_muzero': + from lzero.mcts import MuZeroGameBuffer as GameBuffer + from lzero.model.muzero_model_mlp import MuZeroModelMLP as Encoder + elif create_cfg.policy.type == 'efficientzero' or create_cfg.policy.type == 'multi_agent_efficientzero': + from lzero.mcts import EfficientZeroGameBuffer as GameBuffer + from lzero.model.efficientzero_model_mlp import EfficientZeroModelMLP as Encoder + elif create_cfg.policy.type == 'sampled_efficientzero': + from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'gumbel_muzero': + from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer + + if cfg.policy.cuda and torch.cuda.is_available(): + cfg.policy.device = 'cuda' + else: + cfg.policy.device = 'cpu' + + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create main components: env, policy + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + env_cfg = copy.deepcopy(evaluator_env_cfg[0]) + env_cfg.contain_raw_obs = True + vsbot_evaluator_env_cfg = [env_cfg for _ in range(len(evaluator_env_cfg))] + vsbot_evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in vsbot_evaluator_env_cfg]) + + collector_env.seed(cfg.seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + vsbot_evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = Encoder(**cfg.policy.model, state_encoder=GoBiggerEncoder()) + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # load pretrained model + if model_path is not None: + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = cfg.policy + batch_size = policy_config.batch_size + # specific game buffer for MCTS+RL algorithms + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config + ) + evaluator = GoBiggerMuZeroEvaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config + ) + + vsbot_evaluator = GoBiggerMuZeroEvaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=vsbot_evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + instance_name='vsbot_evaluator' + ) + + # ============================================================== + # Main loop + # ============================================================== + # Learner's before_run hook. + learner.call_hook('before_run') + + if cfg.policy.update_per_collect is not None: + update_per_collect = cfg.policy.update_per_collect + + # The purpose of collecting random data before training: + # Exploration: The collection of random data aids the agent in exploring the environment and prevents premature convergence to a suboptimal policy. + # Comparation: The agent's performance during random action-taking can be used as a reference point to evaluate the efficacy of reinforcement learning algorithms. + if cfg.policy.random_collect_episode_num > 0: + if policy_config.multi_agent: + from lzero.policy.multi_agent_random_policy import MultiAgentLightZeroRandomPolicy as RandomPolicy + else: + from lzero.policy.random_policy import LightZeroRandomPolicy as RandomPolicy + random_collect(cfg.policy, policy, RandomPolicy, collector, collector_env, replay_buffer) + + while True: + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + collect_kwargs = {} + # set temperature for visit count distributions according to the train_iter, + # please refer to Appendix D in MuZero paper for details. + collect_kwargs['temperature'] = visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ) + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + else: + collect_kwargs['epsilon'] = 0.0 + + # Evaluate policy performance. + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(None, learner.train_iter, collector.envstep) # save_ckpt_fn = None + stop, reward = vsbot_evaluator.eval_vsbot(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + # Collect data by default config n_sample/n_episode. + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + if cfg.policy.update_per_collect is None: + # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio. + collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]]) + update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio) + # save returned new_data collected by the collector + replay_buffer.push_game_segments(new_data) + # remove the oldest data if the replay buffer is full. + replay_buffer.remove_oldest_data_to_fit() + + # Learn policy from collected data. + for i in range(update_per_collect): + # Learner will train ``update_per_collect`` times in one iteration. + if replay_buffer.get_num_of_transitions() > batch_size: + train_data = replay_buffer.sample(batch_size, policy) + else: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, ' + f'{replay_buffer} ' + f'continue to collect now ....' + ) + break + + # The core train steps for MCTS+RL algorithms. + log_vars = learner.train(train_data, collector.envstep) + + if cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + break + + # Learner's after_run hook. + learner.call_hook('after_run') + return policy diff --git a/zoo/gobigger/env/__init__.py b/zoo/gobigger/env/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/gobigger/env/gobigger_env.py b/zoo/gobigger/env/gobigger_env.py new file mode 100644 index 000000000..b9a2b0f63 --- /dev/null +++ b/zoo/gobigger/env/gobigger_env.py @@ -0,0 +1,548 @@ +import gym +import numpy as np +from ditk import logging +from ding.envs import BaseEnv, BaseEnvTimestep +from ding.utils import ENV_REGISTRY, deep_merge_dicts +import math +from easydict import EasyDict +try: + from gobigger.envs import GoBiggerEnv +except ImportError: + import sys + logging.warning("not found gobigger package, please install it through `pip install git+https://github.com/opendilab/GoBigger.git") + sys.exit(1) + + +default_t2p2_config = dict( + team_num=2, + player_num_per_team=2, + direction_num=12, + step_mul=8, + map_width=64, + map_height=64, + frame_limit=3600, + action_space_size=27, + use_action_mask=False, + reward_div_value=0.1, + reward_type='log_reward', + contain_raw_obs=False, # False on collect mode, True on eval vsbot mode, because bot need raw obs + start_spirit_progress=0.2, + end_spirit_progress=0.8, + manager_settings=dict( + food_manager=dict( + num_init=260, + num_min=260, + num_max=300, + ), + thorns_manager=dict( + num_init=3, + num_min=3, + num_max=4, + ), + player_manager=dict(ball_settings=dict(score_init=13000, ), ), + ), + playback_settings=dict( + playback_type='by_frame', + by_frame=dict( + save_frame=False, # when training should set as False + save_dir='./', + save_name_prefix='gobigger', + ), + ), + ) + + +@ENV_REGISTRY.register('gobigger_lightzero') +class GoBiggerLightZeroEnv(BaseEnv): + + def __init__(self, cfg: dict) -> None: + self._cfg = deep_merge_dicts(default_t2p2_config, cfg) + self._cfg = EasyDict(self._cfg) + # ding env info + self._init_flag = False + self._observation_space = None + self._action_space = None + self._reward_space = None + # gobigger env info + self.team_num = self._cfg.team_num + self.player_num_per_team = self._cfg.player_num_per_team + self.direction_num = self._cfg.direction_num + self.use_action_mask = self._cfg.use_action_mask + self.action_space_size = self._cfg.action_space_size # discrete action space size + self.step_mul = self._cfg.get('step_mul', 8) + self.setup_action() + self.setup_feature() + self.contain_raw_obs = self._cfg.contain_raw_obs # for save memory + + def setup_feature(self): + self.second_per_frame = 0.05 + self.spatial_x = 64 + self.spatial_y = 64 + self.max_ball_num = 80 + self.max_food_num = 256 + self.max_spore_num = 64 + self.max_player_num = self.player_num_per_team + self.reward_div_value = self._cfg.reward_div_value + self.reward_type = self._cfg.reward_type + self.player_init_score = self._cfg.manager_settings.player_manager.ball_settings.score_init + self.start_spirit_progress = self._cfg.start_spirit_progress + self.end_spirit_progress = self._cfg.end_spirit_progress + + def reset(self) -> np.ndarray: + if not self._init_flag: + self._env = GoBiggerEnv(self._cfg, step_mul=self.step_mul) + self._init_flag = True + self.last_action_types = { + player_id: self.direction_num * 2 + for player_id in range(self.player_num_per_team * self.team_num) + } + raw_obs = self._env.reset() + obs = self.observation(raw_obs) + self.last_action_types = { + player_id: self.direction_num * 2 + for player_id in range(self.player_num_per_team * self.team_num) + } + self.last_leaderboard = { + team_idx: self.player_init_score * self.player_num_per_team + for team_idx in range(self.team_num) + } + self.last_player_scores = { + player_id: self.player_init_score + for player_id in range(self.player_num_per_team * self.team_num) + } + return obs + + def observation(self, raw_obs): + obs = self.preprocess_obs(raw_obs) + # for alignment with other environments, reverse the action mask + action_mask = [np.logical_not(o['action_mask']) for o in obs] + to_play = [-1 for _ in range(len(obs))] # Moot, for alignment with other environments + if self.contain_raw_obs: + obs = {'observation': obs, 'action_mask': action_mask, 'to_play': to_play, 'raw_obs': raw_obs} + else: + obs = {'observation': obs, 'action_mask': action_mask, 'to_play': to_play} + return obs + + def postproecess(self, action_dict): + for k, v in action_dict.items(): + if np.isscalar(v): + self.last_action_types[k] = v + else: + self.last_action_types[k] = self.direction_num * 2 + + def step(self, action_dict: dict) -> BaseEnvTimestep: + action = {k: self.transform_action(v) if np.isscalar(v) else v for k, v in action_dict.items()} + raw_obs, raw_rew, done, info = self._env.step(action) + # print('current_frame={}'.format(raw_obs[0]['last_time'])) + # print('action={}'.format(action)) + # print('raw_rew={}, done={}'.format(raw_rew, done)) + rew = self.transform_reward(raw_obs) + obs = self.observation(raw_obs) + # postprocess + self.postproecess(action_dict) + if done: + info['eval_episode_return'] = raw_obs[0]['leaderboard'][0] + info['eval_bot_episode_return'] = raw_obs[0]['leaderboard'][1] #TODO only support t2p2 + return BaseEnvTimestep(obs, rew, done, info) + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + self._seed = seed + self._dynamic_seed = dynamic_seed + np.random.seed(self._seed) + + def close(self) -> None: + if self._init_flag: + self._env.close() + self._init_flag = False + + @property + def observation_space(self) -> gym.spaces.Space: + # The following ensures compatibility with the DI-engine Env class. + return self._observation_space + + @property + def action_space(self) -> gym.spaces.Space: + # The following ensures compatibility with the DI-engine Env class. + return self._action_space + + @property + def reward_space(self) -> gym.spaces.Space: + # The following ensures compatibility with the DI-engine Env class. + return self._reward_space + + def __repr__(self) -> str: + return "LightZero Env({})".format(self.cfg.env_name) + + def transform_obs( + self, + obs, + own_player_id=1, + padding=True, + last_action_type=None, + ): + global_state, player_observations = obs + player2team = self.get_player2team() + leaderboard = global_state['leaderboard'] + team2rank = {key: rank for rank, key in enumerate(sorted(leaderboard, key=leaderboard.get, reverse=True), )} + + own_player_obs = player_observations[own_player_id] + own_team_id = player2team[own_player_id] + + # =========== + # scalar info + # =========== + scene_size = global_state['border'][0] + own_left_top_x, own_left_top_y, own_right_bottom_x, own_right_bottom_y = own_player_obs['rectangle'] + own_view_center = [ + (own_left_top_x + own_right_bottom_x - scene_size) / 2, + (own_left_top_y + own_right_bottom_y - scene_size) / 2 + ] + # own_view_width == own_view_height + own_view_width = float(own_right_bottom_x - own_left_top_x) + + own_score = own_player_obs['score'] / 100 + own_team_score = global_state['leaderboard'][own_team_id] / 100 + own_rank = team2rank[own_team_id] + + scalar_info = { + 'view_x': np.round(np.array(own_view_center[0])).astype(np.int64), + 'view_y': np.round(np.array(own_view_center[1])).astype(np.int64), + 'view_width': np.round(np.array(own_view_width)).astype(np.int64), + 'score': np.clip(np.round(np.log(np.array(own_score) / 10)).astype(np.int64), a_min=None, a_max=9), + 'team_score': np.clip( + np.round(np.log(np.array(own_team_score / 10))).astype(np.int64), a_min=None, a_max=9 + ), + 'time': np.array(global_state['last_time'] // 20, dtype=np.int64), + 'rank': np.array(own_rank, dtype=np.int64), + 'last_action_type': np.array(last_action_type, dtype=np.int64) + } + + # =========== + # team_info + # =========== + + all_players = [] + scene_size = global_state['border'][0] + + for game_player_id in player_observations.keys(): + game_team_id = player2team[game_player_id] + game_player_left_top_x, game_player_left_top_y, game_player_right_bottom_x, game_player_right_bottom_y = \ + player_observations[game_player_id]['rectangle'] + if game_player_id == own_player_id: + alliance = 0 + elif game_team_id == own_team_id: + alliance = 1 + else: + alliance = 2 + if alliance != 2: + game_player_view_x = (game_player_right_bottom_x + game_player_left_top_x - scene_size) / 2 + game_player_view_y = (game_player_right_bottom_y + game_player_left_top_y - scene_size) / 2 + + all_players.append([ + alliance, + game_player_view_x, + game_player_view_y, + ]) + + all_players = np.array(all_players) + player_padding_num = self.max_player_num - len(all_players) + player_num = len(all_players) + if player_padding_num < 0: + all_players = all_players[:self.max_player_num, :] + else: + all_players = np.pad(all_players, pad_width=((0, player_padding_num), (0, 0)), mode='constant') + team_info = { + 'alliance': all_players[:, 0].astype(np.int64), + 'view_x': np.round(all_players[:, 1]).astype(np.int64), + 'view_y': np.round(all_players[:, 2]).astype(np.int64), + 'player_num': np.array(player_num, dtype=np.int64), + } + + # =========== + # ball info + # =========== + ball_type_map = {'clone': 1, 'food': 2, 'thorns': 3, 'spore': 4} + clone = own_player_obs['overlap']['clone'] + thorns = own_player_obs['overlap']['thorns'] + food = own_player_obs['overlap']['food'] + spore = own_player_obs['overlap']['spore'] + + neutral_team_id = self.team_num + neutral_player_id = self.team_num * self.player_num_per_team + neutral_team_rank = self.team_num + + # clone = [type, score, player_id, team_id, team_rank, x, y, next_x, next_y] + clone = [ + [ + ball_type_map['clone'], bl[3], bl[-2], bl[-1], team2rank[bl[-1]], bl[0], bl[1], + *self.next_position(bl[0], bl[1], bl[4], bl[5]) + ] for bl in clone + ] + + # thorn = [type, score, player_id, team_id, team_rank, x, y, next_x, next_y] + thorns = [ + [ + ball_type_map['thorns'], bl[3], neutral_player_id, neutral_team_id, neutral_team_rank, bl[0], bl[1], + *self.next_position(bl[0], bl[1], bl[4], bl[5]) + ] for bl in thorns + ] + + # thorn = [type, score, player_id, team_id, team_rank, x, y, next_x, next_y] + food = [ + [ + ball_type_map['food'], bl[3], neutral_player_id, neutral_team_id, neutral_team_rank, bl[0], bl[1], + bl[0], bl[1] + ] for bl in food + ] + + # spore = [type, score, player_id, team_id, team_rank, x, y, next_x, next_y] + spore = [ + [ + ball_type_map['spore'], bl[3], bl[-1], player2team[bl[-1]], team2rank[player2team[bl[-1]]], bl[0], + bl[1], *self.next_position(bl[0], bl[1], bl[4], bl[5]) + ] for bl in spore + ] + + all_balls = clone + thorns + food + spore + + # Particularly handle balls outside the field of view + for b in all_balls: + if b[2] == own_player_id and b[0] == 1: + if b[5] < own_left_top_x or b[5] > own_right_bottom_x or \ + b[6] < own_left_top_y or b[6] > own_right_bottom_y: + b[5] = int((own_left_top_x + own_right_bottom_x) / 2) + b[6] = int((own_left_top_y + own_right_bottom_y) / 2) + b[7], b[8] = b[5], b[6] + all_balls = np.array(all_balls) + + origin_x = own_left_top_x + origin_y = own_left_top_y + + all_balls[:, -4] = ((all_balls[:, -4] - origin_x) / own_view_width * self.spatial_x) + all_balls[:, -3] = ((all_balls[:, -3] - origin_y) / own_view_width * self.spatial_y) + all_balls[:, -2] = ((all_balls[:, -2] - origin_x) / own_view_width * self.spatial_x) + all_balls[:, -1] = ((all_balls[:, -1] - origin_y) / own_view_width * self.spatial_y) + + # ball + ball_indices = np.logical_and( + all_balls[:, 0] != 2, all_balls[:, 0] != 4 + ) # include player balls and thorn balls + balls = all_balls[ball_indices] + + balls_num = len(balls) + + # consider position of thorns ball + if balls_num > self.max_ball_num: # filter small balls + own_indices = balls[:, 3] == own_player_id + teammate_indices = (balls[:, 4] == own_team_id) & ~own_indices + enemy_indices = balls[:, 4] != own_team_id + + own_balls = balls[own_indices] + teammate_balls = balls[teammate_indices] + enemy_balls = balls[enemy_indices] + + if own_balls.shape[0] + teammate_balls.shape[0] >= self.max_ball_num: + remain_ball_num = self.max_ball_num - own_balls.shape[0] + teammate_ball_score = teammate_balls[:, 1] + teammate_high_score_indices = teammate_ball_score.sort(descending=True)[1][:remain_ball_num] + teammate_remain_balls = teammate_balls[teammate_high_score_indices] + balls = np.concatenate([own_balls, teammate_remain_balls], axis=0) + else: + remain_ball_num = self.max_ball_num - own_balls.shape[0] - teammate_balls.shape[0] + enemy_ball_score = enemy_balls[:, 1] + enemy_high_score_ball_indices = enemy_ball_score.sort(descending=True)[1][:remain_ball_num] + remain_enemy_balls = enemy_balls[enemy_high_score_ball_indices] + + balls = np.concatenate([own_balls, teammate_balls, remain_enemy_balls], axis=0) + balls_num = len(balls) + ball_padding_num = self.max_ball_num - len(balls) + if ball_padding_num < 0: + balls = balls[:self.max_ball_num, :] + alliance = np.zeros(self.max_ball_num) + balls_num = self.max_ball_num + elif padding: + balls = np.pad(balls, ((0, ball_padding_num), (0, 0)), 'constant', constant_values=0) + alliance = np.zeros(self.max_ball_num) + balls_num = min(self.max_ball_num, balls_num) + else: + alliance = np.zeros(balls_num) + alliance[balls[:, 3] == own_team_id] = 2 + alliance[balls[:, 2] == own_player_id] = 1 + alliance[balls[:, 3] != own_team_id] = 3 + alliance[balls[:, 0] == 3] = 0 + + ## score&radius + scale_score = balls[:, 1] / 100 + radius = np.clip(np.sqrt(scale_score * 0.042 + 0.15) / own_view_width, a_max=1, a_min=None) + score = np.clip( + np.round(np.clip(np.sqrt(scale_score * 0.042 + 0.15) / own_view_width, a_max=1, a_min=None) * 50 + ).astype(int), + a_max=49, + a_min=None + ) + ## rank: + ball_rank = balls[:, 4] + + ## coordinates relative to the center of [spatial_x, spatial_y] + x = balls[:, -4] - self.spatial_x // 2 + y = balls[:, -3] - self.spatial_y // 2 + next_x = balls[:, -2] - self.spatial_x // 2 + next_y = balls[:, -1] - self.spatial_y // 2 + + ball_info = { + 'alliance': alliance.astype(np.int64), + 'score': score.astype(np.int64), + 'radius': radius, + 'rank': ball_rank.astype(np.int64), + 'x': np.round(x).astype(np.int64), + 'y': np.round(y).astype(np.int64), + 'next_x': np.round(next_x).astype(np.int64), + 'next_y': np.round(next_y).astype(np.int64), + 'ball_num': np.array(balls_num).astype(np.int64), + } + + # ============ + # spatial info + # ============ + # ball coordinate for scatter connection + # coordinates relative to the upper left corner of [spatial_x, spatial_y] + ball_x = balls[:, -4] + ball_y = balls[:, -3] + + food_indices = all_balls[:, 0] == 2 + food_x = all_balls[food_indices, -4] + food_y = all_balls[food_indices, -3] + food_num = len(food_x) + food_padding_num = self.max_food_num - len(food_x) + if food_padding_num < 0: + food_x = food_x[:self.max_food_num] + food_y = food_y[:self.max_food_num] + elif padding: + food_x = np.pad(food_x, (0, food_padding_num), 'constant', constant_values=0) + food_y = np.pad(food_y, (0, food_padding_num), 'constant', constant_values=0) + food_num = min(food_num, self.max_food_num) + + spore_indices = all_balls[:, 0] == 4 + spore_x = all_balls[spore_indices, -4] + spore_y = all_balls[spore_indices, -3] + spore_num = len(spore_x) + spore_padding_num = self.max_spore_num - len(spore_x) + if spore_padding_num < 0: + spore_x = spore_x[:self.max_spore_num] + spore_y = spore_y[:self.max_spore_num] + elif padding: + spore_x = np.pad(spore_x, (0, spore_padding_num), 'constant', constant_values=0) + spore_y = np.pad(spore_y, (0, spore_padding_num), 'constant', constant_values=0) + spore_num = min(spore_num, self.max_spore_num) + + spatial_info = { + 'food_x': np.clip(np.round(food_x), 0, self.spatial_x - 1).astype(np.int64), + 'food_y': np.clip(np.round(food_y), 0, self.spatial_y - 1).astype(np.int64), + 'spore_x': np.clip(np.round(spore_x), 0, self.spatial_x - 1).astype(np.int64), + 'spore_y': np.clip(np.round(spore_y), 0, self.spatial_y - 1).astype(np.int64), + 'ball_x': np.clip(np.round(ball_x), 0, self.spatial_x - 1).astype(np.int64), + 'ball_y': np.clip(np.round(ball_y), 0, self.spatial_y - 1).astype(np.int64), + 'food_num': np.array(food_num).astype(np.int64), + 'spore_num': np.array(spore_num).astype(np.int64), + } + + output_obs = { + 'scalar_info': scalar_info, + 'team_info': team_info, + 'ball_info': ball_info, + 'spatial_info': spatial_info, + } + return output_obs + + def preprocess_obs(self, raw_obs): + env_player_obs = [] + for game_player_id in range(self.player_num_per_team * self.team_num): + last_action_type = self.last_action_types[game_player_id] + if self.use_action_mask: + can_eject = raw_obs[1][game_player_id]['can_eject'] + can_split = raw_obs[1][game_player_id]['can_split'] + action_mask = self.generate_action_mask(can_eject=can_eject, can_split=can_split) + else: + action_mask = self.generate_action_mask(can_eject=True, can_split=True) + game_player_obs = self.transform_obs( + raw_obs, own_player_id=game_player_id, padding=True, last_action_type=last_action_type + ) + game_player_obs['action_mask'] = action_mask + env_player_obs.append(game_player_obs) + return env_player_obs + + def generate_action_mask(self, can_eject, can_split): + # action mask + # 1 represent can not do this action + # 0 represent can do this action + action_mask = np.zeros((self.action_space_size, ), dtype=np.bool_) + if not can_eject: + action_mask[self.direction_num * 2 + 1] = True + if not can_split: + action_mask[self.direction_num * 2 + 2] = True + return action_mask + + def get_player2team(self, ): + player2team = {} + for player_id in range(self.player_num_per_team * self.team_num): + player2team[player_id] = player_id // self.player_num_per_team + return player2team + + def next_position(self, x, y, vel_x, vel_y): + next_x = x + self.second_per_frame * vel_x * self.step_mul + next_y = y + self.second_per_frame * vel_y * self.step_mul + return next_x, next_y + + def transform_action(self, action_idx): + return self.x_y_action_List[int(action_idx)] + + def setup_action(self): + theta = math.pi * 2 / self.direction_num + self.x_y_action_List = [[0.3 * math.cos(theta * i), 0.3 * math.sin(theta * i), 0] for i in range(self.direction_num)] + \ + [[math.cos(theta * i), math.sin(theta * i), 0] for i in range(self.direction_num)] + \ + [[0, 0, 0], [0, 0, 1], [0, 0, 2]] + + def get_spirit(self, progress): + if progress < self.start_spirit_progress: + return 0 + elif progress <= self.end_spirit_progress: + spirit = (progress - self.start_spirit_progress) / (self.end_spirit_progress - self.start_spirit_progress) + return spirit + else: + return 1 + + def transform_reward(self, next_obs): + last_time = next_obs[0]['last_time'] + total_frame = next_obs[0]['total_frame'] + progress = last_time / total_frame + spirit = self.get_spirit(progress) + score_rewards_list = [] + for game_player_id in range(self.player_num_per_team * self.team_num): + game_team_id = game_player_id // self.player_num_per_team + player_score = next_obs[1][game_player_id]['score'] + team_score = next_obs[0]['leaderboard'][game_team_id] + if self.reward_type == 'log_reward': + player_reward = math.log(player_score) - math.log(self.last_player_scores[game_player_id]) + team_reward = math.log(team_score) - math.log(self.last_leaderboard[game_team_id]) + score_reward = (1 - spirit) * player_reward + spirit * team_reward / self.player_num_per_team + score_reward = score_reward / self.reward_div_value + score_rewards_list.append(score_reward) + elif self.reward_type == 'score': + player_reward = player_score - self.last_player_scores[game_player_id] + team_reward = team_score - self.last_leaderboard[game_team_id] + score_reward = (1 - spirit) * player_reward + spirit * team_reward / self.player_num_per_team + score_reward = score_reward / self.reward_div_value + score_rewards_list.append(score_reward) + elif self.reward_type == 'sqrt_player': + player_reward = player_score - self.last_player_scores[game_player_id] + reward_sign = (player_reward > 0) - (player_reward < 0) # np.sign + score_rewards_list.append(reward_sign * math.sqrt(abs(player_reward)) / 2) + elif self.reward_type == 'sqrt_team': + team_reward = team_score - self.last_leaderboard[game_team_id] + reward_sign = (team_reward > 0) - (team_reward < 0) # np.sign + score_rewards_list.append(reward_sign * math.sqrt(abs(team_reward)) / 2) + else: + raise NotImplementedError + self.last_player_scores[game_player_id] = player_score + self.last_leaderboard = next_obs[0]['leaderboard'] + return score_rewards_list diff --git a/zoo/gobigger/env/gobigger_rule_bot.py b/zoo/gobigger/env/gobigger_rule_bot.py new file mode 100644 index 000000000..916d57bae --- /dev/null +++ b/zoo/gobigger/env/gobigger_rule_bot.py @@ -0,0 +1,216 @@ +import copy +from ding.policy.base_policy import Policy +from ding.utils import POLICY_REGISTRY +import torch +import math +import queue +import random +import numpy as np +from typing import List, Dict, Any, Optional, Tuple, Union +from collections import namedtuple +from collections import defaultdict + + +@POLICY_REGISTRY.register('gobigger_bot') +class GoBiggerBot(Policy): + + def __init__(self, env_num, agent_id: List[int]): + self.env_num = env_num + self.agent_id = agent_id + self.bot = [[BotAgent(i) for i in self.agent_id] for _ in range(self.env_num)] + + def forward(self, raw_obs): + action = defaultdict(dict) + for env_id in range(self.env_num): + obs = raw_obs[env_id] + for agent in self.bot[env_id]: + action[env_id].update(agent.step(obs)) + return action + + def reset(self, env_id_lst=None): + if env_id_lst is None: + env_id_lst = range(self.env_num) + for env_id in env_id_lst: + for agent in self.bot[env_id]: + agent.reset() + + # The following ensures compatibility with the DI-engine Policy class. + def _init_learn(self) -> None: + pass + + def _init_collect(self) -> None: + pass + + def _init_eval(self) -> None: + pass + + def _forward_learn(self, data: dict) -> dict: + pass + + def _forward_collect(self, envs: Dict, obs: Dict, temperature: float = 1) -> Dict[str, torch.Tensor]: + pass + + def _forward_eval(self, data: dict) -> dict: + pass + + def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: + pass + + def _get_train_sample(self, data: list) -> Union[None, List[Any]]: + pass + + def default_model(self) -> Tuple[str, List[str]]: + return 'bot_model', ['lzero.model.bot_model'] + + def _monitor_vars_learn(self) -> List[str]: + pass + + +class BotAgent(): + + def __init__(self, game_player_id): + self.game_player_id = game_player_id # start from 0 + self.actions_queue = queue.Queue() + + def step(self, obs): + obs = obs[1][self.game_player_id] + if self.actions_queue.qsize() > 0: + return {self.game_player_id: self.actions_queue.get()} + overlap = obs['overlap'] + overlap = self.preprocess(overlap) + food_balls = overlap['food'] + thorns_balls = overlap['thorns'] + spore_balls = overlap['spore'] + clone_balls = overlap['clone'] + + my_clone_balls, others_clone_balls = self.process_clone_balls(clone_balls) + + if len(my_clone_balls) >= 9 and my_clone_balls[4]['radius'] > 4: + self.actions_queue.put([0, 0, 0]) + self.actions_queue.put([0, 0, 0]) + self.actions_queue.put([0, 0, 0]) + self.actions_queue.put([0, 0, 0]) + self.actions_queue.put([0, 0, 0]) + self.actions_queue.put([0, 0, 0]) + self.actions_queue.put([0, 0, 0]) + self.actions_queue.put([None, None, 1]) + self.actions_queue.put([None, None, 1]) + self.actions_queue.put([None, None, 1]) + self.actions_queue.put([None, None, 1]) + self.actions_queue.put([None, None, 1]) + self.actions_queue.put([None, None, 1]) + self.actions_queue.put([None, None, 1]) + self.actions_queue.put([None, None, 1]) + action_ret = self.actions_queue.get() + return {self.game_player_id: action_ret} + + if len(others_clone_balls) > 0 and self.can_eat(others_clone_balls[0]['radius'], my_clone_balls[0]['radius']): + direction = (my_clone_balls[0]['position'] - others_clone_balls[0]['position']) + action_type = 0 + else: + min_distance, min_thorns_ball = self.process_thorns_balls(thorns_balls, my_clone_balls[0]) + if min_thorns_ball is not None: + direction = (min_thorns_ball['position'] - my_clone_balls[0]['position']) + else: + min_distance, min_food_ball = self.process_food_balls(food_balls, my_clone_balls[0]) + if min_food_ball is not None: + direction = (min_food_ball['position'] - my_clone_balls[0]['position']) + else: + direction = (np.array([0, 0]) - my_clone_balls[0]['position']) + action_random = random.random() + if action_random < 0.02: + action_type = 1 + if action_random < 0.04 and action_random > 0.02: + action_type = 2 + else: + action_type = 0 + if np.linalg.norm(direction) > 0: + direction = direction / np.linalg.norm(direction) + else: + direction = np.array([1,1]) / np.linalg.norm(np.array([1,1])) + direction = self.add_noise_to_direction(direction) + direction = direction / np.linalg.norm(direction) + self.actions_queue.put([direction[0], direction[1], action_type]) + action_ret = self.actions_queue.get() + return {self.game_player_id: action_ret} + + def process_clone_balls(self, clone_balls): + my_clone_balls = [] + others_clone_balls = [] + for clone_ball in clone_balls: + if clone_ball['player'] == self.game_player_id: + my_clone_balls.append(copy.deepcopy(clone_ball)) + my_clone_balls.sort(key=lambda a: a['radius'], reverse=True) + + for clone_ball in clone_balls: + if clone_ball['player'] != self.game_player_id: + others_clone_balls.append(copy.deepcopy(clone_ball)) + others_clone_balls.sort(key=lambda a: a['radius'], reverse=True) + return my_clone_balls, others_clone_balls + + def process_thorns_balls(self, thorns_balls, my_max_clone_ball): + min_distance = 10000 + min_thorns_ball = None + for thorns_ball in thorns_balls: + if self.can_eat(my_max_clone_ball['radius'], thorns_ball['radius']): + distance = np.linalg.norm((thorns_ball['position'] - my_max_clone_ball['position'])) + if distance < min_distance: + min_distance = distance + min_thorns_ball = copy.deepcopy(thorns_ball) + return min_distance, min_thorns_ball + + def process_food_balls(self, food_balls, my_max_clone_ball): + min_distance = 10000 + min_food_ball = None + for food_ball in food_balls: + distance = np.linalg.norm(food_ball['position'] - my_max_clone_ball['position']) + if distance < min_distance: + min_distance = distance + min_food_ball = copy.deepcopy(food_ball) + return min_distance, min_food_ball + + def preprocess(self, overlap): + new_overlap = {} + for k, v in overlap.items(): + if k == 'clone': + new_overlap[k] = [] + for index, vv in enumerate(v): + tmp = {} + tmp['position'] = np.array([vv[0], vv[1]]) + tmp['radius'] = vv[2] + tmp['player'] = int(vv[-2]) + tmp['team'] = int(vv[-1]) + new_overlap[k].append(tmp) + else: + new_overlap[k] = [] + for index, vv in enumerate(v): + tmp = {} + tmp['position'] = np.array([vv[0], vv[1]]) + tmp['radius'] = vv[2] + new_overlap[k].append(tmp) + return new_overlap + + def preprocess_tuple2vector(self, overlap): + new_overlap = {} + for k, v in overlap.items(): + new_overlap[k] = [] + for index, vv in enumerate(v): + new_overlap[k].append(vv) + new_overlap[k][index]['position'] = np.array(*vv['position']) + return new_overlap + + def add_noise_to_direction(self, direction, noise_ratio=0.1): + direction = direction + np.array( + ((random.random() * 2 - 1) * noise_ratio) * direction[0], + ((random.random() * 2 - 1) * noise_ratio) * direction[1] + ) + return direction + + def radius_to_score(self, radius): + return (math.pow(radius, 2) - 0.15) / 0.042 * 100 + + def can_eat(self, radius1, radius2): + return self.radius_to_score(radius1) > 1.3 * self.radius_to_score(radius2) + + def reset(self, ): + self.actions_queue.queue.clear() diff --git a/zoo/gobigger/env/test_gobigger_env.py b/zoo/gobigger/env/test_gobigger_env.py new file mode 100644 index 000000000..416bf4c43 --- /dev/null +++ b/zoo/gobigger/env/test_gobigger_env.py @@ -0,0 +1,65 @@ +import pytest +from easydict import EasyDict +from gobigger_env import GoBiggerLightZeroEnv +from gobigger_rule_bot import BotAgent + +env_cfg = EasyDict( + dict( + env_name='gobigger', + team_num=2, + player_num_per_team=2, + direction_num=12, + step_mul=8, + map_width=64, + map_height=64, + frame_limit=3600, + action_space_size=27, + use_action_mask=False, + reward_div_value=0.1, + reward_type='log_reward', + contain_raw_obs=True, # False on collect mode, True on eval vsbot mode, because bot need raw obs + start_spirit_progress=0.2, + end_spirit_progress=0.8, + manager_settings=dict( + food_manager=dict( + num_init=260, + num_min=260, + num_max=300, + ), + thorns_manager=dict( + num_init=3, + num_min=3, + num_max=4, + ), + player_manager=dict(ball_settings=dict(score_init=13000, ), ), + ), + playback_settings=dict( + playback_type='by_frame', + by_frame=dict( + # save_frame=False, + save_frame=True, + save_dir='./', + save_name_prefix='test', + ), + ), + ) +) + + +@pytest.mark.envtest +class TestGoBiggerLightZeroEnv: + + def test_env(self): + env = GoBiggerLightZeroEnv(env_cfg) + obs = env.reset() + from gobigger_rule_bot import BotAgent + bot = [BotAgent(i) for i in range(4)] + while True: + actions = {} + for i in range(4): + # bot[i].step(obs['raw_obs'] is dict + actions.update(bot[i].step(obs['raw_obs'])) + obs, rew, done, info = env.step(actions) + print(rew, info) + if done: + break diff --git a/zoo/gobigger/model/__init__.py b/zoo/gobigger/model/__init__.py new file mode 100644 index 000000000..c7ba3e2e0 --- /dev/null +++ b/zoo/gobigger/model/__init__.py @@ -0,0 +1 @@ +from .model import GoBiggerEncoder \ No newline at end of file diff --git a/zoo/gobigger/model/encoder.py b/zoo/gobigger/model/encoder.py new file mode 100644 index 000000000..a42de3282 --- /dev/null +++ b/zoo/gobigger/model/encoder.py @@ -0,0 +1,311 @@ +import numpy as np +import torch +import torch.nn as nn + + +class OnehotEncoder(nn.Module): + """ + Overview: + For encoding integers into one-hot vectors using PyTorch's Embedding layer. + """ + def __init__(self, num_embeddings: int): + """ + Overview: + The initializer for OnehotEncoder. It initializes an Embedding layer with an identity matrix as its weights. + Arguments: + - num_embeddings (int): The size of the dictionary of embeddings, i.e., the number of rows in the embedding matrix. + """ + super(OnehotEncoder, self).__init__() + self.num_embeddings = num_embeddings + self.main = nn.Embedding.from_pretrained(torch.eye(self.num_embeddings), freeze=True, padding_idx=None) + + def forward(self, x: torch.Tensor): + """ + Overview: + The common computation graph of the OnehotEncoder. It encodes the input tensor into one-hot vectors. + Arguments: + - x (torch.Tensor): The input tensor should be integers and its maximum value should be less than the 'num_embeddings' specified in the initializer. + Returns: + - (torch.Tensor): Return the x-th row of the embedding layer. + """ + x = x.long().clamp_(max=self.num_embeddings - 1) + return self.main(x) + + +class OnehotEmbedding(nn.Module): + """ + Overview: + For encoding integers into higher-dimensional embedding vectors using PyTorch's Embedding layer. + """ + def __init__(self, num_embeddings: int, embedding_dim: int): + """ + Overview: + The initializer for OnehotEmbedding. It initializes an Embedding layer with 'num_embeddings' rows and 'embedding_dim' columns. + Arguments: + - num_embeddings (int): The size of the dictionary of embeddings, i.e., the number of rows in the embedding matrix. + - embedding_dim (int): The size of each embedding vector, i.e., the number of columns in the embedding matrix. + """ + super(OnehotEmbedding, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.main = nn.Embedding(num_embeddings=self.num_embeddings, embedding_dim=self.embedding_dim) + + def forward(self, x: torch.Tensor): + """ + Overview: + The common computation graph of the OnehotEmbedding. + It encodes the input tensor into higher-dimensional embedding vectors. + Arguments: + - x (torch.Tensor): The input tensor should be integers, and its maximum value should be less than the 'num_embeddings' specified in the initializer. + Returns: + - (torch.Tensor): Return the x-th row of the embedding layer. + """ + x = x.long().clamp_(max=self.num_embeddings - 1) + return self.main(x) + + +class BinaryEncoder(nn.Module): + """ + Overview: + For encoding integers into binary vectors using PyTorch's Embedding layer. + """ + def __init__(self, num_embeddings: int): + """ + Overview: + The initializer for BinaryEncoder. It initializes an Embedding layer with a binary embedding matrix + as its weights. The binary embedding matrix is constructed by representing each integer (from 0 to 2^bit_num-1) + as a binary vector. + Arguments: + - num_embeddings (int): The number of bits in the binary representation. It determines the size of the dictionary of embeddings, + i.e., the number of rows in the embedding matrix (2^bit_num), and the size of each embedding vector, + i.e., the number of columns in the embedding matrix (bit_num). + """ + super(BinaryEncoder, self).__init__() + self.bit_num = num_embeddings + self.main = nn.Embedding.from_pretrained( + self.get_binary_embed_matrix(self.bit_num), freeze=True, padding_idx=None + ) + + @staticmethod + def get_binary_embed_matrix(bit_num): + """ + Overview: + A helper function that generates the binary embedding matrix. + Arguments: + - bit_num (int): The number of bits in the binary representation. + Returns: + - (torch.Tensor): A tensor of shape (2^bit_num, bit_num), where each row is the binary representation of the row index. + """ + embedding_matrix = [] + for n in range(2 ** bit_num): + embedding = [n >> d & 1 for d in range(bit_num)][::-1] + embedding_matrix.append(embedding) + return torch.tensor(embedding_matrix, dtype=torch.float) + + def forward(self, x: torch.Tensor): + """ + Overview: + The common computation graph of the BinaryEncoder. It encodes the input tensor into binary vectors. + Arguments: + - x (torch.Tensor): The input tensor should be integers, and its maximum value should be less than 2^bit_num. + Returns: + - (torch.Tensor): Return the x-th row of the embedding layer. + """ + x = x.long().clamp_(max=2 ** self.bit_num - 1) + return self.main(x) + + +class SignBinaryEncoder(nn.Module): + """ + Overview: + For encoding integers into signed binary vectors using PyTorch's Embedding layer. + """ + def __init__(self, num_embeddings: int): + """ + Overview: + The initializer for SignBinaryEncoder. It initializes an Embedding layer with a signed binary embedding matrix + as its weights. The signed binary embedding matrix is constructed by representing each integer (from -2^(bit_num-1) to 2^(bit_num-1)-1) + as a signed binary vector. The first bit is the sign bit, with 1 representing negative and 0 representing nonnegative. + Arguments: + - num_embeddings (int): The number of bits in the signed binary representation. It determines the size of the dictionary of embeddings, + i.e., the number of rows in the embedding matrix (2^bit_num), and the size of each embedding vector, + i.e., the number of columns in the embedding matrix (bit_num). + """ + super(SignBinaryEncoder, self).__init__() + self.bit_num = num_embeddings + self.main = nn.Embedding.from_pretrained( + self.get_sign_binary_matrix(self.bit_num), freeze=True, padding_idx=None + ) + self.max_val = 2 ** (self.bit_num - 1) - 1 + + @staticmethod + def get_sign_binary_matrix(bit_num): + """ + Overview: + A helper function that generates the signed binary embedding matrix. + Arguments: + - bit_num (int): The number of bits in the signed binary representation. + Returns: + - (torch.Tensor): A tensor of shape (2^bit_num, bit_num), where each row is the signed binary representation of the row index minus 2^(bit_num-1). + The first column is the sign bit, with 1 representing negative and 0 representing nonnegative. + """ + neg_embedding_matrix = [] + pos_embedding_matrix = [] + for n in range(1, 2 ** (bit_num - 1)): + embedding = [n >> d & 1 for d in range(bit_num - 1)][::-1] + neg_embedding_matrix.append([1] + embedding) + pos_embedding_matrix.append([0] + embedding) + embedding_matrix = neg_embedding_matrix[::-1] + [[0 for _ in range(bit_num)]] + pos_embedding_matrix + return torch.tensor(embedding_matrix, dtype=torch.float) + + def forward(self, x: torch.Tensor): + """ + Overview: + The common computation graph of the SignBinaryEncoder. It encodes the input tensor into signed binary vectors. + Arguments: + - x (torch.Tensor): The input tensor. Its data type should be integers, and its maximum absolute value should be less than 2^(bit_num-1). + Returns: + - (torch.Tensor): Return the x-th row of the embedding layer. + """ + x = x.long().clamp_(max=self.max_val, min=-self.max_val) + return self.main(x + self.max_val) + + +class PositionEncoder(nn.Module): + """ + Overview: + For encoding the position of elements into higher-dimensional vectors using PyTorch's Embedding layer. + This is typically used in Transformer models to add positional information to the input tokens. + The position encoding is initialized using a sinusoidal formula, as proposed in the "Attention is All You Need" paper. + """ + def __init__(self, num_embeddings: int, embedding_dim: int = None): + """ + Overview: + The initializer for PositionEncoder. It initializes an Embedding layer with a sinusoidal position encoding matrix + as its weights. + Arguments: + - num_embeddings (int): The maximum number of positions to be encoded, i.e., the number of rows in the position encoding matrix. + - embedding_dim (int, optional): The size of each position encoding vector, i.e., the number of columns in the position encoding matrix. + If not provided, it is set equal to 'num_embeddings'. + """ + super(PositionEncoder, self).__init__() + self.n_position = num_embeddings + self.embedding_dim = self.n_position if embedding_dim is None else embedding_dim + self.position_enc = nn.Embedding.from_pretrained( + self.position_encoding_init(self.n_position, self.embedding_dim), freeze=True, padding_idx=None + ) + + @staticmethod + def position_encoding_init(n_position, embedding_dim): + """ + Overview: + A helper function that generates the sinusoidal position encoding matrix. + Arguments: + - n_position (int): The maximum number of positions to be encoded. + - embedding_dim (int): The size of each position encoding vector. + Returns: + - (torch.Tensor): A tensor of shape (n_position, embedding_dim), where each row is the sinusoidal position encoding of the row index. + """ + position_enc = np.array( + [ + [pos / np.power(10000, 2 * (j // 2) / embedding_dim) for j in range(embedding_dim)] + for pos in range(n_position) + ] + ) + position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # apply sin on 0th,2nd,4th...embedding_dim + position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # apply cos on 1st,3rd,5th...embedding_dim + return torch.from_numpy(position_enc).type(torch.FloatTensor) + + def forward(self, x: torch.Tensor): + """ + Overview: + The common computation graph of the PositionEncoder. It encodes the input tensor into positional vectors. + Arguments: + - x (torch.Tensor): The input tensor should be integers, and its maximum value should be less than 'num_embeddings' specified in the initializer. Each value in 'x' represents a position to be encoded. + Returns: + - (torch.Tensor): Return the x-th row of the embedding layer. + """ + return self.position_enc(x) + + +class TimeEncoder(nn.Module): + """ + Overview: + For encoding temporal or sequential data into higher-dimensional vectors using the sinusoidal position + encoding mechanism used in Transformer models. This is useful when working with time series data or sequences where + the position of each element (in this case time) is important. + """ + def __init__(self, embedding_dim: int): + """ + Overview: + The initializer for TimeEncoder. It initializes the position array which is used to scale the input data in the sinusoidal encoding function. + Arguments: + - embedding_dim (int): The size of each position encoding vector, i.e., the number of features in the encoded representation. + """ + super(TimeEncoder, self).__init__() + self.embedding_dim = embedding_dim + self.position_array = torch.nn.Parameter(self.get_position_array(), requires_grad=False) + + def get_position_array(self): + """ + Overview: + A helper function that generates the position array used in the sinusoidal encoding function. Each element in the array is 1 / (10000^(2i/d)), + where i is the position in the array and d is the embedding dimension. This array is used to scale the input data in the encoding function. + Returns: + - (torch.Tensor): A tensor of shape (embedding_dim,) containing the position array. + """ + x = torch.arange(0, self.embedding_dim, dtype=torch.float) + x = x // 2 * 2 + x = torch.div(x, self.embedding_dim) + x = torch.pow(10000., x) + x = torch.div(1., x) + return x + + def forward(self, x: torch.Tensor): + """ + Overview: + The common computation graph of the TimeEncoder. It encodes the input tensor into temporal vectors. + Arguments: + - x (torch.Tensor): The input tensor. Its data type should be one-dimensional, and each value in 'x' represents a timestamp or a position in sequence to be encoded. + Returns: + - (torch.Tensor): A tensor containing the temporal encoded vectors of the input tensor 'x'. + """ + v = torch.zeros(size=(x.shape[0], self.embedding_dim), dtype=torch.float, device=x.device) + assert len(x.shape) == 1 + x = x.unsqueeze(dim=1) + v[:, 0::2] = torch.sin(x * self.position_array[0::2]) # apply sin on even-indexed positions + v[:, 1::2] = torch.cos(x * self.position_array[1::2]) # apply cos on odd-indexed positions + return v + + +class UnsqueezeEncoder(nn.Module): + """ + Overview: + For unsqueezes a tensor along the specified dimension and then optionally normalizes the tensor. + This is useful when we want to add an extra dimension to the input tensor and potentially scale its values. + """ + def __init__(self, unsqueeze_dim: int = -1, norm_value: float = 1): + """ + Overview: + The initializer for UnsqueezeEncoder. + Arguments: + - unsqueeze_dim (int, optional): The dimension to unsqueeze. Default is -1, which unsqueezes at the last dimension. + - norm_value (float, optional): The value to normalize the tensor by. Default is 1, which means no normalization. + """ + super(UnsqueezeEncoder, self).__init__() + self.unsqueeze_dim = unsqueeze_dim + self.norm_value = norm_value + + def forward(self, x: torch.Tensor): + """ + Overview: + The common computation graph of the UnsqueezeEncoder. It unsqueezes the input tensor along the specified dimension and then normalizes the tensor. + Arguments: + - x (torch.Tensor): The input tensor. + Returns: + - (torch.Tensor): The unsqueezed and normalized tensor. Its shape is the same as the input tensor, but with an extra dimension at the position specified by 'unsqueeze_dim'. Its values are the values of the input tensor divided by 'norm_value'. + """ + x = x.float().unsqueeze(dim=self.unsqueeze_dim) + if self.norm_value != 1: + x = x / self.norm_value + return x diff --git a/zoo/gobigger/model/model.py b/zoo/gobigger/model/model.py new file mode 100644 index 000000000..6fda88c7e --- /dev/null +++ b/zoo/gobigger/model/model.py @@ -0,0 +1,432 @@ +from typing import Dict, Optional +import torch +import torch.nn as nn +from torch import Tensor +from ding.utils.default_helper import deep_merge_dicts +from ding.torch_utils import MLP, fc_block, conv2d_block, ResBlock +from ding.torch_utils import Transformer, ScatterConnection +from easydict import EasyDict + +from .encoder import SignBinaryEncoder, BinaryEncoder, OnehotEncoder, TimeEncoder, UnsqueezeEncoder + + +def sequence_mask(lengths: torch.Tensor, max_len: Optional[int] = None): + r""" + Overview: + create a mask for a batch sequences with different lengths + Arguments: + - lengths (:obj:`tensor`): lengths in each different sequences, shape could be (n, 1) or (n) + - max_len (:obj:`int`): the padding size, if max_len is None, the padding size is the + max length of sequences + Returns: + - masks (:obj:`torch.BoolTensor`): mask has the same device as lengths + """ + if len(lengths.shape) == 1: + lengths = lengths.unsqueeze(dim=1) + bz = lengths.numel() + if max_len is None: + max_len = lengths.max() + return torch.arange(0, max_len).type_as(lengths).repeat(bz, 1).lt(lengths).to(lengths.device) + + +class ScalarEncoder(nn.Module): + + def __init__(self, cfg): + super(ScalarEncoder, self).__init__() + self.whole_cfg = cfg + self.cfg = self.whole_cfg.scalar_encoder + self.encode_modules = nn.ModuleDict() + for k, item in self.cfg.modules.items(): + if item['arc'] == 'time': + self.encode_modules[k] = TimeEncoder(embedding_dim=item['embedding_dim']) + elif item['arc'] == 'one_hot': + self.encode_modules[k] = OnehotEncoder(num_embeddings=item['num_embeddings'], ) + elif item['arc'] == 'binary': + self.encode_modules[k] = BinaryEncoder(num_embeddings=item['num_embeddings'], ) + elif item['arc'] == 'sign_binary': + self.encode_modules[k] = SignBinaryEncoder(num_embeddings=item['num_embeddings'], ) + else: + print(f'cant implement {k} for arc {item["arc"]}') + raise NotImplementedError + + self.layers = MLP( + in_channels=self.cfg.mlp.input_dim, + hidden_channels=self.cfg.mlp.hidden_dim, + out_channels=self.cfg.mlp.output_dim, + layer_num=self.cfg.mlp.layer_num, + layer_fn=fc_block, + activation=self.cfg.mlp.activation, + norm_type=self.cfg.mlp.norm_type, + use_dropout=False, + output_activation=True, + output_norm=True, + last_linear_layer_init_zero=False + ) + + def forward(self, x: Dict[str, Tensor]): + embeddings = [] + for key, item in self.cfg.modules.items(): + assert key in x, key + embeddings.append(self.encode_modules[key](x[key])) + + out = torch.cat(embeddings, dim=-1) + out = self.layers(out) + return out + + +class TeamEncoder(nn.Module): + + def __init__(self, cfg): + super(TeamEncoder, self).__init__() + self.whole_cfg = cfg + self.cfg = self.whole_cfg.team_encoder + self.encode_modules = nn.ModuleDict() + + for k, item in self.cfg.modules.items(): + if item['arc'] == 'one_hot': + self.encode_modules[k] = OnehotEncoder(num_embeddings=item['num_embeddings'], ) + elif item['arc'] == 'binary': + self.encode_modules[k] = BinaryEncoder(num_embeddings=item['num_embeddings'], ) + elif item['arc'] == 'sign_binary': + self.encode_modules[k] = SignBinaryEncoder(num_embeddings=item['num_embeddings'], ) + else: + print(f'cant implement {k} for arc {item["arc"]}') + raise NotImplementedError + + self.encode_layers = MLP( + in_channels=self.cfg.mlp.input_dim, + hidden_channels=self.cfg.mlp.hidden_dim, + out_channels=self.cfg.mlp.output_dim, + layer_num=self.cfg.mlp.layer_num, + layer_fn=fc_block, + activation=self.cfg.mlp.activation, + norm_type=self.cfg.mlp.norm_type, + use_dropout=False, + output_activation=True, + output_norm=True, + last_linear_layer_init_zero=False + ) + + self.transformer = Transformer( + input_dim=self.cfg.transformer.input_dim, + output_dim=self.cfg.transformer.output_dim, + head_num=self.cfg.transformer.head_num, + head_dim=self.cfg.transformer.embedding_dim, + hidden_dim=self.cfg.transformer.ffn_size, + layer_num=self.cfg.transformer.layer_num, + activation=self.cfg.transformer.activation, + ) + self.output_fc = fc_block( + self.cfg.fc_block.input_dim, + self.cfg.fc_block.output_dim, + norm_type=self.cfg.fc_block.norm_type, + activation=self.cfg.fc_block.activation + ) + + def forward(self, x): + embeddings = [] + player_num = x['player_num'] + mask = sequence_mask(player_num, max_len=x['view_x'].shape[1]) + for key, item in self.cfg.modules.items(): + assert key in x, f"{key} not implemented" + x_input = x[key] + embeddings.append(self.encode_modules[key](x_input)) + + x = torch.cat(embeddings, dim=-1) + x = self.encode_layers(x) + x = self.transformer(x, mask=mask) + team_info = self.output_fc(x.sum(dim=1) / player_num.unsqueeze(dim=-1)) + return team_info + + +class BallEncoder(nn.Module): + + def __init__(self, cfg): + super(BallEncoder, self).__init__() + self.whole_cfg = cfg + self.cfg = self.whole_cfg.ball_encoder + self.encode_modules = nn.ModuleDict() + for k, item in self.cfg.modules.items(): + if item['arc'] == 'one_hot': + self.encode_modules[k] = OnehotEncoder(num_embeddings=item['num_embeddings'], ) + elif item['arc'] == 'binary': + self.encode_modules[k] = BinaryEncoder(num_embeddings=item['num_embeddings'], ) + elif item['arc'] == 'sign_binary': + self.encode_modules[k] = SignBinaryEncoder(num_embeddings=item['num_embeddings'], ) + elif item['arc'] == 'unsqueeze': + self.encode_modules[k] = UnsqueezeEncoder() + else: + print(f'cant implement {k} for arc {item["arc"]}') + raise NotImplementedError + self.encode_layers = MLP( + in_channels=self.cfg.mlp.input_dim, + hidden_channels=self.cfg.mlp.hidden_dim, + out_channels=self.cfg.mlp.output_dim, + layer_num=self.cfg.mlp.layer_num, + layer_fn=fc_block, + activation=self.cfg.mlp.activation, + norm_type=self.cfg.mlp.norm_type, + use_dropout=False, + output_activation=True, + output_norm=True, + last_linear_layer_init_zero=False + ) + + self.transformer = Transformer( + input_dim=self.cfg.transformer.input_dim, + output_dim=self.cfg.transformer.output_dim, + head_num=self.cfg.transformer.head_num, + head_dim=self.cfg.transformer.embedding_dim, + hidden_dim=self.cfg.transformer.ffn_size, + layer_num=self.cfg.transformer.layer_num, + activation=self.cfg.transformer.activation, + ) + self.output_fc = fc_block( + self.cfg.fc_block.input_dim, + self.cfg.fc_block.output_dim, + norm_type=self.cfg.fc_block.norm_type, + activation=self.cfg.fc_block.activation + ) + + def forward(self, x): + ball_num = x['ball_num'] + embeddings = [] + mask = sequence_mask(ball_num, max_len=x['x'].shape[1]) + for key, item in self.cfg.modules.items(): + assert key in x, key + x_input = x[key] + embeddings.append(self.encode_modules[key](x_input)) + x = torch.cat(embeddings, dim=-1) + x = self.encode_layers(x) + x = self.transformer(x, mask=mask) + + ball_info = x.sum(dim=1) / ball_num.unsqueeze(dim=-1) + ball_info = self.output_fc(ball_info) + return x, ball_info + + +class SpatialEncoder(nn.Module): + + def __init__(self, cfg): + super(SpatialEncoder, self).__init__() + self.whole_cfg = cfg + self.cfg = self.whole_cfg.spatial_encoder + + # scatter related + self.spatial_x = 64 + self.spatial_y = 64 + self.scatter_fc = fc_block( + in_channels=self.cfg.scatter.input_dim, + out_channels=self.cfg.scatter.output_dim, + activation=self.cfg.scatter.activation, + norm_type=self.cfg.scatter.norm_type + ) + self.scatter_connection = ScatterConnection(self.cfg.scatter.scatter_type) + + # resnet related + self.get_resnet_blocks() + + self.output_fc = fc_block( + in_channels=self.spatial_x // 8 * self.spatial_y // 8 * self.cfg.resnet.down_channels[-1], + out_channels=self.cfg.fc_block.output_dim, + norm_type=self.cfg.fc_block.norm_type, + activation=self.cfg.fc_block.activation + ) + + def get_resnet_blocks(self): + # 2 means food/spore embedding + project = conv2d_block( + in_channels=self.cfg.scatter.output_dim + 2, + out_channels=self.cfg.resnet.project_dim, + kernel_size=1, + stride=1, + padding=0, + activation=self.cfg.resnet.activation, + norm_type=self.cfg.resnet.norm_type, + bias=False, + ) + + layers = [project] + dims = [self.cfg.resnet.project_dim] + self.cfg.resnet.down_channels + for i in range(len(dims) - 1): + layer = conv2d_block( + in_channels=dims[i], + out_channels=dims[i + 1], + kernel_size=4, + stride=2, + padding=1, + activation=self.cfg.resnet.activation, + norm_type=self.cfg.resnet.norm_type, + bias=False, + ) + layers.append(layer) + layers.append( + ResBlock( + res_type='basic', + in_channels=dims[i + 1], + activation=self.cfg.resnet.activation, + norm_type=self.cfg.resnet.norm_type + ) + ) + self.resnet = torch.nn.Sequential(*layers) + + def get_background_embedding( + self, + coord_x, + coord_y, + num, + ): + + background_ones = torch.ones(size=(coord_x.shape[0], coord_x.shape[1]), device=coord_x.device) + background_mask = sequence_mask(num, max_len=coord_x.shape[1]) + background_ones = (background_ones * background_mask).unsqueeze(-1) + background_embedding = self.scatter_connection.xy_forward( + background_ones, spatial_size=[self.spatial_x, self.spatial_y], coord_x=coord_x, coord_y=coord_y + ) + + return background_embedding + + def forward( + self, + inputs, + ball_embeddings, + ): + spatial_info = inputs['spatial_info'] + # food and spore + food_embedding = self.get_background_embedding( + coord_x=spatial_info['food_x'], + coord_y=spatial_info['food_y'], + num=spatial_info['food_num'], + ) + + spore_embedding = self.get_background_embedding( + coord_x=spatial_info['spore_x'], + coord_y=spatial_info['spore_y'], + num=spatial_info['spore_num'], + ) + # scatter ball embeddings + ball_info = inputs['ball_info'] + ball_num = ball_info['ball_num'] + ball_mask = sequence_mask(ball_num, max_len=ball_embeddings.shape[1]) + ball_embedding = self.scatter_fc(ball_embeddings) * ball_mask.unsqueeze(dim=2) + + ball_embedding = self.scatter_connection.xy_forward( + ball_embedding, + spatial_size=[self.spatial_x, self.spatial_y], + coord_x=spatial_info['ball_x'], + coord_y=spatial_info['ball_y'] + ) + + x = torch.cat([food_embedding, spore_embedding, ball_embedding], dim=1) + + x = self.resnet(x) + + x = torch.flatten(x, start_dim=1, end_dim=-1) + x = self.output_fc(x) + return x + + +class GoBiggerEncoder(nn.Module): + config = dict( + scalar_encoder=dict( + modules=dict( + view_x=dict(arc='sign_binary', num_embeddings=7), + view_y=dict(arc='sign_binary', num_embeddings=7), + view_width=dict(arc='binary', num_embeddings=7), + score=dict(arc='one_hot', num_embeddings=10), + team_score=dict(arc='one_hot', num_embeddings=10), + rank=dict(arc='one_hot', num_embeddings=4), + time=dict(arc='time', embedding_dim=8), + last_action_type=dict(arc='one_hot', num_embeddings=27), + ), + mlp=dict( + input_dim=80, + hidden_dim=64, + layer_num=2, + norm_type='BN', + output_dim=32, + activation=nn.ReLU(inplace=True) + ), + ), + team_encoder=dict( + modules=dict( + alliance=dict(arc='one_hot', num_embeddings=2), + view_x=dict(arc='sign_binary', num_embeddings=7), + view_y=dict(arc='sign_binary', num_embeddings=7), + ), + mlp=dict( + input_dim=16, + hidden_dim=32, + layer_num=2, + norm_type=None, + output_dim=16, + activation=nn.ReLU(inplace=True) + ), + transformer=dict( + input_dim=16, + output_dim=16, + head_num=4, + ffn_size=32, + layer_num=2, + embedding_dim=16, + activation=nn.ReLU(inplace=True), + variant='postnorm' + ), + fc_block=dict(input_dim=16, output_dim=16, activation=nn.ReLU(inplace=True), norm_type='BN'), + ), + ball_encoder=dict( + modules=dict( + alliance=dict(arc='one_hot', num_embeddings=4), + score=dict(arc='one_hot', num_embeddings=50), + radius=dict(arc='unsqueeze', ), + rank=dict(arc='one_hot', num_embeddings=5), + x=dict(arc='sign_binary', num_embeddings=8), + y=dict(arc='sign_binary', num_embeddings=8), + next_x=dict(arc='sign_binary', num_embeddings=8), + next_y=dict(arc='sign_binary', num_embeddings=8), + ), + mlp=dict( + input_dim=92, + hidden_dim=128, + layer_num=2, + norm_type=None, + output_dim=64, + activation=nn.ReLU(inplace=True) + ), + transformer=dict( + input_dim=64, + output_dim=64, + head_num=4, + ffn_size=64, + layer_num=3, + embedding_dim=64, + activation=nn.ReLU(inplace=True), + variant='postnorm' + ), + fc_block=dict(input_dim=64, output_dim=64, activation=nn.ReLU(inplace=True), norm_type='BN'), + ), + spatial_encoder=dict( + scatter=dict( + input_dim=64, output_dim=16, scatter_type='add', activation=nn.ReLU(inplace=True), norm_type=None + ), + resnet=dict(project_dim=12, down_channels=[32, 32, 16], activation=nn.ReLU(inplace=True), norm_type='BN'), + fc_block=dict(output_dim=64, activation=nn.ReLU(inplace=True), norm_type='BN'), + ), + ) + + def __init__(self, cfg=None): + super(GoBiggerEncoder, self).__init__() + self._cfg = deep_merge_dicts(self.config, cfg) + self._cfg = EasyDict(self._cfg) + self.scalar_encoder = ScalarEncoder(self._cfg) + self.team_encoder = TeamEncoder(self._cfg) + self.ball_encoder = BallEncoder(self._cfg) + self.spatial_encoder = SpatialEncoder(self._cfg) + + def forward(self, x): + scalar_info = self.scalar_encoder(x['scalar_info']) + team_info = self.team_encoder(x['team_info']) + ball_embeddings, ball_info = self.ball_encoder(x['ball_info']) + spatial_info = self.spatial_encoder(x, ball_embeddings) + x = torch.cat([scalar_info, team_info, ball_info, spatial_info], dim=1) + return x