diff --git a/README.md b/README.md
index 2551b1e..2b1e190 100644
--- a/README.md
+++ b/README.md
@@ -1,35 +1,41 @@
# Joy Caption Cli
-A Python base cli tool for tagging images with [joy-caption-pre-alpha](https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha) models.
+
+A Python base cli tool for tagging images
+with [joy-caption-pre-alpha](https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha) models.
### Only support cuda devices in current.
## Introduce
-I make this repo because I want to caption some images cross-platform (On My old MBP, my game win pc or docker base linux cloud-server(like Google colab))
+I make this repo because I want to caption some images cross-platform (On My old MBP, my game win pc or docker base
+linux cloud-server(like Google colab))
-But I don't want to install a huge webui just for this little work. And some cloud-service are unfriendly to gradio base ui.
+But I don't want to install a huge webui just for this little work. And some cloud-service are unfriendly to gradio base
+ui.
So this repo born.
-
## Model source
-Huggingface are original sources, modelscope are pure forks from Huggingface(Because HuggingFace was blocked in Some place).
+Huggingface are original sources, modelscope are pure forks from Huggingface(Because HuggingFace was blocked in Some
+place).
-| Model | HuggingFace Link | ModelScope Link |
-|:---------------------------------:|:------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------:|
-| joy-caption-pre-alpha | [HuggingFace](https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha) | [ModelScope](https://www.modelscope.cn/models/fireicewolf/joy-caption-pre-alpha) |
-| siglip-so400m-patch14-384(Google) | [HuggingFace](https://huggingface.co/google/siglip-so400m-patch14-384) | [ModelScope](https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384) |
-| Meta-Llama-3.1-8B | [HuggingFace](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) | [ModelScope](https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B) |
+| Model | HuggingFace Link | ModelScope Link |
+|:---------------------------------:|:-----------------------------------------------------------------------------:|:------------------------------------------------------------------------------------:|
+| joy-caption-pre-alpha | [HuggingFace](https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha) | [ModelScope](https://www.modelscope.cn/models/fireicewolf/joy-caption-pre-alpha) |
+| siglip-so400m-patch14-384(Google) | [HuggingFace](https://huggingface.co/google/siglip-so400m-patch14-384) | [ModelScope](https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384) |
+| Meta-Llama-3.1-8B | [HuggingFace](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) | [ModelScope](https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B) |
## TO-DO
make a simple ui by Jupyter widget(When my lazy cancer curedπ)
## Installation
+
Python 3.10 works fine.
Open a shell terminal and follow below steps:
+
```shell
# Clone this repo
git clone https://github.com/fireicewolf/joy-caption-cli.git
@@ -54,16 +60,21 @@ pip install -U -r modelscope-requirements.txt
```
## Simple usage
+
__Make sure your python venv has been activated first!__
+
```shell
python caption.py your_datasets_path
```
+
To run with more options, You can find help by run with this or see at [Options](#options)
+
```shell
python caption.py -h
```
-## Options
+## Options
+
Advance options
`data_path`
@@ -81,11 +92,12 @@ config json for llava models, default is "default.json"
[//]: # (`--use_cpu`)
[//]: # ()
+
[//]: # (Use cpu for inference.)
`--model_name MODEL_NAME`
-model name for inference, default is "Joy-Caption-Pre-Alpha", please check configs/default.json)
+model name for inference, default is "Joy-Caption-Pre-Alpha", please check configs/default.json
`--model_site MODEL_SITE`
@@ -142,6 +154,7 @@ max tokens for output, default is 300.
## Credits
+
Base on [oy-caption-pre-alpha](https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha)
Without their works(ππ), this repo won't exist.
diff --git a/caption.py b/caption.py
index d38a595..4c0e1bd 100644
--- a/caption.py
+++ b/caption.py
@@ -1,9 +1,10 @@
import argparse
import os
+import time
from datetime import datetime
from pathlib import Path
-from utils.download import download
+from utils.download import download_models
from utils.joy import Joy
from utils.logger import Logger
@@ -60,147 +61,186 @@ def main(args):
else:
models_save_path = Path(os.path.join(Path(__file__).parent, args.models_save_path))
- image_adapter_path, clip_path, llm_path = download(
+ image_adapter_path, clip_path, llm_path = download_models(
logger=my_logger,
+ args=args,
config_file=config_file,
- model_name=str(args.model_name),
- model_site=str(args.model_site),
models_save_path=models_save_path,
- use_sdk_cache=True if args.use_sdk_cache else False,
- download_method=str(args.download_method)
)
-
# Load models
my_joy = Joy(
logger=my_logger,
args=args,
image_adapter_path=image_adapter_path,
clip_path=clip_path,
- llm_path=llm_path
+ llm_path=llm_path,
+ use_gpu=True if not args.llm_use_cpu else False
)
my_joy.load_model()
# Inference
+ start_inference_time = time.monotonic()
my_joy.inference()
+ total_inference_time = time.monotonic() - start_inference_time
+ days = total_inference_time // (24 * 3600)
+ total_inference_time %= (24 * 3600)
+ hours = total_inference_time // 3600
+ total_inference_time %= 3600
+ minutes = total_inference_time // 60
+ seconds = total_inference_time % 60
+ days = f"{days} Day(s) " if days > 0 else ""
+ hours = f"{hours} Hour(s) " if hours > 0 or (days and hours == 0) else ""
+ minutes = f"{minutes} Min(s) " if minutes > 0 or (hours and minutes == 0) else ""
+ seconds = f"{seconds:.1f} Sec(s)"
+ my_logger.info(f"All work done with in {days}{hours}{minutes}{seconds}.")
# Unload models
my_joy.unload_model()
def setup_args() -> argparse.ArgumentParser:
- args = argparse.ArgumentParser()
-
- args.add_argument(
+ parsed_args = argparse.ArgumentParser()
+ base_args = parsed_args.add_argument_group("Base")
+ base_args.add_argument(
'data_path',
type=str,
help='path for data.'
)
- args.add_argument(
+ base_args.add_argument(
'--recursive',
action='store_true',
help='Include recursive dirs'
)
- args.add_argument(
+
+ log_args = parsed_args.add_argument_group("Logs")
+ log_args.add_argument(
+ '--log_level',
+ type=str,
+ choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
+ default='INFO',
+ help='set log level, default is "INFO"'
+ )
+ log_args.add_argument(
+ '--save_logs',
+ action='store_true',
+ help='save log file.'
+ )
+
+ download_args = parsed_args.add_argument_group("Download")
+ download_args.add_argument(
'--config',
type=str,
default='default.json',
help='config json for llava models, default is "default.json"'
)
- # args.add_argument(
- # '--use_cpu',
- # action='store_true',
- # help='use cpu for inference.'
- # )
- args.add_argument(
- '--image_size',
- type=int,
- default=1024,
- help='resize image to suitable, default is 1024.'
- )
- args.add_argument(
+ download_args.add_argument(
'--model_name',
type=str,
default='Joy-Caption-Pre-Alpha',
help='model name for inference, default is "Joy-Caption-Pre-Alpha", please check configs/default.json'
)
- args.add_argument(
+ download_args.add_argument(
'--model_site',
type=str,
choices=['huggingface', 'modelscope'],
default='huggingface',
help='download model from model site huggingface or modelscope, default is "huggingface".'
)
- args.add_argument(
+ download_args.add_argument(
'--models_save_path',
type=str,
default="models",
help='path to save models, default is "models".'
)
- args.add_argument(
+ download_args.add_argument(
'--use_sdk_cache',
action='store_true',
help='use sdk\'s cache dir to store models. \
if this option enabled, "--models_save_path" will be ignored.'
)
- args.add_argument(
+ download_args.add_argument(
'--download_method',
type=str,
choices=["SDK", "URL"],
default='SDK',
help='download method via SDK or URL, default is "SDK".'
)
- args.add_argument(
+ download_args.add_argument(
+ '--force_download',
+ action='store_true',
+ help='force download even file exists.'
+ )
+ download_args.add_argument(
+ '--skip_download',
+ action='store_true',
+ help='skip download if exists.'
+ )
+ download_args.add_argument(
'--custom_caption_save_path',
type=str,
default=None,
help='Input custom caption file save path.'
)
- args.add_argument(
- '--log_level',
+
+ inference_args = parsed_args.add_argument_group("Inference")
+ inference_args.add_argument(
+ '--llm_use_cpu',
+ action='store_true',
+ help='use cpu for inference.'
+ )
+ inference_args.add_argument(
+ '--llm_dtype',
type=str,
- choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
- default='INFO',
- help='set log level, default is "INFO"'
+ choices=["auto", "fp16", "bf16", "fp32"],
+ default='fp16',
+ help='choice joy LLM load dtype, default is `auto`.'
)
- args.add_argument(
- '--save_logs',
- action='store_true',
- help='save log file.'
+ inference_args.add_argument(
+ '--llm_qnt',
+ type=str,
+ choices=["none", "4bit", "8bit"],
+ default='none',
+ help='Enable quantization for LLM ["none","4bit", "8bit"]. default is `none`.'
+ )
+ inference_args.add_argument(
+ '--image_size',
+ type=int,
+ default=1024,
+ help='resize image to suitable, default is 1024.'
)
- args.add_argument(
+ inference_args.add_argument(
'--caption_extension',
type=str,
default='.txt',
help='extension of caption file, default is ".txt"'
)
- args.add_argument(
+ inference_args.add_argument(
'--not_overwrite',
action='store_true',
help='not overwrite caption file if exist.'
)
- args.add_argument(
+ inference_args.add_argument(
'--user_prompt',
type=str,
default=DEFAULT_USER_PROMPT,
help='user prompt for caption.'
)
- args.add_argument(
+ inference_args.add_argument(
'--temperature',
type=float,
default=0.5,
help='temperature for Llama model.'
)
- args.add_argument(
+ inference_args.add_argument(
'--max_tokens',
type=int,
default=300,
help='max tokens for output.'
)
-
- return args
+ return parsed_args
if __name__ == "__main__":
- args = setup_args()
- args = args.parse_args()
- main(args)
+ get_args = setup_args()
+ get_args = get_args.parse_args()
+ main(get_args)
diff --git a/configs/default.json b/configs/default.json
index 1043068..566e6c5 100644
--- a/configs/default.json
+++ b/configs/default.json
@@ -1,87 +1,172 @@
{
- "Joy-Caption-Pre-Alpha": {
- "huggingface": {
- "image_adapter": {
- "repo_id": "fancyfeast/joy-caption-pre-alpha",
- "revision": "main",
- "repo_type": "space",
- "subfolder": "wpkklhc6",
- "file_list": {
- "image_adapter.pt": "https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha/resolve/main/wpkklhc6/image_adapter.pt"
- }
- },
- "clip": {
- "repo_id": "google/siglip-so400m-patch14-384",
- "revision": "main",
- "repo_type": "model",
- "subfolder": "",
- "file_list": {
- "config.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/config.json",
- "tokenizer_config.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/tokenizer_config.json",
- "special_tokens_map.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/preprocessor_config.json",
- "preprocessor_config.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/special_tokens_map.json",
- "spiece.model": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/spiece.model",
- "model.safetensors": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/model.safetensors"
- }
- },
- "llm": {
- "repo_id": "meta-llama/Meta-Llama-3.1-8B",
- "revision": "main",
- "repo_type": "model",
- "subfolder": "",
- "file_list": {
- "config.json":"https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/config.json",
- "generation_config.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/generation_config.json",
- "tokenizer.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/tokenizer.json",
- "tokenizer_config.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/tokenizer_config.json",
- "special_tokens_map.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/special_tokens_map.json",
- "model.safetensors.index.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model.safetensors.index.json",
- "model-00001-of-00004.safetensors": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model-00001-of-00004.safetensors",
- "model-00002-of-00004.safetensors": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model-00002-of-00004.safetensors",
- "model-00003-of-00004.safetensors": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model-00003-of-00004.safetensors",
- "model-00004-of-00004.safetensors": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model-00004-of-00004.safetensors"
- }
- }
- },
- "modelscope": {
- "image_adapter": {
- "repo_id": "fireicewolf/joy-caption-pre-alpha",
- "revision": "master",
- "subfolder": "wpkklhc6",
- "file_list": {
- "image_adapter.pt": "https://www.modelscope.cn/models/fireicewolf/joy-caption-pre-alpha/resolve/master/wpkklhc6/image_adapter.pt"
- }
- },
- "clip": {
- "repo_id": "fireicewolf/siglip-so400m-patch14-384",
- "revision": "master",
- "subfolder": "",
- "file_list": {
- "config.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/config.json",
- "tokenizer_config.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/tokenizer_config.json",
- "special_tokens_map.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/preprocessor_config.json",
- "preprocessor_config.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/special_tokens_map.json",
- "spiece.model": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/spiece.model",
- "model.safetensors": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/model.safetensors"
- }
- },
- "llm": {
- "repo_id": "fireicewolf/Meta-Llama-3.1-8B",
- "revision": "master",
- "subfolder": "",
- "file_list": {
- "config.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/config.json",
- "generation_config.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/generation_config.json",
- "tokenizer.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/tokenizer.json",
- "tokenizer_config.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/tokenizer_config.json",
- "special_tokens_map.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/special_tokens_map.json",
- "model.safetensors.index.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model.safetensors.index.json",
- "model-00001-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model-00001-of-00004.safetensors",
- "model-00002-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model-00002-of-00004.safetensors",
- "model-00003-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model-00003-of-00004.safetensors",
- "model-00004-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model-00004-of-00004.safetensors"
- }
- }
+ "Joy-Caption-Pre-Alpha": {
+ "huggingface": {
+ "image_adapter": {
+ "repo_id": "fancyfeast/joy-caption-pre-alpha",
+ "revision": "main",
+ "repo_type": "space",
+ "subfolder": "wpkklhc6",
+ "file_list": {
+ "image_adapter.pt": "https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha/resolve/main/wpkklhc6/image_adapter.pt"
}
+ },
+ "clip": {
+ "repo_id": "google/siglip-so400m-patch14-384",
+ "revision": "main",
+ "repo_type": "model",
+ "subfolder": "",
+ "file_list": {
+ "config.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/config.json",
+ "tokenizer_config.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/tokenizer_config.json",
+ "special_tokens_map.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/preprocessor_config.json",
+ "preprocessor_config.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/special_tokens_map.json",
+ "spiece.model": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/spiece.model",
+ "model.safetensors": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/model.safetensors"
+ }
+ },
+ "llm": {
+ "repo_id": "meta-llama/Meta-Llama-3.1-8B",
+ "revision": "main",
+ "repo_type": "model",
+ "subfolder": "",
+ "file_list": {
+ "config.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/config.json",
+ "generation_config.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/generation_config.json",
+ "tokenizer.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/tokenizer.json",
+ "tokenizer_config.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/tokenizer_config.json",
+ "special_tokens_map.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/special_tokens_map.json",
+ "model.safetensors.index.json": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model.safetensors.index.json",
+ "model-00001-of-00004.safetensors": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model-00001-of-00004.safetensors",
+ "model-00002-of-00004.safetensors": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model-00002-of-00004.safetensors",
+ "model-00003-of-00004.safetensors": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model-00003-of-00004.safetensors",
+ "model-00004-of-00004.safetensors": "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/resolve/main/model-00004-of-00004.safetensors"
+ }
+ }
+ },
+ "modelscope": {
+ "image_adapter": {
+ "repo_id": "fireicewolf/joy-caption-pre-alpha",
+ "revision": "master",
+ "subfolder": "wpkklhc6",
+ "file_list": {
+ "image_adapter.pt": "https://www.modelscope.cn/models/fireicewolf/joy-caption-pre-alpha/resolve/master/wpkklhc6/image_adapter.pt"
+ }
+ },
+ "clip": {
+ "repo_id": "fireicewolf/siglip-so400m-patch14-384",
+ "revision": "master",
+ "subfolder": "",
+ "file_list": {
+ "config.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/config.json",
+ "tokenizer_config.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/tokenizer_config.json",
+ "special_tokens_map.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/preprocessor_config.json",
+ "preprocessor_config.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/special_tokens_map.json",
+ "spiece.model": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/spiece.model",
+ "model.safetensors": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/model.safetensors"
+ }
+ },
+ "llm": {
+ "repo_id": "fireicewolf/Meta-Llama-3.1-8B",
+ "revision": "master",
+ "subfolder": "",
+ "file_list": {
+ "config.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/config.json",
+ "generation_config.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/generation_config.json",
+ "tokenizer.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/tokenizer.json",
+ "tokenizer_config.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/tokenizer_config.json",
+ "special_tokens_map.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/special_tokens_map.json",
+ "model.safetensors.index.json": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model.safetensors.index.json",
+ "model-00001-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model-00001-of-00004.safetensors",
+ "model-00002-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model-00002-of-00004.safetensors",
+ "model-00003-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model-00003-of-00004.safetensors",
+ "model-00004-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Meta-Llama-3.1-8B/resolve/master/model-00004-of-00004.safetensors"
+ }
+ }
+ }
+ },
+ "Joy-Caption-Uncensored": {
+ "huggingface": {
+ "image_adapter": {
+ "repo_id": "fancyfeast/joy-caption-pre-alpha",
+ "revision": "main",
+ "repo_type": "space",
+ "subfolder": "wpkklhc6",
+ "file_list": {
+ "image_adapter.pt": "https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha/resolve/main/wpkklhc6/image_adapter.pt"
+ }
+ },
+ "clip": {
+ "repo_id": "google/siglip-so400m-patch14-384",
+ "revision": "main",
+ "repo_type": "model",
+ "subfolder": "",
+ "file_list": {
+ "config.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/config.json",
+ "tokenizer_config.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/tokenizer_config.json",
+ "special_tokens_map.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/preprocessor_config.json",
+ "preprocessor_config.json": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/special_tokens_map.json",
+ "spiece.model": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/spiece.model",
+ "model.safetensors": "https://huggingface.co/google/siglip-so400m-patch14-384/resolve/main/model.safetensors"
+ }
+ },
+ "llm": {
+ "repo_id": "Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2",
+ "revision": "main",
+ "repo_type": "model",
+ "subfolder": "",
+ "file_list": {
+ "config.json": "https://huggingface.co/Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/main/config.json",
+ "generation_config.json": "https://huggingface.co/Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/main/generation_config.json",
+ "tokenizer.json": "https://huggingface.co/Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/main/tokenizer.json",
+ "tokenizer_config.json": "https://huggingface.co/Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/main/tokenizer_config.json",
+ "special_tokens_map.json": "https://huggingface.co/Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/main/special_tokens_map.json",
+ "model.safetensors.index.json": "https://huggingface.co/Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/main/model.safetensors.index.json",
+ "model-00001-of-00004.safetensors": "https://huggingface.co/Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/main/model-00001-of-00004.safetensors",
+ "model-00002-of-00004.safetensors": "https://huggingface.co/Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/main/model-00002-of-00004.safetensors",
+ "model-00003-of-00004.safetensors": "https://huggingface.co/Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/main/model-00003-of-00004.safetensors",
+ "model-00004-of-00004.safetensors": "https://huggingface.co/Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/main/model-00004-of-00004.safetensors"
+ }
+ }
+ },
+ "modelscope": {
+ "image_adapter": {
+ "repo_id": "fireicewolf/joy-caption-pre-alpha",
+ "revision": "master",
+ "subfolder": "wpkklhc6",
+ "file_list": {
+ "image_adapter.pt": "https://www.modelscope.cn/models/fireicewolf/joy-caption-pre-alpha/resolve/master/wpkklhc6/image_adapter.pt"
+ }
+ },
+ "clip": {
+ "repo_id": "fireicewolf/siglip-so400m-patch14-384",
+ "revision": "master",
+ "subfolder": "",
+ "file_list": {
+ "config.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/config.json",
+ "tokenizer_config.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/tokenizer_config.json",
+ "special_tokens_map.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/preprocessor_config.json",
+ "preprocessor_config.json": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/special_tokens_map.json",
+ "spiece.model": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/spiece.model",
+ "model.safetensors": "https://www.modelscope.cn/models/fireicewolf/siglip-so400m-patch14-384/resolve/master/model.safetensors"
+ }
+ },
+ "llm": {
+ "repo_id": "fireicewolf/Llama-3.1-8B-Lexi-Uncensored-V2",
+ "revision": "master",
+ "subfolder": "",
+ "file_list": {
+ "config.json": "https://www.modelscope.cn/models/fireicewolf/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/master/config.json",
+ "generation_config.json": "https://www.modelscope.cn/models/fireicewolf/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/master/generation_config.json",
+ "tokenizer.json": "https://www.modelscope.cn/models/fireicewolf/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/master/tokenizer.json",
+ "tokenizer_config.json": "https://www.modelscope.cn/models/fireicewolf/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/master/tokenizer_config.json",
+ "special_tokens_map.json": "https://www.modelscope.cn/models/fireicewolf/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/master/special_tokens_map.json",
+ "model.safetensors.index.json": "https://www.modelscope.cn/models/fireicewolf/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/master/model.safetensors.index.json",
+ "model-00001-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/master/model-00001-of-00004.safetensors",
+ "model-00002-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/master/model-00002-of-00004.safetensors",
+ "model-00003-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/master/model-00003-of-00004.safetensors",
+ "model-00004-of-00004.safetensors": "https://www.modelscope.cn/models/fireicewolf/Llama-3.1-8B-Lexi-Uncensored-V2/resolve/master/model-00004-of-00004.safetensors"
+ }
+ }
}
+ }
}
diff --git a/huggingface-requirements.txt b/huggingface-requirements.txt
index 2afcdef..d292e9d 100644
--- a/huggingface-requirements.txt
+++ b/huggingface-requirements.txt
@@ -1,2 +1,2 @@
-huggingface_hub==0.24.6
+huggingface_hub==0.25.1
-r requirements.txt
\ No newline at end of file
diff --git a/modelscope-requirements.txt b/modelscope-requirements.txt
index a4775a9..fac4583 100644
--- a/modelscope-requirements.txt
+++ b/modelscope-requirements.txt
@@ -1,2 +1,2 @@
-modelscope==1.17.1
+modelscope==1.18.1
-r requirements.txt
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index f0c1d55..27f883c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,8 +1,8 @@
-accelerate==0.33.0
-transformers==4.44.2
-sentencepiece==0.2.0
-numpy==1.26.4
+accelerate==0.34.2
+bitsandbytes==0.44.1
opencv-python-headless==4.10.0.84
pillow==10.4.0
requests==2.32.3
-tqdm==4.66.5
\ No newline at end of file
+sentencepiece==0.2.0
+tqdm==4.66.5
+transformers==4.45.1
\ No newline at end of file
diff --git a/utils/download.py b/utils/download.py
index 38a1152..a33bc29 100644
--- a/utils/download.py
+++ b/utils/download.py
@@ -1,6 +1,6 @@
+import argparse
import json
import os
-import shutil
from pathlib import Path
from typing import Union, Optional
@@ -10,15 +10,69 @@
from utils.logger import Logger
-def download(
+def url_download(
logger: Logger,
+ url: str,
+ local_dir: Union[str, Path],
+ skip_local_file_exist: bool = True,
+ force_download: bool = False,
+ force_filename: Optional[str] = None
+) -> Path:
+ # Download file via url by requests library
+ filename = os.path.basename(url) if not force_filename else force_filename
+ local_file = os.path.join(local_dir, filename)
+
+ hf_token = os.environ.get("HF_TOKEN")
+ if hf_token:
+ logger.info(f"Loading huggingface token from environment variable")
+ response = requests.get(url, stream=True, headers={
+ "Authorization": f"Bearer {hf_token}"} if "huggingface.co" in url and hf_token else None)
+ total_size = int(response.headers.get('content-length', 0))
+
+ def download_progress():
+ desc = f'Downloading {filename}'
+
+ if total_size > 0:
+ pbar = tqdm(total=total_size, initial=0, unit='B', unit_divisor=1024, unit_scale=True,
+ dynamic_ncols=True,
+ desc=desc)
+ else:
+ pbar = tqdm(initial=0, unit='B', unit_divisor=1024, unit_scale=True, dynamic_ncols=True, desc=desc)
+
+ if not os.path.exists(local_dir):
+ os.makedirs(local_dir, exist_ok=True)
+
+ with open(local_file, 'ab') as download_file:
+ for data in response.iter_content(chunk_size=1024):
+ if data:
+ download_file.write(data)
+ pbar.update(len(data))
+ pbar.close()
+
+ if not force_download and os.path.isfile(local_file):
+ if os.path.exists(local_file) and skip_local_file_exist:
+ logger.info(f"`skip_local_file_exist` is Enable, Skipping download {local_file}...")
+ else:
+ if total_size == 0:
+ logger.info(
+ f'"{local_file}" already exist, but can\'t get its size from "{url}". Won\'t download it.')
+ elif os.path.getsize(local_file) == total_size:
+ logger.info(f'"{local_file}" already exist, and its size match with "{url}".')
+ else:
+ logger.info(
+ f'"{local_file}" already exist, but its size not match with "{url}"!\nWill download this file '
+ f'again...')
+ download_progress()
+ else:
+ download_progress()
+ return Path(os.path.join(local_dir, filename))
+
+
+def download_models(
+ logger: Logger,
+ args: argparse.Namespace,
config_file: Path,
- model_name: str,
- model_site: str,
models_save_path: Path,
- use_sdk_cache: bool = False,
- download_method: str = "sdk",
- force_download: bool = False
) -> tuple[Path, Path, Path]:
if os.path.isfile(config_file):
logger.info(f'Using config: {str(config_file)}')
@@ -26,194 +80,130 @@ def download(
logger.error(f'{str(config_file)} NOT FOUND!')
raise FileNotFoundError
- def read_json(config_file, model_name) -> dict[str]:
+ def read_json(config_file) -> tuple[str, dict[str]]:
with open(config_file, 'r', encoding='utf-8') as config_json:
datas = json.load(config_json)
+ model_name = list(datas.keys())[0] if args.model_name is None else args.model_name
+ args.model_name = model_name
+
if model_name not in datas.keys():
logger.error(f'"{str(model_name)}" NOT FOUND IN CONFIG!')
raise FileNotFoundError
- return datas[model_name]
-
- model_info = read_json(config_file, model_name)
+ return model_name, datas[model_name]
+ model_name, model_info = read_json(config_file)
models_save_path = Path(os.path.join(models_save_path, model_name))
- if use_sdk_cache:
+ if args.use_sdk_cache:
logger.warning('use_sdk_cache ENABLED! download_method force to use "SDK" and models_save_path will be ignored')
- download_method = 'sdk'
+ args.download_method = 'sdk'
else:
logger.info(f'Models will be stored in {str(models_save_path)}.')
- def url_download(
- url: str,
- local_dir: Union[str, Path],
- force_download: bool = False,
- force_filename: Optional[str] = None
- ) -> Path:
- # Download file via url by requests library
- filename = os.path.basename(url) if not force_filename else force_filename
- local_file = os.path.join(local_dir, filename)
-
- hf_token = os.environ.get("HF_TOKEN")
-
- response = requests.get(url, stream=True, headers={"Authorization": f"Bearer {hf_token}"} if "huggingface.co" in url and hf_token else None)
- total_size = int(response.headers.get('content-length', 0))
-
- def download_progress():
- desc = 'Downloading {}'.format(filename)
-
- if total_size > 0:
- pbar = tqdm(total=total_size, initial=0, unit='B', unit_divisor=1024, unit_scale=True,
- dynamic_ncols=True,
- desc=desc)
- else:
- pbar = tqdm(initial=0, unit='B', unit_divisor=1024, unit_scale=True, dynamic_ncols=True, desc=desc)
-
- if not os.path.exists(local_dir):
- os.makedirs(local_dir, exist_ok=True)
-
- with open(local_file, 'ab') as download_file:
- for data in response.iter_content(chunk_size=1024):
- if data:
- download_file.write(data)
- pbar.update(len(data))
- pbar.close()
-
- if not force_download and os.path.isfile(local_file):
- if total_size == 0:
- logger.info(
- f'"{local_file}" already exist, but can\'t get its size from "{url}". Won\'t download it.')
- elif os.path.getsize(local_file) == total_size:
- logger.info(f'"{local_file}" already exist, and its size match with "{url}".')
- else:
- logger.info(
- f'"{local_file}" already exist, but its size not match with "{url}"!\nWill download this file '
- f'again...')
- download_progress()
- else:
- download_progress()
-
- return Path(os.path.join(local_dir, filename))
-
def download_choice(
+ args: argparse.Namespace,
model_info: dict[str],
model_site: str,
models_save_path: Path,
download_method: str = "sdk",
use_sdk_cache: bool = False,
+ skip_local_file_exist: bool = True,
force_download: bool = False
):
- if download_method.lower() == 'sdk':
- if model_site == "huggingface":
- model_hf_info = model_info["huggingface"]
- try:
+ if model_site not in ["huggingface", "modelscope"]:
+ logger.error('Invalid model site!')
+ raise ValueError
+
+ model_site_info = model_info[model_site]
+ try:
+ if download_method == "sdk":
+ if model_site == "huggingface":
from huggingface_hub import hf_hub_download
-
- models_path = []
- for sub_model_name in model_hf_info:
- sub_model_info = model_hf_info[sub_model_name]
- sub_model_path = ""
-
- for filename in sub_model_info["file_list"]:
- logger.info(f'Will download "{filename}" from huggingface repo: "{sub_model_info["repo_id"]}".')
- sub_model_path = hf_hub_download(
- repo_id=sub_model_info["repo_id"],
- filename=filename,
- subfolder=sub_model_info["subfolder"] if sub_model_info["subfolder"] != "" else None,
- repo_type=sub_model_info["repo_type"],
- revision=sub_model_info["revision"],
- local_dir=os.path.join(models_save_path,sub_model_name) if not use_sdk_cache else None,
- # local_dir_use_symlinks=False if not use_sdk_cache else "auto",
- # resume_download=True,
- force_download=force_download
- )
- models_path.append(sub_model_path)
- return models_path
-
- except:
- logger.warning('huggingface_hub not installed or download via it failed, '
- 'retrying with URL method to download...')
- models_path = download_choice(
- model_info,
- model_site,
- models_save_path,
- use_sdk_cache=False,
- download_method="url",
- force_download=force_download
- )
- return models_path
-
- elif model_site == "modelscope":
- model_ms_info = model_info["modelscope"]
- try:
- if force_download:
- logger.warning(
- 'modelscope api not support force download, '
- 'trying to remove model path before download!')
- shutil.rmtree(models_save_path)
-
+ elif model_site == "modelscope":
from modelscope.hub.file_download import model_file_download
- models_path = []
- for sub_model_name in model_ms_info:
- sub_model_info = model_ms_info[sub_model_name]
- sub_model_path = ""
-
- for filename in sub_model_info["file_list"]:
- logger.info(f'Will download "{filename}" from modelscope repo: "{sub_model_info["repo_id"]}".')
- sub_model_path = model_file_download(
- model_id=sub_model_info["repo_id"],
- file_path=filename if sub_model_info["subfolder"] == ""
- else os.path.join(sub_model_info["subfolder"],filename),
- revision=sub_model_info["revision"],
- local_dir=os.path.join(models_save_path,sub_model_name) if not use_sdk_cache else None,
- )
- models_path.append(sub_model_path)
- return models_path
-
- except:
- logger.warning('modelscope not installed or download via it failed, '
- 'retrying with URL method to download...')
- models_path = download_choice(
- model_info,
- model_site,
- models_save_path,
- use_sdk_cache=False,
- download_method="url",
- force_download=force_download
- )
- return models_path
-
- else:
- logger.error('Invalid model site!')
- raise ValueError
+ except ModuleNotFoundError:
+ if model_site == "huggingface":
+ logger.warning('huggingface_hub not installed or download via it failed, '
+ 'retrying with URL method to download...')
+ elif model_site == "modelscope":
+ logger.warning('modelscope not installed or download via it failed, '
+ 'retrying with URL method to download...')
+
+ models_path = download_choice(
+ args,
+ model_info,
+ model_site,
+ models_save_path,
+ use_sdk_cache=False,
+ download_method="url",
+ skip_local_file_exist=skip_local_file_exist,
+ force_download=force_download
+ )
+ return models_path
- else:
- model_site_info = model_info[model_site]
- models_path = []
- for sub_model_name in model_site_info:
- sub_model_info = model_site_info[sub_model_name]
- sub_model_path = ""
- for filename in sub_model_info["file_list"]:
+ models_path = []
+ for sub_model_name in model_site_info:
+ sub_model_info = model_site_info[sub_model_name]
+ if sub_model_name == "patch" and not args.llm_patch:
+ logger.warning(f"Found LLM patch, but llm_patch not enabled, won't download it.")
+ continue
+ sub_model_path = ""
+
+ for filename in sub_model_info["file_list"]:
+ if download_method.lower() == 'sdk':
+ if model_site == "huggingface":
+ logger.info(f'Will download "{filename}" from huggingface repo: "{sub_model_info["repo_id"]}".')
+ sub_model_path = hf_hub_download(
+ repo_id=sub_model_info["repo_id"],
+ filename=filename,
+ subfolder=sub_model_info["subfolder"] if sub_model_info["subfolder"] != "" else None,
+ repo_type=sub_model_info["repo_type"],
+ revision=sub_model_info["revision"],
+ local_dir=os.path.join(models_save_path, sub_model_name) if not use_sdk_cache else None,
+ local_files_only=skip_local_file_exist \
+ if os.path.exists(os.path.join(models_save_path, sub_model_name, filename)) else False,
+ # local_dir_use_symlinks=False if not use_sdk_cache else "auto",
+ # resume_download=True,
+ force_download=force_download
+ )
+ elif model_site == "modelscope":
+ logger.info(f'Will download "{filename}" from modelscope repo: "{sub_model_info["repo_id"]}".')
+ sub_model_path = model_file_download(
+ model_id=sub_model_info["repo_id"],
+ file_path=filename if sub_model_info["subfolder"] == ""
+ else os.path.join(sub_model_info["subfolder"], filename),
+ revision=sub_model_info["revision"],
+ local_files_only=skip_local_file_exist,
+ local_dir=os.path.join(models_save_path, sub_model_name) if not use_sdk_cache else None,
+ )
+ else:
model_url = sub_model_info["file_list"][filename]
logger.info(f'Will download model from url: {model_url}')
sub_model_path = url_download(
+ logger=logger,
url=model_url,
- local_dir=os.path.join(models_save_path,sub_model_name) if sub_model_info["subfolder"] == ""
- else os.path.join(models_save_path,sub_model_name,sub_model_info["subfolder"]),
+ local_dir=os.path.join(models_save_path, sub_model_name) if sub_model_info["subfolder"] == ""
+ else os.path.join(models_save_path, sub_model_name, sub_model_info["subfolder"]),
force_filename=filename,
+ skip_local_file_exist=skip_local_file_exist,
force_download=force_download
)
- models_path.append(sub_model_path)
- return models_path
+ models_path.append(sub_model_path)
+ return models_path
models_path = download_choice(
+ args=args,
model_info=model_info,
- model_site=model_site,
- models_save_path=models_save_path,
- download_method=download_method,
- use_sdk_cache=use_sdk_cache,
- force_download=force_download
+ model_site=str(args.model_site),
+ models_save_path=Path(models_save_path),
+ download_method=str(args.download_method).lower(),
+ use_sdk_cache=args.use_sdk_cache,
+ skip_local_file_exist=args.skip_download,
+ force_download=args.force_download
)
- return Path(models_path[0]), Path(os.path.dirname(models_path[1])), Path(os.path.dirname(models_path[2]))
+ image_adapter_path = Path(models_path[0])
+ clip_path = Path(os.path.dirname(models_path[1]))
+ llm_path = Path(os.path.dirname(models_path[2]))
+ return image_adapter_path, clip_path, llm_path
diff --git a/utils/image.py b/utils/image.py
index 825f4f1..be2e8ae 100644
--- a/utils/image.py
+++ b/utils/image.py
@@ -1,13 +1,43 @@
import base64
+import glob
+import os
+from io import BytesIO
from pathlib import Path
+from typing import List
import cv2
import numpy
-from io import BytesIO
from PIL import Image
+from utils.logger import Logger
+
+SUPPORT_IMAGE_FORMATS = ("bmp", "jpg", "jpeg", "png", "webp")
+
+
+def get_image_paths(
+ logger: Logger,
+ path: Path,
+ recursive: bool = False,
+) -> List[str]:
+ # Get image paths
+ path_to_find = os.path.join(path, '**') if recursive else os.path.join(path, '*')
+ image_paths = sorted(set(
+ [image for image in glob.glob(path_to_find, recursive=recursive)
+ if image.lower().endswith(SUPPORT_IMAGE_FORMATS)]), key=lambda filename: (os.path.splitext(filename)[0])
+ ) if not os.path.isfile(path) else [str(path)] \
+ if str(path).lower().endswith(SUPPORT_IMAGE_FORMATS) else None
+
+ logger.debug(f"Path for inference: \"{path}\"")
+
+ if image_paths is None:
+ logger.error('Invalid dir or image path!')
+ raise FileNotFoundError
-def image_process(image: Image.Image, target_size: int) -> Image.Image:
+ logger.info(f'Found {len(image_paths)} image(s).')
+ return image_paths
+
+
+def image_process(image: Image.Image, target_size: int) -> numpy.ndarray:
# make alpha to white
image = image.convert('RGBA')
new_image = Image.new('RGBA', image.size, 'WHITE')
@@ -53,9 +83,24 @@ def image_process(image: Image.Image, target_size: int) -> Image.Image:
interpolation=cv2.INTER_LANCZOS4
)
+ return padded_image
+
+
+def image_process_image(
+ padded_image: numpy.ndarray
+) -> Image.Image:
return Image.fromarray(padded_image)
+def image_process_gbr(
+ padded_image: numpy.ndarray
+) -> numpy.ndarray:
+ # From PIL RGB to OpenCV GBR
+ padded_image = padded_image[:, :, ::-1]
+ padded_image = padded_image.astype(numpy.float32)
+ return padded_image
+
+
def encode_image_to_base64(image: Image.Image):
with BytesIO() as bytes_output:
image.save(bytes_output, format="PNG")
diff --git a/utils/joy.py b/utils/joy.py
index d4dfb8f..5d5b06f 100644
--- a/utils/joy.py
+++ b/utils/joy.py
@@ -1,4 +1,3 @@
-import glob
import os
import time
from argparse import Namespace
@@ -9,13 +8,12 @@
from PIL import Image
from torch import nn
from tqdm import tqdm
-from transformers import (AutoModel, AutoProcessor, AutoTokenizer,
- PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM)
+from transformers import (AutoModel, AutoProcessor, AutoTokenizer, AutoModelForCausalLM,
+ BitsAndBytesConfig, PreTrainedTokenizer, PreTrainedTokenizerFast)
-from utils.image import image_process
+from utils.image import image_process, get_image_paths, image_process_image
from utils.logger import Logger
-SUPPORT_IMAGE_FORMATS = ("bmp", "jpg", "jpeg", "png")
class ImageAdapter(nn.Module):
def __init__(self, input_features: int, output_features: int):
@@ -30,6 +28,7 @@ def forward(self, vision_outputs: torch.Tensor):
x = self.linear2(x)
return x
+
class Joy:
def __init__(
self,
@@ -62,7 +61,7 @@ def load_model(self):
self.clip_model = self.clip_model.vision_model
self.clip_model.eval()
self.clip_model.requires_grad_(False)
- self.clip_model.to("cuda")
+ self.clip_model.to("cuda" if self.use_gpu else "cpu")
self.logger.info(f'CLIP Loaded in {time.monotonic() - start_time:.1f}s.')
# Load LLM
@@ -72,7 +71,27 @@ def load_model(self):
assert (isinstance(self.llm_tokenizer, PreTrainedTokenizer) or
isinstance(self.llm_tokenizer, PreTrainedTokenizerFast)), \
f"Tokenizer is of type {type(self.llm_tokenizer)}"
- self.llm = AutoModelForCausalLM.from_pretrained(self.llm_path, device_map="auto", torch_dtype=torch.bfloat16)
+ # LLM dType
+ llm_dtype = torch.float32 if self.args.llm_use_cpu or self.args.llm_dtype == "fp32" else torch.float16 \
+ if self.args.llm_dtype == "fp16" else torch.bfloat16 if self.args.llm_dtype == "bf16" else "auto"
+ self.logger.info(f'LLM dtype: {llm_dtype}')
+ # LLM BNB quantization config
+ if self.args.llm_qnt == "4bit":
+ qnt_config = BitsAndBytesConfig(load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=llm_dtype,
+ bnb_4bit_use_double_quant=True)
+ self.logger.info(f'LLM 4bit quantization: Enabled')
+ elif self.args.llm_qnt == "8bit":
+ qnt_config = BitsAndBytesConfig(load_in_8bit=True,
+ llm_int8_enable_fp32_cpu_offload=True)
+ self.logger.info(f'LLM 8bit quantization: Enabled')
+ else:
+ qnt_config = None
+ self.llm = AutoModelForCausalLM.from_pretrained(self.llm_path,
+ device_map="auto" if not self.args.llm_use_cpu else "cpu",
+ torch_dtype=llm_dtype if self.args.llm_qnt == "none" else None,
+ quantization_config=qnt_config)
self.llm.eval()
self.logger.info(f'LLM Loaded in {time.monotonic() - start_time:.1f}s.')
@@ -81,28 +100,15 @@ def load_model(self):
self.image_adapter = ImageAdapter(self.clip_model.config.hidden_size, self.llm.config.hidden_size)
self.image_adapter.load_state_dict(torch.load(self.image_adapter_path, map_location="cpu"))
self.image_adapter.eval()
- self.image_adapter.to("cuda")
+ self.image_adapter.to(self.llm.device)
self.logger.info(f'Image Adapter Loaded in {time.monotonic() - start_time:.1f}s.')
def inference(self):
# Get image paths
- path_to_find = os.path.join(self.args.data_path, '**') \
- if self.args.recursive else os.path.join(self.args.data_path, '*')
- image_paths = sorted(set(
- [image for image in glob.glob(path_to_find, recursive=self.args.recursive)
- if image.lower().endswith(SUPPORT_IMAGE_FORMATS)]),
- key=lambda filename: (os.path.splitext(filename)[0])
- ) if not os.path.isfile(self.args.data_path) else [str(self.args.data_path)] \
- if str(self.args.data_path).lower().endswith(SUPPORT_IMAGE_FORMATS) else None
-
- if image_paths is None:
- self.logger.error('Invalid dir or image path!')
- raise FileNotFoundError
-
- self.logger.info(f'Found {len(image_paths)} image(s).')
+ image_paths = get_image_paths(logger=self.logger, path=Path(self.args.data_path), recursive=self.args.recursive)
def get_caption(
- image: Image,
+ image: Image.Image,
user_prompt: str,
temperature: float = 0.5,
max_new_tokens: int = 300,
@@ -111,7 +117,7 @@ def get_caption(
torch.cuda.empty_cache()
# Preprocess image
image = self.clip_processor(images=image, return_tensors='pt').pixel_values
- image = image.to('cuda')
+ image = image.to(self.llm.device)
# Tokenize the prompt
prompt = self.llm_tokenizer.encode(user_prompt,
return_tensors='pt',
@@ -119,13 +125,13 @@ def get_caption(
truncation=False,
add_special_tokens=False)
# Embed image
- with torch.amp.autocast_mode.autocast('cuda', enabled=True):
+ with torch.amp.autocast_mode.autocast("cuda" if self.use_gpu else "cpu", enabled=True):
vision_outputs = self.clip_model(pixel_values=image, output_hidden_states=True)
image_features = vision_outputs.hidden_states[-2]
embedded_images = self.image_adapter(image_features)
- embedded_images = embedded_images.to('cuda')
+ embedded_images = embedded_images.to(self.llm.device)
# Embed prompt
- prompt_embeds = self.llm.model.embed_tokens(prompt.to('cuda'))
+ prompt_embeds = self.llm.model.embed_tokens(prompt.to(self.llm.device))
assert prompt_embeds.shape == (1, prompt.shape[1],
self.llm.config.hidden_size), \
f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], self.llm.config.hidden_size)}"
@@ -143,7 +149,7 @@ def get_caption(
torch.tensor([[self.llm_tokenizer.bos_token_id]], dtype=torch.long),
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
prompt,
- ], dim=1).to('cuda')
+ ], dim=1).to(self.llm.device)
attention_mask = torch.ones_like(input_ids)
# Generate caption
generate_ids = self.llm.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
@@ -167,6 +173,7 @@ def get_caption(
image_path[:15]) + ' ... ' + image_path[-20:])
image = Image.open(image_path)
image = image_process(image, int(self.args.image_size))
+ image = image_process_image(image)
caption = get_caption(
image=image,
user_prompt=self.args.user_prompt,