@@ -39,7 +39,7 @@ fn test_infer() {
3939 } = meta;
4040
4141 let time = Instant :: now ( ) ;
42- let image = Image :: load ( picture) ;
42+ let image = Image :: load ( & picture) ;
4343 println ! ( "load image {:?}" , time. elapsed( ) ) ;
4444
4545 let time = Instant :: now ( ) ;
@@ -48,38 +48,55 @@ fn test_infer() {
4848 . normalize ( dt, image_mean, image_std) ;
4949 println ! ( "slice image {:?}" , time. elapsed( ) ) ;
5050
51+ let batch = slices. batch ( ) ;
52+ let mut img_embd = meta. projector . img_embd ( meta. dt , batch) . map ( Blob :: new) ;
53+ let d = img_embd. shape ( ) [ 2 ] ;
54+
5155 let weights = Weights :: new ( & storage) ;
5256 let mut worker = Worker :: new ( & Cpu , meta. clone ( ) , weights) ;
5357
54- let whole = slices. whole ( ) ;
55- worker
56- . launch (
57- ClipArgs {
58- raw : whole. to_nchw ( ) ,
59- pos : pos70 ( whole. shape ( ) , d_patch) . map_slice ( ) ,
60- pos_resampler : pos_resampler ( 3584 , whole. shape ( ) , d_patch) . map_slice ( ) ,
61- } ,
62- & mut [ ] ,
63- & ThisThread ,
64- )
65- . unwrap ( ) ;
58+ {
59+ let whole = slices. whole ( ) ;
60+ let img_embd = img_embd. map_slice_mut ( ) . slice ( 0 , 0 , 1 , 1 ) ;
61+ worker
62+ . launch (
63+ ClipArgs {
64+ img_embd,
65+ raw : whole. to_nchw ( ) ,
66+ pos : pos70 ( whole. shape ( ) , d_patch) . map_slice ( ) ,
67+ pos_resampler : pos_resampler ( d, whole. shape ( ) , d_patch) . map_slice ( ) ,
68+ } ,
69+ & mut [ ] ,
70+ & ThisThread ,
71+ )
72+ . unwrap ( ) ;
73+ }
6674
6775 if let Some ( patches) = slices. patches_nchw ( ) {
6876 let & [ _, 3 , h, w] = patches. shape ( ) else {
6977 unreachable ! ( )
7078 } ;
79+ let img_embd = img_embd. map_slice_mut ( ) . slice ( 0 , 1 , 1 , batch - 1 ) ;
7180 worker
7281 . launch (
7382 ClipArgs {
83+ img_embd,
7484 raw : patches. map_slice ( ) ,
7585 pos : pos70 ( [ w, h] , d_patch) . map_slice ( ) ,
76- pos_resampler : pos_resampler ( 3584 , [ w, h] , d_patch) . map_slice ( ) ,
86+ pos_resampler : pos_resampler ( d , [ w, h] , d_patch) . map_slice ( ) ,
7787 } ,
7888 & mut [ ] ,
7989 & ThisThread ,
8090 )
8191 . unwrap ( ) ;
8292 }
93+
94+ println ! (
95+ "create {} x {} tokens from {}" ,
96+ img_embd. shape( ) [ 0 ] ,
97+ img_embd. shape( ) [ 1 ] ,
98+ picture. display( ) ,
99+ ) ;
83100}
84101
85102fn pos70 ( [ w, h] : [ usize ; 2 ] , d_patch : usize ) -> Tensor < Blob > {
0 commit comments