Skip to content

Commit

Permalink
release v0.7.0
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Apr 26, 2024
1 parent 031775a commit 168f566
Show file tree
Hide file tree
Showing 13 changed files with 163 additions and 44 deletions.
2 changes: 1 addition & 1 deletion data/dataset_info.json
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
},
"mllm_demo": {
"file_name": "mllm_demo.json",
"file_sha1": "b6709b23657d5c42a701f1c5574f3a6edaa40a20",
"file_sha1": "d626cc0ad88a26d0dc9fcb47336821cf486d8bcc",
"formatting": "sharegpt",
"columns": {
"messages": "messages",
Expand Down
83 changes: 76 additions & 7 deletions data/mllm_demo.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
{
"messages": [
{
"content": "Who are they?<image>",
"content": "Who are they?",
"role": "user"
},
{
Expand All @@ -14,18 +14,18 @@
"role": "user"
},
{
"content": "They are celebrating on the soccer field",
"content": "They are celebrating on the soccer field.",
"role": "assistant"
}
],
"images": [
"images/1.jpg"
"mllm_demo_data/1.jpg"
]
},
{
"messages": [
{
"content": "Who is he?<image>",
"content": "Who is he?",
"role": "user"
},
{
Expand All @@ -42,13 +42,13 @@
}
],
"images": [
"images/2.jpg"
"mllm_demo_data/2.jpg"
]
},
{
"messages": [
{
"content": "Please describe this image<image>",
"content": "Please describe this image",
"role": "user"
},
{
Expand All @@ -65,7 +65,76 @@
}
],
"images": [
"images/3.jpg"
"mllm_demo_data/3.jpg"
]
},
{
"messages": [
{
"content": "他们是谁?",
"role": "user"
},
{
"content": "他们是拜仁慕尼黑的凯恩和格雷茨卡。",
"role": "assistant"
},
{
"content": "他们在做什么?",
"role": "user"
},
{
"content": "他们在足球场上庆祝。",
"role": "assistant"
}
],
"images": [
"mllm_demo_data/1.jpg"
]
},
{
"messages": [
{
"content": "他是谁?",
"role": "user"
},
{
"content": "他是来自拜仁慕尼黑的托马斯·穆勒。",
"role": "assistant"
},
{
"content": "他为什么在地上?",
"role": "user"
},
{
"content": "因为他正在双膝跪地滑行庆祝。",
"role": "assistant"
}
],
"images": [
"mllm_demo_data/2.jpg"
]
},
{
"messages": [
{
"content": "请描述这张图片",
"role": "user"
},
{
"content": "中国宇航员桂海潮正在讲话。",
"role": "assistant"
},
{
"content": "他取得过哪些成就?",
"role": "user"
},
{
"content": "他于2022年6月被任命为神舟十六号任务的有效载荷专家,从而成为2023年5月30日进入太空的首位平民宇航员。他负责在轨操作空间科学实验有效载荷。",
"role": "assistant"
}
],
"images": [
"mllm_demo_data/3.jpg"
]
}
]
File renamed without changes
File renamed without changes
File renamed without changes
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_requires():
"unsloth": ["torch==2.2.0", "unsloth[cu121-ampere-torch220]"],
"galore": ["galore-torch"],
"badam": ["badam"],
"vllm": ["vllm>=0.3.3"],
"vllm": ["vllm>=0.4.0"],
"bitsandbytes": ["bitsandbytes>=0.39.0"],
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"],
Expand Down
2 changes: 1 addition & 1 deletion src/llmtuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
from .webui import create_ui, create_web_demo


__version__ = "0.6.4.dev0"
__version__ = "0.7.0"
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]
2 changes: 1 addition & 1 deletion src/llmtuner/chat/hf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _process_args(
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]:
if processor is not None and image is not None and "<image>" not in messages[0]["content"]:
messages[0]["content"] = messages[0]["content"] + "<image>"
messages[0]["content"] = "<image>" + messages[0]["content"]

paired_messages = messages + [{"role": "assistant", "content": ""}]
prompt_ids, _ = template.encode_oneturn(
Expand Down
53 changes: 39 additions & 14 deletions src/llmtuner/chat/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.sequence import MultiModalData


if TYPE_CHECKING:
import torch
from numpy.typing import NDArray
from transformers.image_processing_utils import BaseImageProcessor

from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments

Expand All @@ -39,20 +42,30 @@ def __init__(
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
self.generating_args = generating_args.to_dict()

engine_args = AsyncEngineArgs(
model=model_args.model_name_or_path,
trust_remote_code=True,
download_dir=model_args.cache_dir,
dtype=infer_dtype,
max_model_len=model_args.vllm_maxlen,
tensor_parallel_size=get_device_count() or 1,
gpu_memory_utilization=model_args.vllm_gpu_util,
disable_log_stats=True,
disable_log_requests=True,
enforce_eager=model_args.vllm_enforce_eager,
enable_lora=model_args.adapter_name_or_path is not None,
)
self.model = AsyncLLMEngine.from_engine_args(engine_args)
engine_args = {
"model": model_args.model_name_or_path,
"trust_remote_code": True,
"download_dir": model_args.cache_dir,
"dtype": infer_dtype,
"max_model_len": model_args.vllm_maxlen,
"tensor_parallel_size": get_device_count() or 1,
"gpu_memory_utilization": model_args.vllm_gpu_util,
"disable_log_stats": True,
"disable_log_requests": True,
"enforce_eager": model_args.vllm_enforce_eager,
"enable_lora": model_args.adapter_name_or_path is not None,
}

if model_args.visual_inputs:
# TODO: auto derive from config
# https://github.com/vllm-project/vllm/pull/3042#issuecomment-1984893549
self.image_feature_size = 576
engine_args["image_input_type"] = "pixel_values"
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids("<image>")
engine_args["image_input_shape"] = "1,3,336,336"
engine_args["image_feature_size"] = self.image_feature_size

self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
if model_args.adapter_name_or_path is not None:
self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
else:
Expand All @@ -67,6 +80,9 @@ async def _generate(
**input_kwargs,
) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
if self.processor is not None and image is not None and "<image>" not in messages[0]["content"]:
messages[0]["content"] = "<image>" * self.image_feature_size + messages[0]["content"]

paired_messages = messages + [{"role": "assistant", "content": ""}]
prompt_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
Expand Down Expand Up @@ -110,12 +126,21 @@ async def _generate(
max_tokens=generating_args["max_new_tokens"],
skip_special_tokens=True,
)

if self.processor is not None and image is not None:
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else:
multi_modal_data = None

result_generator = self.model.generate(
prompt=None,
sampling_params=sampling_params,
request_id=request_id,
prompt_token_ids=prompt_ids,
lora_request=self.lora_request,
multi_modal_data=multi_modal_data,
)
return result_generator

Expand Down
50 changes: 37 additions & 13 deletions src/llmtuner/data/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from functools import partial
from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple

from ..extras.constants import IGNORE_INDEX
from ..extras.logging import get_logger
from ..extras.packages import is_pillow_available
from .utils import Role


if is_pillow_available():
from PIL import Image


if TYPE_CHECKING:
from PIL.Image import Image
from numpy.typing import NDArray
from PIL.Image import Image as ImageObject
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
from transformers.image_processing_utils import BaseImageProcessor
from transformers.tokenization_utils import PreTrainedTokenizer
Expand All @@ -20,12 +26,11 @@
logger = get_logger(__name__)


def _preprocess_visual_inputs(model_inputs: Dict[str, Any], processor: "ProcessorMixin", image: "Image") -> None:
def _preprocess_visual_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
# process visual inputs (currently only supports a single image)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"][0]
if "pixel_values" not in model_inputs:
model_inputs["pixel_values"] = []
model_inputs["pixel_values"].append(pixel_values)
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
return image_processor(image, return_tensors="pt")["pixel_values"][0]


def preprocess_pretrain_dataset(
Expand Down Expand Up @@ -66,11 +71,17 @@ def preprocess_supervised_dataset(
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)

for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
continue

if processor is not None:
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]

messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(
Expand Down Expand Up @@ -100,8 +111,8 @@ def preprocess_supervised_dataset(
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
if processor is not None and "images" in examples:
_preprocess_visual_inputs(model_inputs, processor, examples["images"][i][0])
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))

return model_inputs

Expand Down Expand Up @@ -161,11 +172,17 @@ def preprocess_unsupervised_dataset(
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)

for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1:
continue

if processor is not None:
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]

if len(examples["response"][i]) == 1:
messages = examples["prompt"][i] + examples["response"][i]
else:
Expand All @@ -186,8 +203,8 @@ def preprocess_unsupervised_dataset(
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
if processor is not None and "images" in examples:
_preprocess_visual_inputs(model_inputs, processor, examples["images"][i][0])
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))

return model_inputs

Expand All @@ -201,10 +218,17 @@ def preprocess_pairwise_dataset(
) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)

for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
continue

if processor is not None:
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]

chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
prompt_ids, chosen_ids = template.encode_oneturn(
Expand All @@ -231,8 +255,8 @@ def preprocess_pairwise_dataset(
model_inputs["prompt_ids"].append(prompt_ids)
model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids)
if processor is not None and "images" in examples:
_preprocess_visual_inputs(model_inputs, processor, examples["images"][i][0])
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))

return model_inputs

Expand Down
4 changes: 4 additions & 0 deletions src/llmtuner/extras/packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def is_nltk_available():
return _is_package_available("nltk")


def is_pillow_available():
return _is_package_available("PIL")


def is_requests_available():
return _is_package_available("requests")

Expand Down
Loading

0 comments on commit 168f566

Please sign in to comment.