diff --git a/Cargo.toml b/Cargo.toml index cf3d17c7..dd549368 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,9 @@ members = [ "models/gpt2/common", "models/gpt2/common-cpu", "models/gpt2/cuda", + + "models/minicpm3/common", + "models/minicpm3/common-cpu", ] resolver = "2" @@ -38,7 +41,7 @@ itertools = "0.13" env_logger = "0.11" build-script-cfg = "0.0" -operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "61789f7", default-features = false } +operators = { path = "/home/ztf/operators-rs/operators"} search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "f69b160" } search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "e8362c3" } diff --git a/models/llama/common/src/compute.rs b/models/llama/common/src/compute.rs index e15da0b7..87976e5c 100644 --- a/models/llama/common/src/compute.rs +++ b/models/llama/common/src/compute.rs @@ -480,6 +480,7 @@ where cos_layout: cos.layout(), cos_base: cos.base(), theta: self.meta.theta, + rope_type: rope::RopeType::Rope, }, workspace, queue_alloc, diff --git a/models/minicpm3/common-cpu/Cargo.toml b/models/minicpm3/common-cpu/Cargo.toml new file mode 100644 index 00000000..00ad4763 --- /dev/null +++ b/models/minicpm3/common-cpu/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "minicpm3-cpu" +version = "0.0.0" +edition = "2021" +authors = ["onenewcode ", "YdrMaster "] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +minicpm3.path = "../common" +common.workspace = true +operators = { workspace = true, features = ["common-cpu"] } + +[dev-dependencies] +test-utils.workspace = true +gguf.workspace = true +regex.workspace = true diff --git a/models/minicpm3/common-cpu/src/infer.rs b/models/minicpm3/common-cpu/src/infer.rs new file mode 100644 index 00000000..4ef4cf41 --- /dev/null +++ b/models/minicpm3/common-cpu/src/infer.rs @@ -0,0 +1,165 @@ +use crate::{Operators, RandomSample, Weights}; +use common::Distribution; +use gguf::GGufModel; +use minicpm3::{ext::ggml_quants::f16, MiniCPM3Request, MiniCPM3Storage, Minicpm3Worker, Tensor}; +use operators::{ + all_reduce::common_cpu::Operator as AllReduce, + common_cpu::{InprocNode, ThisThread}, + random_sample::{KVPair, SampleArgs}, + Blob, +}; +use regex::Regex; +use std::{ + iter::zip, + ptr::copy_nonoverlapping, + slice::from_raw_parts_mut, + sync::{Arc, Barrier}, + thread, +}; +use test_utils::{test_infer_paralle, Inference, Task, TokenizerAndPrompt, WorkerSeed}; + +type Worker<'w> = Minicpm3Worker, AllReduce>, Weights<'w>>; + +#[test] +fn test_infer() { + std::env::set_var( + "TEST_MODEL", + "/home/ztf/cpm/Origin-MiniCPM3-4B-v0.0-F16.gguf", + ); + let Some(Inference { + model, + devices, + prompt, + as_user, + temperature, + top_p, + top_k, + max_steps, + }) = Inference::load() + else { + return; + }; + let gguf = GGufModel::read(model.iter().map(|s| &**s)); + + let TokenizerAndPrompt { + eos, + tokenizer, + prompt, + } = TokenizerAndPrompt::new(&gguf, prompt, as_user); + + let model = MiniCPM3Storage::from_gguf(&gguf); + println!("{:?}", model.meta); + + let sample_args = SampleArgs::new(temperature, top_p, top_k).expect("invalid sample args"); + println!("{sample_args:?}"); + + let lens = devices + .map(|devices| { + Regex::new(r"\d+") + .unwrap() + .find_iter(&devices) + .map(|c| c.as_str().parse().unwrap()) + .collect() + }) + .unwrap_or_else(|| vec![1]); + let dist = lens.iter().sum(); + println!("distribution: {lens:?}"); + + let (seeds, senders) = WorkerSeed::new(InprocNode::new(lens.len())); + let barrier = Arc::new(Barrier::new(dist + 1)); + thread::scope(|s| { + let _workers = zip(lens, seeds) + .enumerate() + .scan(0, |start, (id, (len, seed))| { + let dist = Distribution::new(*start, len, dist); + *start += len; + + let meta = model.meta.distribute(dist); + let model = &model; + let barrier = barrier.clone(); + Some(s.spawn(move || { + let WorkerSeed { node, tasks } = seed; + let weights = Weights::new(model, dist); + let mut worker = Worker::new(id, &node, meta.clone(), weights); + let mut cache = meta.kv_cache(meta.nctx).map(Blob::new); + let sin_cos = ::build_sin_cos( + meta.dt_embd, + meta.nctx, + meta.dh, + &ThisThread, + ); + + let sample = RandomSample::new(&node); + let indices = RandomSample::build_indices(model.meta.nvoc, &ThisThread); + let mut pair = KVPair::new(0, f16::ZERO); + let mut pairs = Tensor::kv_pair_vec(1, |_| unsafe { + from_raw_parts_mut(&mut pair as *mut _ as *mut u8, size_of_val(&pair)) + }); + + barrier.wait(); + for task in tasks { + let Task { + nt, + pos, + embd, + next, + } = task; + let mut embd = meta.embd(nt).map(|size| { + let mut blob = Blob::new(size); + unsafe { copy_nonoverlapping(embd, blob.as_mut_ptr(), size) }; + blob + }); + let mut logits = meta.logits(if id == 0 { 1 } else { 0 }).map(Blob::new); + worker + .launch( + minicpm3::MiniCPM3Args { + embd: embd.map_slice_mut(), + logits: logits.map_slice_mut(), + sin_cos: sin_cos.map_slice(), + requests: vec![MiniCPM3Request { + cache: cache.map_slice_mut(), + seq_len: nt, + out_len: if id == 0 { 1 } else { 0 }, + pos, + }], + num_tokens: nt, + max_seq_len: nt, + max_att_len: nt + pos, + }, + &mut [], + &ThisThread, + ) + .unwrap(); + if id == 0 { + sample + .launch( + &mut pairs, + &logits, + &indices, + sample_args, + &mut [], + &ThisThread, + ) + .unwrap(); + next.send(pair.idx() as _).unwrap() + } + } + })) + }) + .collect::>(); + + let senders = senders.into_boxed_slice(); + barrier.wait(); + test_infer_paralle( + senders, + test_utils::AboutToken { + tokenizer, + token_embd: model.token_embd, + nvoc: model.meta.nvoc, + eos, + }, + &prompt, + max_steps, + ) + }) +} diff --git a/models/minicpm3/common-cpu/src/lib.rs b/models/minicpm3/common-cpu/src/lib.rs new file mode 100644 index 00000000..deb4bdca --- /dev/null +++ b/models/minicpm3/common-cpu/src/lib.rs @@ -0,0 +1,159 @@ +use common::{Contiguous, Distribution}; +use minicpm3::{MiniCPM3BlkStorage, MiniCPM3BlkWeight, MiniCPM3Storage, Tensor, WeightLoader}; +use operators::{ + all_reduce::{AllReduce, NonAllReduce}, + common_cpu::Cpu, + random_sample::common_cpu::Operator as RandomSampleCpu, + rearrange::common_cpu::Operator as Rearrange, + Blob, ByteOf, QueueOf, TopoNode, +}; +use std::{marker::PhantomData, ops::Deref}; + +pub struct Operators>(PhantomData<(N, R)>); + +pub type RandomSample = minicpm3::RandomSample; + +pub struct Weights<'w> { + blks: Box<[MiniCPM3BlkStorage>]>, + output_norm: &'w [u8], + output: &'w [u8], + long_factor: &'w [u8], + sort_factor: &'w [u8], +} + +macro_rules! op { + ($name:ident) => { + operators::$name::common_cpu::Operator + }; +} + +impl minicpm3::Operators for Operators +where + N: TopoNode, + R: AllReduce, +{ + type Hardware = Cpu; + type TopoNode = N; + type Rope = op!(rope); + type AttentionMLA = op!(attention_mla); + type RmsNorm = op!(rms_norm); + type Add = op!(add); + type MatMul = op!(mat_mul); + type Swiglu = op!(swiglu); + type Rearrange = op!(rearrange); + type Scale = op!(scale); + type AttMLACached = op!(attention_mla_cached); + type FuesdSoftmax = op!(fuesd_softmax); + type AllReduce = R; + + fn debug(tensor: &Tensor, _queue: &QueueOf) + where + T: Deref]>, + { + println!("{tensor}") + } +} + +impl<'w> Weights<'w> { + pub fn new(model: &'w MiniCPM3Storage<&'w [u8]>, dist: Distribution) -> Self { + let MiniCPM3Storage { + meta, + output_norm, + output, + blocks, + rope_long, + rope_short, + .. + } = model; + + let blks = blocks + .iter() + .map(|blk| { + blk.clone() + .into_vec() + .into_iter() + .map(|(which, data)| { + (which, meta.distribute_data(which, data, dist, Blob::new)) + }) + .collect::>() + }) + .collect(); + + Self { + blks, + output_norm, + output, + long_factor: rope_long, + sort_factor: rope_short, + } + } +} + +impl WeightLoader for Weights<'_> { + type Hardware = Cpu; + type Weight<'s> + = &'s [u8] + where + Self: 's; + + #[inline] + fn load_blk( + &self, + which: MiniCPM3BlkWeight, + iblk: usize, + _queue: &QueueOf, + ) -> Self::Weight<'_> { + let MiniCPM3BlkStorage { + attn_norm, + attn_qb, + attn_qa, + attn_kvb, + attn_kva, + attn_qa_norm, + attn_kva_norm, + attn_o, + ffn_norm, + ffn_gate_up, + ffn_down, + ffn_gate, + ffn_up, + } = &self.blks[iblk]; + use MiniCPM3BlkWeight as W; + match which { + W::AttnNorm => attn_norm, + W::AttnQB => attn_qb, + W::AttnQA => attn_qa, + W::AttnKvB => attn_kvb, + W::AttnKvA => attn_kva, + W::AttnQANorm => attn_qa_norm, + W::AttnKvANorm => attn_kva_norm, + W::AttnO => attn_o, + W::FfnNorm => ffn_norm, + W::FfnGateUp => ffn_gate_up, + W::FfnDown => ffn_down, + W::FfnGate => ffn_gate, + W::FfnUp => ffn_up, + } + } + + #[inline] + fn output_norm(&self, _queue: &QueueOf) -> Self::Weight<'_> { + self.output_norm + } + + #[inline] + fn output(&self, _queue: &QueueOf) -> Self::Weight<'_> { + self.output + } + #[inline] + fn long_factor<'a>(&'a self, _queue: &'a QueueOf) -> Self::Weight<'_> { + self.long_factor + } + #[inline] + fn short_factor<'a>(&'a self, _queue: &'a QueueOf) -> Self::Weight<'_> { + self.sort_factor + } +} + +#[cfg(test)] +mod infer; diff --git a/models/minicpm3/common/Cargo.toml b/models/minicpm3/common/Cargo.toml new file mode 100644 index 00000000..c90fea1d --- /dev/null +++ b/models/minicpm3/common/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "minicpm3" +version = "0.0.0" +edition = "2021" +authors = ["onenewcode ", "YdrMaster "] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +common.workspace = true +gguf.workspace = true +tensor.workspace = true +operators.workspace = true +itertools.workspace = true +half = "2.4" + +[dev-dependencies] +test-utils.workspace = true diff --git a/models/minicpm3/common/src/args.rs b/models/minicpm3/common/src/args.rs new file mode 100644 index 00000000..bc512965 --- /dev/null +++ b/models/minicpm3/common/src/args.rs @@ -0,0 +1,25 @@ +use operators::Hardware; +use tensor::Tensor; + +pub struct Args<'a, H: Hardware> { + /// shape: [nt, d] + pub embd: Tensor<&'a mut [H::Byte]>, + /// shape: [nout, nvoc] + pub logits: Tensor<&'a mut [H::Byte]>, + /// shape: [2, _, dh] + pub sin_cos: Tensor<&'a [H::Byte]>, + + pub requests: Vec>, + + pub num_tokens: usize, + pub max_seq_len: usize, + pub max_att_len: usize, +} + +pub struct Request<'a, H: Hardware> { + /// shape: [buf, nblk, nh, dkv+dr] + pub cache: Tensor<&'a mut [H::Byte]>, + pub seq_len: usize, + pub out_len: usize, + pub pos: usize, +} diff --git a/models/minicpm3/common/src/compute.rs b/models/minicpm3/common/src/compute.rs new file mode 100644 index 00000000..0cac757c --- /dev/null +++ b/models/minicpm3/common/src/compute.rs @@ -0,0 +1,882 @@ +use super::{args::Args, MiniCPM3BlkWeight, MiniCPM3Meta}; +use gguf::ggml_quants::digit_layout::types as ty; +use gguf::ggml_quants::digit_layout::DigitLayout; +use half::f16; +use itertools::izip; +use itertools::Itertools; +use operators::attention_mla_cached; +use operators::fuesd_softmax; +use operators::fuesd_softmax::FusedSoftmax; +use operators::scale; +use operators::scale::Scale; +use operators::{ + add::{self, Add}, + all_reduce::{self, AllReduce, ReduceOp}, + attention_kv_cached::AttnKVCached, + attention_mla::{self, AttentionMLA}, + attention_mla_cached::AttMLACached, + fuesd_softmax::AttnMask, + mat_mul::{self, MatMul}, + rearrange::{self, Rearrange}, + rms_norm::{self, RmsNorm}, + rope::{self, Rope, Seq, SinCosTable}, + swiglu::{self, Swiglu}, + ByteOf, Hardware, LaunchError, Operator, QueueAlloc, QueueOf, TopoNode, Workspace, +}; +use std::ops::{Deref, DerefMut}; +use std::process::Output; +use tensor::split_mut; +use tensor::{split, Tensor}; + +pub trait Operators { + type Hardware: Hardware; + type TopoNode: TopoNode; + type AttentionMLA: AttentionMLA; + type AttMLACached: AttMLACached; + type Rope: Rope; + type RmsNorm: RmsNorm; + type Add: Add; + type MatMul: MatMul; + type Swiglu: Swiglu; + type Scale: Scale; + type FuesdSoftmax: FusedSoftmax; + type Rearrange: Rearrange; + type AllReduce: AllReduce; + + fn debug(tensor: &Tensor, queue: &QueueOf) + where + T: Deref]>; + + fn build_sin_cos( + dt: DigitLayout, + nctx: usize, + dh: usize, + queue_alloc: &QA, + ) -> Tensor + where + QA: QueueAlloc, + { + let SinCosTable { nctx, mem } = + >::build_sincos(dt, nctx, dh, queue_alloc); + Tensor::new(dt, &[2, nctx, dh]).map(|_| mem) + } +} + +pub trait WeightLoader { + type Hardware: Hardware; + type Weight<'s>: Deref]> + 's + where + Self: 's; + + fn load_blk<'a>( + &'a self, + which: MiniCPM3BlkWeight, + iblk: usize, + queue: &'a QueueOf, + ) -> Self::Weight<'a>; + + fn output_norm<'a>(&'a self, queue: &'a QueueOf) -> Self::Weight<'a>; + fn output<'a>(&'a self, queue: &'a QueueOf) -> Self::Weight<'a>; + fn long_factor<'a>(&'a self, _queue: &'a QueueOf) -> Self::Weight<'a>; + fn short_factor<'a>(&'a self, _queue: &'a QueueOf) -> Self::Weight<'a>; +} + +pub struct Minicpm3Worker { + id: usize, + meta: MiniCPM3Meta, + weights: WeightDecorator, + dt_pos: DigitLayout, + add: Ops::Add, + attn_mla_cached: Ops::AttMLACached, + attention_mla: Ops::AttentionMLA, + rope: Ops::Rope, + rms_norm: Ops::RmsNorm, + mat_mul: Ops::MatMul, + scale: Ops::Scale, + swiglu: Ops::Swiglu, + fuesd_softmax: Ops::FuesdSoftmax, + rearrange: Ops::Rearrange, + all_reduce: Ops::AllReduce, +} + +impl Minicpm3Worker { + pub fn new(id: usize, node: &Ops::TopoNode, meta: MiniCPM3Meta, weights: W) -> Self { + let processor = node.processor(); + Self { + id, + weights: meta.decorator(weights), + meta, + attn_mla_cached: Ops::AttMLACached::new(processor), + rope: Ops::Rope::new(processor), + rms_norm: Ops::RmsNorm::new(processor), + mat_mul: Ops::MatMul::new(processor), + scale: Ops::Scale::new(processor), + swiglu: Ops::Swiglu::new(processor), + rearrange: Ops::Rearrange::new(processor), + add: Ops::Add::new(processor), + all_reduce: Ops::AllReduce::new(node), + dt_pos: ty::U64, + attention_mla: Ops::AttentionMLA::new(processor), + fuesd_softmax: Ops::FuesdSoftmax::new(processor), + } + } + + #[inline] + pub const fn meta(&self) -> &MiniCPM3Meta { + &self.meta + } +} + +impl Minicpm3Worker +where + Ops: Operators, + W: WeightLoader, + ByteOf: 'static, +{ + pub fn launch( + &mut self, + args: Args, + workspace: &mut [ByteOf], + queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + QA: QueueAlloc, + { + let Args { + embd: mut x, + mut logits, + mut requests, + num_tokens: nt, + sin_cos, + .. + } = args; + let MiniCPM3Meta { + nblk, + di, + dq_lora, + nh, + dk, + dh, + dkv_lora, + dv, + dt_embd, + .. + } = self.meta; + // llama.cpp 定义死 + let scale_emb = 12f32; + let scale_depth = 1.4f32; + // 残差连接时权重缩放 + let s = scale_depth / (nblk as f32).sqrt(); + + let dnope = dk - dh; + let tensor = |shape: &[usize]| Tensor::new(dt_embd, shape); + let x1 = tensor(x.shape()); + + let gate_up = tensor(&[nt, di * 2]); + // 空间 x+x1+q(应该可以删除)+q3+kv_pe+attn + let workspace_size = *x1.get() * 20 + *gate_up.get(); + let mut workspace = Workspace::new(queue_alloc, workspace, workspace_size); + let (buf, workspace) = workspace.split_at_mut(*x1.get()); + let mut x1 = x1.map(|_| buf); + + let queue = queue_alloc.queue(); + + let sin = sin_cos.clone().index(0, 0); + let cos = sin_cos.index(0, 1); + + let pos = Tensor::new(self.dt_pos, &[nt]).map(|_| { + Ops::Rope::build_pos( + self.dt_pos, + nt, + requests.iter().map(|req| Seq { + pos: req.pos, + len: req.seq_len, + }), + queue_alloc, + ) + }); + let req_split = requests.iter().map(|req| req.seq_len).collect::>(); + // 缩放 + let inplace = unsafe { x.map_slice_static() }; + self.scale(&mut x, &inplace, scale_emb, workspace, queue_alloc)?; + for iblk in 0..nblk { + // norm + let w = self.weights.attn_norm(iblk, queue); + self.rms_norm(&mut x1, &x, &w, workspace, queue_alloc)?; + drop(w); + let q = tensor(&[nt, dq_lora]); + let (buf, workspace) = workspace.split_at_mut(*q.get()); + let mut q = q.map(|_| buf); + let w = self.weights.attn_qa(iblk, queue).transpose(&[1, 0]); + self.mat_mul(&mut q, 0., &x1, &w, 1., workspace, queue_alloc)?; + + let inplace = unsafe { q.map_slice_static() }; + let w = self.weights.attn_qa_norm(iblk, queue); + self.rms_norm(&mut q, &inplace, &w, workspace, queue_alloc)?; + { + + let q1 = tensor(&[nt, nh * dk]); + let (buf, workspace) = workspace.split_at_mut(*q1.get()); + let mut q1 = q1.map(|_| buf); + let w = self.weights.attn_qb(iblk, queue).transpose(&[1, 0]); + self.mat_mul(&mut q1, 0., &q, &w, 1., workspace, queue_alloc)?; + + let mut q3 = q1.tile(1, &[nh, dk]); + let q2 = unsafe { q3.map_slice_static_mut() }; + split_mut!(q2=>q_nope, q_rope;[dnope, dh]@ 2); + + // kv_pe [1,288] + let kv_pe = tensor(&[nt, dkv_lora + dh]); + let (buf, workspace) = workspace.split_at_mut(*kv_pe.get()); + let mut kv_pe = kv_pe.map(|_| buf); + + let w = self.weights.attn_kva(iblk, queue).transpose(&[1, 0]); + self.mat_mul(&mut kv_pe, 0., &x1, &w, 1., workspace, queue_alloc)?; + drop(q); + split_mut!(kv_pe => kv_lora, k_rope; [dkv_lora, dh] @ 1); + let mut k_rope = k_rope.tile(1, &[1, dh]); + self.rope(&mut k_rope, &pos, &sin, &cos, workspace, queue_alloc)?; + self.rope(&mut q_rope, &pos, &sin, &cos, workspace, queue_alloc)?; + Ops::debug(&k_rope, queue); + todo!(); + let k_rope = k_rope.broadcast(1, nh); + + let inplace = unsafe { kv_lora.map_slice_static() }; + let w = self.weights.attn_kva_norm(iblk, queue); + self.rms_norm(&mut kv_lora, &inplace, &w, workspace, queue_alloc)?; + // kv X[1, 5120] + let kv = tensor(&[nt, nh * (dnope + dv)]); + let (buf, workspace) = workspace.split_at_mut(*kv.get()); + let mut kv = kv.map(|_| buf); + + let kv_b_proj = unsafe { + self.weights + .attn_kvb(iblk, queue) + .tile(0, &[nh, dnope + dv]) + .map_slice_static() + }; + split!(kv_b_proj=> q_absorb , out_absorb ; [dnope, dv] @ 1); + let inplace = unsafe { q_nope.map_slice_static() }; + + + let q_nope_0 = q_nope.map_slice().transpose(&[1, 0]); + let q_nope_1 = tensor(&[nh, nt, dkv_lora]); + let (buf, workspace) = workspace.split_at_mut(*q_nope_1.get()); + let mut q_nope = q_nope_1.map(|_| buf); + self.mat_mul( + &mut q_nope, + 0., + &q_nope_0, + &q_absorb, + 1., + workspace, + queue_alloc, + )?; + + // attn_output + let attn_output = tensor(&[nt, nh, dv]); + let (buf, workspace) = workspace.split_at_mut(*attn_output.get()); + let mut attn_output = attn_output.map(|_| buf); + let mut o = unsafe { attn_output.map_slice_static_mut().transpose(&[1, 0]) }; + + { + let q_rope = q_rope.transpose(&[1, 0]); + let k_rope = k_rope.transpose(&[1, 0]); + let kv_lora = kv_lora.map_slice().tile(0, &[1, nt]).broadcast(0, nh); + let q_nope = q_nope.split(1, &req_split); + let q_rope = q_rope.split(1, &req_split); + let k_rope = k_rope.split(1, &req_split); + let kv_lora = kv_lora.split(1, &req_split); + + for (mut qn, kv, qr, kr, req) in izip!(q_nope, kv_lora, q_rope, k_rope, &mut requests) + { + let cache = req + .cache + .as_mut() // [buf, nblk, nh, dkv+dr] + .index(1, iblk) // [buf, nh, dkv+dr] + .transpose(&[1, 0]) // [ nh,buf, dkv+dr] + .map(|t| &mut t[..]); + + split_mut!(cache => kvc, krc; [dkv_lora, dh] @ 2); + + self.attn_mla_cache(&mut qn, &kv, &kr, &qr, &out_absorb, &mut o, &mut kvc, &mut krc, req.pos, workspace, queue_alloc)?; + } + } + + let o = attn_output.map_slice().merge(1..3).unwrap(); + let w = self.weights.attn_o(iblk, queue); + + self.mat_mul(&mut x1, 0., &o, &w, 1., workspace, queue_alloc)?; + } + let inplace = unsafe { x.map_slice_static() }; + + self.add(&mut x, &inplace, &x1, workspace, queue_alloc)?; + self.all_reduce(&mut x, workspace, queue_alloc)?; + let w = self.weights.ffn_norm(iblk, queue); + self.rms_norm(&mut x1, &x, &w, workspace, queue_alloc)?; + + drop(w); + + let (buf, workspace) = workspace.split_at_mut(*gate_up.get()); + let gate_up = gate_up.clone().map(|_| buf); + split!(gate_up => gate, up; [di, di] @ 1); + let mut gate = gate; + let mut up = up; + let w = self.weights.ffn_gate(iblk, queue); + self.mat_mul(&mut gate, 0., &x1, &w, 1., workspace, queue_alloc)?; + + let w = self.weights.ffn_up(iblk, queue); + self.mat_mul(&mut up, 0., &x1, &w, 1., workspace, queue_alloc)?; + + self.swiglu(&mut gate, &up, workspace, queue_alloc)?; + + let w = self.weights.ffn_down(iblk, queue); + self.mat_mul(&mut x1, 0., &gate, &w, s, workspace, queue_alloc)?; + + let inplace = unsafe { x.map_slice_static() }; + self.add(&mut x, &inplace, &x1, workspace, queue_alloc)?; + if iblk==0{ + Ops::debug(&x, queue); + todo!(); + } + + self.all_reduce(&mut x, workspace, queue_alloc)? + } + if logits.shape()[0] == 0 { + return Ok(()); + } + + // 集中要采样的 token + // NOTICE: 输入之前将请求按 seq len 升序排列可降低移动开销 + let mut dst = 0; + let mut src = 0; + for req in &requests { + src += req.seq_len; + for src in src - req.out_len..src { + if src != dst { + let src = unsafe { x.map_slice_static() }.index(0, src); + let mut dst = x.map_slice_mut().index(0, dst); + self.rearrange(&mut dst, &src, workspace, queue_alloc)? + } + dst += 1 + } + } + assert_eq!(dst, logits.shape()[0]); + + let mut x = x.map_slice_mut().slice(0, 0, 1, dst); + { + let inplace = unsafe { x.map_slice_static() }; + let w = self.weights.output_norm(queue); + self.rms_norm(&mut x, &inplace, &w, workspace, queue_alloc)? + } + let w = self.weights.output(queue); + + self.mat_mul(&mut logits, 0., &x, &w, 1., workspace, queue_alloc) + } +} + +#[allow(clippy::too_many_arguments)] +impl Minicpm3Worker +where + Ops: Operators, + W: WeightLoader, +{ + fn rms_norm( + &self, + y: &mut Tensor, + x: &Tensor, + w: &Tensor, + workspace: &mut [ByteOf], + queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + Y: DerefMut]>, + X: Deref]>, + W_: Deref]>, + QA: QueueAlloc, + { + self.rms_norm.launch( + &rms_norm::Args { + y_layout: y.layout(), + y_base: y.base_mut(), + x_layout: x.layout(), + x_base: x.base(), + w_layout: w.layout(), + w_base: w.base(), + epsilon: self.meta.epsilon, + }, + workspace, + queue_alloc, + ) + } + + fn mat_mul( + &self, + c: &mut Tensor, + beta: f32, + a: &Tensor, + b: &Tensor, + alpha: f32, + workspace: &mut [ByteOf], + queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + C: DerefMut]>, + A: Deref]>, + B: Deref]>, + QA: QueueAlloc, + { + self.mat_mul.launch( + &mat_mul::Args { + c_layout: c.layout(), + c_base: c.base_mut(), + beta, + a_layout: a.layout(), + a_base: a.base(), + b_layout: b.layout(), + b_base: b.base(), + alpha, + }, + workspace, + queue_alloc, + ) + } + fn rope( + &self, + t: &mut Tensor, + p: &Tensor

, + sin: &Tensor, + cos: &Tensor, + workspace: &mut [ByteOf], + queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + T: DerefMut]>, + P: Deref]>, + Sin: Deref]>, + Cos: Deref]>, + QA: QueueAlloc, + { + let [long, short] = self.weights.factor(queue_alloc.queue()); + self.rope.launch( + &rope::Args { + t_layout: t.layout(), + t_base: t.base_mut(), + p_layout: p.layout(), + p_base: p.base(), + sin_layout: sin.layout(), + sin_base: sin.base(), + cos_layout: cos.layout(), + cos_base: cos.base(), + theta: self.meta.theta, + // TODO + rope_type: rope::RopeType::Long { + long: long.base(), + short: short.base(), + max_pos: 100, + origin_pos: 100, + }, + }, + workspace, + queue_alloc, + ) + } + fn attn_mla_cache( + &self, + q: &mut Tensor, + kv: &Tensor, + kr: &Tensor, + qr: &Tensor, + a: &Tensor, + o: &mut Tensor, + kvc: &mut Tensor, + krc: &mut Tensor, + pos: usize, + workspace: &mut [ByteOf], + queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + Q: DerefMut]>, + KV: Deref]>, + KR: Deref]>, + QR: Deref]>, + A: Deref]>, + O: DerefMut]>, + KVC: DerefMut]>, + KRC: DerefMut]>, + QA: QueueAlloc, + { + self.attn_mla_cached.launch( + &attention_mla_cached::Args { + q_layout: q.layout(), + q_base: q.base_mut(), + kv_layout: kv.layout(), + kv_base: kv.base(), + absorb_layout: a.layout(), + absorb_base: a.base(), + qr_layout: qr.layout(), + qr_base: qr.base(), + kr_layout: kr.layout(), + kr_base: kr.base(), + o_layout: o.layout(), + o_base: o.base_mut(), + kv_cache_layout: kvc.layout(), + kv_cache_base: kvc.base_mut(), + kr_cache_layout: krc.layout(), + kr_cache_base: krc.base_mut(), + mask: AttnMask::Causal, + pos: pos.into(), + }, + workspace, + queue_alloc, + ) + } + fn swiglu( + &self, + gate: &mut Tensor, + up: &Tensor, + workspace: &mut [ByteOf], + queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + Gate: DerefMut]>, + Up: DerefMut]>, + QA: QueueAlloc, + { + self.swiglu.launch( + &swiglu::Args { + gate_layout: gate.layout(), + gate_base: gate.base_mut(), + up_layout: up.layout(), + up_base: up.base(), + }, + workspace, + queue_alloc, + ) + } + + fn rearrange( + &self, + dst: &mut Tensor, + src: &Tensor, + workspace: &mut [ByteOf], + queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + Y: DerefMut]>, + X: Deref]>, + QA: QueueAlloc, + { + self.rearrange.launch( + &rearrange::Args { + dst_layout: dst.layout(), + dst_base: dst.base_mut(), + src_layout: src.layout(), + src_base: src.base(), + }, + workspace, + queue_alloc, + ) + } + + fn add( + &self, + c: &mut Tensor, + a: &Tensor, + b: &Tensor, + workspace: &mut [ByteOf], + queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + C: DerefMut]>, + A: Deref]>, + B: Deref]>, + QA: QueueAlloc, + { + self.add.launch( + &add::Args { + c_layout: c.layout(), + c_base: c.base_mut(), + a_layout: a.layout(), + a_base: a.base(), + b_layout: b.layout(), + b_base: b.base(), + }, + workspace, + queue_alloc, + ) + } + fn scale( + &self, + c: &mut Tensor, + a: &Tensor, + s: f32, + workspace: &mut [ByteOf], + queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + C: DerefMut]>, + A: Deref]>, + QA: QueueAlloc, + { + self.scale.launch( + &scale::Args { + c_layout: c.layout(), + c_base: c.base_mut(), + a_layout: a.layout(), + a_base: a.base(), + s, + }, + workspace, + queue_alloc, + ) + } + fn softmax( + &self, + a: &mut Tensor, + workspace: &mut [ByteOf], + queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + A: DerefMut]>, + QA: QueueAlloc, + { + self.fuesd_softmax.launch( + &fuesd_softmax::Args { + att_mask: AttnMask::Causal, + att_layout: a.layout(), + att_base: a.base_mut(), + }, + workspace, + queue_alloc, + ) + } + fn all_reduce( + &self, + x: &mut Tensor, + workspace: &mut [ByteOf], + queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + X: DerefMut]>, + QA: QueueAlloc, + { + self.all_reduce.launch( + &all_reduce::Args { + pair: rearrange::Args { + dst_layout: x.layout(), + dst_base: x.base_mut(), + src_layout: x.layout(), + src_base: x.base(), + }, + op: ReduceOp::Sum, + }, + workspace, + queue_alloc, + ) + } +} + +struct WeightDecorator { + norm: Tensor, + attn_qb: Tensor, + attn_qa: Tensor, + attn_kvb: Tensor, + attn_kva_mqa: Tensor, + attn_qa_norm: Tensor, + attn_kva_norm: Tensor, + attn_o: Tensor, + ffn_gate_up: Tensor, + ffn_down: Tensor, + factor: [Tensor; 2], + ffn_gate: Tensor, + ffn_up: Tensor, + output: Tensor, + weights: W, +} + +impl MiniCPM3Meta { + fn decorator(&self, weights: W) -> WeightDecorator { + use crate::TensorUsage::Computation; + WeightDecorator { + norm: self.norm(), + attn_qa: self.attn_qa(Computation), + attn_qb: self.attn_qb(Computation), + attn_kvb: self.attn_kvb(Computation), + attn_kva_mqa: self.attn_kva(Computation), + attn_qa_norm: self.attn_qa_norm(), + attn_kva_norm: self.attn_kva_norm(), + attn_o: self.attn_o(Computation), + ffn_gate_up: self.ffn_gate_up(Computation), + ffn_down: self.ffn_down(Computation), + ffn_gate: self.ffn(Computation), + ffn_up: self.ffn(Computation), + factor: self.factor(), + output: self.output(), + weights, + } + } +} + +impl WeightDecorator { + #[inline] + pub fn attn_norm<'a>( + &'a self, + iblk: usize, + queue: &'a QueueOf, + ) -> Tensor> { + let w = self + .weights + .load_blk(MiniCPM3BlkWeight::AttnNorm, iblk, queue); + self.norm.clone().map(|_| w) + } + #[inline] + pub fn attn_qa<'a>( + &'a self, + iblk: usize, + queue: &'a QueueOf, + ) -> Tensor> { + let w = self + .weights + .load_blk(MiniCPM3BlkWeight::AttnQA, iblk, queue); + self.attn_qa.clone().map(|_| w) + } + #[inline] + pub fn attn_qb<'a>( + &'a self, + iblk: usize, + queue: &'a QueueOf, + ) -> Tensor> { + let w = self + .weights + .load_blk(MiniCPM3BlkWeight::AttnQB, iblk, queue); + self.attn_qb.clone().map(|_| w) + } + + #[inline] + pub fn attn_kvb<'a>( + &'a self, + iblk: usize, + queue: &'a QueueOf, + ) -> Tensor> { + let w = self + .weights + .load_blk(MiniCPM3BlkWeight::AttnKvB, iblk, queue); + self.attn_kvb.clone().map(|_| w) + } + pub fn attn_kva<'a>( + &'a self, + iblk: usize, + queue: &'a QueueOf, + ) -> Tensor> { + let w = self + .weights + .load_blk(MiniCPM3BlkWeight::AttnKvA, iblk, queue); + self.attn_kva_mqa.clone().map(|_| w) + } + #[inline] + pub fn attn_qa_norm<'a>( + &'a self, + iblk: usize, + queue: &'a QueueOf, + ) -> Tensor> { + let w = self + .weights + .load_blk(MiniCPM3BlkWeight::AttnQANorm, iblk, queue); + self.attn_qa_norm.clone().map(|_| w) + } + #[inline] + pub fn attn_kva_norm<'a>( + &'a self, + iblk: usize, + queue: &'a QueueOf, + ) -> Tensor> { + let w = self + .weights + .load_blk(MiniCPM3BlkWeight::AttnKvANorm, iblk, queue); + self.attn_kva_norm.clone().map(|_| w) + } + #[inline] + pub fn attn_o<'a>( + &'a self, + iblk: usize, + queue: &'a QueueOf, + ) -> Tensor> { + let w = self.weights.load_blk(MiniCPM3BlkWeight::AttnO, iblk, queue); + self.attn_o.clone().map(|_| w) + } + + #[inline] + pub fn ffn_norm<'a>( + &'a self, + iblk: usize, + queue: &'a QueueOf, + ) -> Tensor> { + let w = self + .weights + .load_blk(MiniCPM3BlkWeight::FfnNorm, iblk, queue); + self.norm.clone().map(|_| w) + } + + #[inline] + pub fn ffn_gate_up<'a>( + &'a self, + iblk: usize, + queue: &'a QueueOf, + ) -> Tensor> { + const WHICH: MiniCPM3BlkWeight = MiniCPM3BlkWeight::FfnGateUp; + let w = self.weights.load_blk(WHICH, iblk, queue); + self.ffn_gate_up.clone().map(|_| w) + } + #[inline] + pub fn ffn_gate<'a>( + &'a self, + iblk: usize, + queue: &'a QueueOf, + ) -> Tensor> { + const WHICH: MiniCPM3BlkWeight = MiniCPM3BlkWeight::FfnGate; + let w = self.weights.load_blk(WHICH, iblk, queue); + self.ffn_gate.clone().map(|_| w) + } + #[inline] + pub fn ffn_up<'a>( + &'a self, + iblk: usize, + queue: &'a QueueOf, + ) -> Tensor> { + const WHICH: MiniCPM3BlkWeight = MiniCPM3BlkWeight::FfnUp; + let w = self.weights.load_blk(WHICH, iblk, queue); + self.ffn_up.clone().map(|_| w) + } + #[inline] + pub fn ffn_down<'a>( + &'a self, + iblk: usize, + queue: &'a QueueOf, + ) -> Tensor> { + const WHICH: MiniCPM3BlkWeight = MiniCPM3BlkWeight::FfnDown; + let w = self.weights.load_blk(WHICH, iblk, queue); + self.ffn_down.clone().map(|_| w) + } + + #[inline] + pub fn output_norm<'a>(&'a self, queue: &'a QueueOf) -> Tensor> { + self.norm.clone().map(|_| self.weights.output_norm(queue)) + } + + #[inline] + pub fn output<'a>(&'a self, queue: &'a QueueOf) -> Tensor> { + self.output.clone().map(|_| self.weights.output(queue)) + } + #[inline] + fn factor<'a>(&'a self, queue: &'a QueueOf) -> [Tensor>; 2] { + [ + self.factor[0] + .clone() + .map(|_| self.weights.long_factor(queue)), + self.factor[1] + .clone() + .map(|_| self.weights.short_factor(queue)), + ] + } +} diff --git a/models/minicpm3/common/src/lib.rs b/models/minicpm3/common/src/lib.rs new file mode 100644 index 00000000..ea190d52 --- /dev/null +++ b/models/minicpm3/common/src/lib.rs @@ -0,0 +1,230 @@ +mod args; +mod compute; +mod storage; + +use common::Distribution; +use gguf::ggml_quants::digit_layout::DigitLayout; + +pub use args::{Args as MiniCPM3Args, Request as MiniCPM3Request}; +pub use compute::{Minicpm3Worker, Operators, WeightLoader}; +pub use storage::{BlkStorage as MiniCPM3BlkStorage, Storage as MiniCPM3Storage}; +pub use tensor::{RandomSample, Tensor}; +pub mod ext { + pub use gguf::{ + ext::{utok, Mmap}, + ggml_quants, + }; +} + +#[derive(Clone, Debug)] +pub struct MiniCPM3Meta { + pub dt_embd: DigitLayout, + pub dt_norm: DigitLayout, + pub dt_linear: DigitLayout, + + pub nblk: usize, + pub nctx: usize, + pub nvoc: usize, + pub nh: usize, + pub nkvh: usize, + pub d: usize, + pub dh: usize, + pub di: usize, + + pub dq_lora: usize, + pub dkv_lora: usize, + pub dk: usize, + pub dv: usize, + pub dnope: usize, + + pub epsilon: f32, + pub theta: f32, +} + +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +pub enum TensorUsage { + Storage, + Computation, +} + +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +pub enum MiniCPM3BlkWeight { + AttnNorm, + AttnQB, + AttnQA, + AttnKvB, + AttnKvA, + AttnQANorm, + AttnKvANorm, + AttnO, + FfnNorm, + FfnGateUp, + FfnGate, + FfnUp, + FfnDown, +} + +impl MiniCPM3Meta { + pub fn distribute(&self, dist: Distribution) -> Self { + let [_, len, total] = dist.info(); + assert_eq!(self.nkvh % total, 0); + assert_eq!(self.di % total, 0); + + Self { + nh: self.nh / total * len, + nkvh: self.nkvh / total * len, + di: self.di / total * len, + ..self.clone() + } + } + + pub fn blk(&self) -> MiniCPM3BlkStorage { + use TensorUsage::Storage as TensorMem; + let norm = self.norm().take(); + MiniCPM3BlkStorage { + attn_norm: norm, + attn_qa: self.attn_qa(TensorMem).take(), + attn_qb: self.attn_qb(TensorMem).take(), + attn_kvb: self.attn_kvb(TensorMem).take(), + attn_kva: self.attn_kva(TensorMem).take(), + attn_qa_norm: self.attn_qa_norm().take(), + attn_kva_norm: self.attn_kva_norm().take(), + attn_o: self.attn_o(TensorMem).take(), + ffn_norm: norm, + ffn_gate_up: self.ffn_gate_up(TensorMem).take(), + ffn_down: self.ffn_down(TensorMem).take(), + ffn_gate: self.ffn(TensorMem).take(), + ffn_up: self.ffn(TensorMem).take(), + } + } + + pub fn kv_cache(&self, buf: usize) -> Tensor { + let &Self { + dt_embd, + nblk, + nh, + dk, + dkv_lora, + dh, + .. + } = self; + Tensor::new(dt_embd, &[buf, nblk, nh, dkv_lora+dh]) + } + + pub fn embd(&self, nt: usize) -> Tensor { + let &Self { dt_embd, d, .. } = self; + Tensor::new(dt_embd, &[nt, d]) + } + + pub fn logits(&self, nt: usize) -> Tensor { + let &Self { dt_embd, nvoc, .. } = self; + Tensor::new(dt_embd, &[nt, nvoc]) + } + pub fn token_embd(&self) -> Tensor { + self.embd(self.nvoc) + } + pub fn norm(&self) -> Tensor { + let &Self { dt_norm, d, .. } = self; + Tensor::new(dt_norm, &[d]) + } + // TODO 未实现分布 + pub fn attn_qa(&self, usage: TensorUsage) -> Tensor { + let &Self { + dt_embd, + dq_lora, + d, + .. + } = self; + Tensor::new(dt_embd, &[dq_lora, d]) + } + pub fn attn_qa_norm(&self) -> Tensor { + let &Self { + dt_norm, dq_lora, .. + } = self; + Tensor::new(dt_norm, &[dq_lora]) + } + // TODO 未实现分布 + pub fn attn_qb(&self, usage: TensorUsage) -> Tensor { + let &Self { + dt_embd, + nh, + dk, + dq_lora, + .. + } = self; + Tensor::new(dt_embd, &[nh * dk, dq_lora]) + } + + // TODO 为实现分布式 + pub fn attn_kvb(&self, usage: TensorUsage) -> Tensor { + let &Self { + dt_embd, + nh, + dkv_lora, + dnope, + dv, + .. + } = self; + Tensor::new(dt_embd, &[nh * (dnope + dv), dkv_lora]) + } + // TODO 为实现分布式 + pub fn attn_kva(&self, usage: TensorUsage) -> Tensor { + let &Self { + dt_embd, + dkv_lora, + dh, + d, + .. + } = self; + Tensor::new(dt_embd, &[dkv_lora + dh, d]) + } + pub fn attn_kva_norm(&self) -> Tensor { + let &Self { + dt_norm, dkv_lora, .. + } = self; + Tensor::new(dt_norm, &[dkv_lora]) + } + pub fn attn_o(&self, usage: TensorUsage) -> Tensor { + let &Self { nh, d, dh, .. } = self; + self.mat(d, d, usage) + } + + pub fn ffn_gate_up(&self, usage: TensorUsage) -> Tensor { + let &Self { d, di, .. } = self; + self.mat(di + di, d, usage) + } + + pub fn ffn_down(&self, usage: TensorUsage) -> Tensor { + let &Self { d, di, .. } = self; + self.mat(d, di, usage) + } + pub fn ffn(&self, usage: TensorUsage) -> Tensor { + let &Self { d, di, .. } = self; + self.mat(di, d, usage) + } + // TODO + pub fn factor(&self) -> [Tensor; 2] { + let &Self { dt_norm, dh, .. } = self; + [ + Tensor::new(dt_norm, &[dh / 2]), + Tensor::new(dt_norm, &[dh / 2]), + ] + } + pub fn output(&self) -> Tensor { + self.token_embd().transpose(&[1, 0]) + } + + fn mat(&self, row: usize, col: usize, usage: TensorUsage) -> Tensor { + let &Self { + dt_embd, dt_linear, .. + } = self; + // NOTICE: 权重矩阵以 mat 类型存储但以 embd 类型参与计算 + match usage { + TensorUsage::Storage => Tensor::new(dt_linear, &[row, col / dt_linear.group_size()]), + TensorUsage::Computation => { + assert_eq!(dt_embd.group_size(), 1); + Tensor::new(dt_embd, &[row, col]).transpose(&[1, 0]) + } + } + } +} diff --git a/models/minicpm3/common/src/storage.rs b/models/minicpm3/common/src/storage.rs new file mode 100644 index 00000000..7f2e94c4 --- /dev/null +++ b/models/minicpm3/common/src/storage.rs @@ -0,0 +1,308 @@ +use crate::{MiniCPM3BlkWeight, MiniCPM3Meta}; +use common::{borrow, own, Contiguous, Distribution}; +use gguf::{GGufMetaMapExt, GGufModel}; +use std::ops::DerefMut; +use tensor::{rearrange, split, Tensor}; + +#[derive(Clone)] +pub struct Storage { + pub meta: MiniCPM3Meta, + pub token_embd: T, + pub output_norm: T, + pub output: T, + pub rope_long: T, + pub rope_short: T, + pub blocks: Box<[BlkStorage]>, +} + +#[derive(Clone)] +pub struct BlkStorage { + pub attn_norm: T, + pub attn_qb: T, + pub attn_qa: T, + pub attn_kvb: T, + pub attn_kva: T, + pub attn_qa_norm: T, + pub attn_kva_norm: T, + pub attn_o: T, + pub ffn_norm: T, + pub ffn_gate_up: T, + pub ffn_gate: T, + pub ffn_up: T, + pub ffn_down: T, +} + +impl<'a> Storage<&'a [u8]> { + pub fn from_gguf(gguf: &GGufModel<'a>) -> Self { + use gguf::{meta, tensor}; + assert_eq!(meta!(gguf => general_architecture), "minicpm3"); + let token_embd = tensor![gguf => "token_embd.weight"]; + let output_norm = tensor![gguf => "output_norm.weight"]; + let rope_long = tensor![gguf => "rope_factors_long.weight"]; + let rope_short = tensor![gguf => "rope_factors_short.weight"]; + let attn_o0 = tensor![gguf => "blk.0.attn_output.weight"]; + let dv = tensor![gguf => format!("blk.0.attn_kv_b.weight" )].shape[1]; + let dk = meta![gguf => (usize) "minicpm3.attention.key_length" ]; + let d = meta![gguf => llm_embedding_length]; + let nh = meta![gguf => llm_attention_head_count]; + + let dh = meta![gguf => llm_rope_dimension_count; d / nh]; + #[rustfmt::skip] + let meta = MiniCPM3Meta { + dt_embd : token_embd.ty, + dt_norm : output_norm.ty, + dt_linear : attn_o0.ty, + + nctx : meta![gguf => llm_context_length ], + nvoc : meta![gguf => tokenizer_ggml_tokens].len(), + d, nh,dh, + nblk: meta![gguf => llm_block_count ], + nkvh: meta![gguf => llm_attention_head_count_kv; nh], + di : meta![gguf => llm_feed_forward_length ], + + dq_lora: meta![gguf => (usize) "minicpm3.attention.q_lora_rank" ], + dkv_lora:meta![gguf => (usize) "minicpm3.attention.kv_lora_rank"], + dk, + dv:(dv/nh)-dk+dh, + dnope: dk-dh, + epsilon: meta!(gguf => llm_attention_layer_norm_rms_epsilon; 1e-5), + theta : meta!(gguf => llm_rope_freq_base ; 1e4 ), + }; + #[rustfmt::skip] + let blocks = (0..meta.nblk) + .map(|i| BlkStorage { + attn_norm : tensor![gguf => format!("blk.{i}.attn_norm.weight" )].data, + attn_qb : tensor![gguf => format!("blk.{i}.attn_q_b.weight" )].data, + attn_qa : tensor![gguf => format!("blk.{i}.attn_q_a.weight" )].data, + attn_kvb : tensor![gguf => format!("blk.{i}.attn_kv_b.weight" )].data, + attn_kva : tensor![gguf => format!("blk.{i}.attn_kv_a_mqa.weight" )].data, + attn_qa_norm : tensor![gguf => format!("blk.{i}.attn_q_a_norm.weight" )].data, + attn_kva_norm: tensor![gguf => format!("blk.{i}.attn_kv_a_norm.weight" )].data, + attn_o : tensor![gguf => format!("blk.{i}.attn_output.weight" )].data, + ffn_norm : tensor![gguf => format!("blk.{i}.ffn_norm.weight" )].data, + // TODO 待修改 gguf 字段名称应该为 ffn_gate_up + ffn_gate_up : tensor![gguf => format!("blk.{i}.attn_output.weight" )].data, + ffn_gate: tensor![gguf => format!("blk.{i}.ffn_gate.weight" )].data, + ffn_up: tensor![gguf => format!("blk.{i}.ffn_up.weight" )].data, + ffn_down : tensor![gguf => format!("blk.{i}.ffn_down.weight" )].data, + }) + .collect(); + + Self { + meta, + token_embd: token_embd.data, + output_norm: output_norm.data, + output: gguf.tensors.get("output.weight").unwrap_or(token_embd).data, + blocks, + rope_long: rope_long.data, + rope_short: rope_short.data, + } + } +} + +impl BlkStorage { + #[rustfmt::skip] + pub fn into_vec(self) -> Vec<(MiniCPM3BlkWeight, T)> { + use MiniCPM3BlkWeight as W; + vec![ + (W::AttnNorm , self.attn_norm ), + (W::AttnQB , self.attn_qb ), + (W::AttnQA , self.attn_qa ), + (W::AttnKvB , self.attn_kvb ), + (W::AttnKvA , self.attn_kva ), + (W::AttnQANorm , self.attn_qa_norm ), + (W::AttnKvANorm , self.attn_kva_norm), + (W::AttnO , self.attn_o ), + (W::FfnNorm , self.ffn_norm ), + (W::FfnGateUp , self.ffn_gate_up ), + (W::FfnDown , self.ffn_down ), + (W::FfnGate , self.ffn_gate ), + (W::FfnUp , self.ffn_up ), + ] + } +} + +impl FromIterator<(MiniCPM3BlkWeight, T)> for BlkStorage { + #[rustfmt::skip] + fn from_iter(iter: U) -> Self + where + U: IntoIterator, + { + let mut collector: BlkStorage> = BlkStorage { + attn_norm : None, + attn_o : None, + ffn_norm : None, + ffn_gate_up : None, + ffn_gate: None, + ffn_up: None, + ffn_down : None, + attn_qb : None, + attn_qa : None, + attn_kvb : None, + attn_kva : None, + attn_qa_norm : None, + attn_kva_norm: None, + }; + for (which, data) in iter { + use MiniCPM3BlkWeight as W; + match which { + W::AttnNorm => collector.attn_norm = Some(data), + W::AttnQB => collector.attn_qb = Some(data), + W::AttnQA => collector.attn_qa = Some(data), + W::AttnKvB => collector.attn_kvb = Some(data), + W::AttnKvA => collector.attn_kva = Some(data), + W::AttnQANorm => collector.attn_qa_norm = Some(data), + W::AttnKvANorm => collector.attn_kva_norm = Some(data), + W::AttnO => collector.attn_o = Some(data), + W::FfnNorm => collector.ffn_norm = Some(data), + W::FfnGateUp => collector.ffn_gate_up = Some(data), + W::FfnDown => collector.ffn_down = Some(data), + W::FfnGate => collector.ffn_gate = Some(data), + W::FfnUp => collector.ffn_up = Some(data), + }; + } + BlkStorage { + attn_norm : collector.attn_norm .unwrap(), + attn_qb : collector.attn_qb .unwrap(), + attn_qa : collector.attn_qa .unwrap(), + attn_kvb : collector.attn_kvb .unwrap(), + attn_kva : collector.attn_kva .unwrap(), + attn_qa_norm : collector.attn_qa_norm .unwrap(), + attn_kva_norm: collector.attn_kva_norm.unwrap(), + attn_o : collector.attn_o .unwrap(), + ffn_norm : collector.ffn_norm .unwrap(), + ffn_gate_up : collector.ffn_gate_up .unwrap(), + ffn_down : collector.ffn_down .unwrap(), + ffn_gate: collector.ffn_gate .unwrap(), + ffn_up: collector.ffn_up .unwrap(), + } + } +} + +impl MiniCPM3Meta { + pub fn distribute_data<'w, U>( + &self, + which: MiniCPM3BlkWeight, + data: &'w [u8], + dist: Distribution, + mut f: impl FnMut(usize) -> U, + ) -> Contiguous<'w, U> + where + U: DerefMut, + { + use crate::TensorUsage::Storage as TensorMem; + use MiniCPM3BlkWeight as W; + match which { + W::AttnQB + | W::AttnQA + | W::AttnKvB + | W::AttnKvA + | W::AttnQANorm + | W::AttnKvANorm + | W::FfnGate + | W::FfnUp => borrow(data), + W::AttnNorm | W::FfnNorm => borrow(data), + _ if dist.is_mono() || data.is_empty() => borrow(data), + W::AttnO => { + let [start, len, total] = dist.info(); + let o = self.attn_o(TensorMem).map(|_| data); + + let d = o.shape()[1] / total; + let o = o.slice(1, d * start, 1, d * len); + + let mut o_ = Tensor::new(o.dt(), o.shape()).map(&mut f); + rearrange(&mut o_, &o); + own(o_.take()) + } + W::FfnGateUp => { + let &MiniCPM3Meta { di, .. } = self; + let [start, len, total] = dist.info(); + let dist = self.distribute(dist); + + let gu = self.ffn_gate_up(TensorMem).map(|_| data); + split!(gu => g, u; [di, di] @ 1); + + let di = di / total; + + let g = g.slice(1, di * start, 1, di * len); + let u = u.slice(1, di * start, 1, di * len); + + let mut ans = dist.ffn_gate_up(TensorMem).map(&mut f); + { + let ans = ans.map_slice_mut(); + split!(ans => g_, u_; [di * len , di * len] @ 1); + let mut g_ = g_; + let mut u_ = u_; + rearrange(&mut g_, &g); + rearrange(&mut u_, &u); + } + own(ans.take()) + } + W::FfnDown => { + let [start, len, total] = dist.info(); + let down = self.ffn_down(TensorMem).map(|_| data); + + let d = down.shape()[2] / total; + let down = down.slice(2, d * start, 1, d * len); + + let mut down_ = Tensor::new(down.dt(), down.shape()).map(&mut f); + rearrange(&mut down_, &down); + own(down_.take()) + } + } + } + + pub fn distribute_qkv<'w, U>( + &self, + dist: Distribution, + dst: Tensor, + src: Tensor<&'w [u8]>, + ) -> Contiguous<'w, U> + where + U: DerefMut, + { + let &MiniCPM3Meta { nh, nkvh, dh, .. } = self; + let [start, len, total] = dist.info(); + + let dq = nh * dh; + let dkv = nkvh * dh; + + let qkv = src; + split!(qkv => q, k, v; [dq, dkv, dkv] @ 0); + + let dq = dq / total; + let dkv = dkv / total; + + let q = q.slice(0, dq * start, 1, dq * len); + let k = k.slice(0, dkv * start, 1, dkv * len); + let v = v.slice(0, dkv * start, 1, dkv * len); + debug_assert!(q.is_contiguous() && k.is_contiguous() && v.is_contiguous()); + + let mut ans = dst; + { + let ans = ans.map_slice_mut(); + split!(ans => q_, k_, v_; [dq * len , dkv * len, dkv * len] @ 0); + let mut q_ = q_; + let mut k_ = k_; + let mut v_ = v_; + rearrange(&mut q_, &q); + rearrange(&mut k_, &k); + rearrange(&mut v_, &v); + } + own(ans.take()) + } +} +#[test] +fn test() { + use test_utils::Inference; + std::env::set_var( + "TEST_MODEL", + "/home/ztf/cpm/Origin-MiniCPM3-4B-v0.0-F16.gguf", + ); + let Some(Inference { model, .. }) = Inference::load() else { + return; + }; + let gguf = GGufModel::read(model.iter().map(|s| &**s)); + let storage = Storage::from_gguf(&gguf); + println!("{:#?}", storage.meta); +} diff --git a/tensor/src/split.rs b/tensor/src/split.rs index ebf775e4..1841c41c 100644 --- a/tensor/src/split.rs +++ b/tensor/src/split.rs @@ -14,7 +14,6 @@ impl Splitable for &[T] { self } } - impl Splitable for &mut [T] { #[inline] fn split(&self) -> Self { @@ -68,3 +67,12 @@ macro_rules! split { assert!(parts.next().is_none()); }; } +#[macro_export] +macro_rules! split_mut { + ($tensor:expr => $( $name:ident ),+; [$( $part:expr ),+] @ $axis:expr) => { + let parts = [$($part),+]; + let mut parts = $tensor.split($axis, &parts); + $( let mut $name = parts.next().unwrap(); )+ + assert!(parts.next().is_none()); + }; +}