@@ -126,29 +126,16 @@ class OVInferRequest {
126
126
OVTensorPtr GetTensor (const std::string& name);
127
127
std::string GetInputTensorName (uint32_t index);
128
128
129
- // Set tensor described param_info and ort_ptr. Overrides shape in param_info with shape_override. Call infer req tensor if ort_ptr is last set.
129
+ // Set tensor call infer req tensor if ort_ptr differs from last set ptr .
130
130
void SetTensor (const std::string& name, const ov::element::Type& type, const ov::Shape& shape, void * ort_ptr) {
131
131
auto & cached_binding = bindings_cache_[name];
132
- if (cached_binding.ort_ptr != ort_ptr) {
133
- auto tensor_ptr = std::make_shared<ov::Tensor>(type, shape, const_cast <void *>(ort_ptr));
134
- SetTensor (name, tensor_ptr);
135
- cached_binding = {tensor_ptr, ort_ptr};
136
- } else if (ort_ptr == nullptr ) {
137
- // a null ort_ptr is expected for a tensor that has 0 elements.
138
- // for example, a tensor of shape=[1, 8, 0, 64], which is valid.
139
- // So, we check to see if at least one shape entry is 0.
140
- auto contains_zero = [](const ov::Shape& shape) {
141
- for (auto & s : shape)
142
- if (s == 0 ) return true ;
143
- return false ;
144
- };
145
- if (contains_zero (shape)) {
146
- // if there are zero elements (i.e. at least one shape entry is 0),
147
- // then create and set the tensor anyway.
148
- auto tensor_ptr = std::make_shared<ov::Tensor>(type, shape);
149
- SetTensor (name, tensor_ptr);
150
- cached_binding = {tensor_ptr, ort_ptr};
151
- }
132
+ if (cached_binding.ort_ptr != ort_ptr ||
133
+ !cached_binding.tensor_ptr ||
134
+ cached_binding.tensor_ptr ->get_shape () != shape) {
135
+ cached_binding.tensor_ptr .reset ();
136
+ auto ov_tensor = std::make_shared<ov::Tensor>(type, shape, const_cast <void *>(ort_ptr));
137
+ ovInfReq.set_tensor (name, *ov_tensor);
138
+ cached_binding = {ov_tensor, ort_ptr};
152
139
}
153
140
}
154
141
0 commit comments