@@ -232,6 +232,7 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
232
232
rope_max_pos_emb = config .max_position_embeddings
233
233
rope_base = config .rope_theta
234
234
rope_scaling = config .rope_scaling
235
+ partial_rotary_factor = getattr (config , 'partial_rotary_factor' , None )
235
236
if rope_scaling is not None :
236
237
scaling_type = rope_scaling ['type' ]
237
238
assert scaling_type in ['longrope' , 'su' ]
@@ -246,13 +247,15 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
246
247
rope_base ,
247
248
longrope_params = longrope_params ,
248
249
emb_type = emb_type ,
250
+ partial_rotary_factor = partial_rotary_factor ,
249
251
)
250
252
else :
251
253
self .rotary_emb = build_rotary_embedding (
252
254
rope_dim ,
253
255
rope_max_pos_emb ,
254
256
rope_base ,
255
257
emb_type = emb_type ,
258
+ partial_rotary_factor = partial_rotary_factor ,
256
259
)
257
260
258
261
def forward (
@@ -348,6 +351,11 @@ def get_logits(self, hidden_states: torch.Tensor):
348
351
"""compute logits of the model output."""
349
352
return self .lm_head (hidden_states )
350
353
354
+ def update_weights (self ):
355
+ """update weights."""
356
+ if self .config .tie_word_embeddings :
357
+ self .lm_head .weight = self .model .embed_tokens .weight
358
+
351
359
def get_input_embeddings (self ):
352
360
"""get input embeddings."""
353
361
return self .model .get_input_embeddings ()
0 commit comments