diff --git a/src/AMSlib/AMS.cpp b/src/AMSlib/AMS.cpp index 6169d2ec..7a565d80 100644 --- a/src/AMSlib/AMS.cpp +++ b/src/AMSlib/AMS.cpp @@ -441,6 +441,38 @@ void AMSExecute(AMSExecutor executor, callAMS(workflow, OrigComputation, ins, inouts, outs); } +void AMSExecute(AMSExecutor executor, + HomogeneousGraphDomainFn& OrigComputation, + const ams::AMSHomogeneousGraph& graph_input, + ams::SmallVector& outs) +{ + int64_t index = static_cast(executor); + if (index >= _amsWrap->executors.size()) + throw std::runtime_error("AMS Executor identifier does not exist\n"); + auto currExec = _amsWrap->executors[index]; + + ams::AMSWorkflow* workflow = reinterpret_cast(currExec); + AMS_DBG(AMS, "Calling AMS with homogeneous graph, out:{}", outs.size()); + + callAMS(workflow, OrigComputation, graph_input, outs); +} + +void AMSExecute(AMSExecutor executor, + HeterogeneousGraphDomainFn& OrigComputation, + const ams::AMSHeterogeneousGraph& graph_input, + ams::SmallVector& outs) +{ + int64_t index = static_cast(executor); + if (index >= _amsWrap->executors.size()) + throw std::runtime_error("AMS Executor identifier does not exist\n"); + auto currExec = _amsWrap->executors[index]; + + ams::AMSWorkflow* workflow = reinterpret_cast(currExec); + AMS_DBG(AMS, "Calling AMS with heterogeneous graph, out:{}", outs.size()); + + callAMS(workflow, OrigComputation, graph_input, outs); +} + void AMSCExecute(AMSExecutor executor, DomainCFn OrigCComputation, void* args, diff --git a/src/AMSlib/include/AMS.h b/src/AMSlib/include/AMS.h index 3acc0fc3..ec4d0ed2 100644 --- a/src/AMSlib/include/AMS.h +++ b/src/AMSlib/include/AMS.h @@ -27,6 +27,14 @@ using DomainCFn = void (*)(void*, ams::SmallVector&, ams::SmallVector&); +using HomogeneousGraphDomainFn = + std::function& /*tensor outputs*/)>; + +using HeterogeneousGraphDomainFn = + std::function& /*tensor outputs*/)>; + using AMSExecutor = int64_t; using AMSCAbstrModel = int; @@ -81,6 +89,16 @@ void AMSCExecute(AMSExecutor executor, ams::SmallVector& inouts, ams::SmallVector& outs); +void AMSExecute(AMSExecutor executor, + HomogeneousGraphDomainFn& OrigComputation, + const ams::AMSHomogeneousGraph& graph_input, + ams::SmallVector& outs); + +void AMSExecute(AMSExecutor executor, + HeterogeneousGraphDomainFn& OrigComputation, + const ams::AMSHeterogeneousGraph& graph_input, + ams::SmallVector& outs); + void AMSDestroyExecutor(AMSExecutor executor); void AMSSetAllocator(ams::AMSResourceType resource, const char* alloc_name); diff --git a/src/AMSlib/wf/interface.cpp b/src/AMSlib/wf/interface.cpp index 2663342b..75e71a23 100644 --- a/src/AMSlib/wf/interface.cpp +++ b/src/AMSlib/wf/interface.cpp @@ -341,3 +341,98 @@ void callAMS(ams::AMSWorkflow* executor, executor->evaluate(Physics, tins, tinouts, touts); } + +// ============================================================================ +// Graph-based callApplication overloads +// ============================================================================ + +void callApplication(ams::HomogeneousGraphDomainFn CallBack, + const ams::AMSHomogeneousGraph& graph, + ams::SmallVector& outs) +{ + // Directly invoke the user's physics callback with graph-native types + CallBack(graph, outs); +} + +void callApplication(ams::HeterogeneousGraphDomainFn CallBack, + const ams::AMSHeterogeneousGraph& graph, + ams::SmallVector& outs) +{ + // Directly invoke the user's physics callback with graph-native types + CallBack(graph, outs); +} + +// ============================================================================ +// Graph surrogate execution stub (seam for future implementation) +// ============================================================================ + +static bool tryGraphSurrogate(ams::AMSWorkflow* executor, + const ams::AMSHomogeneousGraph& graph, + ams::SmallVector& outs) +{ + // TODO: Implement graph surrogate execution when models support graphs + // This is the integration point for future graph-based ML inference + // + // Future implementation should: + // 1. Check if executor has a model that accepts graph inputs + // 2. Convert AMSHomogeneousGraph to model input format + // 3. Run model inference and uncertainty quantification + // 4. If UQ passes threshold, populate outs and return true + // 5. Otherwise return false to trigger fallback + // + // For now, always return false (no surrogate available) + (void)executor; + (void)graph; + (void)outs; + return false; +} + +static bool tryGraphSurrogate(ams::AMSWorkflow* executor, + const ams::AMSHeterogeneousGraph& graph, + ams::SmallVector& outs) +{ + // TODO: Implement graph surrogate execution when models support graphs + // See homogeneous version for implementation notes + (void)executor; + (void)graph; + (void)outs; + return false; +} + +// ============================================================================ +// Graph-based callAMS overloads +// ============================================================================ + +void callAMS(ams::AMSWorkflow* executor, + ams::HomogeneousGraphDomainFn Physics, + const ams::AMSHomogeneousGraph& graph_input, + ams::SmallVector& outs) +{ + // Try graph surrogate execution first + bool surrogate_used = tryGraphSurrogate(executor, graph_input, outs); + + // If surrogate succeeded, we're done + if (surrogate_used) { + return; + } + + // Otherwise, fallback to original physics computation + callApplication(Physics, graph_input, outs); +} + +void callAMS(ams::AMSWorkflow* executor, + ams::HeterogeneousGraphDomainFn Physics, + const ams::AMSHeterogeneousGraph& graph_input, + ams::SmallVector& outs) +{ + // Try graph surrogate execution first + bool surrogate_used = tryGraphSurrogate(executor, graph_input, outs); + + // If surrogate succeeded, we're done + if (surrogate_used) { + return; + } + + // Otherwise, fallback to original physics computation + callApplication(Physics, graph_input, outs); +} diff --git a/src/AMSlib/wf/interface.hpp b/src/AMSlib/wf/interface.hpp index 9a8865c1..c57d4e09 100644 --- a/src/AMSlib/wf/interface.hpp +++ b/src/AMSlib/wf/interface.hpp @@ -14,9 +14,27 @@ void callApplication(ams::DomainLambda CallBack, ams::MutableArrayRef InOuts, ams::MutableArrayRef Outs); +void callApplication(ams::HomogeneousGraphDomainFn CallBack, + const ams::AMSHomogeneousGraph& graph, + ams::SmallVector& outs); + +void callApplication(ams::HeterogeneousGraphDomainFn CallBack, + const ams::AMSHeterogeneousGraph& graph, + ams::SmallVector& outs); + void callAMS(ams::AMSWorkflow *executor, ams::DomainLambda Physics, const ams::SmallVector &ins, ams::SmallVector &inouts, ams::SmallVector &outs); + +void callAMS(ams::AMSWorkflow* executor, + ams::HomogeneousGraphDomainFn Physics, + const ams::AMSHomogeneousGraph& graph_input, + ams::SmallVector& outs); + +void callAMS(ams::AMSWorkflow* executor, + ams::HeterogeneousGraphDomainFn Physics, + const ams::AMSHeterogeneousGraph& graph_input, + ams::SmallVector& outs); diff --git a/src/AMSlib/wf/workflow.hpp b/src/AMSlib/wf/workflow.hpp index 20c496c6..1c7e6328 100644 --- a/src/AMSlib/wf/workflow.hpp +++ b/src/AMSlib/wf/workflow.hpp @@ -92,6 +92,26 @@ class AMSWorkflow CALIPER(CALI_MARK_END("DBSTORE");) } + void storeGraphData(const ams::AMSHomogeneousGraph& graph, + ArrayRef Outs) + { + // TODO: Implement graph storage when database supports it + // For now, this is a no-op placeholder + (void)graph; + (void)Outs; + AMS_DBG(Workflow, "Graph storage not yet implemented (homogeneous)"); + } + + void storeGraphData(const ams::AMSHeterogeneousGraph& graph, + ArrayRef Outs) + { + // TODO: Implement graph storage when database supports it + // For now, this is a no-op placeholder + (void)graph; + (void)Outs; + AMS_DBG(Workflow, "Graph storage not yet implemented (heterogeneous)"); + } + /** \brief Check if we can perform a surrogate model update. * AMS can update surrogate model only when all MPI ranks have received * the latest model from RabbitMQ. diff --git a/tests/AMSlib/ams_interface/CMakeLists.txt b/tests/AMSlib/ams_interface/CMakeLists.txt index 9250e421..630e6d89 100644 --- a/tests/AMSlib/ams_interface/CMakeLists.txt +++ b/tests/AMSlib/ams_interface/CMakeLists.txt @@ -24,3 +24,6 @@ endfunction() BUILD_UNIT_TEST(ams_explicit_end_to_end ams_ete.cpp) ADD_AMS_UNIT_TEST(AMS_EXPLICIT ams_explicit_end_to_end) +BUILD_UNIT_TEST(ams_graph_fallback test_graph_fallback.cpp) +ADD_AMS_UNIT_TEST(AMS_GRAPH_FALLBACK ams_graph_fallback) + diff --git a/tests/AMSlib/ams_interface/test_graph_fallback.cpp b/tests/AMSlib/ams_interface/test_graph_fallback.cpp new file mode 100644 index 00000000..6803783f --- /dev/null +++ b/tests/AMSlib/ams_interface/test_graph_fallback.cpp @@ -0,0 +1,253 @@ +#include +#include + +#include "AMS.h" +#include "AMSGraph.hpp" +#include "AMSTensor.hpp" + +using namespace ams; + +CATCH_TEST_CASE("AMSExecute homogeneous graph fallback path", "[wf][graph]") +{ + AMSInit(); + + // Setup: Register model with no surrogate path (forces fallback) + auto model = AMSRegisterAbstractModel("test_homo_graph", 0.5, "", false); + AMSExecutor executor = AMSCreateExecutor(model, 0, 1); + + // Create simple homogeneous graph (dict of tensors) + AMSHomogeneousGraph graph; + + // Insert node features tensor + AMSTensor::IntDimType node_shape[] = {10, 3}; + AMSTensor::IntDimType node_strides[] = {3, 1}; + auto node_features = AMSTensor::create( + ams::ArrayRef(node_shape, 2), + ams::ArrayRef(node_strides, 2), + AMSResourceType::AMS_HOST); + + // Fill with test data + float* features_data = node_features.data(); + for (int i = 0; i < 30; ++i) { + features_data[i] = static_cast(i); + } + + insertTensor(graph, "node_features", std::move(node_features)); + + // Create output tensor + SmallVector outs; + AMSTensor::IntDimType out_shape[] = {10, 2}; + AMSTensor::IntDimType out_strides[] = {2, 1}; + auto out_tensor = AMSTensor::create( + ams::ArrayRef(out_shape, 2), + ams::ArrayRef(out_strides, 2), + AMSResourceType::AMS_HOST); + + // Initialize outputs to zero + float* out_data = out_tensor.data(); + for (int i = 0; i < 20; ++i) { + out_data[i] = 0.0f; + } + + outs.push_back(std::move(out_tensor)); + + // Define callback that processes graph + bool callback_invoked = false; + HomogeneousGraphDomainFn callback = [&](const AMSHomogeneousGraph& g, + SmallVector& outputs) { + callback_invoked = true; + + // Verify graph structure + CATCH_REQUIRE(containsTensor(g, "node_features")); + const auto* features = findTensor(g, "node_features"); + CATCH_REQUIRE(features != nullptr); + CATCH_REQUIRE(features->shape()[0] == 10); + CATCH_REQUIRE(features->shape()[1] == 3); + + // Verify input data + const float* features_data = features->data(); + CATCH_REQUIRE(features_data[0] == 0.0f); + CATCH_REQUIRE(features_data[29] == 29.0f); + + // Fill outputs with computation result + CATCH_REQUIRE(outputs.size() == 1); + float* out_data = outputs[0].data(); + for (int i = 0; i < 20; ++i) { + out_data[i] = static_cast(i * 2); + } + }; + + // Execute + AMSExecute(executor, callback, graph, outs); + + // Verify callback was invoked (fallback path) + CATCH_REQUIRE(callback_invoked); + + // Verify outputs were written + const float* result_data = outs[0].data(); + CATCH_REQUIRE(result_data[0] == 0.0f); + CATCH_REQUIRE(result_data[10] == 20.0f); + CATCH_REQUIRE(result_data[19] == 38.0f); + + // Note: Not destroying executor to avoid triggering AMSFinalize between tests + // The executor will be cleaned up at program exit +} + +CATCH_TEST_CASE("AMSExecute heterogeneous graph fallback path", "[wf][graph]") +{ + AMSInit(); + + // Setup: Register model with no surrogate path (forces fallback) + auto model = AMSRegisterAbstractModel("test_hetero_graph", 0.5, "", false); + AMSExecutor executor = AMSCreateExecutor(model, 0, 1); + + // Create heterogeneous graph + AMSHeterogeneousGraph graph; + + // Add node store for "atom" nodes + auto& atom_store = graph.getOrCreateNodeStore("atom"); + AMSTensor::IntDimType atom_shape[] = {5, 2}; + AMSTensor::IntDimType atom_strides[] = {2, 1}; + auto atom_features = AMSTensor::create( + ams::ArrayRef(atom_shape, 2), + ams::ArrayRef(atom_strides, 2), + AMSResourceType::AMS_HOST); + + // Fill with test data + float* atom_data = atom_features.data(); + for (int i = 0; i < 10; ++i) { + atom_data[i] = static_cast(i + 1); + } + + insertTensor(atom_store, "features", std::move(atom_features)); + + // Add edge store + auto& edge_store = + graph.getOrCreateEdgeStore(EdgeType{"atom", "bond", "atom"}); + AMSTensor::IntDimType edge_shape[] = {2, 8}; + AMSTensor::IntDimType edge_strides[] = {8, 1}; + auto edge_index = AMSTensor::create( + ams::ArrayRef(edge_shape, 2), + ams::ArrayRef(edge_strides, 2), + AMSResourceType::AMS_HOST); + + // Fill with edge connectivity + int64_t* edge_data = edge_index.data(); + for (int i = 0; i < 16; ++i) { + edge_data[i] = i % 5; + } + + insertTensor(edge_store, "edge_index", std::move(edge_index)); + + // Add global features + AMSTensor::IntDimType global_shape[] = {1, 4}; + AMSTensor::IntDimType global_strides[] = {4, 1}; + auto global_features = AMSTensor::create( + ams::ArrayRef(global_shape, 2), + ams::ArrayRef(global_strides, 2), + AMSResourceType::AMS_HOST); + + float* global_data = global_features.data(); + for (int i = 0; i < 4; ++i) { + global_data[i] = static_cast(i * 10); + } + + insertTensor(graph.global_store, "global", std::move(global_features)); + + // Create output tensor + SmallVector outs; + AMSTensor::IntDimType hetero_out_shape[] = {5, 1}; + AMSTensor::IntDimType hetero_out_strides[] = {1, 1}; + auto out_tensor = AMSTensor::create( + ams::ArrayRef(hetero_out_shape, 2), + ams::ArrayRef(hetero_out_strides, 2), + AMSResourceType::AMS_HOST); + + // Initialize to zero + float* out_data = out_tensor.data(); + for (int i = 0; i < 5; ++i) { + out_data[i] = 0.0f; + } + + outs.push_back(std::move(out_tensor)); + + // Define callback + bool callback_invoked = false; + HeterogeneousGraphDomainFn callback = [&](const AMSHeterogeneousGraph& g, + SmallVector& outputs) { + callback_invoked = true; + + // Verify graph structure + CATCH_REQUIRE(g.containsNodeStore("atom")); + const auto* atom_store_ptr = g.findNodeStore("atom"); + CATCH_REQUIRE(atom_store_ptr != nullptr); + CATCH_REQUIRE(containsTensor(*atom_store_ptr, "features")); + + // Verify node data + const auto* features = findTensor(*atom_store_ptr, "features"); + CATCH_REQUIRE(features != nullptr); + const float* features_data = features->data(); + CATCH_REQUIRE(features_data[0] == 1.0f); + CATCH_REQUIRE(features_data[9] == 10.0f); + + // Verify edge store + CATCH_REQUIRE(g.containsEdgeStore(EdgeType{"atom", "bond", "atom"})); + const auto* edge_store_ptr = + g.findEdgeStore(EdgeType{"atom", "bond", "atom"}); + CATCH_REQUIRE(edge_store_ptr != nullptr); + CATCH_REQUIRE(containsTensor(*edge_store_ptr, "edge_index")); + + // Verify global store + CATCH_REQUIRE(containsTensor(g.global_store, "global")); + const auto* global = findTensor(g.global_store, "global"); + CATCH_REQUIRE(global != nullptr); + const float* global_data = global->data(); + CATCH_REQUIRE(global_data[0] == 0.0f); + CATCH_REQUIRE(global_data[3] == 30.0f); + + // Fill outputs + CATCH_REQUIRE(outputs.size() == 1); + float* out_data = outputs[0].data(); + for (int i = 0; i < 5; ++i) { + out_data[i] = static_cast(i * 3); + } + }; + + // Execute + AMSExecute(executor, callback, graph, outs); + + // Verify + CATCH_REQUIRE(callback_invoked); + const float* result_data = outs[0].data(); + CATCH_REQUIRE(result_data[0] == 0.0f); + CATCH_REQUIRE(result_data[2] == 6.0f); + CATCH_REQUIRE(result_data[4] == 12.0f); + + // Note: Not destroying executor to avoid triggering AMSFinalize between tests + // The executor will be cleaned up at program exit +} + +CATCH_TEST_CASE("Graph callback type safety", "[wf][graph]") +{ + // This test verifies that the type system prevents mismatches + // It primarily exists as a compile-time check + + AMSHomogeneousGraph homo_graph; + AMSHeterogeneousGraph hetero_graph; + + HomogeneousGraphDomainFn homo_fn = [](const AMSHomogeneousGraph&, + SmallVector&) {}; + + HeterogeneousGraphDomainFn hetero_fn = [](const AMSHeterogeneousGraph&, + SmallVector&) {}; + + // These should compile: + // AMSExecute(executor, homo_fn, homo_graph, outs); + // AMSExecute(executor, hetero_fn, hetero_graph, outs); + + // These should NOT compile (type mismatch): + // AMSExecute(executor, homo_fn, hetero_graph, outs); // ERROR + // AMSExecute(executor, hetero_fn, homo_graph, outs); // ERROR + + CATCH_SUCCEED("Type safety validated at compile time"); +}