From dd54371834a162b632aa43099c9345c35182791d Mon Sep 17 00:00:00 2001 From: DukeG Date: Sat, 5 Oct 2024 12:20:37 +0800 Subject: [PATCH] Sync codes with repo https://github.com/fireicewolf/wd-llm-caption-cli --- README.md | 37 +++-- caption.py | 138 +++++++++------ configs/default.json | 251 +++++++++++++++++++--------- huggingface-requirements.txt | 2 +- modelscope-requirements.txt | 2 +- requirements.txt | 10 +- utils/download.py | 314 +++++++++++++++++------------------ utils/image.py | 49 +++++- utils/joy.py | 63 +++---- 9 files changed, 523 insertions(+), 343 deletions(-) diff --git a/README.md b/README.md index 2551b1e..2b1e190 100644 --- a/README.md +++ b/README.md @@ -1,35 +1,41 @@ # Joy Caption Cli -A Python base cli tool for tagging images with [joy-caption-pre-alpha](https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha) models. + +A Python base cli tool for tagging images +with [joy-caption-pre-alpha](https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha) models. ### Only support cuda devices in current. ## Introduce -I make this repo because I want to caption some images cross-platform (On My old MBP, my game win pc or docker base linux cloud-server(like Google colab)) +I make this repo because I want to caption some images cross-platform (On My old MBP, my game win pc or docker base +linux cloud-server(like Google colab)) -But I don't want to install a huge webui just for this little work. And some cloud-service are unfriendly to gradio base ui. +But I don't want to install a huge webui just for this little work. And some cloud-service are unfriendly to gradio base +ui. So this repo born. - ## Model source -Huggingface are original sources, modelscope are pure forks from Huggingface(Because HuggingFace was blocked in Some place). +Huggingface are original sources, modelscope are pure forks from Huggingface(Because HuggingFace was blocked in Some +place). -| Model | HuggingFace Link | ModelScope Link | -|:---------------------------------:|:------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------:| -| joy-caption-pre-alpha | [HuggingFace](https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha) | [ModelScope](https://www.modelscope.cn/models/fireicewolf/joy-caption-pre-alpha) | -| siglip-so400m-patch14-384(Google) | [HuggingFace](https://huggingface.co/google/siglip-so400m-patch14-384) | [ModelScope](https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384) | -| Meta-Llama-3.1-8B | [HuggingFace](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) | [ModelScope](https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B) | +| Model | HuggingFace Link | ModelScope Link | +|:---------------------------------:|:-----------------------------------------------------------------------------:|:------------------------------------------------------------------------------------:| +| joy-caption-pre-alpha | [HuggingFace](https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha) | [ModelScope](https://www.modelscope.cn/models/fireicewolf/joy-caption-pre-alpha) | +| siglip-so400m-patch14-384(Google) | [HuggingFace](https://huggingface.co/google/siglip-so400m-patch14-384) | [ModelScope](https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384) | +| Meta-Llama-3.1-8B | [HuggingFace](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) | [ModelScope](https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B) | ## TO-DO make a simple ui by Jupyter widget(When my lazy cancer cured😊) ## Installation + Python 3.10 works fine. Open a shell terminal and follow below steps: + ```shell # Clone this repo git clone https://github.com/fireicewolf/joy-caption-cli.git @@ -54,16 +60,21 @@ pip install -U -r modelscope-requirements.txt ``` ## Simple usage + __Make sure your python venv has been activated first!__ + ```shell python caption.py your_datasets_path ``` + To run with more options, You can find help by run with this or see at [Options](#options) + ```shell python caption.py -h ``` -## Options +## Options +
Advance options `data_path` @@ -81,11 +92,12 @@ config json for llava models, default is "default.json" [//]: # (`--use_cpu`) [//]: # () + [//]: # (Use cpu for inference.) `--model_name MODEL_NAME` -model name for inference, default is "Joy-Caption-Pre-Alpha", please check configs/default.json) +model name for inference, default is "Joy-Caption-Pre-Alpha", please check configs/default.json `--model_site MODEL_SITE` @@ -142,6 +154,7 @@ max tokens for output, default is 300.
## Credits + Base on [oy-caption-pre-alpha](https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha) Without their works(πŸ‘πŸ‘), this repo won't exist. diff --git a/caption.py b/caption.py index d38a595..4c0e1bd 100644 --- a/caption.py +++ b/caption.py @@ -1,9 +1,10 @@ import argparse import os +import time from datetime import datetime from pathlib import Path -from utils.download import download +from utils.download import download_models from utils.joy import Joy from utils.logger import Logger @@ -60,147 +61,186 @@ def main(args): else: models_save_path = Path(os.path.join(Path(__file__).parent, args.models_save_path)) - image_adapter_path, clip_path, llm_path = download( + image_adapter_path, clip_path, llm_path = download_models( logger=my_logger, + args=args, config_file=config_file, - model_name=str(args.model_name), - model_site=str(args.model_site), models_save_path=models_save_path, - use_sdk_cache=True if args.use_sdk_cache else False, - download_method=str(args.download_method) ) - # Load models my_joy = Joy( logger=my_logger, args=args, image_adapter_path=image_adapter_path, clip_path=clip_path, - llm_path=llm_path + llm_path=llm_path, + use_gpu=True if not args.llm_use_cpu else False ) my_joy.load_model() # Inference + start_inference_time = time.monotonic() my_joy.inference() + total_inference_time = time.monotonic() - start_inference_time + days = total_inference_time // (24 * 3600) + total_inference_time %= (24 * 3600) + hours = total_inference_time // 3600 + total_inference_time %= 3600 + minutes = total_inference_time // 60 + seconds = total_inference_time % 60 + days = f"{days} Day(s) " if days > 0 else "" + hours = f"{hours} Hour(s) " if hours > 0 or (days and hours == 0) else "" + minutes = f"{minutes} Min(s) " if minutes > 0 or (hours and minutes == 0) else "" + seconds = f"{seconds:.1f} Sec(s)" + my_logger.info(f"All work done with in {days}{hours}{minutes}{seconds}.") # Unload models my_joy.unload_model() def setup_args() -> argparse.ArgumentParser: - args = argparse.ArgumentParser() - - args.add_argument( + parsed_args = argparse.ArgumentParser() + base_args = parsed_args.add_argument_group("Base") + base_args.add_argument( 'data_path', type=str, help='path for data.' ) - args.add_argument( + base_args.add_argument( '--recursive', action='store_true', help='Include recursive dirs' ) - args.add_argument( + + log_args = parsed_args.add_argument_group("Logs") + log_args.add_argument( + '--log_level', + type=str, + choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], + default='INFO', + help='set log level, default is "INFO"' + ) + log_args.add_argument( + '--save_logs', + action='store_true', + help='save log file.' + ) + + download_args = parsed_args.add_argument_group("Download") + download_args.add_argument( '--config', type=str, default='default.json', help='config json for llava models, default is "default.json"' ) - # args.add_argument( - # '--use_cpu', - # action='store_true', - # help='use cpu for inference.' - # ) - args.add_argument( - '--image_size', - type=int, - default=1024, - help='resize image to suitable, default is 1024.' - ) - args.add_argument( + download_args.add_argument( '--model_name', type=str, default='Joy-Caption-Pre-Alpha', help='model name for inference, default is "Joy-Caption-Pre-Alpha", please check configs/default.json' ) - args.add_argument( + download_args.add_argument( '--model_site', type=str, choices=['huggingface', 'modelscope'], default='huggingface', help='download model from model site huggingface or modelscope, default is "huggingface".' ) - args.add_argument( + download_args.add_argument( '--models_save_path', type=str, default="models", help='path to save models, default is "models".' ) - args.add_argument( + download_args.add_argument( '--use_sdk_cache', action='store_true', help='use sdk\'s cache dir to store models. \ if this option enabled, "--models_save_path" will be ignored.' ) - args.add_argument( + download_args.add_argument( '--download_method', type=str, choices=["SDK", "URL"], default='SDK', help='download method via SDK or URL, default is "SDK".' ) - args.add_argument( + download_args.add_argument( + '--force_download', + action='store_true', + help='force download even file exists.' + ) + download_args.add_argument( + '--skip_download', + action='store_true', + help='skip download if exists.' + ) + download_args.add_argument( '--custom_caption_save_path', type=str, default=None, help='Input custom caption file save path.' ) - args.add_argument( - '--log_level', + + inference_args = parsed_args.add_argument_group("Inference") + inference_args.add_argument( + '--llm_use_cpu', + action='store_true', + help='use cpu for inference.' + ) + inference_args.add_argument( + '--llm_dtype', type=str, - choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], - default='INFO', - help='set log level, default is "INFO"' + choices=["auto", "fp16", "bf16", "fp32"], + default='fp16', + help='choice joy LLM load dtype, default is `auto`.' ) - args.add_argument( - '--save_logs', - action='store_true', - help='save log file.' + inference_args.add_argument( + '--llm_qnt', + type=str, + choices=["none", "4bit", "8bit"], + default='none', + help='Enable quantization for LLM ["none","4bit", "8bit"]. default is `none`.' + ) + inference_args.add_argument( + '--image_size', + type=int, + default=1024, + help='resize image to suitable, default is 1024.' ) - args.add_argument( + inference_args.add_argument( '--caption_extension', type=str, default='.txt', help='extension of caption file, default is ".txt"' ) - args.add_argument( + inference_args.add_argument( '--not_overwrite', action='store_true', help='not overwrite caption file if exist.' ) - args.add_argument( + inference_args.add_argument( '--user_prompt', type=str, default=DEFAULT_USER_PROMPT, help='user prompt for caption.' ) - args.add_argument( + inference_args.add_argument( '--temperature', type=float, default=0.5, help='temperature for Llama model.' ) - args.add_argument( + inference_args.add_argument( '--max_tokens', type=int, default=300, help='max tokens for output.' ) - - return args + return parsed_args if __name__ == "__main__": - args = setup_args() - args = args.parse_args() - main(args) + get_args = setup_args() + get_args = get_args.parse_args() + main(get_args) diff --git a/configs/default.json b/configs/default.json index 1043068..566e6c5 100644 --- a/configs/default.json +++ b/configs/default.json @@ -1,87 +1,172 @@ { - "Joy-Caption-Pre-Alpha": { - "huggingface": { - "image_adapter": { - "repo_id": "fancyfeast/joy-caption-pre-alpha", - "revision": "main", - "repo_type": "space", - "subfolder": "wpkklhc6", - "file_list": { - "image_adapter.pt": "https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha/resolve/main/wpkklhc6/image_adapter.pt" - } - }, - "clip": { - "repo_id": "google/siglip-so400m-patch14-384", - "revision": "main", - "repo_type": "model", - "subfolder": "", - "file_list": { - "config.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/config.json", - "tokenizer_config.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/tokenizer_config.json", - "special_tokens_map.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/preprocessor_config.json", - "preprocessor_config.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/special_tokens_map.json", - "spiece.model": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/spiece.model", - "model.safetensors": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/model.safetensors" - } - }, - "llm": { - "repo_id": "meta-llama/Meta-Llama-3.1-8B", - "revision": "main", - "repo_type": "model", - "subfolder": "", - "file_list": { - "config.json":"https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/config.json", - "generation_config.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/generation_config.json", - "tokenizer.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/tokenizer.json", - "tokenizer_config.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/tokenizer_config.json", - "special_tokens_map.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/special_tokens_map.json", - "model.safetensors.index.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model.safetensors.index.json", - "model-00001-of-00004.safetensors": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model-00001-of-00004.safetensors", - "model-00002-of-00004.safetensors": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model-00002-of-00004.safetensors", - "model-00003-of-00004.safetensors": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model-00003-of-00004.safetensors", - "model-00004-of-00004.safetensors": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model-00004-of-00004.safetensors" - } - } - }, - "modelscope": { - "image_adapter": { - "repo_id": "fireicewolf/joy-caption-pre-alpha", - "revision": "master", - "subfolder": "wpkklhc6", - "file_list": { - "image_adapter.pt": "https://www.modelscope.cn/models/fireicewolf/joy-caption-pre-alpha/resolve/master/wpkklhc6/image_adapter.pt" - } - }, - "clip": { - "repo_id": "fireicewolf/siglip-so400m-patch14-384", - "revision": "master", - "subfolder": "", - "file_list": { - "config.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/config.json", - "tokenizer_config.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/tokenizer_config.json", - "special_tokens_map.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/preprocessor_config.json", - "preprocessor_config.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/special_tokens_map.json", - "spiece.model": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/spiece.model", - "model.safetensors": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/model.safetensors" - } - }, - "llm": { - "repo_id": "fireicewolf/Meta-Llama-3.1-8B", - "revision": "master", - "subfolder": "", - "file_list": { - "config.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/config.json", - "generation_config.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/generation_config.json", - "tokenizer.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/tokenizer.json", - "tokenizer_config.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/tokenizer_config.json", - "special_tokens_map.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/special_tokens_map.json", - "model.safetensors.index.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model.safetensors.index.json", - "model-00001-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model-00001-of-00004.safetensors", - "model-00002-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model-00002-of-00004.safetensors", - "model-00003-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model-00003-of-00004.safetensors", - "model-00004-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model-00004-of-00004.safetensors" - } - } + "Joy-Caption-Pre-Alpha": { + "huggingface": { + "image_adapter": { + "repo_id": "fancyfeast/joy-caption-pre-alpha", + "revision": "main", + "repo_type": "space", + "subfolder": "wpkklhc6", + "file_list": { + "image_adapter.pt": "https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha/resolve/main/wpkklhc6/image_adapter.pt" } + }, + "clip": { + "repo_id": "google/siglip-so400m-patch14-384", + "revision": "main", + "repo_type": "model", + "subfolder": "", + "file_list": { + "config.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/config.json", + "tokenizer_config.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/tokenizer_config.json", + "special_tokens_map.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/preprocessor_config.json", + "preprocessor_config.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/special_tokens_map.json", + "spiece.model": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/spiece.model", + "model.safetensors": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/model.safetensors" + } + }, + "llm": { + "repo_id": "meta-llama/Meta-Llama-3.1-8B", + "revision": "main", + "repo_type": "model", + "subfolder": "", + "file_list": { + "config.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/config.json", + "generation_config.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/generation_config.json", + "tokenizer.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/tokenizer.json", + "tokenizer_config.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/tokenizer_config.json", + "special_tokens_map.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/special_tokens_map.json", + "model.safetensors.index.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model.safetensors.index.json", + "model-00001-of-00004.safetensors": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model-00001-of-00004.safetensors", + "model-00002-of-00004.safetensors": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model-00002-of-00004.safetensors", + "model-00003-of-00004.safetensors": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model-00003-of-00004.safetensors", + "model-00004-of-00004.safetensors": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model-00004-of-00004.safetensors" + } + } + }, + "modelscope": { + "image_adapter": { + "repo_id": "fireicewolf/joy-caption-pre-alpha", + "revision": "master", + "subfolder": "wpkklhc6", + "file_list": { + "image_adapter.pt": "https://www.modelscope.cn/models/fireicewolf/joy-caption-pre-alpha/resolve/master/wpkklhc6/image_adapter.pt" + } + }, + "clip": { + "repo_id": "fireicewolf/siglip-so400m-patch14-384", + "revision": "master", + "subfolder": "", + "file_list": { + "config.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/config.json", + "tokenizer_config.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/tokenizer_config.json", + "special_tokens_map.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/preprocessor_config.json", + "preprocessor_config.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/special_tokens_map.json", + "spiece.model": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/spiece.model", + "model.safetensors": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/model.safetensors" + } + }, + "llm": { + "repo_id": "fireicewolf/Meta-Llama-3.1-8B", + "revision": "master", + "subfolder": "", + "file_list": { + "config.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/config.json", + "generation_config.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/generation_config.json", + "tokenizer.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/tokenizer.json", + "tokenizer_config.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/tokenizer_config.json", + "special_tokens_map.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/special_tokens_map.json", + "model.safetensors.index.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model.safetensors.index.json", + "model-00001-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model-00001-of-00004.safetensors", + "model-00002-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model-00002-of-00004.safetensors", + "model-00003-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model-00003-of-00004.safetensors", + "model-00004-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model-00004-of-00004.safetensors" + } + } + } + }, + "Joy-Caption-Uncensored": { + "huggingface": { + "image_adapter": { + "repo_id": "fancyfeast/joy-caption-pre-alpha", + "revision": "main", + "repo_type": "space", + "subfolder": "wpkklhc6", + "file_list": { + "image_adapter.pt": "https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha/resolve/main/wpkklhc6/image_adapter.pt" + } + }, + "clip": { + "repo_id": "google/siglip-so400m-patch14-384", + "revision": "main", + "repo_type": "model", + "subfolder": "", + "file_list": { + "config.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/config.json", + "tokenizer_config.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/tokenizer_config.json", + "special_tokens_map.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/preprocessor_config.json", + "preprocessor_config.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/special_tokens_map.json", + "spiece.model": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/spiece.model", + "model.safetensors": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/model.safetensors" + } + }, + "llm": { + "repo_id": "Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2", + "revision": "main", + "repo_type": "model", + "subfolder": "", + "file_list": { + "config.json": "https://huggingface.co/Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/main/config.json", + "generation_config.json": "https://huggingface.co/Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/main/generation_config.json", + "tokenizer.json": "https://huggingface.co/Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/main/tokenizer.json", + "tokenizer_config.json": "https://huggingface.co/Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/main/tokenizer_config.json", + "special_tokens_map.json": "https://huggingface.co/Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/main/special_tokens_map.json", + "model.safetensors.index.json": "https://huggingface.co/Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/main/model.safetensors.index.json", + "model-00001-of-00004.safetensors": "https://huggingface.co/Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/main/model-00001-of-00004.safetensors", + "model-00002-of-00004.safetensors": "https://huggingface.co/Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/main/model-00002-of-00004.safetensors", + "model-00003-of-00004.safetensors": "https://huggingface.co/Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/main/model-00003-of-00004.safetensors", + "model-00004-of-00004.safetensors": "https://huggingface.co/Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/main/model-00004-of-00004.safetensors" + } + } + }, + "modelscope": { + "image_adapter": { + "repo_id": "fireicewolf/joy-caption-pre-alpha", + "revision": "master", + "subfolder": "wpkklhc6", + "file_list": { + "image_adapter.pt": "https://www.modelscope.cn/models/fireicewolf/joy-caption-pre-alpha/resolve/master/wpkklhc6/image_adapter.pt" + } + }, + "clip": { + "repo_id": "fireicewolf/siglip-so400m-patch14-384", + "revision": "master", + "subfolder": "", + "file_list": { + "config.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/config.json", + "tokenizer_config.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/tokenizer_config.json", + "special_tokens_map.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/preprocessor_config.json", + "preprocessor_config.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/special_tokens_map.json", + "spiece.model": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/spiece.model", + "model.safetensors": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/model.safetensors" + } + }, + "llm": { + "repo_id": "fireicewolf/Llama-3.1-8B-Lexi-Uncensored-V2", + "revision": "master", + "subfolder": "", + "file_list": { + "config.json": "https://www.modelscope.cn/models/fireicewolf/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/master/config.json", + "generation_config.json": "https://www.modelscope.cn/models/fireicewolf/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/master/generation_config.json", + "tokenizer.json": "https://www.modelscope.cn/models/fireicewolf/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/master/tokenizer.json", + "tokenizer_config.json": "https://www.modelscope.cn/models/fireicewolf/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/master/tokenizer_config.json", + "special_tokens_map.json": "https://www.modelscope.cn/models/fireicewolf/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/master/special_tokens_map.json", + "model.safetensors.index.json": "https://www.modelscope.cn/models/fireicewolf/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/master/model.safetensors.index.json", + "model-00001-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/master/model-00001-of-00004.safetensors", + "model-00002-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/master/model-00002-of-00004.safetensors", + "model-00003-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/master/model-00003-of-00004.safetensors", + "model-00004-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/master/model-00004-of-00004.safetensors" + } + } } + } } diff --git a/huggingface-requirements.txt b/huggingface-requirements.txt index 2afcdef..d292e9d 100644 --- a/huggingface-requirements.txt +++ b/huggingface-requirements.txt @@ -1,2 +1,2 @@ -huggingface_hub==0.24.6 +huggingface_hub==0.25.1 -r requirements.txt \ No newline at end of file diff --git a/modelscope-requirements.txt b/modelscope-requirements.txt index a4775a9..fac4583 100644 --- a/modelscope-requirements.txt +++ b/modelscope-requirements.txt @@ -1,2 +1,2 @@ -modelscope==1.17.1 +modelscope==1.18.1 -r requirements.txt \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f0c1d55..27f883c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ -accelerate==0.33.0 -transformers==4.44.2 -sentencepiece==0.2.0 -numpy==1.26.4 +accelerate==0.34.2 +bitsandbytes==0.44.1 opencv-python-headless==4.10.0.84 pillow==10.4.0 requests==2.32.3 -tqdm==4.66.5 \ No newline at end of file +sentencepiece==0.2.0 +tqdm==4.66.5 +transformers==4.45.1 \ No newline at end of file diff --git a/utils/download.py b/utils/download.py index 38a1152..a33bc29 100644 --- a/utils/download.py +++ b/utils/download.py @@ -1,6 +1,6 @@ +import argparse import json import os -import shutil from pathlib import Path from typing import Union, Optional @@ -10,15 +10,69 @@ from utils.logger import Logger -def download( +def url_download( logger: Logger, + url: str, + local_dir: Union[str, Path], + skip_local_file_exist: bool = True, + force_download: bool = False, + force_filename: Optional[str] = None +) -> Path: + # Download file via url by requests library + filename = os.path.basename(url) if not force_filename else force_filename + local_file = os.path.join(local_dir, filename) + + hf_token = os.environ.get("HF_TOKEN") + if hf_token: + logger.info(f"Loading huggingface token from environment variable") + response = requests.get(url, stream=True, headers={ + "Authorization": f"Bearer {hf_token}"} if "huggingface.co" in url and hf_token else None) + total_size = int(response.headers.get('content-length', 0)) + + def download_progress(): + desc = f'Downloading {filename}' + + if total_size > 0: + pbar = tqdm(total=total_size, initial=0, unit='B', unit_divisor=1024, unit_scale=True, + dynamic_ncols=True, + desc=desc) + else: + pbar = tqdm(initial=0, unit='B', unit_divisor=1024, unit_scale=True, dynamic_ncols=True, desc=desc) + + if not os.path.exists(local_dir): + os.makedirs(local_dir, exist_ok=True) + + with open(local_file, 'ab') as download_file: + for data in response.iter_content(chunk_size=1024): + if data: + download_file.write(data) + pbar.update(len(data)) + pbar.close() + + if not force_download and os.path.isfile(local_file): + if os.path.exists(local_file) and skip_local_file_exist: + logger.info(f"`skip_local_file_exist` is Enable, Skipping download {local_file}...") + else: + if total_size == 0: + logger.info( + f'"{local_file}" already exist, but can\'t get its size from "{url}". Won\'t download it.') + elif os.path.getsize(local_file) == total_size: + logger.info(f'"{local_file}" already exist, and its size match with "{url}".') + else: + logger.info( + f'"{local_file}" already exist, but its size not match with "{url}"!\nWill download this file ' + f'again...') + download_progress() + else: + download_progress() + return Path(os.path.join(local_dir, filename)) + + +def download_models( + logger: Logger, + args: argparse.Namespace, config_file: Path, - model_name: str, - model_site: str, models_save_path: Path, - use_sdk_cache: bool = False, - download_method: str = "sdk", - force_download: bool = False ) -> tuple[Path, Path, Path]: if os.path.isfile(config_file): logger.info(f'Using config: {str(config_file)}') @@ -26,194 +80,130 @@ def download( logger.error(f'{str(config_file)} NOT FOUND!') raise FileNotFoundError - def read_json(config_file, model_name) -> dict[str]: + def read_json(config_file) -> tuple[str, dict[str]]: with open(config_file, 'r', encoding='utf-8') as config_json: datas = json.load(config_json) + model_name = list(datas.keys())[0] if args.model_name is None else args.model_name + args.model_name = model_name + if model_name not in datas.keys(): logger.error(f'"{str(model_name)}" NOT FOUND IN CONFIG!') raise FileNotFoundError - return datas[model_name] - - model_info = read_json(config_file, model_name) + return model_name, datas[model_name] + model_name, model_info = read_json(config_file) models_save_path = Path(os.path.join(models_save_path, model_name)) - if use_sdk_cache: + if args.use_sdk_cache: logger.warning('use_sdk_cache ENABLED! download_method force to use "SDK" and models_save_path will be ignored') - download_method = 'sdk' + args.download_method = 'sdk' else: logger.info(f'Models will be stored in {str(models_save_path)}.') - def url_download( - url: str, - local_dir: Union[str, Path], - force_download: bool = False, - force_filename: Optional[str] = None - ) -> Path: - # Download file via url by requests library - filename = os.path.basename(url) if not force_filename else force_filename - local_file = os.path.join(local_dir, filename) - - hf_token = os.environ.get("HF_TOKEN") - - response = requests.get(url, stream=True, headers={"Authorization": f"Bearer {hf_token}"} if "huggingface.co" in url and hf_token else None) - total_size = int(response.headers.get('content-length', 0)) - - def download_progress(): - desc = 'Downloading {}'.format(filename) - - if total_size > 0: - pbar = tqdm(total=total_size, initial=0, unit='B', unit_divisor=1024, unit_scale=True, - dynamic_ncols=True, - desc=desc) - else: - pbar = tqdm(initial=0, unit='B', unit_divisor=1024, unit_scale=True, dynamic_ncols=True, desc=desc) - - if not os.path.exists(local_dir): - os.makedirs(local_dir, exist_ok=True) - - with open(local_file, 'ab') as download_file: - for data in response.iter_content(chunk_size=1024): - if data: - download_file.write(data) - pbar.update(len(data)) - pbar.close() - - if not force_download and os.path.isfile(local_file): - if total_size == 0: - logger.info( - f'"{local_file}" already exist, but can\'t get its size from "{url}". Won\'t download it.') - elif os.path.getsize(local_file) == total_size: - logger.info(f'"{local_file}" already exist, and its size match with "{url}".') - else: - logger.info( - f'"{local_file}" already exist, but its size not match with "{url}"!\nWill download this file ' - f'again...') - download_progress() - else: - download_progress() - - return Path(os.path.join(local_dir, filename)) - def download_choice( + args: argparse.Namespace, model_info: dict[str], model_site: str, models_save_path: Path, download_method: str = "sdk", use_sdk_cache: bool = False, + skip_local_file_exist: bool = True, force_download: bool = False ): - if download_method.lower() == 'sdk': - if model_site == "huggingface": - model_hf_info = model_info["huggingface"] - try: + if model_site not in ["huggingface", "modelscope"]: + logger.error('Invalid model site!') + raise ValueError + + model_site_info = model_info[model_site] + try: + if download_method == "sdk": + if model_site == "huggingface": from huggingface_hub import hf_hub_download - - models_path = [] - for sub_model_name in model_hf_info: - sub_model_info = model_hf_info[sub_model_name] - sub_model_path = "" - - for filename in sub_model_info["file_list"]: - logger.info(f'Will download "{filename}" from huggingface repo: "{sub_model_info["repo_id"]}".') - sub_model_path = hf_hub_download( - repo_id=sub_model_info["repo_id"], - filename=filename, - subfolder=sub_model_info["subfolder"] if sub_model_info["subfolder"] != "" else None, - repo_type=sub_model_info["repo_type"], - revision=sub_model_info["revision"], - local_dir=os.path.join(models_save_path,sub_model_name) if not use_sdk_cache else None, - # local_dir_use_symlinks=False if not use_sdk_cache else "auto", - # resume_download=True, - force_download=force_download - ) - models_path.append(sub_model_path) - return models_path - - except: - logger.warning('huggingface_hub not installed or download via it failed, ' - 'retrying with URL method to download...') - models_path = download_choice( - model_info, - model_site, - models_save_path, - use_sdk_cache=False, - download_method="url", - force_download=force_download - ) - return models_path - - elif model_site == "modelscope": - model_ms_info = model_info["modelscope"] - try: - if force_download: - logger.warning( - 'modelscope api not support force download, ' - 'trying to remove model path before download!') - shutil.rmtree(models_save_path) - + elif model_site == "modelscope": from modelscope.hub.file_download import model_file_download - models_path = [] - for sub_model_name in model_ms_info: - sub_model_info = model_ms_info[sub_model_name] - sub_model_path = "" - - for filename in sub_model_info["file_list"]: - logger.info(f'Will download "{filename}" from modelscope repo: "{sub_model_info["repo_id"]}".') - sub_model_path = model_file_download( - model_id=sub_model_info["repo_id"], - file_path=filename if sub_model_info["subfolder"] == "" - else os.path.join(sub_model_info["subfolder"],filename), - revision=sub_model_info["revision"], - local_dir=os.path.join(models_save_path,sub_model_name) if not use_sdk_cache else None, - ) - models_path.append(sub_model_path) - return models_path - - except: - logger.warning('modelscope not installed or download via it failed, ' - 'retrying with URL method to download...') - models_path = download_choice( - model_info, - model_site, - models_save_path, - use_sdk_cache=False, - download_method="url", - force_download=force_download - ) - return models_path - - else: - logger.error('Invalid model site!') - raise ValueError + except ModuleNotFoundError: + if model_site == "huggingface": + logger.warning('huggingface_hub not installed or download via it failed, ' + 'retrying with URL method to download...') + elif model_site == "modelscope": + logger.warning('modelscope not installed or download via it failed, ' + 'retrying with URL method to download...') + + models_path = download_choice( + args, + model_info, + model_site, + models_save_path, + use_sdk_cache=False, + download_method="url", + skip_local_file_exist=skip_local_file_exist, + force_download=force_download + ) + return models_path - else: - model_site_info = model_info[model_site] - models_path = [] - for sub_model_name in model_site_info: - sub_model_info = model_site_info[sub_model_name] - sub_model_path = "" - for filename in sub_model_info["file_list"]: + models_path = [] + for sub_model_name in model_site_info: + sub_model_info = model_site_info[sub_model_name] + if sub_model_name == "patch" and not args.llm_patch: + logger.warning(f"Found LLM patch, but llm_patch not enabled, won't download it.") + continue + sub_model_path = "" + + for filename in sub_model_info["file_list"]: + if download_method.lower() == 'sdk': + if model_site == "huggingface": + logger.info(f'Will download "{filename}" from huggingface repo: "{sub_model_info["repo_id"]}".') + sub_model_path = hf_hub_download( + repo_id=sub_model_info["repo_id"], + filename=filename, + subfolder=sub_model_info["subfolder"] if sub_model_info["subfolder"] != "" else None, + repo_type=sub_model_info["repo_type"], + revision=sub_model_info["revision"], + local_dir=os.path.join(models_save_path, sub_model_name) if not use_sdk_cache else None, + local_files_only=skip_local_file_exist \ + if os.path.exists(os.path.join(models_save_path, sub_model_name, filename)) else False, + # local_dir_use_symlinks=False if not use_sdk_cache else "auto", + # resume_download=True, + force_download=force_download + ) + elif model_site == "modelscope": + logger.info(f'Will download "{filename}" from modelscope repo: "{sub_model_info["repo_id"]}".') + sub_model_path = model_file_download( + model_id=sub_model_info["repo_id"], + file_path=filename if sub_model_info["subfolder"] == "" + else os.path.join(sub_model_info["subfolder"], filename), + revision=sub_model_info["revision"], + local_files_only=skip_local_file_exist, + local_dir=os.path.join(models_save_path, sub_model_name) if not use_sdk_cache else None, + ) + else: model_url = sub_model_info["file_list"][filename] logger.info(f'Will download model from url: {model_url}') sub_model_path = url_download( + logger=logger, url=model_url, - local_dir=os.path.join(models_save_path,sub_model_name) if sub_model_info["subfolder"] == "" - else os.path.join(models_save_path,sub_model_name,sub_model_info["subfolder"]), + local_dir=os.path.join(models_save_path, sub_model_name) if sub_model_info["subfolder"] == "" + else os.path.join(models_save_path, sub_model_name, sub_model_info["subfolder"]), force_filename=filename, + skip_local_file_exist=skip_local_file_exist, force_download=force_download ) - models_path.append(sub_model_path) - return models_path + models_path.append(sub_model_path) + return models_path models_path = download_choice( + args=args, model_info=model_info, - model_site=model_site, - models_save_path=models_save_path, - download_method=download_method, - use_sdk_cache=use_sdk_cache, - force_download=force_download + model_site=str(args.model_site), + models_save_path=Path(models_save_path), + download_method=str(args.download_method).lower(), + use_sdk_cache=args.use_sdk_cache, + skip_local_file_exist=args.skip_download, + force_download=args.force_download ) - return Path(models_path[0]), Path(os.path.dirname(models_path[1])), Path(os.path.dirname(models_path[2])) + image_adapter_path = Path(models_path[0]) + clip_path = Path(os.path.dirname(models_path[1])) + llm_path = Path(os.path.dirname(models_path[2])) + return image_adapter_path, clip_path, llm_path diff --git a/utils/image.py b/utils/image.py index 825f4f1..be2e8ae 100644 --- a/utils/image.py +++ b/utils/image.py @@ -1,13 +1,43 @@ import base64 +import glob +import os +from io import BytesIO from pathlib import Path +from typing import List import cv2 import numpy -from io import BytesIO from PIL import Image +from utils.logger import Logger + +SUPPORT_IMAGE_FORMATS = ("bmp", "jpg", "jpeg", "png", "webp") + + +def get_image_paths( + logger: Logger, + path: Path, + recursive: bool = False, +) -> List[str]: + # Get image paths + path_to_find = os.path.join(path, '**') if recursive else os.path.join(path, '*') + image_paths = sorted(set( + [image for image in glob.glob(path_to_find, recursive=recursive) + if image.lower().endswith(SUPPORT_IMAGE_FORMATS)]), key=lambda filename: (os.path.splitext(filename)[0]) + ) if not os.path.isfile(path) else [str(path)] \ + if str(path).lower().endswith(SUPPORT_IMAGE_FORMATS) else None + + logger.debug(f"Path for inference: \"{path}\"") + + if image_paths is None: + logger.error('Invalid dir or image path!') + raise FileNotFoundError -def image_process(image: Image.Image, target_size: int) -> Image.Image: + logger.info(f'Found {len(image_paths)} image(s).') + return image_paths + + +def image_process(image: Image.Image, target_size: int) -> numpy.ndarray: # make alpha to white image = image.convert('RGBA') new_image = Image.new('RGBA', image.size, 'WHITE') @@ -53,9 +83,24 @@ def image_process(image: Image.Image, target_size: int) -> Image.Image: interpolation=cv2.INTER_LANCZOS4 ) + return padded_image + + +def image_process_image( + padded_image: numpy.ndarray +) -> Image.Image: return Image.fromarray(padded_image) +def image_process_gbr( + padded_image: numpy.ndarray +) -> numpy.ndarray: + # From PIL RGB to OpenCV GBR + padded_image = padded_image[:, :, ::-1] + padded_image = padded_image.astype(numpy.float32) + return padded_image + + def encode_image_to_base64(image: Image.Image): with BytesIO() as bytes_output: image.save(bytes_output, format="PNG") diff --git a/utils/joy.py b/utils/joy.py index d4dfb8f..5d5b06f 100644 --- a/utils/joy.py +++ b/utils/joy.py @@ -1,4 +1,3 @@ -import glob import os import time from argparse import Namespace @@ -9,13 +8,12 @@ from PIL import Image from torch import nn from tqdm import tqdm -from transformers import (AutoModel, AutoProcessor, AutoTokenizer, - PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM) +from transformers import (AutoModel, AutoProcessor, AutoTokenizer, AutoModelForCausalLM, + BitsAndBytesConfig, PreTrainedTokenizer, PreTrainedTokenizerFast) -from utils.image import image_process +from utils.image import image_process, get_image_paths, image_process_image from utils.logger import Logger -SUPPORT_IMAGE_FORMATS = ("bmp", "jpg", "jpeg", "png") class ImageAdapter(nn.Module): def __init__(self, input_features: int, output_features: int): @@ -30,6 +28,7 @@ def forward(self, vision_outputs: torch.Tensor): x = self.linear2(x) return x + class Joy: def __init__( self, @@ -62,7 +61,7 @@ def load_model(self): self.clip_model = self.clip_model.vision_model self.clip_model.eval() self.clip_model.requires_grad_(False) - self.clip_model.to("cuda") + self.clip_model.to("cuda" if self.use_gpu else "cpu") self.logger.info(f'CLIP Loaded in {time.monotonic() - start_time:.1f}s.') # Load LLM @@ -72,7 +71,27 @@ def load_model(self): assert (isinstance(self.llm_tokenizer, PreTrainedTokenizer) or isinstance(self.llm_tokenizer, PreTrainedTokenizerFast)), \ f"Tokenizer is of type {type(self.llm_tokenizer)}" - self.llm = AutoModelForCausalLM.from_pretrained(self.llm_path, device_map="auto", torch_dtype=torch.bfloat16) + # LLM dType + llm_dtype = torch.float32 if self.args.llm_use_cpu or self.args.llm_dtype == "fp32" else torch.float16 \ + if self.args.llm_dtype == "fp16" else torch.bfloat16 if self.args.llm_dtype == "bf16" else "auto" + self.logger.info(f'LLM dtype: {llm_dtype}') + # LLM BNB quantization config + if self.args.llm_qnt == "4bit": + qnt_config = BitsAndBytesConfig(load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=llm_dtype, + bnb_4bit_use_double_quant=True) + self.logger.info(f'LLM 4bit quantization: Enabled') + elif self.args.llm_qnt == "8bit": + qnt_config = BitsAndBytesConfig(load_in_8bit=True, + llm_int8_enable_fp32_cpu_offload=True) + self.logger.info(f'LLM 8bit quantization: Enabled') + else: + qnt_config = None + self.llm = AutoModelForCausalLM.from_pretrained(self.llm_path, + device_map="auto" if not self.args.llm_use_cpu else "cpu", + torch_dtype=llm_dtype if self.args.llm_qnt == "none" else None, + quantization_config=qnt_config) self.llm.eval() self.logger.info(f'LLM Loaded in {time.monotonic() - start_time:.1f}s.') @@ -81,28 +100,15 @@ def load_model(self): self.image_adapter = ImageAdapter(self.clip_model.config.hidden_size, self.llm.config.hidden_size) self.image_adapter.load_state_dict(torch.load(self.image_adapter_path, map_location="cpu")) self.image_adapter.eval() - self.image_adapter.to("cuda") + self.image_adapter.to(self.llm.device) self.logger.info(f'Image Adapter Loaded in {time.monotonic() - start_time:.1f}s.') def inference(self): # Get image paths - path_to_find = os.path.join(self.args.data_path, '**') \ - if self.args.recursive else os.path.join(self.args.data_path, '*') - image_paths = sorted(set( - [image for image in glob.glob(path_to_find, recursive=self.args.recursive) - if image.lower().endswith(SUPPORT_IMAGE_FORMATS)]), - key=lambda filename: (os.path.splitext(filename)[0]) - ) if not os.path.isfile(self.args.data_path) else [str(self.args.data_path)] \ - if str(self.args.data_path).lower().endswith(SUPPORT_IMAGE_FORMATS) else None - - if image_paths is None: - self.logger.error('Invalid dir or image path!') - raise FileNotFoundError - - self.logger.info(f'Found {len(image_paths)} image(s).') + image_paths = get_image_paths(logger=self.logger, path=Path(self.args.data_path), recursive=self.args.recursive) def get_caption( - image: Image, + image: Image.Image, user_prompt: str, temperature: float = 0.5, max_new_tokens: int = 300, @@ -111,7 +117,7 @@ def get_caption( torch.cuda.empty_cache() # Preprocess image image = self.clip_processor(images=image, return_tensors='pt').pixel_values - image = image.to('cuda') + image = image.to(self.llm.device) # Tokenize the prompt prompt = self.llm_tokenizer.encode(user_prompt, return_tensors='pt', @@ -119,13 +125,13 @@ def get_caption( truncation=False, add_special_tokens=False) # Embed image - with torch.amp.autocast_mode.autocast('cuda', enabled=True): + with torch.amp.autocast_mode.autocast("cuda" if self.use_gpu else "cpu", enabled=True): vision_outputs = self.clip_model(pixel_values=image, output_hidden_states=True) image_features = vision_outputs.hidden_states[-2] embedded_images = self.image_adapter(image_features) - embedded_images = embedded_images.to('cuda') + embedded_images = embedded_images.to(self.llm.device) # Embed prompt - prompt_embeds = self.llm.model.embed_tokens(prompt.to('cuda')) + prompt_embeds = self.llm.model.embed_tokens(prompt.to(self.llm.device)) assert prompt_embeds.shape == (1, prompt.shape[1], self.llm.config.hidden_size), \ f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], self.llm.config.hidden_size)}" @@ -143,7 +149,7 @@ def get_caption( torch.tensor([[self.llm_tokenizer.bos_token_id]], dtype=torch.long), torch.zeros((1, embedded_images.shape[1]), dtype=torch.long), prompt, - ], dim=1).to('cuda') + ], dim=1).to(self.llm.device) attention_mask = torch.ones_like(input_ids) # Generate caption generate_ids = self.llm.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, @@ -167,6 +173,7 @@ def get_caption( image_path[:15]) + ' ... ' + image_path[-20:]) image = Image.open(image_path) image = image_process(image, int(self.args.image_size)) + image = image_process_image(image) caption = get_caption( image=image, user_prompt=self.args.user_prompt,