Skip to content

Commit

Permalink
[LLaVa] Follow-up for TODOs in LLaVa model (#2010)
Browse files Browse the repository at this point in the history
Llava: 1. Added base64 image support.
2. Merged as_prompt and as_prompt_list.
3. get_image_from_url uses config
  • Loading branch information
anibohara2000 authored Mar 27, 2024
1 parent 0a23af5 commit 47c8350
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 93 deletions.
2 changes: 2 additions & 0 deletions python/mlc_llm/conversation_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ def get_conv_template(name: str) -> Optional[Conversation]:
role_empty_sep=":",
stop_str=["</s>"],
stop_token_ids=[2],
system_prefix_token_ids=[1],
add_role_after_system_message=False,
)
)

Expand Down
130 changes: 56 additions & 74 deletions python/mlc_llm/protocol/conversation_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ class Conversation(BaseModel):
# The conversation history messages.
# Each message is a pair of strings, denoting "(role, content)".
# The content can be None.
messages: List[Tuple[str, Optional[Union[str, List[Dict[str, str]]]]]] = Field(
default_factory=lambda: []
)
messages: List[Tuple[str, Optional[Union[str, List[Dict]]]]] = Field(default_factory=lambda: [])

# The separators between messages when concatenating into a single prompt.
# List size should be either 1 or 2.
Expand Down Expand Up @@ -114,7 +112,8 @@ def from_json_dict(cls: Type[T], json_dict: Dict[str, Any]) -> T:
"""Convert from a json dictionary"""
return Conversation.model_validate(json_dict)

def as_prompt(self) -> str:
# pylint: disable=too-many-branches
def as_prompt(self, config=None) -> List[Union[str, data.ImageData]]:
"""Convert the conversation template and history messages to
a single prompt.
"""
Expand All @@ -124,91 +123,38 @@ def as_prompt(self) -> str:
)

# - Get the message strings.
message_list: List[str] = []
message_list: List[Union[str, data.ImageData]] = []
separators = list(self.seps)
if len(separators) == 1:
separators.append(separators[0])

if system_msg != "":
system_msg += separators[0]
message_list.append(system_msg)

for i, (role, content) in enumerate(self.messages): # pylint: disable=not-an-iterable
if role not in self.roles.keys():
raise ValueError(f'Role "{role}" is not a supported role in {self.roles.keys()}')
separator = separators[role == "assistant"] # check assistant role
if content is not None:
assert isinstance(content, str)
role_prefix = (
""
# Do not append role prefix if this is the first message and there
# is already a system message
if (not self.add_role_after_system_message and system_msg != "" and i == 0)
else self.roles[role] + self.role_content_sep
)
message_string = (
role_prefix
+ self.role_templates[role].replace(
MessagePlaceholders[role.upper()].value, content
)
+ separator
)
else:
message_string = self.roles[role] + self.role_empty_sep
message_list.append(message_string)

if system_msg != "":
system_msg += separators[0]

prompt = system_msg + "".join(message_list)

# Replace the last function string placeholder with actual function string
prompt = self.function_string.join(prompt.rsplit(MessagePlaceholders.FUNCTION.value, 1))
# Replace with remaining function string placeholders with empty string
prompt = prompt.replace(MessagePlaceholders.FUNCTION.value, "")

return prompt

def as_prompt_list(self, image_embed_size=None) -> List[Union[str, data.ImageData]]:
"""Convert the conversation template and history messages to
a list of prompts.
Returns:
List[Union[str, data.ImageData]]: The list of prompts.
"""
# TODO: Unify this function with as_prompt() # pylint: disable=fixme

# pylint: disable=import-outside-toplevel
from ..serve.entrypoints.entrypoint_utils import get_image_from_url

# - Get the system message.
system_msg = self.system_template.replace(
MessagePlaceholders.SYSTEM.value, self.system_message
)

# - Get the message strings.
message_list: List[Union[str, data.ImageData]] = []
separators = list(self.seps)
if len(separators) == 1:
separators.append(separators[0])
if system_msg != "":
system_msg += separators[0]
message_list.append(system_msg)
for role, content in self.messages: # pylint: disable=not-an-iterable
if role not in self.roles.keys():
raise ValueError(f'Role "{role}" is not a supported role in {self.roles.keys()}')
separator = separators[role == "assistant"] # check assistant role
if content is not None:
if isinstance(content, str):
message_string = (
self.roles[role]
+ self.role_content_sep
role_prefix
+ self.role_templates[role].replace(
MessagePlaceholders[role.upper()].value, content
)
+ separator
)
message_list.append(message_string)
else:
assert isinstance(
content, list
), "Content should be a string or a list of dicts"
message_list.append(self.roles[role] + self.role_content_sep)
message_list.append(role_prefix)
for item in content:
assert isinstance(
item, dict
Expand All @@ -221,23 +167,59 @@ def as_prompt_list(self, image_embed_size=None) -> List[Union[str, data.ImageDat
)
)
elif item["type"] == "image_url":
assert image_embed_size is not None, "Image embed size is required"
message_list.append(
data.ImageData(
image=get_image_from_url(item["image_url"]),
embed_size=image_embed_size,
)
assert config is not None, "Model config is required"

# pylint: disable=import-outside-toplevel
from ..serve.entrypoints.entrypoint_utils import (
get_image_from_url,
)

image_url = _get_url_from_item(item)
message_list.append(get_image_from_url(image_url, config))
else:
raise ValueError(f"Unsupported content type: {item['type']}")
message_list.append(separator)

message_list.append(separator)
else:
message_string = self.roles[role] + self.role_empty_sep
message_list.append(message_string)

prompt = message_list
prompt = _combine_consecutive_strings(message_list)

## TODO: Support function calling # pylint: disable=fixme
if not any(isinstance(item, data.ImageData) for item in message_list):
# Replace the last function string placeholder with actual function string
prompt[0] = self.function_string.join(
prompt[0].rsplit(MessagePlaceholders.FUNCTION.value, 1)
)
# Replace with remaining function string placeholders with empty string
prompt[0] = prompt[0].replace(MessagePlaceholders.FUNCTION.value, "")

return prompt


def _get_url_from_item(item: Dict) -> str:
image_url: str
assert "image_url" in item, "Content item should have an image_url field"
if isinstance(item["image_url"], str):
image_url = item["image_url"]
elif isinstance(item["image_url"], dict):
assert (
"url" in item["image_url"]
), "Content image_url item should be a string or a dict with a url field" # pylint: disable=line-too-long
image_url = item["image_url"]["url"]
else:
raise ValueError(
"Content image_url item type not supported. "
"Should be a string or a dict with a url field."
)
return image_url


def _combine_consecutive_strings(lst):
result = []
for item in lst:
if isinstance(item, str) and result and isinstance(result[-1], str):
result[-1] += item
else:
result.append(item)
return result
2 changes: 1 addition & 1 deletion python/mlc_llm/protocol/openai_api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class ChatToolCall(BaseModel):


class ChatCompletionMessage(BaseModel):
content: Optional[Union[str, List[Dict[str, str]]]] = None
content: Optional[Union[str, List[Dict]]] = None
role: Literal["system", "user", "assistant", "tool"]
name: Optional[str] = None
tool_calls: Optional[List[ChatToolCall]] = None
Expand Down
31 changes: 26 additions & 5 deletions python/mlc_llm/serve/entrypoints/entrypoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,27 +98,42 @@ def process_prompts(
return output_prompts


def get_image_from_url(url: str):
def get_image_from_url(url: str, config: Dict) -> data.ImageData:
"""Get the image from the given URL, process and return the image tensor as TVM NDArray."""

# pylint: disable=import-outside-toplevel, import-error
import base64

import requests
import tvm
from PIL import Image
from transformers import CLIPImageProcessor

response = requests.get(url, timeout=5)
image_tensor = Image.open(BytesIO(response.content)).convert("RGB")
if url.startswith("data:image"):
# The image is encoded in base64 format
base64_image = url.split(",")[1]
image_data = base64.b64decode(base64_image)
image_tensor = Image.open(BytesIO(image_data)).convert("RGB")
elif url.startswith("http"):
response = requests.get(url, timeout=5)
image_tensor = Image.open(BytesIO(response.content)).convert("RGB")
else:
raise ValueError(f"Unsupported image URL format: {url}")

image_input_size = get_image_input_size(config)
image_embed_size = get_image_embed_size(config)

image_processor = CLIPImageProcessor(
size={"shortest_edge": 336}, crop_size={"height": 336, "width": 336}
size={"shortest_edge": image_input_size},
crop_size={"height": image_input_size, "width": image_input_size},
)
image_features = tvm.nd.array(
image_processor.preprocess(image_tensor, return_tensors="np")["pixel_values"].astype(
"float16"
)
)
return image_features
image_data = data.ImageData(image_features, image_embed_size)
return image_data


def get_image_embed_size(config: Dict) -> int:
Expand All @@ -127,3 +142,9 @@ def get_image_embed_size(config: Dict) -> int:
patch_size = config["model_config"]["vision_config"]["patch_size"]
embed_size = (image_size // patch_size) ** 2
return embed_size


def get_image_input_size(config: Dict) -> int:
"""Get the image input size from the model config file."""
image_size = config["model_config"]["vision_config"]["image_size"]
return image_size
22 changes: 9 additions & 13 deletions python/mlc_llm/serve/entrypoints/openai_entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,6 @@ async def request_chat_completion(
if error_msg is not None:
return entrypoint_utils.create_error_response(HTTPStatus.BAD_REQUEST, message=error_msg)

content_has_list = any(isinstance(message.content, list) for message in request.messages)
for message in request.messages:
role = message.role
content = message.content
Expand All @@ -406,17 +405,12 @@ async def request_chat_completion(
# - Check prompt length
async_engine.record_event(request_id, event="start tokenization")

if content_has_list:
model_config = ServerContext.get_model_config(request.model)
image_embed_size = entrypoint_utils.get_image_embed_size(model_config)
prompts = entrypoint_utils.process_prompts(
conv_template.as_prompt_list(image_embed_size=image_embed_size),
async_engine.tokenizer.encode,
)
else:
prompts = entrypoint_utils.process_prompts(
conv_template.as_prompt(), async_engine.tokenizer.encode
)
model_config = ServerContext.get_model_config(request.model)
prompts = entrypoint_utils.process_prompts(
conv_template.as_prompt(model_config),
async_engine.tokenizer.encode,
)

async_engine.record_event(request_id, event="finish tokenization")
if conv_template.system_prefix_token_ids is not None:
prompts[0] = conv_template.system_prefix_token_ids + prompts[0]
Expand Down Expand Up @@ -581,5 +575,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
],
model=request.model,
system_fingerprint="",
usage=UsageInfo(prompt_tokens=len(prompt), completion_tokens=num_completion_tokens),
usage=UsageInfo(
prompt_tokens=sum(len(item) for item in prompt), completion_tokens=num_completion_tokens
),
)

0 comments on commit 47c8350

Please sign in to comment.