Skip to content

Commit 70c9a74

Browse files
committed
[Model] Fix cached buffer allocation and fix dynamic shape in Phi3v
This PR fixes the cached buffer allocation for larger new images and supports dynamic shapes in the vision encoder.
1 parent 1825fed commit 70c9a74

File tree

9 files changed

+436
-75
lines changed

9 files changed

+436
-75
lines changed

cpp/serve/function_table.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,13 @@ ObjectRef FunctionTable::CopyToWorker0(const NDArray& host_array, String buffer_
321321
NDArray buffer{nullptr};
322322
if (it != this->cached_buffers.end()) {
323323
buffer = Downcast<NDArray>((*it).second);
324+
if (buffer_cache_key == "image") {
325+
if (runtime::GetDataSize(*buffer.operator->()) <
326+
runtime::GetDataSize(*host_array.operator->())) {
327+
buffer = NDArray::Empty(max_reserved_shape, host_array->dtype, local_gpu_device);
328+
this->cached_buffers.Set(buffer_cache_key, buffer);
329+
}
330+
}
324331
} else {
325332
buffer = NDArray::Empty(max_reserved_shape, host_array->dtype, local_gpu_device);
326333
this->cached_buffers.Set(buffer_cache_key, buffer);

cpp/serve/model.cc

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include <fstream>
1414

15+
#include "../support/vlm_utils.h"
1516
#include "../support/json_parser.h"
1617
#include "config.h"
1718
#include "logit_processor.h"
@@ -113,8 +114,19 @@ class ModelImpl : public ModelObj {
113114
ObjectRef ImageEmbed(const NDArray& image, ObjectRef* dst, int offset) final {
114115
NVTXScopedRange nvtx_scope("ImageEmbed");
115116
CHECK(ft_.image_embed_func_.defined()) << "`image_embed` function is not found in the model. ";
117+
118+
int tmp_h = 0, tmp_w = 0;
119+
CalculateResizeShape(image, this->model_type_, tmp_h, tmp_w);
120+
ShapeTuple resize_h = {tmp_h};
121+
ShapeTuple resize_w = {tmp_w};
122+
123+
CalculateCropShape(image, this->model_type_, tmp_h, tmp_w);
124+
ShapeTuple crop_h = {tmp_h};
125+
ShapeTuple crop_w = {tmp_w};
126+
116127
auto image_dref_or_nd = ft_.CopyToWorker0(image, "image", image.Shape());
117-
ObjectRef embeddings = ft_.image_embed_func_(image_dref_or_nd, params_);
128+
ObjectRef embeddings =
129+
ft_.image_embed_func_(image_dref_or_nd, resize_h, resize_w, crop_h, crop_w, params_);
118130
if (dst != nullptr) {
119131
CHECK(dst->defined());
120132
ft_.nd_copy_embedding_to_offset_func_(embeddings, *dst, offset);
@@ -1003,6 +1015,7 @@ class ModelImpl : public ModelObj {
10031015
json::LookupOrDefault<int64_t>(config, "attention_sink_size", this->attention_sink_size_);
10041016
this->attention_sink_size_ = std::max(this->attention_sink_size_, 0);
10051017
this->vocab_size_ = json::Lookup<int64_t>(config, "vocab_size");
1018+
this->model_type_ = json::Lookup<std::string>(config, "model_type");
10061019
}
10071020

10081021
//----------------------------
@@ -1019,6 +1032,7 @@ class ModelImpl : public ModelObj {
10191032
DLDataType hidden_states_dtype_;
10201033
int vocab_size_ = -1;
10211034
int image_embed_size_ = -1;
1035+
std::string model_type_;
10221036
//----------------------------
10231037
// TVM related states
10241038
//----------------------------

cpp/support/vlm_utils.cc

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*!
2+
* Copyright (c) 2023-2024 by Contributors
3+
* \file support/image_utils.cc
4+
*/
5+
#include "vlm_utils.h"
6+
7+
#include <cmath>
8+
9+
namespace mlc {
10+
namespace llm {
11+
12+
void CalculateResizeShape(tvm::runtime::NDArray image_data, std::string model_type,
13+
int& target_height, int& target_width) {
14+
ICHECK_EQ(image_data->shape[3], 3) << "Image format must be NHWC";
15+
int height = image_data->shape[1];
16+
int width = image_data->shape[2];
17+
18+
if ("phi3_v" == model_type) {
19+
const int hd_num = 4;
20+
double ratio = static_cast<double>(width) / height;
21+
int scale = 1;
22+
while (scale * std::ceil(scale / ratio) <= hd_num) {
23+
scale += 1;
24+
}
25+
scale -= 1;
26+
target_width = static_cast<int>(scale * 336);
27+
target_height = static_cast<int>(target_width / ratio);
28+
}
29+
}
30+
31+
void CalculatePadShape(tvm::runtime::NDArray image_data, std::string model_type, int& pad_height,
32+
int& pad_width) {
33+
ICHECK_EQ(image_data->shape[3], 3) << "Image format must be NHWC";
34+
if ("phi3_v" == model_type) {
35+
int resized_height = 0, resized_width = 0;
36+
CalculateResizeShape(image_data, model_type, resized_height, resized_width);
37+
int tar = (int)(ceil(resized_height / 336.0) * 336);
38+
int top_padding = (int)((tar - resized_height) / 2);
39+
int bottom_padding = tar - resized_height - top_padding;
40+
ICHECK_EQ(tar, resized_height + top_padding + bottom_padding) << "Padding size not equal!";
41+
pad_height = tar;
42+
pad_width = resized_width;
43+
}
44+
}
45+
46+
void CalculateCropShape(tvm::runtime::NDArray image_data, std::string model_type, int& crop_height,
47+
int& crop_width) {
48+
ICHECK_EQ(image_data->shape[3], 3) << "Image format must be NHWC";
49+
if ("phi3_v" == model_type) {
50+
int pad_h = 0, pad_w = 0;
51+
CalculatePadShape(image_data, model_type, pad_h, pad_w);
52+
crop_height = pad_h / 336;
53+
crop_width = pad_w / 336;
54+
}
55+
}
56+
57+
} // namespace llm
58+
} // namespace mlc

cpp/support/vlm_utils.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*!
2+
* Copyright (c) 2023-2024 by Contributors
3+
* \file support/vlm_utils.h
4+
* \brief Tools for debug purposes.
5+
*/
6+
#ifndef MLC_LLM_SUPPORT_VLM_UTILS_H_
7+
#define MLC_LLM_SUPPORT_VLM_UTILS_H_
8+
9+
#include <tvm/runtime/ndarray.h>
10+
11+
#include <string>
12+
13+
namespace mlc {
14+
namespace llm {
15+
16+
/*!
17+
* \brief Calculate the target height and width for resizing an image based on the input data and model type.
18+
* \param image_data The input image data as a TVM NDArray.
19+
* \param model_type The type of the model influencing the resizing parameters (e.g., phi3v).
20+
* \param target_height Reference to the variable where the calculated target height will be stored.
21+
* \param target_width Reference to the variable where the calculated target width will be stored.
22+
*/
23+
void CalculateResizeShape(tvm::runtime::NDArray image_data, std::string model_type,
24+
int& target_height, int& target_width);
25+
/*!
26+
* \brief Calculate the padding height and width for an image based on the input data and model type.
27+
* \param image_data The input image data as a TVM NDArray.
28+
* \param model_type The type of the model influencing the padding parameters (e.g., phi3v).
29+
* \param pad_height Reference to the variable where the calculated padding height will be stored.
30+
* \param pad_width Reference to the variable where the calculated padding width will be stored.
31+
*/
32+
void CalculatePadShape(tvm::runtime::NDArray image_data, std::string model_type, int& pad_height,
33+
int& pad_width);
34+
35+
/*!
36+
* \brief Calculate the cropping height and width for an image based on the input data and model type.
37+
* \param image_data The input image data as a TVM NDArray.
38+
* \param model_type The type of the model influencing the cropping parameters (e.g., phi3v).
39+
* \param crop_height Reference to the variable where the calculated cropping height will be stored.
40+
* \param crop_width Reference to the variable where the calculated cropping width will be stored.
41+
*/
42+
void CalculateCropShape(tvm::runtime::NDArray image_data, std::string model_type, int& crop_height,
43+
int& crop_width);
44+
45+
} // namespace llm
46+
} // namespace mlc
47+
48+
#endif // MLC_LLM_SUPPORT_IMAGE_UTILS_H_

python/mlc_llm/compiler_pass/low_batch_specialization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def transform_module(
4040
symbolic_vars = set(
4141
expr for shape in shapes for expr in shape if isinstance(expr, tir.Var)
4242
)
43-
assert len(symbolic_vars) == 1, symbolic_vars
43+
if len(symbolic_vars) != 1:
44+
continue
4445
gemm_mod = IRModule({})
4546
gemm_mod["main"] = func
4647
gemm_mod = dl.ApplyDefaultSchedule(

0 commit comments

Comments
 (0)