@@ -3,13 +3,15 @@ use gguf::ggml_quants::digit_layout::types as ty;
33use gguf:: ggml_quants:: digit_layout:: DigitLayout ;
44use half:: f16;
55use itertools:: Itertools ;
6+ use operators:: fuesd_softmax;
7+ use operators:: fuesd_softmax:: FusedSoftmax ;
68use operators:: scale;
79use operators:: scale:: Scale ;
810use operators:: {
911 add:: { self , Add } ,
1012 all_reduce:: { self , AllReduce , ReduceOp } ,
11- attention:: { self , Attention } ,
1213 attention_kv_cached:: AttnKVCached ,
14+ attention_mla:: { self , AttentionMLA } ,
1315 fuesd_softmax:: AttnMask ,
1416 mat_mul:: { self , MatMul } ,
1517 rearrange:: { self , Rearrange } ,
@@ -19,20 +21,22 @@ use operators::{
1921 ByteOf , Hardware , LaunchError , Operator , QueueAlloc , QueueOf , TopoNode , Workspace ,
2022} ;
2123use std:: ops:: { Deref , DerefMut } ;
24+ use std:: process:: Output ;
2225use tensor:: split_mut;
2326use tensor:: { split, Tensor } ;
2427
2528pub trait Operators {
2629 type Hardware : Hardware ;
2730 type TopoNode : TopoNode < Self :: Hardware > ;
28- type Attention : Attention < Self :: Hardware > ;
31+ type AttentionMLA : AttentionMLA < Self :: Hardware > ;
2932 type AttnKVCached : AttnKVCached < Self :: Hardware > ;
3033 type Rope : Rope < Self :: Hardware > ;
3134 type RmsNorm : RmsNorm < Self :: Hardware > ;
3235 type Add : Add < Self :: Hardware > ;
3336 type MatMul : MatMul < Self :: Hardware > ;
3437 type Swiglu : Swiglu < Self :: Hardware > ;
3538 type Scale : Scale < Self :: Hardware > ;
39+ type FuesdSoftmax : FusedSoftmax < Self :: Hardware > ;
3640 type Rearrange : Rearrange < Self :: Hardware > ;
3741 type AllReduce : AllReduce < Self :: Hardware , Self :: TopoNode > ;
3842
@@ -81,12 +85,13 @@ pub struct Minicpm3Worker<Ops: Operators, W> {
8185 dt_pos : DigitLayout ,
8286 add : Ops :: Add ,
8387 attn_kv_cached : Ops :: AttnKVCached ,
84- attention : Ops :: Attention ,
88+ attention_mla : Ops :: AttentionMLA ,
8589 rope : Ops :: Rope ,
8690 rms_norm : Ops :: RmsNorm ,
8791 mat_mul : Ops :: MatMul ,
8892 scale : Ops :: Scale ,
8993 swiglu : Ops :: Swiglu ,
94+ fuesd_softmax : Ops :: FuesdSoftmax ,
9095 rearrange : Ops :: Rearrange ,
9196 all_reduce : Ops :: AllReduce ,
9297}
@@ -108,7 +113,8 @@ impl<Ops: Operators, W> Minicpm3Worker<Ops, W> {
108113 add : Ops :: Add :: new ( processor) ,
109114 all_reduce : Ops :: AllReduce :: new ( node) ,
110115 dt_pos : ty:: U64 ,
111- attention : Ops :: Attention :: new ( processor) ,
116+ attention_mla : Ops :: AttentionMLA :: new ( processor) ,
117+ fuesd_softmax : Ops :: FuesdSoftmax :: new ( processor) ,
112118 }
113119 }
114120
@@ -165,12 +171,11 @@ where
165171
166172 let gate_up = tensor ( & [ nt, di * 2 ] ) ;
167173 // 空间 x+x1+q(应该可以删除)+q3+kv_pe+attn
168- let workspace_size = * x1. get ( ) * 3 + * gate_up. get ( ) ;
174+ let workspace_size = * x1. get ( ) * 20 + * gate_up. get ( ) ;
169175 let mut workspace = Workspace :: new ( queue_alloc, workspace, workspace_size) ;
170176 let ( buf, workspace) = workspace. split_at_mut ( * x1. get ( ) ) ;
171177 let mut x1 = x1. map ( |_| buf) ;
172178
173-
174179 let queue = queue_alloc. queue ( ) ;
175180
176181 let sin = sin_cos. clone ( ) . index ( 0 , 0 ) ;
@@ -205,17 +210,15 @@ where
205210 let w = self . weights . attn_qa_norm ( iblk, queue) ;
206211 self . rms_norm ( & mut q, & inplace, & w, workspace, queue_alloc) ?;
207212 {
208- // q [1, 768] q1 [1, 3840] kv_pe [1,288] kv [1, 5120] k [1, 3840] attn [1, 2560]
209213 let q1 = tensor ( & [ nt, nh * dk] ) ;
210214 let ( buf, workspace) = workspace. split_at_mut ( * q1. get ( ) ) ;
211215 let mut q1 = q1. map ( |_| buf) ;
212216 let w = self . weights . attn_qb ( iblk, queue) . transpose ( & [ 1 , 0 ] ) ;
213217 self . mat_mul ( & mut q1, 0. , & q, & w, 1. , workspace, queue_alloc) ?;
214- drop ( q) ;
215- // q3 是计算 attn 需要用到的数据,但是我们仍然需要对 q3 的的部分进行嵌入操作
218+
216219 let mut q3 = q1. tile ( 1 , & [ nh, dk] ) ;
217220 let q2 = unsafe { q3. map_slice_static_mut ( ) } ;
218- split_mut ! ( q2=>_q , q_rope; [ dnope, dh] @ 2 ) ;
221+ split_mut ! ( q2=>q_nope , q_rope; [ dnope, dh] @ 2 ) ;
219222
220223 // kv_pe [1,288]
221224 let kv_pe = tensor ( & [ nt, dkv_lora + dh] ) ;
@@ -224,62 +227,69 @@ where
224227
225228 let w = self . weights . attn_kva ( iblk, queue) . transpose ( & [ 1 , 0 ] ) ;
226229 self . mat_mul ( & mut kv_pe, 0. , & x1, & w, 1. , workspace, queue_alloc) ?;
227-
230+ drop ( q ) ;
228231 split_mut ! ( kv_pe => kv_lora, k_rope; [ dkv_lora, dh] @ 1 ) ;
229232
233+ self . rope ( & mut q_rope, & pos, & sin, & cos, workspace, queue_alloc) ?;
234+ let mut k_rope = k_rope. tile ( 1 , & [ 1 , dh] ) ;
235+ self . rope ( & mut k_rope, & pos, & sin, & cos, workspace, queue_alloc) ?;
236+ let k_rope = k_rope. broadcast ( 1 , nh) ;
237+
230238 let inplace = unsafe { kv_lora. map_slice_static ( ) } ;
231239 let w = self . weights . attn_kva_norm ( iblk, queue) ;
232240 self . rms_norm ( & mut kv_lora, & inplace, & w, workspace, queue_alloc) ?;
233241 // kv X[1, 5120]
234242 let kv = tensor ( & [ nt, nh * ( dnope + dv) ] ) ;
235243 let ( buf, workspace) = workspace. split_at_mut ( * kv. get ( ) ) ;
236244 let mut kv = kv. map ( |_| buf) ;
237- let w = self . weights . attn_kvb ( iblk, queue) . transpose ( & [ 1 , 0 ] ) ;
238245
239- self . mat_mul ( & mut kv, 0. , & kv_lora, & w, 1. , workspace, queue_alloc) ?;
240-
241- let kv = kv. tile ( 1 , & [ nh, dnope + dv] ) ;
242-
243- split_mut ! ( kv => k_nope , v ; [ dnope , dv ] @ 2 ) ;
244-
245- // k [1, 3840]
246- let k = tensor ( & [ nt, nh, dk] ) ;
247- let ( buf, workspace) = workspace. split_at_mut ( * k. get ( ) ) ;
248- let k = k. map ( |_| buf) ;
249-
250- split_mut ! ( k => k_nope_r , k_rope_r ; [ dnope, dh] @ 2 ) ;
246+ let kv_b_proj = unsafe {
247+ self . weights
248+ . attn_kvb ( iblk, queue)
249+ . tile ( 0 , & [ nh, dnope + dv] )
250+ . map_slice_static ( )
251+ } ;
252+ split ! ( kv_b_proj=> q_absorb , out_absorb ; [ dnope, dv] @ 1 ) ;
253+ let inplace = unsafe { q_nope. map_slice_static ( ) } ;
254+
255+ let q_nope_0 = q_nope. map_slice ( ) . transpose ( & [ 1 , 0 ] ) ;
256+ let q_nope_1 = tensor ( & [ nh, nt, dkv_lora] ) ;
257+ let ( buf, workspace) = workspace. split_at_mut ( * q_nope_1. get ( ) ) ;
258+ let mut q_nope = q_nope_1. map ( |_| buf) ;
259+ self . mat_mul (
260+ & mut q_nope,
261+ 0. ,
262+ & q_nope_0,
263+ & q_absorb,
264+ 1. ,
265+ workspace,
266+ queue_alloc,
267+ ) ?;
251268
252- self . rope ( & mut q_rope, & pos, & sin, & cos, workspace, queue_alloc) ?;
253- let mut k_rope = k_rope. tile ( 1 , & [ 1 , dh] ) ;
254- self . rope ( & mut k_rope, & pos, & sin, & cos, workspace, queue_alloc) ?;
255- let k_rope = k_rope. broadcast ( 1 , nh) ;
256- self . rearrange ( & mut k_rope_r, & k_rope, workspace, queue_alloc) ?;
257- self . rearrange ( & mut k_nope_r, & k_nope, workspace, queue_alloc) ?;
258-
259- let pos = requests. last ( ) . unwrap ( ) . pos as f32 ;
260- let mut q = q3. transpose ( & [ 1 , 0 ] ) ;
261- let k = k. map_slice ( ) . transpose ( & [ 1 , 0 ] ) ;
262- let v = v. map_slice_mut ( ) . transpose ( & [ 1 , 0 ] ) ;
263- // 经行 attention
264- let attn = tensor ( & [ nt, nh, dv] ) ;
265- let ( buf, workspace) = workspace. split_at_mut ( * attn. get ( ) ) ;
266- let mut attn = attn. map ( |_| buf) ;
267-
268- let mut attn = unsafe { attn. map_slice_mut ( ) . transpose ( & [ 1 , 0 ] ) } ;
269- let pos = requests. last ( ) . unwrap ( ) . pos as f32 ;
269+ drop ( q3) ;
270+ // attn_output
271+ let attn_output = tensor ( & [ nt, nh, dv] ) ;
272+ let ( buf, workspace) = workspace. split_at_mut ( * attn_output. get ( ) ) ;
273+ let mut attn_output = attn_output. map ( |_| buf) ;
274+ let q_rope = q_rope. transpose ( & [ 1 , 0 ] ) ;
275+ let k_rope = k_rope. transpose ( & [ 1 , 0 ] ) ;
276+ let kv_lora = kv_lora. map_slice ( ) . tile ( 0 , & [ 1 , 1 ] ) . broadcast ( 0 , nh) ;
277+ let mut o=unsafe {
278+ attn_output. map_slice_static_mut ( ) . transpose ( & [ 1 , 0 ] )
279+ } ;
270280 self . attnention (
271- & mut q,
272- & k,
273- & v,
274- & mut attn,
275- pos as usize ,
281+ & mut q_nope,
282+ & kv_lora,
283+ & out_absorb,
284+ & q_rope,
285+ & k_rope,
286+ & mut o,
287+ 1 ,
276288 workspace,
277289 queue_alloc,
278290 ) ?;
279-
280- let o = attn. transpose ( & [ 1 , 0 ] ) . merge ( 1 ..3 ) . unwrap ( ) ;
291+ let o = attn_output. map_slice ( ) . merge ( 1 ..3 ) . unwrap ( ) ;
281292 let w = self . weights . attn_o ( iblk, queue) ;
282-
283293 self . mat_mul ( & mut x1, 0. , & o, & w, s, workspace, queue_alloc) ?;
284294 let inplace = unsafe { x. map_slice_static ( ) } ;
285295 self . add ( & mut x, & inplace, & x1, workspace, queue_alloc) ?;
@@ -301,17 +311,6 @@ where
301311
302312 self . swiglu ( & mut gate, & up, workspace, queue_alloc) ?;
303313
304- fn print_first_10_elements ( ptr : * const f16 ) {
305- assert ! ( !ptr. is_null( ) , "Pointer must not be null" ) ;
306-
307- unsafe {
308- for i in 0 ..10 {
309- // 逐个访问并打印前 10 个元素
310- let element = ptr. offset ( i as isize ) . read ( ) ;
311- println ! ( "Element {}: {:?}" , i, element) ;
312- }
313- }
314- }
315314
316315 let w = self . weights . ffn_down ( iblk, queue) ;
317316 self . mat_mul ( & mut x1, 0. , & gate, & w, s, workspace, queue_alloc) ?;
@@ -460,31 +459,39 @@ where
460459 queue_alloc,
461460 )
462461 }
463- fn attnention < Q , K , V , O , QA > (
462+ fn attnention < Q , KV , A , QR , KR , O , QA > (
464463 & self ,
465464 q : & mut Tensor < Q > ,
466- k : & Tensor < K > ,
467- v : & Tensor < V > ,
465+ kv : & Tensor < KV > ,
466+ a : & Tensor < A > ,
467+ qr : & Tensor < QR > ,
468+ kr : & Tensor < KR > ,
468469 o : & mut Tensor < O > ,
469470 pos : usize ,
470471 workspace : & mut [ ByteOf < Ops :: Hardware > ] ,
471472 queue_alloc : & QA ,
472473 ) -> Result < ( ) , LaunchError >
473474 where
474475 Q : DerefMut < Target = [ ByteOf < Ops :: Hardware > ] > ,
475- K : Deref < Target = [ ByteOf < Ops :: Hardware > ] > ,
476- V : Deref < Target = [ ByteOf < Ops :: Hardware > ] > ,
476+ KV : Deref < Target = [ ByteOf < Ops :: Hardware > ] > ,
477+ A : Deref < Target = [ ByteOf < Ops :: Hardware > ] > ,
478+ QR : Deref < Target = [ ByteOf < Ops :: Hardware > ] > ,
479+ KR : Deref < Target = [ ByteOf < Ops :: Hardware > ] > ,
477480 O : DerefMut < Target = [ ByteOf < Ops :: Hardware > ] > ,
478481 QA : QueueAlloc < Hardware = Ops :: Hardware > ,
479482 {
480- self . attention . launch (
481- & attention :: Args {
483+ self . attention_mla . launch (
484+ & attention_mla :: Args {
482485 q_layout : q. layout ( ) ,
483486 q_base : q. base_mut ( ) ,
484- k_layout : k. layout ( ) ,
485- k_base : k. base ( ) ,
486- v_layout : v. layout ( ) ,
487- v_base : v. base ( ) ,
487+ kv_layout : kv. layout ( ) ,
488+ kv_base : kv. base ( ) ,
489+ absorb_layout : a. layout ( ) ,
490+ absorb_base : a. base ( ) ,
491+ qr_layout : qr. layout ( ) ,
492+ qr_base : qr. base ( ) ,
493+ kr_layout : kr. layout ( ) ,
494+ kr_base : kr. base ( ) ,
488495 o_layout : o. layout ( ) ,
489496 o_base : o. base_mut ( ) ,
490497 mask : AttnMask :: Causal ,
@@ -594,6 +601,26 @@ where
594601 queue_alloc,
595602 )
596603 }
604+ fn softmax < A , QA > (
605+ & self ,
606+ a : & mut Tensor < A > ,
607+ workspace : & mut [ ByteOf < Ops :: Hardware > ] ,
608+ queue_alloc : & QA ,
609+ ) -> Result < ( ) , LaunchError >
610+ where
611+ A : DerefMut < Target = [ ByteOf < Ops :: Hardware > ] > ,
612+ QA : QueueAlloc < Hardware = Ops :: Hardware > ,
613+ {
614+ self . fuesd_softmax . launch (
615+ & fuesd_softmax:: Args {
616+ att_mask : AttnMask :: Causal ,
617+ att_layout : a. layout ( ) ,
618+ att_base : a. base_mut ( ) ,
619+ } ,
620+ workspace,
621+ queue_alloc,
622+ )
623+ }
597624 fn all_reduce < X , QA > (
598625 & self ,
599626 x : & mut Tensor < X > ,
0 commit comments