Skip to content

Commit 539de82

Browse files
committed
addressed m_emb and m_residual
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
1 parent ddfc6b1 commit 539de82

File tree

1 file changed

+28
-6
lines changed
  • src/instructlab/dolomite/hf_models/model_conversion

1 file changed

+28
-6
lines changed

src/instructlab/dolomite/hf_models/model_conversion/llama.py

+28-6
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,9 @@ def export_to_huggingface_llama(
229229
config.num_key_value_heads,
230230
config.n_embd // config.n_head,
231231
AttentionHeadType(config.attention_head_type),
232+
m_emb=config.m_emb,
233+
m_residual=config.m_residual,
234+
# m_width=config.m_width,
232235
)
233236

234237
SafeTensorsWeightsManager.save_state_dict(state_dict, save_path)
@@ -285,11 +288,25 @@ def _export_state_dict_to_huggingface(
285288
num_key_value_heads: int,
286289
head_dim: int,
287290
attention_head_type: AttentionHeadType,
291+
m_residual: float = None,
292+
m_emb: float = None,
293+
m_width: float = None,
288294
) -> None:
295+
if m_residual is None:
296+
m_residual = 1.
297+
if m_emb is None:
298+
m_emb = 1.
299+
300+
# NOTE: this will not work since the norms are tied
301+
# has_m_width = False
302+
# if m_width is None:
303+
# has_m_width = True
304+
# m_width = 1.
305+
289306
state_dict = {
290307
"model.embed_tokens.weight": safetensors_weight_manager.get_tensor(
291308
"transformer.wte.weight"
292-
),
309+
) * m_emb,
293310
"model.norm.weight": safetensors_weight_manager.get_tensor(
294311
"transformer.ln_f.weight"
295312
),
@@ -298,7 +315,12 @@ def _export_state_dict_to_huggingface(
298315
if safetensors_weight_manager.has_tensor("lm_head.weight"):
299316
state_dict["lm_head.weight"] = safetensors_weight_manager.get_tensor(
300317
"lm_head.weight"
301-
)
318+
) / m_width
319+
# elif has_m_width:
320+
# # int this we cannot tie
321+
# state_dict["lm_head.weight"] = safetensors_weight_manager.get_tensor(
322+
# "transformer.wte.weight"
323+
# ) / m_width
302324

303325
for layer_idx in range(num_layers):
304326
state_dict[f"model.layers.{layer_idx}.input_layernorm.weight"] = (
@@ -332,13 +354,13 @@ def _export_state_dict_to_huggingface(
332354
state_dict[f"model.layers.{layer_idx}.mlp.down_proj.weight"] = (
333355
safetensors_weight_manager.get_tensor(
334356
f"transformer.h.{layer_idx}.mlp.c_proj.weight"
335-
)
357+
) * m_residual
336358
)
337359
if f"transformer.h.{layer_idx}.mlp.c_proj.bias" in safetensors_weight_manager:
338360
state_dict[f"model.layers.{layer_idx}.mlp.down_proj.bias"] = (
339361
safetensors_weight_manager.get_tensor(
340362
f"transformer.h.{layer_idx}.mlp.c_proj.bias"
341-
)
363+
) * m_residual
342364
)
343365

344366
query_weight, key_weight, value_weight = (
@@ -376,12 +398,12 @@ def _export_state_dict_to_huggingface(
376398
safetensors_weight_manager.get_tensor(
377399
f"transformer.h.{layer_idx}.attn.c_proj.weight"
378400
)
379-
)
401+
) * m_residual
380402
if f"transformer.h.{layer_idx}.attn.c_proj.bias" in safetensors_weight_manager:
381403
state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.bias"] = (
382404
safetensors_weight_manager.get_tensor(
383405
f"transformer.h.{layer_idx}.attn.c_proj.bias"
384406
)
385-
)
407+
) * m_residual
386408

387409
return state_dict

0 commit comments

Comments
 (0)