Skip to content

Add initial graph-native AMS wrappers for homogeneous and heterogeneous graph inputs#191

Open
YohannDudouit wants to merge 3 commits into
yohann/graph-supportfrom
yohann/graph-execute
Open

Add initial graph-native AMS wrappers for homogeneous and heterogeneous graph inputs#191
YohannDudouit wants to merge 3 commits into
yohann/graph-supportfrom
yohann/graph-execute

Conversation

@YohannDudouit
Copy link
Copy Markdown
Collaborator

This PR introduces a first graph-native execution path in AMS for applications that want to use graph-structured inputs, in particular for ND surrogate modeling use cases where the natural unit of work is a whole graph per call rather than a batch of independent 0D samples.

The main goal of this branch is to generalize the wrappers only. It adds public and internal plumbing for homogeneous and heterogeneous graph inputs without trying to retrofit the existing tensor-centric AMS workflow to graphs.

In particular, this PR adds graph-native callback and AMSExecute overloads and keeps the graph as a graph throughout the wrapper path. It does not force graph execution through the existing tensor callback ABI, and it does not yet implement graph-aware surrogate batching, partitioning, slicing, scattering, or DB storage.

@YohannDudouit YohannDudouit self-assigned this Apr 22, 2026
Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cpp-linter Review

Used clang-format v18.1.8

Click here for the full clang-format patch
diff --git a/tests/AMSlib/ams_interface/test_graph_fallback.cpp b/tests/AMSlib/ams_interface/test_graph_fallback.cpp
index e8069a7..6803783 100644
--- a/tests/AMSlib/ams_interface/test_graph_fallback.cpp
+++ b/tests/AMSlib/ams_interface/test_graph_fallback.cpp
@@ -56,23 +56,23 @@ CATCH_TEST_CASE("AMSExecute homogeneous graph fallback path", "[wf][graph]")
-  HomogeneousGraphDomainFn callback =
-      [&](const AMSHomogeneousGraph& g, SmallVector<AMSTensor>& outputs) {
-        callback_invoked = true;
-
-        // Verify graph structure
-        CATCH_REQUIRE(containsTensor(g, "node_features"));
-        const auto* features = findTensor(g, "node_features");
-        CATCH_REQUIRE(features != nullptr);
-        CATCH_REQUIRE(features->shape()[0] == 10);
-        CATCH_REQUIRE(features->shape()[1] == 3);
-
-        // Verify input data
-        const float* features_data = features->data<float>();
-        CATCH_REQUIRE(features_data[0] == 0.0f);
-        CATCH_REQUIRE(features_data[29] == 29.0f);
-
-        // Fill outputs with computation result
-        CATCH_REQUIRE(outputs.size() == 1);
-        float* out_data = outputs[0].data<float>();
-        for (int i = 0; i < 20; ++i) {
-          out_data[i] = static_cast<float>(i * 2);
-        }
-      };
+  HomogeneousGraphDomainFn callback = [&](const AMSHomogeneousGraph& g,
+                                          SmallVector<AMSTensor>& outputs) {
+    callback_invoked = true;
+
+    // Verify graph structure
+    CATCH_REQUIRE(containsTensor(g, "node_features"));
+    const auto* features = findTensor(g, "node_features");
+    CATCH_REQUIRE(features != nullptr);
+    CATCH_REQUIRE(features->shape()[0] == 10);
+    CATCH_REQUIRE(features->shape()[1] == 3);
+
+    // Verify input data
+    const float* features_data = features->data<float>();
+    CATCH_REQUIRE(features_data[0] == 0.0f);
+    CATCH_REQUIRE(features_data[29] == 29.0f);
+
+    // Fill outputs with computation result
+    CATCH_REQUIRE(outputs.size() == 1);
+    float* out_data = outputs[0].data<float>();
+    for (int i = 0; i < 20; ++i) {
+      out_data[i] = static_cast<float>(i * 2);
+    }
+  };
@@ -125 +125,2 @@ CATCH_TEST_CASE("AMSExecute heterogeneous graph fallback path", "[wf][graph]")
-  auto& edge_store = graph.getOrCreateEdgeStore(EdgeType{"atom", "bond", "atom"});
+  auto& edge_store =
+      graph.getOrCreateEdgeStore(EdgeType{"atom", "bond", "atom"});
@@ -175,39 +176,39 @@ CATCH_TEST_CASE("AMSExecute heterogeneous graph fallback path", "[wf][graph]")
-  HeterogeneousGraphDomainFn callback =
-      [&](const AMSHeterogeneousGraph& g, SmallVector<AMSTensor>& outputs) {
-        callback_invoked = true;
-
-        // Verify graph structure
-        CATCH_REQUIRE(g.containsNodeStore("atom"));
-        const auto* atom_store_ptr = g.findNodeStore("atom");
-        CATCH_REQUIRE(atom_store_ptr != nullptr);
-        CATCH_REQUIRE(containsTensor(*atom_store_ptr, "features"));
-
-        // Verify node data
-        const auto* features = findTensor(*atom_store_ptr, "features");
-        CATCH_REQUIRE(features != nullptr);
-        const float* features_data = features->data<float>();
-        CATCH_REQUIRE(features_data[0] == 1.0f);
-        CATCH_REQUIRE(features_data[9] == 10.0f);
-
-        // Verify edge store
-        CATCH_REQUIRE(g.containsEdgeStore(EdgeType{"atom", "bond", "atom"}));
-        const auto* edge_store_ptr =
-            g.findEdgeStore(EdgeType{"atom", "bond", "atom"});
-        CATCH_REQUIRE(edge_store_ptr != nullptr);
-        CATCH_REQUIRE(containsTensor(*edge_store_ptr, "edge_index"));
-
-        // Verify global store
-        CATCH_REQUIRE(containsTensor(g.global_store, "global"));
-        const auto* global = findTensor(g.global_store, "global");
-        CATCH_REQUIRE(global != nullptr);
-        const float* global_data = global->data<float>();
-        CATCH_REQUIRE(global_data[0] == 0.0f);
-        CATCH_REQUIRE(global_data[3] == 30.0f);
-
-        // Fill outputs
-        CATCH_REQUIRE(outputs.size() == 1);
-        float* out_data = outputs[0].data<float>();
-        for (int i = 0; i < 5; ++i) {
-          out_data[i] = static_cast<float>(i * 3);
-        }
-      };
+  HeterogeneousGraphDomainFn callback = [&](const AMSHeterogeneousGraph& g,
+                                            SmallVector<AMSTensor>& outputs) {
+    callback_invoked = true;
+
+    // Verify graph structure
+    CATCH_REQUIRE(g.containsNodeStore("atom"));
+    const auto* atom_store_ptr = g.findNodeStore("atom");
+    CATCH_REQUIRE(atom_store_ptr != nullptr);
+    CATCH_REQUIRE(containsTensor(*atom_store_ptr, "features"));
+
+    // Verify node data
+    const auto* features = findTensor(*atom_store_ptr, "features");
+    CATCH_REQUIRE(features != nullptr);
+    const float* features_data = features->data<float>();
+    CATCH_REQUIRE(features_data[0] == 1.0f);
+    CATCH_REQUIRE(features_data[9] == 10.0f);
+
+    // Verify edge store
+    CATCH_REQUIRE(g.containsEdgeStore(EdgeType{"atom", "bond", "atom"}));
+    const auto* edge_store_ptr =
+        g.findEdgeStore(EdgeType{"atom", "bond", "atom"});
+    CATCH_REQUIRE(edge_store_ptr != nullptr);
+    CATCH_REQUIRE(containsTensor(*edge_store_ptr, "edge_index"));
+
+    // Verify global store
+    CATCH_REQUIRE(containsTensor(g.global_store, "global"));
+    const auto* global = findTensor(g.global_store, "global");
+    CATCH_REQUIRE(global != nullptr);
+    const float* global_data = global->data<float>();
+    CATCH_REQUIRE(global_data[0] == 0.0f);
+    CATCH_REQUIRE(global_data[3] == 30.0f);
+
+    // Fill outputs
+    CATCH_REQUIRE(outputs.size() == 1);
+    float* out_data = outputs[0].data<float>();
+    for (int i = 0; i < 5; ++i) {
+      out_data[i] = static_cast<float>(i * 3);
+    }
+  };
@@ -237,2 +238,2 @@ CATCH_TEST_CASE("Graph callback type safety", "[wf][graph]")
-  HomogeneousGraphDomainFn homo_fn =
-      [](const AMSHomogeneousGraph&, SmallVector<AMSTensor>&) {};
+  HomogeneousGraphDomainFn homo_fn = [](const AMSHomogeneousGraph&,
+                                        SmallVector<AMSTensor>&) {};
@@ -240,2 +241,2 @@ CATCH_TEST_CASE("Graph callback type safety", "[wf][graph]")
-  HeterogeneousGraphDomainFn hetero_fn =
-      [](const AMSHeterogeneousGraph&, SmallVector<AMSTensor>&) {};
+  HeterogeneousGraphDomainFn hetero_fn = [](const AMSHeterogeneousGraph&,
+                                            SmallVector<AMSTensor>&) {};

Have any feedback or feature suggestions? Share it here.

Comment thread tests/AMSlib/ams_interface/test_graph_fallback.cpp Outdated
Comment thread tests/AMSlib/ams_interface/test_graph_fallback.cpp Outdated
Comment thread tests/AMSlib/ams_interface/test_graph_fallback.cpp Outdated
Comment thread tests/AMSlib/ams_interface/test_graph_fallback.cpp Outdated
Comment thread tests/AMSlib/ams_interface/test_graph_fallback.cpp Outdated
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant