Skip to content

Commit b20a649

Browse files
committed
update rope
1 parent af99fe6 commit b20a649

File tree

2 files changed

+122
-2
lines changed

2 files changed

+122
-2
lines changed

hf_example/hf_llama2_7b.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828

2929
# ModeSwitch = PROMPT_FORWARD | HF_GENERATE | HF_STREAM_GENERATE | FORWARD | FORWARD_KV_CACHE | FORWARD_HF_DYNAMIC_CACHE | FORWARD_HF_STATIC_CACHE
3030
# ModeSwitch = FORWARD_KV_CACHE | FORWARD_HF_DYNAMIC_CACHE
31-
ModeSwitch = FORWARD_KV_CACHE | FORWARD_HF_STATIC_CACHE
31+
# ModeSwitch = FORWARD_KV_CACHE | FORWARD_HF_STATIC_CACHE
32+
ModeSwitch = FORWARD_HF_STATIC_CACHE
3233

3334
if __name__ == '__main__':
3435

layers/rope.py

+120-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import torch
2+
from typing import Optional, Tuple
23

34
cos_cached = None
45
sin_cached = None
56

7+
### huggingface implementation ###
68

79
def init_rope_embeddings(dim, max_position_embeddings=4096, base=10000, device=None, scaling_factor=1.0):
810
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
@@ -64,5 +66,122 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
6466
k_embed = (k * cos) + (rotate_half(k) * sin)
6567
return q_embed, k_embed
6668

69+
### meta implementation ###
70+
71+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
72+
"""
73+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
74+
75+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
76+
and the end index 'end'. The 'theta' parameter scales the frequencies.
77+
The returned tensor contains complex values in complex64 data type.
78+
79+
Args:
80+
dim (int): Dimension of the frequency tensor.
81+
end (int): End index for precomputing frequencies.
82+
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
83+
84+
Returns:
85+
torch.Tensor: Precomputed frequency tensor with complex exponentials.
86+
87+
"""
88+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
89+
t = torch.arange(end, device=freqs.device) # type: ignore
90+
freqs = torch.outer(t, freqs).float() # type: ignore
91+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
92+
return freqs_cis
93+
94+
95+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
96+
"""
97+
Reshape frequency tensor for broadcasting it with another tensor.
98+
99+
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
100+
for the purpose of broadcasting the frequency tensor during element-wise operations.
101+
102+
Args:
103+
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
104+
x (torch.Tensor): Target tensor for broadcasting compatibility.
105+
106+
Returns:
107+
torch.Tensor: Reshaped frequency tensor.
108+
109+
Raises:
110+
AssertionError: If the frequency tensor doesn't match the expected shape.
111+
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
112+
"""
113+
ndim = x.ndim
114+
assert 0 <= 1 < ndim
115+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
116+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
117+
return freqs_cis.view(*shape)
118+
119+
120+
def apply_rotary_emb(
121+
xq: torch.Tensor,
122+
xk: torch.Tensor,
123+
freqs_cis: torch.Tensor,
124+
) -> Tuple[torch.Tensor, torch.Tensor]:
125+
"""
126+
Apply rotary embeddings to input tensors using the given frequency tensor.
127+
128+
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
129+
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
130+
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
131+
returned as real tensors.
132+
133+
Args:
134+
xq (torch.Tensor): Query tensor to apply rotary embeddings.
135+
xk (torch.Tensor): Key tensor to apply rotary embeddings.
136+
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
137+
138+
Returns:
139+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
140+
141+
142+
143+
"""
144+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
145+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
146+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
147+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
148+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
149+
return xq_out.type_as(xq), xk_out.type_as(xk)
150+
67151
if __name__ == '__main__':
68-
init_rope_embeddings(dim=128)
152+
torch.set_printoptions(linewidth=200) # 这样打印不会存在折叠的问题
153+
# batch_size, seq_len, head_num, head_dim = 1, 13, 32, 128
154+
batch_size, seq_len, head_num, head_dim = 1, 6, 1, 8
155+
max_position_embeddings = 4096
156+
157+
# test hf implementation
158+
cos_cached,sin_cached = init_rope_embeddings(dim=head_dim, max_position_embeddings=max_position_embeddings)
159+
160+
xq = torch.randn(batch_size, head_num, seq_len, head_dim)
161+
import copy
162+
xk = copy.deepcopy(xq)
163+
# import pdb; pdb.set_trace()
164+
cos, sin = get_rope_embeddings(xq, seq_len=seq_len)
165+
position_ids = torch.arange(0, seq_len, dtype=torch.long).unsqueeze(0)
166+
hf_xq_new, hf_xk_new = apply_rotary_pos_emb(xq, xk, cos, sin, position_ids)
167+
# import pdb; pdb.set_trace()
168+
169+
170+
# test meta implementation
171+
xq_t = xq.transpose(1, 2)
172+
xk_t = xk.transpose(1, 2)
173+
# import pdb; pdb.set_trace()
174+
freqs_cis = precompute_freqs_cis(dim=head_dim, end=max_position_embeddings)
175+
freqs_cis = freqs_cis[:seq_len]
176+
meta_xq_new, meta_xk_new = apply_rotary_emb(xq_t, xk_t, freqs_cis)
177+
# import pdb; pdb.set_trace()
178+
meta_xq_new = meta_xq_new.transpose(1, 2)
179+
meta_xk_new = meta_xk_new.transpose(1, 2)
180+
181+
error = torch.abs(meta_xq_new - hf_xq_new)
182+
print(f"Compare xq_new error sum: {torch.sum(error)}")
183+
import pdb; pdb.set_trace()
184+
error = torch.abs(meta_xk_new - hf_xk_new)
185+
print(f"Compare xk_new error sum: {torch.sum(error)}")
186+
187+

0 commit comments

Comments
 (0)