|
10 | 10 | from typing import Any, List, Optional
|
11 | 11 | import importlib
|
12 | 12 | import inspect
|
13 |
| -import json |
14 | 13 | import logging
|
15 | 14 | import os
|
16 | 15 | import random
|
@@ -62,17 +61,10 @@ def check_valid_train_args(train_args: TrainingArgs):
|
62 | 61 | f"Provided path to model does not exist. Please make sure that you've passed a valid model and that it has appropriate permissions: {train_args.model_path}"
|
63 | 62 | )
|
64 | 63 |
|
65 |
| - if train_args.use_dolomite: |
66 |
| - with open(Path(train_args.model_path) / "config.json") as conf_json: |
67 |
| - model_conf = json.load(conf_json) |
68 |
| - if model_conf["model_type"] == "granite": |
69 |
| - raise RuntimeError( |
70 |
| - "Converting Granite models to Dolomite format is currently unsupported." |
71 |
| - ) |
72 |
| - if train_args.disable_flash_attn: |
73 |
| - raise RuntimeError( |
74 |
| - "ERROR: Trying to use dolomite padding-free transformer without flash attention is not supported" |
75 |
| - ) |
| 64 | + if train_args.use_dolomite and train_args.disable_flash_attn: |
| 65 | + raise RuntimeError( |
| 66 | + "ERROR: Trying to use dolomite padding-free transformer without flash attention is not supported" |
| 67 | + ) |
76 | 68 |
|
77 | 69 | if train_args.is_padding_free:
|
78 | 70 | print(
|
@@ -229,7 +221,7 @@ def pad_collate_fn(batch):
|
229 | 221 |
|
230 | 222 | input_ids.extend(item["input_ids"].tolist())
|
231 | 223 | labels.extend(item["labels"].tolist())
|
232 |
| - position_ids.extend(range(total_len, total_len + item_len)) |
| 224 | + position_ids.extend(range(item_len)) |
233 | 225 |
|
234 | 226 | total_len += item_len
|
235 | 227 | num_loss_counted_tokens += (item["labels"] != -100).sum().item()
|
@@ -802,10 +794,21 @@ def _get_state_dict_patched(model, unwrap=False):
|
802 | 794 |
|
803 | 795 | output_dir.mkdir(parents=True, exist_ok=True)
|
804 | 796 | if not model.module.config.architectures and convert_dolomite:
|
805 |
| - model.module.config.architectures = ["LlamaForCausalLM"] |
806 |
| - warnings.warn( |
807 |
| - f"Adding architectures to ckpt: {model.module.config.architectures}", |
808 |
| - ) |
| 797 | + arch_added = False |
| 798 | + if args.model_type == "llama": |
| 799 | + model.module.config.architectures = ["LlamaForCausalLM"] |
| 800 | + arch_added = True |
| 801 | + elif args.model_type == "granite": |
| 802 | + model.module.config.architectures = ["GraniteForCausalLM"] |
| 803 | + arch_added = True |
| 804 | + if arch_added: |
| 805 | + warnings.warn( |
| 806 | + f"Adding architectures to ckpt: {model.module.config.architectures}", |
| 807 | + ) |
| 808 | + else: |
| 809 | + warnings.warn( |
| 810 | + f"Converting from dolomite, but no architecture field added to config.json", |
| 811 | + ) |
809 | 812 | model.module.config.to_json_file(output_config_file)
|
810 | 813 | tokenizer.save_pretrained(output_dir)
|
811 | 814 |
|
@@ -834,7 +837,7 @@ def _get_state_dict_patched(model, unwrap=False):
|
834 | 837 | export_to_huggingface(
|
835 | 838 | pretrained_model_name_or_path=tmpdir.name,
|
836 | 839 | save_path=final_output_dir,
|
837 |
| - model_type="llama", |
| 840 | + model_type=args.model_type, |
838 | 841 | )
|
839 | 842 | tmpdir.cleanup()
|
840 | 843 |
|
|
0 commit comments