Skip to content

Commit

Permalink
[Vision] Support image input in DebugChat (#2913)
Browse files Browse the repository at this point in the history
This PR supports feeding an image url to DebugChat by generalizing
its implementation for tokenization and embedding. Previous pure-text
usage remains the same, and users can pass in `--image-url`
followed by the image URL.

We also add `--disable-instrument` to disable dumping kernel
input/output details for faster generation when instrumenting is
not needed.
  • Loading branch information
CharlieFRuan authored Sep 18, 2024
1 parent 7b53664 commit 571d380
Showing 1 changed file with 111 additions and 30 deletions.
141 changes: 111 additions & 30 deletions python/mlc_llm/testing/debug_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import random
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import tvm
Expand All @@ -16,7 +16,7 @@
from mlc_llm.conversation_template import ConvTemplateRegistry
from mlc_llm.interface.help import HELP
from mlc_llm.protocol.mlc_chat_config import MLCChatConfig
from mlc_llm.serve import engine_utils
from mlc_llm.serve import data, engine_utils
from mlc_llm.support.argparse import ArgumentParser
from mlc_llm.support.auto_device import detect_device
from mlc_llm.support.style import green, red
Expand All @@ -41,11 +41,15 @@ def _load_params(


def _get_tvm_module(
model_weight_path: str, lib_path: str, device: Device, instrument: tvm.runtime.PackedFunc
model_weight_path: str,
lib_path: str,
device: Device,
instrument: Union[tvm.runtime.PackedFunc, None],
):
ex = tvm.runtime.load_module(lib_path)
vm = relax.VirtualMachine(ex, device)
vm.set_instrument(instrument)
if instrument is not None:
vm.set_instrument(instrument)
metadata = _extract_metadata(ex)
params = _load_params(model_weight_path, device, metadata)
return vm.module, params, metadata
Expand Down Expand Up @@ -151,6 +155,8 @@ def __init__( # pylint: disable=too-many-arguments
debug_dir: Path,
device: Optional[str] = "auto",
debug_instrument: Optional[Any] = None,
is_image_model: Optional[bool] = False,
disable_instrument: Optional[bool] = False,
):
"""_summary_
Expand Down Expand Up @@ -200,12 +206,24 @@ def instrument(
- before_run: whether it is before or after call.
- ret_value: the return value of the call, only valid after run.
- args: the arguments being passed to call.
is_image_model: Optional[bool]
Whether the model support image input. If so, will look for image embedding method.
Default to False.
disable_instrument: Optional[bool]
If true, will not use debug instrument for faster generation. Default to False.
"""
self.debug_dir = debug_dir
self.device = detect_device(device)
self.instrument = (
debug_instrument if debug_instrument else DefaultDebugInstrument(debug_dir / "prefill")
)
if disable_instrument:
self.instrument = None
else:
self.instrument = (
debug_instrument
if debug_instrument
else DefaultDebugInstrument(debug_dir / "prefill")
)
self.mod, self.params, self.metadata = _get_tvm_module(
model, model_lib, self.device, self.instrument
)
Expand Down Expand Up @@ -234,6 +252,14 @@ def instrument(
except AttributeError as exc:
raise RuntimeError("DebugChat only supports separate embedding layer") from exc

if is_image_model:
try:
self.embed_image_func = self.mod["image_embed"]
except AttributeError as exc:
raise RuntimeError(
"Expect the model to be an image model, but cannot find `image_embed`."
) from exc

self.prefill_func = self.mod["prefill"]
self.decode_func = self.mod["decode"]
self.create_kv_cache_func = None
Expand All @@ -247,11 +273,25 @@ def instrument(

self.appeared_token_freq: Dict[int, int] = {}

def _tokenize(self, prompt: str) -> tvm.nd.array:
def _preprocess_prompts(
self, prompt: str, image_url: Optional[str] = None
) -> List[Union[List[int], data.ImageData]]:
print("======================= Starts Tokenization & Embedding =======================")
# Step 0. Generate prompt string using conversation template
self.conversation.messages.append(("user", prompt))
if image_url is None:
self.conversation.messages.append(("user", prompt))
else:
self.conversation.messages.append(
(
"user",
[
{"type": "image_url", "image_url": image_url},
{"type": "text", "text": prompt},
],
)
)
self.conversation.messages.append(("assistant", None))

with open(self.config_file_path, "r", encoding="utf-8") as file:
config = json.load(file)
parsed_prompt = self.conversation.as_prompt(config)
Expand All @@ -261,19 +301,41 @@ def _tokenize(self, prompt: str) -> tvm.nd.array:
)
tokens = engine_utils.process_prompts(parsed_prompt, self.tokenizer.encode) # type: ignore

# TODO: Handle ImageData in DebugChat # pylint: disable=fixme
assert len(tokens) == 1, "DebugChat will only handle TextData for now"
if self.conversation.system_prefix_token_ids is not None:
tokens[0] = self.conversation.system_prefix_token_ids + tokens[0]

tokens = tvm.nd.array(np.array(tokens[0]).astype("int32"), device=self.device)
return tokens

def _embed(self, tokens: tvm.nd.array) -> Tuple[tvm.nd.NDArray, int]:
input_len = tokens.shape[0]
embedding = self.embed_func(tokens, self.params)
embedding = self.nd_view_func(embedding, ShapeTuple([1, input_len, embedding.shape[1]]))
return embedding, input_len
def _embed(
self, data_inputs: List[Union[List[int], data.ImageData]]
) -> Tuple[tvm.nd.NDArray, int]:
# We currently convert to numpy after embedded, concat in numpy, then convert back to
# tvm ndarray; could be more optimized; but may suffice for debug purposes.
embeddings = []
for data_input in data_inputs:
if isinstance(data_input, data.ImageData):
# Process image data
# print(f"data_input.get_embed_size(): {data_input.embed_size}")
image_input = data_input.image
if data_input.image.device != self.device:
image_input = data_input.image.copyto(self.device)
embeddings.append(self.embed_image_func(image_input, self.params).asnumpy())
else:
# Process token data
data_input = tvm.nd.array(np.array(data_input).astype("int32"), device=self.device)
embeddings.append(self.embed_func(data_input, self.params).asnumpy())
# for embedding in embeddings:
# print(f"embedding.shape: {embedding.shape}")

# Concatenate
concat_embeddings = tvm.nd.array(np.concatenate(embeddings, axis=0), device=self.device)
concat_embeddings = self.nd_view_func(
concat_embeddings,
ShapeTuple([1, concat_embeddings.shape[0], concat_embeddings.shape[1]]),
)
input_len = concat_embeddings.shape[1]

return concat_embeddings, input_len

def _prefill(self, embedding: tvm.nd.NDArray, input_len: int):
print("======================= Starts Prefill =======================")
Expand Down Expand Up @@ -314,9 +376,7 @@ def _prefill(self, embedding: tvm.nd.NDArray, input_len: int):
return logits, kv_caches

def _decode(self, token: int, kv_caches: Object):
embedding, _ = self._embed(
tvm.nd.array(np.array([token]).astype("int32"), device=self.device)
)
embedding, _ = self._embed([[token]])
self.begin_forward_func(kv_caches, ShapeTuple([0]), ShapeTuple([1]))
logits, kv_caches = self.decode_func(embedding, kv_caches, self.params)
self.end_forward_func(kv_caches)
Expand Down Expand Up @@ -352,7 +412,8 @@ def _sample_token_from_logits(
self._apply_presence_and_freq_penalty(logits_np, presence_penalty, frequency_penalty)

logits_np = self._softmax_with_temperature(logits_np, temperature)
np.savez(self.instrument.debug_out / "logits.npz", logits_np)
if self.instrument is not None:
np.savez(self.instrument.debug_out / "logits.npz", logits_np)

logits = logits.copyfrom(logits_np)
next_token = self.sample_topp_from_prob_func(logits, top_p, random.random())
Expand All @@ -362,6 +423,7 @@ def generate(
self,
prompt: str,
generate_length: int,
image_url: Optional[str] = None,
):
"""Generates the response from the model given a user prompt. User will need to
specify the generation length for debugging purpose. For example, a generation
Expand All @@ -377,23 +439,26 @@ def generate(
"""
out_tokens = []

input_tokens = self._tokenize(prompt)
print(f"{green('Input tokens')}: {input_tokens.numpy()}")
embedding, input_len = self._embed(input_tokens)
data_inputs = self._preprocess_prompts(prompt, image_url)
print(f"{green('Data inputs: ')}: {data_inputs}")
embedding, input_len = self._embed(data_inputs)
logits, kv_caches = self._prefill(embedding, input_len)
next_token = self._sample_token_from_logits(logits)
out_tokens.append(next_token)
path_str = (self.debug_dir / "prefill").as_posix()
print(f"Debug instrument output dumped to {green(path_str)}")
if self.instrument is not None:
path_str = (self.debug_dir / "prefill").as_posix()
print(f"Debug instrument output dumped to {green(path_str)}")

print("======================= Starts Decode =======================")
for i in range(generate_length - 1):
self.instrument.reset(self.debug_dir / f"decode_{i}")
if self.instrument is not None:
self.instrument.reset(self.debug_dir / f"decode_{i}")
logits = self._decode(next_token, kv_caches)
next_token = self._sample_token_from_logits(logits)
out_tokens.append(next_token)
path_str = (self.debug_dir / f"decode_{i}").as_posix()
print(f"Debug instrument output dumped to {green(path_str)}")
if self.instrument is not None:
path_str = (self.debug_dir / f"decode_{i}").as_posix()
print(f"Debug instrument output dumped to {green(path_str)}")

if next_token in self.conversation.stop_token_ids:
break
Expand Down Expand Up @@ -440,15 +505,31 @@ def main():
default="auto",
help=HELP["device_compile"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--image-url",
type=str,
required=False,
help="Image to prefill into the model, can only be set for image models",
)
parser.add_argument(
"--disable-instrument",
action="store_true",
help=(
"Disable dumping customizable detailed information of kernel input "
+ "and output, hence making generation faster."
),
)
parsed = parser.parse_args()
dc = DebugChat(
model=parsed.model,
model_lib=parsed.model_lib,
debug_dir=Path(parsed.debug_dir),
device=parsed.device,
is_image_model=parsed.image_url is not None,
disable_instrument=parsed.disable_instrument,
)

dc.generate(parsed.prompt, parsed.generate_len)
dc.generate(parsed.prompt, parsed.generate_len, parsed.image_url)


if __name__ == "__main__":
Expand Down

0 comments on commit 571d380

Please sign in to comment.