Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/AMSlib/AMS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,38 @@ void AMSExecute(AMSExecutor executor,
callAMS(workflow, OrigComputation, ins, inouts, outs);
}

void AMSExecute(AMSExecutor executor,
HomogeneousGraphDomainFn& OrigComputation,
const ams::AMSHomogeneousGraph& graph_input,
ams::SmallVector<ams::AMSTensor>& outs)
{
int64_t index = static_cast<int64_t>(executor);
if (index >= _amsWrap->executors.size())
throw std::runtime_error("AMS Executor identifier does not exist\n");
auto currExec = _amsWrap->executors[index];

ams::AMSWorkflow* workflow = reinterpret_cast<ams::AMSWorkflow*>(currExec);
AMS_DBG(AMS, "Calling AMS with homogeneous graph, out:{}", outs.size());

callAMS(workflow, OrigComputation, graph_input, outs);
}

void AMSExecute(AMSExecutor executor,
HeterogeneousGraphDomainFn& OrigComputation,
const ams::AMSHeterogeneousGraph& graph_input,
ams::SmallVector<ams::AMSTensor>& outs)
{
int64_t index = static_cast<int64_t>(executor);
if (index >= _amsWrap->executors.size())
throw std::runtime_error("AMS Executor identifier does not exist\n");
auto currExec = _amsWrap->executors[index];

ams::AMSWorkflow* workflow = reinterpret_cast<ams::AMSWorkflow*>(currExec);
AMS_DBG(AMS, "Calling AMS with heterogeneous graph, out:{}", outs.size());

callAMS(workflow, OrigComputation, graph_input, outs);
}

void AMSCExecute(AMSExecutor executor,
DomainCFn OrigCComputation,
void* args,
Expand Down
18 changes: 18 additions & 0 deletions src/AMSlib/include/AMS.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ using DomainCFn = void (*)(void*,
ams::SmallVector<ams::AMSTensor>&,
ams::SmallVector<ams::AMSTensor>&);

using HomogeneousGraphDomainFn =
std::function<void(const ams::AMSHomogeneousGraph& /*graph input*/,
ams::SmallVector<ams::AMSTensor>& /*tensor outputs*/)>;

using HeterogeneousGraphDomainFn =
std::function<void(const ams::AMSHeterogeneousGraph& /*graph input*/,
ams::SmallVector<ams::AMSTensor>& /*tensor outputs*/)>;

using AMSExecutor = int64_t;
using AMSCAbstrModel = int;

Expand Down Expand Up @@ -81,6 +89,16 @@ void AMSCExecute(AMSExecutor executor,
ams::SmallVector<ams::AMSTensor>& inouts,
ams::SmallVector<ams::AMSTensor>& outs);

void AMSExecute(AMSExecutor executor,
HomogeneousGraphDomainFn& OrigComputation,
const ams::AMSHomogeneousGraph& graph_input,
ams::SmallVector<ams::AMSTensor>& outs);

void AMSExecute(AMSExecutor executor,
HeterogeneousGraphDomainFn& OrigComputation,
const ams::AMSHeterogeneousGraph& graph_input,
ams::SmallVector<ams::AMSTensor>& outs);

void AMSDestroyExecutor(AMSExecutor executor);

void AMSSetAllocator(ams::AMSResourceType resource, const char* alloc_name);
Expand Down
95 changes: 95 additions & 0 deletions src/AMSlib/wf/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,3 +341,98 @@ void callAMS(ams::AMSWorkflow* executor,

executor->evaluate(Physics, tins, tinouts, touts);
}

// ============================================================================
// Graph-based callApplication overloads
// ============================================================================

void callApplication(ams::HomogeneousGraphDomainFn CallBack,
const ams::AMSHomogeneousGraph& graph,
ams::SmallVector<ams::AMSTensor>& outs)
{
// Directly invoke the user's physics callback with graph-native types
CallBack(graph, outs);
}

void callApplication(ams::HeterogeneousGraphDomainFn CallBack,
const ams::AMSHeterogeneousGraph& graph,
ams::SmallVector<ams::AMSTensor>& outs)
{
// Directly invoke the user's physics callback with graph-native types
CallBack(graph, outs);
}

// ============================================================================
// Graph surrogate execution stub (seam for future implementation)
// ============================================================================

static bool tryGraphSurrogate(ams::AMSWorkflow* executor,
const ams::AMSHomogeneousGraph& graph,
ams::SmallVector<ams::AMSTensor>& outs)
{
// TODO: Implement graph surrogate execution when models support graphs
// This is the integration point for future graph-based ML inference
//
// Future implementation should:
// 1. Check if executor has a model that accepts graph inputs
// 2. Convert AMSHomogeneousGraph to model input format
// 3. Run model inference and uncertainty quantification
// 4. If UQ passes threshold, populate outs and return true
// 5. Otherwise return false to trigger fallback
//
// For now, always return false (no surrogate available)
(void)executor;
(void)graph;
(void)outs;
return false;
}

static bool tryGraphSurrogate(ams::AMSWorkflow* executor,
const ams::AMSHeterogeneousGraph& graph,
ams::SmallVector<ams::AMSTensor>& outs)
{
// TODO: Implement graph surrogate execution when models support graphs
// See homogeneous version for implementation notes
(void)executor;
(void)graph;
(void)outs;
return false;
}

// ============================================================================
// Graph-based callAMS overloads
// ============================================================================

void callAMS(ams::AMSWorkflow* executor,
ams::HomogeneousGraphDomainFn Physics,
const ams::AMSHomogeneousGraph& graph_input,
ams::SmallVector<ams::AMSTensor>& outs)
{
// Try graph surrogate execution first
bool surrogate_used = tryGraphSurrogate(executor, graph_input, outs);

// If surrogate succeeded, we're done
if (surrogate_used) {
return;
}

// Otherwise, fallback to original physics computation
callApplication(Physics, graph_input, outs);
}

void callAMS(ams::AMSWorkflow* executor,
ams::HeterogeneousGraphDomainFn Physics,
const ams::AMSHeterogeneousGraph& graph_input,
ams::SmallVector<ams::AMSTensor>& outs)
{
// Try graph surrogate execution first
bool surrogate_used = tryGraphSurrogate(executor, graph_input, outs);

// If surrogate succeeded, we're done
if (surrogate_used) {
return;
}

// Otherwise, fallback to original physics computation
callApplication(Physics, graph_input, outs);
}
18 changes: 18 additions & 0 deletions src/AMSlib/wf/interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,27 @@ void callApplication(ams::DomainLambda CallBack,
ams::MutableArrayRef<torch::Tensor> InOuts,
ams::MutableArrayRef<torch::Tensor> Outs);

void callApplication(ams::HomogeneousGraphDomainFn CallBack,
const ams::AMSHomogeneousGraph& graph,
ams::SmallVector<ams::AMSTensor>& outs);

void callApplication(ams::HeterogeneousGraphDomainFn CallBack,
const ams::AMSHeterogeneousGraph& graph,
ams::SmallVector<ams::AMSTensor>& outs);


void callAMS(ams::AMSWorkflow *executor,
ams::DomainLambda Physics,
const ams::SmallVector<ams::AMSTensor> &ins,
ams::SmallVector<ams::AMSTensor> &inouts,
ams::SmallVector<ams::AMSTensor> &outs);

void callAMS(ams::AMSWorkflow* executor,
ams::HomogeneousGraphDomainFn Physics,
const ams::AMSHomogeneousGraph& graph_input,
ams::SmallVector<ams::AMSTensor>& outs);

void callAMS(ams::AMSWorkflow* executor,
ams::HeterogeneousGraphDomainFn Physics,
const ams::AMSHeterogeneousGraph& graph_input,
ams::SmallVector<ams::AMSTensor>& outs);
20 changes: 20 additions & 0 deletions src/AMSlib/wf/workflow.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,26 @@ class AMSWorkflow
CALIPER(CALI_MARK_END("DBSTORE");)
}

void storeGraphData(const ams::AMSHomogeneousGraph& graph,
ArrayRef<ams::AMSTensor> Outs)
{
// TODO: Implement graph storage when database supports it
// For now, this is a no-op placeholder
(void)graph;
(void)Outs;
AMS_DBG(Workflow, "Graph storage not yet implemented (homogeneous)");
}

void storeGraphData(const ams::AMSHeterogeneousGraph& graph,
ArrayRef<ams::AMSTensor> Outs)
{
// TODO: Implement graph storage when database supports it
// For now, this is a no-op placeholder
(void)graph;
(void)Outs;
AMS_DBG(Workflow, "Graph storage not yet implemented (heterogeneous)");
}

/** \brief Check if we can perform a surrogate model update.
* AMS can update surrogate model only when all MPI ranks have received
* the latest model from RabbitMQ.
Expand Down
3 changes: 3 additions & 0 deletions tests/AMSlib/ams_interface/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ endfunction()
BUILD_UNIT_TEST(ams_explicit_end_to_end ams_ete.cpp)
ADD_AMS_UNIT_TEST(AMS_EXPLICIT ams_explicit_end_to_end)

BUILD_UNIT_TEST(ams_graph_fallback test_graph_fallback.cpp)
ADD_AMS_UNIT_TEST(AMS_GRAPH_FALLBACK ams_graph_fallback)

Loading
Loading