Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(wrh): add RoPE for unizero #263

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 54 additions & 8 deletions lzero/model/unizero_world_models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import math
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Tuple

import torch
import torch.nn as nn
Expand All @@ -28,7 +28,10 @@ class TransformerConfig:
embed_pdrop: float
resid_pdrop: float
attn_pdrop: float


# for RoPE
rope_theta: float
max_seq_len: int
@property
def max_tokens(self):
return self.tokens_per_block * self.max_blocks
Expand All @@ -55,6 +58,13 @@ def __init__(self, config: TransformerConfig) -> None:
self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_layers)])
self.ln_f = nn.LayerNorm(config.embed_dim)


self.freqs_cis = precompute_freqs_cis(
self.config.embed_dim // self.config.num_heads,
self.config.max_seq_len * 2,
self.config.rope_theta,
)

def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues:
"""
Generate a placeholder for keys and values.
Expand All @@ -70,7 +80,7 @@ def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues:
return KeysValues(n, self.config.num_heads, max_tokens, self.config.embed_dim, self.config.num_layers, device)

def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues] = None,
valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
valid_context_lengths: Optional[torch.Tensor] = None, start_pos: int = 0) -> torch.Tensor:
"""
Forward pass of the Transformer model.

Expand All @@ -82,11 +92,15 @@ def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues
Returns:
- torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim).
"""
seqlen = sequences.shape[1]
self.freqs_cis = self.freqs_cis.to(sequences.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

assert past_keys_values is None or len(past_keys_values) == len(self.blocks)
x = self.drop(sequences)
for i, block in enumerate(self.blocks):
x = block(x, None if past_keys_values is None else past_keys_values[i], valid_context_lengths)

x = block(x, None if past_keys_values is None else past_keys_values[i], valid_context_lengths, start_pos, freqs_cis)
# TODO: pass the index into start_pos here
x = self.ln_f(x)
return x

Expand Down Expand Up @@ -129,7 +143,7 @@ def __init__(self, config: TransformerConfig) -> None:
)

def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None,
valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
valid_context_lengths: Optional[torch.Tensor] = None, start_pos: int = 0, freqs_cis: torch.Tensor = None) -> torch.Tensor:
"""
Forward pass of the Transformer block.

Expand All @@ -141,7 +155,7 @@ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None
Returns:
- torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim).
"""
x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths)
x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths, start_pos, freqs_cis)
if self.gru_gating:
x = self.gate1(x, x_attn)
x = self.gate2(x, self.mlp(self.ln2(x)))
Expand All @@ -152,6 +166,35 @@ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None
return x


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
print(f"freqs_cis 的 shape是{freqs_cis.shape}, 而 x 的 shape 是{x.shape}")
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[2], x.shape[-1])
shape = [d if i == 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)


def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2)
return xq_out.type_as(xq), xk_out.type_as(xk)


class SelfAttention(nn.Module):
"""
Implements self-attention mechanism for transformers.
Expand Down Expand Up @@ -189,7 +232,7 @@ def __init__(self, config: TransformerConfig) -> None:
self.register_buffer('mask', causal_mask)

def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None,
valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
valid_context_lengths: Optional[torch.Tensor] = None, start_pos: int = 0, freqs_cis: torch.Tensor = None) -> torch.Tensor:
"""
Forward pass for the self-attention mechanism.

Expand All @@ -212,6 +255,9 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None,
q = self.query(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size)
k = self.key(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size)
v = self.value(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size)

if self.config.rotary_emb:
q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)

if kv_cache is not None:
kv_cache.update(k, v)
Expand Down
54 changes: 34 additions & 20 deletions lzero/model/unizero_world_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,11 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None:
self._initialize_patterns()

# Position embedding
self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device)
self.precompute_pos_emb_diff_kv()
print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}")
if not self.config.rotary_emb:
self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device)
self.precompute_pos_emb_diff_kv()

print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}")

# Initialize action embedding table
self.act_embedding_table = nn.Embedding(config.action_space_size, config.embed_dim, device=self.device)
Expand Down Expand Up @@ -185,7 +187,8 @@ def precompute_pos_emb_diff_kv(self):
if self.context_length <= 2:
# If context length is 2 or less, no context is present
return

if self.config.rotary_emb:
return
# Precompute positional embedding matrices for inference in collect/eval stages, not for training
self.positional_embedding_k = [
self._get_positional_embedding(layer, 'key')
Expand Down Expand Up @@ -271,8 +274,11 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu
if len(obs_embeddings.shape) == 2:
obs_embeddings = obs_embeddings.unsqueeze(1)
num_steps = obs_embeddings.size(1)
sequences = self._add_position_embeddings(obs_embeddings, prev_steps, num_steps, kvcache_independent,
if not self.config.rotary_emb:
sequences = self._add_position_embeddings(obs_embeddings, prev_steps, num_steps, kvcache_independent,
is_init_infer, valid_context_lengths)
else:
sequences = obs_embeddings

# Process action tokens
elif 'act_tokens' in obs_embeddings_or_act_tokens:
Expand All @@ -281,8 +287,11 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu
act_tokens = act_tokens.squeeze(1)
num_steps = act_tokens.size(1)
act_embeddings = self.act_embedding_table(act_tokens)
sequences = self._add_position_embeddings(act_embeddings, prev_steps, num_steps, kvcache_independent,
if not self.config.rotary_emb:
sequences = self._add_position_embeddings(act_embeddings, prev_steps, num_steps, kvcache_independent,
is_init_infer, valid_context_lengths)
else:
sequences = act_embeddings

# Process combined observation embeddings and action tokens
else:
Expand Down Expand Up @@ -354,8 +363,11 @@ def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps):
act = act_embeddings[:, i, 0, :].unsqueeze(1)
obs_act = torch.cat([obs, act], dim=1)
obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act

return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps

return_result = obs_act_embeddings
if not self.config.rotary_emb:
return_result += self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device))
return return_result, num_steps

def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, valid_context_lengths):
"""
Expand Down Expand Up @@ -734,12 +746,13 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde
v_cache_trimmed = v_cache_current[:, :, 2:context_length - 1, :].squeeze(0)

# Index pre-computed positional encoding differences
pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)]
pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)]
# ============ NOTE: Very Important ============
# Apply positional encoding correction to k and v
k_cache_trimmed += pos_emb_diff_k.squeeze(0)
v_cache_trimmed += pos_emb_diff_v.squeeze(0)
if not self.config.rotary_emb:
pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)]
pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)]
# ============ NOTE: Very Important ============
# Apply positional encoding correction to k and v
k_cache_trimmed += pos_emb_diff_k.squeeze(0)
v_cache_trimmed += pos_emb_diff_v.squeeze(0)

# Pad the last 3 steps along the third dimension with zeros
# F.pad parameters (0, 0, 0, 3) specify padding amounts for each dimension: (left, right, top, bottom). For 3D tensor, they correspond to (dim2 left, dim2 right, dim1 left, dim1 right).
Expand Down Expand Up @@ -775,12 +788,13 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde
v_cache_trimmed = v_cache_current[:, 2:context_length - 1, :]

# Index pre-computed positional encoding differences
pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)]
pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)]
# ============ NOTE: Very Important ============
# Apply positional encoding correction to k and v
k_cache_trimmed += pos_emb_diff_k.squeeze(0)
v_cache_trimmed += pos_emb_diff_v.squeeze(0)
if not self.config.rotary_emb:
pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)]
pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)]
# ============ NOTE: Very Important ============
# Apply positional encoding correction to k and v
k_cache_trimmed += pos_emb_diff_k.squeeze(0)
v_cache_trimmed += pos_emb_diff_v.squeeze(0)

# Pad the last 3 steps along the third dimension with zeros
# F.pad parameters (0, 0, 0, 3) specify padding amounts for each dimension: (left, right, top, bottom). For 3D tensor, they correspond to (dim2 left, dim2 right, dim1 left, dim1 right).
Expand Down
10 changes: 7 additions & 3 deletions lzero/policy/unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,11 +553,12 @@ def _init_collect(self) -> None:
def _forward_collect(
self,
data: torch.Tensor,
action_mask: list = None,
action_mask: List = None,
temperature: float = 1,
to_play: List = [-1],
epsilon: float = 0.25,
ready_env_id: np.array = None
ready_env_id: np.ndarray = None,
step_index: List = [0]
) -> Dict:
"""
Overview:
Expand All @@ -569,6 +570,7 @@ def _forward_collect(
- temperature (:obj:`float`): The temperature of the policy.
- to_play (:obj:`int`): The player to play.
- ready_env_id (:obj:`list`): The id of the env that is ready to collect.
- step_index (:obj:`list`): The step index of the env in one episode
Shape:
- data (:obj:`torch.Tensor`):
- For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \
Expand All @@ -578,6 +580,7 @@ def _forward_collect(
- temperature: :math:`(1, )`.
- to_play: :math:`(N, 1)`, where N is the number of collect_env.
- ready_env_id: None
- step_index: :math:`(N, 1)`, where N is the number of collect_env.
Returns:
- output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \
``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``.
Expand Down Expand Up @@ -655,6 +658,7 @@ def _forward_collect(
'searched_value': value,
'predicted_value': pred_values[i],
'predicted_policy_logits': policy_logits[i],
'step_index': step_index[i]
}
batch_action.append(action)

Expand Down Expand Up @@ -683,7 +687,7 @@ def _init_eval(self) -> None:
self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)]

def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1,
ready_env_id: np.array = None) -> Dict:
ready_env_id: np.array = None, step_index: int = 0) -> Dict:
"""
Overview:
The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search.
Expand Down
16 changes: 13 additions & 3 deletions lzero/worker/muzero_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def collect(self,

action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)}
to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)}
step_index_dict = {i: to_ndarray(init_obs[i]['step_index']) for i in range(env_nums)}
if self.policy_config.use_ture_chance_label_in_chance_encoder:
chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)}

Expand Down Expand Up @@ -409,8 +410,12 @@ def collect(self,

action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id}
to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id}
step_index_dict = {env_id: step_index_dict[env_id] for env_id in ready_env_id}

action_mask = [action_mask_dict[env_id] for env_id in ready_env_id]
to_play = [to_play_dict[env_id] for env_id in ready_env_id]
step_index = [step_index_dict[env_id] for env_id in ready_env_id]

if self.policy_config.use_ture_chance_label_in_chance_encoder:
chance_dict = {env_id: chance_dict[env_id] for env_id in ready_env_id}

Expand All @@ -423,12 +428,13 @@ def collect(self,
# Key policy forward step
# ==============================================================
# print(f'ready_env_id:{ready_env_id}')
policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id)
policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, step_index=step_index)

# Extract relevant policy outputs
actions_with_env_id = {k: v['action'] for k, v in policy_output.items()}
value_dict_with_env_id = {k: v['searched_value'] for k, v in policy_output.items()}
pred_value_dict_with_env_id = {k: v['predicted_value'] for k, v in policy_output.items()}
step_index_dict_with_env_id = {k: v['step_index'] for k, v in policy_output.items()}

if self.policy_config.sampled_algo:
root_sampled_actions_dict_with_env_id = {
Expand All @@ -450,6 +456,7 @@ def collect(self,
actions = {}
value_dict = {}
pred_value_dict = {}
step_index_dict = {}

if not collect_with_pure_policy:
distributions_dict = {}
Expand All @@ -467,6 +474,7 @@ def collect(self,
actions[env_id] = actions_with_env_id.pop(env_id)
value_dict[env_id] = value_dict_with_env_id.pop(env_id)
pred_value_dict[env_id] = pred_value_dict_with_env_id.pop(env_id)
step_index_dict[env_id] = step_index_dict_with_env_id.pop(env_id)

if not collect_with_pure_policy:
distributions_dict[env_id] = distributions_dict_with_env_id.pop(env_id)
Expand Down Expand Up @@ -517,18 +525,19 @@ def collect(self,
if self.policy_config.use_ture_chance_label_in_chance_encoder:
game_segments[env_id].append(
actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id],
to_play_dict[env_id], chance_dict[env_id]
to_play_dict[env_id], chance_dict[env_id], step_index_dict[env_id]
)
else:
game_segments[env_id].append(
actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id],
to_play_dict[env_id]
to_play_dict[env_id], step_index_dict[env_id]
)

# NOTE: the position of code snippet is very important.
# the obs['action_mask'] and obs['to_play'] are corresponding to the next action
action_mask_dict[env_id] = to_ndarray(obs['action_mask'])
to_play_dict[env_id] = to_ndarray(obs['to_play'])
step_index_dict[env_id] = to_ndarray(obs['step_index'])
if self.policy_config.use_ture_chance_label_in_chance_encoder:
chance_dict[env_id] = to_ndarray(obs['chance'])

Expand Down Expand Up @@ -659,6 +668,7 @@ def collect(self,

action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask'])
to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play'])
step_index_dict[env_id] = to_ndarray(init_obs[env_id]['step_index'])
if self.policy_config.use_ture_chance_label_in_chance_encoder:
chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance'])

Expand Down
Loading