Skip to content

Commit 21942bb

Browse files
authored
Fix the multi-instance overlap in TRT backend (#21)
* Fix the binding initializing logic to correctly detect repeated tensors * Fix the multi-instance overlap in TRT backend
1 parent 4f1190b commit 21942bb

File tree

3 files changed

+111
-23
lines changed

3 files changed

+111
-23
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ add_library(
108108
src/tensorrt_model_instance.h
109109
src/tensorrt_utils.cc
110110
src/tensorrt_utils.h
111+
src/semaphore.h
111112
src/loader.cc
112113
src/loader.h
113114
src/logging.cc

src/semaphore.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
//
3+
// Redistribution and use in source and binary forms, with or without
4+
// modification, are permitted provided that the following conditions
5+
// are met:
6+
// * Redistributions of source code must retain the above copyright
7+
// notice, this list of conditions and the following disclaimer.
8+
// * Redistributions in binary form must reproduce the above copyright
9+
// notice, this list of conditions and the following disclaimer in the
10+
// documentation and/or other materials provided with the distribution.
11+
// * Neither the name of NVIDIA CORPORATION nor the names of its
12+
// contributors may be used to endorse or promote products derived
13+
// from this software without specific prior written permission.
14+
//
15+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
#include <condition_variable>
28+
#include <mutex>
29+
30+
namespace triton { namespace backend { namespace tensorrt {
31+
32+
class Semaphore {
33+
public:
34+
explicit Semaphore(const int count) : count_(count) {}
35+
36+
void Release()
37+
{
38+
std::unique_lock<std::mutex> lck(mtx_);
39+
count_++;
40+
cv_.notify_one();
41+
}
42+
43+
void Acquire()
44+
{
45+
std::unique_lock<std::mutex> lck(mtx_);
46+
cv_.wait(lck, [this]() { return (count_ > 0); });
47+
count_--;
48+
}
49+
50+
private:
51+
int count_;
52+
53+
std::mutex mtx_;
54+
std::condition_variable cv_;
55+
};
56+
57+
}}} // namespace triton::backend::tensorrt

src/tensorrt.cc

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <future>
2828
#include "loader.h"
2929
#include "logging.h"
30+
#include "semaphore.h"
3031
#include "tensorrt_model.h"
3132
#include "tensorrt_model_instance.h"
3233
#include "tensorrt_utils.h"
@@ -221,6 +222,23 @@ class ModelState : public TensorRTModel {
221222
void DisableEngineSharing() { engine_sharing_ = false; }
222223
bool IsEngineSharingEnabled() { return engine_sharing_; }
223224

225+
struct SemaphoreContext {
226+
SemaphoreContext() : next_sem_idx_(0) {}
227+
228+
std::vector<std::unique_ptr<Semaphore>> semaphore_list_;
229+
int next_sem_idx_;
230+
};
231+
232+
std::map<int, std::unique_ptr<SemaphoreContext>>& SemaphoreMap()
233+
{
234+
return semaphore_map_;
235+
}
236+
237+
std::unique_ptr<SemaphoreContext>& SemaphoreDeviceContext(const int device_id)
238+
{
239+
return semaphore_map_[device_id];
240+
}
241+
224242
private:
225243
ModelState(TRITONBACKEND_Model* triton_model);
226244

@@ -264,6 +282,9 @@ class ModelState : public TensorRTModel {
264282
std::shared_ptr<nvinfer1::ICudaEngine>>>
265283
device_engines_;
266284
bool engine_sharing_;
285+
286+
// A map between device id to its semaphore context
287+
std::map<int, std::unique_ptr<SemaphoreContext>> semaphore_map_;
267288
};
268289

269290
TRITONSERVER_Error*
@@ -976,7 +997,7 @@ class ModelInstanceState : public TensorRTModelInstance {
976997
ModelState* model_state,
977998
TRITONBACKEND_ModelInstance* triton_model_instance);
978999

979-
void RegisterContexts();
1000+
void RegisterSemaphore();
9801001
TRITONSERVER_Error* InitStreamsAndEvents();
9811002
TRITONSERVER_Error* InitEventSet(bool busy_wait_events);
9821003
TRITONSERVER_Error* DestroyEventSet();
@@ -1203,19 +1224,16 @@ class ModelInstanceState : public TensorRTModelInstance {
12031224
// executions' event states.
12041225
std::thread completion_thread_;
12051226

1206-
triton::common::SyncQueue<size_t> context_queue_;
1207-
size_t next_context_idx_;
1208-
12091227
// The details needed by the completion thread to finalize the
12101228
// response for a model execution.
12111229
struct Payload {
12121230
explicit Payload(
12131231
size_t event_set_idx, TRITONBACKEND_Request** requests,
1214-
uint32_t request_count, size_t context_idx)
1232+
uint32_t request_count, size_t sem_idx)
12151233
: event_set_idx_(event_set_idx), total_batch_size_(0),
12161234
compute_start_ns_(0), compute_input_end_ns_(0),
12171235
compute_output_start_ns_(0), requests_(requests),
1218-
request_count_(request_count), context_idx_(context_idx)
1236+
request_count_(request_count), sem_idx_(sem_idx)
12191237
{
12201238
}
12211239

@@ -1234,7 +1252,7 @@ class ModelInstanceState : public TensorRTModelInstance {
12341252
std::vector<TRITONBACKEND_Request*> requests_list_;
12351253
TRITONBACKEND_Request** requests_;
12361254
uint32_t request_count_;
1237-
size_t context_idx_;
1255+
size_t sem_idx_;
12381256

12391257
// All the generated InferenceResponse objects
12401258
std::vector<TRITONBACKEND_Response*> responses_;
@@ -1360,7 +1378,7 @@ ModelInstanceState::Create(
13601378
"' for model instance '" + (*state)->Name() + "'");
13611379
}
13621380

1363-
(*state)->RegisterContexts();
1381+
(*state)->RegisterSemaphore();
13641382
RETURN_IF_ERROR((*state)->InitStreamsAndEvents());
13651383
RETURN_IF_ERROR(model_state->CreateEngine(
13661384
(*state)->DeviceId(), (*state)->DLACoreId(), model_path,
@@ -1579,9 +1597,11 @@ ModelInstanceState::ProcessRequests(
15791597
std::to_string(request_count) + " requests")
15801598
.c_str());
15811599

1582-
auto context_idx = next_context_idx_;
1600+
auto& sem_context = (model_state_->SemaphoreDeviceContext(DeviceId()));
1601+
1602+
auto sem_idx = sem_context->next_sem_idx_;
15831603

1584-
Run(requests, request_count, context_idx);
1604+
Run(requests, request_count, sem_idx);
15851605

15861606
bool run_failed = true;
15871607
for (size_t i = 0; i < request_count; ++i) {
@@ -1597,7 +1617,7 @@ ModelInstanceState::ProcessRequests(
15971617
if (run_failed) {
15981618
// On inference error, place the slot back to the queue
15991619
// immediately as all works for the slot should be ignored.
1600-
context_queue_.Put(context_idx);
1620+
sem_context->semaphore_list_[sem_idx]->Release();
16011621
} else {
16021622
auto event_set_idx = next_set_;
16031623
next_set_ = (event_set_idx + 1) % EVENT_SET_COUNT;
@@ -1620,7 +1640,9 @@ ModelInstanceState::ProcessRequests(
16201640
}
16211641

16221642
// Block the execution if there are no available contexts.
1623-
next_context_idx_ = context_queue_.Get();
1643+
sem_context->next_sem_idx_ =
1644+
(sem_idx + 1) % sem_context->semaphore_list_.size();
1645+
sem_context->semaphore_list_[sem_idx]->Acquire();
16241646
}
16251647

16261648
void
@@ -2528,7 +2550,9 @@ ModelInstanceState::ProcessResponse()
25282550
// slots so that it can begin enqueuing new memcpys into the input
25292551
// buffers
25302552
cudaEventSynchronize(event_set.ready_for_input_);
2531-
context_queue_.Put(payload->context_idx_);
2553+
(model_state_->SemaphoreDeviceContext(DeviceId()))
2554+
->semaphore_list_[payload->sem_idx_]
2555+
->Release();
25322556
NVTX_MARKER("plan_input_available");
25332557

25342558
// Call Finalize() here to defer CUDA synchronization as much as
@@ -2963,22 +2987,28 @@ ModelInstanceState::DestroyEventSet()
29632987
}
29642988

29652989
void
2966-
ModelInstanceState::RegisterContexts()
2990+
ModelInstanceState::RegisterSemaphore()
29672991
{
2968-
size_t context_idx = 0;
2969-
context_queue_.Put(context_idx++);
2970-
// If eager batching is set, we add additional slots per device
2992+
// If eager batching is set, we add to the semaphore resource count
29712993
// which allows to start preparing next batch before the previous
2972-
// batch has completed. The number of duplicates are limitedby
2994+
// batch has completed. The number of duplicates are limited by
29732995
// number of event sets to prevent too many iterations are run
29742996
// ahead and to avoid interference of the event communication in
29752997
// the previous execution
2976-
if (model_state_->EagerBatching()) {
2977-
for (int count = 1; count < EVENT_SET_COUNT; ++count) {
2978-
context_queue_.Put(context_idx++);
2979-
}
2998+
int sem_count = (model_state_->EagerBatching()) ? EVENT_SET_COUNT : 1;
2999+
auto it = (model_state_->SemaphoreMap()).find(DeviceId());
3000+
if (it == (model_state_->SemaphoreMap()).end()) {
3001+
it = (model_state_->SemaphoreMap())
3002+
.emplace(
3003+
std::make_pair(DeviceId(), new ModelState::SemaphoreContext()))
3004+
.first;
3005+
}
3006+
it->second->semaphore_list_.emplace_back(new Semaphore(sem_count));
3007+
3008+
if (it->second->semaphore_list_.size() == 1) {
3009+
// Need to acquire a semaphore for first inference request
3010+
it->second->semaphore_list_[it->second->next_sem_idx_]->Acquire();
29803011
}
2981-
next_context_idx_ = context_queue_.Get();
29823012
}
29833013

29843014
TRITONSERVER_Error*

0 commit comments

Comments
 (0)