Skip to content

Commit

Permalink
Unify schema for conversation template and embed into mlc-chat-config…
Browse files Browse the repository at this point in the history
….json (mlc-ai#1965)
  • Loading branch information
rickzx authored and Animesh Bohara committed Mar 18, 2024
1 parent cca4c51 commit e3c3d15
Show file tree
Hide file tree
Showing 11 changed files with 410 additions and 181 deletions.
File renamed without changes.
132 changes: 131 additions & 1 deletion cpp/conversation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,130 @@ namespace llm {
void Conversation::LoadJSONOverride(const picojson::value& config_json, bool partial_update) {
std::string err_templ = " in conversion template json file.";
picojson::object config = config_json.get<picojson::object>();

if (config.count("name")) {
CHECK(config["name"].is<std::string>()) << "Invalid name" << err_templ;
this->name = config["name"].get<std::string>();
} else {
CHECK(partial_update) << "Key \"name\" not found.";
}

if (config.count("system_template") && config.count("system_message")) {
std::string system_placeholder = "{system_message}";
CHECK(config["system_template"].is<std::string>()) << "Invalid system template" << err_templ;
CHECK(config["system_message"].is<std::string>()) << "Invalid system message" << err_templ;
std::string system_template = config["system_template"].get<std::string>();
std::string system_msg = config["system_message"].get<std::string>();
std::string system = system_template.replace(system_template.find(system_placeholder),
system_placeholder.length(), system_msg);
this->system = system;
} else {
CHECK(partial_update) << "Key \"system_template\" or \"system_message\" not found.";
}

if (config.count("system_prefix_token_ids")) {
CHECK(config["system_prefix_token_ids"].is<picojson::array>())
<< "Invalid system_prefix_token_ids" << err_templ;
picojson::array prefix_tokens_arr = config["system_prefix_token_ids"].get<picojson::array>();
std::vector<int32_t> prefix_tokens;
for (const picojson::value& prefix_token : prefix_tokens_arr) {
CHECK(prefix_token.is<int64_t>()) << "Invalid prefix_tokens" << err_templ;
prefix_tokens.push_back(prefix_token.get<int64_t>());
}
this->prefix_tokens = prefix_tokens;
}

if (config.count("roles")) {
CHECK(config["roles"].is<picojson::object>()) << "Invalid roles" << err_templ;
picojson::object roles_json = config["roles"].get<picojson::object>();
std::vector<std::string> roles(2);
for (auto [role, role_name] : roles_json) {
CHECK(role_name.is<std::string>());
if (role == "user") {
roles.at(0) = role_name.get<std::string>();
}
if (role == "assistant") {
roles.at(1) = role_name.get<std::string>();
}
}
this->roles = roles;
}

if (config.count("messages")) {
CHECK(config["messages"].is<picojson::array>()) << "Invalid messages" << err_templ;
std::vector<std::vector<std::string>> messages;
picojson::array msgs_arr = config["messages"].get<picojson::array>();
for (const picojson::value& msgs_i : msgs_arr) {
CHECK(msgs_i.is<picojson::array>()) << "Invalid messages" << err_templ;
picojson::array msgs_i_arr = msgs_i.get<picojson::array>();
std::vector<std::string> messages_i;
for (const picojson::value& msg_v : msgs_i_arr) {
CHECK(msg_v.is<std::string>()) << "Invalid messages" << err_templ;
messages_i.push_back(msg_v.get<std::string>());
}
messages.push_back(messages_i);
}
this->messages = messages;
this->offset = messages.size();
} else {
this->offset = 0;
}

if (config.count("seps")) {
std::vector<std::string> seps;
CHECK(config["seps"].is<picojson::array>()) << "Invalid seps" << err_templ;
picojson::array seps_arr = config["seps"].get<picojson::array>();
for (const picojson::value& sep : seps_arr) {
CHECK(sep.is<std::string>()) << "Invalid seps" << err_templ;
seps.push_back(sep.get<std::string>());
}
this->seps = seps;
} else {
CHECK(partial_update) << "Key \"seps\" not found.";
}

if (config.count("role_content_sep")) {
CHECK(config["role_content_sep"].is<std::string>()) << "Invalid role_content_sep" << err_templ;
this->role_msg_sep = config["role_content_sep"].get<std::string>();
} else {
CHECK(partial_update) << "Key \"role_msg_sep\" not found.";
}
if (config.count("role_empty_sep")) {
CHECK(config["role_empty_sep"].is<std::string>()) << "Invalid role_empty_sep" << err_templ;
this->role_empty_sep = config["role_empty_sep"].get<std::string>();
} else {
CHECK(partial_update) << "Key \"role_empty_sep\" not found.";
}

if (config.count("stop_str")) {
CHECK(config["stop_str"].is<picojson::array>()) << "Invalid stop_str" << err_templ;
picojson::array stop_str_arr = config["stop_str"].get<picojson::array>();
if (stop_str_arr.size() >= 1) {
picojson::value stop_str = stop_str_arr.at(0);
CHECK(stop_str.is<std::string>());
this->stop_str = stop_str.get<std::string>();
}
} else {
CHECK(partial_update) << "Key \"stop_str\" not found.";
}

if (config.count("stop_token_ids")) {
CHECK(config["stop_token_ids"].is<picojson::array>()) << "Invalid stop_token_ids" << err_templ;
picojson::array stop_tokens_arr = config["stop_token_ids"].get<picojson::array>();
std::vector<int32_t> stop_tokens;
for (const picojson::value& stop_token : stop_tokens_arr) {
CHECK(stop_token.is<int64_t>()) << "Invalid stop_tokens" << err_templ;
stop_tokens.push_back(stop_token.get<int64_t>());
}
this->stop_tokens = stop_tokens;
} else {
CHECK(partial_update) << "Key \"stop_token_ids\" not found.";
}
}

void Conversation::LoadJSONOverrideLegacy(const picojson::value& config_json, bool partial_update) {
std::string err_templ = " in conversion template json file.";
picojson::object config = config_json.get<picojson::object>();
if (config.count("name")) {
CHECK(config["name"].is<std::string>()) << "Invalid name" << err_templ;
this->name = config["name"].get<std::string>();
Expand Down Expand Up @@ -134,7 +258,13 @@ void Conversation::LoadJSONOverride(const std::string& config_str, bool partial_
LOG(FATAL) << err;
return;
}
LoadJSONOverride(config_json, partial_update);

picojson::object config = config_json.get<picojson::object>();
try {
LoadJSONOverride(config_json, partial_update);
} catch (...) {
LoadJSONOverrideLegacy(config_json, partial_update);
}
}

picojson::value Conversation::SerializeToJSON() const {
Expand Down
12 changes: 12 additions & 0 deletions cpp/conversation.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ class Conversation {
*/
void LoadJSONOverride(const picojson::value& config_json, bool partial_update = false);

/*!
* \brief Load legacy JSON config and overrides options.
*
* \param config_json A json config in picojson type that is partially specifies
* some of the options.
* \param partial_update Whether it's a partial update or full update, if set to true,
* we perform a partial update on some of the provided options; if set to false, all
* options must be provided.
* \note DEPRECATED. This function loads the legacy JSON config value.
*/
void LoadJSONOverrideLegacy(const picojson::value& config_json, bool partial_update = false);

/*!
* \brief Serialize the Conversation to JSON.
* \return Serialized conversion in JSON format.
Expand Down
25 changes: 20 additions & 5 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -558,16 +558,31 @@ class LLMChat {
CHECK(partial_update) << "Key \"shift_fill_factor\" not found.";
}
if (config.count("conv_template")) {
ICHECK(config["conv_template"].is<std::string>());
std::string conv_template = config["conv_template"].get<std::string>();
this->conversation_ = Conversation::FromTemplate(conv_template);
if (config["conv_template"].is<picojson::object>()) {
this->conversation_.LoadJSONOverride(config["conv_template"], false);
} else {
ICHECK(config["conv_template"].is<std::string>());
LOG(WARNING)
<< "Legacy conversation template detected. It will be deprecated in the future. "
"Please regenerate mlc-chat-config.json with the latest version";
std::string conv_template = config["conv_template"].get<std::string>();
this->conversation_ = Conversation::FromTemplate(conv_template);
}
if (config.count("conv_config")) {
// conv_config can override conv_template
this->conversation_.LoadJSONOverride(config["conv_config"], true);
try {
this->conversation_.LoadJSONOverride(config["conv_config"], true);
} catch (...) {
this->conversation_.LoadJSONOverrideLegacy(config["conv_config"], true);
}
}
} else if (config.count("conv_config")) {
// without conv template, conv_config needs to be a complete config
this->conversation_.LoadJSONOverride(config["conv_config"], false);
try {
this->conversation_.LoadJSONOverride(config["conv_config"], false);
} catch (...) {
this->conversation_.LoadJSONOverrideLegacy(config["conv_config"], false);
}
} else {
CHECK(partial_update) << "Key \"conv_template\" and \"conv_config\" not found.";
}
Expand Down
2 changes: 1 addition & 1 deletion docs/deploy/python.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ We provide an example below.
# Using a `ConvConfig`, we modify `system`, a field in the conversation template
# `system` refers to the prompt encoded before starting the chat
conv_config = ConvConfig(system='Please show as much happiness as you can when talking to me.')
conv_config = ConvConfig(system_message='Please show as much happiness as you can when talking to me.')
# We then include the `ConvConfig` instance in `ChatConfig` while overriding `max_gen_len`
# Note that `conv_config` is an optional subfield of `chat_config`
Expand Down
Loading

0 comments on commit e3c3d15

Please sign in to comment.