Skip to content

Commit db22ce2

Browse files
committed
finish
1 parent 3d8b052 commit db22ce2

File tree

10 files changed

+1806
-1
lines changed

10 files changed

+1806
-1
lines changed

Cargo.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ members = [
1717
"models/gpt2/common",
1818
"models/gpt2/common-cpu",
1919
"models/gpt2/cuda",
20+
21+
"models/minicpm3/common",
22+
"models/minicpm3/common-cpu",
2023
]
2124
resolver = "2"
2225

@@ -38,7 +41,7 @@ itertools = "0.13"
3841
env_logger = "0.11"
3942
build-script-cfg = "0.0"
4043

41-
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "61789f7", default-features = false }
44+
operators = { git = "https://github.com/onenewcode/operators-rs", rev = "f4a83f7", default-features = false }
4245

4346
search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "f69b160" }
4447
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "e8362c3" }
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
[package]
2+
name = "minicpm3-cpu"
3+
version = "0.0.0"
4+
edition = "2021"
5+
authors = ["onenewcode <[email protected]>", "YdrMaster <[email protected]>"]
6+
7+
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
8+
9+
[dependencies]
10+
minicpm3.path = "../common"
11+
common.workspace = true
12+
operators = { workspace = true, features = ["common-cpu"] }
13+
14+
[dev-dependencies]
15+
test-utils.workspace = true
16+
gguf.workspace = true
17+
regex.workspace = true
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
use crate::{Operators, RandomSample, Weights};
2+
use common::Distribution;
3+
use gguf::GGufModel;
4+
use minicpm3::{ext::ggml_quants::f16, MiniCPM3Request, MiniCPM3Storage, Minicpm3Worker, Tensor};
5+
use operators::{
6+
all_reduce::common_cpu::Operator as AllReduce,
7+
common_cpu::{InprocNode, ThisThread},
8+
random_sample::{KVPair, SampleArgs},
9+
Blob,
10+
};
11+
use regex::Regex;
12+
use std::{
13+
iter::zip,
14+
ptr::copy_nonoverlapping,
15+
slice::from_raw_parts_mut,
16+
sync::{Arc, Barrier},
17+
thread,
18+
};
19+
use test_utils::{test_infer_paralle, Inference, Task, TokenizerAndPrompt, WorkerSeed};
20+
21+
type Worker<'w> = Minicpm3Worker<Operators<InprocNode<usize>, AllReduce>, Weights<'w>>;
22+
23+
#[test]
24+
fn test_infer() {
25+
std::env::set_var(
26+
"TEST_MODEL",
27+
"/home/ztf/cpm/Origin-MiniCPM3-4B-v0.0-F16.gguf",
28+
);
29+
let Some(Inference {
30+
model,
31+
devices,
32+
mut prompt,
33+
as_user,
34+
temperature,
35+
top_p,
36+
top_k,
37+
max_steps,
38+
}) = Inference::load()
39+
else {
40+
return;
41+
};
42+
prompt = "我".to_owned();
43+
let gguf = GGufModel::read(model.iter().map(|s| &**s));
44+
45+
let TokenizerAndPrompt {
46+
eos,
47+
tokenizer,
48+
prompt,
49+
} = TokenizerAndPrompt::new(&gguf, prompt, as_user);
50+
51+
let model = MiniCPM3Storage::from_gguf(&gguf);
52+
println!("{:?}", model.meta);
53+
54+
let sample_args = SampleArgs::new(temperature, top_p, top_k).expect("invalid sample args");
55+
println!("{sample_args:?}");
56+
57+
let lens = devices
58+
.map(|devices| {
59+
Regex::new(r"\d+")
60+
.unwrap()
61+
.find_iter(&devices)
62+
.map(|c| c.as_str().parse().unwrap())
63+
.collect()
64+
})
65+
.unwrap_or_else(|| vec![1]);
66+
let dist = lens.iter().sum();
67+
println!("distribution: {lens:?}");
68+
69+
let (seeds, senders) = WorkerSeed::new(InprocNode::new(lens.len()));
70+
let barrier = Arc::new(Barrier::new(dist + 1));
71+
thread::scope(|s| {
72+
let _workers = zip(lens, seeds)
73+
.enumerate()
74+
.scan(0, |start, (id, (len, seed))| {
75+
let dist = Distribution::new(*start, len, dist);
76+
*start += len;
77+
78+
let meta = model.meta.distribute(dist);
79+
let model = &model;
80+
let barrier = barrier.clone();
81+
Some(s.spawn(move || {
82+
let WorkerSeed { node, tasks } = seed;
83+
let weights = Weights::new(model, dist);
84+
let mut worker = Worker::new(id, &node, meta.clone(), weights);
85+
let mut cache = meta.kv_cache(meta.nctx).map(Blob::new);
86+
let sin_cos = <Operators as minicpm3::Operators>::build_sin_cos(
87+
meta.dt_embd,
88+
meta.nctx,
89+
meta.dh,
90+
&ThisThread,
91+
);
92+
93+
let sample = RandomSample::new(&node);
94+
let indices = RandomSample::build_indices(model.meta.nvoc, &ThisThread);
95+
let mut pair = KVPair::new(0, f16::ZERO);
96+
let mut pairs = Tensor::kv_pair_vec(1, |_| unsafe {
97+
from_raw_parts_mut(&mut pair as *mut _ as *mut u8, size_of_val(&pair))
98+
});
99+
100+
barrier.wait();
101+
for task in tasks {
102+
let Task {
103+
nt,
104+
pos,
105+
embd,
106+
next,
107+
} = task;
108+
let mut embd = meta.embd(nt).map(|size| {
109+
let mut blob = Blob::new(size);
110+
unsafe { copy_nonoverlapping(embd, blob.as_mut_ptr(), size) };
111+
blob
112+
});
113+
let mut logits = meta.logits(if id == 0 { 1 } else { 0 }).map(Blob::new);
114+
worker
115+
.launch(
116+
minicpm3::MiniCPM3Args {
117+
embd: embd.map_slice_mut(),
118+
logits: logits.map_slice_mut(),
119+
sin_cos: sin_cos.map_slice(),
120+
requests: vec![MiniCPM3Request {
121+
cache: cache.map_slice_mut(),
122+
seq_len: nt,
123+
out_len: if id == 0 { 1 } else { 0 },
124+
pos,
125+
}],
126+
num_tokens: nt,
127+
max_seq_len: nt,
128+
max_att_len: nt + pos,
129+
},
130+
&mut [],
131+
&ThisThread,
132+
)
133+
.unwrap();
134+
if id == 0 {
135+
sample
136+
.launch(
137+
&mut pairs,
138+
&logits,
139+
&indices,
140+
sample_args,
141+
&mut [],
142+
&ThisThread,
143+
)
144+
.unwrap();
145+
next.send(pair.idx() as _).unwrap()
146+
}
147+
}
148+
}))
149+
})
150+
.collect::<Vec<_>>();
151+
152+
let senders = senders.into_boxed_slice();
153+
barrier.wait();
154+
test_infer_paralle(
155+
senders,
156+
test_utils::AboutToken {
157+
tokenizer,
158+
token_embd: model.token_embd,
159+
nvoc: model.meta.nvoc,
160+
eos,
161+
},
162+
&prompt,
163+
max_steps,
164+
)
165+
})
166+
}
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
use common::{Contiguous, Distribution};
2+
use minicpm3::{MiniCPM3BlkStorage, MiniCPM3BlkWeight, MiniCPM3Storage, Tensor, WeightLoader};
3+
use operators::{
4+
all_reduce::{AllReduce, NonAllReduce},
5+
common_cpu::Cpu,
6+
random_sample::common_cpu::Operator as RandomSampleCpu,
7+
rearrange::common_cpu::Operator as Rearrange,
8+
Blob, ByteOf, QueueOf, TopoNode,
9+
};
10+
use std::{marker::PhantomData, ops::Deref};
11+
12+
pub struct Operators<N = Cpu, R = NonAllReduce<Cpu, Rearrange>>(PhantomData<(N, R)>);
13+
14+
pub type RandomSample = minicpm3::RandomSample<Cpu, RandomSampleCpu>;
15+
16+
pub struct Weights<'w> {
17+
blks: Box<[MiniCPM3BlkStorage<Contiguous<'w, Blob>>]>,
18+
output_norm: &'w [u8],
19+
output: &'w [u8],
20+
long_factor: &'w [u8],
21+
sort_factor: &'w [u8],
22+
}
23+
24+
macro_rules! op {
25+
($name:ident) => {
26+
operators::$name::common_cpu::Operator
27+
};
28+
}
29+
30+
impl<N, R> minicpm3::Operators for Operators<N, R>
31+
where
32+
N: TopoNode<Cpu>,
33+
R: AllReduce<Cpu, N>,
34+
{
35+
type Hardware = Cpu;
36+
type TopoNode = N;
37+
type Rope = op!(rope);
38+
type Attention = op!(attention);
39+
type RmsNorm = op!(rms_norm);
40+
type Add = op!(add);
41+
type MatMul = op!(mat_mul);
42+
type Swiglu = op!(swiglu);
43+
type Rearrange = op!(rearrange);
44+
type Scale = op!(scale);
45+
type AttnKVCached = op!(attention_kv_cached);
46+
type AllReduce = R;
47+
48+
fn debug<T>(tensor: &Tensor<T>, _queue: &QueueOf<Self::Hardware>)
49+
where
50+
T: Deref<Target = [ByteOf<Self::Hardware>]>,
51+
{
52+
println!("{tensor}")
53+
}
54+
}
55+
56+
impl<'w> Weights<'w> {
57+
pub fn new(model: &'w MiniCPM3Storage<&'w [u8]>, dist: Distribution) -> Self {
58+
let MiniCPM3Storage {
59+
meta,
60+
output_norm,
61+
output,
62+
blocks,
63+
rope_long,
64+
rope_short,
65+
..
66+
} = model;
67+
68+
let blks = blocks
69+
.iter()
70+
.map(|blk| {
71+
blk.clone()
72+
.into_vec()
73+
.into_iter()
74+
.map(|(which, data)| {
75+
(which, meta.distribute_data(which, data, dist, Blob::new))
76+
})
77+
.collect::<MiniCPM3BlkStorage<_>>()
78+
})
79+
.collect();
80+
81+
Self {
82+
blks,
83+
output_norm,
84+
output,
85+
long_factor: rope_long,
86+
sort_factor: rope_short,
87+
}
88+
}
89+
}
90+
91+
impl WeightLoader for Weights<'_> {
92+
type Hardware = Cpu;
93+
type Weight<'s>
94+
= &'s [u8]
95+
where
96+
Self: 's;
97+
98+
#[inline]
99+
fn load_blk(
100+
&self,
101+
which: MiniCPM3BlkWeight,
102+
iblk: usize,
103+
_queue: &QueueOf<Self::Hardware>,
104+
) -> Self::Weight<'_> {
105+
let MiniCPM3BlkStorage {
106+
attn_norm,
107+
attn_qb,
108+
attn_qa,
109+
attn_kvb,
110+
attn_kva,
111+
attn_qa_norm,
112+
attn_kva_norm,
113+
attn_o,
114+
ffn_norm,
115+
ffn_gate_up,
116+
ffn_down,
117+
ffn_gate,
118+
ffn_up,
119+
} = &self.blks[iblk];
120+
use MiniCPM3BlkWeight as W;
121+
match which {
122+
W::AttnNorm => attn_norm,
123+
W::AttnQB => attn_qb,
124+
W::AttnQA => attn_qa,
125+
W::AttnKvB => attn_kvb,
126+
W::AttnKvA => attn_kva,
127+
W::AttnQANorm => attn_qa_norm,
128+
W::AttnKvANorm => attn_kva_norm,
129+
W::AttnO => attn_o,
130+
W::FfnNorm => ffn_norm,
131+
W::FfnGateUp => ffn_gate_up,
132+
W::FfnDown => ffn_down,
133+
W::FfnGate => ffn_gate,
134+
W::FfnUp => ffn_up,
135+
}
136+
}
137+
138+
#[inline]
139+
fn output_norm(&self, _queue: &QueueOf<Self::Hardware>) -> Self::Weight<'_> {
140+
self.output_norm
141+
}
142+
143+
#[inline]
144+
fn output(&self, _queue: &QueueOf<Self::Hardware>) -> Self::Weight<'_> {
145+
self.output
146+
}
147+
#[inline]
148+
fn long_factor<'a>(&'a self, _queue: &'a QueueOf<Self::Hardware>) -> Self::Weight<'_> {
149+
self.long_factor
150+
}
151+
#[inline]
152+
fn short_factor<'a>(&'a self, _queue: &'a QueueOf<Self::Hardware>) -> Self::Weight<'_> {
153+
self.sort_factor
154+
}
155+
}
156+
157+
#[cfg(test)]
158+
mod infer;

models/minicpm3/common/Cargo.toml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
[package]
2+
name = "minicpm3"
3+
version = "0.0.0"
4+
edition = "2021"
5+
authors = ["onenewcode <[email protected]>", "YdrMaster <[email protected]>"]
6+
7+
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
8+
9+
[dependencies]
10+
common.workspace = true
11+
gguf.workspace = true
12+
tensor.workspace = true
13+
operators.workspace = true
14+
itertools.workspace = true
15+
half = "2.4"
16+
17+
[dev-dependencies]
18+
test-utils.workspace = true

0 commit comments

Comments
 (0)