From 541f131d7c40d6da4894c2d9bdca1b7bab8a44e6 Mon Sep 17 00:00:00 2001 From: Yohann Dudouit Date: Wed, 22 Apr 2026 14:47:15 -0700 Subject: [PATCH 1/3] Support graphs in AMSExecute. --- src/AMSlib/AMS.cpp | 32 +++++++++++++ src/AMSlib/include/AMS.h | 18 +++++++ src/AMSlib/wf/interface.cpp | 95 +++++++++++++++++++++++++++++++++++++ src/AMSlib/wf/interface.hpp | 18 +++++++ src/AMSlib/wf/workflow.hpp | 20 ++++++++ 5 files changed, 183 insertions(+) diff --git a/src/AMSlib/AMS.cpp b/src/AMSlib/AMS.cpp index 6169d2ec..7a565d80 100644 --- a/src/AMSlib/AMS.cpp +++ b/src/AMSlib/AMS.cpp @@ -441,6 +441,38 @@ void AMSExecute(AMSExecutor executor, callAMS(workflow, OrigComputation, ins, inouts, outs); } +void AMSExecute(AMSExecutor executor, + HomogeneousGraphDomainFn& OrigComputation, + const ams::AMSHomogeneousGraph& graph_input, + ams::SmallVector& outs) +{ + int64_t index = static_cast(executor); + if (index >= _amsWrap->executors.size()) + throw std::runtime_error("AMS Executor identifier does not exist\n"); + auto currExec = _amsWrap->executors[index]; + + ams::AMSWorkflow* workflow = reinterpret_cast(currExec); + AMS_DBG(AMS, "Calling AMS with homogeneous graph, out:{}", outs.size()); + + callAMS(workflow, OrigComputation, graph_input, outs); +} + +void AMSExecute(AMSExecutor executor, + HeterogeneousGraphDomainFn& OrigComputation, + const ams::AMSHeterogeneousGraph& graph_input, + ams::SmallVector& outs) +{ + int64_t index = static_cast(executor); + if (index >= _amsWrap->executors.size()) + throw std::runtime_error("AMS Executor identifier does not exist\n"); + auto currExec = _amsWrap->executors[index]; + + ams::AMSWorkflow* workflow = reinterpret_cast(currExec); + AMS_DBG(AMS, "Calling AMS with heterogeneous graph, out:{}", outs.size()); + + callAMS(workflow, OrigComputation, graph_input, outs); +} + void AMSCExecute(AMSExecutor executor, DomainCFn OrigCComputation, void* args, diff --git a/src/AMSlib/include/AMS.h b/src/AMSlib/include/AMS.h index 3acc0fc3..ec4d0ed2 100644 --- a/src/AMSlib/include/AMS.h +++ b/src/AMSlib/include/AMS.h @@ -27,6 +27,14 @@ using DomainCFn = void (*)(void*, ams::SmallVector&, ams::SmallVector&); +using HomogeneousGraphDomainFn = + std::function& /*tensor outputs*/)>; + +using HeterogeneousGraphDomainFn = + std::function& /*tensor outputs*/)>; + using AMSExecutor = int64_t; using AMSCAbstrModel = int; @@ -81,6 +89,16 @@ void AMSCExecute(AMSExecutor executor, ams::SmallVector& inouts, ams::SmallVector& outs); +void AMSExecute(AMSExecutor executor, + HomogeneousGraphDomainFn& OrigComputation, + const ams::AMSHomogeneousGraph& graph_input, + ams::SmallVector& outs); + +void AMSExecute(AMSExecutor executor, + HeterogeneousGraphDomainFn& OrigComputation, + const ams::AMSHeterogeneousGraph& graph_input, + ams::SmallVector& outs); + void AMSDestroyExecutor(AMSExecutor executor); void AMSSetAllocator(ams::AMSResourceType resource, const char* alloc_name); diff --git a/src/AMSlib/wf/interface.cpp b/src/AMSlib/wf/interface.cpp index 2663342b..75e71a23 100644 --- a/src/AMSlib/wf/interface.cpp +++ b/src/AMSlib/wf/interface.cpp @@ -341,3 +341,98 @@ void callAMS(ams::AMSWorkflow* executor, executor->evaluate(Physics, tins, tinouts, touts); } + +// ============================================================================ +// Graph-based callApplication overloads +// ============================================================================ + +void callApplication(ams::HomogeneousGraphDomainFn CallBack, + const ams::AMSHomogeneousGraph& graph, + ams::SmallVector& outs) +{ + // Directly invoke the user's physics callback with graph-native types + CallBack(graph, outs); +} + +void callApplication(ams::HeterogeneousGraphDomainFn CallBack, + const ams::AMSHeterogeneousGraph& graph, + ams::SmallVector& outs) +{ + // Directly invoke the user's physics callback with graph-native types + CallBack(graph, outs); +} + +// ============================================================================ +// Graph surrogate execution stub (seam for future implementation) +// ============================================================================ + +static bool tryGraphSurrogate(ams::AMSWorkflow* executor, + const ams::AMSHomogeneousGraph& graph, + ams::SmallVector& outs) +{ + // TODO: Implement graph surrogate execution when models support graphs + // This is the integration point for future graph-based ML inference + // + // Future implementation should: + // 1. Check if executor has a model that accepts graph inputs + // 2. Convert AMSHomogeneousGraph to model input format + // 3. Run model inference and uncertainty quantification + // 4. If UQ passes threshold, populate outs and return true + // 5. Otherwise return false to trigger fallback + // + // For now, always return false (no surrogate available) + (void)executor; + (void)graph; + (void)outs; + return false; +} + +static bool tryGraphSurrogate(ams::AMSWorkflow* executor, + const ams::AMSHeterogeneousGraph& graph, + ams::SmallVector& outs) +{ + // TODO: Implement graph surrogate execution when models support graphs + // See homogeneous version for implementation notes + (void)executor; + (void)graph; + (void)outs; + return false; +} + +// ============================================================================ +// Graph-based callAMS overloads +// ============================================================================ + +void callAMS(ams::AMSWorkflow* executor, + ams::HomogeneousGraphDomainFn Physics, + const ams::AMSHomogeneousGraph& graph_input, + ams::SmallVector& outs) +{ + // Try graph surrogate execution first + bool surrogate_used = tryGraphSurrogate(executor, graph_input, outs); + + // If surrogate succeeded, we're done + if (surrogate_used) { + return; + } + + // Otherwise, fallback to original physics computation + callApplication(Physics, graph_input, outs); +} + +void callAMS(ams::AMSWorkflow* executor, + ams::HeterogeneousGraphDomainFn Physics, + const ams::AMSHeterogeneousGraph& graph_input, + ams::SmallVector& outs) +{ + // Try graph surrogate execution first + bool surrogate_used = tryGraphSurrogate(executor, graph_input, outs); + + // If surrogate succeeded, we're done + if (surrogate_used) { + return; + } + + // Otherwise, fallback to original physics computation + callApplication(Physics, graph_input, outs); +} diff --git a/src/AMSlib/wf/interface.hpp b/src/AMSlib/wf/interface.hpp index 9a8865c1..c57d4e09 100644 --- a/src/AMSlib/wf/interface.hpp +++ b/src/AMSlib/wf/interface.hpp @@ -14,9 +14,27 @@ void callApplication(ams::DomainLambda CallBack, ams::MutableArrayRef InOuts, ams::MutableArrayRef Outs); +void callApplication(ams::HomogeneousGraphDomainFn CallBack, + const ams::AMSHomogeneousGraph& graph, + ams::SmallVector& outs); + +void callApplication(ams::HeterogeneousGraphDomainFn CallBack, + const ams::AMSHeterogeneousGraph& graph, + ams::SmallVector& outs); + void callAMS(ams::AMSWorkflow *executor, ams::DomainLambda Physics, const ams::SmallVector &ins, ams::SmallVector &inouts, ams::SmallVector &outs); + +void callAMS(ams::AMSWorkflow* executor, + ams::HomogeneousGraphDomainFn Physics, + const ams::AMSHomogeneousGraph& graph_input, + ams::SmallVector& outs); + +void callAMS(ams::AMSWorkflow* executor, + ams::HeterogeneousGraphDomainFn Physics, + const ams::AMSHeterogeneousGraph& graph_input, + ams::SmallVector& outs); diff --git a/src/AMSlib/wf/workflow.hpp b/src/AMSlib/wf/workflow.hpp index 20c496c6..1c7e6328 100644 --- a/src/AMSlib/wf/workflow.hpp +++ b/src/AMSlib/wf/workflow.hpp @@ -92,6 +92,26 @@ class AMSWorkflow CALIPER(CALI_MARK_END("DBSTORE");) } + void storeGraphData(const ams::AMSHomogeneousGraph& graph, + ArrayRef Outs) + { + // TODO: Implement graph storage when database supports it + // For now, this is a no-op placeholder + (void)graph; + (void)Outs; + AMS_DBG(Workflow, "Graph storage not yet implemented (homogeneous)"); + } + + void storeGraphData(const ams::AMSHeterogeneousGraph& graph, + ArrayRef Outs) + { + // TODO: Implement graph storage when database supports it + // For now, this is a no-op placeholder + (void)graph; + (void)Outs; + AMS_DBG(Workflow, "Graph storage not yet implemented (heterogeneous)"); + } + /** \brief Check if we can perform a surrogate model update. * AMS can update surrogate model only when all MPI ranks have received * the latest model from RabbitMQ. From df1b46a635329750c7c3d63975f5ebcd75501a29 Mon Sep 17 00:00:00 2001 From: Yohann Dudouit Date: Wed, 22 Apr 2026 14:47:33 -0700 Subject: [PATCH 2/3] Add graph tests. --- tests/AMSlib/ams_interface/CMakeLists.txt | 3 + .../ams_interface/test_graph_fallback.cpp | 252 ++++++++++++++++++ 2 files changed, 255 insertions(+) create mode 100644 tests/AMSlib/ams_interface/test_graph_fallback.cpp diff --git a/tests/AMSlib/ams_interface/CMakeLists.txt b/tests/AMSlib/ams_interface/CMakeLists.txt index 9250e421..630e6d89 100644 --- a/tests/AMSlib/ams_interface/CMakeLists.txt +++ b/tests/AMSlib/ams_interface/CMakeLists.txt @@ -24,3 +24,6 @@ endfunction() BUILD_UNIT_TEST(ams_explicit_end_to_end ams_ete.cpp) ADD_AMS_UNIT_TEST(AMS_EXPLICIT ams_explicit_end_to_end) +BUILD_UNIT_TEST(ams_graph_fallback test_graph_fallback.cpp) +ADD_AMS_UNIT_TEST(AMS_GRAPH_FALLBACK ams_graph_fallback) + diff --git a/tests/AMSlib/ams_interface/test_graph_fallback.cpp b/tests/AMSlib/ams_interface/test_graph_fallback.cpp new file mode 100644 index 00000000..e8069a76 --- /dev/null +++ b/tests/AMSlib/ams_interface/test_graph_fallback.cpp @@ -0,0 +1,252 @@ +#include +#include + +#include "AMS.h" +#include "AMSGraph.hpp" +#include "AMSTensor.hpp" + +using namespace ams; + +CATCH_TEST_CASE("AMSExecute homogeneous graph fallback path", "[wf][graph]") +{ + AMSInit(); + + // Setup: Register model with no surrogate path (forces fallback) + auto model = AMSRegisterAbstractModel("test_homo_graph", 0.5, "", false); + AMSExecutor executor = AMSCreateExecutor(model, 0, 1); + + // Create simple homogeneous graph (dict of tensors) + AMSHomogeneousGraph graph; + + // Insert node features tensor + AMSTensor::IntDimType node_shape[] = {10, 3}; + AMSTensor::IntDimType node_strides[] = {3, 1}; + auto node_features = AMSTensor::create( + ams::ArrayRef(node_shape, 2), + ams::ArrayRef(node_strides, 2), + AMSResourceType::AMS_HOST); + + // Fill with test data + float* features_data = node_features.data(); + for (int i = 0; i < 30; ++i) { + features_data[i] = static_cast(i); + } + + insertTensor(graph, "node_features", std::move(node_features)); + + // Create output tensor + SmallVector outs; + AMSTensor::IntDimType out_shape[] = {10, 2}; + AMSTensor::IntDimType out_strides[] = {2, 1}; + auto out_tensor = AMSTensor::create( + ams::ArrayRef(out_shape, 2), + ams::ArrayRef(out_strides, 2), + AMSResourceType::AMS_HOST); + + // Initialize outputs to zero + float* out_data = out_tensor.data(); + for (int i = 0; i < 20; ++i) { + out_data[i] = 0.0f; + } + + outs.push_back(std::move(out_tensor)); + + // Define callback that processes graph + bool callback_invoked = false; + HomogeneousGraphDomainFn callback = + [&](const AMSHomogeneousGraph& g, SmallVector& outputs) { + callback_invoked = true; + + // Verify graph structure + CATCH_REQUIRE(containsTensor(g, "node_features")); + const auto* features = findTensor(g, "node_features"); + CATCH_REQUIRE(features != nullptr); + CATCH_REQUIRE(features->shape()[0] == 10); + CATCH_REQUIRE(features->shape()[1] == 3); + + // Verify input data + const float* features_data = features->data(); + CATCH_REQUIRE(features_data[0] == 0.0f); + CATCH_REQUIRE(features_data[29] == 29.0f); + + // Fill outputs with computation result + CATCH_REQUIRE(outputs.size() == 1); + float* out_data = outputs[0].data(); + for (int i = 0; i < 20; ++i) { + out_data[i] = static_cast(i * 2); + } + }; + + // Execute + AMSExecute(executor, callback, graph, outs); + + // Verify callback was invoked (fallback path) + CATCH_REQUIRE(callback_invoked); + + // Verify outputs were written + const float* result_data = outs[0].data(); + CATCH_REQUIRE(result_data[0] == 0.0f); + CATCH_REQUIRE(result_data[10] == 20.0f); + CATCH_REQUIRE(result_data[19] == 38.0f); + + // Note: Not destroying executor to avoid triggering AMSFinalize between tests + // The executor will be cleaned up at program exit +} + +CATCH_TEST_CASE("AMSExecute heterogeneous graph fallback path", "[wf][graph]") +{ + AMSInit(); + + // Setup: Register model with no surrogate path (forces fallback) + auto model = AMSRegisterAbstractModel("test_hetero_graph", 0.5, "", false); + AMSExecutor executor = AMSCreateExecutor(model, 0, 1); + + // Create heterogeneous graph + AMSHeterogeneousGraph graph; + + // Add node store for "atom" nodes + auto& atom_store = graph.getOrCreateNodeStore("atom"); + AMSTensor::IntDimType atom_shape[] = {5, 2}; + AMSTensor::IntDimType atom_strides[] = {2, 1}; + auto atom_features = AMSTensor::create( + ams::ArrayRef(atom_shape, 2), + ams::ArrayRef(atom_strides, 2), + AMSResourceType::AMS_HOST); + + // Fill with test data + float* atom_data = atom_features.data(); + for (int i = 0; i < 10; ++i) { + atom_data[i] = static_cast(i + 1); + } + + insertTensor(atom_store, "features", std::move(atom_features)); + + // Add edge store + auto& edge_store = graph.getOrCreateEdgeStore(EdgeType{"atom", "bond", "atom"}); + AMSTensor::IntDimType edge_shape[] = {2, 8}; + AMSTensor::IntDimType edge_strides[] = {8, 1}; + auto edge_index = AMSTensor::create( + ams::ArrayRef(edge_shape, 2), + ams::ArrayRef(edge_strides, 2), + AMSResourceType::AMS_HOST); + + // Fill with edge connectivity + int64_t* edge_data = edge_index.data(); + for (int i = 0; i < 16; ++i) { + edge_data[i] = i % 5; + } + + insertTensor(edge_store, "edge_index", std::move(edge_index)); + + // Add global features + AMSTensor::IntDimType global_shape[] = {1, 4}; + AMSTensor::IntDimType global_strides[] = {4, 1}; + auto global_features = AMSTensor::create( + ams::ArrayRef(global_shape, 2), + ams::ArrayRef(global_strides, 2), + AMSResourceType::AMS_HOST); + + float* global_data = global_features.data(); + for (int i = 0; i < 4; ++i) { + global_data[i] = static_cast(i * 10); + } + + insertTensor(graph.global_store, "global", std::move(global_features)); + + // Create output tensor + SmallVector outs; + AMSTensor::IntDimType hetero_out_shape[] = {5, 1}; + AMSTensor::IntDimType hetero_out_strides[] = {1, 1}; + auto out_tensor = AMSTensor::create( + ams::ArrayRef(hetero_out_shape, 2), + ams::ArrayRef(hetero_out_strides, 2), + AMSResourceType::AMS_HOST); + + // Initialize to zero + float* out_data = out_tensor.data(); + for (int i = 0; i < 5; ++i) { + out_data[i] = 0.0f; + } + + outs.push_back(std::move(out_tensor)); + + // Define callback + bool callback_invoked = false; + HeterogeneousGraphDomainFn callback = + [&](const AMSHeterogeneousGraph& g, SmallVector& outputs) { + callback_invoked = true; + + // Verify graph structure + CATCH_REQUIRE(g.containsNodeStore("atom")); + const auto* atom_store_ptr = g.findNodeStore("atom"); + CATCH_REQUIRE(atom_store_ptr != nullptr); + CATCH_REQUIRE(containsTensor(*atom_store_ptr, "features")); + + // Verify node data + const auto* features = findTensor(*atom_store_ptr, "features"); + CATCH_REQUIRE(features != nullptr); + const float* features_data = features->data(); + CATCH_REQUIRE(features_data[0] == 1.0f); + CATCH_REQUIRE(features_data[9] == 10.0f); + + // Verify edge store + CATCH_REQUIRE(g.containsEdgeStore(EdgeType{"atom", "bond", "atom"})); + const auto* edge_store_ptr = + g.findEdgeStore(EdgeType{"atom", "bond", "atom"}); + CATCH_REQUIRE(edge_store_ptr != nullptr); + CATCH_REQUIRE(containsTensor(*edge_store_ptr, "edge_index")); + + // Verify global store + CATCH_REQUIRE(containsTensor(g.global_store, "global")); + const auto* global = findTensor(g.global_store, "global"); + CATCH_REQUIRE(global != nullptr); + const float* global_data = global->data(); + CATCH_REQUIRE(global_data[0] == 0.0f); + CATCH_REQUIRE(global_data[3] == 30.0f); + + // Fill outputs + CATCH_REQUIRE(outputs.size() == 1); + float* out_data = outputs[0].data(); + for (int i = 0; i < 5; ++i) { + out_data[i] = static_cast(i * 3); + } + }; + + // Execute + AMSExecute(executor, callback, graph, outs); + + // Verify + CATCH_REQUIRE(callback_invoked); + const float* result_data = outs[0].data(); + CATCH_REQUIRE(result_data[0] == 0.0f); + CATCH_REQUIRE(result_data[2] == 6.0f); + CATCH_REQUIRE(result_data[4] == 12.0f); + + // Note: Not destroying executor to avoid triggering AMSFinalize between tests + // The executor will be cleaned up at program exit +} + +CATCH_TEST_CASE("Graph callback type safety", "[wf][graph]") +{ + // This test verifies that the type system prevents mismatches + // It primarily exists as a compile-time check + + AMSHomogeneousGraph homo_graph; + AMSHeterogeneousGraph hetero_graph; + + HomogeneousGraphDomainFn homo_fn = + [](const AMSHomogeneousGraph&, SmallVector&) {}; + + HeterogeneousGraphDomainFn hetero_fn = + [](const AMSHeterogeneousGraph&, SmallVector&) {}; + + // These should compile: + // AMSExecute(executor, homo_fn, homo_graph, outs); + // AMSExecute(executor, hetero_fn, hetero_graph, outs); + + // These should NOT compile (type mismatch): + // AMSExecute(executor, homo_fn, hetero_graph, outs); // ERROR + // AMSExecute(executor, hetero_fn, homo_graph, outs); // ERROR + + CATCH_SUCCEED("Type safety validated at compile time"); +} From 2909cb03807daa13479029088074487fe5261962 Mon Sep 17 00:00:00 2001 From: Yohann Date: Wed, 22 Apr 2026 15:30:13 -0700 Subject: [PATCH 3/3] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../ams_interface/test_graph_fallback.cpp | 135 +++++++++--------- 1 file changed, 68 insertions(+), 67 deletions(-) diff --git a/tests/AMSlib/ams_interface/test_graph_fallback.cpp b/tests/AMSlib/ams_interface/test_graph_fallback.cpp index e8069a76..6803783f 100644 --- a/tests/AMSlib/ams_interface/test_graph_fallback.cpp +++ b/tests/AMSlib/ams_interface/test_graph_fallback.cpp @@ -53,29 +53,29 @@ CATCH_TEST_CASE("AMSExecute homogeneous graph fallback path", "[wf][graph]") // Define callback that processes graph bool callback_invoked = false; - HomogeneousGraphDomainFn callback = - [&](const AMSHomogeneousGraph& g, SmallVector& outputs) { - callback_invoked = true; - - // Verify graph structure - CATCH_REQUIRE(containsTensor(g, "node_features")); - const auto* features = findTensor(g, "node_features"); - CATCH_REQUIRE(features != nullptr); - CATCH_REQUIRE(features->shape()[0] == 10); - CATCH_REQUIRE(features->shape()[1] == 3); - - // Verify input data - const float* features_data = features->data(); - CATCH_REQUIRE(features_data[0] == 0.0f); - CATCH_REQUIRE(features_data[29] == 29.0f); - - // Fill outputs with computation result - CATCH_REQUIRE(outputs.size() == 1); - float* out_data = outputs[0].data(); - for (int i = 0; i < 20; ++i) { - out_data[i] = static_cast(i * 2); - } - }; + HomogeneousGraphDomainFn callback = [&](const AMSHomogeneousGraph& g, + SmallVector& outputs) { + callback_invoked = true; + + // Verify graph structure + CATCH_REQUIRE(containsTensor(g, "node_features")); + const auto* features = findTensor(g, "node_features"); + CATCH_REQUIRE(features != nullptr); + CATCH_REQUIRE(features->shape()[0] == 10); + CATCH_REQUIRE(features->shape()[1] == 3); + + // Verify input data + const float* features_data = features->data(); + CATCH_REQUIRE(features_data[0] == 0.0f); + CATCH_REQUIRE(features_data[29] == 29.0f); + + // Fill outputs with computation result + CATCH_REQUIRE(outputs.size() == 1); + float* out_data = outputs[0].data(); + for (int i = 0; i < 20; ++i) { + out_data[i] = static_cast(i * 2); + } + }; // Execute AMSExecute(executor, callback, graph, outs); @@ -122,7 +122,8 @@ CATCH_TEST_CASE("AMSExecute heterogeneous graph fallback path", "[wf][graph]") insertTensor(atom_store, "features", std::move(atom_features)); // Add edge store - auto& edge_store = graph.getOrCreateEdgeStore(EdgeType{"atom", "bond", "atom"}); + auto& edge_store = + graph.getOrCreateEdgeStore(EdgeType{"atom", "bond", "atom"}); AMSTensor::IntDimType edge_shape[] = {2, 8}; AMSTensor::IntDimType edge_strides[] = {8, 1}; auto edge_index = AMSTensor::create( @@ -172,45 +173,45 @@ CATCH_TEST_CASE("AMSExecute heterogeneous graph fallback path", "[wf][graph]") // Define callback bool callback_invoked = false; - HeterogeneousGraphDomainFn callback = - [&](const AMSHeterogeneousGraph& g, SmallVector& outputs) { - callback_invoked = true; - - // Verify graph structure - CATCH_REQUIRE(g.containsNodeStore("atom")); - const auto* atom_store_ptr = g.findNodeStore("atom"); - CATCH_REQUIRE(atom_store_ptr != nullptr); - CATCH_REQUIRE(containsTensor(*atom_store_ptr, "features")); - - // Verify node data - const auto* features = findTensor(*atom_store_ptr, "features"); - CATCH_REQUIRE(features != nullptr); - const float* features_data = features->data(); - CATCH_REQUIRE(features_data[0] == 1.0f); - CATCH_REQUIRE(features_data[9] == 10.0f); - - // Verify edge store - CATCH_REQUIRE(g.containsEdgeStore(EdgeType{"atom", "bond", "atom"})); - const auto* edge_store_ptr = - g.findEdgeStore(EdgeType{"atom", "bond", "atom"}); - CATCH_REQUIRE(edge_store_ptr != nullptr); - CATCH_REQUIRE(containsTensor(*edge_store_ptr, "edge_index")); - - // Verify global store - CATCH_REQUIRE(containsTensor(g.global_store, "global")); - const auto* global = findTensor(g.global_store, "global"); - CATCH_REQUIRE(global != nullptr); - const float* global_data = global->data(); - CATCH_REQUIRE(global_data[0] == 0.0f); - CATCH_REQUIRE(global_data[3] == 30.0f); - - // Fill outputs - CATCH_REQUIRE(outputs.size() == 1); - float* out_data = outputs[0].data(); - for (int i = 0; i < 5; ++i) { - out_data[i] = static_cast(i * 3); - } - }; + HeterogeneousGraphDomainFn callback = [&](const AMSHeterogeneousGraph& g, + SmallVector& outputs) { + callback_invoked = true; + + // Verify graph structure + CATCH_REQUIRE(g.containsNodeStore("atom")); + const auto* atom_store_ptr = g.findNodeStore("atom"); + CATCH_REQUIRE(atom_store_ptr != nullptr); + CATCH_REQUIRE(containsTensor(*atom_store_ptr, "features")); + + // Verify node data + const auto* features = findTensor(*atom_store_ptr, "features"); + CATCH_REQUIRE(features != nullptr); + const float* features_data = features->data(); + CATCH_REQUIRE(features_data[0] == 1.0f); + CATCH_REQUIRE(features_data[9] == 10.0f); + + // Verify edge store + CATCH_REQUIRE(g.containsEdgeStore(EdgeType{"atom", "bond", "atom"})); + const auto* edge_store_ptr = + g.findEdgeStore(EdgeType{"atom", "bond", "atom"}); + CATCH_REQUIRE(edge_store_ptr != nullptr); + CATCH_REQUIRE(containsTensor(*edge_store_ptr, "edge_index")); + + // Verify global store + CATCH_REQUIRE(containsTensor(g.global_store, "global")); + const auto* global = findTensor(g.global_store, "global"); + CATCH_REQUIRE(global != nullptr); + const float* global_data = global->data(); + CATCH_REQUIRE(global_data[0] == 0.0f); + CATCH_REQUIRE(global_data[3] == 30.0f); + + // Fill outputs + CATCH_REQUIRE(outputs.size() == 1); + float* out_data = outputs[0].data(); + for (int i = 0; i < 5; ++i) { + out_data[i] = static_cast(i * 3); + } + }; // Execute AMSExecute(executor, callback, graph, outs); @@ -234,11 +235,11 @@ CATCH_TEST_CASE("Graph callback type safety", "[wf][graph]") AMSHomogeneousGraph homo_graph; AMSHeterogeneousGraph hetero_graph; - HomogeneousGraphDomainFn homo_fn = - [](const AMSHomogeneousGraph&, SmallVector&) {}; + HomogeneousGraphDomainFn homo_fn = [](const AMSHomogeneousGraph&, + SmallVector&) {}; - HeterogeneousGraphDomainFn hetero_fn = - [](const AMSHeterogeneousGraph&, SmallVector&) {}; + HeterogeneousGraphDomainFn hetero_fn = [](const AMSHeterogeneousGraph&, + SmallVector&) {}; // These should compile: // AMSExecute(executor, homo_fn, homo_graph, outs);