Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPP_CLI] MLC Cli App over JSONEngine interface #3114

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ if(NOT CARGO_EXECUTABLE)
message(FATAL_ERROR "Cargo is not found! Please install cargo.")
endif()

add_subdirectory(apps/mlc_cli_chat)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest using USE_CLI_CHAT (default value: false) to determine whether to compile the executable.


# when this option is on, we install all static lib deps into lib
if(MLC_LLM_INSTALL_STATIC_LIB)
install(TARGETS mlc_llm_static tokenizers_cpp sentencepiece-static tvm_runtime
Expand All @@ -178,6 +180,7 @@ else()
mlc_llm_static
tokenizers_cpp
sentencepiece-static
mlc_cli_chat
RUNTIME_DEPENDENCY_SET
tokenizers_c
RUNTIME DESTINATION bin
Expand Down
31 changes: 31 additions & 0 deletions apps/mlc_cli_chat/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
cmake_policy(SET CMP0069 NEW) # suppress cmake warning about IPO

set(MLC_CLI_SOURCES
mlc_cli_chat.cc
chat_state.cc
engine.cc
)
set(MLC_CLI_LINKER_LIBS "")

set(
MLC_CLI_CHAT_INCLUDES
../../3rdparty/tvm/include
../../3rdparty/tvm/3rdparty/dlpack/include
../../3rdparty/tvm/3rdparty/dmlc-core/include
../../3rdparty/tvm/3rdparty/picojson
../../3rdparty/tokenizers-cpp/include
../..//3rdparty/xgrammar/include
)

add_executable(mlc_cli_chat ${MLC_CLI_SOURCES})
target_include_directories(mlc_cli_chat PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} ${MLC_CLI_CHAT_INCLUDES} ${PROJECT_SOURCE_DIR}/cpp)
target_link_libraries(mlc_cli_chat PUBLIC mlc_llm ${TVM_RUNTIME_LINKER_LIBS})

if(USE_CUDA)
include(../../3rdparty/tvm/cmake/utils/Utils.cmake)
include(../../3rdparty/tvm/cmake/utils/FindCUDA.cmake)
find_cuda(${USE_CUDA} ${USE_CUDNN})
target_link_libraries(mlc_cli_chat PUBLIC ${CUDA_NVRTC_LIBRARY})
target_link_libraries(mlc_cli_chat PUBLIC ${CUDA_CUDART_LIBRARY})
target_link_libraries(mlc_cli_chat PUBLIC ${CUDA_CUDA_LIBRARY})
endif()
3 changes: 3 additions & 0 deletions apps/mlc_cli_chat/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#MLC Chat Cli Application

A native app application that can load and run MLC models on cli.
18 changes: 18 additions & 0 deletions apps/mlc_cli_chat/base.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*!
* Copyright (c) 2023-2025 by Contributors
* \file base.h
*/

#ifndef MLC_CLI_CHAT_BASE_H
#define MLC_CLI_CHAT_BASE_H

#include <dlpack/dlpack.h>

#include <string>
#include <unordered_map>

struct Message {
std::unordered_map<std::string, std::string> content;
};

#endif
110 changes: 110 additions & 0 deletions apps/mlc_cli_chat/chat_state.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*!
* Copyright (c) 2023 by Contributors
* \file chat_state.cc
*/

#include "chat_state.h"

#include <iostream>

#include "base.h"
#include "engine.h"

void print_help_str() {
std::string help_string = R"("""You can use the following special commands:
/help print the special commands
/exit quit the cli
/stats print out stats of last request (token/sec)
Multi-line input: Use escape+enter to start a new line.
""")";

std::cout << help_string << std::endl;
}

ChatState::ChatState(std::string model_path, std::string model_lib_path, std::string mode,
std::string device, int device_id) {
history_window_begin = 0;
__json_wrapper =
std::make_shared<JSONFFIEngineWrapper>(model_path, model_lib_path, mode, device, device_id);
}

void ChatState::slide_history() {
size_t history_window_size = history.size() - history_window_begin;
history_window_begin += ((history_window_size + 3) / 4) * 2;
}

std::vector<Message> ChatState::get_current_history_window() {
return std::vector<Message>(history.begin() + history_window_begin, history.end());
}

int ChatState::generate(const std::string& prompt) {
// setting back the finish_reason_length
bool finish_reason_length = false;

// User Message
Message new_message;
new_message.content["role"] = "user";
new_message.content["content"] = prompt;
history.push_back(new_message);

auto curr_window = get_current_history_window();

std::string output_text{""};

output_text = (*__json_wrapper).chat.completions.create(curr_window);

if (__json_wrapper->engine_state->finish_reason == "length") {
finish_reason_length = true;
}

if (finish_reason_length) {
std::cout << "\n[output truncated due to context length limit...]";
}

Message assistant_response;
assistant_response.content["role"] = "assistant";

picojson::value val(output_text);

std::string output_json_str = val.serialize();

assistant_response.content["content"] = output_json_str;
history.push_back(assistant_response);

if (finish_reason_length) {
slide_history();
}
return 0;
}

void ChatState::reset() {
history.clear();
history_window_begin = 0;
}

int ChatState::chat(std::string prompt) {
print_help_str();
// Get the prompt message
if (!prompt.empty()) {
int ret = generate(prompt);
__json_wrapper->background_loops->terminate();
this->__json_wrapper->engine_state->getStats();
return ret;
}
std::string cin_prompt;
while (true) {
std::cout << ">>> ";
std::getline(std::cin, cin_prompt);
if (std::cin.eof() || cin_prompt == "/exit") {
__json_wrapper->background_loops->terminate();
break;
} else if (cin_prompt == "/help") {
print_help_str();
} else if (cin_prompt == "/stats") {
this->__json_wrapper->engine_state->getStats();
} else {
generate(cin_prompt);
}
}
return 0;
}
30 changes: 30 additions & 0 deletions apps/mlc_cli_chat/chat_state.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*!
* Copyright (c) 2023-2025 by Contributors
* \file chat_state.h
*/

#ifndef MLC_CLI_CHAT_CHAT_STATE_H
#define MLC_CLI_CHAT_CHAT_STATE_H

#include "base.h"
#include "engine.h"

void print_help_str();

class ChatState {
public:
std::vector<Message> history;
size_t history_window_begin;
std::shared_ptr<JSONFFIEngineWrapper> __json_wrapper;

ChatState(std::string model_path, std::string model_lib_path, std::string mode,
std::string device, int device_id = 0);

void slide_history();
std::vector<Message> get_current_history_window();
int generate(const std::string& prompt);
void reset();
int chat(std::string prompt = "");
};

#endif
Loading