Skip to content

Commit cf78ed5

Browse files
committed
feat(clip): 完善 clip 结构并为 softmax 添加并行加速
Signed-off-by: YdrMaster <[email protected]>
1 parent c9b6bdf commit cf78ed5

File tree

8 files changed

+63
-26
lines changed

8 files changed

+63
-26
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ itertools = "0.13"
3838
env_logger = "0.11"
3939
build-script-cfg = "0.0"
4040

41-
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "0bd4107", default-features = false }
41+
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "8712870", default-features = false }
4242

4343
search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "f69b160" }
4444
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "e8362c3" }

models/clip/common-cpu/src/infer.rs

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ fn test_infer() {
3939
} = meta;
4040

4141
let time = Instant::now();
42-
let image = Image::load(picture);
42+
let image = Image::load(&picture);
4343
println!("load image {:?}", time.elapsed());
4444

4545
let time = Instant::now();
@@ -48,38 +48,55 @@ fn test_infer() {
4848
.normalize(dt, image_mean, image_std);
4949
println!("slice image {:?}", time.elapsed());
5050

51+
let batch = slices.batch();
52+
let mut img_embd = meta.projector.img_embd(meta.dt, batch).map(Blob::new);
53+
let d = img_embd.shape()[2];
54+
5155
let weights = Weights::new(&storage);
5256
let mut worker = Worker::new(&Cpu, meta.clone(), weights);
5357

54-
let whole = slices.whole();
55-
worker
56-
.launch(
57-
ClipArgs {
58-
raw: whole.to_nchw(),
59-
pos: pos70(whole.shape(), d_patch).map_slice(),
60-
pos_resampler: pos_resampler(3584, whole.shape(), d_patch).map_slice(),
61-
},
62-
&mut [],
63-
&ThisThread,
64-
)
65-
.unwrap();
58+
{
59+
let whole = slices.whole();
60+
let img_embd = img_embd.map_slice_mut().slice(0, 0, 1, 1);
61+
worker
62+
.launch(
63+
ClipArgs {
64+
img_embd,
65+
raw: whole.to_nchw(),
66+
pos: pos70(whole.shape(), d_patch).map_slice(),
67+
pos_resampler: pos_resampler(d, whole.shape(), d_patch).map_slice(),
68+
},
69+
&mut [],
70+
&ThisThread,
71+
)
72+
.unwrap();
73+
}
6674

6775
if let Some(patches) = slices.patches_nchw() {
6876
let &[_, 3, h, w] = patches.shape() else {
6977
unreachable!()
7078
};
79+
let img_embd = img_embd.map_slice_mut().slice(0, 1, 1, batch - 1);
7180
worker
7281
.launch(
7382
ClipArgs {
83+
img_embd,
7484
raw: patches.map_slice(),
7585
pos: pos70([w, h], d_patch).map_slice(),
76-
pos_resampler: pos_resampler(3584, [w, h], d_patch).map_slice(),
86+
pos_resampler: pos_resampler(d, [w, h], d_patch).map_slice(),
7787
},
7888
&mut [],
7989
&ThisThread,
8090
)
8191
.unwrap();
8292
}
93+
94+
println!(
95+
"create {} x {} tokens from {}",
96+
img_embd.shape()[0],
97+
img_embd.shape()[1],
98+
picture.display(),
99+
);
83100
}
84101

85102
fn pos70([w, h]: [usize; 2], d_patch: usize) -> Tensor<Blob> {

models/clip/common/src/args.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
use tensor::Tensor;
33

44
pub struct Args<'a, H: Hardware> {
5+
/// shape: [batch, projector_dp, projector_d]
6+
pub img_embd: Tensor<&'a mut [H::Byte]>,
57
/// shape: [n, c, h, w]
68
pub raw: Tensor<&'a [H::Byte]>,
79
/// shape: [h x w]

models/clip/common/src/compute.rs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ where
149149
{
150150
let time = Instant::now();
151151
let Args {
152+
img_embd: proj_q,
152153
raw,
153154
pos,
154155
pos_resampler,
@@ -317,10 +318,7 @@ where
317318
let [w, b] = weights.resampler_attn_o(queue);
318319
let attn_o = (attn_w.clone().map(|_| w), Some(attn_b.clone().map(|_| b)));
319320

320-
let qo = Tensor::new(dt, &[batch * dq, d]);
321-
322-
let (buf, workspace) = workspace.split_at_mut(*qo.get());
323-
let mut q_ = qo.clone().map(|_| buf);
321+
let mut q_ = proj_q.merge(0..2).unwrap();
324322
{
325323
let mut q_ = q_.map_slice_mut().tile(0, &[batch, dq]);
326324
{
@@ -363,18 +361,19 @@ where
363361
}
364362
let o = q_;
365363

366-
let (buf, workspace) = workspace.split_at_mut(*qo.get());
367-
let mut o_ = qo.map(|_| buf);
364+
let o_ = Tensor::new(o.dt(), o.shape());
365+
let (buf, workspace) = workspace.split_at_mut(*o_.get());
366+
let mut o_ = o_.map(|_| buf);
368367
self.mat_mul(&mut o_, &o, attn_o, workspace, queue_alloc)?;
369368

370369
let [w, b] = weights.resampler_ln_post(queue);
371370
let ln_post = [ln.clone().map(|_| w), ln.clone().map(|_| b)];
372371
let inplace = unsafe { o_.map_slice_static() };
373372
self.layer_norm(&mut o_, &inplace, ln_post, workspace, queue_alloc)?;
374373

375-
let mut out = o;
374+
let mut img_embd = o;
376375
let w = attn_w.map(|_| weights.resampler_proj(queue));
377-
self.mat_mul(&mut out, &o_, (w, None), workspace, queue_alloc)?
376+
self.mat_mul(&mut img_embd, &o_, (w, None), workspace, queue_alloc)?
378377
}
379378
}
380379

models/clip/common/src/image.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,12 @@ impl ImageGrid {
184184
}
185185
}
186186

187+
#[inline]
188+
pub fn batch(&self) -> usize {
189+
let [x, y] = self.grid();
190+
x * y + 1
191+
}
192+
187193
pub fn patch(&self, x: usize, y: usize) -> Image<&[u8]> {
188194
Image(
189195
self.grid

models/clip/common/src/projector/mod.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
pub(crate) mod resampler;
22

3-
use gguf::{GGufMetaMapExt, GGufModel};
3+
use gguf::{ggml_quants::digit_layout::DigitLayout, GGufMetaMapExt, GGufModel};
4+
use tensor::Tensor;
45

56
#[derive(Clone, Debug)]
67
pub enum ProjectorMeta {
@@ -14,6 +15,12 @@ impl ProjectorMeta {
1415
projector => todo!("unsupported projector type: {projector}"),
1516
}
1617
}
18+
19+
pub fn img_embd(&self, dt: DigitLayout, batch: usize) -> Tensor<usize> {
20+
match self {
21+
ProjectorMeta::Resampler(meta) => meta.img_embd(dt, batch),
22+
}
23+
}
1724
}
1825

1926
#[derive(Clone)]

models/clip/common/src/projector/resampler.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
use gguf::{tensor, GGufMetaMapExt, GGufModel};
1+
use gguf::{ggml_quants::digit_layout::DigitLayout, tensor, GGufMetaMapExt, GGufModel};
2+
use tensor::Tensor;
23

34
#[derive(Clone, Debug)]
45
pub struct Meta {
@@ -23,6 +24,11 @@ impl Meta {
2324
version => todo!("Unsupported MiniCPM version: {version}"),
2425
}
2526
}
27+
28+
#[inline]
29+
pub fn img_embd(&self, dt: DigitLayout, batch: usize) -> Tensor<usize> {
30+
Tensor::new(dt, &[batch, self.dq, self.d])
31+
}
2632
}
2733

2834
#[derive(Clone)]

tensor/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ impl<T> Tensor<T> {
8484

8585
let merged = self
8686
.layout
87-
.merge_be(0, self.layout.ndim())
87+
.merge_free(0, self.layout.ndim())
8888
.expect("dense tensor is castable");
8989
let &[d] = merged.shape() else { unreachable!() };
9090
let &[s] = merged.strides() else {

0 commit comments

Comments
 (0)