-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgraph_generation_pandemic.py
129 lines (100 loc) · 6.16 KB
/
graph_generation_pandemic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from data_processing_pandemic import movement_dfs_dict, cases_dfs_dict
from torch_geometric.data import Data, TemporalData
import torch
import numpy as np
from datetime import timedelta
## Extract data
movement_ita_df = movement_dfs_dict['movement_ita_df']
movement_spa_df = movement_dfs_dict['movement_spa_df']
movement_fra_df = movement_dfs_dict['movement_fra_df']
movement_eng_df = movement_dfs_dict['movement_eng_df']
cases_ita_df = cases_dfs_dict['cases_ita_df']
cases_spa_df = cases_dfs_dict['cases_spa_df']
cases_fra_df = cases_dfs_dict['cases_fra_df']
cases_eng_df = cases_dfs_dict['cases_eng_df']
def generate_dict_graph_snapshots_conn(movement_df, cases_df):
dates = movement_df['date'].unique()
snapshots_dict = {}
for i in range(7,len(dates)):
df = movement_df[movement_df['date']==dates[i-7]]
edge_index_arr = np.vstack((df['src'].values, df['trg'].values))
edge_index = torch.tensor(edge_index_arr, dtype=torch.long)
edge_weights = torch.tensor(df['movement'].values[:, None], dtype=torch.float)
edge_weights_index = edge_index.detach().clone()
edge_attr = torch.tensor(df['positive_freq'].values[:, None], dtype=torch.float)
num_nodes = df['src'].max() + 1
node_features = torch.tensor(cases_df.iloc[:, i-6:i+1].values, dtype=torch.float)
graph = Data(x=node_features, edge_index=edge_index, edge_weights=edge_weights,
edge_weights_index=edge_weights_index, edge_attr=edge_attr, num_nodes=num_nodes)
snapshots_dict[dates[i]] = graph
return snapshots_dict
#Dictionaries with connected daily graphs for each country. Connected => all nodes are connected in each snapshot, hence high sparisty.
snapshots_graphs_conn_ita_dict = generate_dict_graph_snapshots_conn(movement_ita_df, cases_ita_df)
snapshots_graphs_conn_spa_dict = generate_dict_graph_snapshots_conn(movement_spa_df, cases_spa_df)
snapshots_graphs_conn_fra_dict = generate_dict_graph_snapshots_conn(movement_fra_df, cases_fra_df)
snapshots_graphs_conn_eng_dict = generate_dict_graph_snapshots_conn(movement_eng_df, cases_eng_df)
def generate_dict_graph_snapshots_unco(movement_df, cases_df):
movement_df = movement_df[movement_df['movement']>0].copy()
dates = movement_df['date'].unique()
snapshots_dict = {}
for i in range(7,len(dates)):
df = movement_df[movement_df['date']==dates[i-7]]
edge_index_arr = np.vstack((df['src'].values, df['trg'].values))
edge_index = torch.tensor(edge_index_arr, dtype=torch.long)
edge_weights = torch.tensor(df['movement'].values[:, None], dtype=torch.float)
edge_weights_index = edge_index.detach().clone()
num_nodes = df['src'].max() + 1
node_features = torch.tensor(cases_df.iloc[:, i-6:i+1].values, dtype=torch.float)
graph = Data(x=node_features, edge_index=edge_index, edge_weights=edge_weights,
edge_weights_index=edge_weights_index, num_nodes=num_nodes)
snapshots_dict[dates[i]] = graph
return snapshots_dict
#Dictionaries with unconnected daily graphs for each country. Unconnected => graphs don't have dummy edges of movemement 0, hence no sparisty.
snapshots_graphs_unco_ita_dict = generate_dict_graph_snapshots_unco(movement_ita_df, cases_ita_df)
snapshots_graphs_unco_spa_dict = generate_dict_graph_snapshots_unco(movement_spa_df, cases_spa_df)
snapshots_graphs_unco_fra_dict = generate_dict_graph_snapshots_unco(movement_fra_df, cases_fra_df)
snapshots_graphs_unco_eng_dict = generate_dict_graph_snapshots_unco(movement_eng_df, cases_eng_df)
#Generation of temporal data (for not fully conencted data only)
def generate_td_movement_df(movement_df):
movement_df = movement_df[movement_df['movement']>0].copy()
movement_df = movement_df.reset_index(drop=True)
movement_df = movement_df.drop(columns = 'positive_freq')
movement_df = movement_df.rename(columns = {'movement':'movement_lag7'})
movement_df['date'] = movement_df['date'] + timedelta(days=7)
return movement_df
def generate_temporal_data(movement_df, cases_df):
movement_df = generate_td_movement_df(movement_df)
dates = movement_df['date'].unique()
node_features_list = []
for i in range(7, len(dates)):
x = torch.tensor(cases_df.iloc[:, i-6:i+1].values, dtype=torch.float)
node_features_list.append(x)
node_features = {i:features for i, features, in enumerate(node_features_list)}
dates = dates[:cases_df.iloc[:, 8:].shape[1]]
movement_df = movement_df[movement_df['date'].isin(dates)]
src = torch.tensor(movement_df['src'].values, dtype=torch.long)
trg = torch.tensor(movement_df['trg'].values, dtype=torch.long)
edge_weights = torch.tensor(movement_df['movement_lag7'].values, dtype=torch.float32).unsqueeze(1)
msg = torch.ones_like(edge_weights)
movement_df['date'] = movement_df['date'].astype('datetime64[s]').astype('int')
dates_mapping = {date: i for i, date in enumerate(movement_df['date'].unique())}
movement_df['date'] = movement_df['date'].map(dates_mapping)
t = torch.tensor(movement_df['date'].values, dtype=torch.int64)
return TemporalData(src=src, dst=trg, t=t, msg=msg), node_features, edge_weights
#TemporalData objects
data, node_features, edge_weights = generate_temporal_data(movement_ita_df, cases_ita_df)
temporal_data_ita = {'TemporalData':data,
'node_features':node_features,
'edge_weights': edge_weights}
data, node_features, edge_weights = generate_temporal_data(movement_spa_df, cases_ita_df)
temporal_data_spa = {'TemporalData':data,
'node_features':node_features,
'edge_weights': edge_weights}
data, node_features, edge_weights = generate_temporal_data(movement_fra_df, cases_ita_df)
temporal_data_fra = {'TemporalData':data,
'node_features':node_features,
'edge_weights': edge_weights}
data, node_features, edge_weights = generate_temporal_data(movement_eng_df, cases_ita_df)
temporal_data_eng = {'TemporalData':data,
'node_features':node_features,
'edge_weights': edge_weights}