From 04b7bd991d48224bd64bcf35afd5141238cbfb05 Mon Sep 17 00:00:00 2001 From: "wangyunlong.115" Date: Wed, 29 Oct 2025 16:11:39 +0800 Subject: [PATCH] feat: support Qwen3-VL. --- CMakeLists.txt | 6 +- third_party/xllm_ops | 2 +- xllm/core/framework/hf_model_loader.cpp | 6 + xllm/core/framework/model/model_args.h | 13 + .../core/framework/model/model_input_params.h | 8 + xllm/core/framework/quant_args.h | 2 +- xllm/core/layers/CMakeLists.txt | 1 + xllm/core/layers/base_layer.cpp | 0 xllm/core/layers/npu/CMakeLists.txt | 2 + .../npu/npu_qwen3_moe_decoder_layer_impl.cpp | 6 +- .../npu_qwen3_vision_encoder_layer_impl.cpp | 285 +++++++ .../npu/npu_qwen3_vision_encoder_layer_impl.h | 125 +++ xllm/core/layers/qwen3_vision_encode_layer.h | 39 + xllm/core/runtime/vlm_engine.cpp | 0 xllm/models/llm/llm_model_base.h | 14 +- xllm/models/llm/qwen3.h | 167 ++++ xllm/models/llm/qwen3_moe.h | 119 ++- xllm/models/models.h | 1 + xllm/models/vlm/qwen3_vl.h | 800 ++++++++++++++++++ xllm/processors/CMakeLists.txt | 0 xllm/processors/qwen2_vl_image_processor.cpp | 17 +- 21 files changed, 1580 insertions(+), 33 deletions(-) mode change 100644 => 100755 CMakeLists.txt mode change 100644 => 100755 xllm/core/framework/model/model_args.h mode change 100644 => 100755 xllm/core/framework/model/model_input_params.h mode change 100644 => 100755 xllm/core/layers/base_layer.cpp mode change 100644 => 100755 xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp create mode 100755 xllm/core/layers/npu/npu_qwen3_vision_encoder_layer_impl.cpp create mode 100755 xllm/core/layers/npu/npu_qwen3_vision_encoder_layer_impl.h create mode 100644 xllm/core/layers/qwen3_vision_encode_layer.h mode change 100644 => 100755 xllm/core/runtime/vlm_engine.cpp mode change 100644 => 100755 xllm/models/llm/qwen3.h mode change 100644 => 100755 xllm/models/models.h create mode 100755 xllm/models/vlm/qwen3_vl.h mode change 100644 => 100755 xllm/processors/CMakeLists.txt mode change 100755 => 100644 xllm/processors/qwen2_vl_image_processor.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt old mode 100644 new mode 100755 index c7765ee7..3514372d --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,20 +28,20 @@ if(USE_NPU) if(DEVICE_TYPE STREQUAL "USE_A3") message("downloading a3 arm xllm kernels") file(DOWNLOAD - "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.6.0/xllm_kernels-1.3.1-Linux.a3.arm.rpm" + "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.6.0/xllm_kernels-1.3.2-Linux.a3.arm.rpm" "${CMAKE_BINARY_DIR}/xllm_kernels.rpm" ) else() if(DEVICE_ARCH STREQUAL "ARM") message("downloading a2 arm xllm_kernels") file(DOWNLOAD - "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.6.0/xllm_kernels-1.3.1-Linux.a2.arm.rpm" + "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.6.0/xllm_kernels-1.3.2-Linux.a2.arm.rpm" "${CMAKE_BINARY_DIR}/xllm_kernels.rpm" ) else() message("downloading a2 x86 xllm_kernels") file(DOWNLOAD - "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.6.0/xllm_kernels-1.3.1-Linux.a2.x86.rpm" + "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.6.0/xllm_kernels-1.3.2-Linux.a2.x86.rpm" "${CMAKE_BINARY_DIR}/xllm_kernels.rpm" ) endif() diff --git a/third_party/xllm_ops b/third_party/xllm_ops index 2cda9bfb..797a0cb1 160000 --- a/third_party/xllm_ops +++ b/third_party/xllm_ops @@ -1 +1 @@ -Subproject commit 2cda9bfbf2fd827972591137a411c0ab79b644d3 +Subproject commit 797a0cb195d33edbc3033744f5ca6a36981a8a3f diff --git a/xllm/core/framework/hf_model_loader.cpp b/xllm/core/framework/hf_model_loader.cpp index 443ba388..a01fc3c5 100644 --- a/xllm/core/framework/hf_model_loader.cpp +++ b/xllm/core/framework/hf_model_loader.cpp @@ -360,6 +360,12 @@ bool HFModelLoader::load_image_preprocessor_args( image_prerocess_data["norm_std"].get>(); } + args_.mm_image_shortest_edge() = + image_preprocess_reader.value_or("size.shortest_edge", 0); + + args_.mm_image_longest_edge() = + image_preprocess_reader.value_or("size.longest_edge", 0); + args_.mm_image_min_pixels() = image_preprocess_reader.value_or("min_pixels", 0); diff --git a/xllm/core/framework/model/model_args.h b/xllm/core/framework/model/model_args.h old mode 100644 new mode 100755 index 8bb49591..ff67b22b --- a/xllm/core/framework/model/model_args.h +++ b/xllm/core/framework/model/model_args.h @@ -242,12 +242,15 @@ struct ModelArgs { PROPERTY(int, mm_window_size) = 0; PROPERTY(std::vector, mm_fullatt_block_indexes); + PROPERTY(std::vector, mm_deepstack_visual_indexes); PROPERTY(int, mm_tokens_per_second) = 0; PROPERTY(int, mm_temporal_patch_size) = 0; // VLM model projector's mm_projector_type PROPERTY(std::string, mm_projector_type); + // + PROPERTY(int64_t, mm_num_position_embeddings); // VLM model projector's mm_projector_hidden_act PROPERTY(std::string, mm_projector_hidden_act); @@ -284,6 +287,9 @@ struct ModelArgs { PROPERTY(int, mm_image_min_pixels) = 0; PROPERTY(int, mm_image_max_pixels) = 0; + PROPERTY(int64_t, mm_image_shortest_edge) = 0; + PROPERTY(int64_t, mm_image_longest_edge) = 0; + PROPERTY(int, mm_image_patch_size) = 0; PROPERTY(int, mm_image_temporal_patch_size) = 0; PROPERTY(int, mm_image_merge_size) = 0; @@ -447,6 +453,11 @@ inline std::ostream& operator<<(std::ostream& os, const ModelArgs& args) { os << index << ","; } os << "]"; + os << ", mm_deepstack_visual_indexes: ["; + for (auto& index : args.mm_deepstack_visual_indexes()) { + os << index << ","; + } + os << "]"; os << ", mm_tokens_per_second: " << args.mm_tokens_per_second(); os << ", mm_temporal_patch_size: " << args.mm_temporal_patch_size(); os << ", mm_projector_type: " << args.mm_projector_type(); @@ -474,6 +485,8 @@ inline std::ostream& operator<<(std::ostream& os, const ModelArgs& args) { os << std << ", "; } os << "]"; + os << ", mm_image_shortest_edge: " << args.mm_image_shortest_edge(); + os << ", mm_image_longest_edge: " << args.mm_image_longest_edge(); os << ", mm_image_min_pixels: " << args.mm_image_min_pixels(); os << ", mm_image_max_pixels: " << args.mm_image_max_pixels(); os << ", mm_image_patch_size: " << args.mm_image_patch_size(); diff --git a/xllm/core/framework/model/model_input_params.h b/xllm/core/framework/model/model_input_params.h old mode 100644 new mode 100755 index aaaae36d..6669baaa --- a/xllm/core/framework/model/model_input_params.h +++ b/xllm/core/framework/model/model_input_params.h @@ -67,6 +67,9 @@ struct ModelInputParams { params.input_embedding = safe_to(input_embedding, device); + params.deep_stacks = deep_stacks; + params.visual_pos_masks = visual_pos_masks; + params.mm_data = MMData::to(mm_data, device); params.dp_global_token_nums = dp_global_token_nums; params.prefill_seq_len = prefill_seq_len; @@ -149,6 +152,11 @@ struct ModelInputParams { // multimodal MMData mm_data; + // deep_stack for Qwen3-VL + mutable std::vector deep_stacks; + // visual pos mask for Qwen3-VL + mutable torch::Tensor visual_pos_masks; + // num tokens of all workers,mainly used for dp case std::vector dp_global_token_nums; // whether the kv-cache is empty for all sequences,mainly used for dp case diff --git a/xllm/core/framework/quant_args.h b/xllm/core/framework/quant_args.h index efd41e5d..ec1506ee 100644 --- a/xllm/core/framework/quant_args.h +++ b/xllm/core/framework/quant_args.h @@ -27,7 +27,7 @@ struct QuantArgs { PROPERTY(std::string, quant_method); PROPERTY(std::string, quantize_type); - PROPERTY(std::string, torch_dtype); + PROPERTY(std::string, torch_dtype) = "bfloat16"; // quantization bits PROPERTY(int64_t, bits) = 0; diff --git a/xllm/core/layers/CMakeLists.txt b/xllm/core/layers/CMakeLists.txt index 6ad3d0c7..53b987ac 100644 --- a/xllm/core/layers/CMakeLists.txt +++ b/xllm/core/layers/CMakeLists.txt @@ -53,6 +53,7 @@ cc_library( multi_head_attention.h qwen2_decoder_layer.h qwen2dot5_vision_decode_layer.h + qwen3_vision_encode_layer.h qwen3_decoder_layer.h qwen3_moe_decoder_layer.h rms_norm.h diff --git a/xllm/core/layers/base_layer.cpp b/xllm/core/layers/base_layer.cpp old mode 100644 new mode 100755 diff --git a/xllm/core/layers/npu/CMakeLists.txt b/xllm/core/layers/npu/CMakeLists.txt index ba662659..61f7759d 100644 --- a/xllm/core/layers/npu/CMakeLists.txt +++ b/xllm/core/layers/npu/CMakeLists.txt @@ -10,6 +10,7 @@ cc_library( npu_pos_embedding_impl.h npu_lm_head_impl.h npu_qwen2dot5_vision_encoder_layer_impl.h + npu_qwen3_vision_encoder_layer_impl.h npu_qwen3_moe_decoder_layer_impl.h # atb_parallel_linear.h npu_block_copy_impl.h @@ -29,6 +30,7 @@ cc_library( npu_pos_embedding_impl.cpp npu_lm_head_impl.cpp npu_qwen2dot5_vision_encoder_layer_impl.cpp + npu_qwen3_vision_encoder_layer_impl.cpp npu_qwen3_moe_decoder_layer_impl.cpp # atb_parallel_linear.cpp npu_block_copy_impl.cpp diff --git a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp old mode 100644 new mode 100755 index 3aefc3a3..9805c3c5 --- a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp @@ -376,7 +376,7 @@ void NpuQwen3MoeDecoderLayerImpl::initialize_mlp_parameters( const ModelArgs& args, const ParallelArgs& parallel_args) { param.hasSharedExpert = (args.n_shared_experts() > 0); - param.hasSharedExpertGate = true; + param.hasSharedExpertGate = false; param.processLogits = "normalization"; param.numOfSelectedExperts = {args.num_experts_per_tok()}; @@ -492,7 +492,6 @@ void NpuQwen3MoeDecoderLayerImpl::process_expert_weights( const int local_index = expert_index % num_experts_per_partition_; const bool is_sharded = shard_map.count(index); - std::lock_guard lock(experts_mutex_); torch::Tensor tmp_tensor = is_sharded ? get_sharded_tensor(state_dict, name, @@ -517,8 +516,6 @@ void NpuQwen3MoeDecoderLayerImpl::process_mlp_common_weights( const int index = get_mapped_index(name, weight_mapping); const bool is_sharded = shard_map.count(index); - std::lock_guard lock(shared_experts_mutex_); - torch::Tensor tmp_tensor = is_sharded ? get_sharded_tensor(state_dict, name, @@ -650,7 +647,6 @@ void NpuQwen3MoeDecoderLayerImpl::verify_loaded_weights( void NpuQwen3MoeDecoderLayerImpl::merge_loaded_weights() { merge_experts_weights(); - at_weight_tensors_[IN_QKV_WEIGHT_0] = torch::cat({at_weight_tensors_[IN_QKV_WEIGHT_0], at_weight_tensors_[IN_QKV_WEIGHT_1], diff --git a/xllm/core/layers/npu/npu_qwen3_vision_encoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen3_vision_encoder_layer_impl.cpp new file mode 100755 index 00000000..3a4bb674 --- /dev/null +++ b/xllm/core/layers/npu/npu_qwen3_vision_encoder_layer_impl.cpp @@ -0,0 +1,285 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "npu_qwen3_vision_encoder_layer_impl.h" + +#include +#include + +#include +#include + +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" +#include "torch_npu/csrc/core/npu/NPUException.h" +#include "xllm_kernels/models/qwen3_vl/qwen3_vl_encoder.h" + +namespace xllm { +namespace layer { + +enum VisionEncoderLayerTensorId : int { + IN_INPUT_NORM_WEIGHT = 0, + IN_INPUT_NORM_BIAS, + IN_POST_NORM_WEIGHT, + IN_POST_NORM_BIAS, + IN_QKV_WEIGHT, + IN_QKV_BIAS, + IN_WATTENTION_OUT_WEIGHT, + IN_WATTENTION_OUT_BIAS, + IN_LINEAR_FC1_WEIGHT, + IN_LINEAR_FC1_BIAS, + IN_LINEAR_FC2_WEIGHT, + IN_LINEAR_FC2_BIAS, + IN_VISION_Q_WEIGHT, + IN_VISION_Q_BIAS, + IN_VISION_K_WEIGHT, + IN_VISION_K_BIAS, + IN_VISION_V_WEIGHT, + IN_VISION_V_BIAS +}; + +const uint64_t WEIGHT_COUNT_PER_LAYER = 18; + +static std::vector> WEIGHT_MAPPING = { + {IN_INPUT_NORM_WEIGHT, "norm1.weight"}, + {IN_INPUT_NORM_BIAS, "norm1.bias"}, + {IN_POST_NORM_WEIGHT, "norm2.weight"}, + {IN_POST_NORM_BIAS, "norm2.bias"}, + {IN_QKV_WEIGHT, "attn.qkv.weight"}, + {IN_QKV_BIAS, "attn.qkv.bias"}, + {IN_WATTENTION_OUT_WEIGHT, "attn.proj.weight"}, + {IN_WATTENTION_OUT_BIAS, "attn.proj.bias"}, + {IN_LINEAR_FC1_WEIGHT, "mlp.linear_fc1.weight"}, + {IN_LINEAR_FC1_BIAS, "mlp.linear_fc1.bias"}, + {IN_LINEAR_FC2_WEIGHT, "mlp.linear_fc2.weight"}, + {IN_LINEAR_FC2_BIAS, "mlp.linear_fc2.bias"}}; + +// {weight,dim} +static std::map WEIGHT_SHARD = { + {IN_WATTENTION_OUT_WEIGHT, 1}, + {IN_LINEAR_FC1_WEIGHT, 0}, + {IN_LINEAR_FC1_BIAS, 0}, + {IN_LINEAR_FC2_WEIGHT, 1}, +}; + +void NpuQwen3VisionEncoderLayerImpl::param_from_args( + atb_speed::qwen::VisionEncoderLayerParam& param, + const ModelArgs& args, + const ParallelArgs& parallel_args) { + param.isBF16 = args.dtype() == "bfloat16"; + param.rmsNormEps = args.rms_norm_eps(); + param.worldSize = parallel_args.world_size(); + param.numAttentionHeadsPerRank = + args.mm_num_attention_heads() / param.worldSize; + param.hiddenSizePerAttentionHead = + args.mm_hidden_size() / args.mm_num_attention_heads(); + std::optional optionalValue = args.mm_num_attention_heads(); + param.numKeyValueHeadsPerRank = + static_cast(optionalValue.value()) / param.worldSize; + param.rank = parallel_args.rank(); + param.backend = "lccl"; + param.enableLogN = false; +} + +NpuQwen3VisionEncoderLayerImpl::NpuQwen3VisionEncoderLayerImpl( + const ModelContext& context) + : NpuBaseLayer(context) { + auto model_args = context.get_model_args(); + auto parallel_args = context.get_parallel_args(); + auto options = context.get_tensor_options(); + param_from_args(encode_param_, model_args, parallel_args); + at_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); + atb_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); + dtype_ = c10::typeMetaToScalarType(options.dtype()); + device_id_ = options.device().index(); + placeholder_ = atb_speed::Utils::AtTensor2Tensor( + torch::zeros({1}).to(device_).to(dtype_)); + at_placeholder_ = torch::zeros({1}).to(device_).to(dtype_); + for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { + at_weight_tensors_[i] = torch::zeros({1}).to(options); + } +} + +void NpuQwen3VisionEncoderLayerImpl::verify_loaded_weights() const { + for (const auto& [index, name] : WEIGHT_MAPPING) { + CHECK(at_weight_tensors_[index].sizes() != std::vector({1})) + << "weight is not loaded for " << name; + } +} + +void NpuQwen3VisionEncoderLayerImpl::merge_loaded_weights() { + // spilt pack qkv weight when enable tp + get_weights_col_packed_qkv(); + if (encode_param_.worldSize > 1) { + // merge qkv weight + auto new_qkv_weight = torch::cat({at_weight_tensors_[IN_VISION_Q_WEIGHT], + at_weight_tensors_[IN_VISION_K_WEIGHT], + at_weight_tensors_[IN_VISION_V_WEIGHT]}, + 0); + at_weight_tensors_[IN_QKV_WEIGHT] = new_qkv_weight; + at_weight_tensors_[IN_VISION_Q_WEIGHT] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_VISION_K_WEIGHT] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_VISION_V_WEIGHT] = torch::zeros({1}).to(device_); + + // merge qkv bias + auto new_qkv_bias = torch::cat({at_weight_tensors_[IN_VISION_Q_BIAS], + at_weight_tensors_[IN_VISION_K_BIAS], + at_weight_tensors_[IN_VISION_V_BIAS]}, + 0); + at_weight_tensors_[IN_QKV_BIAS] = new_qkv_bias; + at_weight_tensors_[IN_VISION_Q_BIAS] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_VISION_K_BIAS] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_VISION_V_BIAS] = torch::zeros({1}).to(device_); + } + c10_npu::NPUCachingAllocator::emptyCache(); + for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { + atb_weight_tensors_[i] = + atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[i]); + } + + init_layer(); +} +// tp spilt weight +void NpuQwen3VisionEncoderLayerImpl::get_weights_col_packed_qkv() { + int rank = encode_param_.rank; + int worldSize = encode_param_.worldSize; + // split qkv weight + qkv_weight = torch::chunk(at_weight_tensors_[IN_QKV_WEIGHT], 3, 0); + qkv_bias = torch::chunk(at_weight_tensors_[IN_QKV_BIAS], 3, 0); + // weight + at_weight_tensors_[IN_VISION_Q_WEIGHT] = + (qkv_weight[0].chunk(worldSize, 0))[rank]; + at_weight_tensors_[IN_VISION_K_WEIGHT] = + (qkv_weight[1].chunk(worldSize, 0))[rank]; + at_weight_tensors_[IN_VISION_V_WEIGHT] = + (qkv_weight[2].chunk(worldSize, 0))[rank]; + // bias + at_weight_tensors_[IN_VISION_Q_BIAS] = + (qkv_bias[0].chunk(worldSize, 0))[rank]; + at_weight_tensors_[IN_VISION_K_BIAS] = + (qkv_bias[1].chunk(worldSize, 0))[rank]; + at_weight_tensors_[IN_VISION_V_BIAS] = + (qkv_bias[2].chunk(worldSize, 0))[rank]; +} + +void NpuQwen3VisionEncoderLayerImpl::load_state_dict( + const StateDict& state_dict) { + for (const auto& [index, name] : WEIGHT_MAPPING) { + if (WEIGHT_SHARD.find(index) != WEIGHT_SHARD.end()) { + set_weight(state_dict, name, index, WEIGHT_SHARD[index]); + } else { + set_weight(state_dict, name, index); + } + } +} + +int64_t NpuQwen3VisionEncoderLayerImpl::init_layer() { + name_ = "qwen3_encoder_layer"; + model_name_ = "qwen3_vl"; + CHECK_OPERATION_STATUS_RETURN(init_node(encode_node_, encode_param_)); + return atb::NO_ERROR; +} + +int64_t NpuQwen3VisionEncoderLayerImpl::init_node( + atb_speed::Model::Node& node, + atb_speed::qwen::VisionEncoderLayerParam& param) { + atb::Operation* operation = nullptr; + atb_speed::qwen::Qwen3VL_EncoderLayer(param, &operation); + node.operation.reset(operation); + if (node.operation == nullptr) { + LOG(ERROR) << "node.operation is null"; + return -1; + } + if (node.operation->GetInputNum() < 1) { + LOG(ERROR) << "Can not resize number which is smaller than 1"; + return -1; + } + node.inTensors.resize(node.operation->GetInputNum()); + node.outTensors.resize(1); + size_t inTensorId = 1; + + for (size_t weightTensorId = 0; weightTensorId < WEIGHT_COUNT_PER_LAYER; + ++weightTensorId) { + node.inTensors.at(weightTensorId) = &atb_weight_tensors_[weightTensorId]; + } + + node.variantPack.inTensors.reserve(node.inTensors.size()); + node.variantPack.inTensors.resize(node.inTensors.size()); + node.variantPack.outTensors.reserve(1); + node.variantPack.outTensors.resize(1); + return atb::NO_ERROR; +} + +torch::Tensor NpuQwen3VisionEncoderLayerImpl::forward( + torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& cu_seqlen, + std::vector& cu_seqlen_vec, + ModelInputParams& input_params, + int node_id, + aclrtEvent* event, + std::atomic* event_flag) { + atb::Status st; + + build_node_variant_pack(encode_node_, + x, + cos_pos, + sin_pos, + cu_seqlen, + cu_seqlen_vec, + input_params, + true); + // mstxRangeEnd(id); + st = execute_node(encode_node_, node_id); + LOG_IF(FATAL, st != 0) << model_name_ + << "excute encode layer fail, error code: " << st; + return x; +} + +void NpuQwen3VisionEncoderLayerImpl::build_node_variant_pack( + atb_speed::Model::Node& node, + torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& cu_seqlen, + std::vector& cu_seqlen_vec, + ModelInputParams& input_params, + bool is_prefill) { + internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x); + + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER) = internal_tensors_; + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 1) = + atb_speed::Utils::AtTensor2Tensor(cos_pos); + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 2) = + atb_speed::Utils::AtTensor2Tensor(sin_pos); + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 3) = + atb_speed::Utils::AtTensor2Tensor(cu_seqlen); + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 3).hostData = + cu_seqlen_vec.data(); + + for (size_t i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { + CHECK_THROW(node.inTensors.at(i) == nullptr, + model_name_ << "inTensor " << i << "is NULL"); + node.variantPack.inTensors.at(i) = *node.inTensors.at(i); + // LOG(INFO) << model_name_ << "inTensors[" << i << "]:" + // << atb_speed::TensorUtil::TensorToString( + // node.variantPack.inTensors.at(i)); + } + + node.variantPack.outTensors.at(0) = internal_tensors_; +} + +} // namespace layer +} // namespace xllm diff --git a/xllm/core/layers/npu/npu_qwen3_vision_encoder_layer_impl.h b/xllm/core/layers/npu/npu_qwen3_vision_encoder_layer_impl.h new file mode 100755 index 00000000..6000ed7d --- /dev/null +++ b/xllm/core/layers/npu/npu_qwen3_vision_encoder_layer_impl.h @@ -0,0 +1,125 @@ + +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once +#ifdef TORCH_HIGHER_THAN_PTA6 +#include +#include +#else +#include +#include +#endif + +#include + +#include + +#include "atb/atb_infer.h" +#include "atb_speed/base/hosttensor_binder.h" +#include "atb_speed/base/model.h" +#include "atb_speed/log.h" +#include "atb_speed/utils/model_factory.h" +#include "core/framework/model/model_args.h" +#include "core/framework/model/model_input_params.h" +#include "core/framework/state_dict/state_dict.h" +#include "nlohmann/json.hpp" +#include "npu_base_layer.h" +#include "pytorch/adapter/utils/utils.h" +#include "xllm_kernels/models/qwen3_vl/qwen3_vl_encoder.h" + +namespace xllm { +namespace layer { + +class NpuQwen3VisionEncoderLayerImpl : public NpuBaseLayer { + public: + explicit NpuQwen3VisionEncoderLayerImpl(const ModelContext& context); + + ~NpuQwen3VisionEncoderLayerImpl() {}; + + void load_state_dict(const StateDict& state_dict) override; + + void verify_loaded_weights() const override; + + void merge_loaded_weights() override; + + int64_t init_layer() override; + + torch::Tensor forward(torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& cu_seqlen, + std::vector& cu_seqlen_vec, + ModelInputParams& input_params, + int node_id = 0, + aclrtEvent* event = nullptr, + std::atomic* event_flag = nullptr); + + private: + void build_node_variant_pack(atb_speed::Model::Node& node, + torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& cu_seqlen, + std::vector& cu_seqlen_vec, + ModelInputParams& input_params, + bool is_prefill); + + void get_weights_col_packed_qkv(); + + void param_from_args(atb_speed::qwen::VisionEncoderLayerParam& param, + const ModelArgs& args, + const ParallelArgs& parallel_args); + + int64_t init_node(atb_speed::Model::Node& node, + atb_speed::qwen::VisionEncoderLayerParam& param); + + void pad_qkv_weights(); + + void pad_mlp_weights(); + + torch::Tensor pad_tensor(const torch::Tensor& tensor, + int64_t target_shape, + int64_t dim = 0) { + int64_t pad_size = target_shape - tensor.size(dim); + if (tensor.dim() == 1) { + return torch::nn::functional::pad( + tensor, torch::nn::functional::PadFuncOptions({0, pad_size})); + } else if (tensor.dim() == 2) { + if (1 == dim) + return torch::nn::functional::pad( + tensor, torch::nn::functional::PadFuncOptions({0, pad_size, 0, 0})); + else + return torch::nn::functional::pad( + tensor, torch::nn::functional::PadFuncOptions({0, 0, 0, pad_size})); + } + return tensor; + } + + atb_speed::Model::Node encode_node_; + std::string model_name_; + + atb_speed::qwen::VisionEncoderLayerParam encode_param_; + atb::Tensor internal_tensors_; + atb::Tensor placeholder_; + at::Tensor cu_seqlen_; + at::Tensor at_placeholder_; + std::vector qkv_weight; + std::vector qkv_bias; + int device_id_; +}; + +} // namespace layer +} // namespace xllm diff --git a/xllm/core/layers/qwen3_vision_encode_layer.h b/xllm/core/layers/qwen3_vision_encode_layer.h new file mode 100644 index 00000000..f67c426a --- /dev/null +++ b/xllm/core/layers/qwen3_vision_encode_layer.h @@ -0,0 +1,39 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#if defined(USE_NPU) +#include "npu/npu_qwen3_vision_encoder_layer_impl.h" +#endif + +namespace xllm { +namespace layer { + +#if defined(USE_NPU) +class Qwen3VisionEncoderLayer + : public torch::nn::ModuleHolder { + public: + using torch::nn::ModuleHolder::ModuleHolder; + using Impl __attribute__((__unused__)) = NpuQwen3VisionEncoderLayerImpl; + + Qwen3VisionEncoderLayer(const ModelContext& context) + : ModuleHolder( + std::make_shared(context)) {} +}; +#endif + +} // namespace layer +} // namespace xllm diff --git a/xllm/core/runtime/vlm_engine.cpp b/xllm/core/runtime/vlm_engine.cpp old mode 100644 new mode 100755 diff --git a/xllm/models/llm/llm_model_base.h b/xllm/models/llm/llm_model_base.h index 7b4212be..156f8169 100644 --- a/xllm/models/llm/llm_model_base.h +++ b/xllm/models/llm/llm_model_base.h @@ -455,22 +455,20 @@ class LlmForCausalLMImplBase : public torch::nn::Module { } void load_model(std::unique_ptr loader, - std::string prefix = "" /*llm model weight prefix*/) { + std::string prefix = "model." /*llm model weight prefix*/) { for (const auto& state_dict : loader->get_state_dicts()) { - model_->load_state_dict( - state_dict->get_dict_with_prefix(prefix + "model.")); + model_->load_state_dict(state_dict->get_dict_with_prefix(prefix)); if (tie_word_embeddings) { lm_head_->load_state_dict( - state_dict->get_dict_with_prefix(prefix + "model.embed_tokens.")); + state_dict->get_dict_with_prefix(prefix + "embed_tokens.")); } else { - lm_head_->load_state_dict( - state_dict->get_dict_with_prefix(prefix + "lm_head.")); + lm_head_->load_state_dict(state_dict->get_dict_with_prefix("lm_head.")); } } #if defined(USE_NPU) // verify - model_->verify_loaded_weights(prefix + "model."); - lm_head_->verify_loaded_weights(prefix + "lm_head."); + model_->verify_loaded_weights(prefix); + lm_head_->verify_loaded_weights("lm_head."); model_->merge_loaded_weights(); // test diff --git a/xllm/models/llm/qwen3.h b/xllm/models/llm/qwen3.h old mode 100644 new mode 100755 index 4fb7023a..8b26d71f --- a/xllm/models/llm/qwen3.h +++ b/xllm/models/llm/qwen3.h @@ -75,6 +75,173 @@ class QWen3ModelImpl : public LlmModelImplBase { blocks_->push_back(block); } } + + torch::Tensor deepstack_process(torch::Tensor hidden_states, + torch::Tensor visual_pos_masks, + torch::Tensor visual_embeds) { + visual_pos_masks = visual_pos_masks.to(hidden_states.device()); + auto selected = hidden_states.index({visual_pos_masks}); + auto local_this = selected + visual_embeds; + hidden_states.index_put_({visual_pos_masks}, local_this); + return hidden_states; + } + + virtual torch::Tensor forward( + std::vector tokens, + std::vector positions, + std::vector& kv_caches, + const std::vector& input_params) { + auto micro_batch_num = tokens.size(); + std::vector hs; + hs.reserve(micro_batch_num); + std::vector> deep_stacks; + deep_stacks.reserve(micro_batch_num); + bool use_deepstack = input_params[0].deep_stacks.size() > 0; + std::vector cos_poss; + cos_poss.reserve(micro_batch_num); + std::vector sin_poss; + sin_poss.reserve(micro_batch_num); + std::vector attn_masks; + attn_masks.reserve(micro_batch_num); + std::vector& input_params_news = + const_cast&>(input_params); + + for (auto i = 0; i < micro_batch_num; ++i) { + if (tokens[i].numel() == 0) { + tokens[i] = torch::tensor({1}).to(torch::kInt32).to(tokens[0].device()); + positions[i] = + torch::tensor({0}).to(torch::kInt32).to(tokens[0].device()); + } + auto inputs_embeds = input_params[i].input_embedding; + torch::Tensor h; + if (inputs_embeds.defined()) { + h = inputs_embeds; + } else { +#if defined(USE_NPU) + h = embed_tokens_[i](tokens[i], 0); +#elif defined(USE_MLU) + h = embed_tokens_[i](tokens[i]); +#endif + } + hs.push_back(std::move(h)); +#if defined(USE_NPU) + if (use_deepstack) { + deep_stacks.push_back( + input_params[i].deep_stacks); // [num_deepstack, hidden_size] + } + auto target_cos_sin = atb_pos_embeds_[i](cos_sin_, positions[i], 0); + auto target_cos_sin_chunks = + target_cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); + auto cos_pos = target_cos_sin_chunks[0].contiguous(); + auto sin_pos = target_cos_sin_chunks[1].contiguous(); + + if (positions[i].dim() == 2) { // mrope + auto apply = [this](torch::Tensor x) { + auto freqs_t = x[0].clone(); + for (int dim_idx = 1; dim_idx <= 2; ++dim_idx) { + int64_t offset = dim_idx; + int64_t section_len = mrope_section_[dim_idx]; + int64_t length = section_len * 3; + auto idx_first_half = + torch::arange(offset, length, 3, torch::kLong); + auto idx_second_half = + torch::arange(offset, length, 3, torch::kLong); + auto idx_tensor = + torch::cat({idx_first_half, idx_second_half}, 0).to(x.device()); + // freqs_t[..., idx] = freqs[dim_idx][..., idx] + auto src = x[dim_idx].index_select(-1, idx_tensor); + freqs_t.index_copy_(-1, idx_tensor, src); + } + return freqs_t; + }; + cos_pos = apply(cos_pos.reshape( + {positions[i].sizes().front(), -1, cos_pos.sizes().back()})); + sin_pos = apply(sin_pos.reshape( + {positions[i].sizes().front(), -1, sin_pos.sizes().back()})); + } + + torch::Tensor attn_mask; + + torch::Tensor max_of_seq = torch::max(input_params[i].kv_seq_lens); + max_seq_len_ = FLAGS_enable_chunked_prefill + ? std::max(max_of_seq.item(), max_seq_len_) + : 128; + attn_mask = attn_mask_.get_attn_mask( + max_seq_len_, cos_pos.dtype().toScalarType(), cos_pos.device()); + + if (FLAGS_enable_chunked_prefill) { + int batch_size = input_params[i].q_seq_lens_vec.size(); + if (batch_size > 0) { + std::vector req_mask_vec; + req_mask_vec.reserve(batch_size); + + for (int j = 0; j < batch_size; j++) { + int start = input_params[i].kv_seq_lens_vec[j] - + input_params[i].q_seq_lens_vec[j]; + int end = input_params[i].kv_seq_lens_vec[j]; + + auto req_mask_slice = attn_mask.slice(0, start, end); + req_mask_vec.emplace_back(req_mask_slice); + } + attn_mask = torch::cat(req_mask_vec, 0); + } + } + + cos_poss.push_back(std::move(cos_pos)); + sin_poss.push_back(std::move(sin_pos)); + attn_masks.push_back(std::move(attn_mask)); +#endif + } +#if defined(USE_NPU) + for (size_t i = 0; i < layers_.size(); i++) { + std::vector events(micro_batch_num, nullptr); + std::vector*> event_flags(micro_batch_num, nullptr); + for (auto j = 0; j < micro_batch_num; ++j) { + if (input_params[j].layer_synchronizer != nullptr) { + events[j] = input_params[j].layer_synchronizer->get_event(i); + event_flags[j] = + input_params[j].layer_synchronizer->get_event_flag(i); + } + } + auto& layer = layers_[i]; + + layer(hs, + cos_poss, + sin_poss, + attn_masks, + kv_caches[i], + input_params_news, + i, + events, + event_flags); + if (use_deepstack) { + for (auto j = 0; j < micro_batch_num; ++j) { + if (deep_stacks[j].size() > 0 && i < deep_stacks[j].size()) { + hs[j] = deepstack_process( + hs[j], input_params[j].visual_pos_masks, deep_stacks[j][i]); + } + } + } + } + auto cancated_h = torch::cat(hs, 0); + return norm_(cancated_h, 0); +#elif defined(USE_MLU) + bool is_prefill = input_params[0].q_max_seq_len > 1; + auto attn_metadata = + layer::AttentionMetadata::build(input_params[0], is_prefill); + + torch::Tensor h; + for (size_t i = 0; i < layers_.size(); i++) { + auto& layer = layers_[i]; + h = layer( + hs[0], positions[0], attn_metadata, kv_caches[i], input_params[0]); + } + return norm_(h); +#endif + } + + private: + torch::Tensor viusal_pos_mask_; }; TORCH_MODULE(QWen3Model); diff --git a/xllm/models/llm/qwen3_moe.h b/xllm/models/llm/qwen3_moe.h index 16771fb9..d540f94d 100644 --- a/xllm/models/llm/qwen3_moe.h +++ b/xllm/models/llm/qwen3_moe.h @@ -15,8 +15,9 @@ limitations under the License. #pragma once -#include +#include +#include #if defined(USE_NPU) #include "core/framework/model/npu_dp_ep_padding.h" #endif @@ -68,7 +69,45 @@ class Qwen3MoeDecoderLayerImpl : public torch::nn::Module { } #endif void load_state_dict(const StateDict& state_dict) { - decoder_layer_->load_state_dict(state_dict); + auto experts_state_dict = state_dict.get_dict_with_prefix("mlp.experts."); + auto fused_gate_up = experts_state_dict.get_tensor("gate_up_proj"); + auto fused_down = experts_state_dict.get_tensor("down_proj"); + + bool is_fused = fused_gate_up.defined() && fused_down.defined(); + + if (is_fused) { + torch::Tensor expert_gate_up = fused_gate_up; + torch::Tensor expert_down = fused_down; + + const int num_experts = expert_gate_up.size(0); + + auto chunks = expert_gate_up.chunk(2, /*dim=*/-1); + auto expert_gate = chunks[0].contiguous(); + auto expert_up = chunks[1].contiguous(); + + std::unordered_map out_state_dict; + for (const auto& [name, tensor] : state_dict) { + if (name.find("self_attn.") == 0 || name.find("mlp.gate.") == 0 || + name.find("input_layernorm.") == 0 || + name.find("post_attention_layernorm.") == 0) { + out_state_dict.emplace(name, tensor); + } + } + + for (int i = 0; i < num_experts; ++i) { + auto gate_i = expert_gate[i].transpose(0, 1); + auto up_i = expert_up[i].transpose(0, 1); + auto down_i = expert_down[i].transpose(0, 1); + + const std::string base = "mlp.experts." + std::to_string(i) + "."; + out_state_dict.emplace(base + "gate_proj.weight", gate_i); + out_state_dict.emplace(base + "up_proj.weight", up_i); + out_state_dict.emplace(base + "down_proj.weight", down_i); + } + decoder_layer_->load_state_dict(StateDict(std::move(out_state_dict))); + } else { + decoder_layer_->load_state_dict(state_dict); + } } #if defined(USE_NPU) @@ -99,7 +138,7 @@ class Qwen3MoeModelImpl : public torch::nn::Module { auto options = context.get_tensor_options(); auto model_args = context.get_model_args(); auto parallel_args = context.get_parallel_args(); - + mrope_section_ = model_args.rope_scaling_mrope_section(); blocks_ = register_module("layers", torch::nn::ModuleList()); layers_.reserve(model_args.n_layers()); // register submodules @@ -153,6 +192,16 @@ class Qwen3MoeModelImpl : public torch::nn::Module { } } + torch::Tensor deepstack_process(torch::Tensor hidden_states, + torch::Tensor visual_pos_masks, + torch::Tensor visual_embeds) { + visual_pos_masks = visual_pos_masks.to(hidden_states.device()); + auto selected = hidden_states.index({visual_pos_masks}); + auto local_this = selected + visual_embeds; + hidden_states.index_put_({visual_pos_masks}, local_this); + return hidden_states; + } + // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence torch::Tensor forward(torch::Tensor tokens, @@ -166,8 +215,14 @@ class Qwen3MoeModelImpl : public torch::nn::Module { } } #if defined(USE_NPU) - auto h = embed_tokens_(tokens, 0); - int64_t input_length = tokens.size(0); + auto inputs_embeds = input_params.input_embedding; + torch::Tensor h; + if (inputs_embeds.defined()) { + h = inputs_embeds; + } else { + h = embed_tokens_(tokens, 0); + } + int64_t input_length = h.size(0); torch::Tensor expert_array = torch::arange( 0, input_length * num_experts_per_tok_, @@ -176,6 +231,31 @@ class Qwen3MoeModelImpl : public torch::nn::Module { auto target_cos_sin_chunks = target_cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); auto cos_pos = target_cos_sin_chunks[0].contiguous(); auto sin_pos = target_cos_sin_chunks[1].contiguous(); + if (positions.dim() == 2) { // mrope + auto apply = [this](torch::Tensor x) { + // auto sections = mrope_section_; + auto freqs_t = x[0].clone(); + for (int dim_idx = 1; dim_idx <= 2; ++dim_idx) { + int64_t offset = dim_idx; // H -> offset=1, W -> offset=2 + int64_t section_len = mrope_section_[dim_idx]; + int64_t length = section_len * 3; + + // indices: [offset, offset+3, offset+6, ..., < length] + auto idx_first_half = torch::arange(offset, length, 3, torch::kLong); + auto idx_second_half = torch::arange(offset, length, 3, torch::kLong); + auto idx_tensor = + torch::cat({idx_first_half, idx_second_half}, 0).to(x.device()); + // freqs_t[..., idx] = freqs[dim_idx][..., idx] + auto src = x[dim_idx].index_select(-1, idx_tensor); + freqs_t.index_copy_(-1, idx_tensor, src); + } + return freqs_t; + }; + cos_pos = apply(cos_pos.reshape( + {positions.sizes().front(), -1, cos_pos.sizes().back()})); + sin_pos = apply(sin_pos.reshape( + {positions.sizes().front(), -1, sin_pos.sizes().back()})); + } torch::Tensor attn_mask; if (num_speculative_tokens_ == 0 || input_params.global_empty_kv_cache) { @@ -184,13 +264,16 @@ class Qwen3MoeModelImpl : public torch::nn::Module { attn_mask = attn_mask_.gen_free_mask( num_speculative_tokens_ + 1, dtype_, device_); } - + auto deep_stacks = input_params.deep_stacks; + int deep_stack_size = deep_stacks.size(); for (size_t i = 0; i < layers_.size(); i++) { aclrtEvent* event = nullptr; std::atomic* event_flag = nullptr; if (input_params.layer_synchronizer != nullptr) { event = input_params.layer_synchronizer->get_event(i); event_flag = input_params.layer_synchronizer->get_event_flag(i); + } else { + LOG(INFO) << "layer_synchronizer is nullptr"; } auto& layer = layers_[i]; layer(h, @@ -202,6 +285,9 @@ class Qwen3MoeModelImpl : public torch::nn::Module { expert_array, event, event_flag); + if (deep_stack_size && i < deep_stack_size) { + h = deepstack_process(h, input_params.visual_pos_masks, deep_stacks[i]); + } } return norm_(h, 0); #elif defined(USE_MLU) @@ -258,6 +344,15 @@ class Qwen3MoeModelImpl : public torch::nn::Module { void set_word_embedding(std::vector& word_embedding) { embed_tokens_ = word_embedding[0]; } + torch::Tensor get_input_embeddings(torch::Tensor input_ids) { +#if defined(USE_NPU) + return embed_tokens_(input_ids, 0); +#elif defined(USE_MLU) + return embed_tokens_(input_ids); +#else + LOG(FATAL) << "Backend not supported: enable USE_NPU or USE_MLU."; +#endif + } private: torch::nn::ModuleList blocks_{nullptr}; @@ -279,6 +374,7 @@ class Qwen3MoeModelImpl : public torch::nn::Module { torch::Tensor cos_sin_; layer::PosEmbedding atb_pos_emb_{nullptr}; #endif + std::vector mrope_section_; }; TORCH_MODULE(Qwen3MoeModel); @@ -329,15 +425,20 @@ class Qwen3MoeForCausalLMImpl : public torch::nn::Module { #endif } - void load_model(std::unique_ptr loader) { + torch::Tensor get_input_embeddings(torch::Tensor input_ids) { + return model_->get_input_embeddings(input_ids); + } + + void load_model(std::unique_ptr loader, + std::string prefix = "model." /*llm model weight prefix*/) { for (const auto& state_dict : loader->get_state_dicts()) { - model_->load_state_dict(state_dict->get_dict_with_prefix("model.")); + model_->load_state_dict(state_dict->get_dict_with_prefix(prefix)); lm_head_->load_state_dict(state_dict->get_dict_with_prefix("lm_head.")); } #if defined(USE_NPU) // verify - model_->verify_loaded_weights("model."); + model_->verify_loaded_weights(prefix); lm_head_->verify_loaded_weights("lm_head."); model_->merge_loaded_weights(); diff --git a/xllm/models/models.h b/xllm/models/models.h old mode 100644 new mode 100755 index 5c77ce86..92b2bf4a --- a/xllm/models/models.h +++ b/xllm/models/models.h @@ -31,6 +31,7 @@ limitations under the License. #include "llm/qwen3_embedding.h" // IWYU pragma: keep #include "vlm/minicpmv.h" // IWYU pragma: keep #include "vlm/qwen2_5_vl.h" // IWYU pragma: keep +#include "vlm/qwen3_vl.h" // IWYU pragma: keep #endif #include "llm/llm_model_base.h" // IWYU pragma: keep diff --git a/xllm/models/vlm/qwen3_vl.h b/xllm/models/vlm/qwen3_vl.h new file mode 100755 index 00000000..dae43c2d --- /dev/null +++ b/xllm/models/vlm/qwen3_vl.h @@ -0,0 +1,800 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include +#include + +#include + +#include "core/framework/kv_cache/kv_cache.h" +#include "core/framework/model/model_input_params.h" +#include "core/layers/lm_head.h" +#include "core/layers/qwen3_vision_encode_layer.h" +#include "core/layers/rms_norm.h" +#include "models/llm/qwen3.h" +#include "models/model_registry.h" +#include "processors/input_processor.h" +#include "processors/qwen2_vl_image_processor.h" +#include "qwen2_5_vl.h" +#include "xllm_kernels/core/include/atb_speed/log.h" + +namespace xllm { + +#define PrintTensor(tensor) print_tensor(tensor, #tensor, 10, true, false); + +class Qwen3_VisionPatchEmbedImpl : public torch::nn::Module { + public: + Qwen3_VisionPatchEmbedImpl(const ModelContext& context) { + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + + auto in_features = model_args.mm_num_channels() * + model_args.mm_temporal_patch_size() * + model_args.mm_patch_size() * model_args.mm_patch_size(); + + auto out_features = model_args.mm_hidden_size(); + + proj_ = register_module( + "proj", + torch::nn::Linear( + torch::nn::LinearOptions(in_features, out_features).bias(true))); + + proj_->weight.set_data(proj_->weight.to(options)); + proj_->bias.set_data(proj_->bias.to(options)); + } + + torch::Tensor forward(torch::Tensor x) { return proj_(x); } + + void load_state_dict(const StateDict& state_dict) { + auto weight = state_dict.get_tensor("proj.weight"); + if (weight.defined()) { + weight = weight.reshape({weight.size(0), -1}); + DCHECK_EQ(proj_->weight.sizes(), weight.sizes()) + << "proj weight size mismatch for " << name(); + proj_->weight.data().copy_(weight); + proj_weight_loaded_ = true; + } + auto bias = state_dict.get_tensor("proj.bias"); + if (bias.defined()) { + bias = bias.reshape({bias.size(0)}); + DCHECK_EQ(proj_->bias.sizes(), bias.sizes()) + << "proj bias size mismatch for " << name(); + proj_->bias.data().copy_(bias); + proj_bias_loaded_ = true; + } + } + + void verify_loaded_weights(const std::string& prefix) const { + CHECK(proj_weight_loaded_) + << "weight is not loaded for " << prefix + "proj.weight"; + CHECK(proj_bias_loaded_) + << "bias is not loaded for " << prefix + "proj.bias"; + } + + private: + bool proj_weight_loaded_ = false; + bool proj_bias_loaded_ = false; + torch::nn::Linear proj_{nullptr}; +}; +TORCH_MODULE(Qwen3_VisionPatchEmbed); + +class Qwen3_VisionBlockImpl : public torch::nn::Module { + public: + Qwen3_VisionBlockImpl(const ModelContext& context) { + // register submodules + encoder_layer_ = register_module("encoder_layer", + layer::Qwen3VisionEncoderLayer(context)); + } + + torch::Tensor forward(torch::Tensor& x, + torch::Tensor& m_cos_pos, + torch::Tensor& m_sin_pos, + torch::Tensor& cu_seq_len, + std::vector& cu_seq_len_vec, + ModelInputParams& input_params, + int node_id) { + return encoder_layer_(x, + m_cos_pos, + m_sin_pos, + cu_seq_len, + cu_seq_len_vec, + input_params, + node_id); + } + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict) { + // call each submodule's load_state_dict function + encoder_layer_->load_state_dict(state_dict); + } + + void verify_loaded_weights(const std::string& prefix) const { + encoder_layer_->verify_loaded_weights(); + } + void merge_loaded_weights() { encoder_layer_->merge_loaded_weights(); } + + private: + layer::Qwen3VisionEncoderLayer encoder_layer_{nullptr}; +}; +TORCH_MODULE(Qwen3_VisionBlock); + +class Qwen3_VisionRotaryEmbeddingImpl : public torch::nn::Module { + public: + Qwen3_VisionRotaryEmbeddingImpl(const ModelContext& context) { + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + + dim_ = model_args.mm_head_dim() / 2; + theta_ = 10000.0; + + auto opts = options.dtype(torch::kFloat32); + auto inv_freq = + 1.0 / torch::pow(theta_, torch::arange(0, dim_, 2, opts) / dim_); + inv_freq_ = register_buffer("inv_freq", inv_freq); + } + + void update_freqs_cache(int64_t seqlen) { + if (seqlen <= seq_len_cached_) return; + + seqlen *= 2; + seq_len_cached_ = seqlen; + + auto options = torch::TensorOptions() + .dtype(torch::kFloat32) + .device(inv_freq_.device()); + inv_freq_ = + 1.0 / torch::pow(theta_, torch::arange(0, dim_, 2, options) / dim_); + auto seq = torch::arange(seqlen, options); + freqs_cached_ = torch::outer(seq, inv_freq_); + } + + torch::Tensor forward(int seqlen) { + update_freqs_cache(seqlen); + return freqs_cached_.slice(0, 0, seqlen); + } + + private: + int dim_ = 0; + double theta_ = 0.0; + + int64_t seq_len_cached_ = 0; + torch::Tensor inv_freq_; + torch::Tensor freqs_cached_; +}; +TORCH_MODULE(Qwen3_VisionRotaryEmbedding); + +class Qwen3_VisionPatchMergerImpl : public torch::nn::Module { + public: + Qwen3_VisionPatchMergerImpl(const ModelContext& context, + bool use_postshuffle_norm = false) { + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + auto quant_args = context.get_quant_args(); + auto parallel_args = context.get_parallel_args(); + int64_t d_model = model_args.mm_projection_dim(); + int context_dim = model_args.mm_hidden_size(); + int spatial_merge_size = model_args.mm_spatial_merge_size(); + hidden_size_ = + context_dim * static_cast(std::pow(spatial_merge_size, 2)); + use_postshuffle_norm_ = use_postshuffle_norm; + if (use_postshuffle_norm) + norm_ = register_module( + "norm", + torch::nn::LayerNorm(torch::nn::LayerNormOptions({hidden_size_}) + .elementwise_affine(true) + .eps(1e-6))); + else + norm_ = register_module( + "norm", + torch::nn::LayerNorm(torch::nn::LayerNormOptions({context_dim}) + .elementwise_affine(true) + .eps(1e-6))); + norm_->weight.set_data(norm_->weight.to(options)); + norm_->bias.set_data(norm_->bias.to(options)); + + auto fc1 = torch::nn::Linear( + torch::nn::LinearOptions(hidden_size_, hidden_size_).bias(true)); + fc1->weight.set_data(fc1->weight.to(options)); + fc1->bias.set_data(fc1->bias.to(options)); + auto act = torch::nn::GELU(); + auto fc2 = torch::nn::Linear( + torch::nn::LinearOptions(hidden_size_, d_model).bias(true)); + fc2->weight.set_data(fc2->weight.to(options)); + fc2->bias.set_data(fc2->bias.to(options)); + mlp_ = register_module("mlp", torch::nn::Sequential(fc1, act, fc2)); + layers_ = std::make_tuple(fc1, act, fc2); + } + + torch::Tensor forward(torch::Tensor x) { + if (use_postshuffle_norm_) + x = norm_(x.view({-1, hidden_size_})); + else + x = norm_(x).view({-1, hidden_size_}); + return mlp_->forward(x); + } + + void load_state_dict(const StateDict& state_dict) { + // norm + const auto& norm_dict = state_dict.get_dict_with_prefix("norm."); + const auto& norm_weight = norm_dict.get_tensor("weight"); + if (norm_weight.defined()) { + CHECK_EQ(norm_->weight.sizes(), norm_weight.sizes()) + << "weight size mismatch for " << name(); + norm_->weight.data().copy_(norm_weight); + is_norm_weight_loaded = true; + } + const auto norm_bias = norm_dict.get_tensor("bias"); + if (norm_bias.defined()) { + CHECK_EQ(norm_->bias.sizes(), norm_bias.sizes()) + << "bias size mismatch for " << name(); + norm_->bias.data().copy_(norm_bias); + is_norm_bias_loaded = true; + } + + const auto& fc1_dict = state_dict.get_dict_with_prefix("linear_fc1."); + const auto& fc1_weight = fc1_dict.get_tensor("weight"); + if (fc1_weight.defined()) { + CHECK_EQ(std::get<0>(layers_)->weight.sizes(), fc1_weight.sizes()) + << "weight size mismatch for " << name(); + std::get<0>(layers_)->weight.data().copy_(fc1_weight); + is_fc1_weight_loaded = true; + } + const auto fc1_bias = fc1_dict.get_tensor("bias"); + if (fc1_bias.defined()) { + CHECK_EQ(std::get<0>(layers_)->bias.sizes(), fc1_bias.sizes()) + << "bias size mismatch for " << name(); + std::get<0>(layers_)->bias.data().copy_(fc1_bias); + is_fc1_bias_loaded = true; + } + + const auto& fc2_dict = state_dict.get_dict_with_prefix("linear_fc2."); + const auto& fc2_weight = fc2_dict.get_tensor("weight"); + if (fc2_weight.defined()) { + CHECK_EQ(std::get<2>(layers_)->weight.sizes(), fc2_weight.sizes()) + << "weight size mismatch for " << name(); + std::get<2>(layers_)->weight.data().copy_(fc2_weight); + is_fc2_weight_loaded = true; + } + const auto fc2_bias = fc2_dict.get_tensor("bias"); + if (fc2_bias.defined()) { + CHECK_EQ(std::get<2>(layers_)->bias.sizes(), fc2_bias.sizes()) + << "bias size mismatch for " << name(); + std::get<2>(layers_)->bias.data().copy_(fc2_bias); + is_fc2_bias_loaded = true; + } + } + + void verify_loaded_weights(const std::string& prefix) const { + CHECK(is_fc1_weight_loaded) + << "weight is not loaded for " << prefix + "linear_fc1" + ".weight"; + CHECK(is_fc1_bias_loaded) + << "bias is not loaded for " << prefix + "linear_fc1" + ".bias"; + CHECK(is_fc2_weight_loaded) + << "weight is not loaded for " << prefix + "linear_fc2" + ".weight"; + CHECK(is_fc2_bias_loaded) + << "bias is not loaded for " << prefix + "linear_fc2" + ".bias"; + CHECK(is_norm_weight_loaded) + << "weight is not loaded for " << prefix + "norm" + ".weight"; + CHECK(is_norm_bias_loaded) + << "bias is not loaded for " << prefix + "norm" + ".bias"; + } + + private: + int hidden_size_; + bool use_postshuffle_norm_; + torch::nn::LayerNorm norm_{nullptr}; + torch::nn::Sequential mlp_{nullptr}; + std::tuple layers_ = { + nullptr, + nullptr, + nullptr}; + bool is_fc1_weight_loaded = false; + bool is_fc1_bias_loaded = false; + bool is_fc2_weight_loaded = false; + bool is_fc2_bias_loaded = false; + bool is_norm_weight_loaded = false; + bool is_norm_bias_loaded = false; +}; +TORCH_MODULE(Qwen3_VisionPatchMerger); + +class Qwen3_VisionTransformerImpl : public torch::nn::Module { + public: + Qwen3_VisionTransformerImpl(const ModelContext& context) + : options_(context.get_tensor_options()) { + auto model_args = context.get_model_args(); + hidden_size_ = model_args.mm_hidden_size(); + num_heads_ = model_args.mm_num_attention_heads(); + window_size_ = model_args.mm_window_size(); + patch_size_ = model_args.mm_patch_size(); + spatial_merge_size_ = model_args.mm_spatial_merge_size(); + auto& visual_indexes = model_args.mm_deepstack_visual_indexes(); + deepstack_visual_indexes_.insert(deepstack_visual_indexes_.end(), + visual_indexes.begin(), + visual_indexes.end()); + num_position_embeddings_ = model_args.mm_num_position_embeddings(); + spatial_merge_unit_ = + static_cast(spatial_merge_size_ * spatial_merge_size_); + num_grid_per_side_ = static_cast(std::sqrt(num_position_embeddings_)); + + patch_embed_ = + register_module("patch_embed", Qwen3_VisionPatchEmbed(context)); + rotary_pos_emb_ = + register_module("rotary_pos_emb", Qwen3_VisionRotaryEmbedding(context)); + + blocks_ = register_module("blocks", torch::nn::ModuleList()); + deepstack_mergers_ = + register_module("deepstack_mergers", torch::nn::ModuleList()); + + emb_ = register_module( + "embedding", + torch::nn::Embedding(num_position_embeddings_, hidden_size_)); + emb_->weight.set_data(emb_->weight.to(options_)); + + merger_ = register_module("merger", Qwen3_VisionPatchMerger(context)); + + for (int32_t idx = 0; idx < model_args.mm_num_hidden_layers(); idx++) { + auto block = Qwen3_VisionBlock(context); + blocks_->push_back(block); + layers_.push_back(block); + } + for (int32_t idx = 0; idx < deepstack_visual_indexes_.size(); idx++) { + auto merger = Qwen3_VisionPatchMerger(context, true); + deepstack_mergers_->push_back(merger); + deepstack_merger_layers_.push_back(merger); + } + } + + torch::Tensor rot_pos_emb(torch::Tensor grid_thw) { + std::vector pos_ids_vec; + auto count = grid_thw.sizes()[0]; + pos_ids_vec.reserve(count); + // int merge_size = + + auto grid_thw_cpu = grid_thw.cpu(); + auto options = + torch::TensorOptions().dtype(torch::kLong).device(grid_thw.device()); + + for (int idx = 0; idx < count; ++idx) { + auto t = grid_thw_cpu[idx][0].item(); + auto h = grid_thw_cpu[idx][1].item(); + auto w = grid_thw_cpu[idx][2].item(); + + auto hpos_ids = torch::arange(h, options).unsqueeze(1).expand({-1, w}); + hpos_ids = hpos_ids + .reshape({h / spatial_merge_size_, + spatial_merge_size_, + w / spatial_merge_size_, + spatial_merge_size_}) + .permute({0, 2, 1, 3}) + .flatten(); + + auto wpos_ids = torch::arange(w, options).unsqueeze(0).expand({h, -1}); + wpos_ids = wpos_ids + .reshape({h / spatial_merge_size_, + spatial_merge_size_, + w / spatial_merge_size_, + spatial_merge_size_}) + .permute({0, 2, 1, 3}) + .flatten(); + + pos_ids_vec.push_back( + torch::stack({hpos_ids, wpos_ids}, -1).repeat({t, 1})); + } + + auto pos_ids = torch::cat(pos_ids_vec, 0); + auto max_grid_size = + grid_thw + .index({torch::indexing::Slice(), + torch::indexing::Slice(1, torch::indexing::None)}) + .max(); + + auto rotary_pos_emb_full = rotary_pos_emb_(max_grid_size.item()); + auto rotary_pos_emb = rotary_pos_emb_full.index({pos_ids}).flatten(1); + + return rotary_pos_emb; + } + + torch::Tensor fast_pos_embed_interpolate(const torch::Tensor& grid_thw) { + auto device = grid_thw.device(); + int64_t hidden_dim = hidden_size_; + int64_t m_size = spatial_merge_size_; + + auto grid_cpu = grid_thw.to(torch::kCPU); + int64_t count = grid_thw.size(0); + + std::vector outputs; + outputs.reserve(count); + + for (int64_t idx = 0; idx < count; ++idx) { + int64_t t = grid_cpu[idx][0].item(); + int64_t h = grid_cpu[idx][1].item(); + int64_t w = grid_cpu[idx][2].item(); + + auto h_idxs = + torch::linspace( + 0, static_cast(num_grid_per_side_ - 1), h, torch::kFloat32) + .to(device); + auto w_idxs = + torch::linspace( + 0, static_cast(num_grid_per_side_ - 1), w, torch::kFloat32) + .to(device); + + auto h_floor = h_idxs.to(torch::kLong); + auto w_floor = w_idxs.to(torch::kLong); + auto h_ceil = torch::clamp(h_floor + 1, 0, num_grid_per_side_ - 1); + auto w_ceil = torch::clamp(w_floor + 1, 0, num_grid_per_side_ - 1); + + auto dh = h_idxs - h_floor; + auto dw = w_idxs - w_floor; + + auto mesh_d = torch::meshgrid({dh, dw}, "ij"); + auto dh_grid = mesh_d[0], dw_grid = mesh_d[1]; + + auto mesh_floor = torch::meshgrid({h_floor, w_floor}, "ij"); + auto h_floor_grid = mesh_floor[0]; + auto w_floor_grid = mesh_floor[1]; + + auto mesh_ceil = torch::meshgrid({h_ceil, w_ceil}, "ij"); + auto h_ceil_grid = mesh_ceil[0]; + auto w_ceil_grid = mesh_ceil[1]; + + auto h_floor_grid_idx = h_floor_grid * num_grid_per_side_; + auto h_ceil_grid_idx = h_ceil_grid * num_grid_per_side_; + + auto w11 = dh_grid * dw_grid; + auto w10 = dh_grid - w11; + auto w01 = dw_grid - w11; + auto w00 = 1.0f - dh_grid - dw_grid + w11; + + auto idx00 = h_floor_grid_idx + w_floor_grid; + auto idx01 = h_floor_grid_idx + w_ceil_grid; + auto idx10 = h_ceil_grid_idx + w_floor_grid; + auto idx11 = h_ceil_grid_idx + w_ceil_grid; + + auto indices = torch::stack({idx00, idx01, idx10, idx11}, 0) + .reshape({4, -1}) + .to(torch::kLong); + auto weights = torch::stack({w00, w01, w10, w11}, 0) + .reshape({4, -1, 1}) + .to(options_); + + auto embeds = emb_(indices); + + auto combined = (embeds * weights).sum(0); // [h*w, hidden_dim] + + auto repeated = combined.unsqueeze(0).expand({t, -1, -1}).contiguous(); + repeated = repeated.view( + {t, h / m_size, m_size, w / m_size, m_size, hidden_dim}); + repeated = repeated.permute({0, 1, 3, 2, 4, 5}).reshape({-1, hidden_dim}); + + outputs.push_back(repeated); + } + + return torch::cat(outputs, 0); + } + + std::tuple> forward( + torch::Tensor hidden_states, + torch::Tensor grid_thw, // [batch,thw] + const ModelInputParams& input_params) { + hidden_states = patch_embed_(hidden_states); + auto pos_embeds = fast_pos_embed_interpolate(grid_thw); + hidden_states = hidden_states + pos_embeds; + // compute position embedding + auto rotary_pos_emb = rot_pos_emb(grid_thw); + // compute cu_seqlens + auto cu_seqlens = torch::repeat_interleave( + grid_thw.index({torch::indexing::Slice(), 1}) * + grid_thw.index({torch::indexing::Slice(), 2}), + grid_thw.index({torch::indexing::Slice(), 0})) + .cumsum(0, torch::kInt32); + namespace F = torch::nn::functional; + cu_seqlens = F::pad( + cu_seqlens, F::PadFuncOptions({1, 0}).mode(torch::kConstant).value(0)); + + // transformers + cu_seqlens = torch::diff(cu_seqlens); + + m_cos = rotary_pos_emb.cos().type_as(hidden_states); + m_cos = m_cos.repeat({1, 2}); + m_sin = rotary_pos_emb.sin().type_as(hidden_states); + m_sin = m_sin.repeat({1, 2}); + + ModelInputParams& input_params_new = + const_cast(input_params); + torch::Tensor cu_seqlens_cpu = cu_seqlens.cpu(); + std::vector cu_seqlens_vec( + cu_seqlens_cpu.data_ptr(), // full seqlen vec + cu_seqlens_cpu.data_ptr() + cu_seqlens_cpu.numel()); + std::vector deepstack_feature_lists; + deepstack_feature_lists.reserve(deepstack_visual_indexes_.size()); + for (int idx = 0; idx < blocks_->size(); ++idx) { + hidden_states = layers_[idx](hidden_states, + m_cos, + m_sin, + cu_seqlens, + cu_seqlens_vec, + input_params_new, + idx); + auto it = std::find(deepstack_visual_indexes_.begin(), + deepstack_visual_indexes_.end(), + idx); + + if (it != deepstack_visual_indexes_.end()) { + int index = std::distance(deepstack_visual_indexes_.begin(), it); + deepstack_feature_lists.push_back( + deepstack_merger_layers_[index](hidden_states)); + } + } + // adapter + hidden_states = merger_(hidden_states); + return std::make_tuple(hidden_states, deepstack_feature_lists); + } + + void load_state_dict(const StateDict& state_dict) { + patch_embed_->load_state_dict( + state_dict.get_dict_with_prefix("patch_embed.")); + + for (int idx = 0; idx < layers_.size(); ++idx) { + layers_[idx]->load_state_dict(state_dict.get_dict_with_prefix( + "blocks." + std::to_string(idx) + ".")); + // std::cout << "load block " << idx << std::endl; + } + + merger_->load_state_dict(state_dict.get_dict_with_prefix("merger.")); + + for (int idx = 0; idx < deepstack_merger_layers_.size(); ++idx) { + deepstack_merger_layers_[idx]->load_state_dict( + state_dict.get_dict_with_prefix("deepstack_merger_list." + + std::to_string(idx) + ".")); + } + + const auto& emb_dict = state_dict.get_dict_with_prefix("pos_embed."); + const auto& emb_weight = emb_dict.get_tensor("weight"); + if (emb_weight.defined()) { + CHECK_EQ(emb_->weight.sizes(), emb_weight.sizes()) + << "weight size mismatch for " << name(); + emb_->weight.data().copy_(emb_weight); + is_emb_weight_loaded = true; + } + } + + void verify_loaded_weights(const std::string& prefix) const { + patch_embed_->verify_loaded_weights(prefix + "patch_embed."); + for (int idx = 0; idx < blocks_->size(); ++idx) { + layers_[idx]->verify_loaded_weights(prefix + "blocks." + + std::to_string(idx) + "."); + } + merger_->verify_loaded_weights(prefix + "merger."); + + for (int idx = 0; idx < deepstack_merger_layers_.size(); ++idx) { + deepstack_merger_layers_[idx]->verify_loaded_weights( + "deepstack_merger_list." + std::to_string(idx) + "."); + } + CHECK(is_emb_weight_loaded) + << "weight is not loaded for " << prefix + "" + ".bias"; + } + + void merge_loaded_weights() { + for (int idx = 0; idx < layers_.size(); ++idx) { + layers_[idx]->merge_loaded_weights(); + } + } + + private: + int hidden_size_ = 0; + int num_heads_ = 0; + int window_size_ = 0; + int patch_size_ = 0; + int spatial_merge_size_ = 0; + std::vector deepstack_visual_indexes_; + int spatial_merge_unit_ = 0; + int64_t num_position_embeddings_ = 0; + int num_grid_per_side_ = 0; + + Qwen3_VisionPatchEmbed patch_embed_{nullptr}; + Qwen3_VisionRotaryEmbedding rotary_pos_emb_{nullptr}; + torch::nn::Embedding emb_{nullptr}; + + torch::nn::ModuleList blocks_{nullptr}; + std::vector layers_; + + torch::nn::ModuleList deepstack_mergers_{nullptr}; + std::vector deepstack_merger_layers_; + Qwen3_VisionPatchMerger merger_{nullptr}; + + torch::Tensor m_cos; + torch::Tensor m_sin; + int device_id = 0; + bool is_emb_weight_loaded = false; + torch::TensorOptions options_; +}; +TORCH_MODULE(Qwen3_VisionTransformer); + +struct Qwen3_VLImageInputs { + torch::Tensor pixel_values; + torch::Tensor image_grid_thw; +}; + +struct Qwen3_VLVideoInputs { + torch::Tensor pixel_values_videos; + torch::Tensor video_grid_thw; + torch::Tensor second_per_grid_ts; +}; + +class Qwen3_VLForConditionalGenerationImpl : public torch::nn::Module { + public: + Qwen3_VLForConditionalGenerationImpl(const ModelContext& context) + : model_args_(context.get_model_args()), + options_(context.get_tensor_options()) { + visual_ = register_module("visual", Qwen3_VisionTransformer(context)); + language_model_ = + register_module("language_model", QWen3ForCausalLM(context)); + } + + torch::Tensor get_input_embeddings( + torch::Tensor input_ids, + const std::optional& image_input, + const std::optional& video_input, + const ModelInputParams& input_params) { + auto inputs_embeds = language_model_->get_input_embeddings(input_ids); + if (image_input) { + // visual + auto [image_embeds, deep_stacks] = + visual_(image_input->pixel_values.to(options_), + image_input->image_grid_thw, + input_params); + input_params.deep_stacks = deep_stacks; + // merge + auto is_multimodal = torch::isin(input_ids, model_args_.image_token_id()); + input_params.visual_pos_masks = is_multimodal; + inputs_embeds.index_put_({is_multimodal}, image_embeds); + } + return inputs_embeds; + } + + torch::Tensor forward(const std::vector& tokens, + const std::vector& positions, + std::vector& kv_caches, + const std::vector& input_params) { + torch::NoGradGuard no_grad; + const auto& mm_data = input_params[0].mm_data; + torch::Tensor pixel_values; + if (const auto& res = mm_data.get("pixel_values")) + pixel_values = res.value(); + + torch::Tensor image_grid_thw; + if (const auto& res = mm_data.get("image_grid_thw")) + image_grid_thw = res.value(); + std::optional image_inputs; + std::optional video_inputs; + + if (pixel_values.defined() && image_grid_thw.defined()) + image_inputs = Qwen3_VLImageInputs{pixel_values, image_grid_thw}; + + auto inputs_embeds = get_input_embeddings( + tokens[0], image_inputs, video_inputs, input_params[0]); + input_params[0].input_embedding = inputs_embeds; + auto emb = language_model_(tokens, positions, kv_caches, input_params); + + return emb; + } + + torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + return language_model_->logits(hidden_states, seleted_idxes); + } + + void load_model(std::unique_ptr loader) { + for (const auto& state_dict : loader->get_state_dicts()) { + visual_->load_state_dict( + state_dict->get_dict_with_prefix("model.visual.")); + } + // verify + visual_->verify_loaded_weights("model.visual."); + visual_->merge_loaded_weights(); + if (!model_args_.image_embedding_mode()) { + language_model_->load_model(std::move(loader), "model.language_model."); + } + } + + layer::LmHead get_lm_head() { return language_model_->get_lm_head(); } + void set_lm_head(layer::LmHead& head) { language_model_->set_lm_head(head); } + + std::vector get_word_embedding() { + return language_model_->get_word_embedding(); + } + + void set_word_embedding(std::vector& word_embedding) { + language_model_->set_word_embedding(word_embedding); + } + + private: + ModelArgs model_args_; + torch::TensorOptions options_; + + Qwen3_VisionTransformer visual_{nullptr}; + QWen3ForCausalLM language_model_{nullptr}; +}; +TORCH_MODULE(Qwen3_VLForConditionalGeneration); + +REGISTER_INPUT_PROCESSOR(qwen3_vl, Qwen2_5_VLInputProcessor); +REGISTER_CAUSAL_VLM_MODEL(qwen3_vl, Qwen3_VLForConditionalGeneration); +REGISTER_IMAGE_PROCESSOR(qwen3_vl, Qwen2VLImageProcessor); + +REGISTER_MODEL_ARGS(qwen3_vl, [&] { + // text config + // LOAD_ARG_OR(attention_dropout, "attention_dropout", 0.0); + LOAD_ARG_OR(model_type, "model_type", "qwen3_vl"); + LOAD_ARG_OR(bos_token_id, "text_config.bos_token_id", 151643); + LOAD_ARG_OR(eos_token_id, "text_config.eos_token_id", 151645); + LOAD_ARG_OR( + vision_start_token_id, "text_config.vision_start_token_id", 151652); + LOAD_ARG_OR(vision_end_token_id, "text_config.vision_end_token_id", 151653); + LOAD_ARG_OR(vision_token_id, "text_config.vision_token_id", 151654); + LOAD_ARG_OR(image_token_id, "text_config.image_token_id", 151655); + LOAD_ARG_OR(video_token_id, "text_config.video_token_id", 151656); + LOAD_ARG_OR(hidden_act, "text_config.hidden_act", "silu"); + LOAD_ARG_OR(hidden_size, "text_config.hidden_size", 3584); + LOAD_ARG_OR(intermediate_size, "text_config.intermediate_size", 18944); + LOAD_ARG_OR( + max_position_embeddings, "text_config.max_position_embeddings", 128000); + LOAD_ARG_OR(max_window_layers, "text_config.max_window_layers", 28); + LOAD_ARG_OR(n_heads, "text_config.num_attention_heads", 32); + LOAD_ARG_OR(n_layers, "text_config.num_hidden_layers", 48); + LOAD_ARG_OR(n_kv_heads, "text_config.num_key_value_heads", 4); + LOAD_ARG_OR(rms_norm_eps, "text_config.rms_norm_eps", 1e-06); + LOAD_ARG_OR(rope_theta, "text_config.rope_theta", 5000000.0f); + LOAD_ARG_OR(sliding_window, "text_config.sliding_window", 32768); + LOAD_ARG_OR(tie_word_embeddings, "text_config.tie_word_embeddings", false); + LOAD_ARG(rope_scaling_mrope_section, + "text_config.rope_scaling.mrope_section"); + LOAD_ARG_OR(dtype, "text_config.dtype", "bfloat16"); + // LOAD_ARG_OR(transformers_version, "transformers_version", "4.41.2"); + // LOAD_ARG_OR(use_cache, "use_cache", true); + LOAD_ARG_OR(use_sliding_window, "use_sliding_window", false); + LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] { + return args->hidden_size() / args->n_heads(); + }); + // vision_config + LOAD_ARG_OR(mm_num_hidden_layers, "vision_config.depth", 27); + LOAD_ARG_OR(mm_hidden_act, "vision_config.hidden_act", "gelu_pytorch_tanh"); + LOAD_ARG_OR(mm_hidden_size, "vision_config.hidden_size", 1152); + LOAD_ARG_OR(mm_intermediate_size, "vision_config.intermediate_size", 4304); + LOAD_ARG_OR(mm_num_attention_heads, "vision_config.num_heads", 16); + LOAD_ARG_OR(mm_num_channels, "vision_config.in_channels", 3); + LOAD_ARG_OR(mm_projection_dim, "vision_config.out_hidden_size", 4096); + LOAD_ARG_OR(mm_patch_size, "vision_config.patch_size", 16); + LOAD_ARG_OR(mm_num_position_embeddings, + "vision_config.num_position_embeddings", + 2304); + LOAD_ARG_OR(mm_spatial_merge_size, "vision_config.spatial_merge_size", 2); + LOAD_ARG(mm_deepstack_visual_indexes, + "vision_config.deepstack_visual_indexes"); + LOAD_ARG_OR(mm_temporal_patch_size, "vision_config.temporal_patch_size", 2); + LOAD_ARG_OR_FUNC(mm_head_dim, "head_dim", [&] { + return args->mm_hidden_size() / args->mm_num_attention_heads(); + }); + + LOAD_ARG_OR( + rope_scaling_rope_type, "vision_config.rope_scaling.type", "mrope"); + + LOAD_ARG_OR(vocab_size, "vocab_size", 152064); +}); +} // namespace xllm diff --git a/xllm/processors/CMakeLists.txt b/xllm/processors/CMakeLists.txt old mode 100644 new mode 100755 diff --git a/xllm/processors/qwen2_vl_image_processor.cpp b/xllm/processors/qwen2_vl_image_processor.cpp old mode 100755 new mode 100644 index 796068ab..16adc17d --- a/xllm/processors/qwen2_vl_image_processor.cpp +++ b/xllm/processors/qwen2_vl_image_processor.cpp @@ -63,10 +63,13 @@ std::optional smart_resize(int height, Qwen2VLImageProcessor::Qwen2VLImageProcessor(const ModelArgs& args) { image_mean_ = args.mm_image_normalize_mean(); image_std_ = args.mm_image_normalize_std(); - - min_pixels_ = args.mm_image_min_pixels(); - max_pixels_ = args.mm_image_max_pixels(); - + if (args.mm_image_max_pixels() && args.mm_image_min_pixels()) { + min_pixels_ = args.mm_image_min_pixels(); + max_pixels_ = args.mm_image_max_pixels(); + } else if (args.mm_image_shortest_edge() && args.mm_image_longest_edge()) { + min_pixels_ = args.mm_image_shortest_edge(); + max_pixels_ = args.mm_image_longest_edge(); + } patch_size_ = args.mm_image_patch_size(); temporal_patch_size_ = args.mm_image_temporal_patch_size(); @@ -139,8 +142,10 @@ bool Qwen2VLImageProcessor::process_image( auto size = smart_resize(resized_height, resized_width, patch_size_ * merge_size_, - size_["shortest_edge"], - size_["longest_edge"]); + min_pixels_, + max_pixels_); + // size_["shortest_edge"], + // size_["longest_edge"]); if (!size) { return false; }