Skip to content

Commit

Permalink
[Model] Fix cached buffer allocation and fix dynamic shape in Phi3v (#…
Browse files Browse the repository at this point in the history
…3082)

This PR fixes the cached buffer allocation for larger new images
and supports dynamic shapes in the vision encoder.
  • Loading branch information
mengshyu authored Jan 5, 2025
1 parent 8243b2b commit 6059763
Show file tree
Hide file tree
Showing 9 changed files with 434 additions and 75 deletions.
7 changes: 7 additions & 0 deletions cpp/serve/function_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,13 @@ ObjectRef FunctionTable::CopyToWorker0(const NDArray& host_array, String buffer_
NDArray buffer{nullptr};
if (it != this->cached_buffers.end()) {
buffer = Downcast<NDArray>((*it).second);
if (buffer_cache_key == "image") {
if (runtime::GetDataSize(*buffer.operator->()) <
runtime::GetDataSize(*host_array.operator->())) {
buffer = NDArray::Empty(max_reserved_shape, host_array->dtype, local_gpu_device);
this->cached_buffers.Set(buffer_cache_key, buffer);
}
}
} else {
buffer = NDArray::Empty(max_reserved_shape, host_array->dtype, local_gpu_device);
this->cached_buffers.Set(buffer_cache_key, buffer);
Expand Down
16 changes: 15 additions & 1 deletion cpp/serve/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <fstream>

#include "../support/json_parser.h"
#include "../support/vlm_utils.h"
#include "config.h"
#include "logit_processor.h"

Expand Down Expand Up @@ -113,8 +114,19 @@ class ModelImpl : public ModelObj {
ObjectRef ImageEmbed(const NDArray& image, ObjectRef* dst, int offset) final {
NVTXScopedRange nvtx_scope("ImageEmbed");
CHECK(ft_.image_embed_func_.defined()) << "`image_embed` function is not found in the model. ";

int tmp_h = 0, tmp_w = 0;
CalculateResizeShape(image, this->model_type_, &tmp_h, &tmp_w);
ShapeTuple resize_h = {tmp_h};
ShapeTuple resize_w = {tmp_w};

CalculateCropShape(image, this->model_type_, &tmp_h, &tmp_w);
ShapeTuple crop_h = {tmp_h};
ShapeTuple crop_w = {tmp_w};

auto image_dref_or_nd = ft_.CopyToWorker0(image, "image", image.Shape());
ObjectRef embeddings = ft_.image_embed_func_(image_dref_or_nd, params_);
ObjectRef embeddings =
ft_.image_embed_func_(image_dref_or_nd, resize_h, resize_w, crop_h, crop_w, params_);
if (dst != nullptr) {
CHECK(dst->defined());
ft_.nd_copy_embedding_to_offset_func_(embeddings, *dst, offset);
Expand Down Expand Up @@ -1003,6 +1015,7 @@ class ModelImpl : public ModelObj {
json::LookupOrDefault<int64_t>(config, "attention_sink_size", this->attention_sink_size_);
this->attention_sink_size_ = std::max(this->attention_sink_size_, 0);
this->vocab_size_ = json::Lookup<int64_t>(config, "vocab_size");
this->model_type_ = json::Lookup<std::string>(config, "model_type");
}

//----------------------------
Expand All @@ -1019,6 +1032,7 @@ class ModelImpl : public ModelObj {
DLDataType hidden_states_dtype_;
int vocab_size_ = -1;
int image_embed_size_ = -1;
std::string model_type_;
//----------------------------
// TVM related states
//----------------------------
Expand Down
58 changes: 58 additions & 0 deletions cpp/support/vlm_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*!
* Copyright (c) 2023-2024 by Contributors
* \file support/image_utils.cc
*/
#include "vlm_utils.h"

#include <cmath>

namespace mlc {
namespace llm {

void CalculateResizeShape(tvm::runtime::NDArray image_data, std::string model_type,
int* p_target_height, int* p_target_width) {
ICHECK_EQ(image_data->shape[3], 3) << "Image format must be NHWC";
int height = image_data->shape[1];
int width = image_data->shape[2];

if ("phi3_v" == model_type) {
const int hd_num = 4;
double ratio = static_cast<double>(width) / height;
int scale = 1;
while (scale * std::ceil(scale / ratio) <= hd_num) {
scale += 1;
}
scale -= 1;
*p_target_width = static_cast<int>(scale * 336);
*p_target_height = static_cast<int>(*p_target_width / ratio);
}
}

void CalculatePadShape(tvm::runtime::NDArray image_data, std::string model_type, int* p_pad_height,
int* p_pad_width) {
ICHECK_EQ(image_data->shape[3], 3) << "Image format must be NHWC";
if ("phi3_v" == model_type) {
int resized_height = 0, resized_width = 0;
CalculateResizeShape(image_data, model_type, &resized_height, &resized_width);
int tar = (int)(ceil(resized_height / 336.0) * 336);
int top_padding = (int)((tar - resized_height) / 2);
int bottom_padding = tar - resized_height - top_padding;
ICHECK_EQ(tar, resized_height + top_padding + bottom_padding) << "Padding size not equal!";
*p_pad_height = tar;
*p_pad_width = resized_width;
}
}

void CalculateCropShape(tvm::runtime::NDArray image_data, std::string model_type,
int* p_crop_height, int* p_crop_width) {
ICHECK_EQ(image_data->shape[3], 3) << "Image format must be NHWC";
if ("phi3_v" == model_type) {
int pad_h = 0, pad_w = 0;
CalculatePadShape(image_data, model_type, &pad_h, &pad_w);
*p_crop_height = pad_h / 336;
*p_crop_width = pad_w / 336;
}
}

} // namespace llm
} // namespace mlc
48 changes: 48 additions & 0 deletions cpp/support/vlm_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*!
* Copyright (c) 2023-2024 by Contributors
* \file support/vlm_utils.h
* \brief Tools for debug purposes.
*/
#ifndef MLC_LLM_SUPPORT_VLM_UTILS_H_
#define MLC_LLM_SUPPORT_VLM_UTILS_H_

#include <tvm/runtime/ndarray.h>

#include <string>

namespace mlc {
namespace llm {

/*!
* \brief Calculate the target height and width for resizing an image based on the input data and
* model type. \param image_data The input image data as a TVM NDArray. \param model_type The type
* of the model influencing the resizing parameters (e.g., phi3v). \param target_height Reference to
* the variable where the calculated target height will be stored. \param target_width Reference to
* the variable where the calculated target width will be stored.
*/
void CalculateResizeShape(tvm::runtime::NDArray image_data, std::string model_type,
int* p_target_height, int* p_target_width);
/*!
* \brief Calculate the padding height and width for an image based on the input data and model
* type. \param image_data The input image data as a TVM NDArray. \param model_type The type of the
* model influencing the padding parameters (e.g., phi3v). \param pad_height Reference to the
* variable where the calculated padding height will be stored. \param pad_width Reference to the
* variable where the calculated padding width will be stored.
*/
void CalculatePadShape(tvm::runtime::NDArray image_data, std::string model_type, int* p_pad_height,
int* p_pad_width);

/*!
* \brief Calculate the cropping height and width for an image based on the input data and model
* type. \param image_data The input image data as a TVM NDArray. \param model_type The type of the
* model influencing the cropping parameters (e.g., phi3v). \param crop_height Reference to the
* variable where the calculated cropping height will be stored. \param crop_width Reference to the
* variable where the calculated cropping width will be stored.
*/
void CalculateCropShape(tvm::runtime::NDArray image_data, std::string model_type,
int* p_crop_height, int* p_crop_width);

} // namespace llm
} // namespace mlc

#endif // MLC_LLM_SUPPORT_IMAGE_UTILS_H_
3 changes: 2 additions & 1 deletion python/mlc_llm/compiler_pass/low_batch_specialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def transform_module(
symbolic_vars = set(
expr for shape in shapes for expr in shape if isinstance(expr, tir.Var)
)
assert len(symbolic_vars) == 1, symbolic_vars
if len(symbolic_vars) != 1:
continue
gemm_mod = IRModule({})
gemm_mod["main"] = func
gemm_mod = dl.ApplyDefaultSchedule(
Expand Down
Loading

0 comments on commit 6059763

Please sign in to comment.