Skip to content

Commit ab51424

Browse files
fix: Add missing function (#172)
* fix: Add missing functions This change add a pair of missing function definitions that got lost in the refactor. * Put back missing line --------- Co-authored-by: Dmitry Tokarev <[email protected]>
1 parent 7e9d0f9 commit ab51424

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

src/model_instance_state.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,27 @@ ModelInstanceState::Create(
240240
return nullptr; // success
241241
}
242242

243+
void
244+
ModelInstanceState::CreateCudaEvents(const int32_t& device_id)
245+
{
246+
#ifdef TRITON_ENABLE_GPU
247+
// Need to set the CUDA context so that the context that events are
248+
// created on match with contexts that events are recorded with.
249+
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
250+
cudaSetDevice(device_id), TRITONSERVER_ERROR_INTERNAL,
251+
"Failed to set the device"));
252+
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
253+
cudaEventCreate(&compute_input_start_event_), TRITONSERVER_ERROR_INTERNAL,
254+
"Failed to create cuda event"));
255+
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
256+
cudaEventCreate(&compute_infer_start_event_), TRITONSERVER_ERROR_INTERNAL,
257+
"Failed to create cuda event"));
258+
THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError(
259+
cudaEventCreate(&compute_output_start_event_),
260+
TRITONSERVER_ERROR_INTERNAL, "Failed to create cuda event"));
261+
#endif
262+
}
263+
243264
void
244265
ModelInstanceState::Execute(
245266
std::vector<TRITONBACKEND_Response*>* responses,
@@ -1230,6 +1251,12 @@ ModelInstanceState::SetInputTensors(
12301251
return nullptr;
12311252
}
12321253

1254+
ModelState*
1255+
ModelInstanceState::StateForModel() const
1256+
{
1257+
return model_state_;
1258+
}
1259+
12331260
TRITONSERVER_Error*
12341261
ModelInstanceState::ValidateBooleanSequenceControl(
12351262
triton::common::TritonJson::Value& sequence_batching,

0 commit comments

Comments
 (0)