@@ -232,6 +232,7 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
232232 rope_max_pos_emb = config .max_position_embeddings
233233 rope_base = config .rope_theta
234234 rope_scaling = config .rope_scaling
235+ partial_rotary_factor = getattr (config , 'partial_rotary_factor' , None )
235236 if rope_scaling is not None :
236237 scaling_type = rope_scaling ['type' ]
237238 assert scaling_type in ['longrope' , 'su' ]
@@ -246,13 +247,15 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
246247 rope_base ,
247248 longrope_params = longrope_params ,
248249 emb_type = emb_type ,
250+ partial_rotary_factor = partial_rotary_factor ,
249251 )
250252 else :
251253 self .rotary_emb = build_rotary_embedding (
252254 rope_dim ,
253255 rope_max_pos_emb ,
254256 rope_base ,
255257 emb_type = emb_type ,
258+ partial_rotary_factor = partial_rotary_factor ,
256259 )
257260
258261 def forward (
@@ -348,6 +351,11 @@ def get_logits(self, hidden_states: torch.Tensor):
348351 """compute logits of the model output."""
349352 return self .lm_head (hidden_states )
350353
354+ def update_weights (self ):
355+ """update weights."""
356+ if self .config .tie_word_embeddings :
357+ self .lm_head .weight = self .model .embed_tokens .weight
358+
351359 def get_input_embeddings (self ):
352360 """get input embeddings."""
353361 return self .model .get_input_embeddings ()
0 commit comments