Add initial graph-native AMS wrappers for homogeneous and heterogeneous graph inputs#191
Open
YohannDudouit wants to merge 3 commits into
Open
Add initial graph-native AMS wrappers for homogeneous and heterogeneous graph inputs#191YohannDudouit wants to merge 3 commits into
YohannDudouit wants to merge 3 commits into
Conversation
There was a problem hiding this comment.
Cpp-linter Review
Used clang-format v18.1.8
Click here for the full clang-format patch
diff --git a/tests/AMSlib/ams_interface/test_graph_fallback.cpp b/tests/AMSlib/ams_interface/test_graph_fallback.cpp
index e8069a7..6803783 100644
--- a/tests/AMSlib/ams_interface/test_graph_fallback.cpp
+++ b/tests/AMSlib/ams_interface/test_graph_fallback.cpp
@@ -56,23 +56,23 @@ CATCH_TEST_CASE("AMSExecute homogeneous graph fallback path", "[wf][graph]")
- HomogeneousGraphDomainFn callback =
- [&](const AMSHomogeneousGraph& g, SmallVector<AMSTensor>& outputs) {
- callback_invoked = true;
-
- // Verify graph structure
- CATCH_REQUIRE(containsTensor(g, "node_features"));
- const auto* features = findTensor(g, "node_features");
- CATCH_REQUIRE(features != nullptr);
- CATCH_REQUIRE(features->shape()[0] == 10);
- CATCH_REQUIRE(features->shape()[1] == 3);
-
- // Verify input data
- const float* features_data = features->data<float>();
- CATCH_REQUIRE(features_data[0] == 0.0f);
- CATCH_REQUIRE(features_data[29] == 29.0f);
-
- // Fill outputs with computation result
- CATCH_REQUIRE(outputs.size() == 1);
- float* out_data = outputs[0].data<float>();
- for (int i = 0; i < 20; ++i) {
- out_data[i] = static_cast<float>(i * 2);
- }
- };
+ HomogeneousGraphDomainFn callback = [&](const AMSHomogeneousGraph& g,
+ SmallVector<AMSTensor>& outputs) {
+ callback_invoked = true;
+
+ // Verify graph structure
+ CATCH_REQUIRE(containsTensor(g, "node_features"));
+ const auto* features = findTensor(g, "node_features");
+ CATCH_REQUIRE(features != nullptr);
+ CATCH_REQUIRE(features->shape()[0] == 10);
+ CATCH_REQUIRE(features->shape()[1] == 3);
+
+ // Verify input data
+ const float* features_data = features->data<float>();
+ CATCH_REQUIRE(features_data[0] == 0.0f);
+ CATCH_REQUIRE(features_data[29] == 29.0f);
+
+ // Fill outputs with computation result
+ CATCH_REQUIRE(outputs.size() == 1);
+ float* out_data = outputs[0].data<float>();
+ for (int i = 0; i < 20; ++i) {
+ out_data[i] = static_cast<float>(i * 2);
+ }
+ };
@@ -125 +125,2 @@ CATCH_TEST_CASE("AMSExecute heterogeneous graph fallback path", "[wf][graph]")
- auto& edge_store = graph.getOrCreateEdgeStore(EdgeType{"atom", "bond", "atom"});
+ auto& edge_store =
+ graph.getOrCreateEdgeStore(EdgeType{"atom", "bond", "atom"});
@@ -175,39 +176,39 @@ CATCH_TEST_CASE("AMSExecute heterogeneous graph fallback path", "[wf][graph]")
- HeterogeneousGraphDomainFn callback =
- [&](const AMSHeterogeneousGraph& g, SmallVector<AMSTensor>& outputs) {
- callback_invoked = true;
-
- // Verify graph structure
- CATCH_REQUIRE(g.containsNodeStore("atom"));
- const auto* atom_store_ptr = g.findNodeStore("atom");
- CATCH_REQUIRE(atom_store_ptr != nullptr);
- CATCH_REQUIRE(containsTensor(*atom_store_ptr, "features"));
-
- // Verify node data
- const auto* features = findTensor(*atom_store_ptr, "features");
- CATCH_REQUIRE(features != nullptr);
- const float* features_data = features->data<float>();
- CATCH_REQUIRE(features_data[0] == 1.0f);
- CATCH_REQUIRE(features_data[9] == 10.0f);
-
- // Verify edge store
- CATCH_REQUIRE(g.containsEdgeStore(EdgeType{"atom", "bond", "atom"}));
- const auto* edge_store_ptr =
- g.findEdgeStore(EdgeType{"atom", "bond", "atom"});
- CATCH_REQUIRE(edge_store_ptr != nullptr);
- CATCH_REQUIRE(containsTensor(*edge_store_ptr, "edge_index"));
-
- // Verify global store
- CATCH_REQUIRE(containsTensor(g.global_store, "global"));
- const auto* global = findTensor(g.global_store, "global");
- CATCH_REQUIRE(global != nullptr);
- const float* global_data = global->data<float>();
- CATCH_REQUIRE(global_data[0] == 0.0f);
- CATCH_REQUIRE(global_data[3] == 30.0f);
-
- // Fill outputs
- CATCH_REQUIRE(outputs.size() == 1);
- float* out_data = outputs[0].data<float>();
- for (int i = 0; i < 5; ++i) {
- out_data[i] = static_cast<float>(i * 3);
- }
- };
+ HeterogeneousGraphDomainFn callback = [&](const AMSHeterogeneousGraph& g,
+ SmallVector<AMSTensor>& outputs) {
+ callback_invoked = true;
+
+ // Verify graph structure
+ CATCH_REQUIRE(g.containsNodeStore("atom"));
+ const auto* atom_store_ptr = g.findNodeStore("atom");
+ CATCH_REQUIRE(atom_store_ptr != nullptr);
+ CATCH_REQUIRE(containsTensor(*atom_store_ptr, "features"));
+
+ // Verify node data
+ const auto* features = findTensor(*atom_store_ptr, "features");
+ CATCH_REQUIRE(features != nullptr);
+ const float* features_data = features->data<float>();
+ CATCH_REQUIRE(features_data[0] == 1.0f);
+ CATCH_REQUIRE(features_data[9] == 10.0f);
+
+ // Verify edge store
+ CATCH_REQUIRE(g.containsEdgeStore(EdgeType{"atom", "bond", "atom"}));
+ const auto* edge_store_ptr =
+ g.findEdgeStore(EdgeType{"atom", "bond", "atom"});
+ CATCH_REQUIRE(edge_store_ptr != nullptr);
+ CATCH_REQUIRE(containsTensor(*edge_store_ptr, "edge_index"));
+
+ // Verify global store
+ CATCH_REQUIRE(containsTensor(g.global_store, "global"));
+ const auto* global = findTensor(g.global_store, "global");
+ CATCH_REQUIRE(global != nullptr);
+ const float* global_data = global->data<float>();
+ CATCH_REQUIRE(global_data[0] == 0.0f);
+ CATCH_REQUIRE(global_data[3] == 30.0f);
+
+ // Fill outputs
+ CATCH_REQUIRE(outputs.size() == 1);
+ float* out_data = outputs[0].data<float>();
+ for (int i = 0; i < 5; ++i) {
+ out_data[i] = static_cast<float>(i * 3);
+ }
+ };
@@ -237,2 +238,2 @@ CATCH_TEST_CASE("Graph callback type safety", "[wf][graph]")
- HomogeneousGraphDomainFn homo_fn =
- [](const AMSHomogeneousGraph&, SmallVector<AMSTensor>&) {};
+ HomogeneousGraphDomainFn homo_fn = [](const AMSHomogeneousGraph&,
+ SmallVector<AMSTensor>&) {};
@@ -240,2 +241,2 @@ CATCH_TEST_CASE("Graph callback type safety", "[wf][graph]")
- HeterogeneousGraphDomainFn hetero_fn =
- [](const AMSHeterogeneousGraph&, SmallVector<AMSTensor>&) {};
+ HeterogeneousGraphDomainFn hetero_fn = [](const AMSHeterogeneousGraph&,
+ SmallVector<AMSTensor>&) {};
Have any feedback or feature suggestions? Share it here.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR introduces a first graph-native execution path in AMS for applications that want to use graph-structured inputs, in particular for ND surrogate modeling use cases where the natural unit of work is a whole graph per call rather than a batch of independent 0D samples.
The main goal of this branch is to generalize the wrappers only. It adds public and internal plumbing for homogeneous and heterogeneous graph inputs without trying to retrofit the existing tensor-centric AMS workflow to graphs.
In particular, this PR adds graph-native callback and AMSExecute overloads and keeps the graph as a graph throughout the wrapper path. It does not force graph execution through the existing tensor callback ABI, and it does not yet implement graph-aware surrogate batching, partitioning, slicing, scattering, or DB storage.