Skip to content

Commit 94eb61c

Browse files
authored
refactor: Refactor core input size checks (#382)
1 parent 623d0a5 commit 94eb61c

File tree

3 files changed

+136
-47
lines changed

3 files changed

+136
-47
lines changed

src/infer_request.cc

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,7 @@ Status
910910
InferenceRequest::Normalize()
911911
{
912912
const inference::ModelConfig& model_config = model_raw_->Config();
913+
const std::string& model_name = ModelName();
913914

914915
// Fill metadata for raw input
915916
if (!raw_input_name_.empty()) {
@@ -922,7 +923,7 @@ InferenceRequest::Normalize()
922923
std::to_string(original_inputs_.size()) +
923924
") to be deduced but got " +
924925
std::to_string(model_config.input_size()) + " inputs in '" +
925-
ModelName() + "' model configuration");
926+
model_name + "' model configuration");
926927
}
927928
auto it = original_inputs_.begin();
928929
if (raw_input_name_ != it->first) {
@@ -1055,7 +1056,7 @@ InferenceRequest::Normalize()
10551056
Status::Code::INVALID_ARG,
10561057
LogRequest() + "input '" + input.Name() +
10571058
"' has no shape but model requires batch dimension for '" +
1058-
ModelName() + "'");
1059+
model_name + "'");
10591060
}
10601061

10611062
if (batch_size_ == 0) {
@@ -1064,7 +1065,7 @@ InferenceRequest::Normalize()
10641065
return Status(
10651066
Status::Code::INVALID_ARG,
10661067
LogRequest() + "input '" + input.Name() +
1067-
"' batch size does not match other inputs for '" + ModelName() +
1068+
"' batch size does not match other inputs for '" + model_name +
10681069
"'");
10691070
}
10701071

@@ -1080,7 +1081,7 @@ InferenceRequest::Normalize()
10801081
Status::Code::INVALID_ARG,
10811082
LogRequest() + "inference request batch-size must be <= " +
10821083
std::to_string(model_config.max_batch_size()) + " for '" +
1083-
ModelName() + "'");
1084+
model_name + "'");
10841085
}
10851086

10861087
// Verify that each input shape is valid for the model, make
@@ -1089,17 +1090,17 @@ InferenceRequest::Normalize()
10891090
const inference::ModelInput* input_config;
10901091
RETURN_IF_ERROR(model_raw_->GetInput(pr.second.Name(), &input_config));
10911092

1092-
auto& input_id = pr.first;
1093+
auto& input_name = pr.first;
10931094
auto& input = pr.second;
10941095
auto shape = input.MutableShape();
10951096

10961097
if (input.DType() != input_config->data_type()) {
10971098
return Status(
10981099
Status::Code::INVALID_ARG,
1099-
LogRequest() + "inference input '" + input_id + "' data-type is '" +
1100+
LogRequest() + "inference input '" + input_name + "' data-type is '" +
11001101
std::string(
11011102
triton::common::DataTypeToProtocolString(input.DType())) +
1102-
"', but model '" + ModelName() + "' expects '" +
1103+
"', but model '" + model_name + "' expects '" +
11031104
std::string(triton::common::DataTypeToProtocolString(
11041105
input_config->data_type())) +
11051106
"'");
@@ -1119,7 +1120,7 @@ InferenceRequest::Normalize()
11191120
Status::Code::INVALID_ARG,
11201121
LogRequest() +
11211122
"All input dimensions should be specified for input '" +
1122-
input_id + "' for model '" + ModelName() + "', got " +
1123+
input_name + "' for model '" + model_name + "', got " +
11231124
triton::common::DimsListToString(input.OriginalShape()));
11241125
} else if (
11251126
(config_dims[i] != triton::common::WILDCARD_DIM) &&
@@ -1148,8 +1149,8 @@ InferenceRequest::Normalize()
11481149
}
11491150
return Status(
11501151
Status::Code::INVALID_ARG,
1151-
LogRequest() + "unexpected shape for input '" + input_id +
1152-
"' for model '" + ModelName() + "'. Expected " +
1152+
LogRequest() + "unexpected shape for input '" + input_name +
1153+
"' for model '" + model_name + "'. Expected " +
11531154
triton::common::DimsListToString(full_dims) + ", got " +
11541155
triton::common::DimsListToString(input.OriginalShape()) + ". " +
11551156
implicit_batch_note);
@@ -1201,32 +1202,25 @@ InferenceRequest::Normalize()
12011202
// TensorRT backend.
12021203
if (!input.IsNonLinearFormatIo()) {
12031204
TRITONSERVER_MemoryType input_memory_type;
1204-
// Because Triton expects STRING type to be in special format
1205-
// (prepend 4 bytes to specify string length), so need to add all the
1206-
// first 4 bytes for each element to find expected byte size
12071205
if (data_type == inference::DataType::TYPE_STRING) {
1208-
RETURN_IF_ERROR(
1209-
ValidateBytesInputs(input_id, input, &input_memory_type));
1210-
1211-
// FIXME: Temporarily skips byte size checks for GPU tensors. See
1212-
// DLIS-6820.
1206+
RETURN_IF_ERROR(ValidateBytesInputs(
1207+
input_name, input, model_name, &input_memory_type));
12131208
} else {
12141209
// Shape tensor with dynamic batching does not introduce a new
12151210
// dimension to the tensor but adds an additional value to the 1-D
12161211
// array.
12171212
const std::vector<int64_t>& input_dims =
12181213
input.IsShapeTensor() ? input.OriginalShape()
12191214
: input.ShapeWithBatchDim();
1220-
int64_t expected_byte_size = INT_MAX;
1221-
expected_byte_size =
1215+
int64_t expected_byte_size =
12221216
triton::common::GetByteSize(data_type, input_dims);
12231217
const size_t& byte_size = input.Data()->TotalByteSize();
1224-
if ((byte_size > INT_MAX) ||
1218+
if ((byte_size > LLONG_MAX) ||
12251219
(static_cast<int64_t>(byte_size) != expected_byte_size)) {
12261220
return Status(
12271221
Status::Code::INVALID_ARG,
12281222
LogRequest() + "input byte size mismatch for input '" +
1229-
input_id + "' for model '" + ModelName() + "'. Expected " +
1223+
input_name + "' for model '" + model_name + "'. Expected " +
12301224
std::to_string(expected_byte_size) + ", got " +
12311225
std::to_string(byte_size));
12321226
}
@@ -1300,7 +1294,8 @@ InferenceRequest::ValidateRequestInputs()
13001294

13011295
Status
13021296
InferenceRequest::ValidateBytesInputs(
1303-
const std::string& input_id, const Input& input,
1297+
const std::string& input_name, const Input& input,
1298+
const std::string& model_name,
13041299
TRITONSERVER_MemoryType* buffer_memory_type) const
13051300
{
13061301
const auto& input_dims = input.ShapeWithBatchDim();
@@ -1325,27 +1320,48 @@ InferenceRequest::ValidateBytesInputs(
13251320
buffer_next_idx++, (const void**)(&buffer), &remaining_buffer_size,
13261321
buffer_memory_type, &buffer_memory_id));
13271322

1323+
// GPU tensors are validated at platform backends to avoid additional
1324+
// data copying. Check "ValidateStringBuffer" in backend_common.cc.
13281325
if (*buffer_memory_type == TRITONSERVER_MEMORY_GPU) {
13291326
return Status::Success;
13301327
}
13311328
}
13321329

1333-
constexpr size_t kElementSizeIndicator = sizeof(uint32_t);
13341330
// Get the next element if not currently processing one.
13351331
if (!remaining_element_size) {
1332+
// Triton expects STRING type to be in special format
1333+
// (prepend 4 bytes to specify string length), so need to add the
1334+
// first 4 bytes for each element to find expected byte size.
1335+
constexpr size_t kElementSizeIndicator = sizeof(uint32_t);
1336+
13361337
// FIXME: Assume the string element's byte size indicator is not spread
13371338
// across buffer boundaries for simplicity.
13381339
if (remaining_buffer_size < kElementSizeIndicator) {
13391340
return Status(
13401341
Status::Code::INVALID_ARG,
13411342
LogRequest() +
1342-
"element byte size indicator exceeds the end of the buffer.");
1343+
"incomplete string length indicator for inference input '" +
1344+
input_name + "' for model '" + model_name + "', expecting " +
1345+
std::to_string(sizeof(uint32_t)) + " bytes but only " +
1346+
std::to_string(remaining_buffer_size) +
1347+
" bytes available. Please make sure the string length "
1348+
"indicator is in one buffer.");
13431349
}
13441350

13451351
// Start the next element and reset the remaining element size.
13461352
remaining_element_size = *(reinterpret_cast<const uint32_t*>(buffer));
13471353
element_checked++;
13481354

1355+
// Early stop
1356+
if (element_checked > element_count) {
1357+
return Status(
1358+
Status::Code::INVALID_ARG,
1359+
LogRequest() + "unexpected number of string elements " +
1360+
std::to_string(element_checked) + " for inference input '" +
1361+
input_name + "' for model '" + model_name + "', expecting " +
1362+
std::to_string(element_count));
1363+
}
1364+
13491365
// Advance pointer and remainder by the indicator size.
13501366
buffer += kElementSizeIndicator;
13511367
remaining_buffer_size -= kElementSizeIndicator;
@@ -1371,16 +1387,17 @@ InferenceRequest::ValidateBytesInputs(
13711387
return Status(
13721388
Status::Code::INVALID_ARG,
13731389
LogRequest() + "expected " + std::to_string(buffer_count) +
1374-
" buffers for inference input '" + input_id + "', got " +
1375-
std::to_string(buffer_next_idx));
1390+
" buffers for inference input '" + input_name + "' for model '" +
1391+
model_name + "', got " + std::to_string(buffer_next_idx));
13761392
}
13771393

13781394
// Validate the number of processed elements exactly match expectations.
13791395
if (element_checked != element_count) {
13801396
return Status(
13811397
Status::Code::INVALID_ARG,
13821398
LogRequest() + "expected " + std::to_string(element_count) +
1383-
" string elements for inference input '" + input_id + "', got " +
1399+
" string elements for inference input '" + input_name +
1400+
"' for model '" + model_name + "', got " +
13841401
std::to_string(element_checked));
13851402
}
13861403

src/infer_request.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,7 @@ class InferenceRequest {
775775

776776
Status ValidateBytesInputs(
777777
const std::string& input_id, const Input& input,
778+
const std::string& model_name,
778779
TRITONSERVER_MemoryType* buffer_memory_type) const;
779780

780781
// Helpers for pending request metrics

0 commit comments

Comments
 (0)