From 16e10c1ffcdcfb3c359077d37182f6dd4896718a Mon Sep 17 00:00:00 2001 From: "Klimenko, Mikhail" Date: Thu, 21 Aug 2025 15:38:22 +0200 Subject: [PATCH] Fix bfloat16 pass to work with OrtValues --- .../qdq_transformations/qdq_scales_fix.cpp | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp index f1ce230387565..997537866e740 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp @@ -966,9 +966,32 @@ void replace_bf16_with_fp16(qdq_scales_fix::CustomGraph& gen_graph) { auto tensor_proto = const_cast(const_tensor_proto); auto dt = tensor_proto->data_type(); if (dt == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) { - auto raw_data = tensor_proto->has_raw_data() ? reinterpret_cast(tensor_proto->mutable_raw_data()->data()) : nullptr; + std::uint16_t* raw_data = nullptr; + if (tensor_proto->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL) { + auto external_data = tensor_proto->mutable_external_data(); + std::size_t address = 0; + for (auto i = 0; i < external_data->size(); ++i) { + auto& data = external_data->at(i); + char* end = nullptr; + const auto& mkey = *data.mutable_key(); + if (mkey.find("offset") != mkey.npos) { + address = (size_t)std::strtoull(data.mutable_value()->data(), &end, 10); + } + } + if (address) + raw_data = reinterpret_cast(address); + } else if (tensor_proto->has_raw_data()) { + raw_data = reinterpret_cast(tensor_proto->mutable_raw_data()->data()); + } if (raw_data) { - tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + if (tensor_proto->has_raw_data()) { + tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + } + else { + auto ort_value = OrtValue(); + gen_graph.original_graph.GetOrtValueInitializer(key, ort_value); + ort_value.GetMutable()->SetElementType(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + } std::int64_t size = 1; for (int i = 0; i < tensor_proto->dims_size(); ++i) size *= tensor_proto->dims()[i];