@@ -9,12 +9,12 @@ use operators::{
99 add:: { self , Add } ,
1010 all_reduce:: { self , AllReduce , ReduceOp } ,
1111 attention:: { self , Attention } ,
12- attention_kv_cached:: { AttnKVCached } ,
12+ attention_kv_cached:: AttnKVCached ,
1313 fuesd_softmax:: AttnMask ,
1414 mat_mul:: { self , MatMul } ,
1515 rearrange:: { self , Rearrange } ,
1616 rms_norm:: { self , RmsNorm } ,
17- rope:: { self , Rope , SinCosTable } ,
17+ rope:: { self , Rope , Seq , SinCosTable } ,
1818 swiglu:: { self , Swiglu } ,
1919 ByteOf , Hardware , LaunchError , Operator , QueueAlloc , QueueOf , TopoNode , Workspace ,
2020} ;
@@ -151,7 +151,6 @@ where
151151 dkv_lora,
152152 dv,
153153 dt_embd,
154-
155154 ..
156155 } = self . meta ;
157156 // llama.cpp 定义死
@@ -175,8 +174,22 @@ where
175174 let attn = tensor ( & [ nt, nh, dv] ) ;
176175 let ( buf, workspace) = workspace. split_at_mut ( * attn. get ( ) ) ;
177176 let mut attn = attn. map ( |_| buf) ;
178-
179177 let queue = queue_alloc. queue ( ) ;
178+
179+ let sin = sin_cos. clone ( ) . index ( 0 , 0 ) ;
180+ let cos = sin_cos. index ( 0 , 1 ) ;
181+
182+ let pos = Tensor :: new ( self . dt_pos , & [ nt] ) . map ( |_| {
183+ Ops :: Rope :: build_pos (
184+ self . dt_pos ,
185+ nt,
186+ requests. iter ( ) . map ( |req| Seq {
187+ pos : req. pos ,
188+ len : req. seq_len ,
189+ } ) ,
190+ queue_alloc,
191+ )
192+ } ) ;
180193 // 缩放
181194 let inplace = unsafe { x. map_slice_static ( ) } ;
182195 self . scale ( & mut x, & inplace, scale_emb, workspace, queue_alloc) ?;
@@ -232,95 +245,26 @@ where
232245
233246 split_mut ! ( kv => k_nope , v ; [ dnope , dv ] @ 2 ) ;
234247
235- /// longrope
236- pub fn longrope (
237- embd : & mut [ f32 ] ,
238- pos : f32 ,
239- theta : f32 ,
240- long_factor : & [ f32 ] ,
241- short_factor : & [ f32 ] ,
242- max_pos : f32 ,
243- origin_max_pos : f32 ,
244- ) {
245- use std:: slice:: from_raw_parts_mut;
246- // 计算 scaling_factor
247- let scaling_factor =
248- 1.0 + ( ( max_pos / origin_max_pos) . ln ( ) / origin_max_pos. ln ( ) ) . sqrt ( ) ;
249- let factor = if pos > origin_max_pos {
250- long_factor
251- } else {
252- short_factor
253- } ;
254- let dh = embd. len ( ) / 2 ;
255- let embd =
256- unsafe { from_raw_parts_mut ( embd. as_mut_ptr ( ) . cast :: < [ f32 ; 2 ] > ( ) , dh) } ;
257- for ( i, pair) in embd. iter_mut ( ) . enumerate ( ) {
258- let theta = theta. powf ( -( i as f32 / dh as f32 ) ) ;
259- let freq = pos * theta * factor. get ( i) . unwrap ( ) . recip ( ) ;
260- let ( sin, cos) = freq. sin_cos ( ) ;
261- let ( sin, cos) = ( sin * scaling_factor, cos * scaling_factor) ;
262- let [ a, b] = * pair;
263- * pair = [ a * cos - b * sin, a * sin + b * cos] ;
264- }
265- }
266- let cast = |t : * const f32 | -> & ' static [ f32 ] {
267- unsafe { std:: slice:: from_raw_parts ( t, dh / 2 ) }
268- } ;
269- let [ long_factor, short_factor] = self . weights . factor ( queue) ;
270- let long_factor = cast ( long_factor. base ( ) . cast ( ) ) ;
271- let short_factor = cast ( short_factor. base ( ) . cast ( ) ) ;
272-
273248 // k [1, 3840]
274249 let k = tensor ( & [ nt, nh, dk] ) ;
275250 let ( buf, workspace) = workspace. split_at_mut ( * k. get ( ) ) ;
276251 let k = k. map ( |_| buf) ;
277252
278253 split_mut ! ( k => k_nope_r , k_rope_r ; [ dnope, dh] @ 2 ) ;
279254
280- let pos = requests. last ( ) . unwrap ( ) . pos as f32 ;
281- let ( max_pos, origin_max_pos) = ( 100f32 , 100f32 ) ;
282-
283- // q 嵌入
284- ( 0 ..nh) . for_each ( |i| {
285- let tmp_q = unsafe {
286- std:: slice:: from_raw_parts_mut (
287- q_rope. base_mut ( ) . cast :: < f32 > ( ) . add ( i * 32 ) ,
288- 32 ,
289- )
290- } ;
291- longrope (
292- tmp_q,
293- pos,
294- self . meta . theta ,
295- long_factor,
296- short_factor,
297- max_pos,
298- origin_max_pos,
299- ) ;
300- } ) ;
301- // k 嵌入
302-
303- let k_rope_1 =
304- unsafe { std:: slice:: from_raw_parts_mut ( k_rope. base_mut ( ) . cast :: < f32 > ( ) , 32 ) } ;
305- longrope (
306- k_rope_1,
307- pos,
308- self . meta . theta ,
309- long_factor,
310- short_factor,
311- max_pos,
312- origin_max_pos,
313- ) ;
314-
315- // 经行广播和拷贝
316- let k_rope = k_rope. tile ( 1 , & [ 1 , dh] ) . broadcast ( 1 , nh) ;
255+ self . rope ( & mut q_rope, & pos, & sin, & cos, workspace, queue_alloc) ?;
256+ let mut k_rope = k_rope. tile ( 1 , & [ 1 , dh] ) ;
257+ self . rope ( & mut k_rope, & pos, & sin, & cos, workspace, queue_alloc) ?;
258+ let k_rope = k_rope. broadcast ( 1 , nh) ;
317259 self . rearrange ( & mut k_rope_r, & k_rope, workspace, queue_alloc) ?;
318260 self . rearrange ( & mut k_nope_r, & k_nope, workspace, queue_alloc) ?;
319261
262+ let pos = requests. last ( ) . unwrap ( ) . pos as f32 ;
320263 let mut q = q3. transpose ( & [ 1 , 0 ] ) ;
321264 let k = k. map_slice ( ) . transpose ( & [ 1 , 0 ] ) ;
322265 let v = v. map_slice_mut ( ) . transpose ( & [ 1 , 0 ] ) ;
323266 let mut attn = unsafe { attn. map_slice_mut ( ) . transpose ( & [ 1 , 0 ] ) } ;
267+ let pos = requests. last ( ) . unwrap ( ) . pos as f32 ;
324268 self . attnention (
325269 & mut q,
326270 & k,
@@ -490,6 +434,7 @@ where
490434 Cos : Deref < Target = [ ByteOf < Ops :: Hardware > ] > ,
491435 QA : QueueAlloc < Hardware = Ops :: Hardware > ,
492436 {
437+ let [ long, short] = self . weights . factor ( queue_alloc. queue ( ) ) ;
493438 self . rope . launch (
494439 & rope:: Args {
495440 t_layout : t. layout ( ) ,
@@ -501,6 +446,12 @@ where
501446 cos_layout : cos. layout ( ) ,
502447 cos_base : cos. base ( ) ,
503448 theta : self . meta . theta ,
449+ rope_type : rope:: RopeType :: Long {
450+ long : long. base ( ) ,
451+ short : short. base ( ) ,
452+ max_pos : 100 ,
453+ origin_pos : 100 ,
454+ } ,
504455 } ,
505456 workspace,
506457 queue_alloc,
0 commit comments