diff --git a/docs/source/python/m-surrogate.rst b/docs/source/python/m-surrogate.rst index 3ea9e9782d..9b80e790a0 100644 --- a/docs/source/python/m-surrogate.rst +++ b/docs/source/python/m-surrogate.rst @@ -163,7 +163,347 @@ The `grid_search.py` and `hyperparameter_tuning.py` modules provide tools for sy - Visualization of hyperparameter importance - Selection of optimal model configurations -SECIR Groups Model ------------------- -To be added... + +Graph Neural Network (GNN) Surrogate Models +-------------------------------------------- + +The Graph Neural Network (GNN) module provides advanced surrogate models that leverage spatial connectivity and age-stratified epidemiological dynamics. These models are designed for immediate and reliable pandemic response by combining mechanistic expert knowledge with machine learning efficiency. + +Overview and Scientific Foundation +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The GNN surrogate models are based on the research presented in: + +|Graph_Neural_Network_Surrogates| + +The implementation leverages the mechanistic ODE-SECIR model (see :doc:`ODE-SECIR documentation <../models/ode_secir>`) as the underlying expert model, using Python bindings to the C++ backend for efficient simulation during data generation. + +Module Structure +~~~~~~~~~~~~~~~~ + +The GNN module is located in `pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN `_ and consists of: + +- **data_generation.py**: Generates training and evaluation data by simulating epidemiological scenarios with the mechanistic SECIR model +- **network_architectures.py**: Defines various GNN architectures (GCN, GAT, GIN) with configurable layers and preprocessing +- **evaluate_and_train.py**: Implements training and evaluation pipelines for GNN models +- **grid_search.py**: Provides hyperparameter optimization through systematic grid search +- **GNN_utils.py**: Contains utility functions for data preprocessing, graph construction, and population data handling + +Data Generation +~~~~~~~~~~~~~~~ + +The data generation process in ``data_generation.py`` creates graph-structured training data through mechanistic simulations: + +.. code-block:: python + + from memilio.surrogatemodel.GNN import data_generation + + # Generate training dataset + dataset = data_generation.generate_dataset( + num_runs=1000, # Number of simulation scenarios + num_days=30, # Simulation horizon + num_age_groups=6, # Age stratification + data_dir='path/to/contact_data', # Contact matrices location + mobility_dir='path/to/mobility', # Mobility data location + save_path='gnn_training_data.pickle' + ) + +**Data Generation Workflow:** + +1. **Parameter Sampling**: Randomly sample epidemiological parameters (transmission rates, incubation periods, recovery rates) from predefined distributions to create diverse scenarios. + +2. **Compartment Initialization**: Initialize epidemic compartments for each age group in each region based on realistic demographic data. Compartments are initialized using shared base factors. + +3. **Mobility Graph Construction**: Build a spatial graph where: + + - Nodes represent geographic regions (e.g., German counties) + - Edges represent mobility connections with weights from commuting data + - Node features include age-stratified population sizes + +4. **Contact Matrix Configuration**: Load and configure baseline contact matrices for different location types (home, school, work, other) stratified by age groups. + +5. **Damping Application**: Apply time-varying dampings to contact matrices to simulate NPIs: + + - Multiple damping periods with random start days + - Location-specific damping factors (e.g., stronger school closures, moderate workplace restrictions) + - Realistic parameter ranges based on observed intervention strengths + +6. **Simulation Execution**: Run the mechanistic ODE-SECIR model using MEmilio's C++ backend through Python bindings to generate the dataset. + +7. **Data Processing**: Transform simulation results into graph-structured format: + + - Extract compartment time series for each node (region) and age group + - Apply logarithmic transformation for numerical stability + - Store graph topology, node features, and temporal sequences + +Network Architectures +~~~~~~~~~~~~~~~~~~~~~ + +The ``network_architectures.py`` module provides flexible GNN model construction for different layer types. + +.. code-block:: python + + from memilio.surrogatemodel.GNN import network_architectures + + # Define GNN architecture + model_config = { + 'layer_type': 'GCN', # GNN layer type + 'num_layers': 3, # Network depth + 'hidden_dim': 64, # Hidden layer dimensions + 'activation': 'relu', # Activation function + 'dropout_rate': 0.2, # Dropout for regularization + 'use_batch_norm': True, # Batch normalization + 'aggregation': 'mean', # Neighborhood aggregation method + } + + # Build model + model = network_architectures.build_gnn_model( + config=model_config, + input_shape=(num_timesteps, num_features), + output_dim=num_compartments * num_age_groups + ) + + +Training and Evaluation +~~~~~~~~~~~~~~~~~~~~~~~ + +The ``evaluate_and_train.py`` module provides the training functionality: + +.. code-block:: python + + from memilio.surrogatemodel.GNN import evaluate_and_train + + # Load training data + with open('gnn_training_data.pickle', 'rb') as f: + dataset = pickle.load(f) + + # Define training configuration + training_config = { + 'epochs': 100, + 'batch_size': 32, + 'learning_rate': 0.001, + 'optimizer': 'adam', + 'loss_function': 'mse', + 'early_stopping_patience': 10, + 'validation_split': 0.2 + } + + # Train model + history = evaluate_and_train.train_gnn_model( + model=model, + dataset=dataset, + config=training_config, + save_weights='best_gnn_model.h5' + ) + + # Evaluate on test set + metrics = evaluate_and_train.evaluate_model( + model=model, + test_data=test_dataset, + metrics=['mae', 'mape', 'r2'] + ) + +**Training Features:** + +1. **Mini-batch Training**: Graph batching for efficient training on large datasets +2. **Custom Loss Functions**: MSE, MAE, MAPE, or custom compartment-weighted losses +3. **Early Stopping**: Monitors validation loss to prevent overfitting +4. **Learning Rate Scheduling**: Adaptive learning rate reduction on plateaus +5. **Save Best Weights**: Saves best model weights based on validation performance + +**Evaluation Metrics:** + +- **Mean Absolute Error (MAE)**: Average absolute prediction error per compartment +- **Mean Absolute Percentage Error (MAPE)**: Mean absolute error as percentage +- **R² Score**: Coefficient of determination for prediction quality + +**Data Splitting:** + +- **Training Set (70%)**: For model parameter optimization +- **Validation Set (15%)**: For hyperparameter tuning and early stopping +- **Test Set (15%)**: For final performance evaluation + +Hyperparameter Optimization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The ``grid_search.py`` module enables systematic exploration of hyperparameter space: + +.. code-block:: python + + from memilio.surrogatemodel.GNN import grid_search + + # Define search space + param_grid = { + 'layer_type': ['GCN', 'GAT', 'GIN'], + 'num_layers': [2, 3, 4, 5], + 'hidden_dim': [32, 64, 128, 256], + 'learning_rate': [0.001, 0.0005, 0.0001], + 'dropout_rate': [0.0, 0.1, 0.2, 0.3], + 'batch_size': [16, 32, 64], + 'activation': ['relu', 'elu', 'tanh'] + } + + # Run grid search with cross-validation + results = grid_search.run_hyperparameter_search( + param_grid=param_grid, + data_path='gnn_training_data.pickle', + cv_folds=5, + metric='mae', + save_results='grid_search_results.csv' + ) + + # Analyze best configuration + best_config = grid_search.get_best_configuration(results) + print(f"Best configuration: {best_config}") + +Utility Functions +~~~~~~~~~~~~~~~~~ + +The ``GNN_utils.py`` module provides essential helper functions used throughout the GNN workflow: + +**Data Preprocessing:** + +.. code-block:: python + + from memilio.surrogatemodel.GNN import GNN_utils + + # Remove confirmed compartments (simplify model) + simplified_data = GNN_utils.remove_confirmed_compartments( + dataset_entries=dataset, + num_groups=6 + ) + + # Apply logarithmic scaling + scaled_data = GNN_utils.scale_data( + data=dataset, + method='log', + epsilon=1e-6 # Small constant to avoid log(0) + ) + + # Load population data + population = GNN_utils.load_population_data( + data_dir='path/to/demographics', + age_groups=[0, 5, 15, 35, 60, 80] + ) + +**Graph Construction:** + +.. code-block:: python + + # Create mobility graph from commuting data + graph = GNN_utils.create_mobility_graph( + mobility_dir='path/to/mobility', + num_regions=401, # German counties + county_ids=county_list, + models=models_per_region # SECIR models for each region + ) + + # Get baseline contact matrix + contact_matrix = GNN_utils.get_baseline_contact_matrix( + data_dir='path/to/contact_matrices' + ) + +Practical Usage Example +~~~~~~~~~~~~~~~~~~~~~~~ + +Here is a complete example workflow from data generation to model evaluation: + +.. code-block:: python + + import pickle + from pathlib import Path + from memilio.surrogatemodel.GNN import ( + data_generation, + network_architectures, + evaluate_and_train + ) + + # Step 1: Generate training data + print("Generating training data...") + dataset = data_generation.generate_dataset( + num_runs=5000, + num_days=30, + num_age_groups=6, + data_dir='/path/to/memilio/data/Germany', + mobility_dir='/path/to/mobility_data', + save_path='gnn_dataset_5000.pickle' + ) + + # Step 2: Define and build GNN model + print("Building GNN model...") + model_config = { + 'layer_type': 'GCN', + 'num_layers': 4, + 'hidden_dim': 128, + 'activation': 'relu', + 'dropout_rate': 0.2, + 'use_batch_norm': True + } + + model = network_architectures.build_gnn_model( + config=model_config, + input_shape=(1, 48), # 6 age groups × 8 compartments + output_dim=48 # Predict all compartments + ) + + # Step 3: Train the model + print("Training model...") + training_config = { + 'epochs': 200, + 'batch_size': 32, + 'learning_rate': 0.001, + 'optimizer': 'adam', + 'loss_function': 'mae', + 'early_stopping_patience': 20, + 'validation_split': 0.2 + } + + history = evaluate_and_train.train_gnn_model( + model=model, + dataset=dataset, + config=training_config, + save_weights='gnn_weights_best.h5' + ) + + # Step 4: Evaluate on test data + print("Evaluating model...") + test_metrics = evaluate_and_train.evaluate_model( + model=model, + test_data='gnn_test_data.pickle', + metrics=['mae', 'mape', 'r2'] + ) + + # Print results + print(f"Test MAE: {test_metrics['mae']:.4f}") + print(f"Test MAPE: {test_metrics['mape']:.2f}%") + print(f"Test R²: {test_metrics['r2']:.4f}") + + # Step 5: Make predictions on new scenarios + with open('new_scenario.pickle', 'rb') as f: + new_data = pickle.load(f) + + predictions = model.predict(new_data) + print(f"Predictions shape: {predictions.shape}") + +**GPU Acceleration:** + +- TensorFlow automatically uses GPU when available +- Spektral layers are optimized for GPU execution +- Training time can be heavily reduced with appropriate GPU hardware + +Additional Resources +~~~~~~~~~~~~~~~~~~~~ + +**Code and Examples:** + +- `GNN Module `_ +- `GNN README `_ +- `Test Scripts `_ + +**Related Documentation:** + +- :doc:`ODE-SECIR Model <../models/ode_secir>` +- :doc:`MEmilio Simulation Package ` +- :doc:`Python Bindings ` + diff --git a/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/GNN_utils.py b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/GNN_utils.py new file mode 100644 index 0000000000..e98d0ef99f --- /dev/null +++ b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/GNN_utils.py @@ -0,0 +1,295 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Agatha Schmidt, Henrik Zunker +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# +""" +Utility functions for GNN-based surrogate models. + +This module provides helper functions for data preprocessing, transformation, +and graph construction used in Graph Neural Network surrogate models for +epidemiological simulations. +""" + +import os +from typing import List, Tuple + +import numpy as np +import pandas as pd +from sklearn.preprocessing import FunctionTransformer + +from memilio.simulation.osecir import ModelGraph, set_edges +from memilio.epidata import getDataIntoPandasDataFrame as gd +from memilio.epidata import transformMobilityData as tmd +from memilio.epidata import modifyDataframeSeries as mdfs + + +# Default number of compartments in the ODE-SECIR model (without confirmed compartments) +DEFAULT_NUM_COMPARTMENTS = 8 + + +def remove_confirmed_compartments(dataset_entries, num_groups): + """Removes confirmed compartments from simulation data by merging them with base compartments. + + The ODE-SECIR model includes separate "confirmed" compartments that track + diagnosed cases. For GNN training, these are merged back into their base + compartments (InfectedNoSymptoms + InfectedNoSymptomsConfirmed -> InfectedNoSymptoms, + InfectedSymptoms + InfectedSymptomsConfirmed -> InfectedSymptoms). + + :param dataset_entries: Compartment data array containing confirmed compartments. + Shape: [num_timesteps, num_groups * num_compartments_with_confirmed] + :param num_groups: Number of age groups in the model. + :returns: Array with confirmed compartments merged into base compartments. + Shape: [num_timesteps, num_groups * num_compartments_without_confirmed] + + """ + + new_dataset_entries = [] + for timestep_data in dataset_entries: + # Reshape to separate age groups and compartments + data_reshaped = timestep_data.reshape( + [num_groups, int(np.asarray(dataset_entries).shape[1] / num_groups)] + ) + + # Merge InfectedNoSymptoms (index 2) with InfectedNoSymptomsConfirmed (index 3) + sum_infected_no_symptoms = np.sum(data_reshaped[:, [2, 3]], axis=1) + + # Merge InfectedSymptoms (index 4) with InfectedSymptomsConfirmed (index 5) + sum_infected_symptoms = np.sum(data_reshaped[:, [4, 5]], axis=1) + + # Replace original compartments with merged values + data_reshaped[:, 2] = sum_infected_no_symptoms + data_reshaped[:, 4] = sum_infected_symptoms + + # Remove confirmed compartments (indices 3 and 5) and flatten + new_dataset_entries.append( + np.delete(data_reshaped, [3, 5], axis=1).flatten() + ) + + return new_dataset_entries + + +def get_baseline_contact_matrix(data_dir): + """Loads and sums baseline contact matrices for all location types. + + Loads contact matrices for home, school, work, and other locations, then + returns their sum as the total baseline contact matrix. + + :param data_dir: Root directory containing contact matrix data. + :returns: Combined baseline contact matrix as numpy array. + :raises FileNotFoundError: If any contact matrix file is not found. + + """ + contact_dir = os.path.join(data_dir, "Germany", "contacts") + + contact_files = [ + "baseline_home.txt", + "baseline_school_pf_eig.txt", + "baseline_work.txt", + "baseline_other.txt" + ] + + baseline_matrix = None + + for filename in contact_files: + filepath = os.path.join(contact_dir, filename) + + if not os.path.exists(filepath): + raise FileNotFoundError( + f"Contact matrix file not found: {filepath}") + + matrix = np.loadtxt(filepath) + + if baseline_matrix is None: + baseline_matrix = matrix + else: + baseline_matrix += matrix + + return baseline_matrix + + +def create_mobility_graph(mobility_dir, num_regions, county_ids, models): + """Creates a graph-ODE model with mobility connections between regions. + + Constructs a graph, where each node represents a region (county) with + its own ODE-SECIR model, and edges represent mobility flows between regions. + + :param mobility_dir: Directory containing mobility data files. + :param num_regions: Number of regions/counties to include in the graph. + :param county_ids: List of county IDs/keys for each region. + :param models: List of ODE-SECIR Model instances, one per county. + :returns: Configured graph with nodes and mobility edges. + + """ + # Initialize empty graph + graph = ModelGraph() + + # Add one node per region with its model + for i in range(num_regions): + graph.add_node(int(county_ids[i]), models[i]) + + # Number of contact locations (home, school, work, other) + num_contact_locations = 4 + + # Add mobility edges between nodes + parent_dir = os.path.abspath(os.path.join(mobility_dir, os.pardir)) + set_edges(parent_dir, graph, num_contact_locations) + + return graph + + +def transform_mobility_data(data_dir): + """Updates mobility data to merge Eisenach and Wartburgkreis counties. + + :param data_dir: Root directory containing Germany mobility data. + :returns: Path to the updated mobility directory. + """ + # Get mobility data directory + mobility_dir = os.path.join(data_dir, 'Germany/mobility/') + + # Update mobility files by merging Eisenach and Wartburgkreis + tmd.updateMobility2022(mobility_dir, mobility_file='twitter_scaled_1252') + tmd.updateMobility2022( + mobility_dir, mobility_file='commuter_mobility_2022') + + return mobility_dir + + +def load_population_data(data_dir, age_groups=None): + """Loads population data for counties stratified by age groups. + + Reads county-level population data and aggregates it into specified age groups. + Default age groups follow the standard 6-group structure used in the ODE-SECIR model. + + :param data_dir: Root directory containing population data (should contain Germany/pydata/ subdirectory). + :param age_groups: List of age group labels (default: ['0-4', '5-14', '15-34', '35-59', '60-79', '80-130']). + :returns: List of population data entries, where each entry is [county_id, pop_group1, pop_group2, ...]. + :raises FileNotFoundError: If population data file is not found. + """ + if age_groups is None: + age_groups = ['0-4', '5-14', '15-34', '35-59', '60-79', '80-130'] + + # Load population data + population_file = os.path.join( + data_dir, 'Germany', 'pydata', 'county_population.json') + + if not os.path.exists(population_file): + raise FileNotFoundError( + f"Population data file not found: {population_file}") + + df_population = pd.read_json(population_file) + + # Initialize DataFrame for age-grouped population + df_population_agegroups = pd.DataFrame( + columns=[df_population.columns[0]] + age_groups + ) + + # Process each region + for region_id in df_population.iloc[:, 0]: + region_data = df_population[df_population.iloc[:, 0] + == int(region_id)] + age_grouped_pop = list( + mdfs.fit_age_group_intervals( + region_data.iloc[:, 2:], age_groups) + ) + + df_population_agegroups.loc[len(df_population_agegroups.index)] = [ + int(region_id)] + age_grouped_pop + + return df_population_agegroups.values.tolist() + + +def scale_data( + data, transform=True, num_compartments=DEFAULT_NUM_COMPARTMENTS): + """Applies logarithmic transformation to simulation data for training. + + Transforms compartment data using log1p (log(1+x)) to stabilize training. + This transformation helps handle the wide range of population values and + improves gradient flow during neural network training. + + :param data: Dictionary containing 'inputs' and 'labels' keys with simulation data. + Inputs shape: [num_samples, time_steps, num_nodes, num_features] + Labels shape: [num_samples, time_steps, num_nodes, num_features] + :param transform: Whether to apply log transformation (True) or just reshape (False). + :param num_compartments: Number of compartments per age group in the model (default: 8). + :returns: Tuple of (scaled_inputs, scaled_labels), both with shape + [num_samples, num_features, time_steps, num_nodes] + :raises ValueError: If input data is not numeric or has unexpected structure. + + """ + # Validate input data types + inputs_array = np.asarray(data['inputs']) + labels_array = np.asarray(data['labels']) + + if not np.issubdtype(inputs_array.dtype, np.number): + raise ValueError("Input data must be numeric.") + if not np.issubdtype(labels_array.dtype, np.number): + raise ValueError("Label data must be numeric.") + + # Calculate number of age groups from data shape + num_groups = int(inputs_array.shape[2] / num_compartments) + + # Initialize transformer (log1p for numerical stability) + transformer = FunctionTransformer(np.log1p, validate=True) + + # Process inputs + # Reshape: [samples, timesteps, nodes, features] -> [nodes, samples, timesteps, features] + # -> [nodes * compartments, samples * timesteps * age_groups] + inputs_reshaped = inputs_array.transpose(2, 0, 1, 3).reshape( + num_groups * num_compartments, -1 + ) + + if transform: + inputs_transformed = transformer.transform(inputs_reshaped) + else: + inputs_transformed = inputs_reshaped + + original_shape_input = inputs_array.shape + + # Reverse reshape to separate dimensions + inputs_back = inputs_transformed.reshape( + original_shape_input[2], + original_shape_input[0], + original_shape_input[1], + original_shape_input[3] + ) + + # Reverse transpose and reorder to [samples, features, timesteps, nodes] + scaled_inputs = inputs_back.transpose(1, 2, 0, 3).transpose(0, 3, 1, 2) + + # Process labels with same procedure + labels_reshaped = labels_array.transpose(2, 0, 1, 3).reshape( + num_groups * num_compartments, -1 + ) + + if transform: + labels_transformed = transformer.transform(labels_reshaped) + else: + labels_transformed = labels_reshaped + + original_shape_labels = labels_array.shape + + labels_back = labels_transformed.reshape( + original_shape_labels[2], + original_shape_labels[0], + original_shape_labels[1], + original_shape_labels[3] + ) + + scaled_labels = labels_back.transpose(1, 2, 0, 3).transpose(0, 3, 1, 2) + + return scaled_inputs, scaled_labels diff --git a/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/README.md b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/README.md new file mode 100644 index 0000000000..007496b748 --- /dev/null +++ b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/README.md @@ -0,0 +1,30 @@ +# Graph Neural Network (GNN) Surrogate Models + +This module implements Graph Neural Network-based surrogate models for epidemiological simulations, specifically designed to accelerate and enhance pandemic response modeling. + +## Overview + +The GNN surrogate models are based on the research presented in: + +**Schmidt A, Zunker H, Heinlein A, Kühn MJ. (2025).** *Graph Neural Network Surrogates to leverage Mechanistic Expert Knowledge towards Reliable and Immediate Pandemic Response*. Submitted for publication. +https://doi.org/10.48550/arXiv.2411.06500 + +This implementation leverages the underlying [ODE SECIR model](https://memilio.readthedocs.io/en/latest/cpp/models/osecir.html) and applies Graph Neural Networks to create fast, reliable surrogate models that can be used for immediate pandemic response scenarios. The models are stratified by age groups and incorporate spatial connectivity through mobility data. + +## Module Structure + +The GNN module consists of the following components: + +- **`data_generation.py`**: Generates training and evaluation data for GNN surrogate models by simulating epidemiological scenarios using the mechanistic SECIR model. Handles parameter sampling, compartment initialization, damping factors, and mobility connections between regions. + +- **`network_architectures.py`**: Defines various GNN architectures using different layer types (e.g., Graph Convolutional Networks, Graph Attention Networks). Provides functionality to configure network depth, width, activation functions, and preprocessing transformations. + +- **`evaluate_and_train.py`**: Implements training and evaluation pipelines for GNN surrogate models. Loads generated data, trains models with specified hyperparameters, evaluates performance metrics, and saves trained model weights. + +- **`grid_search.py`**: Provides hyperparameter optimization through systematic grid search over network architectures, training configurations, and model parameters to identify optimal GNN configurations for epidemiological forecasting. + +- **`GNN_utils.py`**: Contains utility functions for data preprocessing, mobility graph creation, population data loading, data scaling/transformation, and other helper functions used throughout the GNN workflow. + +## Documentation + +Comprehensive documentation for the GNN surrogate models, including tutorials and usage examples, is available in our [documentation](https://memilio.readthedocs.io/en/latest/python/m-surrogate.html). diff --git a/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/__init__.py b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/__init__.py new file mode 100644 index 0000000000..d97497a6ee --- /dev/null +++ b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/__init__.py @@ -0,0 +1,23 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Agatha Schmidt, Henrik Zunker +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# + +""" +A surrogate model for a spatial resolved SECIR model. +""" diff --git a/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/data_generation.py b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/data_generation.py new file mode 100644 index 0000000000..eb3e44e271 --- /dev/null +++ b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/data_generation.py @@ -0,0 +1,623 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Agatha Schmidt, Henrik Zunker +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# +""" +Data generation module for GNN-based surrogate models. + +This module provides functionality to generate training data for Graph Neural Network +surrogate models by running a ODE-SECIR model based simulations across multiple regions and +time periods with varying damping interventions. +""" + +import copy +import os +import pickle +import random +import time +from enum import Enum +from typing import Dict, List, Tuple + +import numpy as np +from progress.bar import Bar + +import memilio.simulation as mio +import memilio.simulation.osecir as osecir +from memilio.simulation import AgeGroup, set_log_level, Damping +from memilio.simulation.osecir import ( + Index_InfectionState, InfectionState, ParameterStudy, + interpolate_simulation_result) + +from memilio.surrogatemodel.GNN.GNN_utils import ( + scale_data, + remove_confirmed_compartments +) +import memilio.surrogatemodel.utils.dampings as dampings + + +class Location(Enum): + """Contact location types for the model.""" + Home = 0 + School = 1 + Work = 2 + Other = 3 + + +# Default number of age groups +DEFAULT_NUM_AGE_GROUPS = 6 + +# Contact location names corresponding to provided contact files +CONTACT_LOCATIONS = ["home", "school_pf_eig", "work", "other"] + + +def set_covid_parameters( + model, start_date, num_groups=DEFAULT_NUM_AGE_GROUPS): + """Sets COVID-19 specific parameters for all age groups. + + Parameters are based on Kühn et al. (2021): https://doi.org/10.1016/j.mbs.2021.108648 + The function sets age-stratified parameters (when available) the following age groups: + 0-4, 5-14, 15-34, 35-59, 60-79, 80+ + + :param model: MEmilio ODE SECIR model to configure. + :param start_date: Start date of the simulation (used to set StartDay parameter). + :param num_groups: Number of age groups (default: 6). + + """ + # Age-specific transmission and progression parameters + transmission_probability = [0.03, 0.06, 0.06, 0.06, 0.09, 0.175] + recovered_per_infected_no_symptoms = [0.25, 0.25, 0.2, 0.2, 0.2, 0.2] + severe_per_infected_symptoms = [ + 0.0075, 0.0075, 0.019, 0.0615, 0.165, 0.225] + critical_per_severe = [0.075, 0.075, 0.075, 0.15, 0.3, 0.4] + deaths_per_critical = [0.05, 0.05, 0.14, 0.14, 0.4, 0.6] + + # Age-specific compartment transition times (in days) + time_infected_no_symptoms = [2.74, 2.74, 2.565, 2.565, 2.565, 2.565] + time_infected_symptoms = [7.02625, 7.02625, 7.0665, 6.9385, 6.835, 6.775] + time_infected_severe = [5, 5, 5.925, 7.55, 8.5, 11] + time_infected_critical = [6.95, 6.95, 6.86, 17.36, 17.1, 11.6] + + # Apply parameters for each age group + for i in range(num_groups): + age_group = AgeGroup(i) + + model.parameters.TimeExposed[age_group] = 3.335 + model.parameters.TimeInfectedNoSymptoms[age_group] = time_infected_no_symptoms[i] + model.parameters.TimeInfectedSymptoms[age_group] = time_infected_symptoms[i] + model.parameters.TimeInfectedSevere[age_group] = time_infected_severe[i] + model.parameters.TimeInfectedCritical[age_group] = time_infected_critical[i] + + model.parameters.RelativeTransmissionNoSymptoms[age_group] = 1.0 + model.parameters.TransmissionProbabilityOnContact[age_group] = transmission_probability[i] + model.parameters.RiskOfInfectionFromSymptomatic[age_group] = 0.25 + model.parameters.MaxRiskOfInfectionFromSymptomatic[age_group] = 0.5 + model.parameters.RecoveredPerInfectedNoSymptoms[age_group] = recovered_per_infected_no_symptoms[i] + model.parameters.SeverePerInfectedSymptoms[age_group] = severe_per_infected_symptoms[i] + model.parameters.CriticalPerSevere[age_group] = critical_per_severe[i] + model.parameters.DeathsPerCritical[age_group] = deaths_per_critical[i] + + # Set simulation start day + model.parameters.StartDay = start_date.day_in_year + + +def set_contact_matrices(model, data_dir, num_groups=DEFAULT_NUM_AGE_GROUPS): + """Loads and configures contact matrices for different location types. + + Contact matrices are loaded for the locations defined in CONTACT_LOCATIONS. + + :param model: MEmilio ODE SECIR model to configure. + :param data_dir: Root directory containing contact matrix data (should contain Germany/contacts/ subdirectory). + :param num_groups: Number of age groups (default: 6). + + """ + contact_matrices = mio.ContactMatrixGroup( + len(CONTACT_LOCATIONS), num_groups) + + # Load contact matrices for each location + for location_idx, location_name in enumerate(CONTACT_LOCATIONS): + contact_file = os.path.join( + data_dir, "Germany", "contacts", f"baseline_{location_name}.txt" + ) + + if not os.path.exists(contact_file): + raise FileNotFoundError( + f"Contact matrix file not found: {contact_file}" + ) + + contact_matrices[location_idx] = mio.ContactMatrix( + mio.read_mobility_plain(contact_file) + ) + + model.parameters.ContactPatterns.cont_freq_mat = contact_matrices + + +def get_graph(num_groups, data_dir, mobility_directory, start_date, end_date): + """Creates a graph with mobility connections. + + Creates a graph where each node represents a geographic region (here county) with its own ODE model, + and edges represent mobility/commuter connections between these regions. Each node is initialized with + population data and COVID parameters. Edges are weighted by commuter mobility patterns. + + :param num_groups: Number of age groups to model. + :param data_dir: Root directory containing population and contact data. + :param mobility_directory: Path to mobility/commuter data file. + :param start_date: Simulation start date. + :param end_date: Simulation end date (used for data loading). + :returns: Configured ModelGraph with nodes for each region and mobility edges. + + """ + model = osecir.Model(num_groups) + set_covid_parameters(model, start_date, num_groups) + set_contact_matrices(model, data_dir, num_groups) + + # Initialize empty graph + graph = osecir.ModelGraph() + + # To account for underreporting + scaling_factor_infected = [2.5] * num_groups + scaling_factor_icu = 1.0 + # Test & trace capacity as in Kühn et al. (2021): https://doi.org/10.1016/j.mbs.2021.108648 + tnt_capacity_factor = 7.5 / 100000.0 + + # Paths to population data + pydata_dir = os.path.join(data_dir, "Germany", "pydata") + path_population_data = os.path.join( + pydata_dir, "county_current_population.json") + + # Verify data files exist + if not os.path.exists(path_population_data): + raise FileNotFoundError( + f"Population data not found: {path_population_data}" + ) + + # Create one node per county + is_node_for_county = True + + # Populate graph nodes with county level data + osecir.set_nodes( + model.parameters, + start_date, + end_date, + pydata_dir, + path_population_data, + is_node_for_county, + graph, + scaling_factor_infected, + scaling_factor_icu, + tnt_capacity_factor, + 0, # Zero days extrapolating the data + False # No extrapolation of data + ) + + # Add mobility edges between regions + osecir.set_edges(mobility_directory, graph, len(CONTACT_LOCATIONS)) + + return graph + + +def get_compartment_factors(): + """Draw base factors for compartment initialization. + + Factors follow the sampling strategy from Schmidt et al. + (2024): symptomatic individuals between 0.01% and 5% of the population, + exposed and asymptomatic compartments proportional to that symptomatic + proportion, and hospital/ICU/deaths sampled hierarchically. + Recovered is taken from the remaining feasible proportion and susceptibles + fill the residual. + """ + p_infected = random.uniform(0.0001, 0.05) + p_exposed = p_infected * random.uniform(0.1, 5.0) + p_ins = p_infected * random.uniform(0.1, 5.0) + p_hosp = p_infected * random.uniform(0.001, 1.0) + p_icu = p_hosp * random.uniform(0.001, 1.0) + p_dead = p_icu * random.uniform(0.001, 1.0) + + sum_randoms = ( + p_infected + p_exposed + p_ins + p_hosp + p_icu + p_dead + ) + + if sum_randoms >= 1.0: + raise RuntimeError( + "Sampled compartment factors exceed total population. Adjust bounds." + ) + + p_recovered = random.uniform(0.0, 1.0 - sum_randoms) + p_susceptible = max(1.0 - (sum_randoms + p_recovered), 0.0) + + return { + "infected": p_infected, + "exposed": p_exposed, + "infected_no_symptoms": p_ins, + "hospitalized": p_hosp, + "critical": p_icu, + "dead": p_dead, + "recovered": p_recovered, + "susceptible": p_susceptible + } + + +def _initialize_compartments_for_node( + model, factors, num_groups, within_group_variation): + """Initializes epidemic compartments using shared base factors. + + :param model: Model instance for the specific node. + :param factors: Compartment factors obtained from get_compartment_factors. + :param num_groups: Number of age groups. + :param within_group_variation: Whether to apply additional random scaling per age group. + """ + + def _variation(): + return random.uniform(0.1, 1.0) if within_group_variation else 1.0 + + for age_idx in range(num_groups): + age_group = AgeGroup(age_idx) + total_population = model.populations.get_group_total_AgeGroup( + age_group) + + infected_symptoms = total_population * \ + factors["infected"] * _variation() + + exposed = infected_symptoms * factors["exposed"] * _variation() + infected_no_symptoms = infected_symptoms * \ + factors["infected_no_symptoms"] * _variation() + infected_severe = infected_symptoms * \ + factors["hospitalized"] * _variation() + infected_critical = infected_severe * \ + factors["critical"] * _variation() + dead = infected_critical * factors["dead"] * _variation() + recovered = total_population * factors["recovered"] * _variation() + + model.populations[age_group, Index_InfectionState( + InfectionState.InfectedSymptoms)] = infected_symptoms + model.populations[age_group, Index_InfectionState( + InfectionState.Exposed)] = exposed + model.populations[age_group, Index_InfectionState( + InfectionState.InfectedNoSymptoms)] = infected_no_symptoms + model.populations[age_group, Index_InfectionState( + InfectionState.InfectedSevere)] = infected_severe + model.populations[age_group, Index_InfectionState( + InfectionState.InfectedCritical)] = infected_critical + model.populations[age_group, Index_InfectionState( + InfectionState.Dead)] = dead + model.populations[age_group, Index_InfectionState( + InfectionState.Recovered)] = recovered + + # Susceptibles are the remainder + model.populations.set_difference_from_group_total_AgeGroup( + (age_group, InfectionState.Susceptible), total_population + ) + + +def _apply_dampings_to_model(model, damping_days, damping_factors, num_groups): + """Applies contact dampings (NPIs) to model at specified days. + + Currently only supports global dampings (same for all age groups and spatial units). + + :param model: Model to apply dampings to. + :param damping_days: Days at which to apply dampings. + :param damping_factors: Multiplicative factors for contact reduction. + :param num_groups: Number of age groups. + :returns: Tuple of (damped_contact_matrices, damping_coefficients). + + """ + damped_matrices = [] + damping_coefficients = [] + + for day, factor in zip(damping_days, damping_factors): + # Create uniform damping matrix for all age groups + damping_matrix = np.ones((num_groups, num_groups)) * factor + + # Add damping to model + model.parameters.ContactPatterns.cont_freq_mat.add_damping( + Damping(coeffs=damping_matrix, t=day, level=0, type=0) + ) + + # Store resulting contact matrix and coefficients + damped_matrices.append( + model.parameters.ContactPatterns.cont_freq_mat.get_matrix_at( + day + 1)) + damping_coefficients.append(damping_matrix) + + return damped_matrices, damping_coefficients + + +def run_secir_groups_simulation( + days, damping_days, damping_factors, graph, within_group_variation, + num_groups=DEFAULT_NUM_AGE_GROUPS): + """Runs a multi-region ODE SECIR simulation with age groups and NPIs. + + Performs the following steps: + 1. Initialize each node with compartment specific (random) factors + 2. Apply contact dampings (NPIs) at specified days + 3. Run the simulation + 4. Post-process and return results + + :param days: Total number of days to simulate. + :param damping_days: List of days when NPIs are applied. + :param damping_factors: List of contact reduction factors for each damping. + :param graph: Pre-configured ModelGraph with nodes and edges. + :param within_group_variation: Whether to apply per-age random variation during initialization. + :param num_groups: Number of age groups (default: 6). + :returns: Tuple containing dataset_entry (simulation results for each node [num_nodes, time_steps, compartments]), + damped_matrices (contact matrices at each damping day), damping_coefficients (damping coefficient matrices), + and runtime (execution time in seconds). + :raises ValueError: If lengths of damping_days and damping_factors don't match. + + """ + if len(damping_days) != len(damping_factors): + raise ValueError( + f"Length mismatch: damping_days ({len(damping_days)}) != " + f"damping_factors ({len(damping_factors)})" + ) + + # Initialize each node in the graph + damped_matrices = [] + damping_coefficients = [] + + factors = get_compartment_factors() + + for node_idx in range(graph.num_nodes): + model = graph.get_node(node_idx).property + + # Initialize compartment populations + _initialize_compartments_for_node( + model, factors, num_groups, within_group_variation) + + # Apply dampings/NPIs + if damping_days: + node_damped_mats, node_damping_coeffs = _apply_dampings_to_model( + model, damping_days, damping_factors, num_groups + ) + # Store only from first node, since dampings are global at the moment + if node_idx == 0: + damped_matrices = node_damped_mats + damping_coefficients = node_damping_coeffs + + model.apply_constraints() + + # Update graph with initialized populations + graph.get_node(node_idx).property.populations = model.populations + + # Run simulation and measure runtime + study = ParameterStudy(graph, t0=0, tmax=days, dt=0.5, num_runs=1) + + start_time = time.perf_counter() + study_results = study.run() + runtime = time.perf_counter() - start_time + + # Interpolate results to daily values + graph_run = study_results[0] + results = interpolate_simulation_result(graph_run) + + # Remove confirmed compartments (not used in GNN) + for result_idx in range(len(results)): + results[result_idx] = remove_confirmed_compartments( + np.asarray(results[result_idx]), num_groups + ) + + dataset_entry = copy.deepcopy(results) + + return dataset_entry, damped_matrices, damping_coefficients, runtime + + +def generate_data(num_runs, data_dir, output_path, input_width, label_width, + start_date, end_date, save_data=True, transform=True, + damping_method="classic", max_number_damping=3, + mobility_file="commuter_mobility.txt", + num_groups=DEFAULT_NUM_AGE_GROUPS, + within_group_variation=True): + """Generates training dataset for GNN surrogate model. + + Runs num_runs-ODE SECIR simulations with random initial conditions and damping patterns to create a training dataset. + If save_data=True, saves a pickle file named: + 'GNN_data_{label_width}days_{max_number_damping}dampings_{damping_method}{num_runs}.pickle' + + :param num_runs: Number of simulation runs to generate. + :param data_dir: Root directory containing all required input data (population, contacts, mobility). + :param output_path: Directory where generated dataset will be saved. + :param input_width: Number of time steps for model input. + :param label_width: Number of time steps for model output/labels. + :param start_date: Simulation start date. + :param end_date: Simulation end date (used for set_nodes function). + :param save_data: Whether to save dataset (default: True). + :param transform: Whether to apply scaling transformation to data (default: True). + :param damping_method: Method for generating damping patterns: "classic", "active", "random". + :param max_number_damping: Maximum number of damping events per simulation. + :param mobility_file: Filename of mobility file (in data_dir/Germany/mobility/). + :param num_groups: Number of age groups (default: 6). + :param within_group_variation: Whether to apply random variation per spatial unit/age group when initializing compartments. + :returns: Dictionary with keys: "inputs" ([num_runs, input_width, num_nodes, features]), + "labels" ([num_runs, label_width, num_nodes, features]), + "contact_matrix" (List of damped contact matrices), "damping_days" (List of damping day arrays), + "damping_factors" (List of damping factor arrays). + + """ + set_log_level(mio.LogLevel.Error) + + # Calculate total simulation days + total_days = label_width + input_width - 1 + + # Initialize output dictionary + data = { + "inputs": [], + "labels": [], + "contact_matrix": [], + "damping_days": [], + "damping_factors": [] + } + + # Build mobility file path + mobility_path = os.path.join( + data_dir, "Germany", "mobility", mobility_file) + + # Verify mobility file exists + if not os.path.exists(mobility_path): + raise FileNotFoundError(f"Mobility file not found: {mobility_path}") + + # Create graph (reused for all runs with different initial conditions) + graph = get_graph(num_groups, data_dir, + mobility_path, start_date, end_date) + + print(f"\nGenerating {num_runs} simulation runs...") + bar = Bar( + 'Progress', max=num_runs, + suffix='%(percent)d%% [%(elapsed_td)s / %(eta_td)s]') + + runtimes = [] + + for _ in range(num_runs): + # Generate random damping pattern + if max_number_damping > 0: + damping_days, damping_factors = dampings.generate_dampings( + total_days, + max_number_damping, + method=damping_method, + min_distance=2, + min_damping_day=2 + ) + else: + damping_days = [] + damping_factors = [] + + # Run simulation + simulation_result, damped_mats, damping_coeffs, runtime = \ + run_secir_groups_simulation( + total_days, damping_days, damping_factors, graph, + within_group_variation, num_groups + ) + + runtimes.append(runtime) + + # Split into inputs and labels + # Shape: [num_nodes, time_steps, features] -> transpose to [time_steps, num_nodes, features] + result_transposed = np.asarray(simulation_result).transpose(1, 0, 2) + inputs = result_transposed[:input_width] + labels = result_transposed[input_width:] + + # Store results + data["inputs"].append(inputs) + data["labels"].append(labels) + data["contact_matrix"].append(np.array(damped_mats)) + data["damping_factors"].append(damping_coeffs) + data["damping_days"].append(damping_days) + + bar.next() + + bar.finish() + + # Print performance statistics + print(f"\nSimulation Statistics:") + print(f" Total days simulated: {total_days}") + print(f" Average runtime: {np.mean(runtimes):.3f}s") + print(f" Median runtime: {np.median(runtimes):.3f}s") + print(f" Total time: {np.sum(runtimes):.1f}s") + + # Save dataset if requested + if save_data: + # Apply scaling transformation + inputs_scaled, labels_scaled = scale_data(data, transform) + + all_data = { + "inputs": inputs_scaled, + "labels": labels_scaled, + "damping_day": data["damping_days"], + "contact_matrix": data["contact_matrix"], + "damping_coeff": data["damping_factors"] + } + + # Create output directory if needed + os.makedirs(output_path, exist_ok=True) + + # Generate filename + if num_runs < 1000: + filename = f'GNN_data_{label_width}days_{max_number_damping}dampings_{damping_method}{num_runs}.pickle' + else: + filename = f'GNN_data_{label_width}days_{max_number_damping}dampings_{damping_method}{num_runs//1000}k.pickle' + + # Save to pickle file + output_file = os.path.join(output_path, filename) + with open(output_file, 'wb') as f: + pickle.dump(all_data, f) + + print(f"\nDataset saved to: {output_file}") + + return data + + +def main(): + """Main function for dataset generation. + + Example configuration for generating GNN training data. + + """ + # Set random seed for reproducibility + random.seed(10) + + # Configuration + data_dir = os.path.join(os.getcwd(), 'data') + output_path = os.path.join(os.getcwd(), 'generated_datasets') + + # Simulation parameters + input_width = 5 # Days of history used as input + num_runs = 1 # Number of simulation runs + max_dampings = 0 # Number of NPI dampings per simulation + + # Prediction horizons to generate data for + prediction_horizons = [30] # Days to predict into the future + + # Simulation time period + start_date = mio.Date(2020, 10, 1) + end_date = mio.Date(2021, 10, 31) + + # Generate datasets + print("=" * 70) + print("GNN Surrogate Model - Dataset Generation") + print("=" * 70) + print(f"Data directory: {data_dir}") + print(f"Output directory: {output_path}") + print(f"Simulation period: {start_date} to {end_date}") + print(f"Number of runs per configuration: {num_runs}") + print(f"Input width: {input_width} days") + print(f"Max dampings: {max_dampings}") + print("=" * 70) + + for label_width in prediction_horizons: + print(f"\n{'='*70}") + print(f"Generating data for {label_width}-day predictions") + print(f"{'='*70}") + + generate_data( + num_runs=num_runs, + data_dir=data_dir, + output_path=output_path, + input_width=input_width, + label_width=label_width, + start_date=start_date, + end_date=end_date, + save_data=True, + damping_method="active", + max_number_damping=max_dampings + ) + + print(f"\n{'='*70}") + print("Dataset generation complete!") + print(f"{'='*70}") + + +if __name__ == "__main__": + main() diff --git a/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/evaluate_and_train.py b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/evaluate_and_train.py new file mode 100644 index 0000000000..75e8fce937 --- /dev/null +++ b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/evaluate_and_train.py @@ -0,0 +1,430 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Agatha Schmidt, Henrik Zunker, Manuel Heger +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# +""" +Training and evaluation module for GNN-based surrogate models. + +This module loads GNN training data, trains/evaluates surrogate models, and saves weights plus metrics. +""" + +import pickle +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Tuple + +import numpy as np +import pandas as pd +import spektral + + +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.losses import MeanAbsolutePercentageError +import tensorflow.keras.initializers as initializers + +import tensorflow as tf +import memilio.surrogatemodel.GNN.network_architectures as network_architectures +from memilio.surrogatemodel.utils.helper_functions import (calc_split_index) + +from spektral.data import MixedLoader +from spektral.layers import ARMAConv +from spektral.utils.convolution import normalized_laplacian, rescale_laplacian + + +def _iterate_batches(loader): + """Yield all batches from a Spektral loader for one epoch.""" + for _ in range(loader.steps_per_epoch): + yield next(loader) + + +class StaticGraphDataset(spektral.data.Dataset): + """Spektral dataset wrapper for samples that share a single adjacency matrix.""" + + def __init__( + self, + node_features, + node_labels, + adjacency): + self._node_features = node_features + self._node_labels = node_labels + self._adjacency = adjacency.astype(np.float32) + super().__init__() + # This must be set AFTER calling super().__init__() + self.a = self._adjacency + + def read(self): + """Create one Graph object per sample. + + Note: For MixedLoader, individual Graph objects should not have an + adjacency matrix (a). The adjacency matrix is stored at the dataset level. + """ + return [ + spektral.data.Graph( + x=feature.astype(np.float32), + y=label.astype(np.float32), + ) + for feature, label in zip(self._node_features, self._node_labels) + ] + + +@dataclass +class TrainingSummary: + """Container for aggregated training and evaluation metrics.""" + + model_name: str + mean_train_loss: float + mean_val_loss: float + mean_test_loss: float + mean_test_loss_orig: float + training_time: float + train_losses: List[List[float]] + val_losses: List[List[float]] + + +def load_gnn_dataset( + dataset_path, + mobility_dir, + number_of_nodes=400, + mobility_filename="commuter_mobility_2022.txt"): + """Load serialized samples and mobility data into a Spektral dataset. + + Args: + dataset_path: Pickle file containing the generated training samples. + mobility_dir: Directory containing the commuter mobility file. + number_of_nodes: Number of spatial nodes to retain from the mobility data. + mobility_filename: Mobility file name to use. + + Returns: + Spektral dataset with one `Graph` per sample sharing a common adjacency matrix. + """ + dataset_path = Path(dataset_path) + mobility_dir = Path(mobility_dir) + + if not dataset_path.exists(): + raise FileNotFoundError(f"Dataset not found: {dataset_path}") + + mobility_path = mobility_dir / mobility_filename + if not mobility_path.exists(): + raise FileNotFoundError(f"Mobility file not found: {mobility_path}") + + with dataset_path.open("rb") as fp: + data = pickle.load(fp) + + if "inputs" not in data or "labels" not in data: + raise KeyError( + f"Dataset at {dataset_path} must contain 'inputs' and 'labels' keys." + ) + + inputs = np.asarray(data["inputs"]) + labels = np.asarray(data["labels"]) + if inputs.shape[0] == 0: + raise ValueError( + "Loaded dataset is empty; expected at least one sample.") + + # Flatten temporal dimensions into feature vectors per node. + num_samples, input_width, num_nodes, num_features = inputs.shape + _, label_width, _, label_features = labels.shape + + if num_nodes != number_of_nodes: + raise ValueError( + f"Number of nodes in dataset ({num_nodes}) does not match expected " + f"value ({number_of_nodes}).") + + node_features = inputs.transpose(0, 2, 1, 3).reshape( + num_samples, number_of_nodes, input_width * num_features) + node_labels = labels.transpose(0, 2, 1, 3).reshape( + num_samples, number_of_nodes, label_width * label_features) + + commuter_data = pd.read_csv(mobility_path, sep=" ", header=None) + adjacency_matrix = commuter_data.iloc[ + :number_of_nodes, :number_of_nodes + ].to_numpy() + adjacency_matrix = (adjacency_matrix > 0).astype(np.float32) + adjacency_matrix = np.maximum(adjacency_matrix, adjacency_matrix.T) + + return StaticGraphDataset(node_features, node_labels, adjacency_matrix) + + +def create_dataset(path_cases, path_mobility, number_of_nodes=400): + """Compatibility wrapper around `load_gnn_dataset`.""" + return load_gnn_dataset( + Path(path_cases), + Path(path_mobility), + number_of_nodes=number_of_nodes) + + +def _train_step_impl(model, optimizer, loss_fn, inputs, target): + """Perform one optimization step.""" + with tf.GradientTape() as tape: + predictions = model(inputs, training=True) + loss = loss_fn(target, predictions) + if model.losses: + loss += tf.add_n(model.losses) + + gradients = tape.gradient(loss, model.trainable_variables) + optimizer.apply_gradients(zip(gradients, model.trainable_variables)) + + metric = tf.reduce_mean(loss_fn(target, predictions)) + return loss, metric + + +def train_step(*args, **kwargs): + """Wrapper that accepts both old and new call signatures.""" + if len(args) == 5 and not kwargs: + first, second, third, fourth, fifth = args + if isinstance(first, (tuple, list)): + inputs, target = first, second + loss_fn = third + model = fourth + optimizer = fifth + else: + model, optimizer, loss_fn, inputs, target = args + else: + raise TypeError( + "train_step expects either (inputs, target, loss_fn, model, optimizer) " + "or (model, optimizer, loss_fn, inputs, target).") + + return _train_step_impl(model, optimizer, loss_fn, inputs, target) + + +def evaluate(loader, model, loss_fn, retransform=False): + """Evaluate the model on the dataset provided by loader.""" + total_loss = 0.0 + total_metric = 0.0 + total_samples = 0 + + for inputs, target in _iterate_batches(loader): + predictions = model(inputs, training=False) + + target_tensor = tf.convert_to_tensor(target, dtype=tf.float32) + prediction_tensor = tf.cast(predictions, tf.float32) + if retransform: + target_tensor = tf.math.expm1(target_tensor) + prediction_tensor = tf.math.expm1(prediction_tensor) + + batch_losses = loss_fn(target_tensor, prediction_tensor) + batch_loss = tf.reduce_mean(batch_losses) + batch_metric = tf.reduce_mean(batch_losses) + + batch_size_tensor = tf.cast(tf.shape(target_tensor)[0], tf.float32) + batch_size = float(batch_size_tensor.numpy()) + total_loss += float(batch_loss.numpy()) * batch_size + total_metric += float(batch_metric.numpy()) * batch_size + total_samples += int(round(batch_size)) + + if total_samples == 0: + return 0.0, 0.0 + + mean_loss = total_loss / total_samples + mean_metric = total_metric / total_samples + return mean_loss, mean_metric + + +def train_and_evaluate( + data, batch_size, epochs, model, loss_fn, optimizer, es_patience, + save_dir=None, save_name="model"): + """Train the provided GNN model.""" + dataset_size = len(data) + if dataset_size == 0: + raise ValueError("Dataset must contain at least one sample.") + + n_train, n_valid, n_test = calc_split_index( + dataset_size, split_train=0.7, split_valid=0.2, split_test=0.1) + if n_train == 0 or n_valid == 0 or n_test == 0: + raise ValueError( + "Dataset split produced empty partitions. Provide a larger dataset " + "or adjust the split configuration.") + + train_data = data[:n_train] + valid_data = data[n_train:n_train + n_valid] + test_data = data[n_train + n_valid:] + + # Build the model by passing a single batch through it. + build_loader = MixedLoader(train_data, batch_size=min( + batch_size, max(1, n_train)), epochs=1, shuffle=False) + build_inputs, _ = next(build_loader) + model(build_inputs) + + def _make_loader(dataset, *, batch_size, shuffle=False): + return MixedLoader( + dataset, batch_size=batch_size, epochs=1, shuffle=shuffle) + + best_val_loss = np.inf + best_weights = model.get_weights() + patience_counter = es_patience + + epoch_train_losses: List[float] = [] + epoch_val_losses: List[float] = [] + + start_time = time.perf_counter() + + for epoch in range(1, epochs + 1): + train_loader = _make_loader( + train_data, batch_size=batch_size, shuffle=True) + batch_losses = [] + for inputs, target in _iterate_batches(train_loader): + loss, _ = train_step(model, optimizer, loss_fn, inputs, target) + batch_losses.append(float(loss.numpy())) + + epoch_train_loss = float(np.mean(batch_losses) + ) if batch_losses else 0.0 + epoch_train_losses.append(epoch_train_loss) + + val_loader = _make_loader(valid_data, batch_size=min( + batch_size, max(1, n_valid)), shuffle=False) + val_loss, _ = evaluate(val_loader, model, loss_fn) + epoch_val_losses.append(val_loss) + + print( + f"Epoch {epoch:02d} | train_loss={epoch_train_loss:.4f} " + f"| val_loss={val_loss:.4f}" + ) + + if val_loss < best_val_loss: + best_val_loss = val_loss + best_weights = model.get_weights() + patience_counter = es_patience + print(f" ↳ New best validation loss: {best_val_loss:.4f}") + else: + patience_counter -= 1 + if patience_counter == 0: + print("Early stopping triggered.") + break + + elapsed = time.perf_counter() - start_time + + # Restore best weights and evaluate on the test set. + model.set_weights(best_weights) + test_loader = _make_loader( + test_data, batch_size=min(batch_size, max(1, n_test)), shuffle=False) + test_loss, _ = evaluate(test_loader, model, loss_fn) + + test_loader_retransform = _make_loader( + test_data, batch_size=min(batch_size, max(1, n_test)), shuffle=False) + test_loss_orig, _ = evaluate( + test_loader_retransform, model, loss_fn, retransform=True) + + print(f"Test loss (log space): {test_loss:.4f}") + print(f"Test loss (original scale): {test_loss_orig:.4f}") + print(f"Training runtime: {elapsed:.2f}s ({elapsed / 60:.2f} min)") + + summary = TrainingSummary( + model_name=save_name, mean_train_loss=float( + np.min(epoch_train_losses)) + if epoch_train_losses else float("nan"), + mean_val_loss=float(np.min(epoch_val_losses)) + if epoch_val_losses else float("nan"), mean_test_loss=float(test_loss), + mean_test_loss_orig=float(test_loss_orig), + training_time=elapsed / 60, train_losses=[epoch_train_losses], + val_losses=[epoch_val_losses]) + + if save_dir: + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + + metrics_df = pd.DataFrame(columns=[ + "train_loss", "val_loss", "test_loss", + "test_loss_orig", "training_time", + "loss_history", "val_loss_history"]) + metrics_df.loc[len(metrics_df.index)] = [ + summary.mean_train_loss, + summary.mean_val_loss, + summary.mean_test_loss, + summary.mean_test_loss_orig, + summary.training_time, + summary.train_losses, + summary.val_losses + ] + + weights_dir = save_dir / "saved_weights" + weights_dir.mkdir(parents=True, exist_ok=True) + + weights_filename = save_name if save_name.endswith( + ".pickle") else f"{save_name}.pickle" + weights_path = weights_dir / weights_filename + with weights_path.open("wb") as fp: + pickle.dump(best_weights, fp) + + results_dir = save_dir / "model_evaluations_paper" + results_dir.mkdir(parents=True, exist_ok=True) + results_path = results_dir / \ + weights_filename.replace(".pickle", ".csv") + metrics_df.to_csv(results_path, index=False) + print(f"Saved weights to {weights_path}") + print(f"Saved evaluation metrics to {results_path}") + + return asdict(summary) + + +if __name__ == "__main__": + start_hyper = time.perf_counter() + epochs = 10 + batch_size = 2 + es_patience = 10 + optimizer = Adam(learning_rate=0.001) + loss_fn = MeanAbsolutePercentageError() + + repo_root = Path(__file__).resolve().parents[4] + artifacts_root = repo_root / "artifacts" + + dataset_path = artifacts_root / \ + "generated_datasets" / "GNN_data_30days_3dampings_classic5.pickle" + + mobility_dir = repo_root / "data" / "Germany" / "mobility" + data = load_gnn_dataset(dataset_path, mobility_dir) + + # Define the model architecture + def transform_a(adjacency_matrix): + a = adjacency_matrix.numpy() + a = rescale_laplacian(normalized_laplacian(a)) + return tf.convert_to_tensor(a, dtype=tf.float32) + + layer_types = [ + # Dense layer (only uses x) + lambda: ARMAConv(512, activation='elu', + kernel_initializer=initializers.GlorotUniform(seed=None)) + ] + num_repeat = [7] + + model_class = network_architectures.generate_model_class( + "ARMA", layer_types, num_repeat, num_output=1440, transform=transform_a) + + model = model_class() + + save_name = 'GNN_30days' # name for model + save_dir = artifacts_root / "model_results" + + train_and_evaluate( + data, + batch_size, + epochs, + model, + loss_fn, + optimizer, + es_patience, + save_dir=save_dir, + save_name=save_name) + + elapsed_hyper = time.perf_counter() - start_hyper + print( + "Time for hyperparameter testing: {:.4f} minutes".format( + elapsed_hyper / 60)) + print( + "Time for hyperparameter testing: {:.4f} hours".format( + elapsed_hyper / 60 / 60)) diff --git a/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/grid_search.py b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/grid_search.py new file mode 100644 index 0000000000..b6f78db540 --- /dev/null +++ b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/grid_search.py @@ -0,0 +1,326 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Agatha Schmidt, Henrik Zunker, Manuel Heger +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# +""" +Grid search module for GNN-based surrogate model hyperparameter optimization. + +This module provides functionality to perform systematic hyperparameter search +over GNN architectures, training configurations, and model parameters to identify +optimal configurations for epidemiological surrogate models. +""" + +from pathlib import Path + +import pandas as pd +import tensorflow as tf +from spektral.data import MixedLoader +from tensorflow.keras.losses import MeanAbsolutePercentageError +from tensorflow.keras.metrics import MeanAbsolutePercentageError as MeanAbsolutePercentageErrorMetric +from tensorflow.keras.optimizers import Adam + +from memilio.surrogatemodel.GNN.evaluate_and_train import ( + create_dataset, + train_and_evaluate +) +from memilio.surrogatemodel.GNN.network_architectures import get_model + + +# Default hyperparameter grid +DEFAULT_LAYER_TYPES = [ + "ARMAConv", + "GCSConv", + "GATConv", + "GCNConv", + "APPNPConv" +] + +DEFAULT_NUM_LAYERS = [2, 3, 4, 5, 6, 7] +DEFAULT_NUM_CHANNELS = [2, 3, 4, 5, 6, 7] +DEFAULT_ACTIVATION_FUNCTIONS = ["elu", "relu", "tanh", "sigmoid"] + +# Default training configuration +DEFAULT_BATCH_SIZE = 32 +DEFAULT_MAX_EPOCHS = 100 +DEFAULT_ES_PATIENCE = 30 + + +def generate_parameter_grid( + layer_types=None, + num_layers_options=None, + num_channels_options=None, + activation_functions=None): + """Generates a grid of model parameter combinations for hyperparameter search. + + :param layer_types: List of GNN layer types to test (default: DEFAULT_LAYER_TYPES). + :param num_layers_options: List of layer counts to test (default: DEFAULT_NUM_LAYERS). + :param num_channels_options: List of channel counts to test (default: DEFAULT_NUM_CHANNELS). + :param activation_functions: List of activation functions to test (default: DEFAULT_ACTIVATION_FUNCTIONS). + :returns: List of tuples, each containing (layer_type, num_layers, num_channels, activation). + + """ + if layer_types is None: + layer_types = DEFAULT_LAYER_TYPES + if num_layers_options is None: + num_layers_options = DEFAULT_NUM_LAYERS + if num_channels_options is None: + num_channels_options = DEFAULT_NUM_CHANNELS + if activation_functions is None: + activation_functions = DEFAULT_ACTIVATION_FUNCTIONS + + parameter_grid = [ + (layer_type, num_layers, num_channels, activation) + for layer_type in layer_types + for num_layers in num_layers_options + for num_channels in num_channels_options + for activation in activation_functions + ] + + return parameter_grid + + +def perform_grid_search( + data, + parameter_grid, + save_dir, + batch_size=DEFAULT_BATCH_SIZE, + max_epochs=DEFAULT_MAX_EPOCHS, + es_patience=DEFAULT_ES_PATIENCE, + learning_rate=0.001): + """Performs systematic grid search over GNN hyperparameters. + + Trains and evaluates models for each parameter combination in the grid, + tracking performance metrics and saving results. + Results are stored in a CSV file for later analysis. + + :param data: Spektral dataset containing training, validation, and test samples. + :param parameter_grid: List of tuples with (layer_type, num_layers, num_channels, activation) combinations. + :param save_dir: Directory to save results. + :param batch_size: Batch size for training (default: 32). + :param max_epochs: Maximum number of training epochs per configuration (default: 100). + :param es_patience: Early stopping in epochs (default: 30). + :param learning_rate: Learning rate for Adam optimizer (default: 0.001). + :returns: DataFrame containing all grid search results. + :raises ValueError: If data is empty or parameter_grid is invalid. + + """ + if not data or len(data) == 0: + raise ValueError("Dataset must contain at least one sample.") + + if not parameter_grid or len(parameter_grid) == 0: + raise ValueError( + "Parameter grid must contain at least one configuration.") + + # Convert save_dir to Path if it's a string + save_dir = Path(save_dir) / "saves" + + # Determine output dimension from data + output_dim = data[0].y.shape[-1] + + # Initialize loss function and optimizer + loss_function = MeanAbsolutePercentageError() + + # Initialize results DataFrame + results_df = pd.DataFrame( + columns=[ + 'model', + 'optimizer', + 'number_of_hidden_layers', + 'number_of_channels', + 'activation', + 'mean_train_loss', + 'mean_validation_loss', + 'mean_test_loss', + 'mean_test_loss_orig', + 'training_time', + 'train_losses', + 'val_losses' + ] + ) + + save_dir.mkdir(parents=True, exist_ok=True) + results_file = save_dir / 'grid_search_results.csv' + + print(f"\n{'=' * 70}") + print("GNN Grid Search - Hyperparameter Optimization") + print(f"{'=' * 70}") + print(f"Total configurations to test: {len(parameter_grid)}") + print(f"Results will be saved to: {results_file}") + print(f"{'=' * 70}\n") + + # Iterate through all parameter combinations + for idx, (layer_type, num_layers, num_channels, activation) in enumerate( + parameter_grid, 1): + print(f"\n[{idx}/{len(parameter_grid)}] Training configuration:") + print(f" Layer type: {layer_type}") + print(f" Number of layers: {num_layers}") + print(f" Number of channels: {num_channels}") + print(f" Activation function: {activation}") + + try: + # Create model instance + model = get_model( + layer_type=layer_type, + num_layers=num_layers, + num_channels=num_channels, + activation=activation, + num_output=output_dim + ) + + # Initialize optimizer for this configuration + optimizer = Adam(learning_rate=learning_rate) + + # Build model by passing a sample batch through it + build_loader = MixedLoader( + data, batch_size=batch_size, epochs=1, shuffle=False) + build_inputs, _ = next(build_loader) + model(build_inputs) + + # Initialize optimizer variables + optimizer.build(model.trainable_variables) + + model.compile( + optimizer=optimizer, + loss=MeanAbsolutePercentageError(), + metrics=[MeanAbsolutePercentageErrorMetric()] + ) + + # Train and evaluate model + results = train_and_evaluate( + data=data, + batch_size=batch_size, + epochs=max_epochs, + model=model, + loss_fn=loss_function, + optimizer=optimizer, + es_patience=es_patience, + save_name="", # Dont save individual models during grid search + save_dir=None # Dont save individual results + ) + + # Store results + results_df.loc[len(results_df)] = [ + layer_type, + optimizer.__class__.__name__, + num_layers, + num_channels, + activation, + results["mean_train_loss"], + results["mean_val_loss"], + results["mean_test_loss"], + results["mean_test_loss_orig"], + results["training_time"], + results["train_losses"], + results["val_losses"] + ] + + # Save intermediate results after each configuration + results_df.to_csv(results_file, index=False) + print( + f"Configuration complete. Results saved to {results_file}") + + except Exception as e: + print(f"Error training configuration: {e}") + # Continue with next configuration rather than failing entire search + + finally: + # Clear TensorFlow session to free memory + tf.keras.backend.clear_session() + + print(f"\n{'=' * 70}") + print("Grid Search Complete!") + print(f"{'=' * 70}") + print(f"Total configurations tested: {len(results_df)}") + print(f"Results saved to: {results_file}") + + # Print best configuration + if len(results_df) > 0: + best_idx = results_df['mean_validation_loss'].idxmin() + best_config = results_df.loc[best_idx] + print(f"\nBest Configuration:") + print(f" Model: {best_config['model']}") + print(f" Layers: {best_config['number_of_hidden_layers']}") + print(f" Channels: {best_config['number_of_channels']}") + print(f" Activation: {best_config['activation']}") + print(f" Validation Loss: {best_config['mean_validation_loss']:.4f}") + print(f" Test Loss: {best_config['mean_test_loss']:.4f}") + + print(f"{'=' * 70}\n") + + return results_df + + +def main(): + """Main function demonstrating grid search usage. + """ + # Dataset path + dataset_path = Path.cwd() / "generated_datasets" / \ + "GNN_data_30days_3dampings_classic5.pickle" + mobility_dir = Path.cwd() / "data" / "Germany" / "mobility" + + # Output directory for results + output_dir = Path.cwd() / "grid_search_results" + + print("=" * 70) + print("GNN Grid Search - Configuration") + print("=" * 70) + print(f"Dataset: {dataset_path}") + print(f"Mobility data: {mobility_dir}") + print(f"Output directory: {output_dir}") + print("=" * 70) + + # Verify files exist + if not dataset_path.exists(): + raise FileNotFoundError(f"Dataset not found: {dataset_path}") + if not mobility_dir.exists(): + raise FileNotFoundError( + f"Mobility directory not found: {mobility_dir}") + + # Load dataset + print("\nLoading dataset...") + data = create_dataset(str(dataset_path), str(mobility_dir)) + print(f"Dataset loaded: {len(data)} samples") + + # Generate parameter grid + # For demonstration, use a smaller grid. Remove restrictions for full search. + parameter_grid = generate_parameter_grid( + layer_types=["ARMAConv", "GCNConv"], + num_layers_options=[3, 5], + num_channels_options=[4, 6], + activation_functions=["elu", "relu"] + ) + + print( + f"Generated parameter grid with {len(parameter_grid)} configurations") + + # Perform grid search + results = perform_grid_search( + data=data, + parameter_grid=parameter_grid, + save_dir=str(output_dir), + batch_size=32, + max_epochs=100, + es_patience=30, + learning_rate=0.001 + ) + + print(f"\nGrid search complete. Results shape: {results.shape}") + + +if __name__ == "__main__": + main() diff --git a/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/network_architectures.py b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/network_architectures.py new file mode 100644 index 0000000000..55321d5b00 --- /dev/null +++ b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/network_architectures.py @@ -0,0 +1,427 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Agatha Schmidt, Henrik Zunker, Manuel Heger +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# +""" +Network architecture module for GNN-based surrogate models. + +This file provides functionality to generate Graph Neural Network +architectures with various layer types, configurations, and preprocessing +transformations. +""" + +import numpy as np +import tensorflow as tf +from spektral import layers as spektral_layers +from spektral.utils import convolution as spektral_convolution + + +# Supported GNN layer types +SUPPORTED_LAYER_TYPES = [ + "ARMAConv", + "GCSConv", + "GATConv", + "GCNConv", + "APPNPConv" +] + + +def _rescale_laplacian(laplacian): + """Rescales a Laplacian matrix while remaining compatible with newer SciPy.""" + laplacian = laplacian.toarray() if hasattr( + laplacian, "toarray") else np.asarray(laplacian) + try: + return spektral_convolution.rescale_laplacian(laplacian) + except TypeError as err: + if "eigvals" not in str(err): + raise + + lmax = np.linalg.eigvalsh(laplacian)[-1] + if lmax <= 0: + lmax = 2.0 + identity = np.eye(laplacian.shape[0], dtype=laplacian.dtype) + return (2.0 / lmax) * laplacian - identity + + +def _apply_graph_transform(adjacency, single_transform): + """Applies a transform function to adjacency tensors with rank 2 or 3.""" + adj_array = adjacency.numpy() + + if adj_array.ndim == 2: + transformed = np.asarray(single_transform(adj_array), dtype=np.float32) + return tf.convert_to_tensor(transformed, dtype=tf.float32) + + if adj_array.ndim == 3: + transformed = [ + np.asarray(single_transform(graph_adj), dtype=np.float32) + for graph_adj in adj_array + ] + return tf.convert_to_tensor(np.stack(transformed), dtype=tf.float32) + + raise ValueError( + f"Adjacency tensor must be rank 2 or 3, got rank {adj_array.ndim}." + ) + + +def generate_model_class( + name, + layer_types, + num_repeat, + num_output, + transform=None): + """Dynamically generates a custom Keras GNN model class with specified layer configuration. + + This function creates a new Keras Model class with a configurable sequence of layers. + Each layer type can be repeated multiple times, allowing for flexible architecture design. + The generated model supports both graph-based layers (Spektral) and standard layers. + + :param name: Name for the generated model class. + :param layer_types: List of layer constructors. + Each element should be a callable that instantiates a layer. + :param num_repeat: List of integers specifying repetition count for each layer type. + Must have same length as layer_types. + :param num_output: Number of output units in the final dense layer. + :param transform: Optional function to preprocess the adjacency matrix before passing + it to graph layers. Should accept a tensor and return a tensor. + :returns: Created Keras Model class. + :raises ValueError: If layer_types and num_repeat have different lengths, or if + any num_repeat value is less than 1. + + """ + + def __init__(self): + """Initializes the model with the specified layer sequence.""" + if len(layer_types) != len(num_repeat): + raise ValueError( + f"layer_types and num_repeat must have the same length. " + f"Got {len(layer_types)} and {len(num_repeat)}." + ) + if any(n < 1 for n in num_repeat): + raise ValueError( + "All values in num_repeat must be at least 1." + ) + + super(type(self), self).__init__() + + # Build sequence of hidden layers + self.layer_seq = [] + for layer_idx, layer_type in enumerate(layer_types): + for _ in range(num_repeat[layer_idx]): + # Instantiate layer from callable + layer = layer_type() if callable(layer_type) else layer_type + self.layer_seq.append(layer) + + # Final output layer with ReLU activation + self.output_layer = tf.keras.layers.Dense( + num_output, activation="relu" + ) + + def call(self, inputs, mask=None): + """Forward pass through the model. + + :param inputs: Tuple of (node_features, adjacency_matrix) where: + - node_features: [batch_size, num_nodes, num_features] + - adjacency_matrix: [num_nodes, num_nodes] or [batch_size, num_nodes, num_nodes] + :param mask: Optional node mask from data loader. Can be None or a list [node_mask, None]. + :returns: Model output tensor with shape [batch_size, num_nodes, num_output]. + + """ + # Unpack inputs + x, a = inputs + + # Extract and prepare node mask for masking operations + node_mask = None + if mask is not None: + # Spektral typically provides masks as [node_mask, None] for [x, a] + node_mask = mask[0] if isinstance( + mask, (tuple, list)) else mask + + if node_mask is not None: + # Ensure mask has shape [batch_size, num_nodes, 1] + if tf.rank(node_mask) == 2: + node_mask = tf.expand_dims(node_mask, axis=-1) + + # Create default mask if none provided + if node_mask is None: + # x has shape [batch_size, num_nodes, features] + x_shape = tf.shape(x) + node_mask = tf.ones([x_shape[0], x_shape[1], 1], dtype=tf.float32) + + # Apply adjacency matrix transformation if provided + if not tf.is_symbolic_tensor(a): + if transform is not None: + a = transform(a) + + # Forward pass through layer sequence + for layer in self.layer_seq: + # Check if layer is a Spektral graph layer + if type(layer).__module__.startswith("spektral.layers"): + # Graph layers need both features and adjacency matrix + x = layer([x, a], mask=[node_mask, None]) + else: + # Standard layers only need features + x = layer(x) + + # Apply final output layer + output = self.output_layer(x) + return output + + # Create class dictionary with methods + class_dict = { + '__init__': __init__, + 'call': call + } + + return type(name, (tf.keras.Model,), class_dict) + + +def _get_layer_config(layer_type): + """Returns layer class and transformation function for a given layer type. + + Each GNN layer type requires specific preprocessing of the adjacency matrix. + This function encapsulates the layer-specific configuration. + + :param layer_type: String identifier for the GNN layer type. + :returns: Tuple of (layer_class, transform_function). + :raises ValueError: If layer_type is not supported. + + """ + if layer_type == "ARMAConv": + layer_class = spektral_layers.ARMAConv + + def transform(adjacency): + """Applies rescaled Laplacian transformation for ARMA convolution.""" + + def single_transform(adj_array): + laplacian = spektral_convolution.normalized_laplacian( + adj_array) + return _rescale_laplacian(laplacian) + + return _apply_graph_transform(adjacency, single_transform) + + return layer_class, transform + + elif layer_type == "GCSConv": + layer_class = spektral_layers.GCSConv + + def transform(adjacency): + """Applies normalized adjacency for GCS convolution.""" + return _apply_graph_transform( + adjacency, + spektral_convolution.normalized_adjacency + ) + + return layer_class, transform + + elif layer_type == "GATConv": + layer_class = spektral_layers.GATConv + + def transform(adjacency): + """Applies normalized adjacency for GAT convolution.""" + return _apply_graph_transform( + adjacency, + spektral_convolution.normalized_adjacency + ) + + return layer_class, transform + + elif layer_type == "GCNConv": + layer_class = spektral_layers.GCNConv + + def transform(adjacency): + """Applies GCN filter for GCN convolution.""" + return _apply_graph_transform( + adjacency, + spektral_convolution.gcn_filter + ) + + return layer_class, transform + + elif layer_type == "APPNPConv": + layer_class = spektral_layers.APPNPConv + + def transform(adjacency): + """Applies GCN filter for APPNP convolution.""" + return _apply_graph_transform( + adjacency, + spektral_convolution.gcn_filter + ) + + return layer_class, transform + + else: + raise ValueError( + f"Unsupported layer_type: '{layer_type}'. " + f"Supported types are: {', '.join(SUPPORTED_LAYER_TYPES)}" + ) + + +def get_model( + layer_type, + num_layers, + num_channels, + activation, + num_output=1): + """Generates a GNN model instance with specified architecture. + + Creates a Graph Neural Network model with a layer type repeated + multiple times. The model includes appropriate preprocessing for the adjacency + matrix based on the layer type. + + :param layer_type: Type of GNN layer to use. Must be one of: 'ARMAConv', 'GCSConv', + 'GATConv', 'GCNConv', 'APPNPConv'. + :param num_layers: Number of hidden GNN layers to stack. + :param num_channels: Number of channels (units/features) in each hidden layer. + :param activation: Activation function for hidden layers (e.g., 'relu', 'elu', 'tanh', 'sigmoid'). + :param num_output: Number of output units in the final dense layer (default: 1). + :returns: Instantiated Keras Model ready for training. + :raises ValueError: If parameters are invalid or layer_type is not supported. + + """ + # Validate inputs + if layer_type not in SUPPORTED_LAYER_TYPES: + raise ValueError( + f"Unsupported layer_type: '{layer_type}'. " + f"Supported types are: {', '.join(SUPPORTED_LAYER_TYPES)}" + ) + + if num_layers < 1: + raise ValueError( + f"num_layers must be at least 1, got {num_layers}." + ) + + if num_channels < 1: + raise ValueError( + f"num_channels must be at least 1, got {num_channels}." + ) + + if not isinstance(activation, str): + raise ValueError( + f"activation must be a string, got {type(activation).__name__}." + ) + + if num_output < 1: + raise ValueError( + f"num_output must be at least 1, got {num_output}." + ) + + # Get layer configuration for the specified type + layer_class, transform_fn = _get_layer_config(layer_type) + + # Create layer constructor with specified parameters + def create_layer(): + return layer_class( + num_channels, + activation=activation, + kernel_initializer=tf.keras.initializers.GlorotUniform() + ) + + # Configure model structure + layer_types = [create_layer] + num_repeat = [num_layers] + + # Generate model class + model_class = generate_model_class( + name="CustomGNNModel", + layer_types=layer_types, + num_repeat=num_repeat, + num_output=num_output, + transform=transform_fn + ) + + return model_class() + + +def main(): + """Main function demonstrating network architecture usage. + """ + print("=" * 70) + print("GNN Network Architecture - Examples") + print("=" * 70) + + # Example 1: Custom model with mixed layer types + print("\nExample 1: Custom model with mixed Dense layers") + layer_types = [ + lambda: tf.keras.layers.Dense(10, activation="relu"), + lambda: tf.keras.layers.Dense(20, activation="relu"), + lambda: tf.keras.layers.Dense(30, activation="relu") + ] + num_repeat = [2, 3, 1] # 2 layers of 10 units, 3 of 20, 1 of 30 + + custom_model_class = generate_model_class( + name="CustomMixedModel", + layer_types=layer_types, + num_repeat=num_repeat, + num_output=2 + ) + model1 = custom_model_class() + + # Example inputs for dense layers (no graph structure) + batch_size = 8 + num_nodes = 20 + num_features = 5 + + # Node features: [batch_size, num_nodes, features] + x1 = tf.random.normal([batch_size, num_nodes, num_features]) + # Adjacency matrix per sample (not used by dense layers but required for input signature) + a1_single = tf.random.normal([num_nodes, num_nodes]) + a1 = tf.tile(tf.expand_dims(a1_single, axis=0), [batch_size, 1, 1]) + # Labels: [batch_size, num_nodes, outputs] + labels1 = tf.random.normal([batch_size, num_nodes, 2]) + + print("Compiling and training custom model...") + # Build the model by calling it once + _ = model1([x1, a1]) + model1.compile(optimizer="adam", loss="mse") + model1.fit([x1, a1], labels1, epochs=5, verbose=0) + print("Custom model trained successfully") + + # Example 2: Pre-configured GNN model + print("\nExample 2: Pre-configured ARMA GNN model") + model2 = get_model( + layer_type="ARMAConv", + num_layers=3, + num_channels=16, + activation="relu", + num_output=2 + ) + + # Example inputs for GNN (proper graph structure) + # Node features: [batch_size, num_nodes, features] + x2 = tf.random.normal([batch_size, num_nodes, num_features]) + # Adjacency matrix per sample + a2_single = tf.eye( + num_nodes) + 0.1 * tf.random.normal([num_nodes, num_nodes]) + a2 = tf.tile(tf.expand_dims(a2_single, axis=0), [batch_size, 1, 1]) + # Labels: [batch_size, num_nodes, outputs] + labels2 = tf.random.normal([batch_size, num_nodes, 2]) + + print("Compiling and training ARMA model...") + # Build the model by calling it once + _ = model2([x2, a2]) + model2.compile(optimizer="adam", loss="mse") + model2.fit([x2, a2], labels2, epochs=5, verbose=0) + print("ARMA model trained successfully") + + print("\n" + "=" * 70) + print("Examples completed successfully!") + print("=" * 70 + "\n") + + +if __name__ == "__main__": + main() diff --git a/pycode/memilio-surrogatemodel/memilio/surrogatemodel_test/test_surrogatemodel_GNN.py b/pycode/memilio-surrogatemodel/memilio/surrogatemodel_test/test_surrogatemodel_GNN.py new file mode 100644 index 0000000000..58419294fb --- /dev/null +++ b/pycode/memilio-surrogatemodel/memilio/surrogatemodel_test/test_surrogatemodel_GNN.py @@ -0,0 +1,488 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Manuel Heger, Henrik Zunker +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# + +from pyfakefs import fake_filesystem_unittest +from unittest.mock import patch +import os +import unittest +import pickle +import numpy as np +import pandas as pd +import tensorflow as tf +import logging +import spektral + +import memilio.surrogatemodel.GNN.network_architectures as gnn_arch +import memilio.surrogatemodel.GNN.GNN_utils as utils + +from memilio.surrogatemodel.GNN.evaluate_and_train import ( + create_dataset, train_and_evaluate, evaluate, train_step, MixedLoader) +from memilio.surrogatemodel.GNN.grid_search import perform_grid_search +from memilio.surrogatemodel.GNN.network_architectures import generate_model_class, get_model +from tensorflow.keras.losses import MeanAbsolutePercentageError +# suppress all autograph warnings from tensorflow + +logging.getLogger("tensorflow").setLevel(logging.ERROR) + + +class TestSurrogatemodelGNN(fake_filesystem_unittest.TestCase): + path = "/home/" + + def create_dummy_data( + self, num_samples, num_nodes, num_node_features, output_dim): + """ + Create dummy data for testing. + + :param num_samples: Number of samples in the dataset. + :param num_nodes: Number of nodes in each graph. + :param num_node_features: Number of features per node. + :param output_dim: Number of output dimensions per node. + :return: A dictionary containing inputs, adjacency matrix, and labels. + """ + # Shape should be (num_samples, input_width, num_nodes, num_features) + X = np.random.rand(num_samples, 1, num_nodes, + num_node_features).astype(np.float32) + A = np.random.randint(0, 2, (num_nodes, num_nodes)).astype(np.float32) + # Shape should be (num_samples, label_width, num_nodes, label_features) + y = np.random.rand(num_samples, 1, num_nodes, + output_dim).astype(np.float32) + return {"inputs": X, "adjacency": A, "labels": y} + + def setup_fake_filesystem(self, fs, path, data): + """ + Save dummy data to the fake file system. + + :param fs: The fake file system object. + :param path: The base path in the fake file system. + :param data: The dummy data dictionary containing inputs, adjacency, and labels. + :return: Paths to the cases and mobility files. + """ + path_cases_dir = os.path.join(path, "cases") + fs.create_dir(path_cases_dir) + path_cases = os.path.join(path_cases_dir, "cases.pickle") + with open(path_cases, 'wb') as f: + pickle.dump({"inputs": data["inputs"], + "labels": data["labels"]}, f) + + path_mobility = os.path.join(path, "mobility") + mobility_file = os.path.join( + path_mobility, "commuter_mobility_2022.txt") + fs.create_dir(path_mobility) + fs.create_file(mobility_file) + with open(mobility_file, 'w') as f: + np.savetxt(f, data["adjacency"], delimiter=" ") + + return path_cases, path_mobility + + def setUp(self): + self.setUpPyfakefs() + + def test_generate_model_class(self): + + # Test parameters + layer_types = [ + lambda: tf.keras.layers.Dense(10, activation="relu"), + ] + num_layers = [3] + num_output = 2 + + # Generate the model class + ModelClass = generate_model_class( + "TestModel", layer_types, num_layers, num_output) + # Check if the generated class is a subclass of tf.keras.Model + self.assertTrue(issubclass(ModelClass, tf.keras.Model)) + + # Instantiate the model + model = ModelClass() + + # Check if the model has the expected number of layers + expected_num_layers = num_layers[0] + 1 # +1 for the output layer + self.assertEqual(len(model.layers), expected_num_layers) + self.assertIsInstance(model.layers[-1], tf.keras.layers.Dense) + self.assertEqual(model.layers[-1].units, num_output) + self.assertEqual( + model.layers[-1].activation.__name__, "relu") + + # Test with invalid parameters + layer_types = [ + lambda: tf.keras.layers.Dense(10, activation="relu"), + ] + num_layers = [0] + num_output = 2 + with self.assertRaises(ValueError) as error: + ModelClass = generate_model_class( + "TestModel", layer_types, num_layers, num_output) + model = ModelClass() + self.assertEqual(str( + error.exception), "All values in num_repeat must be at least 1.") + + num_layers = [3, 2] + with self.assertRaises(ValueError) as error: + ModelClass = generate_model_class( + "TestModel", layer_types, num_layers, num_output) + model = ModelClass() + self.assertEqual( + str(error.exception), + "layer_types and num_repeat must have the same length. " + "Got 1 and 2.") + + # Test with multiple layer types + layer_types = [ + lambda: tf.keras.layers.Dense(10, activation="relu"), + lambda: tf.keras.layers.Dense(20, activation="relu"), + ] + num_repeat = [2, 3] + num_output = 4 + + ModelClass = generate_model_class( + "TestModel", layer_types, num_repeat, num_output) + model = ModelClass() + + # Check the number of layers + self.assertEqual(len(model.layer_seq), sum(num_repeat)) + self.assertEqual(model.output_layer.units, num_output) + + def test_get_model(self): + + # Test parameters + layer_type = "GCNConv" + num_layers = 2 + num_channels = 16 + activation = "relu" + num_output = 3 + + # Generate the model + model = get_model(layer_type, num_layers, + num_channels, activation, num_output) + + # Check if the model is an instance of tf.keras.Model + self.assertIsInstance(model, tf.keras.Model) + + # Check if the model has the expected number of layers + expected_num_layers = num_layers + 1 # +1 for the output layer + self.assertEqual(len(model.layers), expected_num_layers) + + # Check handling of invalid parameters + # Test with invalid layer type + layer_type = "MonvConv" + with self.assertRaises(ValueError) as error: + model = get_model(layer_type, num_layers, + num_channels, activation, num_output) + self.assertEqual( + str(error.exception), + "Unsupported layer_type: 'MonvConv'. " + "Supported types are: ARMAConv, GCSConv, GATConv, GCNConv, APPNPConv") + # Test with invalud num_layers + layer_type = "GATConv" + num_layers = 0 + with self.assertRaises(ValueError) as error: + model = get_model(layer_type, num_layers, + num_channels, activation, num_output) + self.assertEqual(str( + error.exception), "num_layers must be at least 1, got 0.") + # Test with invalid num_output + num_layers = 2 + num_output = 0 + with self.assertRaises(ValueError) as error: + model = get_model(layer_type, num_layers, + num_channels, activation, num_output) + self.assertEqual(str( + error.exception), "num_output must be at least 1, got 0.") + # Test with invalid num_channels + num_output = 2 + num_channels = 0 + with self.assertRaises(ValueError) as error: + model = get_model(layer_type, num_layers, + num_channels, activation, num_output) + self.assertEqual(str( + error.exception), "num_channels must be at least 1, got 0.") + # Test with invalid activation + num_channels = 16 + activation = 5 + with self.assertRaises(ValueError) as error: + model = get_model(layer_type, num_layers, + num_channels, activation, num_output) + self.assertEqual( + str(error.exception), + "activation must be a string, got int.") + + def test_create_dataset(self): + + # Create dummy data in the fake filesystem for testing + num_samples = 10 + num_nodes = 5 + num_node_features = 3 + output_dim = 4 + data = self.create_dummy_data( + num_samples, num_nodes, num_node_features, output_dim) + # Save dummy data to the fake file system + path_cases, path_mobility = self.setup_fake_filesystem( + self.fs, self.path, data) + + # Create dataset + dataset = create_dataset( + path_cases, path_mobility, number_of_nodes=num_nodes) + self.assertEqual(len(dataset), num_samples) + for graph in dataset: + self.assertEqual(graph.x.shape, (num_nodes, num_node_features)) + self.assertEqual(dataset.a.shape, (num_nodes, num_nodes)) + self.assertEqual(graph.y.shape, (num_nodes, output_dim)) + # Clean up + self.fs.remove_object(path_cases) + self.fs.remove_object(os.path.join( + path_mobility, "commuter_mobility_2022.txt")) + self.fs.remove_object(path_mobility) + + def test_train_step(self): + + # Create a simple model for testing + model = get_model( + layer_type="GCNConv", num_layers=2, num_channels=16, + activation="relu", num_output=2) + optimizer = tf.keras.optimizers.Adam() + loss_fn = MeanAbsolutePercentageError() + + # Create dummy data + num_samples = 10 + num_nodes = 5 + num_node_features = 3 + output_dim = 2 + data = self.create_dummy_data( + num_samples, num_nodes, num_node_features, output_dim) + # Save dummy data to the fake file system + path_cases, path_mobility = self.setup_fake_filesystem( + self.fs, self.path, data) + # Create dataset + dataset = create_dataset( + path_cases, path_mobility, number_of_nodes=num_nodes) + # Build the model by calling it on a batch of data + loader = MixedLoader(dataset, batch_size=4, epochs=1) + inputs, y = next(loader) + model(inputs) + + # Perform a training step + loss, acc = train_step(inputs, y, loss_fn, model, optimizer) + + # Check if the loss is a scalar tensor + self.assertIsInstance(loss, tf.Tensor) + self.assertEqual(loss.shape, ()) + self.assertIsInstance(acc, tf.Tensor) + self.assertEqual(acc.shape, ()) + self.assertGreaterEqual(acc.numpy(), 0) + self.assertGreaterEqual(loss.numpy(), 0) + + def test_evaluate(self): + + # Create a simple model for testing + model = get_model( + layer_type="GCNConv", num_layers=2, num_channels=16, + activation="relu", num_output=4) + loss_fn = MeanAbsolutePercentageError() + + # Create dummy data in the fake filesystem for testing + num_samples = 10 + num_nodes = 5 + num_node_features = 3 + output_dim = 4 + data = self.create_dummy_data( + num_samples, num_nodes, num_node_features, output_dim) + # Save dummy data to the fake file system + path_cases, path_mobility = self.setup_fake_filesystem( + self.fs, self.path, data) + # Create dataset + dataset = create_dataset( + path_cases, path_mobility, number_of_nodes=num_nodes) + # Build the model by calling it on a batch of data + loader = MixedLoader(dataset, batch_size=2, epochs=1) + inputs, _ = loader.__next__() + model(inputs) + # Redefine the loader + loader = MixedLoader(dataset, batch_size=2, epochs=1) + res = evaluate(loader, model, loss_fn) + # Check if the result is a tuple of (loss, accuracy) + self.assertEqual(len(res), 2) + self.assertGreaterEqual(res[0], 0) + self.assertGreaterEqual(res[1], 0) + + # Test with retransformation + loader = MixedLoader(dataset, batch_size=2, epochs=1) + res = evaluate(loader, model, loss_fn, True) + + self.assertEqual(len(res), 2) + self.assertGreaterEqual(res[0], 0) + self.assertGreaterEqual(res[1], 0) + # Clean up + self.fs.remove_object(path_cases) + self.fs.remove_object(os.path.join( + path_mobility, "commuter_mobility_2022.txt")) + self.fs.remove_object(path_mobility) + + @patch("os.path.realpath", return_value="/home/") + def test_train_and_evaluate(self, mock_realpath): + + number_of_epochs = 2 + # Create a simple model for testing + model = get_model( + layer_type="GCNConv", num_layers=2, num_channels=16, + activation="relu", num_output=4) + + # Create dummy data in the fake filesystem for testing + num_samples = 20 + num_nodes = 5 + num_node_features = 3 + output_dim = 4 + data = self.create_dummy_data( + num_samples, num_nodes, num_node_features, output_dim) + # Save dummy data to the fake file system + path_cases, path_mobility = self.setup_fake_filesystem( + self.fs, self.path, data) + # Create dataset + dataset = create_dataset( + path_cases, path_mobility, number_of_nodes=num_nodes) + + res = train_and_evaluate( + dataset, + batch_size=2, + epochs=number_of_epochs, + model=model, + loss_fn=MeanAbsolutePercentageError(), + optimizer=tf.keras.optimizers.Adam(), + es_patience=100) + + self.assertEqual(len(res["train_losses"][0]), number_of_epochs) + self.assertEqual(len(res["val_losses"][0]), number_of_epochs) + self.assertGreater(res["mean_test_loss"], 0) + + # Testing with saving the results + res = train_and_evaluate( + dataset, + batch_size=2, + epochs=number_of_epochs, + model=model, + loss_fn=MeanAbsolutePercentageError(), + optimizer=tf.keras.optimizers.Adam(), + es_patience=100, + save_dir=self.path) + save_results_path = os.path.join(self.path, "model_evaluations_paper") + save_model_path = os.path.join(self.path, "saved_weights") + self.assertTrue(os.path.exists(save_results_path)) + self.assertTrue(os.path.exists(save_model_path)) + + file_path_df = save_results_path+"/model.csv" + df = pd.read_csv(file_path_df) + self.assertEqual(len(df), 1) + for item in [ + "train_loss", "val_loss", "test_loss", + "test_loss_orig", "training_time", + "loss_history", "val_loss_history"]: + self.assertIn(item, df.columns) + + file_path_model = save_model_path+"/model.pickle" + with open(file_path_model, 'rb') as f: + weights_loaded = pickle.load(f) + weights = model.get_weights() + for w1, w2 in zip(weights_loaded, weights): + np.testing.assert_array_equal(w1, w2) + # Clean up + self.fs.remove_object(path_cases) + self.fs.remove_object(os.path.join( + path_mobility, "commuter_mobility_2022.txt")) + self.fs.remove_object(path_mobility) + self.fs.remove_object(save_results_path) + self.fs.remove_object(save_model_path) + + def test_perform_grid_search(self): + + # Create dummy data in the fake filesystem for testing + num_samples = 10 + num_nodes = 4 + num_node_features = 3 + output_dim = 4 + data = self.create_dummy_data( + num_samples, num_nodes, num_node_features, output_dim) + # Save dummy data to the fake file system + path_cases, path_mobility = self.setup_fake_filesystem( + self.fs, self.path, data) + # Create dataset + dataset = create_dataset( + path_cases, path_mobility, number_of_nodes=num_nodes) + + # Define model parameters for grid search + layers = ["GCNConv"] + num_layers = [1] + num_channels = [8] + activations = ["relu"] + parameter_grid = [(layer, n_layer, channel, activation) + for layer in layers for n_layer in num_layers + for channel in num_channels + for activation in activations] + batch_size = 2 + es_patience = 5 + max_epochs = 2 + # Perform grid search with explicit save_dir to avoid os.path.realpath issues + perform_grid_search(dataset, parameter_grid, self.path, + batch_size=batch_size, max_epochs=max_epochs, + es_patience=es_patience) + + # Check if the results file is created + results_file = os.path.join( + self.path, "saves", "grid_search_results.csv") + self.assertTrue(os.path.exists(results_file)) + + # Check if the results file has the expected number of rows + df_results = pd.read_csv(results_file) + self.assertEqual(len(df_results), len(parameter_grid)) + self.assertEqual(len(df_results.columns), 12) + + def test_scale_data_valid_data(self): + """Test utils.scale_data with valid input and label data.""" + data = { + # 10 samples, 1 day, 5 nodes, 8 groups + "inputs": np.random.rand(10, 1, 8, 5), + "labels": np.random.rand(10, 1, 8, 5) + } + + scaled_inputs, scaled_labels = utils.scale_data(data, True) + + # Check that the scaled data is not equal to the original data + assert not np.allclose( + data["inputs"].transpose(0, 3, 1, 2), scaled_inputs) + assert not np.allclose( + data["labels"].transpose(0, 3, 1, 2), scaled_labels) + + # Check that the scaled data is log-transformed + assert np.allclose(scaled_inputs, np.log1p( + data["inputs"]).transpose(0, 3, 1, 2)) + assert np.allclose(scaled_labels, np.log1p( + data["labels"]).transpose(0, 3, 1, 2)) + + def test_scale_data_invalid_data(self): + """Test utils.scale_data with invalid (non-numeric) data.""" + data = { + "inputs": np.array([["a", "b"], ["c", "d"]]), # Non-numeric data + "labels": np.array([["e", "f"], ["g", "h"]]) + } + + with self.assertRaises(ValueError): + utils.scale_data(data) + + +if __name__ == '__main__': + unittest.main() diff --git a/pycode/memilio-surrogatemodel/requirements-dev.txt b/pycode/memilio-surrogatemodel/requirements-dev.txt index d96c76b022..8594db7ac1 100644 --- a/pycode/memilio-surrogatemodel/requirements-dev.txt +++ b/pycode/memilio-surrogatemodel/requirements-dev.txt @@ -1,3 +1,5 @@ # first support of python 3.11 pyfakefs>=4.6 coverage>=7.0.1 +spektral>=1.2 +tensorflow>=2.12.0