2727#include < future>
2828#include " loader.h"
2929#include " logging.h"
30+ #include " semaphore.h"
3031#include " tensorrt_model.h"
3132#include " tensorrt_model_instance.h"
3233#include " tensorrt_utils.h"
@@ -221,6 +222,23 @@ class ModelState : public TensorRTModel {
221222 void DisableEngineSharing () { engine_sharing_ = false ; }
222223 bool IsEngineSharingEnabled () { return engine_sharing_; }
223224
225+ struct SemaphoreContext {
226+ SemaphoreContext () : next_sem_idx_(0 ) {}
227+
228+ std::vector<std::unique_ptr<Semaphore>> semaphore_list_;
229+ int next_sem_idx_;
230+ };
231+
232+ std::map<int , std::unique_ptr<SemaphoreContext>>& SemaphoreMap ()
233+ {
234+ return semaphore_map_;
235+ }
236+
237+ std::unique_ptr<SemaphoreContext>& SemaphoreDeviceContext (const int device_id)
238+ {
239+ return semaphore_map_[device_id];
240+ }
241+
224242 private:
225243 ModelState (TRITONBACKEND_Model* triton_model);
226244
@@ -264,6 +282,9 @@ class ModelState : public TensorRTModel {
264282 std::shared_ptr<nvinfer1::ICudaEngine>>>
265283 device_engines_;
266284 bool engine_sharing_;
285+
286+ // A map between device id to its semaphore context
287+ std::map<int , std::unique_ptr<SemaphoreContext>> semaphore_map_;
267288};
268289
269290TRITONSERVER_Error*
@@ -976,7 +997,7 @@ class ModelInstanceState : public TensorRTModelInstance {
976997 ModelState* model_state,
977998 TRITONBACKEND_ModelInstance* triton_model_instance);
978999
979- void RegisterContexts ();
1000+ void RegisterSemaphore ();
9801001 TRITONSERVER_Error* InitStreamsAndEvents ();
9811002 TRITONSERVER_Error* InitEventSet (bool busy_wait_events);
9821003 TRITONSERVER_Error* DestroyEventSet ();
@@ -1203,19 +1224,16 @@ class ModelInstanceState : public TensorRTModelInstance {
12031224 // executions' event states.
12041225 std::thread completion_thread_;
12051226
1206- triton::common::SyncQueue<size_t > context_queue_;
1207- size_t next_context_idx_;
1208-
12091227 // The details needed by the completion thread to finalize the
12101228 // response for a model execution.
12111229 struct Payload {
12121230 explicit Payload (
12131231 size_t event_set_idx, TRITONBACKEND_Request** requests,
1214- uint32_t request_count, size_t context_idx )
1232+ uint32_t request_count, size_t sem_idx )
12151233 : event_set_idx_(event_set_idx), total_batch_size_(0 ),
12161234 compute_start_ns_(0 ), compute_input_end_ns_(0 ),
12171235 compute_output_start_ns_(0 ), requests_(requests),
1218- request_count_(request_count), context_idx_(context_idx )
1236+ request_count_(request_count), sem_idx_(sem_idx )
12191237 {
12201238 }
12211239
@@ -1234,7 +1252,7 @@ class ModelInstanceState : public TensorRTModelInstance {
12341252 std::vector<TRITONBACKEND_Request*> requests_list_;
12351253 TRITONBACKEND_Request** requests_;
12361254 uint32_t request_count_;
1237- size_t context_idx_ ;
1255+ size_t sem_idx_ ;
12381256
12391257 // All the generated InferenceResponse objects
12401258 std::vector<TRITONBACKEND_Response*> responses_;
@@ -1360,7 +1378,7 @@ ModelInstanceState::Create(
13601378 " ' for model instance '" + (*state)->Name () + " '" );
13611379 }
13621380
1363- (*state)->RegisterContexts ();
1381+ (*state)->RegisterSemaphore ();
13641382 RETURN_IF_ERROR ((*state)->InitStreamsAndEvents ());
13651383 RETURN_IF_ERROR (model_state->CreateEngine (
13661384 (*state)->DeviceId (), (*state)->DLACoreId (), model_path,
@@ -1579,9 +1597,11 @@ ModelInstanceState::ProcessRequests(
15791597 std::to_string (request_count) + " requests" )
15801598 .c_str ());
15811599
1582- auto context_idx = next_context_idx_;
1600+ auto & sem_context = (model_state_->SemaphoreDeviceContext (DeviceId ()));
1601+
1602+ auto sem_idx = sem_context->next_sem_idx_ ;
15831603
1584- Run (requests, request_count, context_idx );
1604+ Run (requests, request_count, sem_idx );
15851605
15861606 bool run_failed = true ;
15871607 for (size_t i = 0 ; i < request_count; ++i) {
@@ -1597,7 +1617,7 @@ ModelInstanceState::ProcessRequests(
15971617 if (run_failed) {
15981618 // On inference error, place the slot back to the queue
15991619 // immediately as all works for the slot should be ignored.
1600- context_queue_. Put (context_idx );
1620+ sem_context-> semaphore_list_ [sem_idx]-> Release ( );
16011621 } else {
16021622 auto event_set_idx = next_set_;
16031623 next_set_ = (event_set_idx + 1 ) % EVENT_SET_COUNT;
@@ -1620,7 +1640,9 @@ ModelInstanceState::ProcessRequests(
16201640 }
16211641
16221642 // Block the execution if there are no available contexts.
1623- next_context_idx_ = context_queue_.Get ();
1643+ sem_context->next_sem_idx_ =
1644+ (sem_idx + 1 ) % sem_context->semaphore_list_ .size ();
1645+ sem_context->semaphore_list_ [sem_idx]->Acquire ();
16241646}
16251647
16261648void
@@ -2528,7 +2550,9 @@ ModelInstanceState::ProcessResponse()
25282550 // slots so that it can begin enqueuing new memcpys into the input
25292551 // buffers
25302552 cudaEventSynchronize (event_set.ready_for_input_ );
2531- context_queue_.Put (payload->context_idx_ );
2553+ (model_state_->SemaphoreDeviceContext (DeviceId ()))
2554+ ->semaphore_list_ [payload->sem_idx_ ]
2555+ ->Release ();
25322556 NVTX_MARKER (" plan_input_available" );
25332557
25342558 // Call Finalize() here to defer CUDA synchronization as much as
@@ -2963,22 +2987,28 @@ ModelInstanceState::DestroyEventSet()
29632987}
29642988
29652989void
2966- ModelInstanceState::RegisterContexts ()
2990+ ModelInstanceState::RegisterSemaphore ()
29672991{
2968- size_t context_idx = 0 ;
2969- context_queue_.Put (context_idx++);
2970- // If eager batching is set, we add additional slots per device
2992+ // If eager batching is set, we add to the semaphore resource count
29712993 // which allows to start preparing next batch before the previous
2972- // batch has completed. The number of duplicates are limitedby
2994+ // batch has completed. The number of duplicates are limited by
29732995 // number of event sets to prevent too many iterations are run
29742996 // ahead and to avoid interference of the event communication in
29752997 // the previous execution
2976- if (model_state_->EagerBatching ()) {
2977- for (int count = 1 ; count < EVENT_SET_COUNT; ++count) {
2978- context_queue_.Put (context_idx++);
2979- }
2998+ int sem_count = (model_state_->EagerBatching ()) ? EVENT_SET_COUNT : 1 ;
2999+ auto it = (model_state_->SemaphoreMap ()).find (DeviceId ());
3000+ if (it == (model_state_->SemaphoreMap ()).end ()) {
3001+ it = (model_state_->SemaphoreMap ())
3002+ .emplace (
3003+ std::make_pair (DeviceId (), new ModelState::SemaphoreContext ()))
3004+ .first ;
3005+ }
3006+ it->second ->semaphore_list_ .emplace_back (new Semaphore (sem_count));
3007+
3008+ if (it->second ->semaphore_list_ .size () == 1 ) {
3009+ // Need to acquire a semaphore for first inference request
3010+ it->second ->semaphore_list_ [it->second ->next_sem_idx_ ]->Acquire ();
29803011 }
2981- next_context_idx_ = context_queue_.Get ();
29823012}
29833013
29843014TRITONSERVER_Error*
0 commit comments