Skip to content

Commit b36f713

Browse files
author
daixu
committed
add result
1 parent e289aed commit b36f713

File tree

4 files changed

+21
-14
lines changed

4 files changed

+21
-14
lines changed

README.md

+8-1
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,26 @@ pip install transformers >= 4.35.2
1212

1313
```
1414

15-
## excute
15+
## excute & result
1616

1717
```bash
1818
git clone https://github.com/silencelamb/naked_llama.git
1919

2020
# default model_size is 7b
2121
python naked_llama.py
2222

23+
```
24+
25+
![llama2 7B](llama2_7b_image.png)
26+
27+
```bash
2328
# run 70 b
2429
python naked_llama.py --model_size 70b
2530

2631
```
2732

33+
![llama2 70B](llama2_70b_image.png)
34+
2835
## references
2936

3037
- [llama in huggingface transformers](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py)

layers/attention.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
88
"""
99
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
10-
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
10+
num_kv_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
1111
"""
12-
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
12+
batch, num_kv_heads, slen, head_dim = hidden_states.shape
1313
if n_rep == 1:
1414
return hidden_states
15-
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
16-
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
15+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_kv_heads, n_rep, slen, head_dim)
16+
return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim)
1717

1818

1919

@@ -42,8 +42,8 @@ def multi_head_attention(hidden_states, w_q, w_k, w_v, w_o, config: LlamaConfig,
4242
"""
4343

4444
num_heads = config.num_attention_heads
45-
num_key_value_heads = config.num_key_value_heads
46-
num_key_value_groups = num_heads // num_key_value_heads
45+
num_kv_heads = config.num_key_value_heads
46+
num_kv_groups = num_heads // num_kv_heads
4747

4848

4949
batch_size, seq_len, hidden_size = hidden_states.shape[0:3]
@@ -58,13 +58,13 @@ def multi_head_attention(hidden_states, w_q, w_k, w_v, w_o, config: LlamaConfig,
5858
# 将Q, K, V矩阵分割成多个头, [batch_size, heads, seq_len, head_dim]
5959
# 先 view, 然后transpose
6060
query = query.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
61-
key = key.view(batch_size, seq_len, num_key_value_heads, head_dim).transpose(1, 2)
62-
value = value.view(batch_size, seq_len, num_key_value_heads, head_dim).transpose(1, 2)
61+
key = key.view(batch_size, seq_len, num_kv_heads, head_dim).transpose(1, 2)
62+
value = value.view(batch_size, seq_len, num_kv_heads, head_dim).transpose(1, 2)
6363

6464
# 将Q, K, V矩阵分割成多个头, 另一种写法
6565
# query = query.view(batch_size, seq_len, num_heads, head_dim).permute(0, 2, 1, 3)
66-
# key = key.view(batch_size, seq_len, num_key_value_heads, head_dim).permute(0, 2, 1, 3)
67-
# value = value.view(batch_size, seq_len, num_key_value_heads, head_dim).permute(0, 2, 1, 3)
66+
# key = key.view(batch_size, seq_len, num_kv_heads, head_dim).permute(0, 2, 1, 3)
67+
# value = value.view(batch_size, seq_len, num_kv_heads, head_dim).permute(0, 2, 1, 3)
6868

6969
# ROPE计算
7070
cos, sin = get_rope_embeddings(value, seq_len=seq_len)
@@ -77,9 +77,9 @@ def multi_head_attention(hidden_states, w_q, w_k, w_v, w_o, config: LlamaConfig,
7777
position_ids = position_ids.unsqueeze(0)
7878
query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids=position_ids)
7979

80-
# 重复多头, 对于 7B num_key_value_groups=32/32=1, 对于 70B num_key_value_groups=64/8=8
81-
key = repeat_kv(key, num_key_value_groups)
82-
value = repeat_kv(value, num_key_value_groups)
80+
# 重复多头, 对于 7B num_kv_groups=32/32=1, 对于 70B num_kv_groups=64/8=8
81+
key = repeat_kv(key, num_kv_groups)
82+
value = repeat_kv(value, num_kv_groups)
8383

8484
# 注意力机制
8585
attention_output = scaled_dot_product_attention(query, key, value, mask)

llama2_70b_image.png

70.3 KB
Loading

llama2_7b_image.png

68.8 KB
Loading

0 commit comments

Comments
 (0)