Skip to content

Commit ad2ff0b

Browse files
committed
feat: 添加mla
1 parent c469a04 commit ad2ff0b

File tree

3 files changed

+101
-74
lines changed

3 files changed

+101
-74
lines changed

models/minicpm3/common-cpu/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,15 @@ where
3535
type Hardware = Cpu;
3636
type TopoNode = N;
3737
type Rope = op!(rope);
38-
type Attention = op!(attention);
38+
type AttentionMLA = op!(attention_mla);
3939
type RmsNorm = op!(rms_norm);
4040
type Add = op!(add);
4141
type MatMul = op!(mat_mul);
4242
type Swiglu = op!(swiglu);
4343
type Rearrange = op!(rearrange);
4444
type Scale = op!(scale);
4545
type AttnKVCached = op!(attention_kv_cached);
46+
type FuesdSoftmax = op!(fuesd_softmax);
4647
type AllReduce = R;
4748

4849
fn debug<T>(tensor: &Tensor<T>, _queue: &QueueOf<Self::Hardware>)

models/minicpm3/common/src/compute.rs

Lines changed: 99 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@ use gguf::ggml_quants::digit_layout::types as ty;
33
use gguf::ggml_quants::digit_layout::DigitLayout;
44
use half::f16;
55
use itertools::Itertools;
6+
use operators::fuesd_softmax;
7+
use operators::fuesd_softmax::FusedSoftmax;
68
use operators::scale;
79
use operators::scale::Scale;
810
use operators::{
911
add::{self, Add},
1012
all_reduce::{self, AllReduce, ReduceOp},
11-
attention::{self, Attention},
1213
attention_kv_cached::AttnKVCached,
14+
attention_mla::{self, AttentionMLA},
1315
fuesd_softmax::AttnMask,
1416
mat_mul::{self, MatMul},
1517
rearrange::{self, Rearrange},
@@ -19,20 +21,22 @@ use operators::{
1921
ByteOf, Hardware, LaunchError, Operator, QueueAlloc, QueueOf, TopoNode, Workspace,
2022
};
2123
use std::ops::{Deref, DerefMut};
24+
use std::process::Output;
2225
use tensor::split_mut;
2326
use tensor::{split, Tensor};
2427

2528
pub trait Operators {
2629
type Hardware: Hardware;
2730
type TopoNode: TopoNode<Self::Hardware>;
28-
type Attention: Attention<Self::Hardware>;
31+
type AttentionMLA: AttentionMLA<Self::Hardware>;
2932
type AttnKVCached: AttnKVCached<Self::Hardware>;
3033
type Rope: Rope<Self::Hardware>;
3134
type RmsNorm: RmsNorm<Self::Hardware>;
3235
type Add: Add<Self::Hardware>;
3336
type MatMul: MatMul<Self::Hardware>;
3437
type Swiglu: Swiglu<Self::Hardware>;
3538
type Scale: Scale<Self::Hardware>;
39+
type FuesdSoftmax: FusedSoftmax<Self::Hardware>;
3640
type Rearrange: Rearrange<Self::Hardware>;
3741
type AllReduce: AllReduce<Self::Hardware, Self::TopoNode>;
3842

@@ -81,12 +85,13 @@ pub struct Minicpm3Worker<Ops: Operators, W> {
8185
dt_pos: DigitLayout,
8286
add: Ops::Add,
8387
attn_kv_cached: Ops::AttnKVCached,
84-
attention: Ops::Attention,
88+
attention_mla: Ops::AttentionMLA,
8589
rope: Ops::Rope,
8690
rms_norm: Ops::RmsNorm,
8791
mat_mul: Ops::MatMul,
8892
scale: Ops::Scale,
8993
swiglu: Ops::Swiglu,
94+
fuesd_softmax: Ops::FuesdSoftmax,
9095
rearrange: Ops::Rearrange,
9196
all_reduce: Ops::AllReduce,
9297
}
@@ -108,7 +113,8 @@ impl<Ops: Operators, W> Minicpm3Worker<Ops, W> {
108113
add: Ops::Add::new(processor),
109114
all_reduce: Ops::AllReduce::new(node),
110115
dt_pos: ty::U64,
111-
attention: Ops::Attention::new(processor),
116+
attention_mla: Ops::AttentionMLA::new(processor),
117+
fuesd_softmax: Ops::FuesdSoftmax::new(processor),
112118
}
113119
}
114120

@@ -165,12 +171,11 @@ where
165171

166172
let gate_up = tensor(&[nt, di * 2]);
167173
// 空间 x+x1+q(应该可以删除)+q3+kv_pe+attn
168-
let workspace_size = *x1.get() * 3 + *gate_up.get();
174+
let workspace_size = *x1.get() * 20 + *gate_up.get();
169175
let mut workspace = Workspace::new(queue_alloc, workspace, workspace_size);
170176
let (buf, workspace) = workspace.split_at_mut(*x1.get());
171177
let mut x1 = x1.map(|_| buf);
172178

173-
174179
let queue = queue_alloc.queue();
175180

176181
let sin = sin_cos.clone().index(0, 0);
@@ -205,17 +210,15 @@ where
205210
let w = self.weights.attn_qa_norm(iblk, queue);
206211
self.rms_norm(&mut q, &inplace, &w, workspace, queue_alloc)?;
207212
{
208-
// q [1, 768] q1 [1, 3840] kv_pe [1,288] kv [1, 5120] k [1, 3840] attn [1, 2560]
209213
let q1 = tensor(&[nt, nh * dk]);
210214
let (buf, workspace) = workspace.split_at_mut(*q1.get());
211215
let mut q1 = q1.map(|_| buf);
212216
let w = self.weights.attn_qb(iblk, queue).transpose(&[1, 0]);
213217
self.mat_mul(&mut q1, 0., &q, &w, 1., workspace, queue_alloc)?;
214-
drop(q);
215-
// q3 是计算 attn 需要用到的数据,但是我们仍然需要对 q3 的的部分进行嵌入操作
218+
216219
let mut q3 = q1.tile(1, &[nh, dk]);
217220
let q2 = unsafe { q3.map_slice_static_mut() };
218-
split_mut!(q2=>_q, q_rope;[dnope, dh]@ 2);
221+
split_mut!(q2=>q_nope, q_rope;[dnope, dh]@ 2);
219222

220223
// kv_pe [1,288]
221224
let kv_pe = tensor(&[nt, dkv_lora + dh]);
@@ -224,62 +227,69 @@ where
224227

225228
let w = self.weights.attn_kva(iblk, queue).transpose(&[1, 0]);
226229
self.mat_mul(&mut kv_pe, 0., &x1, &w, 1., workspace, queue_alloc)?;
227-
230+
drop(q);
228231
split_mut!(kv_pe => kv_lora, k_rope; [dkv_lora, dh] @ 1);
229232

233+
self.rope(&mut q_rope, &pos, &sin, &cos, workspace, queue_alloc)?;
234+
let mut k_rope = k_rope.tile(1, &[1, dh]);
235+
self.rope(&mut k_rope, &pos, &sin, &cos, workspace, queue_alloc)?;
236+
let k_rope = k_rope.broadcast(1, nh);
237+
230238
let inplace = unsafe { kv_lora.map_slice_static() };
231239
let w = self.weights.attn_kva_norm(iblk, queue);
232240
self.rms_norm(&mut kv_lora, &inplace, &w, workspace, queue_alloc)?;
233241
// kv X[1, 5120]
234242
let kv = tensor(&[nt, nh * (dnope + dv)]);
235243
let (buf, workspace) = workspace.split_at_mut(*kv.get());
236244
let mut kv = kv.map(|_| buf);
237-
let w = self.weights.attn_kvb(iblk, queue).transpose(&[1, 0]);
238245

239-
self.mat_mul(&mut kv, 0., &kv_lora, &w, 1., workspace, queue_alloc)?;
240-
241-
let kv = kv.tile(1, &[nh, dnope + dv]);
242-
243-
split_mut!(kv => k_nope ,v ; [dnope , dv ] @ 2);
244-
245-
// k [1, 3840]
246-
let k = tensor(&[nt, nh, dk]);
247-
let (buf, workspace) = workspace.split_at_mut(*k.get());
248-
let k = k.map(|_| buf);
249-
250-
split_mut!(k => k_nope_r ,k_rope_r ; [dnope, dh] @ 2);
246+
let kv_b_proj = unsafe {
247+
self.weights
248+
.attn_kvb(iblk, queue)
249+
.tile(0, &[nh, dnope + dv])
250+
.map_slice_static()
251+
};
252+
split!(kv_b_proj=> q_absorb , out_absorb ; [dnope, dv] @ 1);
253+
let inplace = unsafe { q_nope.map_slice_static() };
254+
255+
let q_nope_0 = q_nope.map_slice().transpose(&[1, 0]);
256+
let q_nope_1 = tensor(&[nh, nt, dkv_lora]);
257+
let (buf, workspace) = workspace.split_at_mut(*q_nope_1.get());
258+
let mut q_nope = q_nope_1.map(|_| buf);
259+
self.mat_mul(
260+
&mut q_nope,
261+
0.,
262+
&q_nope_0,
263+
&q_absorb,
264+
1.,
265+
workspace,
266+
queue_alloc,
267+
)?;
251268

252-
self.rope(&mut q_rope, &pos, &sin, &cos, workspace, queue_alloc)?;
253-
let mut k_rope = k_rope.tile(1, &[1, dh]);
254-
self.rope(&mut k_rope, &pos, &sin, &cos, workspace, queue_alloc)?;
255-
let k_rope = k_rope.broadcast(1, nh);
256-
self.rearrange(&mut k_rope_r, &k_rope, workspace, queue_alloc)?;
257-
self.rearrange(&mut k_nope_r, &k_nope, workspace, queue_alloc)?;
258-
259-
let pos = requests.last().unwrap().pos as f32;
260-
let mut q = q3.transpose(&[1, 0]);
261-
let k = k.map_slice().transpose(&[1, 0]);
262-
let v = v.map_slice_mut().transpose(&[1, 0]);
263-
// 经行 attention
264-
let attn = tensor(&[nt, nh, dv]);
265-
let (buf, workspace) = workspace.split_at_mut(*attn.get());
266-
let mut attn = attn.map(|_| buf);
267-
268-
let mut attn = unsafe { attn.map_slice_mut().transpose(&[1, 0]) };
269-
let pos = requests.last().unwrap().pos as f32;
269+
drop(q3);
270+
// attn_output
271+
let attn_output = tensor(&[nt, nh, dv]);
272+
let (buf, workspace) = workspace.split_at_mut(*attn_output.get());
273+
let mut attn_output = attn_output.map(|_| buf);
274+
let q_rope = q_rope.transpose(&[1, 0]);
275+
let k_rope = k_rope.transpose(&[1, 0]);
276+
let kv_lora = kv_lora.map_slice().tile(0, &[1, 1]).broadcast(0, nh);
277+
let mut o=unsafe {
278+
attn_output.map_slice_static_mut().transpose(&[1, 0])
279+
};
270280
self.attnention(
271-
&mut q,
272-
&k,
273-
&v,
274-
&mut attn,
275-
pos as usize,
281+
&mut q_nope,
282+
&kv_lora,
283+
&out_absorb,
284+
&q_rope,
285+
&k_rope,
286+
&mut o,
287+
1,
276288
workspace,
277289
queue_alloc,
278290
)?;
279-
280-
let o = attn.transpose(&[1, 0]).merge(1..3).unwrap();
291+
let o = attn_output.map_slice().merge(1..3).unwrap();
281292
let w = self.weights.attn_o(iblk, queue);
282-
283293
self.mat_mul(&mut x1, 0., &o, &w, s, workspace, queue_alloc)?;
284294
let inplace = unsafe { x.map_slice_static() };
285295
self.add(&mut x, &inplace, &x1, workspace, queue_alloc)?;
@@ -301,17 +311,6 @@ where
301311

302312
self.swiglu(&mut gate, &up, workspace, queue_alloc)?;
303313

304-
fn print_first_10_elements(ptr: *const f16) {
305-
assert!(!ptr.is_null(), "Pointer must not be null");
306-
307-
unsafe {
308-
for i in 0..10 {
309-
// 逐个访问并打印前 10 个元素
310-
let element = ptr.offset(i as isize).read();
311-
println!("Element {}: {:?}", i, element);
312-
}
313-
}
314-
}
315314

316315
let w = self.weights.ffn_down(iblk, queue);
317316
self.mat_mul(&mut x1, 0., &gate, &w, s, workspace, queue_alloc)?;
@@ -460,31 +459,39 @@ where
460459
queue_alloc,
461460
)
462461
}
463-
fn attnention<Q, K, V, O, QA>(
462+
fn attnention<Q, KV, A, QR, KR, O, QA>(
464463
&self,
465464
q: &mut Tensor<Q>,
466-
k: &Tensor<K>,
467-
v: &Tensor<V>,
465+
kv: &Tensor<KV>,
466+
a: &Tensor<A>,
467+
qr: &Tensor<QR>,
468+
kr: &Tensor<KR>,
468469
o: &mut Tensor<O>,
469470
pos: usize,
470471
workspace: &mut [ByteOf<Ops::Hardware>],
471472
queue_alloc: &QA,
472473
) -> Result<(), LaunchError>
473474
where
474475
Q: DerefMut<Target = [ByteOf<Ops::Hardware>]>,
475-
K: Deref<Target = [ByteOf<Ops::Hardware>]>,
476-
V: Deref<Target = [ByteOf<Ops::Hardware>]>,
476+
KV: Deref<Target = [ByteOf<Ops::Hardware>]>,
477+
A: Deref<Target = [ByteOf<Ops::Hardware>]>,
478+
QR: Deref<Target = [ByteOf<Ops::Hardware>]>,
479+
KR: Deref<Target = [ByteOf<Ops::Hardware>]>,
477480
O: DerefMut<Target = [ByteOf<Ops::Hardware>]>,
478481
QA: QueueAlloc<Hardware = Ops::Hardware>,
479482
{
480-
self.attention.launch(
481-
&attention::Args {
483+
self.attention_mla.launch(
484+
&attention_mla::Args {
482485
q_layout: q.layout(),
483486
q_base: q.base_mut(),
484-
k_layout: k.layout(),
485-
k_base: k.base(),
486-
v_layout: v.layout(),
487-
v_base: v.base(),
487+
kv_layout: kv.layout(),
488+
kv_base: kv.base(),
489+
absorb_layout: a.layout(),
490+
absorb_base: a.base(),
491+
qr_layout: qr.layout(),
492+
qr_base: qr.base(),
493+
kr_layout: kr.layout(),
494+
kr_base: kr.base(),
488495
o_layout: o.layout(),
489496
o_base: o.base_mut(),
490497
mask: AttnMask::Causal,
@@ -594,6 +601,26 @@ where
594601
queue_alloc,
595602
)
596603
}
604+
fn softmax<A, QA>(
605+
&self,
606+
a: &mut Tensor<A>,
607+
workspace: &mut [ByteOf<Ops::Hardware>],
608+
queue_alloc: &QA,
609+
) -> Result<(), LaunchError>
610+
where
611+
A: DerefMut<Target = [ByteOf<Ops::Hardware>]>,
612+
QA: QueueAlloc<Hardware = Ops::Hardware>,
613+
{
614+
self.fuesd_softmax.launch(
615+
&fuesd_softmax::Args {
616+
att_mask: AttnMask::Causal,
617+
att_layout: a.layout(),
618+
att_base: a.base_mut(),
619+
},
620+
workspace,
621+
queue_alloc,
622+
)
623+
}
597624
fn all_reduce<X, QA>(
598625
&self,
599626
x: &mut Tensor<X>,

tensor/src/split.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ impl<T> Splitable for &[T] {
1414
self
1515
}
1616
}
17-
1817
impl<T> Splitable for &mut [T] {
1918
#[inline]
2019
fn split(&self) -> Self {

0 commit comments

Comments
 (0)