diff --git a/xtuner/configs/internvl/v2/internvl_v2_internlm2_26b_finetune_balance.py b/xtuner/configs/internvl/v2/internvl_v2_internlm2_26b_finetune_balance.py new file mode 100644 index 000000000..fe7e71c9a --- /dev/null +++ b/xtuner/configs/internvl/v2/internvl_v2_internlm2_26b_finetune_balance.py @@ -0,0 +1,179 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from torch.optim import AdamW +from transformers import AutoTokenizer + +from xtuner.dataset import InternVL_V1_5_Dataset, BalancedDataset +from xtuner.dataset.collate_fns import default_collate_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from xtuner.engine.hooks import DatasetInfoHook, VarlenAttnArgsToMessageHubHook +from xtuner.engine.runner import TrainLoop +from xtuner.model import InternVL_V1_5 +from xtuner.utils import PROMPT_TEMPLATE + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +path = '/model/internvl' + +# Data +data_path='pack_internvl_sft_1.2M.json' +'''about the data pack_internvl_sft_1.2M.json +use the scripts data_preprocess_stastics.sh in the xtuner/tools +to generate pack_internvl_sft_1.2M.json +''' +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = 4096 + +# Scheduler & Optimizer +batch_size = 1 # per_device +accumulative_counts = 4 +dataloader_num_workers = 8 +max_epochs = 1 +optim_type = AdamW +# official 1024 -> 4e-5 +lr = 1e-6 +betas = (0.9, 0.999) +weight_decay = 0.05 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + +# Save +save_steps = 1000 +save_total_limit = 1 # Maximum checkpoints to keep (-1 means unlimited) + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +model = dict( + type=InternVL_V1_5, + model_path=path, + freeze_llm=False, + freeze_visual_encoder=False, # or False + use_varlen_attn=True +) + +evaluation_freq = 0 +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +llava_dataset = dict( + type=BalancedDataset, + model_path=path, + data_path=data_path, + vit_packed_length=9, # The value for vit packed length + llm_packed_length=4096, # The value for llm packed length + llm_thresh=4068, # The value for llm thresh + template=prompt_template, + max_length=max_length) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=llava_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=default_collate_fn, use_varlen_attn=True, balance_data=True)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=path, + trust_remote_code=True) + +custom_hooks = [ + dict(type=DatasetInfoHook, tokenizer=tokenizer), +] + +custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)] # vallen_attention 依赖的 Hook + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=1), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + save_optimizer=False, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) + diff --git a/xtuner/dataset/__init__.py b/xtuner/dataset/__init__.py index 2ad3d7bd9..6219ba6f1 100644 --- a/xtuner/dataset/__init__.py +++ b/xtuner/dataset/__init__.py @@ -14,6 +14,7 @@ from .refcoco_json import (InvRefCOCOJsonDataset, RefCOCOJsonDataset, RefCOCOJsonEvalDataset) from .utils import decode_base64_to_image, expand2square, load_image +from .fast_dataset import BalancedDataset # ignore FutureWarning in hf datasets warnings.simplefilter(action='ignore', category=FutureWarning) @@ -25,5 +26,5 @@ 'load_intern_repo_tokenized_dataset', 'load_intern_repo_untokenized_dataset', 'build_packed_dataset', 'RefCOCOJsonDataset', 'RefCOCOJsonEvalDataset', 'InvRefCOCOJsonDataset', - 'load_json_file', 'InternVL_V1_5_Dataset' + 'load_json_file', 'InternVL_V1_5_Dataset', 'BalancedDataset' ] diff --git a/xtuner/dataset/collate_fns/default_collate_fn.py b/xtuner/dataset/collate_fns/default_collate_fn.py index 3d9fe18fb..51437c002 100644 --- a/xtuner/dataset/collate_fns/default_collate_fn.py +++ b/xtuner/dataset/collate_fns/default_collate_fn.py @@ -12,7 +12,8 @@ def default_collate_fn(instances: Sequence[Dict], pad_index: int = DEFAULT_PAD_TOKEN_INDEX, return_hf_format: bool = False, - use_varlen_attn: bool = False): + use_varlen_attn: bool = False, + balance_data: bool = False): seq_parallel_world_size = get_sequence_parallel_world_size() input_ids, labels = [], [] @@ -22,7 +23,7 @@ def default_collate_fn(instances: Sequence[Dict], assert len(instances) == 1, ( f'If utilizing varlen attention, the batch size should be' f' set to 1, but got {len(instances)}') - assert not has_image, 'Currently, it is not configured to ' + assert not has_image or balance_data, 'Currently, it is not configured to ' 'accommodate the use of varlen Attention in multimodal training' if has_image: @@ -39,6 +40,7 @@ def default_collate_fn(instances: Sequence[Dict], pixel_values.append(example['pixel_values']) ori_length = [len(ids) for ids in input_ids] + if len(instances) > 1: input_ids = pad_sequence( input_ids, batch_first=True, padding_value=pad_index) diff --git a/xtuner/dataset/fast_dataset.py b/xtuner/dataset/fast_dataset.py new file mode 100644 index 000000000..652ca6f4e --- /dev/null +++ b/xtuner/dataset/fast_dataset.py @@ -0,0 +1,341 @@ +import json +import copy + +import numpy as np +from multiprocessing.pool import ThreadPool as Pool + +import os +import torch +from torch.utils.data import Dataset, ConcatDataset +from transformers import AutoTokenizer + +from .internvl_dataset import InternVL_V1_5_Dataset + + +def get_token_sum(g): + sum = 0 + for i in g: + sum += i[2] + return sum + + +def get_vit_num(g): + vit_num = 0 + for _ in g: + vit_num += _[1] + return vit_num + + +DEFAULT_SEED = 1024 + + +class BalancedDataset(Dataset): + def __init__(self, + data_path, + model_path, + template, + max_length, + vit_packed_length=15, + llm_packed_length=4096, + llm_thresh={}, + worker=64, + iter_time=100): + cfg_dataset_base = {} + cfg_dataset_base['template'] = template + cfg_dataset_base['model_path'] = model_path + cfg_dataset_base['max_length'] = max_length + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True) + self.vit_packed_length = vit_packed_length + self.llm_packed_length = llm_packed_length + self.llm_thresh = {'thresh': llm_thresh} + + self.vit_lengths, self.llm_lengths = [], [] + self.worker = worker + self.pad_token_id = len(self.tokenizer) - 1 + self.iter_time = iter_time + + # prepare concat dataset + ds_collections = json.loads(open(data_path).read()) + datasets = [] + for ds_name in ds_collections.keys(): + cfg_dataset = copy.deepcopy(cfg_dataset_base) + cfg_dataset['repeat_times'] = ds_collections[ds_name]['repeat_time'] + cfg_dataset['data_paths'] = ds_collections[ds_name]['annotation'] + cfg_dataset['image_folders'] = ds_collections[ds_name]['root'] + + dataset = InternVL_V1_5_Dataset(**cfg_dataset) + dataset.meta = ds_collections[ds_name] + datasets.append(dataset) + + self.dataset = ConcatDataset(datasets) + print("Begin preprocess dataset", flush=True) + self.preprocess() + print("Preprocess dataset successed", flush=True) + self.seed = DEFAULT_SEED + self.pack_groups = self.get_packed_groups() + + group_length = [] + for g in self.pack_groups: + num_img = 0 + length_input = 0 + for item in g: + num_img += item[3] + length_input += item[2] + + if num_img != 0: + group_length.append(length_input) + else: + group_length.append(-length_input) + + self.group_length = group_length + + @property + def modality_length(self): + return self.group_length + + @property + def length(self): + group_length = np.array(self.group_length) + group_length = np.abs(group_length).tolist() + return group_length + + def preprocess(self): + dict_num_tokens = {} + num_datasets = len(self.dataset.datasets) + for dataset_idx in range(num_datasets): + sub_dataset = self.dataset.datasets[dataset_idx] + if "token_lengths" in sub_dataset.meta: + print(f"Load from cache for dataset {dataset_idx}", flush=True) + assert os.path.exists( + sub_dataset.meta["token_lengths"]), f"Dataset {dataset_idx} token_lengths file does not exist." + with open(sub_dataset.meta["token_lengths"], "r") as f: + token_lengths = json.load(f) + dict_num_tokens[dataset_idx] = { + "lengths": len(sub_dataset), + # sub_dataset.meta["token_lengths"] + "token_lengths": token_lengths + } + else: + print( + f"Generate length json for dataset {dataset_idx}", flush=True) + token_lengths = [] + origin_indexs = list(range(len(sub_dataset))) + token_lengths_dict = dict() + + def decode_text(idx): + meta = sub_dataset.__getitem__(idx) + if meta['pixel_values'].sum().item() == 0: + image_flags = 0 + else: + image_flags = meta['pixel_values'].shape[0] + token_lengths_dict[idx] = { + "vit_num": meta['pixel_values'].shape[0], + "token_num": len(meta['input_ids']), + "image_flags": image_flags + } + + with Pool(self.worker) as p: + _ = p.map(decode_text, origin_indexs[:]) + for idx in range(len(sub_dataset)): + token_lengths.append( + token_lengths_dict[idx] + ) + dict_num_tokens[dataset_idx] = { + "lengths": len(sub_dataset), + "token_lengths": token_lengths + } + print( + f"Finish length json for dataset {dataset_idx}", flush=True) + self.dict_num_tokens = dict_num_tokens + + def _random_groups(self, token_lengths, seed=None): + """ + tokens_length: [(idx, vit_img_num, llm_token_len)] + """ + rng = np.random.RandomState(seed) + index = list(range(len(token_lengths))) + rng.shuffle(index) + + pack_groups = [] + vit_token_length_sum, llm_token_length_sum = 0, 0 + each_group = [] + for idx, sample_id in enumerate(index): + vit_sample_length, llm_sample_length = token_lengths[ + sample_id][1], token_lengths[sample_id][2] + if vit_sample_length > self.vit_packed_length or llm_sample_length > self.llm_packed_length: + continue + vit_token_length_sum += vit_sample_length + llm_token_length_sum += llm_sample_length + if vit_token_length_sum > self.vit_packed_length or llm_token_length_sum > self.llm_packed_length: + pack_groups.append(each_group) + vit_token_length_sum = vit_sample_length + llm_token_length_sum = llm_sample_length + each_group = [token_lengths[sample_id]] + else: + each_group.append(token_lengths[sample_id]) + if idx == len(token_lengths) - 1: + if len(each_group) > 0: + pack_groups.append(each_group) + return pack_groups + + def process_random_groups_input(self, groups, accu_length=0): + new_groups = [] + for idx, item in enumerate(groups): + if item["vit_num"] == -1: + print(f"item {idx} was filted.", flush=True) + continue + new_groups.append( + (idx + accu_length, item['vit_num'], item['token_num'], item["image_flags"])) + return new_groups + + def iter_random_groups(self, groups, llm_thresh=None, seed=None, iter_time=300): + if llm_thresh is None: + llm_thresh = self.llm_packed_length + if seed is None: + seed = self.seed + groups = self._random_groups(groups, seed=seed) + if iter_time == 1: + return groups + output = [] + for i in range(iter_time - 1): + print(f"iter_random_groups {i} / {iter_time - 1}", flush=True) + need_process_groups = [] + for g in groups: + vit_num = get_vit_num(g) + llm_num = get_token_sum(g) + if vit_num == self.vit_packed_length or llm_num >= llm_thresh: + output.append(g) + else: + need_process_groups.extend(g) + if len(need_process_groups) >= 0: + groups = self._random_groups(need_process_groups, seed + i) + else: + break + if len(need_process_groups) > 0: + output.extend(self._random_groups(need_process_groups, seed + i)) + return output + + def collect_packed_info(self, packed_groups): + info_dict = {} + info_dict['vit_num_info'] = {} + vit_num_min = 10000000 + vit_num_max = 0 + llm_num_min = 10000000 + llm_num_max = 0 + vit_ave_num = 0 + llm_ave_num = 0 + sample_num = 0 + for group in packed_groups: + vit_num = get_vit_num(group) + llm_num = get_token_sum(group) + if vit_num not in info_dict['vit_num_info']: + info_dict['vit_num_info'][vit_num] = 0 + info_dict['vit_num_info'][vit_num] += 1 + vit_num_min = min(vit_num_min, vit_num) + vit_num_max = max(vit_num_max, vit_num) + llm_num_min = min(llm_num_min, llm_num) + llm_num_max = max(llm_num_max, llm_num) + vit_ave_num += vit_num + llm_ave_num += llm_num + sample_num += len(group) + info_dict['vit_num_min'] = vit_num_min + info_dict['vit_num_max'] = vit_num_max + info_dict['vit_ave_num'] = vit_ave_num / float(len(packed_groups)) + info_dict['llm_ave_num'] = llm_ave_num / float(len(packed_groups)) + info_dict['sample_num'] = sample_num + info_dict['packed_group_num'] = len(packed_groups) + return info_dict + + def find_best_groups(self, input_groups, step=4, step_num=20): + best_group_num = 10000000000000 + best_groups = [] + best_info_dict = {} + best_llm_thresh = 0 + llm_thresh = self.llm_packed_length + for step_id in range(step_num): + print(f"find_best_groups {step_id} / {step_num}", flush=True) + groups = self.iter_random_groups( + input_groups, llm_thresh, seed=self.seed, iter_time=self.iter_time) + cur_info_dict = self.collect_packed_info(groups) + if cur_info_dict['packed_group_num'] < best_group_num: + best_group_num = cur_info_dict['packed_group_num'] + best_groups = groups + best_info_dict = cur_info_dict + best_llm_thresh = llm_thresh + llm_thresh -= step + print(f"llm thresh {best_llm_thresh} best info dict", + best_info_dict, flush=True) + return best_groups + + def get_packed_groups(self): + num_datasets = len(list(self.dict_num_tokens.keys())) + accu_length = 0 + input_groups = [] + for d_idx in range(num_datasets): + dict_item = self.dict_num_tokens[d_idx] + token_lengths = dict_item["token_lengths"] + groups = self.process_random_groups_input( + token_lengths, accu_length) + print(f"get_packed_groups {d_idx}.", flush=True) + input_groups.extend(groups) + accu_length += len(token_lengths) + if self.llm_thresh.get('thresh', None) is not None: + groups = self.iter_random_groups( + input_groups, llm_thresh=self.llm_thresh['thresh'], seed=self.seed, iter_time=self.iter_time) + else: + groups = self.find_best_groups(input_groups, self.llm_thresh.get( + 'step', 4), self.llm_thresh.get('step_num', 10)) + print(self.collect_packed_info(groups), flush=True) + print("get_packed_groups done!", flush=True) + return groups + + def __getitem__(self, item: int): + item = item % len(self.pack_groups) + # item = random.randint(0, len(self.pack_groups) - 1) + while True: + try: + groups = self.pack_groups[item] + + input_ids, pixel_values = [], [] + labels, position_ids, image_flags = [], [], [] + cu_seqlens = [0] + for g in groups: + idx, _, llm_length, image_flag = g + meta = self.dataset.__getitem__(idx) + # print("llm_length: ", llm_length, "input_ids: ", len(meta["input_ids"])) + assert len(meta["input_ids"]) == llm_length + image_flag_ = ( + torch.sum(meta['pixel_values'], dim=(1, 2, 3)) != 0).sum() + assert image_flag == image_flag_.item() + input_ids.extend(meta['input_ids']) + pixel_values.append(meta['pixel_values']) + labels.extend(meta['labels']) + cu_seqlens.append(len(meta['input_ids'])) + position_ids.extend(list(range(len(meta['input_ids'])))) + image_flags.append(image_flag) + + cu_seqlens = np.cumsum(np.array(cu_seqlens)).tolist() + input_ids = input_ids[:self.llm_packed_length] + pixel_values = torch.cat(pixel_values)[:self.vit_packed_length] + labels = labels[:self.llm_packed_length] + position_ids = position_ids[:self.llm_packed_length] + cu_seqlens[-1] = len(position_ids) + + ret = { + "input_ids": input_ids, + "labels": labels, + "cumulative_len": cu_seqlens, + "position_ids": position_ids, + "pixel_values": pixel_values, + } + break + except Exception as e: + print(f"{e}", flush=True) + # i = random.randint(0, len(self.raw_data) - 1) + item = (item + 100) % len(self.pack_groups) + return ret + + def __len__(self): + n_packs = len(self.pack_groups) + return n_packs diff --git a/xtuner/dataset/internvl_dataset.py b/xtuner/dataset/internvl_dataset.py index 82904ae87..74aed8277 100644 --- a/xtuner/dataset/internvl_dataset.py +++ b/xtuner/dataset/internvl_dataset.py @@ -309,6 +309,11 @@ def prepare_data(self, index): print_log(f'Error: {e}', logger='current') return None + # Ensure the first conversation contains an image placeholder + if '' not in data_dict['conversations'][0]['value']: + data_dict['conversations'][0]['value'] = \ + '\n' + data_dict['conversations'][0]['value'] + images = dynamic_preprocess(image, self.min_dynamic_patch, self.max_dynamic_patch, self.image_size, self.use_thumbnail) diff --git a/xtuner/model/internvl.py b/xtuner/model/internvl.py index 0358266a9..f14fb040a 100644 --- a/xtuner/model/internvl.py +++ b/xtuner/model/internvl.py @@ -15,6 +15,7 @@ from xtuner.registry import BUILDER from .utils import (find_all_linear_names, get_peft_model_state_dict, guess_load_checkpoint, make_inputs_require_grad) +from .modules import dispatch_modules class InternVL_V1_5(BaseModel): @@ -27,7 +28,8 @@ def __init__(self, visual_encoder_lora=None, quantization_vit=False, quantization_llm=False, - pretrained_pth=None): + pretrained_pth=None, + use_varlen_attn=False): print_log('Start to load InternVL_V1_5 model.', logger='current') super().__init__() self.freeze_llm = freeze_llm @@ -110,6 +112,8 @@ def __init__(self, self._count = 0 print_log(self, logger='current') print_log('InternVL_V1_5 construction is complete', logger='current') + dispatch_modules(self.model.language_model, + use_varlen_attn=use_varlen_attn) def _parse_lora_config(self, lora_config): if isinstance(lora_config, dict) or isinstance( @@ -194,13 +198,13 @@ def forward(self, data, data_samples=None, mode='loss'): image.to(self.model.vision_model.dtype) for image in pixel_values ], - dim=0) + dim=0) else: raise NotImplementedError() input_ids = data['input_ids'] position_ids = data['position_ids'] - attention_mask = data['attention_mask'] + attention_mask = data.get('attention_mask', None) # sum is 0 are text image_flags = torch.sum(concat_images, dim=(1, 2, 3)) != 0 image_flags = image_flags.long() diff --git a/xtuner/tools/data_preprocess_stastics.py b/xtuner/tools/data_preprocess_stastics.py new file mode 100755 index 000000000..441262608 --- /dev/null +++ b/xtuner/tools/data_preprocess_stastics.py @@ -0,0 +1,547 @@ +from xtuner.utils import IGNORE_INDEX +from transformers import AutoConfig, AutoTokenizer +from torchvision.transforms.functional import InterpolationMode +from PIL import Image +from mmengine.fileio import get +from mmengine import print_log +import torchvision.transforms as T +import warnings +import random +import io +import copy +import json +from multiprocessing import Manager +import multiprocessing +import argparse +from tqdm import tqdm +from functools import partial +import os +import numpy as np +from copy import deepcopy +import torch + +from transformers import AutoTokenizer +from torch.utils.data import Dataset +PROCESSES = 64 + + +# Referenced from InternVL +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, + image_size): + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def dynamic_preprocess(image, + min_num=1, + max_num=6, + image_size=448, + use_thumbnail=False): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = {(i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num} + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, + target_ratios, orig_width, + orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ((i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + +def total_image_token(orig_size, + min_num=1, + max_num=12, + image_size=448, + use_thumbnail=True): + orig_width, orig_height = orig_size + + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = {(i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) for j in range(1, n + 1) + if max_num >= i * j >= min_num} + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, + target_ratios, orig_width, + orig_height, image_size) + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + if use_thumbnail and blocks != 1: + blocks += 1 + + return blocks + + +def load_json_or_jsonl(json_path): + if json_path.endswith('.json'): + with open(json_path) as f: + data = json.load(f) + elif json_path.endswith('.jsonl'): + with open(json_path) as f: + data = [json.loads(line) for line in f] + else: + raise ValueError(f'Unsupported file format: {json_path}, ' + f'only support .json and .jsonl.') + return data + + +class InternVL_V1_5_Dataset(Dataset): + os.environ['TOKENIZERS_PARALLELISM'] = 'true' + IMG_CONTEXT_TOKEN = '' + IMG_START_TOKEN = '' + IMG_END_TOKEN = '' + + IMAGENET_MEAN = (0.485, 0.456, 0.406) + IMAGENET_STD = (0.229, 0.224, 0.225) + + def __init__(self, + model_path, + template, + data_paths, + image_folders=None, + repeat_times=1, + max_length=8192): + self.template = template + self.max_length = max_length + + self.cfg = AutoConfig.from_pretrained( + model_path, trust_remote_code=True) + + # The following modifications are only to ensure full + # consistency with the official template, + # without investigating the impact on performance. + if self.cfg.llm_config.architectures[0] == 'Phi3ForCausalLM': + self._system = 'You are an AI assistant whose name is Phi-3.' + self.template[ + 'INSTRUCTION'] = '<|user|>\n{input}<|end|><|assistant|>\n' + elif self.cfg.llm_config.architectures[0] == 'InternLM2ForCausalLM': + self._system = 'You are an AI assistant whose name ' \ + 'is InternLM (书生·浦语).' + self.template['SYSTEM'] = '<|im_start|>system\n{system}<|im_end|>' + self.template[ + 'INSTRUCTION'] = '<|im_start|>user\n{input}' \ + '<|im_end|><|im_start|>assistant\n' + else: + raise NotImplementedError + + self.min_dynamic_patch = self.cfg.min_dynamic_patch + self.max_dynamic_patch = self.cfg.max_dynamic_patch + self.downsample_ratio = self.cfg.downsample_ratio + self.image_size = self.cfg.force_image_size + self.use_thumbnail = self.cfg.use_thumbnail + patch_size = self.cfg.vision_config.patch_size + self.patch_token = int( + (self.image_size // patch_size)**2 * (self.downsample_ratio**2)) + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True) + self.transformer = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') + if img.mode != 'RGB' else img), + T.Resize((self.image_size, self.image_size), + interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD) + ]) + + if not isinstance(data_paths, (list, tuple)): + data_paths = [data_paths] + if not isinstance(image_folders, (list, tuple)): + image_folders = [image_folders] + if not isinstance(repeat_times, (list, tuple)): + repeat_times = [repeat_times] + assert len(data_paths) == len(image_folders) == len(repeat_times) + + print_log('Starting to loading data and calc length', logger='current') + self.data = [] + self.image_folder = [] + self.group_length = [] + self.conv2length_text = { + } # using dict to speedup the calculation of token length + + for data_file, image_folder, repeat_time in zip( + data_paths, image_folders, repeat_times): + print_log( + f'=======Starting to process {data_file} =======', + logger='current') + assert repeat_time > 0 + json_data = load_json_or_jsonl(data_file) + if repeat_time < 1: + json_data = random.sample(json_data, + int(len(json_data) * repeat_time)) + elif repeat_time > 1: + int_repeat_time = int(repeat_time) + remaining_repeat_time = repeat_time - repeat_time + if remaining_repeat_time > 0: + remaining_json_data = random.sample( + json_data, int(len(json_data) * remaining_repeat_time)) + json_data = json_data * int_repeat_time + json_data.extend(remaining_json_data) + else: + json_data = json_data * int_repeat_time + + self.data.extend(json_data) + self.image_folder.extend([image_folder] * len(json_data)) + + # TODO: multi process + for data_item in json_data: + if 'length' in data_item: + token_length = data_item['length'] # include image token + else: + conversations = '\n'.join( + [temp['value'] for temp in data_item['conversations']]) + str_length = len(conversations) + + if str_length not in self.conv2length_text: + token_length = self.tokenizer( + conversations, + return_tensors='pt', + padding=False, + truncation=False, + ).input_ids.size(1) + self.conv2length_text[str_length] = token_length + else: + token_length = self.conv2length_text[str_length] + + if 'image' in data_item and data_item['image'] is not None: + if 'image_wh' in data_item and data_item[ + 'image_wh'] is not None: + # more accurate calculation of image token + image_wh = data_item['image_wh'] + if isinstance(image_wh[0], list): + image_wh = image_wh[0] + image_token = total_image_token( + image_wh, self.min_dynamic_patch, + self.max_dynamic_patch, self.image_size, + self.use_thumbnail) + image_token = self.patch_token * image_token + else: + # max_dynamic_patch + use_thumbnail + image_token = self.patch_token * ( + self.max_dynamic_patch + self.use_thumbnail) + + token_length = token_length + image_token + else: + token_length = -token_length + + self.group_length.append(token_length) + print_log( + f'=======total {len(json_data)} samples of {data_file}=======', + logger='current') + + assert len(self.group_length) == len(self.data) + print_log('end loading data and calc length', logger='current') + print_log( + f'=======total {len(self.data)} samples=======', logger='current') + self._max_refetch = 1000 + + def __getitem__(self, index): + for _ in range(self._max_refetch + 1): + data = self.prepare_data(index) + # Broken images may cause the returned data to be None + if data is None: + index = self._rand_another() + continue + return data + + def __len__(self): + return len(self.data) + + @property + def modality_length(self): + return self.group_length + + @property + def length(self): + group_length = np.array(self.group_length) + group_length = np.abs(group_length).tolist() + return group_length + + def prepare_data(self, index): + data_dict: dict = self.data[index] + image_folder = self.image_folder[index] + + out_data_dict = {} + if data_dict.get('image', None) is not None: + image_wh = data_dict["width"], data_dict["height"] + # Ensure the first conversation contains an image placeholder + if '' not in data_dict['conversations'][0]['value']: + data_dict['conversations'][0]['value'] = \ + '\n' + data_dict['conversations'][0]['value'] + image_token = total_image_token( + image_wh, self.min_dynamic_patch, + self.max_dynamic_patch, self.image_size, + self.use_thumbnail) + num_image_tokens = self.patch_token * image_token + + image_token_str = f'{self.IMG_START_TOKEN}' \ + f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \ + f'{self.IMG_END_TOKEN}' + token_dict = self.get_inputid_labels(data_dict['conversations'], + image_token_str) + out_data_dict['num_patches'] = image_token + out_data_dict['num_tokens'] = len(token_dict['input_ids']) + out_data_dict['image_flags'] = torch.tensor( + [1] * image_token, dtype=torch.long) + else: + token_dict = self.get_inputid_labels(data_dict['conversations'], + None) + out_data_dict['num_patches'] = 1 + out_data_dict['num_tokens'] = len(token_dict['input_ids']) + out_data_dict['image_flags'] = torch.tensor([0], dtype=torch.long) + return out_data_dict + + def _rand_another(self) -> int: + return np.random.randint(0, len(self.data)) + + def get_image(self, path): + if 's3://' in path: + img_bytes = get(path) + with io.BytesIO(img_bytes) as buff: + img = Image.open(buff).convert('RGB') + return img + else: + return Image.open(path).convert('RGB') + + def get_inputid_labels(self, conversations, image_token_str) -> dict: + input = '' + out_conversation = [] + while conversations and conversations[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + conversations = conversations[1:] + for msg in conversations: + if msg['from'] == 'human': + if image_token_str is None and '' in msg['value']: + warnings.warn( + f'The current data << {msg["value"]} >> is ' + f'in plain text mode, but ' + 'there are tags present in the data. ' + 'We need to remove the tags.') + msg['value'] = msg['value'].replace('', '') + if '' in msg['value']: + msg['value'] = msg['value'].replace('', '').strip() + msg['value'] = image_token_str + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'].strip() + elif msg['from'] == 'gpt': + out_conversation.append({ + 'input': input, + 'output': msg['value'].strip() + }) + input = '' + else: + raise NotImplementedError + + input_ids, labels = [], [] + for i, single_turn_conversation in enumerate(out_conversation): + input = single_turn_conversation.get('input', '') + if input is None: + input = '' + input_text = self.template.INSTRUCTION.format( + input=input, round=i + 1) + + if i == 0: + system = self.template.SYSTEM.format(system=self._system) + input_text = system + input_text + input_encode = self.tokenizer.encode( + input_text, add_special_tokens=True) + else: + input_encode = self.tokenizer.encode( + input_text, add_special_tokens=False) + input_ids += input_encode + labels += [IGNORE_INDEX] * len(input_encode) + + output_text = single_turn_conversation.get('output', '') + if self.template.get('SUFFIX', None): + output_text += self.template.SUFFIX + output_encode = self.tokenizer.encode( + output_text, add_special_tokens=False) + input_ids += output_encode + labels += copy.deepcopy(output_encode) + + if len(input_ids) > self.max_length: + input_ids = input_ids[:self.max_length] + labels = labels[:self.max_length] + print_log( + f'Warning: input_ids length({len(input_ids)}) ' + f'is longer than max_length, cut to {self.max_length}', + logger='current') + return {'input_ids': input_ids, 'labels': labels} + + +def decode_text(args): + cfg_dataset, inds = args + dataset = InternVL_V1_5_Dataset(**cfg_dataset) + dataset.ds_name = "dummy" + token_lengths = [] + for idx in inds: + item = dataset.__getitem__(idx) + flag = item['image_flags'].sum().item() + if flag == 0: + num_vit_patch = item['num_patches'] + num_token = item['num_tokens'] + image_flags = 0 + elif flag == -1: + num_vit_patch = -1 + num_token = -1 + image_flags = -1 + else: + num_vit_patch = flag + num_token = item['num_tokens'] + image_flags = flag + + token_lengths.append( + { + "vit_num": num_vit_patch, + "token_num": num_token, + "image_flags": image_flags + } + ) + + return token_lengths + + +def worker(cfg_dataset, ds_name, token_lengths_path, ds_info): + dataset = InternVL_V1_5_Dataset(**cfg_dataset) + with multiprocessing.Pool(PROCESSES) as pool: + token_lengths_all = pool.map(decode_text, [( + cfg_dataset, inds) for inds in np.array_split(range(len(dataset)), PROCESSES)]) + l_token_lengths = [] + # token_lengths_all = decode_text((cfg_dataset, list(range(len(dataset))))) + for tmp in token_lengths_all: + l_token_lengths.extend(tmp) + + length_save_path = os.path.join( + token_lengths_path, f"{ds_name}"+"_token_lengths.json") + + with open(length_save_path, "w") as f: + json.dump(l_token_lengths, f, indent=4) + if "max_dynamic_patch" in ds_info: + info = { + "root": ds_info["root"], + "annotation": ds_info["annotation"], + "data_augment": ds_info["data_augment"], + "repeat_time": ds_info["repeat_time"], + "length": len(dataset), + "token_lengths": length_save_path, + "max_dynamic_patch": ds_info["max_dynamic_patch"] + } + else: + info = { + "root": ds_info["root"], + "annotation": ds_info["annotation"], + "data_augment": ds_info["data_augment"], + "repeat_time": ds_info["repeat_time"], + "length": len(dataset), + "token_lengths": length_save_path + } + return info + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--json_file", + default=None, + help="json file to statistics" + ) + parser.add_argument( + "--worker", + default=64, type=int, + help="worker num", + ) + parser.add_argument( + "--token_lengths_path", + default=None, + help="token_lengths_path", + ) + parser.add_argument( + "--output_path", + default=None, + help="token_lengths_path", + ) + args = parser.parse_args() + + token_lengths_path = args.token_lengths_path + + # setting + data_path = args.json_file + from xtuner.utils import PROMPT_TEMPLATE + cfg_dataset_base = { + 'template': PROMPT_TEMPLATE.internlm2_chat, + 'model_path': '/model/path', + 'max_length': 4096, + } + + ds_collections = json.loads(open(data_path).read()) + import time + t_1 = time.time() + meta = {} + idx = 0 + + datasets = [] + for ds_name in tqdm(ds_collections.keys()): + print(ds_name) + cfg_dataset = copy.deepcopy(cfg_dataset_base) + cfg_dataset['repeat_times'] = ds_collections[ds_name]['repeat_time'] + cfg_dataset['data_paths'] = ds_collections[ds_name]['annotation'] + cfg_dataset['image_folders'] = ds_collections[ds_name]['root'] + + ds_info = {} + ds_info["root"] = ds_collections[ds_name]["root"] + ds_info["annotation"] = ds_collections[ds_name]["annotation"] + ds_info["data_augment"] = ds_collections[ds_name].get( + "data_augment", False) + ds_info["repeat_time"] = ds_collections[ds_name]['repeat_time'] + if 'max_dynamic_patch' in ds_collections[ds_name]: + ds_info['max_dynamic_patch'] = ds_collections[ds_name]['max_dynamic_patch'] + + meta[ds_name] = worker(cfg_dataset, ds_name, + token_lengths_path, ds_info) + + with open(args.output_path, "w") as f: + json.dump(meta.copy(), f, indent=4) + + t_2 = time.time() + print(f"time: {t_2-t_1}") diff --git a/xtuner/tools/data_preprocess_stastics.sh b/xtuner/tools/data_preprocess_stastics.sh new file mode 100755 index 000000000..de38489b7 --- /dev/null +++ b/xtuner/tools/data_preprocess_stastics.sh @@ -0,0 +1,9 @@ +ROOT=/path/to/xtuner +export PYTHONPATH=$ROOT:$PYTHONPATH + +export OMP_NUM_THREADS=1 + +# $1: internvl_sft_1.2M.json, data format as https://internvl.readthedocs.io/en/latest/get_started/chat_data_format.html#meta-file +# $2: the folder of the results of token stastics which is absolute path +# $3: pack_internvl_sft_1.2M.json, results file +python data_preprocess_stastics.py --json_file $1 --token_lengths_path $2 --output_path $3 2>&1 | tee -a log_statistics.txt