Skip to content

Commit e634654

Browse files
OVInferRequest::SetTensor: Set tensor upon cached_binding shape mismatch (#783)
1 parent aea572d commit e634654

File tree

1 file changed

+8
-21
lines changed

1 file changed

+8
-21
lines changed

onnxruntime/core/providers/openvino/ov_interface.h

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -126,29 +126,16 @@ class OVInferRequest {
126126
OVTensorPtr GetTensor(const std::string& name);
127127
std::string GetInputTensorName(uint32_t index);
128128

129-
// Set tensor described param_info and ort_ptr. Overrides shape in param_info with shape_override. Call infer req tensor if ort_ptr is last set.
129+
// Set tensor call infer req tensor if ort_ptr differs from last set ptr.
130130
void SetTensor(const std::string& name, const ov::element::Type& type, const ov::Shape& shape, void* ort_ptr) {
131131
auto& cached_binding = bindings_cache_[name];
132-
if (cached_binding.ort_ptr != ort_ptr) {
133-
auto tensor_ptr = std::make_shared<ov::Tensor>(type, shape, const_cast<void*>(ort_ptr));
134-
SetTensor(name, tensor_ptr);
135-
cached_binding = {tensor_ptr, ort_ptr};
136-
} else if (ort_ptr == nullptr) {
137-
// a null ort_ptr is expected for a tensor that has 0 elements.
138-
// for example, a tensor of shape=[1, 8, 0, 64], which is valid.
139-
// So, we check to see if at least one shape entry is 0.
140-
auto contains_zero = [](const ov::Shape& shape) {
141-
for (auto& s : shape)
142-
if (s == 0) return true;
143-
return false;
144-
};
145-
if (contains_zero(shape)) {
146-
// if there are zero elements (i.e. at least one shape entry is 0),
147-
// then create and set the tensor anyway.
148-
auto tensor_ptr = std::make_shared<ov::Tensor>(type, shape);
149-
SetTensor(name, tensor_ptr);
150-
cached_binding = {tensor_ptr, ort_ptr};
151-
}
132+
if (cached_binding.ort_ptr != ort_ptr ||
133+
!cached_binding.tensor_ptr ||
134+
cached_binding.tensor_ptr->get_shape() != shape) {
135+
cached_binding.tensor_ptr.reset();
136+
auto ov_tensor = std::make_shared<ov::Tensor>(type, shape, const_cast<void*>(ort_ptr));
137+
ovInfReq.set_tensor(name, *ov_tensor);
138+
cached_binding = {ov_tensor, ort_ptr};
152139
}
153140
}
154141

0 commit comments

Comments
 (0)