diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e7fe8d8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +.envs +.idea +.huggingface +models \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..e27e5e1 --- /dev/null +++ b/README.md @@ -0,0 +1,150 @@ +# Joy Caption Cli +A Python base cli tool for tagging images with [joy-caption-pre-alpha](https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha) models. + +### Only support cuda devices in current. + +## Introduce + +I make this repo because I want to caption some images cross-platform (On My old MBP, my game win pc or docker base linux cloud-server(like Google colab)) + +But I don't want to install a huge webui just for this little work. And some cloud-service are unfriendly to gradio base ui. + +So this repo born. + + +## Model source + +Huggingface are original sources, modelscope are pure forks from Huggingface(Because HuggingFace was blocked in Some place). + +| Model | HuggingFace Link | ModelScope Link | +|:---------------------------------:|:------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------:| +| joy-caption-pre-alpha | [HuggingFace](https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha) | [ModelScope](https://www.modelscope.cn/models/fireicewolf/joy-caption-pre-alpha) | +| siglip-so400m-patch14-384(Google) | [HuggingFace](https://huggingface.co/google/siglip-so400m-patch14-384) | [ModelScope](https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384) | +| Meta-Llama-3.1-8B | [HuggingFace](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) | [ModelScope](https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B) | + +## TO-DO + +make a simple ui by Jupyter widget(When my lazy cancer cured😊) + +## Installation +Python 3.10 works fine. + +Open a shell terminal and follow below steps: +```shell +# Clone this repo +git clone https://github.com/fireicewolf/joy-caption-cli.git +cd joy-caption-cli + +# create a Python venv +python -m venv .venv +.\venv\Scripts\activate + +# Install torch +# Install torch base on your GPU driver. ex. +pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cu124 + +# Base dependencies, models for inference will download via python request libs. +pip install -U -r requirements.txt + +# If you want to download or cache model via huggingface hub, install this. +pip install -U -r huggingface-requirements.txt + +# If you want to download or cache model via modelscope hub, install this. +pip install -U -r modelscope-requirements.txt +``` + +### Take a notice +This project use llama-cpp-python as base lib, and it needs to be complied. + +## Simple usage +__Make sure your python venv has been activated first!__ +```shell +python caption.py your_datasets_path +``` +To run with more options, You can find help by run with this or see at [Options](#options) +```shell +python caption.py -h +``` + +## Options +
+ Advance options +`data_path` + +path for data + +`--recursive` + +Will include all support images format in your input datasets path and its sub-path. + +`config` + +config json for llava models, default is "default.json" + +[//]: # (`--use_cpu`) + +[//]: # () +[//]: # (Use cpu for inference.) + +`--model_name MODEL_NAME` + +model name for inference, default is "Joy-Caption-Pre-Alpha", please check configs/default.json) + +`--model_site MODEL_SITE` + +Model site where onnx model download from(huggingface or modelscope), default is huggingface. + +`--models_save_path MODEL_SAVE_PATH` + +Path for models to save, default is models(under project folder). + +`--download_method SDK` + +Download models via sdk or url, default is sdk. + +If huggingface hub or modelscope sdk not installed or download failed, will auto retry with url download. + +`--use_sdk_cache` + +Use huggingface or modelscope sdk cache to store models, this option need huggingface_hub or modelscope sdk installed. + +If this enabled, `--models_save_path` will be ignored. + +`--custom_caption_save_path CUSTOM_CAPTION_SAVE_PATH` + +Save caption files to a custom path but not with images(But keep their directory structure) + +`--log_level LOG_LEVEL` + +Log level for terminal console and log file, default is `INFO`(`DEBUG`,`INFO`,`WARNING`,`ERROR`,`CRITICAL`) + +`--save_logs` + +Save logs to a file, log will be saved at same level with `data_dir_path` + +`--caption_extension CAPTION_EXTENSION` + +Caption file extension, default is `.txt` + +`--not_overwrite` + +Do not overwrite caption file if it existed. + +`--user_prompt USER_PROMPT` + +user prompt for caption. + +`--temperature TEMPERATURE` + +temperature for Llama model,default is 0.5. + +`--max_tokens MAX_TOKENS` + +max tokens for output, default is 300. + +
+ +## Credits +Base on [oy-caption-pre-alpha](https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha) + +Without their works(πŸ‘πŸ‘), this repo won't exist. diff --git a/caption.py b/caption.py new file mode 100644 index 0000000..d38a595 --- /dev/null +++ b/caption.py @@ -0,0 +1,206 @@ +import argparse +import os +from datetime import datetime +from pathlib import Path + +from utils.download import download +from utils.joy import Joy +from utils.logger import Logger + +DEFAULT_USER_PROMPT = """ +A descriptive caption for this image:\n +""" + + +def main(args): + # Set logger + workspace_path = os.getcwd() + data_dir_path = Path(args.data_path) + + log_file_path = data_dir_path.parent if os.path.exists(data_dir_path.parent) else workspace_path + + if args.custom_caption_save_path: + log_file_path = Path(args.custom_caption_save_path) + + log_time = datetime.now().strftime('%Y%m%d_%H%M%S') + # caption_failed_list_file = f'Caption_failed_list_{log_time}.txt' + + if os.path.exists(data_dir_path): + log_name = os.path.basename(data_dir_path) + + else: + print(f'{data_dir_path} NOT FOUND!!!') + raise FileNotFoundError + + if args.save_logs: + log_file = f'Caption_{log_name}_{log_time}.log' if log_name else f'test_{log_time}.log' + log_file = os.path.join(log_file_path, log_file) \ + if os.path.exists(log_file_path) else os.path.join(os.getcwd(), log_file) + else: + log_file = None + + if str(args.log_level).lower() in 'debug, info, warning, error, critical': + my_logger = Logger(args.log_level, log_file).logger + my_logger.info(f'Set log level to "{args.log_level}"') + + else: + my_logger = Logger('INFO', log_file).logger + my_logger.warning('Invalid log level, set log level to "INFO"!') + + if args.save_logs: + my_logger.info(f'Log file will be saved as "{log_file}".') + + # Check custom models path + config_file = os.path.join(Path(__file__).parent, 'configs', 'default.json') \ + if args.config == "default.json" else Path(args.config) + + # Download models + if os.path.exists(Path(args.models_save_path)): + models_save_path = Path(args.models_save_path) + else: + models_save_path = Path(os.path.join(Path(__file__).parent, args.models_save_path)) + + image_adapter_path, clip_path, llm_path = download( + logger=my_logger, + config_file=config_file, + model_name=str(args.model_name), + model_site=str(args.model_site), + models_save_path=models_save_path, + use_sdk_cache=True if args.use_sdk_cache else False, + download_method=str(args.download_method) + ) + + # Load models + my_joy = Joy( + logger=my_logger, + args=args, + image_adapter_path=image_adapter_path, + clip_path=clip_path, + llm_path=llm_path + ) + my_joy.load_model() + + # Inference + my_joy.inference() + + # Unload models + my_joy.unload_model() + + +def setup_args() -> argparse.ArgumentParser: + args = argparse.ArgumentParser() + + args.add_argument( + 'data_path', + type=str, + help='path for data.' + ) + args.add_argument( + '--recursive', + action='store_true', + help='Include recursive dirs' + ) + args.add_argument( + '--config', + type=str, + default='default.json', + help='config json for llava models, default is "default.json"' + ) + # args.add_argument( + # '--use_cpu', + # action='store_true', + # help='use cpu for inference.' + # ) + args.add_argument( + '--image_size', + type=int, + default=1024, + help='resize image to suitable, default is 1024.' + ) + args.add_argument( + '--model_name', + type=str, + default='Joy-Caption-Pre-Alpha', + help='model name for inference, default is "Joy-Caption-Pre-Alpha", please check configs/default.json' + ) + args.add_argument( + '--model_site', + type=str, + choices=['huggingface', 'modelscope'], + default='huggingface', + help='download model from model site huggingface or modelscope, default is "huggingface".' + ) + args.add_argument( + '--models_save_path', + type=str, + default="models", + help='path to save models, default is "models".' + ) + args.add_argument( + '--use_sdk_cache', + action='store_true', + help='use sdk\'s cache dir to store models. \ + if this option enabled, "--models_save_path" will be ignored.' + ) + args.add_argument( + '--download_method', + type=str, + choices=["SDK", "URL"], + default='SDK', + help='download method via SDK or URL, default is "SDK".' + ) + args.add_argument( + '--custom_caption_save_path', + type=str, + default=None, + help='Input custom caption file save path.' + ) + args.add_argument( + '--log_level', + type=str, + choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], + default='INFO', + help='set log level, default is "INFO"' + ) + args.add_argument( + '--save_logs', + action='store_true', + help='save log file.' + ) + args.add_argument( + '--caption_extension', + type=str, + default='.txt', + help='extension of caption file, default is ".txt"' + ) + args.add_argument( + '--not_overwrite', + action='store_true', + help='not overwrite caption file if exist.' + ) + args.add_argument( + '--user_prompt', + type=str, + default=DEFAULT_USER_PROMPT, + help='user prompt for caption.' + ) + args.add_argument( + '--temperature', + type=float, + default=0.5, + help='temperature for Llama model.' + ) + args.add_argument( + '--max_tokens', + type=int, + default=300, + help='max tokens for output.' + ) + + return args + + +if __name__ == "__main__": + args = setup_args() + args = args.parse_args() + main(args) diff --git a/configs/default.json b/configs/default.json new file mode 100644 index 0000000..1043068 --- /dev/null +++ b/configs/default.json @@ -0,0 +1,87 @@ +{ + "Joy-Caption-Pre-Alpha": { + "huggingface": { + "image_adapter": { + "repo_id": "fancyfeast/joy-caption-pre-alpha", + "revision": "main", + "repo_type": "space", + "subfolder": "wpkklhc6", + "file_list": { + "image_adapter.pt": "https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha/resolve/main/wpkklhc6/image_adapter.pt" + } + }, + "clip": { + "repo_id": "google/siglip-so400m-patch14-384", + "revision": "main", + "repo_type": "model", + "subfolder": "", + "file_list": { + "config.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/config.json", + "tokenizer_config.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/tokenizer_config.json", + "special_tokens_map.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/preprocessor_config.json", + "preprocessor_config.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/special_tokens_map.json", + "spiece.model": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/spiece.model", + "model.safetensors": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/model.safetensors" + } + }, + "llm": { + "repo_id": "meta-llama/Meta-Llama-3.1-8B", + "revision": "main", + "repo_type": "model", + "subfolder": "", + "file_list": { + "config.json":"https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/config.json", + "generation_config.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/generation_config.json", + "tokenizer.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/tokenizer.json", + "tokenizer_config.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/tokenizer_config.json", + "special_tokens_map.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/special_tokens_map.json", + "model.safetensors.index.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model.safetensors.index.json", + "model-00001-of-00004.safetensors": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model-00001-of-00004.safetensors", + "model-00002-of-00004.safetensors": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model-00002-of-00004.safetensors", + "model-00003-of-00004.safetensors": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model-00003-of-00004.safetensors", + "model-00004-of-00004.safetensors": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model-00004-of-00004.safetensors" + } + } + }, + "modelscope": { + "image_adapter": { + "repo_id": "fireicewolf/joy-caption-pre-alpha", + "revision": "master", + "subfolder": "wpkklhc6", + "file_list": { + "image_adapter.pt": "https://www.modelscope.cn/models/fireicewolf/joy-caption-pre-alpha/resolve/master/wpkklhc6/image_adapter.pt" + } + }, + "clip": { + "repo_id": "fireicewolf/siglip-so400m-patch14-384", + "revision": "master", + "subfolder": "", + "file_list": { + "config.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/config.json", + "tokenizer_config.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/tokenizer_config.json", + "special_tokens_map.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/preprocessor_config.json", + "preprocessor_config.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/special_tokens_map.json", + "spiece.model": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/spiece.model", + "model.safetensors": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/model.safetensors" + } + }, + "llm": { + "repo_id": "fireicewolf/Meta-Llama-3.1-8B", + "revision": "master", + "subfolder": "", + "file_list": { + "config.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/config.json", + "generation_config.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/generation_config.json", + "tokenizer.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/tokenizer.json", + "tokenizer_config.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/tokenizer_config.json", + "special_tokens_map.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/special_tokens_map.json", + "model.safetensors.index.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model.safetensors.index.json", + "model-00001-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model-00001-of-00004.safetensors", + "model-00002-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model-00002-of-00004.safetensors", + "model-00003-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model-00003-of-00004.safetensors", + "model-00004-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model-00004-of-00004.safetensors" + } + } + } + } +} diff --git a/huggingface-requirements.txt b/huggingface-requirements.txt new file mode 100644 index 0000000..2afcdef --- /dev/null +++ b/huggingface-requirements.txt @@ -0,0 +1,2 @@ +huggingface_hub==0.24.6 +-r requirements.txt \ No newline at end of file diff --git a/modelscope-requirements.txt b/modelscope-requirements.txt new file mode 100644 index 0000000..a4775a9 --- /dev/null +++ b/modelscope-requirements.txt @@ -0,0 +1,2 @@ +modelscope==1.17.1 +-r requirements.txt \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e0890fa --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +accelerate==0.33.0 +transformers==4.44.0 +sentencepiece==0.2.0 +numpy==1.26.4 +opencv-python-headless==4.10.0.84 +pillow==10.4.0 +requests==2.32.3 +tqdm==4.66.5 \ No newline at end of file diff --git a/utils/download.py b/utils/download.py new file mode 100644 index 0000000..38a1152 --- /dev/null +++ b/utils/download.py @@ -0,0 +1,219 @@ +import json +import os +import shutil +from pathlib import Path +from typing import Union, Optional + +import requests +from tqdm import tqdm + +from utils.logger import Logger + + +def download( + logger: Logger, + config_file: Path, + model_name: str, + model_site: str, + models_save_path: Path, + use_sdk_cache: bool = False, + download_method: str = "sdk", + force_download: bool = False +) -> tuple[Path, Path, Path]: + if os.path.isfile(config_file): + logger.info(f'Using config: {str(config_file)}') + else: + logger.error(f'{str(config_file)} NOT FOUND!') + raise FileNotFoundError + + def read_json(config_file, model_name) -> dict[str]: + with open(config_file, 'r', encoding='utf-8') as config_json: + datas = json.load(config_json) + if model_name not in datas.keys(): + logger.error(f'"{str(model_name)}" NOT FOUND IN CONFIG!') + raise FileNotFoundError + return datas[model_name] + + model_info = read_json(config_file, model_name) + + models_save_path = Path(os.path.join(models_save_path, model_name)) + + if use_sdk_cache: + logger.warning('use_sdk_cache ENABLED! download_method force to use "SDK" and models_save_path will be ignored') + download_method = 'sdk' + else: + logger.info(f'Models will be stored in {str(models_save_path)}.') + + def url_download( + url: str, + local_dir: Union[str, Path], + force_download: bool = False, + force_filename: Optional[str] = None + ) -> Path: + # Download file via url by requests library + filename = os.path.basename(url) if not force_filename else force_filename + local_file = os.path.join(local_dir, filename) + + hf_token = os.environ.get("HF_TOKEN") + + response = requests.get(url, stream=True, headers={"Authorization": f"Bearer {hf_token}"} if "huggingface.co" in url and hf_token else None) + total_size = int(response.headers.get('content-length', 0)) + + def download_progress(): + desc = 'Downloading {}'.format(filename) + + if total_size > 0: + pbar = tqdm(total=total_size, initial=0, unit='B', unit_divisor=1024, unit_scale=True, + dynamic_ncols=True, + desc=desc) + else: + pbar = tqdm(initial=0, unit='B', unit_divisor=1024, unit_scale=True, dynamic_ncols=True, desc=desc) + + if not os.path.exists(local_dir): + os.makedirs(local_dir, exist_ok=True) + + with open(local_file, 'ab') as download_file: + for data in response.iter_content(chunk_size=1024): + if data: + download_file.write(data) + pbar.update(len(data)) + pbar.close() + + if not force_download and os.path.isfile(local_file): + if total_size == 0: + logger.info( + f'"{local_file}" already exist, but can\'t get its size from "{url}". Won\'t download it.') + elif os.path.getsize(local_file) == total_size: + logger.info(f'"{local_file}" already exist, and its size match with "{url}".') + else: + logger.info( + f'"{local_file}" already exist, but its size not match with "{url}"!\nWill download this file ' + f'again...') + download_progress() + else: + download_progress() + + return Path(os.path.join(local_dir, filename)) + + def download_choice( + model_info: dict[str], + model_site: str, + models_save_path: Path, + download_method: str = "sdk", + use_sdk_cache: bool = False, + force_download: bool = False + ): + if download_method.lower() == 'sdk': + if model_site == "huggingface": + model_hf_info = model_info["huggingface"] + try: + from huggingface_hub import hf_hub_download + + models_path = [] + for sub_model_name in model_hf_info: + sub_model_info = model_hf_info[sub_model_name] + sub_model_path = "" + + for filename in sub_model_info["file_list"]: + logger.info(f'Will download "{filename}" from huggingface repo: "{sub_model_info["repo_id"]}".') + sub_model_path = hf_hub_download( + repo_id=sub_model_info["repo_id"], + filename=filename, + subfolder=sub_model_info["subfolder"] if sub_model_info["subfolder"] != "" else None, + repo_type=sub_model_info["repo_type"], + revision=sub_model_info["revision"], + local_dir=os.path.join(models_save_path,sub_model_name) if not use_sdk_cache else None, + # local_dir_use_symlinks=False if not use_sdk_cache else "auto", + # resume_download=True, + force_download=force_download + ) + models_path.append(sub_model_path) + return models_path + + except: + logger.warning('huggingface_hub not installed or download via it failed, ' + 'retrying with URL method to download...') + models_path = download_choice( + model_info, + model_site, + models_save_path, + use_sdk_cache=False, + download_method="url", + force_download=force_download + ) + return models_path + + elif model_site == "modelscope": + model_ms_info = model_info["modelscope"] + try: + if force_download: + logger.warning( + 'modelscope api not support force download, ' + 'trying to remove model path before download!') + shutil.rmtree(models_save_path) + + from modelscope.hub.file_download import model_file_download + + models_path = [] + for sub_model_name in model_ms_info: + sub_model_info = model_ms_info[sub_model_name] + sub_model_path = "" + + for filename in sub_model_info["file_list"]: + logger.info(f'Will download "{filename}" from modelscope repo: "{sub_model_info["repo_id"]}".') + sub_model_path = model_file_download( + model_id=sub_model_info["repo_id"], + file_path=filename if sub_model_info["subfolder"] == "" + else os.path.join(sub_model_info["subfolder"],filename), + revision=sub_model_info["revision"], + local_dir=os.path.join(models_save_path,sub_model_name) if not use_sdk_cache else None, + ) + models_path.append(sub_model_path) + return models_path + + except: + logger.warning('modelscope not installed or download via it failed, ' + 'retrying with URL method to download...') + models_path = download_choice( + model_info, + model_site, + models_save_path, + use_sdk_cache=False, + download_method="url", + force_download=force_download + ) + return models_path + + else: + logger.error('Invalid model site!') + raise ValueError + + else: + model_site_info = model_info[model_site] + models_path = [] + for sub_model_name in model_site_info: + sub_model_info = model_site_info[sub_model_name] + sub_model_path = "" + for filename in sub_model_info["file_list"]: + model_url = sub_model_info["file_list"][filename] + logger.info(f'Will download model from url: {model_url}') + sub_model_path = url_download( + url=model_url, + local_dir=os.path.join(models_save_path,sub_model_name) if sub_model_info["subfolder"] == "" + else os.path.join(models_save_path,sub_model_name,sub_model_info["subfolder"]), + force_filename=filename, + force_download=force_download + ) + models_path.append(sub_model_path) + return models_path + + models_path = download_choice( + model_info=model_info, + model_site=model_site, + models_save_path=models_save_path, + download_method=download_method, + use_sdk_cache=use_sdk_cache, + force_download=force_download + ) + + return Path(models_path[0]), Path(os.path.dirname(models_path[1])), Path(os.path.dirname(models_path[2])) diff --git a/utils/image.py b/utils/image.py new file mode 100644 index 0000000..825f4f1 --- /dev/null +++ b/utils/image.py @@ -0,0 +1,65 @@ +import base64 +from pathlib import Path + +import cv2 +import numpy +from io import BytesIO +from PIL import Image + + +def image_process(image: Image.Image, target_size: int) -> Image.Image: + # make alpha to white + image = image.convert('RGBA') + new_image = Image.new('RGBA', image.size, 'WHITE') + new_image.alpha_composite(image) + image = new_image.convert('RGB') + del new_image + + # Pad image to square + original_size = image.size + desired_size = max(max(original_size), target_size) + + delta_width = desired_size - original_size[0] + delta_height = desired_size - original_size[1] + top_padding, bottom_padding = delta_height // 2, delta_height - (delta_height // 2) + left_padding, right_padding = delta_width // 2, delta_width - (delta_width // 2) + + # Convert image data to numpy float32 data + image = numpy.asarray(image) + + padded_image = cv2.copyMakeBorder( + src=image, + top=top_padding, + bottom=bottom_padding, + left=left_padding, + right=right_padding, + borderType=cv2.BORDER_CONSTANT, + value=[255, 255, 255] # WHITE + ) + + # USE INTER_AREA downscale + if padded_image.shape[0] > target_size: + padded_image = cv2.resize( + src=padded_image, + dsize=(target_size, target_size), + interpolation=cv2.INTER_AREA + ) + + # USE INTER_LANCZOS4 upscale + elif padded_image.shape[0] < target_size: + padded_image = cv2.resize( + src=padded_image, + dsize=(target_size, target_size), + interpolation=cv2.INTER_LANCZOS4 + ) + + return Image.fromarray(padded_image) + + +def encode_image_to_base64(image: Image.Image): + with BytesIO() as bytes_output: + image.save(bytes_output, format="PNG") + image_bytes = bytes_output.getvalue() + base64_image = base64.b64encode(image_bytes).decode("utf-8") + image_url = f"data:image/png;base64,{base64_image}" + return image_url diff --git a/utils/joy.py b/utils/joy.py new file mode 100644 index 0000000..72952d1 --- /dev/null +++ b/utils/joy.py @@ -0,0 +1,236 @@ +import glob +import os +import time +from argparse import Namespace +from pathlib import Path + +import torch +import torch.amp.autocast_mode +from PIL import Image +from torch import nn +from tqdm import tqdm +from transformers import (AutoModel, AutoProcessor, AutoTokenizer, + PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM) + +from utils.image import image_process +from utils.logger import Logger + +SUPPORT_IMAGE_FORMATS = ("bmp", "jpg", "jpeg", "png") + +class ImageAdapter(nn.Module): + def __init__(self, input_features: int, output_features: int): + super().__init__() + self.linear1 = nn.Linear(input_features, output_features) + self.activation = nn.GELU() + self.linear2 = nn.Linear(output_features, output_features) + + def forward(self, vision_outputs: torch.Tensor): + x = self.linear1(vision_outputs) + x = self.activation(x) + x = self.linear2(x) + return x + +class Joy: + def __init__( + self, + logger: Logger, + args: Namespace, + image_adapter_path: Path, + clip_path: Path, + llm_path: Path, + use_gpu: bool = True, + ): + self.logger = logger + self.args = args + self.model_name = self.args.model_name + self.image_adapter_path = image_adapter_path + self.clip_path = clip_path + self.llm_path = llm_path + self.image_adapter = None + self.clip_processor = None + self.clip_model = None + self.llm_tokenizer = None + self.llm = None + self.use_gpu = use_gpu + + def load_model(self): + # Load CLIP + self.logger.info(f'Loading CLIP with {"GPU" if self.use_gpu else "CPU"}...') + start_time = time.monotonic() + self.clip_processor = AutoProcessor.from_pretrained(self.clip_path) + self.clip_model = AutoModel.from_pretrained(self.clip_path) + self.clip_model = self.clip_model.vision_model + self.clip_model.eval() + self.clip_model.requires_grad_(False) + self.clip_model.to("cuda") + self.logger.info(f'CLIP Loaded in {time.monotonic() - start_time:.1f}s.') + + # Load LLM + self.logger.info(f'Loading LLM with {"GPU" if self.use_gpu else "CPU"}...') + start_time = time.monotonic() + self.llm_tokenizer = AutoTokenizer.from_pretrained(self.llm_path, use_fast=False) + assert (isinstance(self.llm_tokenizer, PreTrainedTokenizer) or + isinstance(self.llm_tokenizer, PreTrainedTokenizerFast)), \ + f"Tokenizer is of type {type(self.llm_tokenizer)}" + self.llm = AutoModelForCausalLM.from_pretrained(self.llm_path, device_map="auto", torch_dtype=torch.bfloat16) + self.llm.eval() + self.logger.info(f'LLM Loaded in {time.monotonic() - start_time:.1f}s.') + + # Load Image Adapter + self.logger.info(f'Loading Image Adapter with {"GPU" if self.use_gpu else "CPU"}...') + self.image_adapter = ImageAdapter(self.clip_model.config.hidden_size, self.llm.config.hidden_size) + self.image_adapter.load_state_dict(torch.load(self.image_adapter_path, map_location="cpu")) + self.image_adapter.eval() + self.image_adapter.to("cuda") + self.logger.info(f'Image Adapter Loaded in {time.monotonic() - start_time:.1f}s.') + + def inference(self): + # Get image paths + path_to_find = os.path.join(self.args.data_path, '**') \ + if self.args.recursive else os.path.join(self.args.data_path, '*') + image_paths = sorted(set( + [image for image in glob.glob(path_to_find, recursive=self.args.recursive) + if image.lower().endswith(SUPPORT_IMAGE_FORMATS)]), + key=lambda filename: (os.path.splitext(filename)[0]) + ) if not os.path.isfile(self.args.data_path) else str(self.args.data_path) \ + if str(self.args.data_path).lower().endswith(SUPPORT_IMAGE_FORMATS) else None + + if image_paths is None: + self.logger.error('Invalid dir or image path!') + raise FileNotFoundError + + self.logger.info(f'Found {len(image_paths)} image(s).') + + def get_caption( + image: Image, + user_prompt: str, + temperature: float = 0.5, + max_new_tokens: int = 300, + ) -> str: + # Cleaning VRAM cache + torch.cuda.empty_cache() + # Preprocess image + image = self.clip_processor(images=image, return_tensors='pt').pixel_values + image = image.to('cuda') + # Tokenize the prompt + prompt = self.llm_tokenizer.encode(user_prompt, + return_tensors='pt', + padding=False, + truncation=False, + add_special_tokens=False) + # Embed image + with torch.amp.autocast_mode.autocast('cuda', enabled=True): + vision_outputs = self.clip_model(pixel_values=image, output_hidden_states=True) + image_features = vision_outputs.hidden_states[-2] + embedded_images = self.image_adapter(image_features) + embedded_images = embedded_images.to('cuda') + # Embed prompt + prompt_embeds = self.llm.model.embed_tokens(prompt.to('cuda')) + assert prompt_embeds.shape == (1, prompt.shape[1], + self.llm.config.hidden_size), \ + f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], self.llm.config.hidden_size)}" + embedded_bos = self.llm.model.embed_tokens(torch.tensor([[self.llm_tokenizer.bos_token_id]], + device=self.llm.device, + dtype=torch.int64)) + # Construct prompts + inputs_embeds = torch.cat([ + embedded_bos.expand(embedded_images.shape[0], -1, -1), + embedded_images.to(dtype=embedded_bos.dtype), + prompt_embeds.expand(embedded_images.shape[0], -1, -1), + ], dim=1) + + input_ids = torch.cat([ + torch.tensor([[self.llm_tokenizer.bos_token_id]], dtype=torch.long), + torch.zeros((1, embedded_images.shape[1]), dtype=torch.long), + prompt, + ], dim=1).to('cuda') + attention_mask = torch.ones_like(input_ids) + # Generate caption + generate_ids = self.llm.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, + max_new_tokens=max_new_tokens, do_sample=True, top_k=10, + temperature=temperature, suppress_tokens=None) + # Trim off the prompt + generate_ids = generate_ids[:, input_ids.shape[1]:] + if generate_ids[0][-1] == self.llm_tokenizer.eos_token_id: + generate_ids = generate_ids[:, :-1] + + content = self.llm_tokenizer.batch_decode(generate_ids, + skip_special_tokens=False, + clean_up_tokenization_spaces=False)[0] + + return content.strip() + + pbar = tqdm(total=len(image_paths), smoothing=0.0) + for image_path in image_paths: + try: + pbar.set_description('Processing: {}'.format(image_path if len(image_path) <= 40 else + image_path[:15]) + ' ... ' + image_path[-20:]) + image = Image.open(image_path) + image = image_process(image, int(self.args.image_size)) + caption = get_caption( + image=image, + user_prompt=self.args.user_prompt, + temperature=self.args.temperature, + max_new_tokens=self.args.max_tokens + ) + if self.args.custom_caption_save_path is not None: + if not os.path.exists(self.args.custom_caption_save_path): + self.logger.error(f'{self.args.custom_caption_save_path} NOT FOUND!') + raise FileNotFoundError + + self.logger.debug(f'Caption file(s) will be saved in {self.args.custom_caption_save_path}') + + if os.path.isfile(self.args.data_path): + caption_file = str(os.path.splitext(os.path.basename(image_path))[0]) + + else: + caption_file = os.path.splitext(str(image_path)[len(str(self.args.data_path)):])[0] + + caption_file = caption_file[1:] if caption_file[0] == '/' else caption_file + caption_file = os.path.join(self.args.custom_caption_save_path, caption_file) + # Make dir if not exist. + os.makedirs(os.path.join(str(caption_file)[:-len(os.path.basename(caption_file))]), exist_ok=True) + caption_file = Path(str(caption_file) + self.args.caption_extension) + + else: + caption_file = os.path.splitext(image_path)[0] + self.args.caption_extension + + if self.args.not_overwrite and os.path.isfile(caption_file): + self.logger.warning(f'Caption file {caption_file} already exist! Skip this caption.') + continue + + with open(caption_file, "wt", encoding="utf-8") as f: + f.write(caption + "\n") + self.logger.debug(f"\tImage path: {image_path}") + self.logger.debug(f"\tCaption path: {caption_file}") + self.logger.debug(f"\tCaption content: {caption}") + + pbar.update(1) + + except Exception as e: + self.logger.error(f"Could not load image path: {image_path}, skip it.\nerror info: {e}") + continue + + pbar.close() + + def unload_model(self): + # Unload Image Adapter + if self.image_adapter is not None: + self.logger.info(f'Unloading Image Adapter...') + start = time.monotonic() + del self.image_adapter + self.logger.info(f'Image Adapter unloaded in {time.monotonic() - start:.1f}s.') + # Unload LLM + if self.llm is not None: + self.logger.info(f'Unloading LLM...') + start = time.monotonic() + del self.llm + del self.llm_tokenizer + self.logger.info(f'LLM unloaded in {time.monotonic() - start:.1f}s.') + # Unload CLIP + if self.clip_model is not None: + self.logger.info(f'Unloading CLIP...') + start = time.monotonic() + del self.clip_model + del self.clip_processor + self.logger.info(f'CLIP unloaded in {time.monotonic() - start:.1f}s.') diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..3aab52b --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,65 @@ +import logging +from logging import handlers +from typing import Optional + + +class Logger: + + def __init__(self, level="INFO", log_file: Optional[str] = None): + self.logger = logging.getLogger() + self.logger.setLevel(level) + + formatter = logging.Formatter('%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') + + console_handler = logging.StreamHandler() + console_handler.setLevel(level) + console_handler.setFormatter(formatter) + self.logger.addHandler(console_handler) + + if log_file is not None: + file_handler = handlers.TimedRotatingFileHandler(filename=log_file, + when='D', + interval=1, + backupCount=5, + encoding='utf-8') + file_handler.setLevel(level) + file_handler.setFormatter(formatter) + self.logger.addHandler(file_handler) + + else: + self.logger.warning("save_log not enable or log file path not exist, log will only output in console.") + + def set_level(self, level): + if level.lower() == "debug": + level = logging.DEBUG + elif level.lower() == "info": + level = logging.INFO + elif level.lower() == "warning": + level = logging.WARNING + elif level.lower() == "error": + level = logging.ERROR + elif level.lower() == "critical": + level = logging.CRITICAL + else: + error_message = "Invalid log level" + self.logger.critical(error_message) + raise ValueError(error_message) + + self.logger.setLevel(level) + for handler in self.logger.handlers: + handler.setLevel(level) + + def debug(self, message): + self.logger.debug(message) + + def info(self, message): + self.logger.info(message) + + def warning(self, message): + self.logger.warning(message) + + def error(self, message): + self.logger.error(message) + + def critical(self, message): + self.logger.critical(message)