@@ -533,6 +533,85 @@ def copy_weights_qwen_2_5(
533533 pbar .update (progress_per_file )
534534
535535
536+ def copy_weights_olmo2 (
537+ config : Config ,
538+ qkv_weights : Dict [int , List [Optional [NotYetLoadedTensor ]]],
539+ state_dict : Dict [str , torch .Tensor ],
540+ hf_weights : Dict [str , Union [torch .Tensor , NotYetLoadedTensor ]],
541+ saver : Optional [incremental_save ] = None ,
542+ dtype : Optional [torch .dtype ] = None ,
543+ pbar : Optional [tqdm ] = None ,
544+ progress_per_file : Optional [float ] = None ,
545+ debug_mode : Optional [bool ] = False ,
546+ ) -> None :
547+ weight_map = {
548+ "model.embed_tokens.weight" : "transformer.wte.weight" ,
549+ "model.layers.{}.self_attn.q_norm.weight" : "transformer.h.{}.attn.norm_q.weight" ,
550+ "model.layers.{}.self_attn.q_proj.weight" : None ,
551+ "model.layers.{}.self_attn.k_norm.weight" : "transformer.h.{}.attn.norm_k.weight" ,
552+ "model.layers.{}.self_attn.k_proj.weight" : None ,
553+ "model.layers.{}.self_attn.v_proj.weight" : None ,
554+ "model.layers.{}.self_attn.o_proj.weight" : "transformer.h.{}.attn.proj.weight" ,
555+ "model.layers.{}.self_attn.rotary_emb.inv_freq" : None ,
556+ "model.layers.{}.post_attention_layernorm.weight" : "transformer.h.{}.post_attention_norm.weight" ,
557+ "model.layers.{}.post_attention_layernorm.bias" : "transformer.h.{}.post_attention_norm.bias" ,
558+ "model.layers.{}.post_feedforward_layernorm.weight" : "transformer.h.{}.post_mlp_norm.weight" ,
559+ "model.norm.weight" : "transformer.ln_f.weight" ,
560+ "model.norm.bias" : "transformer.ln_f.bias" ,
561+ "lm_head.weight" : "lm_head.weight" ,
562+ }
563+ if config .mlp_class_name in ("LLaMAMLP" , "GemmaMLP" ):
564+ weight_map .update (
565+ {
566+ "model.layers.{}.mlp.gate_proj.weight" : "transformer.h.{}.mlp.fc_1.weight" ,
567+ "model.layers.{}.mlp.up_proj.weight" : "transformer.h.{}.mlp.fc_2.weight" ,
568+ "model.layers.{}.mlp.down_proj.weight" : "transformer.h.{}.mlp.proj.weight" ,
569+ }
570+ )
571+ else :
572+ raise NotImplementedError
573+
574+ if progress_per_file is not None :
575+ progress_per_file = progress_per_file / max (1 , len (hf_weights ) + len (qkv_weights ))
576+
577+ for from_name , param in hf_weights .items ():
578+ name_template , * ids = layer_template (from_name , num_matches = 2 )
579+ to_name = weight_map [name_template ]
580+ param = load_param (param , from_name , dtype , verbose = debug_mode )
581+ if any (w in from_name for w in ("q_proj" , "k_proj" , "v_proj" )):
582+ qkv = qkv_weights .setdefault (ids [0 ], defaultdict (dict ))
583+ weight_name , weight_type = from_name .split ("." )[- 2 :]
584+ qkv [weight_type ][weight_name ] = param
585+ if to_name is None :
586+ continue
587+ to_name = to_name .format (* ids )
588+ if saver is not None :
589+ param = saver .store_early (param )
590+ state_dict [to_name ] = param
591+
592+ if progress_per_file is not None :
593+ pbar .update (progress_per_file )
594+
595+ if "lm_head.weight" not in state_dict :
596+ state_dict ["lm_head.weight" ] = state_dict ["transformer.wte.weight" ]
597+
598+ for i in list (qkv_weights ):
599+ for weight_type in list (qkv_weights [i ]):
600+ qkv = qkv_weights [i ][weight_type ]
601+ if len (qkv ) != 3 :
602+ # qkv is split across different .bin files
603+ continue
604+ q = load_param (qkv ["q_proj" ], f"layer { i } q { weight_type } " , dtype , verbose = debug_mode )
605+ k = load_param (qkv ["k_proj" ], f"layer { i } k { weight_type } " , dtype , verbose = debug_mode )
606+ v = load_param (qkv ["v_proj" ], f"layer { i } v { weight_type } " , dtype , verbose = debug_mode )
607+ qkv = torch .cat ((q , k , v ))
608+ state_dict [f"transformer.h.{ i } .attn.qkv.{ weight_type } " ] = qkv
609+ del qkv_weights [i ][weight_type ]
610+
611+ if progress_per_file is not None :
612+ pbar .update (progress_per_file )
613+
614+
536615def copy_weights_qwen_3 (
537616 config : Config ,
538617 qkv_weights : Dict [int , List [Optional [NotYetLoadedTensor ]]],
@@ -693,6 +772,10 @@ def convert_hf_checkpoint(
693772 # holder to reconstitute the split q, k, v
694773 qkv_weights = {}
695774 copy_fn = partial (copy_weights_qwen_2_5 , config , qkv_weights )
775+ elif model_name .lower ().startswith ("olmo-2-" ):
776+ # holder to reconstitute the split q, k, v
777+ qkv_weights = {}
778+ copy_fn = partial (copy_weights_olmo2 , config , qkv_weights )
696779 elif model_name .lower ().startswith ("qwen3" ):
697780 # holder to reconstitute the split q, k, v
698781 qkv_weights = {}
0 commit comments