1
1
import torch
2
+ from typing import Optional , Tuple
2
3
3
4
cos_cached = None
4
5
sin_cached = None
5
6
7
+ ### huggingface implementation ###
6
8
7
9
def init_rope_embeddings (dim , max_position_embeddings = 4096 , base = 10000 , device = None , scaling_factor = 1.0 ):
8
10
inv_freq = 1.0 / (base ** (torch .arange (0 , dim , 2 , dtype = torch .int64 ).float ().to (device ) / dim ))
@@ -64,5 +66,122 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
64
66
k_embed = (k * cos ) + (rotate_half (k ) * sin )
65
67
return q_embed , k_embed
66
68
69
+ ### meta implementation ###
70
+
71
+ def precompute_freqs_cis (dim : int , end : int , theta : float = 10000.0 ):
72
+ """
73
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
74
+
75
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
76
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
77
+ The returned tensor contains complex values in complex64 data type.
78
+
79
+ Args:
80
+ dim (int): Dimension of the frequency tensor.
81
+ end (int): End index for precomputing frequencies.
82
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
83
+
84
+ Returns:
85
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
86
+
87
+ """
88
+ freqs = 1.0 / (theta ** (torch .arange (0 , dim , 2 )[: (dim // 2 )].float () / dim ))
89
+ t = torch .arange (end , device = freqs .device ) # type: ignore
90
+ freqs = torch .outer (t , freqs ).float () # type: ignore
91
+ freqs_cis = torch .polar (torch .ones_like (freqs ), freqs ) # complex64
92
+ return freqs_cis
93
+
94
+
95
+ def reshape_for_broadcast (freqs_cis : torch .Tensor , x : torch .Tensor ):
96
+ """
97
+ Reshape frequency tensor for broadcasting it with another tensor.
98
+
99
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
100
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
101
+
102
+ Args:
103
+ freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
104
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
105
+
106
+ Returns:
107
+ torch.Tensor: Reshaped frequency tensor.
108
+
109
+ Raises:
110
+ AssertionError: If the frequency tensor doesn't match the expected shape.
111
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
112
+ """
113
+ ndim = x .ndim
114
+ assert 0 <= 1 < ndim
115
+ assert freqs_cis .shape == (x .shape [1 ], x .shape [- 1 ])
116
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i , d in enumerate (x .shape )]
117
+ return freqs_cis .view (* shape )
118
+
119
+
120
+ def apply_rotary_emb (
121
+ xq : torch .Tensor ,
122
+ xk : torch .Tensor ,
123
+ freqs_cis : torch .Tensor ,
124
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
125
+ """
126
+ Apply rotary embeddings to input tensors using the given frequency tensor.
127
+
128
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
129
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
130
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
131
+ returned as real tensors.
132
+
133
+ Args:
134
+ xq (torch.Tensor): Query tensor to apply rotary embeddings.
135
+ xk (torch.Tensor): Key tensor to apply rotary embeddings.
136
+ freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
137
+
138
+ Returns:
139
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
140
+
141
+
142
+
143
+ """
144
+ xq_ = torch .view_as_complex (xq .float ().reshape (* xq .shape [:- 1 ], - 1 , 2 ))
145
+ xk_ = torch .view_as_complex (xk .float ().reshape (* xk .shape [:- 1 ], - 1 , 2 ))
146
+ freqs_cis = reshape_for_broadcast (freqs_cis , xq_ )
147
+ xq_out = torch .view_as_real (xq_ * freqs_cis ).flatten (3 )
148
+ xk_out = torch .view_as_real (xk_ * freqs_cis ).flatten (3 )
149
+ return xq_out .type_as (xq ), xk_out .type_as (xk )
150
+
67
151
if __name__ == '__main__' :
68
- init_rope_embeddings (dim = 128 )
152
+ torch .set_printoptions (linewidth = 200 ) # 这样打印不会存在折叠的问题
153
+ # batch_size, seq_len, head_num, head_dim = 1, 13, 32, 128
154
+ batch_size , seq_len , head_num , head_dim = 1 , 6 , 1 , 8
155
+ max_position_embeddings = 4096
156
+
157
+ # test hf implementation
158
+ cos_cached ,sin_cached = init_rope_embeddings (dim = head_dim , max_position_embeddings = max_position_embeddings )
159
+
160
+ xq = torch .randn (batch_size , head_num , seq_len , head_dim )
161
+ import copy
162
+ xk = copy .deepcopy (xq )
163
+ # import pdb; pdb.set_trace()
164
+ cos , sin = get_rope_embeddings (xq , seq_len = seq_len )
165
+ position_ids = torch .arange (0 , seq_len , dtype = torch .long ).unsqueeze (0 )
166
+ hf_xq_new , hf_xk_new = apply_rotary_pos_emb (xq , xk , cos , sin , position_ids )
167
+ # import pdb; pdb.set_trace()
168
+
169
+
170
+ # test meta implementation
171
+ xq_t = xq .transpose (1 , 2 )
172
+ xk_t = xk .transpose (1 , 2 )
173
+ # import pdb; pdb.set_trace()
174
+ freqs_cis = precompute_freqs_cis (dim = head_dim , end = max_position_embeddings )
175
+ freqs_cis = freqs_cis [:seq_len ]
176
+ meta_xq_new , meta_xk_new = apply_rotary_emb (xq_t , xk_t , freqs_cis )
177
+ # import pdb; pdb.set_trace()
178
+ meta_xq_new = meta_xq_new .transpose (1 , 2 )
179
+ meta_xk_new = meta_xk_new .transpose (1 , 2 )
180
+
181
+ error = torch .abs (meta_xq_new - hf_xq_new )
182
+ print (f"Compare xq_new error sum: { torch .sum (error )} " )
183
+ import pdb ; pdb .set_trace ()
184
+ error = torch .abs (meta_xk_new - hf_xk_new )
185
+ print (f"Compare xk_new error sum: { torch .sum (error )} " )
186
+
187
+
0 commit comments