Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(wrh): add RoPE for unizero #263

Closed
wants to merge 4 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feature(pu): add rope in unizero's transformer
dyyoungg committed Aug 8, 2024
commit 1f1df62493f45370e4fcd8a649424211317bd5c7
56 changes: 50 additions & 6 deletions lzero/model/unizero_world_models/transformer.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@

import math
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Tuple

import torch
import torch.nn as nn
@@ -55,6 +55,15 @@ def __init__(self, config: TransformerConfig) -> None:
self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_layers)])
self.ln_f = nn.LayerNorm(config.embed_dim)

self.config.rope_theta = 500000
self.config.max_seq_len = 2048
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个参数确定一下,是否应该设置成与实际训练的长度一致

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rope_theta 是影响位置编码的频率,应该就用默认的就行。max_seq_len 是最大序列长度,它决定了预计算频率张量的长度,如果我们希望在测试时支持更长的序列,应该将 max_seq_len 设置为能覆盖我们期望的最大测试序列长度,例如如果我们测试最长是2048,这个值应该设置为2048, 10有点太小了,如果测试长度>10会报错。


self.freqs_cis = precompute_freqs_cis(
self.config.embed_dim // self.config.num_heads,
self.config.max_seq_len * 2,
self.config.rope_theta,
)

def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues:
"""
Generate a placeholder for keys and values.
@@ -70,7 +79,7 @@ def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues:
return KeysValues(n, self.config.num_heads, max_tokens, self.config.embed_dim, self.config.num_layers, device)

def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues] = None,
valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
valid_context_lengths: Optional[torch.Tensor] = None, start_pos: int = 0) -> torch.Tensor:
"""
Forward pass of the Transformer model.

@@ -82,10 +91,14 @@ def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues
Returns:
- torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim).
"""
seqlen = sequences.shape[1]
self.freqs_cis = self.freqs_cis.to(sequences.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

assert past_keys_values is None or len(past_keys_values) == len(self.blocks)
x = self.drop(sequences)
for i, block in enumerate(self.blocks):
x = block(x, None if past_keys_values is None else past_keys_values[i], valid_context_lengths)
x = block(x, None if past_keys_values is None else past_keys_values[i], valid_context_lengths, start_pos, freqs_cis)

x = self.ln_f(x)
return x
@@ -129,7 +142,7 @@ def __init__(self, config: TransformerConfig) -> None:
)

def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None,
valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
valid_context_lengths: Optional[torch.Tensor] = None, start_pos: int = 0, freqs_cis: torch.Tensor = None) -> torch.Tensor:
"""
Forward pass of the Transformer block.

@@ -141,7 +154,7 @@ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None
Returns:
- torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim).
"""
x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths)
x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths, start_pos, freqs_cis)
if self.gru_gating:
x = self.gate1(x, x_attn)
x = self.gate2(x, self.mlp(self.ln2(x)))
@@ -152,6 +165,34 @@ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None
return x


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)


def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)


class SelfAttention(nn.Module):
"""
Implements self-attention mechanism for transformers.
@@ -189,7 +230,7 @@ def __init__(self, config: TransformerConfig) -> None:
self.register_buffer('mask', causal_mask)

def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None,
valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
valid_context_lengths: Optional[torch.Tensor] = None, start_pos: int = 0, freqs_cis: torch.Tensor = None) -> torch.Tensor:
"""
Forward pass for the self-attention mechanism.

@@ -212,6 +253,9 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None,
q = self.query(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size)
k = self.key(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size)
v = self.value(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size)

if self.config.rotary_emb:
q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)

if kv_cache is not None:
kv_cache.update(k, v)
18 changes: 13 additions & 5 deletions lzero/model/unizero_world_models/world_model.py
Original file line number Diff line number Diff line change
@@ -56,9 +56,11 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None:
self._initialize_patterns()

# Position embedding
self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device)
self.precompute_pos_emb_diff_kv()
print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}")
if not self.config.rotary_emb:
self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device)
self.precompute_pos_emb_diff_kv()

print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}")

# Initialize action embedding table
self.act_embedding_table = nn.Embedding(config.action_space_size, config.embed_dim, device=self.device)
@@ -271,8 +273,11 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu
if len(obs_embeddings.shape) == 2:
obs_embeddings = obs_embeddings.unsqueeze(1)
num_steps = obs_embeddings.size(1)
sequences = self._add_position_embeddings(obs_embeddings, prev_steps, num_steps, kvcache_independent,
if not self.config.rotary_emb:
sequences = self._add_position_embeddings(obs_embeddings, prev_steps, num_steps, kvcache_independent,
is_init_infer, valid_context_lengths)
else:
sequences = obs_embeddings

# Process action tokens
elif 'act_tokens' in obs_embeddings_or_act_tokens:
@@ -281,8 +286,11 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu
act_tokens = act_tokens.squeeze(1)
num_steps = act_tokens.size(1)
act_embeddings = self.act_embedding_table(act_tokens)
sequences = self._add_position_embeddings(act_embeddings, prev_steps, num_steps, kvcache_independent,
if not self.config.rotary_emb:
sequences = self._add_position_embeddings(act_embeddings, prev_steps, num_steps, kvcache_independent,
is_init_infer, valid_context_lengths)
else:
sequences = act_embeddings

# Process combined observation embeddings and action tokens
else:
21 changes: 11 additions & 10 deletions zoo/atari/config/atari_unizero_config.py
Original file line number Diff line number Diff line change
@@ -20,14 +20,14 @@
infer_context_length = 4

# ====== only for debug =====
# collector_env_num = 2
# n_episode = 2
# evaluator_env_num = 2
# num_simulations = 5
# max_env_step = int(5e5)
# reanalyze_ratio = 0.
# batch_size = 2
# num_unroll_steps = 10
collector_env_num = 2
n_episode = 2
evaluator_env_num = 2
num_simulations = 5
max_env_step = int(5e5)
reanalyze_ratio = 0.
batch_size = 2
num_unroll_steps = 10
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

训练的时候,不是用的debug config吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不是,只是提交上来的是debug版的, 但Training的过程里用的不是

# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
@@ -51,13 +51,14 @@
observation_shape=(3, 64, 64),
action_space_size=action_space_size,
world_model_cfg=dict(
rotary_emb=True,
max_blocks=num_unroll_steps,
max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action
context_length=2 * infer_context_length,
device='cuda',
# device='cpu',
action_space_size=action_space_size,
num_layers=4,
num_layers=2,
num_heads=8,
embed_dim=768,
obs_type='image',
@@ -101,6 +102,6 @@
seeds = [0] # You can add more seed values here
for seed in seeds:
# Update exp_name to include the current seed
main_config.exp_name = f'data_unizero/{env_id[:-14]}_stack1_unizero_upc{update_per_collect}-rr{replay_ratio}_H{num_unroll_steps}_bs{batch_size}_seed{seed}'
main_config.exp_name = f'data_unizero_debug/{env_id[:-14]}_stack1_unizero_upc{update_per_collect}-rr{replay_ratio}_H{num_unroll_steps}_bs{batch_size}_seed{seed}'
from lzero.entry import train_unizero
train_unizero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step)