Skip to content

Commit 512949f

Browse files
authored
Merge pull request #312 from instructlab/granite-dolomite
Adding Dolomite Support and Bringing HF Padding-Free into Performance Parity
2 parents dc7c97d + 2a9626f commit 512949f

File tree

3 files changed

+27
-19
lines changed

3 files changed

+27
-19
lines changed

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ numba
1717
numpy>=1.23.5,<2.0.0 ; python_version == '3.10'
1818
numpy>=1.26.4,<2.0.0 ; python_version != '3.10'
1919
rich
20-
instructlab-dolomite>=0.1.1
20+
instructlab-dolomite>=0.2.0
2121
trl>=0.9.4
2222
peft
2323
pydantic>=2.7.0

src/instructlab/training/main_ds.py

+5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from copy import deepcopy
55
from pathlib import Path
66
import argparse
7+
import json
78
import math
89
import os
910
import re
@@ -528,6 +529,10 @@ def main(args):
528529
tokenizer = setup_tokenizer(args.model_name_or_path, SPECIAL_TOKENS, CHAT_TEMPLATE)
529530
# device = torch.device("cuda", args.local_rank)
530531

532+
with open(Path(args.model_name_or_path) / "config.json") as conf_json:
533+
model_conf = json.load(conf_json)
534+
args.model_type = model_conf["model_type"]
535+
531536
#### distributed init #####
532537
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
533538
args.local_rank = int(os.environ["LOCAL_RANK"])

src/instructlab/training/utils.py

+21-18
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from typing import Any, List, Optional
1111
import importlib
1212
import inspect
13-
import json
1413
import logging
1514
import os
1615
import random
@@ -62,17 +61,10 @@ def check_valid_train_args(train_args: TrainingArgs):
6261
f"Provided path to model does not exist. Please make sure that you've passed a valid model and that it has appropriate permissions: {train_args.model_path}"
6362
)
6463

65-
if train_args.use_dolomite:
66-
with open(Path(train_args.model_path) / "config.json") as conf_json:
67-
model_conf = json.load(conf_json)
68-
if model_conf["model_type"] == "granite":
69-
raise RuntimeError(
70-
"Converting Granite models to Dolomite format is currently unsupported."
71-
)
72-
if train_args.disable_flash_attn:
73-
raise RuntimeError(
74-
"ERROR: Trying to use dolomite padding-free transformer without flash attention is not supported"
75-
)
64+
if train_args.use_dolomite and train_args.disable_flash_attn:
65+
raise RuntimeError(
66+
"ERROR: Trying to use dolomite padding-free transformer without flash attention is not supported"
67+
)
7668

7769
if train_args.is_padding_free:
7870
print(
@@ -229,7 +221,7 @@ def pad_collate_fn(batch):
229221

230222
input_ids.extend(item["input_ids"].tolist())
231223
labels.extend(item["labels"].tolist())
232-
position_ids.extend(range(total_len, total_len + item_len))
224+
position_ids.extend(range(item_len))
233225

234226
total_len += item_len
235227
num_loss_counted_tokens += (item["labels"] != -100).sum().item()
@@ -802,10 +794,21 @@ def _get_state_dict_patched(model, unwrap=False):
802794

803795
output_dir.mkdir(parents=True, exist_ok=True)
804796
if not model.module.config.architectures and convert_dolomite:
805-
model.module.config.architectures = ["LlamaForCausalLM"]
806-
warnings.warn(
807-
f"Adding architectures to ckpt: {model.module.config.architectures}",
808-
)
797+
arch_added = False
798+
if args.model_type == "llama":
799+
model.module.config.architectures = ["LlamaForCausalLM"]
800+
arch_added = True
801+
elif args.model_type == "granite":
802+
model.module.config.architectures = ["GraniteForCausalLM"]
803+
arch_added = True
804+
if arch_added:
805+
warnings.warn(
806+
f"Adding architectures to ckpt: {model.module.config.architectures}",
807+
)
808+
else:
809+
warnings.warn(
810+
f"Converting from dolomite, but no architecture field added to config.json",
811+
)
809812
model.module.config.to_json_file(output_config_file)
810813
tokenizer.save_pretrained(output_dir)
811814

@@ -834,7 +837,7 @@ def _get_state_dict_patched(model, unwrap=False):
834837
export_to_huggingface(
835838
pretrained_model_name_or_path=tmpdir.name,
836839
save_path=final_output_dir,
837-
model_type="llama",
840+
model_type=args.model_type,
838841
)
839842
tmpdir.cleanup()
840843

0 commit comments

Comments
 (0)