-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpytorch_geometric_temporal_models.py
597 lines (487 loc) · 21.6 KB
/
pytorch_geometric_temporal_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
### MPNN LSTM ###
class MPNNLSTM(nn.Module):
r"""An implementation of the Message Passing Neural Network with Long Short Term Memory.
For details see this paper: `"Transfer Graph Neural Networks for Pandemic Forecasting." <https://arxiv.org/abs/2009.08388>`_
Args:
in_channels (int): Number of input features.
hidden_size (int): Dimension of hidden representations.
num_nodes (int): Number of nodes in the network.
window (int): Number of past samples included in the input.
dropout (float): Dropout rate.
"""
def __init__(
self,
in_channels: int,
hidden_size: int,
num_nodes: int,
window: int,
dropout: float,
):
super(MPNNLSTM, self).__init__()
self.window = window
self.num_nodes = num_nodes
self.hidden_size = hidden_size
self.dropout = dropout
self.in_channels = in_channels
self._create_parameters_and_layers()
def _create_parameters_and_layers(self):
self._convolution_1 = GCNConv(self.in_channels, self.hidden_size)
self._convolution_2 = GCNConv(self.hidden_size, self.hidden_size)
self._batch_norm_1 = nn.BatchNorm1d(self.hidden_size)
self._batch_norm_2 = nn.BatchNorm1d(self.hidden_size)
self._recurrent_1 = nn.LSTM(2 * self.hidden_size, self.hidden_size, 1)
self._recurrent_2 = nn.LSTM(self.hidden_size, self.hidden_size, 1)
def _graph_convolution_1(self, X, edge_index, edge_weight):
X = F.relu(self._convolution_1(X, edge_index, edge_weight))
X = self._batch_norm_1(X)
X = F.dropout(X, p=self.dropout, training=self.training)
return X
def _graph_convolution_2(self, X, edge_index, edge_weight):
X = F.relu(self._convolution_2(X, edge_index, edge_weight))
X = self._batch_norm_2(X)
X = F.dropout(X, p=self.dropout, training=self.training)
return X
def forward(
self,
X: torch.FloatTensor,
edge_index: torch.LongTensor,
edge_weight: torch.FloatTensor = None,
) -> torch.FloatTensor:
"""
Making a forward pass through the whole architecture.
Arg types:
* **X** *(PyTorch FloatTensor)* - Node features.
* **edge_index** *(PyTorch LongTensor)* - Graph edge indices.
* **edge_weight** *(PyTorch LongTensor, optional)* - Edge weight vector.
Return types:
* **H** *(PyTorch FloatTensor)* - The hidden representation of size 2*nhid+in_channels+window-1 for each node.
"""
R = list()
S = X.view(-1, self.window, self.num_nodes, self.in_channels)
S = torch.transpose(S, 1, 2)
S = S.reshape(-1, self.window, self.in_channels)
O = [S[:, 0, :]]
for l in range(1, self.window):
O.append(S[:, l, self.in_channels - 1].unsqueeze(1))
S = torch.cat(O, dim=1)
X = self._graph_convolution_1(X, edge_index, edge_weight)
R.append(X)
X = self._graph_convolution_2(X, edge_index, edge_weight)
R.append(X)
X = torch.cat(R, dim=1)
X = X.view(-1, self.window, self.num_nodes, X.size(1))
X = torch.transpose(X, 0, 1)
X = X.contiguous().view(self.window, -1, X.size(3))
X, (H_1, _) = self._recurrent_1(X)
X, (H_2, _) = self._recurrent_2(X)
H = torch.cat([H_1[0, :, :], H_2[0, :, :], S], dim=1)
return H
### EVOLVEGCN-H ###
from torch.nn import GRU
from torch_geometric.nn import TopKPooling
from torch_geometric.nn.inits import glorot
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
class GCNConv_Fixed_W(MessagePassing):
r"""The graph convolutional operator adapted from the `"Semi-supervised
Classification with Graph Convolutional Networks"
<https://arxiv.org/abs/1609.02907>`_ paper, with weights not trainable.
.. math::
\mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
\mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta},
where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the
adjacency matrix with inserted self-loops and
:math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix.
The adjacency matrix can include other values than :obj:`1` representing
edge weights via the optional :obj:`edge_weight` tensor.
Its node-wise formulation is given by:
.. math::
\mathbf{x}^{\prime}_i = \mathbf{\Theta} \sum_{j \in \mathcal{N}(v) \cup
\{ i \}} \frac{e_{j,i}}{\sqrt{\hat{d}_j \hat{d}_i}} \mathbf{x}_j
with :math:`\hat{d}_i = 1 + \sum_{j \in \mathcal{N}(i)} e_{j,i}`, where
:math:`e_{j,i}` denotes the edge weight from source node :obj:`j` to target
node :obj:`i` (default: :obj:`1.0`)
Args:
in_channels (int): Size of each input sample, or :obj:`-1` to derive
the size from the first input(s) to the forward method.
out_channels (int): Size of each output sample.
improved (bool, optional): If set to :obj:`True`, the layer computes
:math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`.
(default: :obj:`False`)
cached (bool, optional): If set to :obj:`True`, the layer will cache
the computation of :math:`\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
\mathbf{\hat{D}}^{-1/2}` on first execution, and will use the
cached version for further executions.
This parameter should only be set to :obj:`True` in transductive
learning scenarios. (default: :obj:`False`)
add_self_loops (bool, optional): If set to :obj:`False`, will not add
self-loops to the input graph. (default: :obj:`True`)
normalize (bool, optional): Whether to add self-loops and compute
symmetric normalization coefficients on the fly.
(default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self, in_channels: int, out_channels: int,
improved: bool = False, cached: bool = False,
add_self_loops: bool = True, normalize: bool = True,
**kwargs):
kwargs.setdefault('aggr', 'add')
super(GCNConv_Fixed_W, self).__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.improved = improved
self.cached = cached
self.add_self_loops = add_self_loops
self.normalize = normalize
self._cached_edge_index = None
self._cached_adj_t = None
self.reset_parameters()
def reset_parameters(self):
self._cached_edge_index = None
self._cached_adj_t = None
def forward(self, W, x, edge_index, edge_weight):
if self.normalize:
cache = self._cached_edge_index
if cache is None:
edge_index, edge_weight = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(self.node_dim),
self.improved, self.add_self_loops)
x = torch.matmul(x, W)
# propagate_type: (x: Tensor, edge_weight: OptTensor)
out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
size=None)
return out
def message(self, x_j, edge_weight):
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j
class EvolveGCNH(torch.nn.Module):
r"""An implementation of the Evolving Graph Convolutional Hidden Layer.
For details see this paper: `"EvolveGCN: Evolving Graph Convolutional
Networks for Dynamic Graph." <https://arxiv.org/abs/1902.10191>`_
Args:
num_of_nodes (int): Number of vertices.
in_channels (int): Number of filters.
improved (bool, optional): If set to :obj:`True`, the layer computes
:math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`.
(default: :obj:`False`)
cached (bool, optional): If set to :obj:`True`, the layer will cache
the computation of :math:`\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
\mathbf{\hat{D}}^{-1/2}` on first execution, and will use the
cached version for further executions.
This parameter should only be set to :obj:`True` in transductive
learning scenarios. (default: :obj:`False`)
normalize (bool, optional): Whether to add self-loops and apply
symmetric normalization. (default: :obj:`True`)
add_self_loops (bool, optional): If set to :obj:`False`, will not add
self-loops to the input graph. (default: :obj:`True`)
"""
def __init__(
self,
num_of_nodes: int,
in_channels: int,
improved: bool = False,
cached: bool = False,
normalize: bool = True,
add_self_loops: bool = True,
):
super(EvolveGCNH, self).__init__()
self.num_of_nodes = num_of_nodes
self.in_channels = in_channels
self.improved = improved
self.cached = cached
self.normalize = normalize
self.add_self_loops = add_self_loops
self.weight = None
self.initial_weight = torch.nn.Parameter(torch.Tensor(in_channels, in_channels))
self._create_layers()
self.reset_parameters()
def reset_parameters(self):
glorot(self.initial_weight)
def _create_layers(self):
self.ratio = self.in_channels / self.num_of_nodes
self.pooling_layer = TopKPooling(self.in_channels, self.ratio)
self.recurrent_layer = GRU(
input_size=self.in_channels, hidden_size=self.in_channels, num_layers=1
)
self.conv_layer = GCNConv_Fixed_W(
in_channels=self.in_channels,
out_channels=self.in_channels,
improved=self.improved,
cached=self.cached,
normalize=self.normalize,
add_self_loops=self.add_self_loops
)
def forward(
self,
X: torch.FloatTensor,
edge_index: torch.LongTensor,
edge_weight: torch.FloatTensor = None,
) -> torch.FloatTensor:
"""
Making a forward pass.
Arg types:
* **X** *(PyTorch Float Tensor)* - Node embedding.
* **edge_index** *(PyTorch Long Tensor)* - Graph edge indices.
* **edge_weight** *(PyTorch Float Tensor, optional)* - Edge weight vector.
Return types:
* **X** *(PyTorch Float Tensor)* - Output matrix for all nodes.
"""
X_tilde = self.pooling_layer(X, edge_index)
X_tilde = X_tilde[0][None, :, :]
if self.weight is None:
self.weight = self.initial_weight.data
W = self.weight[None, :, :]
X_tilde, W = self.recurrent_layer(X_tilde, W)
X = self.conv_layer(W.squeeze(dim=0), X, edge_index, edge_weight)
return X
### A3TGCN ###
class TGCN(torch.nn.Module):
r"""An implementation of the Temporal Graph Convolutional Gated Recurrent Cell.
For details see this paper: `"T-GCN: A Temporal Graph ConvolutionalNetwork for
Traffic Prediction." <https://arxiv.org/abs/1811.05320>`_
Args:
in_channels (int): Number of input features.
out_channels (int): Number of output features.
improved (bool): Stronger self loops. Default is False.
cached (bool): Caching the message weights. Default is False.
add_self_loops (bool): Adding self-loops for smoothing. Default is True.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
improved: bool = False,
cached: bool = False,
add_self_loops: bool = True,
):
super(TGCN, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.improved = improved
self.cached = cached
self.add_self_loops = add_self_loops
self._create_parameters_and_layers()
def _create_update_gate_parameters_and_layers(self):
self.conv_z = GCNConv(
in_channels=self.in_channels,
out_channels=self.out_channels,
improved=self.improved,
cached=self.cached,
add_self_loops=self.add_self_loops,
)
self.linear_z = torch.nn.Linear(2 * self.out_channels, self.out_channels)
def _create_reset_gate_parameters_and_layers(self):
self.conv_r = GCNConv(
in_channels=self.in_channels,
out_channels=self.out_channels,
improved=self.improved,
cached=self.cached,
add_self_loops=self.add_self_loops,
)
self.linear_r = torch.nn.Linear(2 * self.out_channels, self.out_channels)
def _create_candidate_state_parameters_and_layers(self):
self.conv_h = GCNConv(
in_channels=self.in_channels,
out_channels=self.out_channels,
improved=self.improved,
cached=self.cached,
add_self_loops=self.add_self_loops,
)
self.linear_h = torch.nn.Linear(2 * self.out_channels, self.out_channels)
def _create_parameters_and_layers(self):
self._create_update_gate_parameters_and_layers()
self._create_reset_gate_parameters_and_layers()
self._create_candidate_state_parameters_and_layers()
def _set_hidden_state(self, X, H):
if H is None:
H = torch.zeros(X.shape[0], self.out_channels).to(X.device)
return H
def _calculate_update_gate(self, X, edge_index, edge_weight, H):
Z = torch.cat([self.conv_z(X, edge_index, edge_weight), H], axis=1)
Z = self.linear_z(Z)
Z = torch.sigmoid(Z)
return Z
def _calculate_reset_gate(self, X, edge_index, edge_weight, H):
R = torch.cat([self.conv_r(X, edge_index, edge_weight), H], axis=1)
R = self.linear_r(R)
R = torch.sigmoid(R)
return R
def _calculate_candidate_state(self, X, edge_index, edge_weight, H, R):
H_tilde = torch.cat([self.conv_h(X, edge_index, edge_weight), H * R], axis=1)
H_tilde = self.linear_h(H_tilde)
H_tilde = torch.tanh(H_tilde)
return H_tilde
def _calculate_hidden_state(self, Z, H, H_tilde):
H = Z * H + (1 - Z) * H_tilde
return H
def forward(
self,
X: torch.FloatTensor,
edge_index: torch.LongTensor,
edge_weight: torch.FloatTensor = None,
H: torch.FloatTensor = None,
) -> torch.FloatTensor:
"""
Making a forward pass. If edge weights are not present the forward pass
defaults to an unweighted graph. If the hidden state matrix is not present
when the forward pass is called it is initialized with zeros.
Arg types:
* **X** *(PyTorch Float Tensor)* - Node features.
* **edge_index** *(PyTorch Long Tensor)* - Graph edge indices.
* **edge_weight** *(PyTorch Long Tensor, optional)* - Edge weight vector.
* **H** *(PyTorch Float Tensor, optional)* - Hidden state matrix for all nodes.
Return types:
* **H** *(PyTorch Float Tensor)* - Hidden state matrix for all nodes.
"""
H = self._set_hidden_state(X, H)
Z = self._calculate_update_gate(X, edge_index, edge_weight, H)
R = self._calculate_reset_gate(X, edge_index, edge_weight, H)
H_tilde = self._calculate_candidate_state(X, edge_index, edge_weight, H, R)
H = self._calculate_hidden_state(Z, H, H_tilde)
return H
class A3TGCN(torch.nn.Module):
r"""An implementation of the Attention Temporal Graph Convolutional Cell.
For details see this paper: `"A3T-GCN: Attention Temporal Graph Convolutional
Network for Traffic Forecasting." <https://arxiv.org/abs/2006.11583>`_
Args:
in_channels (int): Number of input features.
out_channels (int): Number of output features.
periods (int): Number of time periods.
improved (bool): Stronger self loops (default :obj:`False`).
cached (bool): Caching the message weights (default :obj:`False`).
add_self_loops (bool): Adding self-loops for smoothing (default :obj:`True`).
"""
def __init__(
self,
in_channels: int,
out_channels: int,
periods: int,
improved: bool = False,
cached: bool = False,
add_self_loops: bool = True
):
super(A3TGCN, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.periods = periods
self.improved = improved
self.cached = cached
self.add_self_loops = add_self_loops
self._setup_layers()
def _setup_layers(self):
self._base_tgcn = TGCN(
in_channels=self.in_channels,
out_channels=self.out_channels,
improved=self.improved,
cached=self.cached,
add_self_loops=self.add_self_loops,
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self._attention = torch.nn.Parameter(torch.empty(self.periods, device=device))
torch.nn.init.uniform_(self._attention)
def forward(
self,
X: torch.FloatTensor,
edge_index: torch.LongTensor,
edge_weight: torch.FloatTensor = None,
H: torch.FloatTensor = None,
) -> torch.FloatTensor:
"""
Making a forward pass. If edge weights are not present the forward pass
defaults to an unweighted graph. If the hidden state matrix is not present
when the forward pass is called it is initialized with zeros.
Arg types:
* **X** (PyTorch Float Tensor): Node features for T time periods.
* **edge_index** (PyTorch Long Tensor): Graph edge indices.
* **edge_weight** (PyTorch Long Tensor, optional)*: Edge weight vector.
* **H** (PyTorch Float Tensor, optional): Hidden state matrix for all nodes.
Return types:
* **H** (PyTorch Float Tensor): Hidden state matrix for all nodes.
"""
H_accum = 0
probs = torch.nn.functional.softmax(self._attention, dim=0)
for period in range(self.periods):
H_accum = H_accum + probs[period] * self._base_tgcn(
X[:, :, period], edge_index, edge_weight, H
)
return H_accum
### GATR ##
from torch_geometric.nn import GATv2Conv
from torch_geometric.nn import PositionalEncoding
from torch_geometric.nn import BatchNorm
from torch.nn import Transformer
class GATEncoder(torch.nn.Module):
"""
GNN Encoder module for creating static node embeddings.
Args:
hidden_channels (int): The number of hidden channels.
num_heads_GAT (int): The number of attention heads.
dropout_p (float): Dropout probability.
edge_dim (int): Dimensionality of edge features.
"""
def __init__(self, in_channels, hidden_channels, num_heads_GAT,
dropout_p, edge_dim, momentum_GAT):
super(GATEncoder, self).__init__()
self.encoder = GATv2Conv(in_channels=in_channels, out_channels=hidden_channels,
add_self_loops=False, heads=num_heads_GAT,
edge_dim=edge_dim)
self.norm = BatchNorm(hidden_channels, momentum=momentum_GAT,
affine=False, track_running_stats=False)
self.dropout = torch.nn.Dropout(dropout_p)
def forward(self, x, edge_index):
"""
Forward pass of the GNNEncoder.
Args:
x_dict (torch.Tensor): node types as keys and node features
for each node as values.
edge_index (torch.Tensor): see previous section.
Returns:
torch.Tensor: Static node embeddings for one snapshot.
"""
x = self.dropout(x)
# nodes_embedds = self.encoder(x, edge_index, edge_attrs)
nodes_embedds = self.encoder(x, edge_index)
nodes_embedds = F.leaky_relu(nodes_embedds, negative_slope=0.1)
return nodes_embedds
class TransEncoder(torch.nn.Module):
"""
Transformer-based module for creating temporal node embeddings.
Args:
dim_model (int): The dimension of the model's hidden states.
num_heads_TR (int): The number of attention heads.
num_encoder_layers_TR (int): The number of encoder layers.
num_decoder_layers_TR (int): The number of decoder layers.
dropout_p_TR (float): Dropout probability.
"""
def __init__(
self, dim_model, num_heads_TR, num_encoder_layers_TR,
num_decoder_layers_TR, dropout_p_GAT):
super(TransEncoder, self).__init__()
self.pos_encoder = PositionalEncoding(dim_model)
self.transformer = Transformer(
d_model=dim_model,
nhead=num_heads_TR,
num_decoder_layers=num_encoder_layers_TR,
num_encoder_layers=num_decoder_layers_TR,
dropout=dropout_p_GAT,
batch_first=True)
def forward(self, src, trg):
"""
Forward pass of the Transformer module.
Args:
src (torch.Tensor): Input sequence with dimensions
(seq_len, num_of_nodes, node_embedds_size).
trg (torch.Tensor): Last element of src, with dimensions
(1, num_of_nodes, node_embedds_size).
Returns:
torch.Tensor: Temporal node embeddings for the snapshot
under prediciton.
"""
src = self.pos_encoder(src)
trg = self.pos_encoder(trg)
temporal_node_embeddings = self.transformer(src, trg)
return temporal_node_embeddings