Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions internvl_chat/internvl/train/internvl_chat_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,10 @@ class DataTrainingArguments:
default=False,
metadata={'help': 'Whether to gather all during loss reduction. Default is False.'},
)

eval_meta_path: str = field(
default=None,
metadata={'help': 'The path of the eval meta file of datasets.'},
)

class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
Expand Down Expand Up @@ -703,6 +706,7 @@ def build_datasets(
tokenizer,
tcs_loader,
model,
meta_file_path,
group_by_length=False,
dynamic_image_size=False,
use_thumbnail=False,
Expand All @@ -716,7 +720,7 @@ def build_datasets(
lengths = []
data_rank = dist.get_rank()
data_world_size = dist.get_world_size()
ds_collections = json.loads(open(data_args.meta_path).read())
ds_collections = json.loads(open(meta_file_path).read())
for ds_idx, ds_name in enumerate(ds_collections.keys()):
repeat_time = ds_collections[ds_name]['repeat_time']
if 'max_dynamic_patch' in ds_collections[ds_name]:
Expand Down Expand Up @@ -979,11 +983,38 @@ def main():
model.language_model._set_gradient_checkpointing()

train_dataset = build_datasets(
data_args, tokenizer, tcs_loader, model, group_by_length=training_args.group_by_length,
dynamic_image_size=data_args.dynamic_image_size, use_thumbnail=data_args.use_thumbnail,
min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch,
normalize_type=data_args.normalize_type, min_num_frame=data_args.min_num_frame,
max_num_frame=data_args.max_num_frame)
data_args,
tokenizer,
tcs_loader,
model,
meta_file_path=data_args.meta_path,
group_by_length=training_args.group_by_length,
dynamic_image_size=data_args.dynamic_image_size,
use_thumbnail=data_args.use_thumbnail,
min_dynamic_patch=data_args.min_dynamic_patch,
max_dynamic_patch=data_args.max_dynamic_patch,
normalize_type=data_args.normalize_type,
min_num_frame=data_args.min_num_frame,
max_num_frame=data_args.max_num_frame
)

eval_dataset = None
if training_args.do_eval and data_args.eval_meta_path is not None:
eval_dataset = build_datasets(
data_args,
tokenizer,
tcs_loader,
model,
meta_file_path=data_args.eval_meta_path,
group_by_length=False,
dynamic_image_size=data_args.dynamic_image_size,
use_thumbnail=data_args.use_thumbnail,
min_dynamic_patch=data_args.min_dynamic_patch,
max_dynamic_patch=data_args.max_dynamic_patch,
normalize_type=data_args.normalize_type,
min_num_frame=data_args.min_num_frame,
max_num_frame=data_args.max_num_frame
)

def _freeze_params(module):
for param in module.parameters():
Expand Down Expand Up @@ -1042,7 +1073,7 @@ def _freeze_params(module):
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=None,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=collator,
)
Expand Down
28 changes: 25 additions & 3 deletions internvl_chat_gpt_oss/internvl/train/internvl_chat_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,10 @@ class DataTrainingArguments:
default=False,
metadata={'help': 'Whether to split annotations to save memory usage. Default is False.'},
)

eval_meta_path: str = field(
default=None,
metadata={'help': 'The path of the eval meta file of datasets.'},
)

class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
Expand Down Expand Up @@ -771,6 +774,7 @@ def build_datasets(
tokenizer,
tcs_loader,
model,
meta_file_path,
group_by_length=False,
dynamic_image_size=False,
use_thumbnail=False,
Expand All @@ -785,7 +789,7 @@ def build_datasets(
lengths = []
data_rank = 0 if split_annotations else dist.get_rank()
data_world_size = 1 if split_annotations else dist.get_world_size()
ds_collections = json.loads(open(data_args.meta_path).read())
ds_collections = json.loads(open(meta_file_path).read())
for ds_idx, ds_name in enumerate(ds_collections.keys()):
repeat_time = ds_collections[ds_name]['repeat_time']
if 'max_dynamic_patch' in ds_collections[ds_name]:
Expand Down Expand Up @@ -1058,6 +1062,7 @@ def main():
tokenizer,
tcs_loader,
model,
meta_file_path=data_args.meta_path,
group_by_length=training_args.group_by_length,
dynamic_image_size=data_args.dynamic_image_size,
use_thumbnail=data_args.use_thumbnail,
Expand All @@ -1069,6 +1074,23 @@ def main():
split_annotations=data_args.split_annotations,
)

eval_dataset = None
if training_args.do_eval and data_args.eval_meta_path is not None:
eval_dataset = build_datasets(
data_args,
tokenizer,
tcs_loader, model,
meta_file_path=data_args.eval_meta_path,
group_by_length=False,
dynamic_image_size=data_args.dynamic_image_size,
use_thumbnail=data_args.use_thumbnail,
min_dynamic_patch=data_args.min_dynamic_patch,
max_dynamic_patch=data_args.max_dynamic_patch,
normalize_type=data_args.normalize_type,
min_num_frame=data_args.min_num_frame,
max_num_frame=data_args.max_num_frame
)

def _freeze_params(module):
for param in module.parameters():
param.requires_grad = False
Expand Down Expand Up @@ -1126,7 +1148,7 @@ def _freeze_params(module):
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=None,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=collator,
)
Expand Down