Skip to content

Commit f74f283

Browse files
authored
feat: support PreFetchWeight and IntraAddNorm for qwen3-dense model. (#304)
1 parent 7792c30 commit f74f283

File tree

3 files changed

+55
-5
lines changed

3 files changed

+55
-5
lines changed

xllm/core/common/global_flags.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,3 +389,10 @@ DEFINE_string(reasoning_parser,
389389

390390
// --- qwen3 reranker config ---
391391
DEFINE_bool(enable_qwen3_reranker, false, "Whether to enable qwen3 reranker.");
392+
393+
DEFINE_bool(
394+
enable_prefetch_weight,
395+
false,
396+
"Whether to enable prefetch weight,only applicable to Qwen3-dense model."
397+
"The default prefetching ratio for gateup weight is 40%."
398+
"If adjustments are needed, e.g. export PREFETCH_COEFFOCIENT=0.5");

xllm/core/common/global_flags.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,5 @@ DECLARE_bool(enable_qwen3_reranker);
202202
DECLARE_string(reasoning_parser);
203203

204204
DECLARE_bool(enable_shm);
205+
206+
DECLARE_bool(enable_prefetch_weight);

xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,15 @@ enum DecoderLayerTensorId : int {
8989
IN_MLP_CPROJ_SCALE = 48, // scale
9090
IN_MLP_CPROJ_COMPRESS_IDX = 49,
9191

92-
Q_NORM_WEIGHT = 50,
93-
K_NORM_WEIGHT = 51,
92+
IN_QKV_SCALE_FILL = 50,
93+
IN_QKV_OFFSET_FILL = 51,
94+
IN_MLP_SCALE_FILL = 52,
95+
IN_MLP_OFFSET_FILL = 53,
96+
Q_NORM_WEIGHT = 54,
97+
K_NORM_WEIGHT = 55,
9498
};
9599

96-
const uint64_t WEIGHT_COUNT_PER_LAYER = 52;
100+
const uint64_t WEIGHT_COUNT_PER_LAYER = 56;
97101

98102
static std::vector<std::pair<int, std::string>> WEIGHT_MAPPING = {
99103
{IN_NORM_WEIGHT, "input_layernorm.weight"},
@@ -207,11 +211,16 @@ void NpuQwen3DecoderLayerImpl::param_from_args(
207211
param.useQKNorm = true;
208212

209213
param.numHiddenLayers = args.n_layers();
210-
214+
param.enableIntraLayerAddNorm = true;
215+
param.enableInterLayerAddNorm = false;
216+
param.enablePreFetchWeight = FLAGS_enable_prefetch_weight;
211217
initialize_quantization_parameters(param);
212218

213219
if (isPrefill) {
214-
param.enableAclnnRmsNorm = quantize_type_.empty();
220+
param.enableAclnnRmsNorm =
221+
param.enableIntraLayerAddNorm && quantize_type_.empty()
222+
? false
223+
: quantize_type_.empty();
215224
// for prefix cache without chunked prefill.
216225
if (FLAGS_enable_prefix_cache && !FLAGS_enable_chunked_prefill &&
217226
FLAGS_block_size != 128) {
@@ -383,6 +392,38 @@ void NpuQwen3DecoderLayerImpl::merge_loaded_weights() {
383392
at_weight_tensors_[idx] = at_placeholder_;
384393
}
385394

395+
if (prefill_param_.enableIntraLayerAddNorm ||
396+
prefill_param_.enableInterLayerAddNorm) {
397+
if (quantize_type_.compare("w8a8") == 0) {
398+
// quantize
399+
torch::ScalarType weight_fill_dtype = torch::kBFloat16;
400+
int64_t weight_attn_shape = at_weight_tensors_[IN_Q_WEIGHT].size(-1);
401+
int64_t weight_mlp_shape = at_weight_tensors_[IN_MLP_W2_WEIGHT].size(-1);
402+
at_weight_tensors_[IN_QKV_SCALE_FILL] = at_weight_tensors_[IN_Q_SCALE]
403+
.repeat(weight_attn_shape)
404+
.to(weight_fill_dtype);
405+
at_weight_tensors_[IN_MLP_SCALE_FILL] =
406+
at_weight_tensors_[IN_MLP_W2_SCALE]
407+
.repeat(weight_mlp_shape)
408+
.to(weight_fill_dtype);
409+
at_weight_tensors_[IN_QKV_OFFSET_FILL] = at_weight_tensors_[IN_Q_OFFSET]
410+
.repeat(weight_attn_shape)
411+
.to(weight_fill_dtype);
412+
at_weight_tensors_[IN_MLP_OFFSET_FILL] =
413+
at_weight_tensors_[IN_MLP_W2_OFFSET]
414+
.repeat(weight_mlp_shape)
415+
.to(weight_fill_dtype);
416+
} else {
417+
// bfloat16 or float16
418+
for (auto idx : {IN_QKV_SCALE_FILL,
419+
IN_QKV_OFFSET_FILL,
420+
IN_MLP_SCALE_FILL,
421+
IN_MLP_OFFSET_FILL}) {
422+
at_weight_tensors_[idx] = at_placeholder_;
423+
}
424+
}
425+
}
426+
386427
c10_npu::NPUCachingAllocator::emptyCache();
387428
for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) {
388429
atb_weight_tensors_[i] =

0 commit comments

Comments
 (0)