diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 4ad8f8dd85f4d..4be4bff039df4 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -166,33 +166,24 @@ std::vector> GetCapability::Execute() { auto connected_clusters = GetConnectedClusters(graph_viewer_, ng_clusters); int no_of_clusters = 0; - std::vector prev_cluster; - bool try_next_cluster = false; - + size_t cluster_index = 0; + size_t total_clusters = connected_clusters.size(); for (auto this_cluster : connected_clusters) { bool omit_subgraph = false; - if (try_next_cluster) { - // no need to check previous cluster - for (auto idx : prev_cluster) { - if ((std::find(this_cluster.begin(), this_cluster.end(), idx)) == this_cluster.end()) { - this_cluster.emplace_back(idx); - } - } - try_next_cluster = false; - } - // If subgraph has less then three, graph is considered trivial unless its an epctx cluster - if (!try_next_cluster && this_cluster.size() < 3) { - bool is_epctx_node = false; - for (auto node_idx : this_cluster) { - if (graph_viewer_.GetNode(node_idx)->OpType() == "EPContext") - is_epctx_node = true; - } - if (!is_epctx_node) { - omit_subgraph = true; - prev_cluster = this_cluster; - try_next_cluster = true; - } + //auto id = this_cluster.at(0); + if (this_cluster.size() == 1) { + //check next cluster + auto index = this_cluster.at(0); + if (graph_viewer_.GetNode(index)->OpType() == "EPContext") { + omit_subgraph=false; + } else if(cluster_index < total_clusters-1) { + bool append_node = AddTrivialClusterToNextClusterIfConnected(graph_viewer_, index, connected_clusters[cluster_index+1]); + if(append_node) { + connected_clusters[cluster_index+1].emplace_back(index); + } + omit_subgraph=true; + } } std::vector cluster_graph_inputs, cluster_inputs, cluster_outputs; @@ -233,15 +224,17 @@ std::vector> GetCapability::Execute() { } } } - if (omit_subgraph) - continue; /* In scenarios, when there are no inputs or all inputs being initializers, ConstantFolding optimization in onnxruntime pre-computes the value.*/ - if (!cluster_inputs.empty()) { - AppendClusterToSubGraph(this_cluster, cluster_inputs, cluster_outputs, result); - no_of_clusters++; + if (!omit_subgraph) { + if (!cluster_inputs.empty()) { + AppendClusterToSubGraph(this_cluster, cluster_inputs, cluster_outputs, result); + no_of_clusters++; + } } + + cluster_index = cluster_index+1; } LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Supported subgraphs on OpenVINO: " << no_of_clusters; } diff --git a/onnxruntime/core/providers/openvino/ov_versions/utils.cc b/onnxruntime/core/providers/openvino/ov_versions/utils.cc index f924fa0c8205c..814378eab47d5 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/utils.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/utils.cc @@ -153,6 +153,26 @@ GetConnectedClusters(const GraphViewer& graph_viewer, const std::vector& search_cluster) { + + for(auto index: search_cluster) { + auto curr_node = graph_viewer.GetNode(index); + for (auto node = curr_node->InputNodesBegin(); node != curr_node->InputNodesEnd(); ++node) { + if((*node).Index() == curr_node_index) + return true; + } + + for (auto node = curr_node->OutputNodesBegin(); node != curr_node->OutputNodesEnd(); ++node) { + if((*node).Index() == curr_node_index) + return true; + } + } + return false; +} + + void GetInputsOutputsOfCluster(const GraphViewer& graph_viewer, const std::vector& cluster, const std::unordered_set& ng_required_initializers, diff --git a/onnxruntime/core/providers/openvino/ov_versions/utils.h b/onnxruntime/core/providers/openvino/ov_versions/utils.h index 34aa762ba9b67..bdad047a422c1 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/utils.h +++ b/onnxruntime/core/providers/openvino/ov_versions/utils.h @@ -40,6 +40,10 @@ void IdentifyConnectedNodes( std::vector> GetConnectedClusters(const GraphViewer& graph_viewer, const std::vector>& clusters); +bool AddTrivialClusterToNextClusterIfConnected(const GraphViewer& graph_viewer, + const NodeIndex index, + const std::vector& search_cluster); + void GetInputsOutputsOfCluster(const GraphViewer& graph_viewer, const std::vector& cluster, const std::unordered_set& ng_required_initializers, diff --git a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc index a5fd37361a255..a92c1ed47f69b 100644 --- a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc @@ -1162,7 +1162,7 @@ TEST(Loop, SequenceAsLoopCarriedDependency) { test.AddSeqOutput("loop_var_0_final", seq_output); // Disable TensorRT on unsupported data type BOOL - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } #if !defined(DISABLE_OPTIONAL_TYPE)