Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Fix cached buffer allocation and fix dynamic shape in Phi3v #3082

Merged
merged 2 commits into from
Jan 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cpp/serve/function_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,13 @@ ObjectRef FunctionTable::CopyToWorker0(const NDArray& host_array, String buffer_
NDArray buffer{nullptr};
if (it != this->cached_buffers.end()) {
buffer = Downcast<NDArray>((*it).second);
if (buffer_cache_key == "image") {
if (runtime::GetDataSize(*buffer.operator->()) <
runtime::GetDataSize(*host_array.operator->())) {
buffer = NDArray::Empty(max_reserved_shape, host_array->dtype, local_gpu_device);
this->cached_buffers.Set(buffer_cache_key, buffer);
}
}
} else {
buffer = NDArray::Empty(max_reserved_shape, host_array->dtype, local_gpu_device);
this->cached_buffers.Set(buffer_cache_key, buffer);
Expand Down
16 changes: 15 additions & 1 deletion cpp/serve/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <fstream>

#include "../support/json_parser.h"
#include "../support/vlm_utils.h"
#include "config.h"
#include "logit_processor.h"

Expand Down Expand Up @@ -113,8 +114,19 @@ class ModelImpl : public ModelObj {
ObjectRef ImageEmbed(const NDArray& image, ObjectRef* dst, int offset) final {
NVTXScopedRange nvtx_scope("ImageEmbed");
CHECK(ft_.image_embed_func_.defined()) << "`image_embed` function is not found in the model. ";

int tmp_h = 0, tmp_w = 0;
CalculateResizeShape(image, this->model_type_, &tmp_h, &tmp_w);
ShapeTuple resize_h = {tmp_h};
ShapeTuple resize_w = {tmp_w};

CalculateCropShape(image, this->model_type_, &tmp_h, &tmp_w);
ShapeTuple crop_h = {tmp_h};
ShapeTuple crop_w = {tmp_w};

auto image_dref_or_nd = ft_.CopyToWorker0(image, "image", image.Shape());
ObjectRef embeddings = ft_.image_embed_func_(image_dref_or_nd, params_);
ObjectRef embeddings =
ft_.image_embed_func_(image_dref_or_nd, resize_h, resize_w, crop_h, crop_w, params_);
if (dst != nullptr) {
CHECK(dst->defined());
ft_.nd_copy_embedding_to_offset_func_(embeddings, *dst, offset);
Expand Down Expand Up @@ -1003,6 +1015,7 @@ class ModelImpl : public ModelObj {
json::LookupOrDefault<int64_t>(config, "attention_sink_size", this->attention_sink_size_);
this->attention_sink_size_ = std::max(this->attention_sink_size_, 0);
this->vocab_size_ = json::Lookup<int64_t>(config, "vocab_size");
this->model_type_ = json::Lookup<std::string>(config, "model_type");
}

//----------------------------
Expand All @@ -1019,6 +1032,7 @@ class ModelImpl : public ModelObj {
DLDataType hidden_states_dtype_;
int vocab_size_ = -1;
int image_embed_size_ = -1;
std::string model_type_;
//----------------------------
// TVM related states
//----------------------------
Expand Down
58 changes: 58 additions & 0 deletions cpp/support/vlm_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*!
* Copyright (c) 2023-2024 by Contributors
* \file support/image_utils.cc
*/
#include "vlm_utils.h"

#include <cmath>

namespace mlc {
namespace llm {

void CalculateResizeShape(tvm::runtime::NDArray image_data, std::string model_type,
int* p_target_height, int* p_target_width) {
ICHECK_EQ(image_data->shape[3], 3) << "Image format must be NHWC";
int height = image_data->shape[1];
int width = image_data->shape[2];

if ("phi3_v" == model_type) {
const int hd_num = 4;
double ratio = static_cast<double>(width) / height;
int scale = 1;
while (scale * std::ceil(scale / ratio) <= hd_num) {
scale += 1;
}
scale -= 1;
*p_target_width = static_cast<int>(scale * 336);
*p_target_height = static_cast<int>(*p_target_width / ratio);
}
}

void CalculatePadShape(tvm::runtime::NDArray image_data, std::string model_type, int* p_pad_height,
int* p_pad_width) {
ICHECK_EQ(image_data->shape[3], 3) << "Image format must be NHWC";
if ("phi3_v" == model_type) {
int resized_height = 0, resized_width = 0;
CalculateResizeShape(image_data, model_type, &resized_height, &resized_width);
int tar = (int)(ceil(resized_height / 336.0) * 336);
int top_padding = (int)((tar - resized_height) / 2);
int bottom_padding = tar - resized_height - top_padding;
ICHECK_EQ(tar, resized_height + top_padding + bottom_padding) << "Padding size not equal!";
*p_pad_height = tar;
*p_pad_width = resized_width;
}
}

void CalculateCropShape(tvm::runtime::NDArray image_data, std::string model_type,
int* p_crop_height, int* p_crop_width) {
ICHECK_EQ(image_data->shape[3], 3) << "Image format must be NHWC";
if ("phi3_v" == model_type) {
int pad_h = 0, pad_w = 0;
CalculatePadShape(image_data, model_type, &pad_h, &pad_w);
*p_crop_height = pad_h / 336;
*p_crop_width = pad_w / 336;
}
}

} // namespace llm
} // namespace mlc
48 changes: 48 additions & 0 deletions cpp/support/vlm_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*!
* Copyright (c) 2023-2024 by Contributors
* \file support/vlm_utils.h
* \brief Tools for debug purposes.
*/
#ifndef MLC_LLM_SUPPORT_VLM_UTILS_H_
#define MLC_LLM_SUPPORT_VLM_UTILS_H_

#include <tvm/runtime/ndarray.h>

#include <string>

namespace mlc {
namespace llm {

/*!
* \brief Calculate the target height and width for resizing an image based on the input data and
* model type. \param image_data The input image data as a TVM NDArray. \param model_type The type
* of the model influencing the resizing parameters (e.g., phi3v). \param target_height Reference to
* the variable where the calculated target height will be stored. \param target_width Reference to
* the variable where the calculated target width will be stored.
*/
void CalculateResizeShape(tvm::runtime::NDArray image_data, std::string model_type,
int* p_target_height, int* p_target_width);
/*!
* \brief Calculate the padding height and width for an image based on the input data and model
* type. \param image_data The input image data as a TVM NDArray. \param model_type The type of the
* model influencing the padding parameters (e.g., phi3v). \param pad_height Reference to the
* variable where the calculated padding height will be stored. \param pad_width Reference to the
* variable where the calculated padding width will be stored.
*/
void CalculatePadShape(tvm::runtime::NDArray image_data, std::string model_type, int* p_pad_height,
int* p_pad_width);

/*!
* \brief Calculate the cropping height and width for an image based on the input data and model
* type. \param image_data The input image data as a TVM NDArray. \param model_type The type of the
* model influencing the cropping parameters (e.g., phi3v). \param crop_height Reference to the
* variable where the calculated cropping height will be stored. \param crop_width Reference to the
* variable where the calculated cropping width will be stored.
*/
void CalculateCropShape(tvm::runtime::NDArray image_data, std::string model_type,
int* p_crop_height, int* p_crop_width);

} // namespace llm
} // namespace mlc

#endif // MLC_LLM_SUPPORT_IMAGE_UTILS_H_
3 changes: 2 additions & 1 deletion python/mlc_llm/compiler_pass/low_batch_specialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def transform_module(
symbolic_vars = set(
expr for shape in shapes for expr in shape if isinstance(expr, tir.Var)
)
assert len(symbolic_vars) == 1, symbolic_vars
if len(symbolic_vars) != 1:
continue
gemm_mod = IRModule({})
gemm_mod["main"] = func
gemm_mod = dl.ApplyDefaultSchedule(
Expand Down
Loading
Loading