1212
1313#include < fstream>
1414
15+ #include " ../support/vlm_utils.h"
1516#include " ../support/json_parser.h"
1617#include " config.h"
1718#include " logit_processor.h"
@@ -113,8 +114,19 @@ class ModelImpl : public ModelObj {
113114 ObjectRef ImageEmbed (const NDArray& image, ObjectRef* dst, int offset) final {
114115 NVTXScopedRange nvtx_scope (" ImageEmbed" );
115116 CHECK (ft_.image_embed_func_ .defined ()) << " `image_embed` function is not found in the model. " ;
117+
118+ int tmp_h = 0 , tmp_w = 0 ;
119+ CalculateResizeShape (image, this ->model_type_ , tmp_h, tmp_w);
120+ ShapeTuple resize_h = {tmp_h};
121+ ShapeTuple resize_w = {tmp_w};
122+
123+ CalculateCropShape (image, this ->model_type_ , tmp_h, tmp_w);
124+ ShapeTuple crop_h = {tmp_h};
125+ ShapeTuple crop_w = {tmp_w};
126+
116127 auto image_dref_or_nd = ft_.CopyToWorker0 (image, " image" , image.Shape ());
117- ObjectRef embeddings = ft_.image_embed_func_ (image_dref_or_nd, params_);
128+ ObjectRef embeddings =
129+ ft_.image_embed_func_ (image_dref_or_nd, resize_h, resize_w, crop_h, crop_w, params_);
118130 if (dst != nullptr ) {
119131 CHECK (dst->defined ());
120132 ft_.nd_copy_embedding_to_offset_func_ (embeddings, *dst, offset);
@@ -1003,6 +1015,7 @@ class ModelImpl : public ModelObj {
10031015 json::LookupOrDefault<int64_t >(config, " attention_sink_size" , this ->attention_sink_size_ );
10041016 this ->attention_sink_size_ = std::max (this ->attention_sink_size_ , 0 );
10051017 this ->vocab_size_ = json::Lookup<int64_t >(config, " vocab_size" );
1018+ this ->model_type_ = json::Lookup<std::string>(config, " model_type" );
10061019 }
10071020
10081021 // ----------------------------
@@ -1019,6 +1032,7 @@ class ModelImpl : public ModelObj {
10191032 DLDataType hidden_states_dtype_;
10201033 int vocab_size_ = -1 ;
10211034 int image_embed_size_ = -1 ;
1035+ std::string model_type_;
10221036 // ----------------------------
10231037 // TVM related states
10241038 // ----------------------------
0 commit comments