Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify schema for conversation template and embed into mlc-chat-config.json #1965

Merged
merged 3 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
132 changes: 131 additions & 1 deletion cpp/conversation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,130 @@ namespace llm {
void Conversation::LoadJSONOverride(const picojson::value& config_json, bool partial_update) {
std::string err_templ = " in conversion template json file.";
picojson::object config = config_json.get<picojson::object>();

if (config.count("name")) {
CHECK(config["name"].is<std::string>()) << "Invalid name" << err_templ;
this->name = config["name"].get<std::string>();
} else {
CHECK(partial_update) << "Key \"name\" not found.";
}

if (config.count("system_template") && config.count("system_message")) {
std::string system_placeholder = "{system_message}";
CHECK(config["system_template"].is<std::string>()) << "Invalid system template" << err_templ;
CHECK(config["system_message"].is<std::string>()) << "Invalid system message" << err_templ;
std::string system_template = config["system_template"].get<std::string>();
std::string system_msg = config["system_message"].get<std::string>();
std::string system = system_template.replace(system_template.find(system_placeholder),
system_placeholder.length(), system_msg);
this->system = system;
} else {
CHECK(partial_update) << "Key \"system_template\" or \"system_message\" not found.";
}

if (config.count("system_prefix_token_ids")) {
CHECK(config["system_prefix_token_ids"].is<picojson::array>())
<< "Invalid system_prefix_token_ids" << err_templ;
picojson::array prefix_tokens_arr = config["system_prefix_token_ids"].get<picojson::array>();
std::vector<int32_t> prefix_tokens;
for (const picojson::value& prefix_token : prefix_tokens_arr) {
CHECK(prefix_token.is<int64_t>()) << "Invalid prefix_tokens" << err_templ;
prefix_tokens.push_back(prefix_token.get<int64_t>());
}
this->prefix_tokens = prefix_tokens;
}

if (config.count("roles")) {
CHECK(config["roles"].is<picojson::object>()) << "Invalid roles" << err_templ;
picojson::object roles_json = config["roles"].get<picojson::object>();
std::vector<std::string> roles(2);
for (auto [role, role_name] : roles_json) {
CHECK(role_name.is<std::string>());
if (role == "user") {
roles.at(0) = role_name.get<std::string>();
}
if (role == "assistant") {
roles.at(1) = role_name.get<std::string>();
}
}
this->roles = roles;
}

if (config.count("messages")) {
CHECK(config["messages"].is<picojson::array>()) << "Invalid messages" << err_templ;
std::vector<std::vector<std::string>> messages;
picojson::array msgs_arr = config["messages"].get<picojson::array>();
for (const picojson::value& msgs_i : msgs_arr) {
CHECK(msgs_i.is<picojson::array>()) << "Invalid messages" << err_templ;
picojson::array msgs_i_arr = msgs_i.get<picojson::array>();
std::vector<std::string> messages_i;
for (const picojson::value& msg_v : msgs_i_arr) {
CHECK(msg_v.is<std::string>()) << "Invalid messages" << err_templ;
messages_i.push_back(msg_v.get<std::string>());
}
messages.push_back(messages_i);
}
this->messages = messages;
this->offset = messages.size();
} else {
this->offset = 0;
}

if (config.count("seps")) {
std::vector<std::string> seps;
CHECK(config["seps"].is<picojson::array>()) << "Invalid seps" << err_templ;
picojson::array seps_arr = config["seps"].get<picojson::array>();
for (const picojson::value& sep : seps_arr) {
CHECK(sep.is<std::string>()) << "Invalid seps" << err_templ;
seps.push_back(sep.get<std::string>());
}
this->seps = seps;
} else {
CHECK(partial_update) << "Key \"seps\" not found.";
}

if (config.count("role_content_sep")) {
CHECK(config["role_content_sep"].is<std::string>()) << "Invalid role_content_sep" << err_templ;
this->role_msg_sep = config["role_content_sep"].get<std::string>();
} else {
CHECK(partial_update) << "Key \"role_msg_sep\" not found.";
}
if (config.count("role_empty_sep")) {
CHECK(config["role_empty_sep"].is<std::string>()) << "Invalid role_empty_sep" << err_templ;
this->role_empty_sep = config["role_empty_sep"].get<std::string>();
} else {
CHECK(partial_update) << "Key \"role_empty_sep\" not found.";
}

if (config.count("stop_str")) {
CHECK(config["stop_str"].is<picojson::array>()) << "Invalid stop_str" << err_templ;
picojson::array stop_str_arr = config["stop_str"].get<picojson::array>();
if (stop_str_arr.size() >= 1) {
picojson::value stop_str = stop_str_arr.at(0);
CHECK(stop_str.is<std::string>());
this->stop_str = stop_str.get<std::string>();
}
} else {
CHECK(partial_update) << "Key \"stop_str\" not found.";
}

if (config.count("stop_token_ids")) {
CHECK(config["stop_token_ids"].is<picojson::array>()) << "Invalid stop_token_ids" << err_templ;
picojson::array stop_tokens_arr = config["stop_token_ids"].get<picojson::array>();
std::vector<int32_t> stop_tokens;
for (const picojson::value& stop_token : stop_tokens_arr) {
CHECK(stop_token.is<int64_t>()) << "Invalid stop_tokens" << err_templ;
stop_tokens.push_back(stop_token.get<int64_t>());
}
this->stop_tokens = stop_tokens;
} else {
CHECK(partial_update) << "Key \"stop_token_ids\" not found.";
}
}

void Conversation::LoadJSONOverrideLegacy(const picojson::value& config_json, bool partial_update) {
std::string err_templ = " in conversion template json file.";
picojson::object config = config_json.get<picojson::object>();
if (config.count("name")) {
CHECK(config["name"].is<std::string>()) << "Invalid name" << err_templ;
this->name = config["name"].get<std::string>();
Expand Down Expand Up @@ -134,7 +258,13 @@ void Conversation::LoadJSONOverride(const std::string& config_str, bool partial_
LOG(FATAL) << err;
return;
}
LoadJSONOverride(config_json, partial_update);

picojson::object config = config_json.get<picojson::object>();
try {
LoadJSONOverride(config_json, partial_update);
} catch (...) {
LoadJSONOverrideLegacy(config_json, partial_update);
}
}

picojson::value Conversation::SerializeToJSON() const {
Expand Down
12 changes: 12 additions & 0 deletions cpp/conversation.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ class Conversation {
*/
void LoadJSONOverride(const picojson::value& config_json, bool partial_update = false);

/*!
* \brief Load legacy JSON config and overrides options.
*
* \param config_json A json config in picojson type that is partially specifies
* some of the options.
* \param partial_update Whether it's a partial update or full update, if set to true,
* we perform a partial update on some of the provided options; if set to false, all
* options must be provided.
* \note DEPRECATED. This function loads the legacy JSON config value.
*/
void LoadJSONOverrideLegacy(const picojson::value& config_json, bool partial_update = false);

/*!
* \brief Serialize the Conversation to JSON.
* \return Serialized conversion in JSON format.
Expand Down
25 changes: 20 additions & 5 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -558,16 +558,31 @@ class LLMChat {
CHECK(partial_update) << "Key \"shift_fill_factor\" not found.";
}
if (config.count("conv_template")) {
ICHECK(config["conv_template"].is<std::string>());
std::string conv_template = config["conv_template"].get<std::string>();
this->conversation_ = Conversation::FromTemplate(conv_template);
if (config["conv_template"].is<picojson::object>()) {
this->conversation_.LoadJSONOverride(config["conv_template"], false);
} else {
ICHECK(config["conv_template"].is<std::string>());
LOG(WARNING)
<< "Legacy conversation template detected. It will be deprecated in the future. "
"Please regenerate mlc-chat-config.json with the latest version";
std::string conv_template = config["conv_template"].get<std::string>();
this->conversation_ = Conversation::FromTemplate(conv_template);
}
if (config.count("conv_config")) {
// conv_config can override conv_template
this->conversation_.LoadJSONOverride(config["conv_config"], true);
try {
this->conversation_.LoadJSONOverride(config["conv_config"], true);
} catch (...) {
this->conversation_.LoadJSONOverrideLegacy(config["conv_config"], true);
}
}
} else if (config.count("conv_config")) {
// without conv template, conv_config needs to be a complete config
this->conversation_.LoadJSONOverride(config["conv_config"], false);
try {
this->conversation_.LoadJSONOverride(config["conv_config"], false);
} catch (...) {
this->conversation_.LoadJSONOverrideLegacy(config["conv_config"], true);
rickzx marked this conversation as resolved.
Show resolved Hide resolved
}
} else {
CHECK(partial_update) << "Key \"conv_template\" and \"conv_config\" not found.";
}
Expand Down
92 changes: 54 additions & 38 deletions python/mlc_llm/chat_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import tvm
from tvm.runtime import disco # pylint: disable=unused-import

from mlc_llm.protocol.conversation_protocol import Conversation
from mlc_llm.support import logging
from mlc_llm.support.auto_device import detect_device
from mlc_llm.support.config import ConfigBase
Expand Down Expand Up @@ -44,58 +45,61 @@ class ConvConfig: # pylint: disable=too-many-instance-attributes

Since the configuration is partial, everything will be ``Optional``.

The parameters are the same as :class:`mlc_llm.protocol.conversation_protocol.Conversation`
rickzx marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
name : Optional[str]
Name of the conversation.
system : Optional[str]
The prompt encoded before starting the chat.
roles : Optional[List[str]]
An array that describes the role names of the user and the model. These
names are specific to the model being used.
messages : Optional[List[List[str]]]
The chat history represented as an array of string pairs in the following
format: ``[[role_0, msg_0], [role_1, msg_1], ...]``.
offset : Optional[int]
The offset used to begin the chat from the chat history. When offset
is not ``0``, ``messages[0:offset-1]`` will be encoded.
separator_style : Optional[int]
Specifies whether we are in chat-bot mode (``0``) or pure LM prompt mode (``1``).
system_template : Optional[str]
The system prompt template, it optionally contains the system
message placeholder, and the placeholder will be replaced with
the system message below.
system_message : Optional[str]
The content of the system prompt (without the template format).
system_prefix_token_ids : Optional[List[int]]
The system token ids to be prepended at the beginning of tokenized
generated prompt.
roles : Optional[Dict[str, str]]
The conversation roles
role_templates : Optional[Dict[str, str]]
The roles prompt template, it optionally contains the defaults
message placeholders and will be replaced by actual content
messages : Optional[List[Tuple[str, Optional[str]]]]
The conversation history messages.
Each message is a pair of strings, denoting "(role, content)".
The content can be None.
seps : Optional[List[str]]
An array of strings indicating the separators to be used after a user
message and a model message respectively.
role_msg_sep : Optional[str]
A string indicating the separator between a role and a message.
role_content_sep : Optional[str]
The separator between the role and the content in a message.
role_empty_sep : Optional[str]
A string indicating the separator to append to a role when there is no message yet.
stop_str : Optional[str]
The separator between the role and empty contents.
stop_str : Optional[List[str]]
When the ``stop_str`` is encountered, the model will stop generating output.
stop_tokens : Optional[List[int]]
stop_token_ids : Optional[List[int]]
A list of token IDs that act as stop tokens.
prefix_tokens : Optional[List[int]]
Token list prefixing the conversation.
add_bos : Optional[bool]
Determines whether a beginning-of-string (bos) token should be added
before the input tokens.
function_string : Optional[str]
The function calling string.
use_function_calling : Optional[bool]
Whether using function calling or not, helps check for output message format in API call.
"""

name: Optional[str] = None
system: Optional[str] = None
roles: Optional[List[str]] = None
messages: Optional[List[List[str]]] = None
offset: Optional[int] = None
separator_style: Optional[int] = None
system_template: Optional[str] = None
system_message: Optional[str] = None
system_prefix_token_ids: Optional[List[int]] = None
roles: Optional[Dict[str, str]] = None
role_templates: Optional[Dict[str, str]] = None
messages: Optional[List[Tuple[str, Optional[str]]]] = None
seps: Optional[List[str]] = None
role_msg_sep: Optional[str] = None
role_content_sep: Optional[str] = None
role_empty_sep: Optional[str] = None
stop_str: Optional[str] = None
stop_tokens: Optional[List[int]] = None
prefix_tokens: Optional[List[int]] = None
add_bos: Optional[bool] = None

def __post_init__(self):
if self.messages is not None and self.offset is None:
self.offset = len(self.messages)
stop_str: Optional[List[str]] = None
stop_token_ids: Optional[List[int]] = None
function_string: Optional[str] = None
use_function_calling: Optional[bool] = None


@dataclass
Expand Down Expand Up @@ -192,7 +196,7 @@ class ChatConfig(ConfigBase): # pylint: disable=too-many-instance-attributes

model_lib: Optional[str] = None
local_id: Optional[str] = None
conv_template: Optional[str] = None
conv_template: Optional[Union[str, Conversation]] = None
temperature: Optional[float] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
Expand All @@ -217,6 +221,8 @@ class ChatConfig(ConfigBase): # pylint: disable=too-many-instance-attributes

@classmethod
def _from_json(cls, json_obj: dict):
if "conv_template" in json_obj and isinstance(json_obj["conv_template"], dict):
json_obj["conv_template"] = Conversation.from_json_dict(json_obj["conv_template"])
return cls(**{k: v for k, v in json_obj.items() if k in inspect.signature(cls).parameters})


Expand Down Expand Up @@ -440,6 +446,13 @@ def _get_chat_config(config_file_path: str, user_chat_config: Optional[ChatConfi
"override the full model library path instead."
)
warnings.warn(warn_msg)
elif field_name == "conv_template" and isinstance(field_value, Conversation):
warn_msg = (
'WARNING: Do not override "conv_template" in ChatConfig. '
'Please override "conv_config" instead.'
"This override will be ignored."
)
warnings.warn(warn_msg)
else:
setattr(final_chat_config, field_name, field_value)
return final_chat_config
Expand Down Expand Up @@ -613,6 +626,9 @@ def _convert_chat_config_to_json_str(
conv_dict[conv_k] = conv_v
chat_dict[key] = conv_dict
continue
if key == "conv_template" and isinstance(value, Conversation):
chat_dict[key] = Conversation.to_json_dict(value)
continue
if value is not None:
chat_dict[key] = value

Expand Down
18 changes: 15 additions & 3 deletions python/mlc_llm/interface/gen_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import json
import shutil
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from mlc_llm.conversation_template import ConvTemplateRegistry
from mlc_llm.model import Model
from mlc_llm.quantization import Quantization
from mlc_llm.support import convert_tiktoken, logging
Expand Down Expand Up @@ -45,7 +46,7 @@ class MLCChatConfig: # pylint: disable=too-many-instance-attributes
repetition_penalty: float = None
top_p: float = None
# Conversation template
conv_template: str = None
conv_template: Union[str, Dict[str, Any]] = None
pad_token_id: int = None
bos_token_id: int = None
eos_token_id: int = None
Expand Down Expand Up @@ -89,6 +90,17 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
):
"""Entrypoint of MLC Chat configuration generation."""
# Step 1. Initialize `mlc-chat-config.json` using `config.json`
conversation_reg = ConvTemplateRegistry.get_conv_template(conv_template)
if conversation_reg is None:
logger.warning(
"%s: Conversation template is not registered in ConvTemplateRegistry: %s",
red("Warning"),
conv_template,
)
conversation = conv_template # type: ignore
else:
conversation = conversation_reg.to_json_dict() # type: ignore
rickzx marked this conversation as resolved.
Show resolved Hide resolved

model_config = ModelConfigOverride(
context_window_size=context_window_size,
sliding_window_size=sliding_window_size,
Expand All @@ -107,7 +119,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
prefill_chunk_size=model_config.prefill_chunk_size,
attention_sink_size=getattr(model_config, "attention_sink_size", -1),
tensor_parallel_shards=model_config.tensor_parallel_shards,
conv_template=conv_template,
conv_template=conversation,
)
# Step 2. Load `generation_config.json` and `config.json` for text-generation related configs
for generation_config_filename in ["generation_config.json", "config.json"]:
Expand Down
Loading
Loading