Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
c92bdb8
Extrapolating and saving real world data, initialize model in data ge…
mhheger Aug 11, 2025
0ceefd4
reducing extrapolate gnn to its core functionality, adding comments i…
mhheger Aug 11, 2025
fde4b3b
Add description of workflow
mhheger Aug 11, 2025
6b79df8
Merge remote-tracking branch 'origin/main' into 1139-introduce-improv…
HenrZu Aug 11, 2025
ba90b76
init file gnn
HenrZu Aug 11, 2025
3c9bf9b
[ci skip] Updated data generation and extrapolation, code for trainin…
mhheger Aug 18, 2025
306c752
[ci skip] Include different network architectures and introduce grid …
mhheger Sep 3, 2025
bf6c5f5
[ci skip] Improve model creation and implement grid search functiona…
mhheger Sep 8, 2025
5b24fd9
Add validation checks for dataset and model parameters; enhance test …
mhheger Sep 10, 2025
614082b
Enhance data scaling functionality and validation; add tests for scal…
mhheger Sep 22, 2025
6b32716
Ading necessary imports
mhheger Sep 22, 2025
b451ded
Fix imports?
mhheger Sep 22, 2025
b486d10
Add spektral to requirements
mhheger Sep 22, 2025
ea4cd31
undo scale_data-tests
mhheger Sep 22, 2025
20274ab
Refactor imports in GNN_utils.py and update test_surrogatemodel_GNN.p…
mhheger Sep 24, 2025
35fb822
Add model building and training step to evaluate_and_train.py; update…
mhheger Sep 25, 2025
9f37f74
Update requirements
mhheger Sep 25, 2025
b51623f
Merge branch 'main' into 1139-introduce-improved-gnn-surrogate-models
HenrZu Oct 29, 2025
68fa1c0
formating and fix tests
HenrZu Oct 29, 2025
a47f11e
.
HenrZu Oct 29, 2025
b6f40bb
.
HenrZu Oct 30, 2025
5bc242e
support for py 3.8
HenrZu Oct 30, 2025
70a9b11
.
HenrZu Oct 31, 2025
4d42a2a
[ci skip] start rework data gemeration gnn
HenrZu Nov 4, 2025
8c77ef2
[ci skip] .
HenrZu Nov 4, 2025
4d72f23
[ci skip] complete rework data generation
HenrZu Nov 5, 2025
304a702
.
HenrZu Nov 5, 2025
d19c6ab
[ci skip] rework evaluate and trian
HenrZu Nov 5, 2025
d39917d
.
HenrZu Nov 5, 2025
cf59960
[ci skip] rm extrapolate gnn
HenrZu Nov 5, 2025
d177e4c
[ci skip] rework gnn utils + grid_seach
HenrZu Nov 5, 2025
223b3d5
.
HenrZu Nov 5, 2025
d3f6ca4
add rtd for gnn
HenrZu Nov 5, 2025
e706a61
readme
HenrZu Nov 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
346 changes: 343 additions & 3 deletions docs/source/python/m-surrogate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,347 @@ The `grid_search.py` and `hyperparameter_tuning.py` modules provide tools for sy
- Visualization of hyperparameter importance
- Selection of optimal model configurations

SECIR Groups Model
------------------

To be added...

Graph Neural Network (GNN) Surrogate Models
--------------------------------------------

The Graph Neural Network (GNN) module provides advanced surrogate models that leverage spatial connectivity and age-stratified epidemiological dynamics. These models are designed for immediate and reliable pandemic response by combining mechanistic expert knowledge with machine learning efficiency.

Overview and Scientific Foundation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The GNN surrogate models are based on the research presented in:

|Graph_Neural_Network_Surrogates|

The implementation leverages the mechanistic ODE-SECIR model (see :doc:`ODE-SECIR documentation <../models/ode_secir>`) as the underlying expert model, using Python bindings to the C++ backend for efficient simulation during data generation.

Module Structure
~~~~~~~~~~~~~~~~

The GNN module is located in `pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN <https://github.com/SciCompMod/memilio/tree/main/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN>`_ and consists of:

- **data_generation.py**: Generates training and evaluation data by simulating epidemiological scenarios with the mechanistic SECIR model
- **network_architectures.py**: Defines various GNN architectures (GCN, GAT, GIN) with configurable layers and preprocessing
- **evaluate_and_train.py**: Implements training and evaluation pipelines for GNN models
- **grid_search.py**: Provides hyperparameter optimization through systematic grid search
- **GNN_utils.py**: Contains utility functions for data preprocessing, graph construction, and population data handling

Data Generation
~~~~~~~~~~~~~~~

The data generation process in ``data_generation.py`` creates graph-structured training data through mechanistic simulations:

.. code-block:: python

from memilio.surrogatemodel.GNN import data_generation

# Generate training dataset
dataset = data_generation.generate_dataset(
num_runs=1000, # Number of simulation scenarios
num_days=30, # Simulation horizon
num_age_groups=6, # Age stratification
data_dir='path/to/contact_data', # Contact matrices location
mobility_dir='path/to/mobility', # Mobility data location
save_path='gnn_training_data.pickle'
)

**Data Generation Workflow:**

1. **Parameter Sampling**: Randomly sample epidemiological parameters (transmission rates, incubation periods, recovery rates) from predefined distributions to create diverse scenarios.

2. **Compartment Initialization**: Initialize epidemic compartments for each age group in each region based on realistic demographic data. Compartments are initialized using shared base factors.

3. **Mobility Graph Construction**: Build a spatial graph where:

- Nodes represent geographic regions (e.g., German counties)
- Edges represent mobility connections with weights from commuting data
- Node features include age-stratified population sizes

4. **Contact Matrix Configuration**: Load and configure baseline contact matrices for different location types (home, school, work, other) stratified by age groups.

5. **Damping Application**: Apply time-varying dampings to contact matrices to simulate NPIs:

- Multiple damping periods with random start days
- Location-specific damping factors (e.g., stronger school closures, moderate workplace restrictions)
- Realistic parameter ranges based on observed intervention strengths

6. **Simulation Execution**: Run the mechanistic ODE-SECIR model using MEmilio's C++ backend through Python bindings to generate the dataset.

7. **Data Processing**: Transform simulation results into graph-structured format:

- Extract compartment time series for each node (region) and age group
- Apply logarithmic transformation for numerical stability
- Store graph topology, node features, and temporal sequences

Network Architectures
~~~~~~~~~~~~~~~~~~~~~

The ``network_architectures.py`` module provides flexible GNN model construction for different layer types.

.. code-block:: python

from memilio.surrogatemodel.GNN import network_architectures

# Define GNN architecture
model_config = {
'layer_type': 'GCN', # GNN layer type
'num_layers': 3, # Network depth
'hidden_dim': 64, # Hidden layer dimensions
'activation': 'relu', # Activation function
'dropout_rate': 0.2, # Dropout for regularization
'use_batch_norm': True, # Batch normalization
'aggregation': 'mean', # Neighborhood aggregation method
}

# Build model
model = network_architectures.build_gnn_model(
config=model_config,
input_shape=(num_timesteps, num_features),
output_dim=num_compartments * num_age_groups
)


Training and Evaluation
~~~~~~~~~~~~~~~~~~~~~~~

The ``evaluate_and_train.py`` module provides the training functionality:

.. code-block:: python

from memilio.surrogatemodel.GNN import evaluate_and_train

# Load training data
with open('gnn_training_data.pickle', 'rb') as f:
dataset = pickle.load(f)

# Define training configuration
training_config = {
'epochs': 100,
'batch_size': 32,
'learning_rate': 0.001,
'optimizer': 'adam',
'loss_function': 'mse',
'early_stopping_patience': 10,
'validation_split': 0.2
}

# Train model
history = evaluate_and_train.train_gnn_model(
model=model,
dataset=dataset,
config=training_config,
save_weights='best_gnn_model.h5'
)

# Evaluate on test set
metrics = evaluate_and_train.evaluate_model(
model=model,
test_data=test_dataset,
metrics=['mae', 'mape', 'r2']
)

**Training Features:**

1. **Mini-batch Training**: Graph batching for efficient training on large datasets
2. **Custom Loss Functions**: MSE, MAE, MAPE, or custom compartment-weighted losses
3. **Early Stopping**: Monitors validation loss to prevent overfitting
4. **Learning Rate Scheduling**: Adaptive learning rate reduction on plateaus
5. **Save Best Weights**: Saves best model weights based on validation performance

**Evaluation Metrics:**

- **Mean Absolute Error (MAE)**: Average absolute prediction error per compartment
- **Mean Absolute Percentage Error (MAPE)**: Mean absolute error as percentage
- **R² Score**: Coefficient of determination for prediction quality

**Data Splitting:**

- **Training Set (70%)**: For model parameter optimization
- **Validation Set (15%)**: For hyperparameter tuning and early stopping
- **Test Set (15%)**: For final performance evaluation

Hyperparameter Optimization
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The ``grid_search.py`` module enables systematic exploration of hyperparameter space:

.. code-block:: python

from memilio.surrogatemodel.GNN import grid_search

# Define search space
param_grid = {
'layer_type': ['GCN', 'GAT', 'GIN'],
'num_layers': [2, 3, 4, 5],
'hidden_dim': [32, 64, 128, 256],
'learning_rate': [0.001, 0.0005, 0.0001],
'dropout_rate': [0.0, 0.1, 0.2, 0.3],
'batch_size': [16, 32, 64],
'activation': ['relu', 'elu', 'tanh']
}

# Run grid search with cross-validation
results = grid_search.run_hyperparameter_search(
param_grid=param_grid,
data_path='gnn_training_data.pickle',
cv_folds=5,
metric='mae',
save_results='grid_search_results.csv'
)

# Analyze best configuration
best_config = grid_search.get_best_configuration(results)
print(f"Best configuration: {best_config}")

Utility Functions
~~~~~~~~~~~~~~~~~

The ``GNN_utils.py`` module provides essential helper functions used throughout the GNN workflow:

**Data Preprocessing:**

.. code-block:: python

from memilio.surrogatemodel.GNN import GNN_utils

# Remove confirmed compartments (simplify model)
simplified_data = GNN_utils.remove_confirmed_compartments(
dataset_entries=dataset,
num_groups=6
)

# Apply logarithmic scaling
scaled_data = GNN_utils.scale_data(
data=dataset,
method='log',
epsilon=1e-6 # Small constant to avoid log(0)
)

# Load population data
population = GNN_utils.load_population_data(
data_dir='path/to/demographics',
age_groups=[0, 5, 15, 35, 60, 80]
)

**Graph Construction:**

.. code-block:: python

# Create mobility graph from commuting data
graph = GNN_utils.create_mobility_graph(
mobility_dir='path/to/mobility',
num_regions=401, # German counties
county_ids=county_list,
models=models_per_region # SECIR models for each region
)

# Get baseline contact matrix
contact_matrix = GNN_utils.get_baseline_contact_matrix(
data_dir='path/to/contact_matrices'
)

Practical Usage Example
~~~~~~~~~~~~~~~~~~~~~~~

Here is a complete example workflow from data generation to model evaluation:

.. code-block:: python

import pickle
from pathlib import Path
from memilio.surrogatemodel.GNN import (
data_generation,
network_architectures,
evaluate_and_train
)

# Step 1: Generate training data
print("Generating training data...")
dataset = data_generation.generate_dataset(
num_runs=5000,
num_days=30,
num_age_groups=6,
data_dir='/path/to/memilio/data/Germany',
mobility_dir='/path/to/mobility_data',
save_path='gnn_dataset_5000.pickle'
)

# Step 2: Define and build GNN model
print("Building GNN model...")
model_config = {
'layer_type': 'GCN',
'num_layers': 4,
'hidden_dim': 128,
'activation': 'relu',
'dropout_rate': 0.2,
'use_batch_norm': True
}

model = network_architectures.build_gnn_model(
config=model_config,
input_shape=(1, 48), # 6 age groups × 8 compartments
output_dim=48 # Predict all compartments
)

# Step 3: Train the model
print("Training model...")
training_config = {
'epochs': 200,
'batch_size': 32,
'learning_rate': 0.001,
'optimizer': 'adam',
'loss_function': 'mae',
'early_stopping_patience': 20,
'validation_split': 0.2
}

history = evaluate_and_train.train_gnn_model(
model=model,
dataset=dataset,
config=training_config,
save_weights='gnn_weights_best.h5'
)

# Step 4: Evaluate on test data
print("Evaluating model...")
test_metrics = evaluate_and_train.evaluate_model(
model=model,
test_data='gnn_test_data.pickle',
metrics=['mae', 'mape', 'r2']
)

# Print results
print(f"Test MAE: {test_metrics['mae']:.4f}")
print(f"Test MAPE: {test_metrics['mape']:.2f}%")
print(f"Test R²: {test_metrics['r2']:.4f}")

# Step 5: Make predictions on new scenarios
with open('new_scenario.pickle', 'rb') as f:
new_data = pickle.load(f)

predictions = model.predict(new_data)
print(f"Predictions shape: {predictions.shape}")

**GPU Acceleration:**

- TensorFlow automatically uses GPU when available
- Spektral layers are optimized for GPU execution
- Training time can be heavily reduced with appropriate GPU hardware

Additional Resources
~~~~~~~~~~~~~~~~~~~~

**Code and Examples:**

- `GNN Module <https://github.com/SciCompMod/memilio/tree/main/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN>`_
- `GNN README <https://github.com/SciCompMod/memilio/blob/main/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/README.md>`_
- `Test Scripts <https://github.com/SciCompMod/memilio/tree/main/pycode/memilio-surrogatemodel/memilio/surrogatemodel_test>`_

**Related Documentation:**

- :doc:`ODE-SECIR Model <../models/ode_secir>`
- :doc:`MEmilio Simulation Package <m-simulation>`
- :doc:`Python Bindings <python_bindings>`

Loading
Loading