Skip to content

1057 data generation for GNNs #1090

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
e50eff8
data gerneration graphODE without dampings
AgathaSchmidt Jul 4, 2024
21e6df5
GNN data generator with dampings - graphODE
AgathaSchmidt Jul 9, 2024
95a9d15
add damping information to output
AgathaSchmidt Jul 9, 2024
300b956
adjust to code guidelines
AgathaSchmidt Aug 7, 2024
75ce80d
delete file
AgathaSchmidt Aug 7, 2024
10b6e92
adjust to code guidelines
AgathaSchmidt Aug 7, 2024
9115678
create a make_graph function
AgathaSchmidt Aug 12, 2024
c9b26cf
create make graph function
AgathaSchmidt Aug 12, 2024
d97a1f7
fix population number at 400
AgathaSchmidt Aug 14, 2024
e20fe2d
fix number of population to 400 and rename damping information
AgathaSchmidt Aug 14, 2024
d9110e5
add "return data" for case without saving
AgathaSchmidt Aug 15, 2024
8823aa1
tests for simulation run and datageneration for models with and witho…
AgathaSchmidt Aug 15, 2024
74da6d8
Apply suggestions from code review
AgathaSchmidt Aug 26, 2024
47f99eb
apply suggestions from code review
AgathaSchmidt Aug 29, 2024
58fc8e8
a python file with functions frequently used for our GNNs
AgathaSchmidt Aug 29, 2024
ffa076e
add get population and get minimum matrix to utils
AgathaSchmidt Sep 2, 2024
0cbda2b
remove get population function from this file and import it from utils
AgathaSchmidt Sep 2, 2024
e8fb580
remove functons and import them from utlis
AgathaSchmidt Sep 2, 2024
de35ca6
add comments as proposed by review
AgathaSchmidt Sep 2, 2024
3d5a644
pre commit
AgathaSchmidt Sep 16, 2024
2dfbe4f
mock transform mobility function
AgathaSchmidt Sep 18, 2024
86a2e89
adjust utils: add surrogate utils and add scling to GNN utis
AgathaSchmidt Sep 19, 2024
77af0ab
add test for saving mechanism
AgathaSchmidt Sep 19, 2024
ce4a494
import functions from new utils file
AgathaSchmidt Sep 19, 2024
63980d9
adjust import
AgathaSchmidt Sep 19, 2024
cabf5a1
adjust imports
AgathaSchmidt Sep 19, 2024
06fb8dd
put function which read files outside if the run_simulation function
AgathaSchmidt Sep 24, 2024
f2345c7
add directory as parameter
AgathaSchmidt Sep 24, 2024
d26528e
set edges only one time
AgathaSchmidt Sep 25, 2024
86c76d5
Merge branch 'main' into 1057-GNN-datageneration
HenrZu Oct 2, 2024
e39fd71
new structure for no damp
HenrZu Oct 2, 2024
ff6c571
timing graph sim (Delete before Merge!)
HenrZu Oct 2, 2024
ac705a4
with_dampings
HenrZu Oct 2, 2024
10d5a98
[ci skip] damping correctly setted and reseted after each run
HenrZu Oct 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion cpp/memilio/mobility/graph_simulation.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include "memilio/mobility/graph.h"
#include "memilio/utils/random_number_generator.h"
#include <chrono>

namespace mio
{
Expand Down Expand Up @@ -60,7 +61,9 @@ class GraphSimulationBase

void advance(double t_max = 1.0)
{
auto dt = m_dt;
auto dt = m_dt;
auto start_time = std::chrono::high_resolution_clock::now(); // Startzeit erfassen

while (m_t < t_max) {
if (m_t + dt > t_max) {
dt = t_max - m_t;
Expand All @@ -77,6 +80,12 @@ class GraphSimulationBase
m_graph.nodes()[e.end_node_idx].property);
}
}

auto end_time = std::chrono::high_resolution_clock::now(); // Endzeit erfassen
std::chrono::duration<double> execution_time = end_time - start_time; // Ausführungszeit berechnen

std::cout << "t = " << m_t << " execution time (Graph Simulation): " << execution_time.count() << "sec"
<< std::endl;
}

double get_t() const
Expand Down
166 changes: 166 additions & 0 deletions pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/GNN_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import numpy as np
import pandas as pd
import os
from sklearn.preprocessing import FunctionTransformer

from memilio.epidata import transformMobilityData as tmd
from memilio.epidata import getDataIntoPandasDataFrame as gd
from memilio.simulation.osecir import (ModelGraph, set_edges)
from memilio.epidata import modifyDataframeSeries as mdfs


def remove_confirmed_compartments(dataset_entries, num_groups):
"""! The compartments which contain confirmed cases are not needed and are
therefore omitted by summarizing the confirmed compartment with the
original compartment.
@param dataset_entries Array that contains the compartmental data with
confirmed compartments.
@param num_groups Number of age groups.
@return Array that contains the compartmental data without confirmed compartments.
"""

new_dataset_entries = []
for i in dataset_entries:
dataset_entries_reshaped = i.reshape(
[num_groups, int(np.asarray(dataset_entries).shape[1]/num_groups)]
)
sum_inf_no_symp = np.sum(dataset_entries_reshaped[:, [2, 3]], axis=1)
sum_inf_symp = np.sum(dataset_entries_reshaped[:, [4, 5]], axis=1)
dataset_entries_reshaped[:, 2] = sum_inf_no_symp
dataset_entries_reshaped[:, 4] = sum_inf_symp
new_dataset_entries.append(
np.delete(dataset_entries_reshaped, [3, 5], axis=1).flatten()
)
return new_dataset_entries


def getBaselineMatrix():
"""! loads the baselinematrix
"""

baseline_contact_matrix0 = os.path.join(
"./data/contacts/baseline_home.txt")
baseline_contact_matrix1 = os.path.join(
"./data/contacts/baseline_school_pf_eig.txt")
baseline_contact_matrix2 = os.path.join(
"./data/contacts/baseline_work.txt")
baseline_contact_matrix3 = os.path.join(
"./data/contacts/baseline_other.txt")

baseline = np.loadtxt(baseline_contact_matrix0) \
+ np.loadtxt(baseline_contact_matrix1) + \
np.loadtxt(baseline_contact_matrix2) + \
np.loadtxt(baseline_contact_matrix3)

return baseline


def getMinimumMatrix():
"""! loads the minimum matrix
"""

minimum_contact_matrix0 = os.path.join(
"./data/contacts/minimum_home.txt")
minimum_contact_matrix1 = os.path.join(
"./data/contacts/minimum_school_pf_eig.txt")
minimum_contact_matrix2 = os.path.join(
"./data/contacts/minimum_work.txt")
minimum_contact_matrix3 = os.path.join(
"./data/contacts/minimum_other.txt")

minimum = np.loadtxt(minimum_contact_matrix0) \
+ np.loadtxt(minimum_contact_matrix1) + \
np.loadtxt(minimum_contact_matrix2) + \
np.loadtxt(minimum_contact_matrix3)

return minimum


def make_graph(directory, num_regions, countykey_list, models):
"""!
@param directory Directory with mobility data.
@param num_regions Number (int) of counties that should be added to the
grap-ODE model. Equals 400 for whole Germany.
@param countykey_list List of keys/IDs for each county.
@models models List of osecir Model with one model per population.
@return graph Graph-ODE model.
"""
graph = ModelGraph()
for i in range(num_regions):
graph.add_node(int(countykey_list[i]), models[i])

num_locations = 4

set_edges(os.path.abspath(os.path.join(directory, os.pardir)),
graph, num_locations)
return graph


def transform_mobility_directory():
"""! Transforms the mobility data by merging Eisenach and Wartburgkreis
"""
# get mobility data directory
arg_dict = gd.cli("commuter_official")

directory = arg_dict['out_folder'].split('/pydata')[0]
directory = os.path.join(directory, 'mobility/')

# Merge Eisenach and Wartbugkreis in Input Data
tmd.updateMobility2022(directory, mobility_file='twitter_scaled_1252')
tmd.updateMobility2022(
directory, mobility_file='commuter_migration_scaled')
return directory


def get_population():
df_population = pd.read_json(
'data/pydata/Germany/county_population.json')
age_groups = ['0-4', '5-14', '15-34', '35-59', '60-79', '80-130']

df_population_agegroups = pd.DataFrame(
columns=[df_population.columns[0]] + age_groups)
for region_id in df_population.iloc[:, 0]:
df_population_agegroups.loc[len(df_population_agegroups.index), :] = [int(region_id)] + list(
mdfs.fit_age_group_intervals(df_population[df_population.iloc[:, 0] == int(region_id)].iloc[:, 2:], age_groups))

population = df_population_agegroups.values.tolist()
return population


def scale_data(data):
num_groups = int(np.asarray(data['inputs']).shape[2] / 8)
transformer = FunctionTransformer(np.log1p, validate=True)

# Scale inputs
inputs = np.asarray(
data['inputs']).transpose(2, 0, 1, 3).reshape(num_groups * 8, -1)
scaled_inputs = transformer.transform(inputs)
original_shape_input = np.asarray(data['inputs']).shape

# Reverse the reshape
reshaped_back = scaled_inputs.reshape(original_shape_input[2],
original_shape_input[0],
original_shape_input[1],
original_shape_input[3])

# Reverse the transpose
original_inputs = reshaped_back.transpose(1, 2, 0, 3)
scaled_inputs = original_inputs.transpose(0, 3, 1, 2)

# Scale labels
labels = np.asarray(
data['labels']).transpose(2, 0, 1, 3).reshape(num_groups * 8, -1)
scaled_labels = transformer.transform(labels)
original_shape_labels = np.asarray(data['labels']).shape

# Reverse the reshape
reshaped_back = scaled_labels.reshape(original_shape_labels[2],
original_shape_labels[0],
original_shape_labels[1],
original_shape_labels[3])

# Reverse the transpose
original_labels = reshaped_back.transpose(1, 2, 0, 3)
scaled_labels = original_labels.transpose(0, 3, 1, 2)

return scaled_inputs, scaled_labels
Loading