diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e7fe8d8
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+.envs
+.idea
+.huggingface
+models
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..261eeb9
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..e27e5e1
--- /dev/null
+++ b/README.md
@@ -0,0 +1,150 @@
+# Joy Caption Cli
+A Python base cli tool for tagging images with [joy-caption-pre-alpha](https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha) models.
+
+### Only support cuda devices in current.
+
+## Introduce
+
+I make this repo because I want to caption some images cross-platform (On My old MBP, my game win pc or docker base linux cloud-server(like Google colab))
+
+But I don't want to install a huge webui just for this little work. And some cloud-service are unfriendly to gradio base ui.
+
+So this repo born.
+
+
+## Model source
+
+Huggingface are original sources, modelscope are pure forks from Huggingface(Because HuggingFace was blocked in Some place).
+
+| Model | HuggingFace Link | ModelScope Link |
+|:---------------------------------:|:------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------:|
+| joy-caption-pre-alpha | [HuggingFace](https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha) | [ModelScope](https://www.modelscope.cn/models/fireicewolf/joy-caption-pre-alpha) |
+| siglip-so400m-patch14-384(Google) | [HuggingFace](https://huggingface.co/google/siglip-so400m-patch14-384) | [ModelScope](https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384) |
+| Meta-Llama-3.1-8B | [HuggingFace](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) | [ModelScope](https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B) |
+
+## TO-DO
+
+make a simple ui by Jupyter widget(When my lazy cancer curedπ)
+
+## Installation
+Python 3.10 works fine.
+
+Open a shell terminal and follow below steps:
+```shell
+# Clone this repo
+git clone https://github.com/fireicewolf/joy-caption-cli.git
+cd joy-caption-cli
+
+# create a Python venv
+python -m venv .venv
+.\venv\Scripts\activate
+
+# Install torch
+# Install torch base on your GPU driver. ex.
+pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cu124
+
+# Base dependencies, models for inference will download via python request libs.
+pip install -U -r requirements.txt
+
+# If you want to download or cache model via huggingface hub, install this.
+pip install -U -r huggingface-requirements.txt
+
+# If you want to download or cache model via modelscope hub, install this.
+pip install -U -r modelscope-requirements.txt
+```
+
+### Take a notice
+This project use llama-cpp-python as base lib, and it needs to be complied.
+
+## Simple usage
+__Make sure your python venv has been activated first!__
+```shell
+python caption.py your_datasets_path
+```
+To run with more options, You can find help by run with this or see at [Options](#options)
+```shell
+python caption.py -h
+```
+
+## Options
+
+ Advance options
+`data_path`
+
+path for data
+
+`--recursive`
+
+Will include all support images format in your input datasets path and its sub-path.
+
+`config`
+
+config json for llava models, default is "default.json"
+
+[//]: # (`--use_cpu`)
+
+[//]: # ()
+[//]: # (Use cpu for inference.)
+
+`--model_name MODEL_NAME`
+
+model name for inference, default is "Joy-Caption-Pre-Alpha", please check configs/default.json)
+
+`--model_site MODEL_SITE`
+
+Model site where onnx model download from(huggingface or modelscope), default is huggingface.
+
+`--models_save_path MODEL_SAVE_PATH`
+
+Path for models to save, default is models(under project folder).
+
+`--download_method SDK`
+
+Download models via sdk or url, default is sdk.
+
+If huggingface hub or modelscope sdk not installed or download failed, will auto retry with url download.
+
+`--use_sdk_cache`
+
+Use huggingface or modelscope sdk cache to store models, this option need huggingface_hub or modelscope sdk installed.
+
+If this enabled, `--models_save_path` will be ignored.
+
+`--custom_caption_save_path CUSTOM_CAPTION_SAVE_PATH`
+
+Save caption files to a custom path but not with images(But keep their directory structure)
+
+`--log_level LOG_LEVEL`
+
+Log level for terminal console and log file, default is `INFO`(`DEBUG`,`INFO`,`WARNING`,`ERROR`,`CRITICAL`)
+
+`--save_logs`
+
+Save logs to a file, log will be saved at same level with `data_dir_path`
+
+`--caption_extension CAPTION_EXTENSION`
+
+Caption file extension, default is `.txt`
+
+`--not_overwrite`
+
+Do not overwrite caption file if it existed.
+
+`--user_prompt USER_PROMPT`
+
+user prompt for caption.
+
+`--temperature TEMPERATURE`
+
+temperature for Llama model,default is 0.5.
+
+`--max_tokens MAX_TOKENS`
+
+max tokens for output, default is 300.
+
+
+
+## Credits
+Base on [oy-caption-pre-alpha](https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha)
+
+Without their works(ππ), this repo won't exist.
diff --git a/caption.py b/caption.py
new file mode 100644
index 0000000..d38a595
--- /dev/null
+++ b/caption.py
@@ -0,0 +1,206 @@
+import argparse
+import os
+from datetime import datetime
+from pathlib import Path
+
+from utils.download import download
+from utils.joy import Joy
+from utils.logger import Logger
+
+DEFAULT_USER_PROMPT = """
+A descriptive caption for this image:\n
+"""
+
+
+def main(args):
+ # Set logger
+ workspace_path = os.getcwd()
+ data_dir_path = Path(args.data_path)
+
+ log_file_path = data_dir_path.parent if os.path.exists(data_dir_path.parent) else workspace_path
+
+ if args.custom_caption_save_path:
+ log_file_path = Path(args.custom_caption_save_path)
+
+ log_time = datetime.now().strftime('%Y%m%d_%H%M%S')
+ # caption_failed_list_file = f'Caption_failed_list_{log_time}.txt'
+
+ if os.path.exists(data_dir_path):
+ log_name = os.path.basename(data_dir_path)
+
+ else:
+ print(f'{data_dir_path} NOT FOUND!!!')
+ raise FileNotFoundError
+
+ if args.save_logs:
+ log_file = f'Caption_{log_name}_{log_time}.log' if log_name else f'test_{log_time}.log'
+ log_file = os.path.join(log_file_path, log_file) \
+ if os.path.exists(log_file_path) else os.path.join(os.getcwd(), log_file)
+ else:
+ log_file = None
+
+ if str(args.log_level).lower() in 'debug, info, warning, error, critical':
+ my_logger = Logger(args.log_level, log_file).logger
+ my_logger.info(f'Set log level to "{args.log_level}"')
+
+ else:
+ my_logger = Logger('INFO', log_file).logger
+ my_logger.warning('Invalid log level, set log level to "INFO"!')
+
+ if args.save_logs:
+ my_logger.info(f'Log file will be saved as "{log_file}".')
+
+ # Check custom models path
+ config_file = os.path.join(Path(__file__).parent, 'configs', 'default.json') \
+ if args.config == "default.json" else Path(args.config)
+
+ # Download models
+ if os.path.exists(Path(args.models_save_path)):
+ models_save_path = Path(args.models_save_path)
+ else:
+ models_save_path = Path(os.path.join(Path(__file__).parent, args.models_save_path))
+
+ image_adapter_path, clip_path, llm_path = download(
+ logger=my_logger,
+ config_file=config_file,
+ model_name=str(args.model_name),
+ model_site=str(args.model_site),
+ models_save_path=models_save_path,
+ use_sdk_cache=True if args.use_sdk_cache else False,
+ download_method=str(args.download_method)
+ )
+
+ # Load models
+ my_joy = Joy(
+ logger=my_logger,
+ args=args,
+ image_adapter_path=image_adapter_path,
+ clip_path=clip_path,
+ llm_path=llm_path
+ )
+ my_joy.load_model()
+
+ # Inference
+ my_joy.inference()
+
+ # Unload models
+ my_joy.unload_model()
+
+
+def setup_args() -> argparse.ArgumentParser:
+ args = argparse.ArgumentParser()
+
+ args.add_argument(
+ 'data_path',
+ type=str,
+ help='path for data.'
+ )
+ args.add_argument(
+ '--recursive',
+ action='store_true',
+ help='Include recursive dirs'
+ )
+ args.add_argument(
+ '--config',
+ type=str,
+ default='default.json',
+ help='config json for llava models, default is "default.json"'
+ )
+ # args.add_argument(
+ # '--use_cpu',
+ # action='store_true',
+ # help='use cpu for inference.'
+ # )
+ args.add_argument(
+ '--image_size',
+ type=int,
+ default=1024,
+ help='resize image to suitable, default is 1024.'
+ )
+ args.add_argument(
+ '--model_name',
+ type=str,
+ default='Joy-Caption-Pre-Alpha',
+ help='model name for inference, default is "Joy-Caption-Pre-Alpha", please check configs/default.json'
+ )
+ args.add_argument(
+ '--model_site',
+ type=str,
+ choices=['huggingface', 'modelscope'],
+ default='huggingface',
+ help='download model from model site huggingface or modelscope, default is "huggingface".'
+ )
+ args.add_argument(
+ '--models_save_path',
+ type=str,
+ default="models",
+ help='path to save models, default is "models".'
+ )
+ args.add_argument(
+ '--use_sdk_cache',
+ action='store_true',
+ help='use sdk\'s cache dir to store models. \
+ if this option enabled, "--models_save_path" will be ignored.'
+ )
+ args.add_argument(
+ '--download_method',
+ type=str,
+ choices=["SDK", "URL"],
+ default='SDK',
+ help='download method via SDK or URL, default is "SDK".'
+ )
+ args.add_argument(
+ '--custom_caption_save_path',
+ type=str,
+ default=None,
+ help='Input custom caption file save path.'
+ )
+ args.add_argument(
+ '--log_level',
+ type=str,
+ choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
+ default='INFO',
+ help='set log level, default is "INFO"'
+ )
+ args.add_argument(
+ '--save_logs',
+ action='store_true',
+ help='save log file.'
+ )
+ args.add_argument(
+ '--caption_extension',
+ type=str,
+ default='.txt',
+ help='extension of caption file, default is ".txt"'
+ )
+ args.add_argument(
+ '--not_overwrite',
+ action='store_true',
+ help='not overwrite caption file if exist.'
+ )
+ args.add_argument(
+ '--user_prompt',
+ type=str,
+ default=DEFAULT_USER_PROMPT,
+ help='user prompt for caption.'
+ )
+ args.add_argument(
+ '--temperature',
+ type=float,
+ default=0.5,
+ help='temperature for Llama model.'
+ )
+ args.add_argument(
+ '--max_tokens',
+ type=int,
+ default=300,
+ help='max tokens for output.'
+ )
+
+ return args
+
+
+if __name__ == "__main__":
+ args = setup_args()
+ args = args.parse_args()
+ main(args)
diff --git a/configs/default.json b/configs/default.json
new file mode 100644
index 0000000..1043068
--- /dev/null
+++ b/configs/default.json
@@ -0,0 +1,87 @@
+{
+ "Joy-Caption-Pre-Alpha": {
+ "huggingface": {
+ "image_adapter": {
+ "repo_id": "fancyfeast/joy-caption-pre-alpha",
+ "revision": "main",
+ "repo_type": "space",
+ "subfolder": "wpkklhc6",
+ "file_list": {
+ "image_adapter.pt": "https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha/resolve/main/wpkklhc6/image_adapter.pt"
+ }
+ },
+ "clip": {
+ "repo_id": "google/siglip-so400m-patch14-384",
+ "revision": "main",
+ "repo_type": "model",
+ "subfolder": "",
+ "file_list": {
+ "config.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/config.json",
+ "tokenizer_config.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/tokenizer_config.json",
+ "special_tokens_map.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/preprocessor_config.json",
+ "preprocessor_config.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/special_tokens_map.json",
+ "spiece.model": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/spiece.model",
+ "model.safetensors": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/model.safetensors"
+ }
+ },
+ "llm": {
+ "repo_id": "meta-llama/Meta-Llama-3.1-8B",
+ "revision": "main",
+ "repo_type": "model",
+ "subfolder": "",
+ "file_list": {
+ "config.json":"https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/config.json",
+ "generation_config.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/generation_config.json",
+ "tokenizer.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/tokenizer.json",
+ "tokenizer_config.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/tokenizer_config.json",
+ "special_tokens_map.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/special_tokens_map.json",
+ "model.safetensors.index.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model.safetensors.index.json",
+ "model-00001-of-00004.safetensors": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model-00001-of-00004.safetensors",
+ "model-00002-of-00004.safetensors": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model-00002-of-00004.safetensors",
+ "model-00003-of-00004.safetensors": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model-00003-of-00004.safetensors",
+ "model-00004-of-00004.safetensors": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model-00004-of-00004.safetensors"
+ }
+ }
+ },
+ "modelscope": {
+ "image_adapter": {
+ "repo_id": "fireicewolf/joy-caption-pre-alpha",
+ "revision": "master",
+ "subfolder": "wpkklhc6",
+ "file_list": {
+ "image_adapter.pt": "https://www.modelscope.cn/models/fireicewolf/joy-caption-pre-alpha/resolve/master/wpkklhc6/image_adapter.pt"
+ }
+ },
+ "clip": {
+ "repo_id": "fireicewolf/siglip-so400m-patch14-384",
+ "revision": "master",
+ "subfolder": "",
+ "file_list": {
+ "config.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/config.json",
+ "tokenizer_config.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/tokenizer_config.json",
+ "special_tokens_map.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/preprocessor_config.json",
+ "preprocessor_config.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/special_tokens_map.json",
+ "spiece.model": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/spiece.model",
+ "model.safetensors": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/model.safetensors"
+ }
+ },
+ "llm": {
+ "repo_id": "fireicewolf/Meta-Llama-3.1-8B",
+ "revision": "master",
+ "subfolder": "",
+ "file_list": {
+ "config.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/config.json",
+ "generation_config.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/generation_config.json",
+ "tokenizer.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/tokenizer.json",
+ "tokenizer_config.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/tokenizer_config.json",
+ "special_tokens_map.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/special_tokens_map.json",
+ "model.safetensors.index.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model.safetensors.index.json",
+ "model-00001-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model-00001-of-00004.safetensors",
+ "model-00002-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model-00002-of-00004.safetensors",
+ "model-00003-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model-00003-of-00004.safetensors",
+ "model-00004-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model-00004-of-00004.safetensors"
+ }
+ }
+ }
+ }
+}
diff --git a/huggingface-requirements.txt b/huggingface-requirements.txt
new file mode 100644
index 0000000..2afcdef
--- /dev/null
+++ b/huggingface-requirements.txt
@@ -0,0 +1,2 @@
+huggingface_hub==0.24.6
+-r requirements.txt
\ No newline at end of file
diff --git a/modelscope-requirements.txt b/modelscope-requirements.txt
new file mode 100644
index 0000000..a4775a9
--- /dev/null
+++ b/modelscope-requirements.txt
@@ -0,0 +1,2 @@
+modelscope==1.17.1
+-r requirements.txt
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..e0890fa
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,8 @@
+accelerate==0.33.0
+transformers==4.44.0
+sentencepiece==0.2.0
+numpy==1.26.4
+opencv-python-headless==4.10.0.84
+pillow==10.4.0
+requests==2.32.3
+tqdm==4.66.5
\ No newline at end of file
diff --git a/utils/download.py b/utils/download.py
new file mode 100644
index 0000000..38a1152
--- /dev/null
+++ b/utils/download.py
@@ -0,0 +1,219 @@
+import json
+import os
+import shutil
+from pathlib import Path
+from typing import Union, Optional
+
+import requests
+from tqdm import tqdm
+
+from utils.logger import Logger
+
+
+def download(
+ logger: Logger,
+ config_file: Path,
+ model_name: str,
+ model_site: str,
+ models_save_path: Path,
+ use_sdk_cache: bool = False,
+ download_method: str = "sdk",
+ force_download: bool = False
+) -> tuple[Path, Path, Path]:
+ if os.path.isfile(config_file):
+ logger.info(f'Using config: {str(config_file)}')
+ else:
+ logger.error(f'{str(config_file)} NOT FOUND!')
+ raise FileNotFoundError
+
+ def read_json(config_file, model_name) -> dict[str]:
+ with open(config_file, 'r', encoding='utf-8') as config_json:
+ datas = json.load(config_json)
+ if model_name not in datas.keys():
+ logger.error(f'"{str(model_name)}" NOT FOUND IN CONFIG!')
+ raise FileNotFoundError
+ return datas[model_name]
+
+ model_info = read_json(config_file, model_name)
+
+ models_save_path = Path(os.path.join(models_save_path, model_name))
+
+ if use_sdk_cache:
+ logger.warning('use_sdk_cache ENABLED! download_method force to use "SDK" and models_save_path will be ignored')
+ download_method = 'sdk'
+ else:
+ logger.info(f'Models will be stored in {str(models_save_path)}.')
+
+ def url_download(
+ url: str,
+ local_dir: Union[str, Path],
+ force_download: bool = False,
+ force_filename: Optional[str] = None
+ ) -> Path:
+ # Download file via url by requests library
+ filename = os.path.basename(url) if not force_filename else force_filename
+ local_file = os.path.join(local_dir, filename)
+
+ hf_token = os.environ.get("HF_TOKEN")
+
+ response = requests.get(url, stream=True, headers={"Authorization": f"Bearer {hf_token}"} if "huggingface.co" in url and hf_token else None)
+ total_size = int(response.headers.get('content-length', 0))
+
+ def download_progress():
+ desc = 'Downloading {}'.format(filename)
+
+ if total_size > 0:
+ pbar = tqdm(total=total_size, initial=0, unit='B', unit_divisor=1024, unit_scale=True,
+ dynamic_ncols=True,
+ desc=desc)
+ else:
+ pbar = tqdm(initial=0, unit='B', unit_divisor=1024, unit_scale=True, dynamic_ncols=True, desc=desc)
+
+ if not os.path.exists(local_dir):
+ os.makedirs(local_dir, exist_ok=True)
+
+ with open(local_file, 'ab') as download_file:
+ for data in response.iter_content(chunk_size=1024):
+ if data:
+ download_file.write(data)
+ pbar.update(len(data))
+ pbar.close()
+
+ if not force_download and os.path.isfile(local_file):
+ if total_size == 0:
+ logger.info(
+ f'"{local_file}" already exist, but can\'t get its size from "{url}". Won\'t download it.')
+ elif os.path.getsize(local_file) == total_size:
+ logger.info(f'"{local_file}" already exist, and its size match with "{url}".')
+ else:
+ logger.info(
+ f'"{local_file}" already exist, but its size not match with "{url}"!\nWill download this file '
+ f'again...')
+ download_progress()
+ else:
+ download_progress()
+
+ return Path(os.path.join(local_dir, filename))
+
+ def download_choice(
+ model_info: dict[str],
+ model_site: str,
+ models_save_path: Path,
+ download_method: str = "sdk",
+ use_sdk_cache: bool = False,
+ force_download: bool = False
+ ):
+ if download_method.lower() == 'sdk':
+ if model_site == "huggingface":
+ model_hf_info = model_info["huggingface"]
+ try:
+ from huggingface_hub import hf_hub_download
+
+ models_path = []
+ for sub_model_name in model_hf_info:
+ sub_model_info = model_hf_info[sub_model_name]
+ sub_model_path = ""
+
+ for filename in sub_model_info["file_list"]:
+ logger.info(f'Will download "{filename}" from huggingface repo: "{sub_model_info["repo_id"]}".')
+ sub_model_path = hf_hub_download(
+ repo_id=sub_model_info["repo_id"],
+ filename=filename,
+ subfolder=sub_model_info["subfolder"] if sub_model_info["subfolder"] != "" else None,
+ repo_type=sub_model_info["repo_type"],
+ revision=sub_model_info["revision"],
+ local_dir=os.path.join(models_save_path,sub_model_name) if not use_sdk_cache else None,
+ # local_dir_use_symlinks=False if not use_sdk_cache else "auto",
+ # resume_download=True,
+ force_download=force_download
+ )
+ models_path.append(sub_model_path)
+ return models_path
+
+ except:
+ logger.warning('huggingface_hub not installed or download via it failed, '
+ 'retrying with URL method to download...')
+ models_path = download_choice(
+ model_info,
+ model_site,
+ models_save_path,
+ use_sdk_cache=False,
+ download_method="url",
+ force_download=force_download
+ )
+ return models_path
+
+ elif model_site == "modelscope":
+ model_ms_info = model_info["modelscope"]
+ try:
+ if force_download:
+ logger.warning(
+ 'modelscope api not support force download, '
+ 'trying to remove model path before download!')
+ shutil.rmtree(models_save_path)
+
+ from modelscope.hub.file_download import model_file_download
+
+ models_path = []
+ for sub_model_name in model_ms_info:
+ sub_model_info = model_ms_info[sub_model_name]
+ sub_model_path = ""
+
+ for filename in sub_model_info["file_list"]:
+ logger.info(f'Will download "{filename}" from modelscope repo: "{sub_model_info["repo_id"]}".')
+ sub_model_path = model_file_download(
+ model_id=sub_model_info["repo_id"],
+ file_path=filename if sub_model_info["subfolder"] == ""
+ else os.path.join(sub_model_info["subfolder"],filename),
+ revision=sub_model_info["revision"],
+ local_dir=os.path.join(models_save_path,sub_model_name) if not use_sdk_cache else None,
+ )
+ models_path.append(sub_model_path)
+ return models_path
+
+ except:
+ logger.warning('modelscope not installed or download via it failed, '
+ 'retrying with URL method to download...')
+ models_path = download_choice(
+ model_info,
+ model_site,
+ models_save_path,
+ use_sdk_cache=False,
+ download_method="url",
+ force_download=force_download
+ )
+ return models_path
+
+ else:
+ logger.error('Invalid model site!')
+ raise ValueError
+
+ else:
+ model_site_info = model_info[model_site]
+ models_path = []
+ for sub_model_name in model_site_info:
+ sub_model_info = model_site_info[sub_model_name]
+ sub_model_path = ""
+ for filename in sub_model_info["file_list"]:
+ model_url = sub_model_info["file_list"][filename]
+ logger.info(f'Will download model from url: {model_url}')
+ sub_model_path = url_download(
+ url=model_url,
+ local_dir=os.path.join(models_save_path,sub_model_name) if sub_model_info["subfolder"] == ""
+ else os.path.join(models_save_path,sub_model_name,sub_model_info["subfolder"]),
+ force_filename=filename,
+ force_download=force_download
+ )
+ models_path.append(sub_model_path)
+ return models_path
+
+ models_path = download_choice(
+ model_info=model_info,
+ model_site=model_site,
+ models_save_path=models_save_path,
+ download_method=download_method,
+ use_sdk_cache=use_sdk_cache,
+ force_download=force_download
+ )
+
+ return Path(models_path[0]), Path(os.path.dirname(models_path[1])), Path(os.path.dirname(models_path[2]))
diff --git a/utils/image.py b/utils/image.py
new file mode 100644
index 0000000..825f4f1
--- /dev/null
+++ b/utils/image.py
@@ -0,0 +1,65 @@
+import base64
+from pathlib import Path
+
+import cv2
+import numpy
+from io import BytesIO
+from PIL import Image
+
+
+def image_process(image: Image.Image, target_size: int) -> Image.Image:
+ # make alpha to white
+ image = image.convert('RGBA')
+ new_image = Image.new('RGBA', image.size, 'WHITE')
+ new_image.alpha_composite(image)
+ image = new_image.convert('RGB')
+ del new_image
+
+ # Pad image to square
+ original_size = image.size
+ desired_size = max(max(original_size), target_size)
+
+ delta_width = desired_size - original_size[0]
+ delta_height = desired_size - original_size[1]
+ top_padding, bottom_padding = delta_height // 2, delta_height - (delta_height // 2)
+ left_padding, right_padding = delta_width // 2, delta_width - (delta_width // 2)
+
+ # Convert image data to numpy float32 data
+ image = numpy.asarray(image)
+
+ padded_image = cv2.copyMakeBorder(
+ src=image,
+ top=top_padding,
+ bottom=bottom_padding,
+ left=left_padding,
+ right=right_padding,
+ borderType=cv2.BORDER_CONSTANT,
+ value=[255, 255, 255] # WHITE
+ )
+
+ # USE INTER_AREA downscale
+ if padded_image.shape[0] > target_size:
+ padded_image = cv2.resize(
+ src=padded_image,
+ dsize=(target_size, target_size),
+ interpolation=cv2.INTER_AREA
+ )
+
+ # USE INTER_LANCZOS4 upscale
+ elif padded_image.shape[0] < target_size:
+ padded_image = cv2.resize(
+ src=padded_image,
+ dsize=(target_size, target_size),
+ interpolation=cv2.INTER_LANCZOS4
+ )
+
+ return Image.fromarray(padded_image)
+
+
+def encode_image_to_base64(image: Image.Image):
+ with BytesIO() as bytes_output:
+ image.save(bytes_output, format="PNG")
+ image_bytes = bytes_output.getvalue()
+ base64_image = base64.b64encode(image_bytes).decode("utf-8")
+ image_url = f"data:image/png;base64,{base64_image}"
+ return image_url
diff --git a/utils/joy.py b/utils/joy.py
new file mode 100644
index 0000000..72952d1
--- /dev/null
+++ b/utils/joy.py
@@ -0,0 +1,236 @@
+import glob
+import os
+import time
+from argparse import Namespace
+from pathlib import Path
+
+import torch
+import torch.amp.autocast_mode
+from PIL import Image
+from torch import nn
+from tqdm import tqdm
+from transformers import (AutoModel, AutoProcessor, AutoTokenizer,
+ PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM)
+
+from utils.image import image_process
+from utils.logger import Logger
+
+SUPPORT_IMAGE_FORMATS = ("bmp", "jpg", "jpeg", "png")
+
+class ImageAdapter(nn.Module):
+ def __init__(self, input_features: int, output_features: int):
+ super().__init__()
+ self.linear1 = nn.Linear(input_features, output_features)
+ self.activation = nn.GELU()
+ self.linear2 = nn.Linear(output_features, output_features)
+
+ def forward(self, vision_outputs: torch.Tensor):
+ x = self.linear1(vision_outputs)
+ x = self.activation(x)
+ x = self.linear2(x)
+ return x
+
+class Joy:
+ def __init__(
+ self,
+ logger: Logger,
+ args: Namespace,
+ image_adapter_path: Path,
+ clip_path: Path,
+ llm_path: Path,
+ use_gpu: bool = True,
+ ):
+ self.logger = logger
+ self.args = args
+ self.model_name = self.args.model_name
+ self.image_adapter_path = image_adapter_path
+ self.clip_path = clip_path
+ self.llm_path = llm_path
+ self.image_adapter = None
+ self.clip_processor = None
+ self.clip_model = None
+ self.llm_tokenizer = None
+ self.llm = None
+ self.use_gpu = use_gpu
+
+ def load_model(self):
+ # Load CLIP
+ self.logger.info(f'Loading CLIP with {"GPU" if self.use_gpu else "CPU"}...')
+ start_time = time.monotonic()
+ self.clip_processor = AutoProcessor.from_pretrained(self.clip_path)
+ self.clip_model = AutoModel.from_pretrained(self.clip_path)
+ self.clip_model = self.clip_model.vision_model
+ self.clip_model.eval()
+ self.clip_model.requires_grad_(False)
+ self.clip_model.to("cuda")
+ self.logger.info(f'CLIP Loaded in {time.monotonic() - start_time:.1f}s.')
+
+ # Load LLM
+ self.logger.info(f'Loading LLM with {"GPU" if self.use_gpu else "CPU"}...')
+ start_time = time.monotonic()
+ self.llm_tokenizer = AutoTokenizer.from_pretrained(self.llm_path, use_fast=False)
+ assert (isinstance(self.llm_tokenizer, PreTrainedTokenizer) or
+ isinstance(self.llm_tokenizer, PreTrainedTokenizerFast)), \
+ f"Tokenizer is of type {type(self.llm_tokenizer)}"
+ self.llm = AutoModelForCausalLM.from_pretrained(self.llm_path, device_map="auto", torch_dtype=torch.bfloat16)
+ self.llm.eval()
+ self.logger.info(f'LLM Loaded in {time.monotonic() - start_time:.1f}s.')
+
+ # Load Image Adapter
+ self.logger.info(f'Loading Image Adapter with {"GPU" if self.use_gpu else "CPU"}...')
+ self.image_adapter = ImageAdapter(self.clip_model.config.hidden_size, self.llm.config.hidden_size)
+ self.image_adapter.load_state_dict(torch.load(self.image_adapter_path, map_location="cpu"))
+ self.image_adapter.eval()
+ self.image_adapter.to("cuda")
+ self.logger.info(f'Image Adapter Loaded in {time.monotonic() - start_time:.1f}s.')
+
+ def inference(self):
+ # Get image paths
+ path_to_find = os.path.join(self.args.data_path, '**') \
+ if self.args.recursive else os.path.join(self.args.data_path, '*')
+ image_paths = sorted(set(
+ [image for image in glob.glob(path_to_find, recursive=self.args.recursive)
+ if image.lower().endswith(SUPPORT_IMAGE_FORMATS)]),
+ key=lambda filename: (os.path.splitext(filename)[0])
+ ) if not os.path.isfile(self.args.data_path) else str(self.args.data_path) \
+ if str(self.args.data_path).lower().endswith(SUPPORT_IMAGE_FORMATS) else None
+
+ if image_paths is None:
+ self.logger.error('Invalid dir or image path!')
+ raise FileNotFoundError
+
+ self.logger.info(f'Found {len(image_paths)} image(s).')
+
+ def get_caption(
+ image: Image,
+ user_prompt: str,
+ temperature: float = 0.5,
+ max_new_tokens: int = 300,
+ ) -> str:
+ # Cleaning VRAM cache
+ torch.cuda.empty_cache()
+ # Preprocess image
+ image = self.clip_processor(images=image, return_tensors='pt').pixel_values
+ image = image.to('cuda')
+ # Tokenize the prompt
+ prompt = self.llm_tokenizer.encode(user_prompt,
+ return_tensors='pt',
+ padding=False,
+ truncation=False,
+ add_special_tokens=False)
+ # Embed image
+ with torch.amp.autocast_mode.autocast('cuda', enabled=True):
+ vision_outputs = self.clip_model(pixel_values=image, output_hidden_states=True)
+ image_features = vision_outputs.hidden_states[-2]
+ embedded_images = self.image_adapter(image_features)
+ embedded_images = embedded_images.to('cuda')
+ # Embed prompt
+ prompt_embeds = self.llm.model.embed_tokens(prompt.to('cuda'))
+ assert prompt_embeds.shape == (1, prompt.shape[1],
+ self.llm.config.hidden_size), \
+ f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], self.llm.config.hidden_size)}"
+ embedded_bos = self.llm.model.embed_tokens(torch.tensor([[self.llm_tokenizer.bos_token_id]],
+ device=self.llm.device,
+ dtype=torch.int64))
+ # Construct prompts
+ inputs_embeds = torch.cat([
+ embedded_bos.expand(embedded_images.shape[0], -1, -1),
+ embedded_images.to(dtype=embedded_bos.dtype),
+ prompt_embeds.expand(embedded_images.shape[0], -1, -1),
+ ], dim=1)
+
+ input_ids = torch.cat([
+ torch.tensor([[self.llm_tokenizer.bos_token_id]], dtype=torch.long),
+ torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
+ prompt,
+ ], dim=1).to('cuda')
+ attention_mask = torch.ones_like(input_ids)
+ # Generate caption
+ generate_ids = self.llm.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
+ max_new_tokens=max_new_tokens, do_sample=True, top_k=10,
+ temperature=temperature, suppress_tokens=None)
+ # Trim off the prompt
+ generate_ids = generate_ids[:, input_ids.shape[1]:]
+ if generate_ids[0][-1] == self.llm_tokenizer.eos_token_id:
+ generate_ids = generate_ids[:, :-1]
+
+ content = self.llm_tokenizer.batch_decode(generate_ids,
+ skip_special_tokens=False,
+ clean_up_tokenization_spaces=False)[0]
+
+ return content.strip()
+
+ pbar = tqdm(total=len(image_paths), smoothing=0.0)
+ for image_path in image_paths:
+ try:
+ pbar.set_description('Processing: {}'.format(image_path if len(image_path) <= 40 else
+ image_path[:15]) + ' ... ' + image_path[-20:])
+ image = Image.open(image_path)
+ image = image_process(image, int(self.args.image_size))
+ caption = get_caption(
+ image=image,
+ user_prompt=self.args.user_prompt,
+ temperature=self.args.temperature,
+ max_new_tokens=self.args.max_tokens
+ )
+ if self.args.custom_caption_save_path is not None:
+ if not os.path.exists(self.args.custom_caption_save_path):
+ self.logger.error(f'{self.args.custom_caption_save_path} NOT FOUND!')
+ raise FileNotFoundError
+
+ self.logger.debug(f'Caption file(s) will be saved in {self.args.custom_caption_save_path}')
+
+ if os.path.isfile(self.args.data_path):
+ caption_file = str(os.path.splitext(os.path.basename(image_path))[0])
+
+ else:
+ caption_file = os.path.splitext(str(image_path)[len(str(self.args.data_path)):])[0]
+
+ caption_file = caption_file[1:] if caption_file[0] == '/' else caption_file
+ caption_file = os.path.join(self.args.custom_caption_save_path, caption_file)
+ # Make dir if not exist.
+ os.makedirs(os.path.join(str(caption_file)[:-len(os.path.basename(caption_file))]), exist_ok=True)
+ caption_file = Path(str(caption_file) + self.args.caption_extension)
+
+ else:
+ caption_file = os.path.splitext(image_path)[0] + self.args.caption_extension
+
+ if self.args.not_overwrite and os.path.isfile(caption_file):
+ self.logger.warning(f'Caption file {caption_file} already exist! Skip this caption.')
+ continue
+
+ with open(caption_file, "wt", encoding="utf-8") as f:
+ f.write(caption + "\n")
+ self.logger.debug(f"\tImage path: {image_path}")
+ self.logger.debug(f"\tCaption path: {caption_file}")
+ self.logger.debug(f"\tCaption content: {caption}")
+
+ pbar.update(1)
+
+ except Exception as e:
+ self.logger.error(f"Could not load image path: {image_path}, skip it.\nerror info: {e}")
+ continue
+
+ pbar.close()
+
+ def unload_model(self):
+ # Unload Image Adapter
+ if self.image_adapter is not None:
+ self.logger.info(f'Unloading Image Adapter...')
+ start = time.monotonic()
+ del self.image_adapter
+ self.logger.info(f'Image Adapter unloaded in {time.monotonic() - start:.1f}s.')
+ # Unload LLM
+ if self.llm is not None:
+ self.logger.info(f'Unloading LLM...')
+ start = time.monotonic()
+ del self.llm
+ del self.llm_tokenizer
+ self.logger.info(f'LLM unloaded in {time.monotonic() - start:.1f}s.')
+ # Unload CLIP
+ if self.clip_model is not None:
+ self.logger.info(f'Unloading CLIP...')
+ start = time.monotonic()
+ del self.clip_model
+ del self.clip_processor
+ self.logger.info(f'CLIP unloaded in {time.monotonic() - start:.1f}s.')
diff --git a/utils/logger.py b/utils/logger.py
new file mode 100644
index 0000000..3aab52b
--- /dev/null
+++ b/utils/logger.py
@@ -0,0 +1,65 @@
+import logging
+from logging import handlers
+from typing import Optional
+
+
+class Logger:
+
+ def __init__(self, level="INFO", log_file: Optional[str] = None):
+ self.logger = logging.getLogger()
+ self.logger.setLevel(level)
+
+ formatter = logging.Formatter('%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
+
+ console_handler = logging.StreamHandler()
+ console_handler.setLevel(level)
+ console_handler.setFormatter(formatter)
+ self.logger.addHandler(console_handler)
+
+ if log_file is not None:
+ file_handler = handlers.TimedRotatingFileHandler(filename=log_file,
+ when='D',
+ interval=1,
+ backupCount=5,
+ encoding='utf-8')
+ file_handler.setLevel(level)
+ file_handler.setFormatter(formatter)
+ self.logger.addHandler(file_handler)
+
+ else:
+ self.logger.warning("save_log not enable or log file path not exist, log will only output in console.")
+
+ def set_level(self, level):
+ if level.lower() == "debug":
+ level = logging.DEBUG
+ elif level.lower() == "info":
+ level = logging.INFO
+ elif level.lower() == "warning":
+ level = logging.WARNING
+ elif level.lower() == "error":
+ level = logging.ERROR
+ elif level.lower() == "critical":
+ level = logging.CRITICAL
+ else:
+ error_message = "Invalid log level"
+ self.logger.critical(error_message)
+ raise ValueError(error_message)
+
+ self.logger.setLevel(level)
+ for handler in self.logger.handlers:
+ handler.setLevel(level)
+
+ def debug(self, message):
+ self.logger.debug(message)
+
+ def info(self, message):
+ self.logger.info(message)
+
+ def warning(self, message):
+ self.logger.warning(message)
+
+ def error(self, message):
+ self.logger.error(message)
+
+ def critical(self, message):
+ self.logger.critical(message)