7
7
def repeat_kv (hidden_states : torch .Tensor , n_rep : int ) -> torch .Tensor :
8
8
"""
9
9
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
10
- num_key_value_heads , seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
10
+ num_kv_heads , seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
11
11
"""
12
- batch , num_key_value_heads , slen , head_dim = hidden_states .shape
12
+ batch , num_kv_heads , slen , head_dim = hidden_states .shape
13
13
if n_rep == 1 :
14
14
return hidden_states
15
- hidden_states = hidden_states [:, :, None , :, :].expand (batch , num_key_value_heads , n_rep , slen , head_dim )
16
- return hidden_states .reshape (batch , num_key_value_heads * n_rep , slen , head_dim )
15
+ hidden_states = hidden_states [:, :, None , :, :].expand (batch , num_kv_heads , n_rep , slen , head_dim )
16
+ return hidden_states .reshape (batch , num_kv_heads * n_rep , slen , head_dim )
17
17
18
18
19
19
@@ -42,8 +42,8 @@ def multi_head_attention(hidden_states, w_q, w_k, w_v, w_o, config: LlamaConfig,
42
42
"""
43
43
44
44
num_heads = config .num_attention_heads
45
- num_key_value_heads = config .num_key_value_heads
46
- num_key_value_groups = num_heads // num_key_value_heads
45
+ num_kv_heads = config .num_key_value_heads
46
+ num_kv_groups = num_heads // num_kv_heads
47
47
48
48
49
49
batch_size , seq_len , hidden_size = hidden_states .shape [0 :3 ]
@@ -58,13 +58,13 @@ def multi_head_attention(hidden_states, w_q, w_k, w_v, w_o, config: LlamaConfig,
58
58
# 将Q, K, V矩阵分割成多个头, [batch_size, heads, seq_len, head_dim]
59
59
# 先 view, 然后transpose
60
60
query = query .view (batch_size , seq_len , num_heads , head_dim ).transpose (1 , 2 )
61
- key = key .view (batch_size , seq_len , num_key_value_heads , head_dim ).transpose (1 , 2 )
62
- value = value .view (batch_size , seq_len , num_key_value_heads , head_dim ).transpose (1 , 2 )
61
+ key = key .view (batch_size , seq_len , num_kv_heads , head_dim ).transpose (1 , 2 )
62
+ value = value .view (batch_size , seq_len , num_kv_heads , head_dim ).transpose (1 , 2 )
63
63
64
64
# 将Q, K, V矩阵分割成多个头, 另一种写法
65
65
# query = query.view(batch_size, seq_len, num_heads, head_dim).permute(0, 2, 1, 3)
66
- # key = key.view(batch_size, seq_len, num_key_value_heads , head_dim).permute(0, 2, 1, 3)
67
- # value = value.view(batch_size, seq_len, num_key_value_heads , head_dim).permute(0, 2, 1, 3)
66
+ # key = key.view(batch_size, seq_len, num_kv_heads , head_dim).permute(0, 2, 1, 3)
67
+ # value = value.view(batch_size, seq_len, num_kv_heads , head_dim).permute(0, 2, 1, 3)
68
68
69
69
# ROPE计算
70
70
cos , sin = get_rope_embeddings (value , seq_len = seq_len )
@@ -77,9 +77,9 @@ def multi_head_attention(hidden_states, w_q, w_k, w_v, w_o, config: LlamaConfig,
77
77
position_ids = position_ids .unsqueeze (0 )
78
78
query , key = apply_rotary_pos_emb (query , key , cos , sin , position_ids = position_ids )
79
79
80
- # 重复多头, 对于 7B num_key_value_groups =32/32=1, 对于 70B num_key_value_groups =64/8=8
81
- key = repeat_kv (key , num_key_value_groups )
82
- value = repeat_kv (value , num_key_value_groups )
80
+ # 重复多头, 对于 7B num_kv_groups =32/32=1, 对于 70B num_kv_groups =64/8=8
81
+ key = repeat_kv (key , num_kv_groups )
82
+ value = repeat_kv (value , num_kv_groups )
83
83
84
84
# 注意力机制
85
85
attention_output = scaled_dot_product_attention (query , key , value , mask )
0 commit comments