Skip to content

Commit d19df7a

Browse files
ysjprojectsBordapre-commit-ci[bot]lantigashijie.yu
authored
OLMo 2 (#1897)
Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka B <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Luca Antiga <[email protected]> Co-authored-by: shijie.yu <[email protected]>
1 parent f99ca4e commit d19df7a

File tree

6 files changed

+290
-5
lines changed

6 files changed

+290
-5
lines changed

litgpt/config.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class Config:
3838
norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
3939
norm_eps: float = 1e-5
4040
norm_qk: bool = False
41+
norm_qk_type: Literal["default", "olmo2"] = "default"
4142
post_attention_norm: bool = False
4243
post_mlp_norm: bool = False
4344
parallel_residual: bool = True
@@ -91,6 +92,8 @@ class Config:
9192
scale_embeddings: bool = False
9293
lm_head_bias: bool = False
9394
final_logit_softcapping: Optional[float] = None
95+
norm_1: bool = True
96+
norm_2: bool = True
9497
# The base period of the RoPE embeddings for local attention.
9598
# If not provided, rope_theta will be used for both local and global attention.
9699
rope_local_base_freq: Optional[float] = None
@@ -930,6 +933,68 @@ def norm_class(self) -> Type:
930933

931934
configs.extend(olmo)
932935

936+
olmo2 = [
937+
# https://huggingface.co/allenai/OLMo-2-1124-7B/blob/main/config.json
938+
dict(
939+
name="OLMo-2-1124-7B{}",
940+
hf_config=dict(org="allenai", name="OLMo-2-1124-7B{}"),
941+
vocab_size=100278,
942+
padded_vocab_size=100352,
943+
block_size=4096,
944+
n_embd=4096,
945+
n_layer=32,
946+
n_head=32,
947+
n_query_groups=32,
948+
rotary_percentage=1.0,
949+
parallel_residual=False,
950+
bias=False,
951+
norm_class_name="RMSNorm",
952+
mlp_class_name="LLaMAMLP",
953+
norm_eps=1e-06,
954+
intermediate_size=11008,
955+
rope_base=500000,
956+
norm_qk=True,
957+
post_mlp_norm=True,
958+
norm_1=False,
959+
norm_2=False,
960+
norm_qk_type="olmo2",
961+
post_attention_norm=True,
962+
),
963+
# https://huggingface.co/allenai/OLMo-2-1124-13B/blob/main/config.json
964+
dict(
965+
name="OLMo-2-1124-13B{}",
966+
hf_config=dict(org="allenai", name="OLMo-2-1124-13B{}"),
967+
vocab_size=100278,
968+
padded_vocab_size=100352,
969+
block_size=4096,
970+
n_embd=5120,
971+
n_layer=40,
972+
n_head=40,
973+
n_query_groups=40,
974+
rotary_percentage=1.0,
975+
parallel_residual=False,
976+
bias=False,
977+
norm_class_name="RMSNorm",
978+
mlp_class_name="LLaMAMLP",
979+
norm_eps=1e-06,
980+
intermediate_size=13824,
981+
rope_base=500000,
982+
norm_qk=True,
983+
post_mlp_norm=True,
984+
norm_1=False,
985+
norm_2=False,
986+
norm_qk_type="olmo2",
987+
post_attention_norm=True,
988+
),
989+
]
990+
991+
for c in olmo2:
992+
for kind in ("", "-SFT", "-DPO", "-Instruct"):
993+
copy = deepcopy(c)
994+
copy["name"] = c["name"].format(kind)
995+
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
996+
configs.append(copy)
997+
933998
###############
934999
# Google Gemma
9351000
###############

litgpt/model.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,16 @@ def __init__(
271271
" (non-parallel residual and shared attention norm)."
272272
)
273273

274-
self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
274+
self.norm_1 = nn.Identity() if not config.norm_1 else config.norm_class(config.n_embd, eps=config.norm_eps)
275275
self.attn = CausalSelfAttention(config, block_idx)
276276
self.post_attention_norm = (
277277
config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_attention_norm else nn.Identity()
278278
)
279-
self.norm_2 = None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps)
279+
self.norm_2 = (
280+
nn.Identity()
281+
if not config.norm_2
282+
else (None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps))
283+
)
280284
self.mlp = config.mlp_class(config)
281285
self.post_mlp_norm = (
282286
config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_mlp_norm else nn.Identity()
@@ -325,6 +329,7 @@ def forward(
325329
else:
326330
x = attention_output + x
327331
x_normed = self.norm_2(x)
332+
328333
return self.post_mlp_norm(self.mlp(x_normed)) + x
329334

330335

@@ -346,8 +351,12 @@ def __init__(self, config: Config, block_idx: int) -> None:
346351
self.apply_sliding_window_attention = config.sliding_window_indices[block_idx]
347352

348353
if config.norm_qk:
349-
self.norm_q = config.norm_class(config.head_size, eps=config.norm_eps)
350-
self.norm_k = config.norm_class(config.head_size, eps=config.norm_eps)
354+
norm_q_size = config.n_head * config.head_size if config.norm_qk_type == "olmo2" else config.head_size
355+
norm_k_size = (
356+
config.n_query_groups * config.head_size if config.norm_qk_type == "olmo2" else config.head_size
357+
)
358+
self.norm_q = config.norm_class(norm_q_size, eps=config.norm_eps)
359+
self.norm_k = config.norm_class(norm_k_size, eps=config.norm_eps)
351360
else:
352361
self.norm_q = self.norm_k = None
353362

@@ -387,6 +396,10 @@ def forward(
387396
# Split qkv into query, key and value matrices.
388397
q, k, v = qkv.split((query_size, key_size, value_size), dim=-1) # 3x(B, T, C*)
389398

399+
if self.config.norm_qk and self.config.norm_qk_type == "olmo2":
400+
q = self.norm_q(q)
401+
k = self.norm_k(k)
402+
390403
# To place the num_heads (nh) dimension right after the batch (B) dimension, the first step is to decouple the
391404
# embedding size (C) into num_heads (nh) and head_size (hs).
392405
q = q.view(B, T, n_head, head_size) # (B, T, nh_q, hs)
@@ -400,7 +413,7 @@ def forward(
400413
k = k.transpose(1, 2) # (B, nh_k, T, hs)
401414
v = v.transpose(1, 2) # (B, nh_v, T, hs)
402415

403-
if self.config.norm_qk:
416+
if self.config.norm_qk and self.config.norm_qk_type == "default":
404417
q = self.norm_q(q)
405418
k = self.norm_k(k)
406419

litgpt/prompts.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
472472
return Llama3()
473473
if re.search("Llama-3.*-Instruct-*", model_name):
474474
return Llama3()
475+
if re.search("OLMo-2.*-(Instruct|SFT|DPO)", model_name):
476+
return Llama3()
475477
if re.search("R1", model_name):
476478
return R1Base()
477479
if re.search("FreeWilly2", model_name):

litgpt/scripts/convert_hf_checkpoint.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,85 @@ def copy_weights_qwen_2_5(
533533
pbar.update(progress_per_file)
534534

535535

536+
def copy_weights_olmo2(
537+
config: Config,
538+
qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
539+
state_dict: Dict[str, torch.Tensor],
540+
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
541+
saver: Optional[incremental_save] = None,
542+
dtype: Optional[torch.dtype] = None,
543+
pbar: Optional[tqdm] = None,
544+
progress_per_file: Optional[float] = None,
545+
debug_mode: Optional[bool] = False,
546+
) -> None:
547+
weight_map = {
548+
"model.embed_tokens.weight": "transformer.wte.weight",
549+
"model.layers.{}.self_attn.q_norm.weight": "transformer.h.{}.attn.norm_q.weight",
550+
"model.layers.{}.self_attn.q_proj.weight": None,
551+
"model.layers.{}.self_attn.k_norm.weight": "transformer.h.{}.attn.norm_k.weight",
552+
"model.layers.{}.self_attn.k_proj.weight": None,
553+
"model.layers.{}.self_attn.v_proj.weight": None,
554+
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight",
555+
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
556+
"model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.post_attention_norm.weight",
557+
"model.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.post_attention_norm.bias",
558+
"model.layers.{}.post_feedforward_layernorm.weight": "transformer.h.{}.post_mlp_norm.weight",
559+
"model.norm.weight": "transformer.ln_f.weight",
560+
"model.norm.bias": "transformer.ln_f.bias",
561+
"lm_head.weight": "lm_head.weight",
562+
}
563+
if config.mlp_class_name in ("LLaMAMLP", "GemmaMLP"):
564+
weight_map.update(
565+
{
566+
"model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight",
567+
"model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight",
568+
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight",
569+
}
570+
)
571+
else:
572+
raise NotImplementedError
573+
574+
if progress_per_file is not None:
575+
progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))
576+
577+
for from_name, param in hf_weights.items():
578+
name_template, *ids = layer_template(from_name, num_matches=2)
579+
to_name = weight_map[name_template]
580+
param = load_param(param, from_name, dtype, verbose=debug_mode)
581+
if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")):
582+
qkv = qkv_weights.setdefault(ids[0], defaultdict(dict))
583+
weight_name, weight_type = from_name.split(".")[-2:]
584+
qkv[weight_type][weight_name] = param
585+
if to_name is None:
586+
continue
587+
to_name = to_name.format(*ids)
588+
if saver is not None:
589+
param = saver.store_early(param)
590+
state_dict[to_name] = param
591+
592+
if progress_per_file is not None:
593+
pbar.update(progress_per_file)
594+
595+
if "lm_head.weight" not in state_dict:
596+
state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]
597+
598+
for i in list(qkv_weights):
599+
for weight_type in list(qkv_weights[i]):
600+
qkv = qkv_weights[i][weight_type]
601+
if len(qkv) != 3:
602+
# qkv is split across different .bin files
603+
continue
604+
q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode)
605+
k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode)
606+
v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode)
607+
qkv = torch.cat((q, k, v))
608+
state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv
609+
del qkv_weights[i][weight_type]
610+
611+
if progress_per_file is not None:
612+
pbar.update(progress_per_file)
613+
614+
536615
def copy_weights_qwen_3(
537616
config: Config,
538617
qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
@@ -693,6 +772,10 @@ def convert_hf_checkpoint(
693772
# holder to reconstitute the split q, k, v
694773
qkv_weights = {}
695774
copy_fn = partial(copy_weights_qwen_2_5, config, qkv_weights)
775+
elif model_name.lower().startswith("olmo-2-"):
776+
# holder to reconstitute the split q, k, v
777+
qkv_weights = {}
778+
copy_fn = partial(copy_weights_olmo2, config, qkv_weights)
696779
elif model_name.lower().startswith("qwen3"):
697780
# holder to reconstitute the split q, k, v
698781
qkv_weights = {}

litgpt/scripts/convert_lit_checkpoint.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,64 @@ def copy_weights_qwen_2_5(
393393
state_dict[to_name] = param
394394

395395

396+
def copy_weights_olmo2(
397+
config: Config,
398+
state_dict: Dict[str, torch.Tensor],
399+
lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
400+
untie_weights: bool = False,
401+
saver: Optional[incremental_save] = None,
402+
) -> None:
403+
weight_map = {
404+
"transformer.wte.weight": "model.embed_tokens.weight",
405+
"transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight",
406+
"transformer.h.{}.attn.norm_q.weight": "model.layers.{}.self_attn.q_norm.weight",
407+
"transformer.h.{}.attn.norm_k.weight": "model.layers.{}.self_attn.k_norm.weight",
408+
"transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight",
409+
"transformer.h.{}.norm_2.bias": "model.layers.{}.post_attention_layernorm.bias",
410+
"transformer.h.{}.post_mlp_norm.weight": "model.layers.{}.post_feedforward_layernorm.weight",
411+
"transformer.ln_f.weight": "model.norm.weight",
412+
"transformer.ln_f.bias": "model.norm.bias",
413+
"lm_head.weight": "lm_head.weight",
414+
}
415+
if config.mlp_class_name in ("LLaMAMLP", "GemmaMLP"):
416+
weight_map.update(
417+
{
418+
"transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight",
419+
"transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight",
420+
"transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight",
421+
}
422+
)
423+
else:
424+
raise NotImplementedError
425+
426+
for from_name, param in lit_weights.items():
427+
if from_name == "lm_head.weight" and untie_weights:
428+
continue
429+
name_template, *ids = layer_template(from_name, num_matches=2)
430+
param = load_param(param, from_name, None)
431+
if from_name.endswith(".attn.qkv.weight"):
432+
to_names = (
433+
"model.layers.{}.self_attn.q_proj.weight".format(*ids),
434+
"model.layers.{}.self_attn.k_proj.weight".format(*ids),
435+
"model.layers.{}.self_attn.v_proj.weight".format(*ids),
436+
)
437+
params = param.split(
438+
(
439+
config.n_head * config.head_size,
440+
config.n_query_groups * config.head_size,
441+
config.n_query_groups * config.head_size,
442+
)
443+
)
444+
else:
445+
to_names = (weight_map[name_template].format(*ids),)
446+
params = (param,)
447+
448+
for to_name, param in zip(to_names, params):
449+
if saver is not None:
450+
param = saver.store_early(param)
451+
state_dict[to_name] = param
452+
453+
396454
def copy_weights_qwen_3(
397455
config: Config,
398456
state_dict: Dict[str, torch.Tensor],
@@ -487,6 +545,8 @@ def convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None:
487545
copy_fn = partial(copy_weights_phi, config)
488546
elif config.name.lower().startswith(("qwen2.5", "qwq")):
489547
copy_fn = partial(copy_weights_qwen_2_5, config)
548+
elif config.name.lower().startswith("olmo-2-"):
549+
copy_fn = partial(copy_weights_olmo2, config)
490550
elif config.name.lower().startswith("qwen3"):
491551
copy_fn = partial(copy_weights_qwen_3, config)
492552
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"):

0 commit comments

Comments
 (0)