@@ -229,6 +229,9 @@ def export_to_huggingface_llama(
229
229
config .num_key_value_heads ,
230
230
config .n_embd // config .n_head ,
231
231
AttentionHeadType (config .attention_head_type ),
232
+ m_emb = config .m_emb ,
233
+ m_residual = config .m_residual ,
234
+ # m_width=config.m_width,
232
235
)
233
236
234
237
SafeTensorsWeightsManager .save_state_dict (state_dict , save_path )
@@ -285,11 +288,25 @@ def _export_state_dict_to_huggingface(
285
288
num_key_value_heads : int ,
286
289
head_dim : int ,
287
290
attention_head_type : AttentionHeadType ,
291
+ m_residual : float = None ,
292
+ m_emb : float = None ,
293
+ m_width : float = None ,
288
294
) -> None :
295
+ if m_residual is None :
296
+ m_residual = 1.
297
+ if m_emb is None :
298
+ m_emb = 1.
299
+
300
+ # NOTE: this will not work since the norms are tied
301
+ # has_m_width = False
302
+ # if m_width is None:
303
+ # has_m_width = True
304
+ # m_width = 1.
305
+
289
306
state_dict = {
290
307
"model.embed_tokens.weight" : safetensors_weight_manager .get_tensor (
291
308
"transformer.wte.weight"
292
- ),
309
+ ) * m_emb ,
293
310
"model.norm.weight" : safetensors_weight_manager .get_tensor (
294
311
"transformer.ln_f.weight"
295
312
),
@@ -298,7 +315,12 @@ def _export_state_dict_to_huggingface(
298
315
if safetensors_weight_manager .has_tensor ("lm_head.weight" ):
299
316
state_dict ["lm_head.weight" ] = safetensors_weight_manager .get_tensor (
300
317
"lm_head.weight"
301
- )
318
+ ) / m_width
319
+ # elif has_m_width:
320
+ # # int this we cannot tie
321
+ # state_dict["lm_head.weight"] = safetensors_weight_manager.get_tensor(
322
+ # "transformer.wte.weight"
323
+ # ) / m_width
302
324
303
325
for layer_idx in range (num_layers ):
304
326
state_dict [f"model.layers.{ layer_idx } .input_layernorm.weight" ] = (
@@ -332,13 +354,13 @@ def _export_state_dict_to_huggingface(
332
354
state_dict [f"model.layers.{ layer_idx } .mlp.down_proj.weight" ] = (
333
355
safetensors_weight_manager .get_tensor (
334
356
f"transformer.h.{ layer_idx } .mlp.c_proj.weight"
335
- )
357
+ ) * m_residual
336
358
)
337
359
if f"transformer.h.{ layer_idx } .mlp.c_proj.bias" in safetensors_weight_manager :
338
360
state_dict [f"model.layers.{ layer_idx } .mlp.down_proj.bias" ] = (
339
361
safetensors_weight_manager .get_tensor (
340
362
f"transformer.h.{ layer_idx } .mlp.c_proj.bias"
341
- )
363
+ ) * m_residual
342
364
)
343
365
344
366
query_weight , key_weight , value_weight = (
@@ -376,12 +398,12 @@ def _export_state_dict_to_huggingface(
376
398
safetensors_weight_manager .get_tensor (
377
399
f"transformer.h.{ layer_idx } .attn.c_proj.weight"
378
400
)
379
- )
401
+ ) * m_residual
380
402
if f"transformer.h.{ layer_idx } .attn.c_proj.bias" in safetensors_weight_manager :
381
403
state_dict [f"model.layers.{ layer_idx } .self_attn.o_proj.bias" ] = (
382
404
safetensors_weight_manager .get_tensor (
383
405
f"transformer.h.{ layer_idx } .attn.c_proj.bias"
384
406
)
385
- )
407
+ ) * m_residual
386
408
387
409
return state_dict
0 commit comments