diff --git a/lzero/model/unizero_world_models/transformer.py b/lzero/model/unizero_world_models/transformer.py index 714bc13d6..bfb77a48c 100644 --- a/lzero/model/unizero_world_models/transformer.py +++ b/lzero/model/unizero_world_models/transformer.py @@ -4,7 +4,7 @@ import math from dataclasses import dataclass -from typing import Optional +from typing import Optional, Tuple import torch import torch.nn as nn @@ -28,7 +28,10 @@ class TransformerConfig: embed_pdrop: float resid_pdrop: float attn_pdrop: float - + + # for RoPE + rope_theta: float + max_seq_len: int @property def max_tokens(self): return self.tokens_per_block * self.max_blocks @@ -55,6 +58,13 @@ def __init__(self, config: TransformerConfig) -> None: self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_layers)]) self.ln_f = nn.LayerNorm(config.embed_dim) + + self.freqs_cis = precompute_freqs_cis( + self.config.embed_dim // self.config.num_heads, + self.config.max_seq_len * 2, + self.config.rope_theta, + ) + def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues: """ Generate a placeholder for keys and values. @@ -70,7 +80,7 @@ def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues: return KeysValues(n, self.config.num_heads, max_tokens, self.config.embed_dim, self.config.num_layers, device) def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues] = None, - valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: + valid_context_lengths: Optional[torch.Tensor] = None, start_pos: int = 0) -> torch.Tensor: """ Forward pass of the Transformer model. @@ -82,11 +92,15 @@ def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues Returns: - torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim). """ + seqlen = sequences.shape[1] + self.freqs_cis = self.freqs_cis.to(sequences.device) + freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] + assert past_keys_values is None or len(past_keys_values) == len(self.blocks) x = self.drop(sequences) for i, block in enumerate(self.blocks): - x = block(x, None if past_keys_values is None else past_keys_values[i], valid_context_lengths) - + x = block(x, None if past_keys_values is None else past_keys_values[i], valid_context_lengths, start_pos, freqs_cis) + # TODO: pass the index into start_pos here x = self.ln_f(x) return x @@ -129,7 +143,7 @@ def __init__(self, config: TransformerConfig) -> None: ) def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None, - valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: + valid_context_lengths: Optional[torch.Tensor] = None, start_pos: int = 0, freqs_cis: torch.Tensor = None) -> torch.Tensor: """ Forward pass of the Transformer block. @@ -141,7 +155,7 @@ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None Returns: - torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim). """ - x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths) + x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths, start_pos, freqs_cis) if self.gru_gating: x = self.gate1(x, x_attn) x = self.gate2(x, self.mlp(self.ln2(x))) @@ -152,6 +166,35 @@ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None return x +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + print(f"freqs_cis 的 shape是{freqs_cis.shape}, 而 x 的 shape 是{x.shape}") + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[2], x.shape[-1]) + shape = [d if i == 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2) + return xq_out.type_as(xq), xk_out.type_as(xk) + + class SelfAttention(nn.Module): """ Implements self-attention mechanism for transformers. @@ -189,7 +232,7 @@ def __init__(self, config: TransformerConfig) -> None: self.register_buffer('mask', causal_mask) def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, - valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: + valid_context_lengths: Optional[torch.Tensor] = None, start_pos: int = 0, freqs_cis: torch.Tensor = None) -> torch.Tensor: """ Forward pass for the self-attention mechanism. @@ -212,6 +255,9 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, q = self.query(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size) k = self.key(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size) v = self.value(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size) + + if self.config.rotary_emb: + q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) if kv_cache is not None: kv_cache.update(k, v) diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index ef31d951c..8c9e70a99 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -56,9 +56,11 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self._initialize_patterns() # Position embedding - self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device) - self.precompute_pos_emb_diff_kv() - print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}") + if not self.config.rotary_emb: + self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device) + self.precompute_pos_emb_diff_kv() + + print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}") # Initialize action embedding table self.act_embedding_table = nn.Embedding(config.action_space_size, config.embed_dim, device=self.device) @@ -185,7 +187,8 @@ def precompute_pos_emb_diff_kv(self): if self.context_length <= 2: # If context length is 2 or less, no context is present return - + if self.config.rotary_emb: + return # Precompute positional embedding matrices for inference in collect/eval stages, not for training self.positional_embedding_k = [ self._get_positional_embedding(layer, 'key') @@ -271,8 +274,11 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu if len(obs_embeddings.shape) == 2: obs_embeddings = obs_embeddings.unsqueeze(1) num_steps = obs_embeddings.size(1) - sequences = self._add_position_embeddings(obs_embeddings, prev_steps, num_steps, kvcache_independent, + if not self.config.rotary_emb: + sequences = self._add_position_embeddings(obs_embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, valid_context_lengths) + else: + sequences = obs_embeddings # Process action tokens elif 'act_tokens' in obs_embeddings_or_act_tokens: @@ -281,8 +287,11 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu act_tokens = act_tokens.squeeze(1) num_steps = act_tokens.size(1) act_embeddings = self.act_embedding_table(act_tokens) - sequences = self._add_position_embeddings(act_embeddings, prev_steps, num_steps, kvcache_independent, + if not self.config.rotary_emb: + sequences = self._add_position_embeddings(act_embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, valid_context_lengths) + else: + sequences = act_embeddings # Process combined observation embeddings and action tokens else: @@ -354,8 +363,11 @@ def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps): act = act_embeddings[:, i, 0, :].unsqueeze(1) obs_act = torch.cat([obs, act], dim=1) obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act - - return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps + + return_result = obs_act_embeddings + if not self.config.rotary_emb: + return_result += self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)) + return return_result, num_steps def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, valid_context_lengths): """ @@ -734,12 +746,13 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde v_cache_trimmed = v_cache_current[:, :, 2:context_length - 1, :].squeeze(0) # Index pre-computed positional encoding differences - pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)] - pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)] - # ============ NOTE: Very Important ============ - # Apply positional encoding correction to k and v - k_cache_trimmed += pos_emb_diff_k.squeeze(0) - v_cache_trimmed += pos_emb_diff_v.squeeze(0) + if not self.config.rotary_emb: + pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)] + pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)] + # ============ NOTE: Very Important ============ + # Apply positional encoding correction to k and v + k_cache_trimmed += pos_emb_diff_k.squeeze(0) + v_cache_trimmed += pos_emb_diff_v.squeeze(0) # Pad the last 3 steps along the third dimension with zeros # F.pad parameters (0, 0, 0, 3) specify padding amounts for each dimension: (left, right, top, bottom). For 3D tensor, they correspond to (dim2 left, dim2 right, dim1 left, dim1 right). @@ -775,12 +788,13 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde v_cache_trimmed = v_cache_current[:, 2:context_length - 1, :] # Index pre-computed positional encoding differences - pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)] - pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)] - # ============ NOTE: Very Important ============ - # Apply positional encoding correction to k and v - k_cache_trimmed += pos_emb_diff_k.squeeze(0) - v_cache_trimmed += pos_emb_diff_v.squeeze(0) + if not self.config.rotary_emb: + pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)] + pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)] + # ============ NOTE: Very Important ============ + # Apply positional encoding correction to k and v + k_cache_trimmed += pos_emb_diff_k.squeeze(0) + v_cache_trimmed += pos_emb_diff_v.squeeze(0) # Pad the last 3 steps along the third dimension with zeros # F.pad parameters (0, 0, 0, 3) specify padding amounts for each dimension: (left, right, top, bottom). For 3D tensor, they correspond to (dim2 left, dim2 right, dim1 left, dim1 right). diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index 73af201c3..e44cf9a8d 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -553,11 +553,12 @@ def _init_collect(self) -> None: def _forward_collect( self, data: torch.Tensor, - action_mask: list = None, + action_mask: List = None, temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, - ready_env_id: np.array = None + ready_env_id: np.ndarray = None, + step_index: List = [0] ) -> Dict: """ Overview: @@ -569,6 +570,7 @@ def _forward_collect( - temperature (:obj:`float`): The temperature of the policy. - to_play (:obj:`int`): The player to play. - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + - step_index (:obj:`list`): The step index of the env in one episode Shape: - data (:obj:`torch.Tensor`): - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ @@ -578,6 +580,7 @@ def _forward_collect( - temperature: :math:`(1, )`. - to_play: :math:`(N, 1)`, where N is the number of collect_env. - ready_env_id: None + - step_index: :math:`(N, 1)`, where N is the number of collect_env. Returns: - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. @@ -655,6 +658,7 @@ def _forward_collect( 'searched_value': value, 'predicted_value': pred_values[i], 'predicted_policy_logits': policy_logits[i], + 'step_index': step_index[i] } batch_action.append(action) @@ -683,7 +687,7 @@ def _init_eval(self) -> None: self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, - ready_env_id: np.array = None) -> Dict: + ready_env_id: np.array = None, step_index: int = 0) -> Dict: """ Overview: The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 9933f816e..97c3842d3 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -352,6 +352,7 @@ def collect(self, action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} + step_index_dict = {i: to_ndarray(init_obs[i]['step_index']) for i in range(env_nums)} if self.policy_config.use_ture_chance_label_in_chance_encoder: chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)} @@ -409,8 +410,12 @@ def collect(self, action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} + step_index_dict = {env_id: step_index_dict[env_id] for env_id in ready_env_id} + action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] to_play = [to_play_dict[env_id] for env_id in ready_env_id] + step_index = [step_index_dict[env_id] for env_id in ready_env_id] + if self.policy_config.use_ture_chance_label_in_chance_encoder: chance_dict = {env_id: chance_dict[env_id] for env_id in ready_env_id} @@ -423,12 +428,13 @@ def collect(self, # Key policy forward step # ============================================================== # print(f'ready_env_id:{ready_env_id}') - policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id) + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, step_index=step_index) # Extract relevant policy outputs actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} value_dict_with_env_id = {k: v['searched_value'] for k, v in policy_output.items()} pred_value_dict_with_env_id = {k: v['predicted_value'] for k, v in policy_output.items()} + step_index_dict_with_env_id = {k: v['step_index'] for k, v in policy_output.items()} if self.policy_config.sampled_algo: root_sampled_actions_dict_with_env_id = { @@ -450,6 +456,7 @@ def collect(self, actions = {} value_dict = {} pred_value_dict = {} + step_index_dict = {} if not collect_with_pure_policy: distributions_dict = {} @@ -467,6 +474,7 @@ def collect(self, actions[env_id] = actions_with_env_id.pop(env_id) value_dict[env_id] = value_dict_with_env_id.pop(env_id) pred_value_dict[env_id] = pred_value_dict_with_env_id.pop(env_id) + step_index_dict[env_id] = step_index_dict_with_env_id.pop(env_id) if not collect_with_pure_policy: distributions_dict[env_id] = distributions_dict_with_env_id.pop(env_id) @@ -517,18 +525,19 @@ def collect(self, if self.policy_config.use_ture_chance_label_in_chance_encoder: game_segments[env_id].append( actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id], chance_dict[env_id] + to_play_dict[env_id], chance_dict[env_id], step_index_dict[env_id] ) else: game_segments[env_id].append( actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id] + to_play_dict[env_id], step_index_dict[env_id] ) # NOTE: the position of code snippet is very important. # the obs['action_mask'] and obs['to_play'] are corresponding to the next action action_mask_dict[env_id] = to_ndarray(obs['action_mask']) to_play_dict[env_id] = to_ndarray(obs['to_play']) + step_index_dict[env_id] = to_ndarray(obs['step_index']) if self.policy_config.use_ture_chance_label_in_chance_encoder: chance_dict[env_id] = to_ndarray(obs['chance']) @@ -659,6 +668,7 @@ def collect(self, action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) + step_index_dict[env_id] = to_ndarray(init_obs[env_id]['step_index']) if self.policy_config.use_ture_chance_label_in_chance_encoder: chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) diff --git a/zoo/atari/config/atari_unizero_config.py b/zoo/atari/config/atari_unizero_config.py index 1c549010f..fe4ea8ed9 100644 --- a/zoo/atari/config/atari_unizero_config.py +++ b/zoo/atari/config/atari_unizero_config.py @@ -51,21 +51,25 @@ observation_shape=(3, 64, 64), action_space_size=action_space_size, world_model_cfg=dict( + rotary_emb=False, max_blocks=num_unroll_steps, max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action context_length=2 * infer_context_length, device='cuda', # device='cpu', action_space_size=action_space_size, - num_layers=4, + num_layers=2, num_heads=8, embed_dim=768, + # for RoPE + rope_theta=10000, + max_seq_len=2048, obs_type='image', env_num=max(collector_env_num, evaluator_env_num), ), ), # (str) The path of the pretrained model. If None, the model will be initialized by the default model. - model_path=None, + model_path=None, # num_unroll_steps=num_unroll_steps, update_per_collect=update_per_collect, replay_ratio=replay_ratio, @@ -101,6 +105,6 @@ seeds = [0] # You can add more seed values here for seed in seeds: # Update exp_name to include the current seed - main_config.exp_name = f'data_unizero/{env_id[:-14]}_stack1_unizero_upc{update_per_collect}-rr{replay_ratio}_H{num_unroll_steps}_bs{batch_size}_seed{seed}' + main_config.exp_name = f'data_unizero_debug/{env_id[:-14]}_stack1_unizero_upc{update_per_collect}-rr{replay_ratio}_H{num_unroll_steps}_bs{batch_size}_seed{seed}' from lzero.entry import train_unizero train_unizero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) diff --git a/zoo/atari/envs/atari_lightzero_env.py b/zoo/atari/envs/atari_lightzero_env.py index 84288feb5..f09e98165 100644 --- a/zoo/atari/envs/atari_lightzero_env.py +++ b/zoo/atari/envs/atari_lightzero_env.py @@ -100,6 +100,7 @@ def __init__(self, cfg: EasyDict) -> None: self.clip_rewards = cfg.clip_rewards self.episode_life = cfg.episode_life + self.step_index = 0 def reset(self) -> dict: """ Overview: @@ -133,6 +134,7 @@ def reset(self) -> dict: self.obs = to_ndarray(obs) self._eval_episode_return = 0. obs = self.observe() + self.step_index = 0 return obs def step(self, action: int) -> BaseEnvTimestep: @@ -151,7 +153,8 @@ def step(self, action: int) -> BaseEnvTimestep: observation = self.observe() if done: info['eval_episode_return'] = self._eval_episode_return - + else: + self.step_index += 1 return BaseEnvTimestep(observation, self.reward, done, info) def observe(self) -> dict: @@ -169,7 +172,7 @@ def observe(self) -> dict: observation = np.transpose(observation, (2, 0, 1)) action_mask = np.ones(self._action_space.n, 'int8') - return {'observation': observation, 'action_mask': action_mask, 'to_play': -1} + return {'observation': observation, 'action_mask': action_mask, 'to_play': -1, 'step_index': self.step_index} @property def legal_actions(self):