Skip to content

Commit

Permalink
Replace mutable references with pointers for vlm util functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mengshyu committed Jan 5, 2025
1 parent 868092b commit 950190a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 20 deletions.
4 changes: 2 additions & 2 deletions cpp/serve/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ class ModelImpl : public ModelObj {
CHECK(ft_.image_embed_func_.defined()) << "`image_embed` function is not found in the model. ";

int tmp_h = 0, tmp_w = 0;
CalculateResizeShape(image, this->model_type_, tmp_h, tmp_w);
CalculateResizeShape(image, this->model_type_, &tmp_h, &tmp_w);
ShapeTuple resize_h = {tmp_h};
ShapeTuple resize_w = {tmp_w};

CalculateCropShape(image, this->model_type_, tmp_h, tmp_w);
CalculateCropShape(image, this->model_type_, &tmp_h, &tmp_w);
ShapeTuple crop_h = {tmp_h};
ShapeTuple crop_w = {tmp_w};

Expand Down
26 changes: 13 additions & 13 deletions cpp/support/vlm_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace mlc {
namespace llm {

void CalculateResizeShape(tvm::runtime::NDArray image_data, std::string model_type,
int& target_height, int& target_width) {
int* p_target_height, int* p_target_width) {
ICHECK_EQ(image_data->shape[3], 3) << "Image format must be NHWC";
int height = image_data->shape[1];
int width = image_data->shape[2];
Expand All @@ -23,34 +23,34 @@ void CalculateResizeShape(tvm::runtime::NDArray image_data, std::string model_ty
scale += 1;
}
scale -= 1;
target_width = static_cast<int>(scale * 336);
target_height = static_cast<int>(target_width / ratio);
*p_target_width = static_cast<int>(scale * 336);
*p_target_height = static_cast<int>(*p_target_width / ratio);
}
}

void CalculatePadShape(tvm::runtime::NDArray image_data, std::string model_type, int& pad_height,
int& pad_width) {
void CalculatePadShape(tvm::runtime::NDArray image_data, std::string model_type, int* p_pad_height,
int* p_pad_width) {
ICHECK_EQ(image_data->shape[3], 3) << "Image format must be NHWC";
if ("phi3_v" == model_type) {
int resized_height = 0, resized_width = 0;
CalculateResizeShape(image_data, model_type, resized_height, resized_width);
CalculateResizeShape(image_data, model_type, &resized_height, &resized_width);
int tar = (int)(ceil(resized_height / 336.0) * 336);
int top_padding = (int)((tar - resized_height) / 2);
int bottom_padding = tar - resized_height - top_padding;
ICHECK_EQ(tar, resized_height + top_padding + bottom_padding) << "Padding size not equal!";
pad_height = tar;
pad_width = resized_width;
*p_pad_height = tar;
*p_pad_width = resized_width;
}
}

void CalculateCropShape(tvm::runtime::NDArray image_data, std::string model_type, int& crop_height,
int& crop_width) {
void CalculateCropShape(tvm::runtime::NDArray image_data, std::string model_type,
int* p_crop_height, int* p_crop_width) {
ICHECK_EQ(image_data->shape[3], 3) << "Image format must be NHWC";
if ("phi3_v" == model_type) {
int pad_h = 0, pad_w = 0;
CalculatePadShape(image_data, model_type, pad_h, pad_w);
crop_height = pad_h / 336;
crop_width = pad_w / 336;
CalculatePadShape(image_data, model_type, &pad_h, &pad_w);
*p_crop_height = pad_h / 336;
*p_crop_width = pad_w / 336;
}
}

Expand Down
10 changes: 5 additions & 5 deletions cpp/support/vlm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@ namespace llm {
* the variable where the calculated target width will be stored.
*/
void CalculateResizeShape(tvm::runtime::NDArray image_data, std::string model_type,
int& target_height, int& target_width);
int* p_target_height, int* p_target_width);
/*!
* \brief Calculate the padding height and width for an image based on the input data and model
* type. \param image_data The input image data as a TVM NDArray. \param model_type The type of the
* model influencing the padding parameters (e.g., phi3v). \param pad_height Reference to the
* variable where the calculated padding height will be stored. \param pad_width Reference to the
* variable where the calculated padding width will be stored.
*/
void CalculatePadShape(tvm::runtime::NDArray image_data, std::string model_type, int& pad_height,
int& pad_width);
void CalculatePadShape(tvm::runtime::NDArray image_data, std::string model_type, int* p_pad_height,
int* p_pad_width);

/*!
* \brief Calculate the cropping height and width for an image based on the input data and model
Expand All @@ -39,8 +39,8 @@ void CalculatePadShape(tvm::runtime::NDArray image_data, std::string model_type,
* variable where the calculated cropping height will be stored. \param crop_width Reference to the
* variable where the calculated cropping width will be stored.
*/
void CalculateCropShape(tvm::runtime::NDArray image_data, std::string model_type, int& crop_height,
int& crop_width);
void CalculateCropShape(tvm::runtime::NDArray image_data, std::string model_type,
int* p_crop_height, int* p_crop_width);

} // namespace llm
} // namespace mlc
Expand Down

0 comments on commit 950190a

Please sign in to comment.