@@ -89,11 +89,15 @@ enum DecoderLayerTensorId : int {
8989 IN_MLP_CPROJ_SCALE = 48 , // scale
9090 IN_MLP_CPROJ_COMPRESS_IDX = 49 ,
9191
92- Q_NORM_WEIGHT = 50 ,
93- K_NORM_WEIGHT = 51 ,
92+ IN_QKV_SCALE_FILL = 50 ,
93+ IN_QKV_OFFSET_FILL = 51 ,
94+ IN_MLP_SCALE_FILL = 52 ,
95+ IN_MLP_OFFSET_FILL = 53 ,
96+ Q_NORM_WEIGHT = 54 ,
97+ K_NORM_WEIGHT = 55 ,
9498};
9599
96- const uint64_t WEIGHT_COUNT_PER_LAYER = 52 ;
100+ const uint64_t WEIGHT_COUNT_PER_LAYER = 56 ;
97101
98102static std::vector<std::pair<int , std::string>> WEIGHT_MAPPING = {
99103 {IN_NORM_WEIGHT, " input_layernorm.weight" },
@@ -207,11 +211,16 @@ void NpuQwen3DecoderLayerImpl::param_from_args(
207211 param.useQKNorm = true ;
208212
209213 param.numHiddenLayers = args.n_layers ();
210-
214+ param.enableIntraLayerAddNorm = true ;
215+ param.enableInterLayerAddNorm = false ;
216+ param.enablePreFetchWeight = FLAGS_enable_prefetch_weight;
211217 initialize_quantization_parameters (param);
212218
213219 if (isPrefill) {
214- param.enableAclnnRmsNorm = quantize_type_.empty ();
220+ param.enableAclnnRmsNorm =
221+ param.enableIntraLayerAddNorm && quantize_type_.empty ()
222+ ? false
223+ : quantize_type_.empty ();
215224 // for prefix cache without chunked prefill.
216225 if (FLAGS_enable_prefix_cache && !FLAGS_enable_chunked_prefill &&
217226 FLAGS_block_size != 128 ) {
@@ -383,6 +392,38 @@ void NpuQwen3DecoderLayerImpl::merge_loaded_weights() {
383392 at_weight_tensors_[idx] = at_placeholder_;
384393 }
385394
395+ if (prefill_param_.enableIntraLayerAddNorm ||
396+ prefill_param_.enableInterLayerAddNorm ) {
397+ if (quantize_type_.compare (" w8a8" ) == 0 ) {
398+ // quantize
399+ torch::ScalarType weight_fill_dtype = torch::kBFloat16 ;
400+ int64_t weight_attn_shape = at_weight_tensors_[IN_Q_WEIGHT].size (-1 );
401+ int64_t weight_mlp_shape = at_weight_tensors_[IN_MLP_W2_WEIGHT].size (-1 );
402+ at_weight_tensors_[IN_QKV_SCALE_FILL] = at_weight_tensors_[IN_Q_SCALE]
403+ .repeat (weight_attn_shape)
404+ .to (weight_fill_dtype);
405+ at_weight_tensors_[IN_MLP_SCALE_FILL] =
406+ at_weight_tensors_[IN_MLP_W2_SCALE]
407+ .repeat (weight_mlp_shape)
408+ .to (weight_fill_dtype);
409+ at_weight_tensors_[IN_QKV_OFFSET_FILL] = at_weight_tensors_[IN_Q_OFFSET]
410+ .repeat (weight_attn_shape)
411+ .to (weight_fill_dtype);
412+ at_weight_tensors_[IN_MLP_OFFSET_FILL] =
413+ at_weight_tensors_[IN_MLP_W2_OFFSET]
414+ .repeat (weight_mlp_shape)
415+ .to (weight_fill_dtype);
416+ } else {
417+ // bfloat16 or float16
418+ for (auto idx : {IN_QKV_SCALE_FILL,
419+ IN_QKV_OFFSET_FILL,
420+ IN_MLP_SCALE_FILL,
421+ IN_MLP_OFFSET_FILL}) {
422+ at_weight_tensors_[idx] = at_placeholder_;
423+ }
424+ }
425+ }
426+
386427 c10_npu::NPUCachingAllocator::emptyCache ();
387428 for (int i = 0 ; i < WEIGHT_COUNT_PER_LAYER; ++i) {
388429 atb_weight_tensors_[i] =
0 commit comments