Skip to content

Commit

Permalink
[Serve] Separate callback invocation to another thread in AsyncEngine (
Browse files Browse the repository at this point in the history
…#2046)

This PR enhances the AsyncThreadEngine by separating the callback
invocation to another thread, in order to reduce the CPU time overhead
of invoking Python callback.
  • Loading branch information
MasterJH5574 authored Mar 29, 2024
1 parent 2b82091 commit 522db05
Show file tree
Hide file tree
Showing 8 changed files with 299 additions and 156 deletions.
131 changes: 113 additions & 18 deletions cpp/serve/async_threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class AsyncThreadedEngineImpl : public AsyncThreadedEngine, public ModuleNode {
TVM_MODULE_VTABLE_ENTRY("add_request", &AsyncThreadedEngineImpl::AddRequest);
TVM_MODULE_VTABLE_ENTRY("abort_request", &AsyncThreadedEngineImpl::AbortRequest);
TVM_MODULE_VTABLE_ENTRY("run_background_loop", &AsyncThreadedEngineImpl::RunBackgroundLoop);
TVM_MODULE_VTABLE_ENTRY("run_background_stream_back_loop",
&AsyncThreadedEngineImpl::RunBackgroundStreamBackLoop);
TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &AsyncThreadedEngineImpl::ExitBackgroundLoop);
if (_name == "init_background_engine") {
return PackedFunc([_self](TVMArgs args, TVMRetValue* rv) -> void {
Expand All @@ -39,44 +41,87 @@ class AsyncThreadedEngineImpl : public AsyncThreadedEngine, public ModuleNode {
}
TVM_MODULE_VTABLE_END();

void InitBackgroundEngine(TVMArgs args) { background_engine_ = CreateEnginePacked(args); }
void InitBackgroundEngine(TVMArgs args) {
Optional<PackedFunc> request_stream_callback;
try {
request_stream_callback = args.At<Optional<PackedFunc>>(4);
} catch (const dmlc::Error& e) {
LOG(FATAL) << "ValueError: " << e.what() << kEngineCreationErrorMessage;
}

CHECK(request_stream_callback.defined())
<< "AsyncThreadedEngine requires request stream callback function, but it is not given.";
request_stream_callback_ = request_stream_callback.value();

auto frequest_stream_callback_wrapper = [this](TVMArgs args, TVMRetValue* ret) {
ICHECK_EQ(args.size(), 1);
Array<RequestStreamOutput> delta_outputs = args[0];
bool need_notify = false;
{
std::lock_guard<std::mutex> lock(request_stream_callback_mutex_);
request_stream_callback_inputs_.push_back(std::move(delta_outputs));
++pending_request_stream_callback_cnt_;
need_notify = stream_callback_waiting_;
}
if (need_notify) {
request_stream_callback_cv_.notify_one();
}
};

std::vector<TVMValue> values{args.values, args.values + args.size()};
std::vector<int> type_codes{args.type_codes, args.type_codes + args.size()};
TVMArgsSetter setter(values.data(), type_codes.data());
request_stream_callback = PackedFunc(frequest_stream_callback_wrapper);
setter(4, request_stream_callback);
background_engine_ = CreateEnginePacked(TVMArgs(values.data(), type_codes.data(), args.size()));
}

void AddRequest(Request request) final {
bool need_notify = false;
{
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> lock(background_loop_mutex_);
requests_to_add_.push_back(request);
++pending_operation_cnt_;
++pending_request_operation_cnt_;
need_notify = engine_waiting_;
}
if (need_notify) {
background_loop_cv_.notify_one();
}
cv_.notify_one();
}

void AbortRequest(const String& request_id) final {
bool need_notify = false;
{
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> lock(background_loop_mutex_);
requests_to_abort_.push_back(request_id);
++pending_operation_cnt_;
++pending_request_operation_cnt_;
need_notify = engine_waiting_;
}
if (need_notify) {
background_loop_cv_.notify_one();
}
cv_.notify_one();
}

void RunBackgroundLoop() final {
// The local vectors that load the requests in critical regions.
// The local vectors that load the requests from critical regions.
std::vector<Request> local_requests_to_add;
std::vector<String> local_requests_to_abort;

while (!exit_now_.load(std::memory_order_relaxed)) {
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] {
return !background_engine_->Empty() || pending_operation_cnt_.load() > 0 ||
std::unique_lock<std::mutex> lock(background_loop_mutex_);
engine_waiting_ = true;
background_loop_cv_.wait(lock, [this] {
return !background_engine_->Empty() || pending_request_operation_cnt_.load() > 0 ||
exit_now_.load(std::memory_order_relaxed);
});
engine_waiting_ = false;

local_requests_to_add = requests_to_add_;
local_requests_to_abort = requests_to_abort_;
requests_to_add_.clear();
requests_to_abort_.clear();
pending_operation_cnt_ = 0;
pending_request_operation_cnt_ = 0;
}
for (Request request : local_requests_to_add) {
background_engine_->AddRequest(request);
Expand All @@ -88,22 +133,57 @@ class AsyncThreadedEngineImpl : public AsyncThreadedEngine, public ModuleNode {
}
}

void RunBackgroundStreamBackLoop() final {
// The local vectors that load the request stream callback inputs from critical regions.
std::vector<Array<RequestStreamOutput>> local_request_stream_callback_inputs;
std::vector<RequestStreamOutput> flattened_callback_inputs;

while (!exit_now_.load(std::memory_order_relaxed)) {
{
std::unique_lock<std::mutex> lock(request_stream_callback_mutex_);
stream_callback_waiting_ = true;
request_stream_callback_cv_.wait(lock, [this] {
return pending_request_stream_callback_cnt_.load() > 0 ||
exit_now_.load(std::memory_order_relaxed);
});
stream_callback_waiting_ = false;

local_request_stream_callback_inputs = request_stream_callback_inputs_;
request_stream_callback_inputs_.clear();
pending_request_stream_callback_cnt_ = 0;
}
for (const Array<RequestStreamOutput>& callback_inputs :
local_request_stream_callback_inputs) {
for (const RequestStreamOutput& callback_input : callback_inputs) {
flattened_callback_inputs.push_back(callback_input);
}
}
request_stream_callback_(Array<RequestStreamOutput>(flattened_callback_inputs));
flattened_callback_inputs.clear();
}
}

void ExitBackgroundLoop() final {
{
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> lock(background_loop_mutex_);
exit_now_.store(true);
}
cv_.notify_one();
background_loop_cv_.notify_one();
request_stream_callback_cv_.notify_one();
}

private:
/*! \brief The background normal engine for request processing. */
std::unique_ptr<Engine> background_engine_;
/*! \brief The request stream callback. */
PackedFunc request_stream_callback_;

/*! \brief The mutex ensuring only one thread can access critical regions. */
std::mutex mutex_;
std::mutex background_loop_mutex_;
std::mutex request_stream_callback_mutex_;
/*! \brief The condition variable preventing threaded engine from spinning. */
std::condition_variable cv_;
std::condition_variable background_loop_cv_;
std::condition_variable request_stream_callback_cv_;
/*! \brief A boolean flag denoting if the engine needs to exit background loop. */
std::atomic<bool> exit_now_ = false;

Expand All @@ -121,10 +201,25 @@ class AsyncThreadedEngineImpl : public AsyncThreadedEngine, public ModuleNode {
*/
std::vector<String> requests_to_abort_;
/*!
* \brief Number of pending operations, should be the size of
* \brief The delta outputs to pass through callback.
* Elements are sended from the background loop thread and
* consumed by the foreground thread.
*/
std::vector<Array<RequestStreamOutput>> request_stream_callback_inputs_;
/*!
* \brief Number of pending request operations, should be the size of
* `requests_to_add_` and `requests_to_abort_`.
*/
std::atomic<int> pending_operation_cnt_ = 0;
std::atomic<int> pending_request_operation_cnt_ = 0;
/*!
* \brief Number of pending request stream callback invocations.
* It should be the size of `request_stream_callback_inputs_`.
*/
std::atomic<int> pending_request_stream_callback_cnt_ = 0;
/*! \brief A boolean flag indicating if the engine is waiting for new requests/aborts. */
bool engine_waiting_ = false;
/*! \brief A boolean flag indicating if the stream callback loop is waiting. */
bool stream_callback_waiting_ = false;
};

TVM_REGISTER_GLOBAL("mlc.serve.create_threaded_engine").set_body_typed([]() {
Expand Down
3 changes: 3 additions & 0 deletions cpp/serve/async_threaded_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ class AsyncThreadedEngine {
/*! \brief Starts the background request processing loop. */
virtual void RunBackgroundLoop() = 0;

/*! \brief Starts the request stream callback loop. */
virtual void RunBackgroundStreamBackLoop() = 0;

/*!
* \brief Notify the AsyncThreadedEngine to exit the background
* request processing loop. This method is invoked by threads
Expand Down
19 changes: 1 addition & 18 deletions cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,23 +305,6 @@ void ClearGlobalMemoryManager() {
}

std::unique_ptr<Engine> CreateEnginePacked(TVMArgs args) {
static const char* kErrorMessage =
"With `n` models, engine initialization "
"takes (6 + 4 * n) arguments. The first 6 arguments should be: "
"1) (int) maximum length of a sequence, which must be equal or smaller than the context "
"window size of each model; "
"2) (string) path to tokenizer configuration files, which in MLC LLM, usually in a model "
"weights directory; "
"3) (string) JSON configuration for the KVCache; "
"4) (string) JSON mode for Engine;"
"5) (packed function, optional) global request stream callback function. "
"6) (EventTraceRecorder, optional) the event trace recorder for requests."
"The following (4 * n) arguments, 4 for each model, should be: "
"1) (tvm.runtime.Module) The model library loaded into TVM's RelaxVM; "
"2) (string) Model path which includes weights and mlc-chat-config.json; "
"3) (int, enum DLDeviceType) Device type, e.g. CUDA, ROCm, etc; "
"4) (int) Device id, i.e. the ordinal index of the device that exists locally.";

ClearGlobalMemoryManager();
const int num_non_model_args = 6;
const int num_model_args = 4;
Expand Down Expand Up @@ -352,7 +335,7 @@ std::unique_ptr<Engine> CreateEnginePacked(TVMArgs args) {
model_infos.emplace_back(model_lib, model_path, DLDevice{device_type, device_id});
}
} catch (const dmlc::Error& e) {
LOG(FATAL) << "ValueError: " << e.what() << kErrorMessage;
LOG(FATAL) << "ValueError: " << e.what() << kEngineCreationErrorMessage;
}
return Engine::Create(max_single_sequence_length, tokenizer_path, kv_cache_config_json_str,
engine_mode_json_str, request_stream_callback, std::move(trace_recorder),
Expand Down
17 changes: 17 additions & 0 deletions cpp/serve/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,23 @@ class Engine {
*/
std::unique_ptr<Engine> CreateEnginePacked(TVMArgs args);

constexpr const char* kEngineCreationErrorMessage =
"With `n` models, engine initialization "
"takes (6 + 4 * n) arguments. The first 6 arguments should be: "
"1) (int) maximum length of a sequence, which must be equal or smaller than the context "
"window size of each model; "
"2) (string) path to tokenizer configuration files, which in MLC LLM, usually in a model "
"weights directory; "
"3) (string) JSON configuration for the KVCache; "
"4) (string) JSON mode for Engine;"
"5) (packed function, optional) global request stream callback function. "
"6) (EventTraceRecorder, optional) the event trace recorder for requests."
"The following (4 * n) arguments, 4 for each model, should be: "
"1) (tvm.runtime.Module) The model library loaded into TVM's RelaxVM; "
"2) (string) Model path which includes weights and mlc-chat-config.json; "
"3) (int, enum DLDeviceType) Device type, e.g. CUDA, ROCm, etc; "
"4) (int) Device id, i.e. the ordinal index of the device that exists locally.";

} // namespace serve
} // namespace llm
} // namespace mlc
Expand Down
Loading

0 comments on commit 522db05

Please sign in to comment.