1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ # class Decoder(nn.Module):
6
+ # def __init__(self, in_feats,h_feats,n_head,dropout_rate,n_layers):
7
+ # super().__init__()
8
+
9
+ # self.layers = nn.ModuleList([DecoderLayer(in_feats=in_feats,h_feats=h_feats,n_head=n_head,
10
+ # dropout_rate=0.)for _ in range(n_layers)])
11
+
12
+ # self.act_fn = nn.ReLU()
13
+ # self.lin = nn.Linear(h_feats,in_feats)
14
+ # self.mlp = nn.Sequential(nn.Linear(in_feats,h_feats) )
15
+
16
+ # def forward(self, x, edge_index):
17
+ # _x = x
18
+ # for layer in self.layers:
19
+ # x = layer(x, edge_index)
20
+ # x = x + _x
21
+ # # x = self.lin(x)
22
+ # output = self.mlp(x)
23
+ # return output
24
+
25
+ class Decoder (nn .Module ):
26
+ def __init__ (self , in_feats ,h_feats ,n_head ,dropout_rate ,n_layers ):
27
+ super ().__init__ ()
28
+
29
+ self .layers = nn .ModuleList ([DecoderLayer (in_feats = in_feats ,h_feats = h_feats ,n_head = n_head ,
30
+ dropout_rate = 0. )for _ in range (n_layers )])
31
+ self .act_fn = nn .ReLU ()
32
+ self .mlp = nn .Sequential (nn .Linear (in_feats ,h_feats ) )
33
+
34
+ def forward (self , x , edge_index ):
35
+ _x = x
36
+ for layer in self .layers :
37
+ x = layer (x , edge_index )
38
+ x = x + _x
39
+
40
+ output = self .mlp (x )
41
+ return output
42
+
43
+ class DecoderLayer (nn .Module ):
44
+
45
+ def __init__ (self , in_feats , h_feats , n_head , dropout_rate ):
46
+ super (DecoderLayer , self ).__init__ ()
47
+ self .self_attention = MultiHeadAttention (in_channels = in_feats , hid_channels = h_feats , n_head = n_head )
48
+ self .linear = nn .Linear (in_feats , h_feats )
49
+ self .norm1 = LayerNorm (hid_channels = h_feats )
50
+ self .dropout1 = nn .Dropout (p = dropout_rate )
51
+ self .norm3 = LayerNorm (hid_channels = h_feats )
52
+ self .linear1 = nn .Linear (h_feats , in_feats )
53
+
54
+ def forward (self , x , edge_index ):
55
+
56
+ _x = x
57
+ x = self .self_attention (q = x , k = x , v = x )
58
+ x = self .dropout1 (x )
59
+ _x = self .linear (_x )
60
+
61
+ x = self .norm1 (x + _x )
62
+
63
+ x = self .linear1 (x )
64
+ return x
65
+
66
+ class LayerNorm (nn .Module ):
67
+ def __init__ (self , hid_channels , eps = 1e-12 ):
68
+ super (LayerNorm , self ).__init__ ()
69
+ self .gamma = nn .Parameter (torch .ones (hid_channels ))
70
+ self .beta = nn .Parameter (torch .zeros (hid_channels ))
71
+ self .eps = eps
72
+
73
+ def forward (self , x ):
74
+ mean = x .mean (- 1 , keepdim = True )
75
+ var = x .var (- 1 , unbiased = False , keepdim = True )
76
+
77
+
78
+ out = (x - mean ) / torch .sqrt (var + self .eps )
79
+ out = self .gamma * out + self .beta
80
+ return out
81
+
82
+ class MultiHeadAttention (nn .Module ):
83
+
84
+ def __init__ (self , in_channels , hid_channels ,n_head ):
85
+ super (MultiHeadAttention , self ).__init__ ()
86
+ self .n_head = n_head
87
+ self .attention = ScaleDotProductAttention ()
88
+ self .w_q = nn .Linear (in_channels , hid_channels )
89
+ self .w_k = nn .Linear (in_channels , hid_channels )
90
+ self .w_v = nn .Linear (in_channels , hid_channels )
91
+ self .w_concat = nn .Linear (hid_channels , hid_channels )
92
+
93
+ def forward (self , q , k , v , mask = None ):
94
+ q , k , v = self .w_q (q ), self .w_k (k ), self .w_v (v )
95
+ q , k , v = self .split (q ), self .split (k ), self .split (v )
96
+ out , attention = self .attention (q , k , v , mask = mask )
97
+ out = self .concat (out )
98
+ out = self .w_concat (out )
99
+ return out
100
+
101
+ def split (self , tensor ):
102
+ length , d_model = tensor .size ()
103
+ d_tensor = d_model // self .n_head
104
+ tensor = tensor .view (length , self .n_head , d_tensor ).transpose (1 , 2 )
105
+ return tensor
106
+
107
+ def concat (self , tensor ):
108
+ length , head , d_tensor = tensor .size ()
109
+ d_model = head * d_tensor
110
+ tensor = tensor .transpose (1 , 2 ).contiguous ().view (length , d_model )
111
+ return tensor
112
+
113
+ class PositionwiseFeedForward (nn .Module ):
114
+
115
+ def __init__ (self , in_channels , hid_channels , drop_prob = 0.1 ):
116
+ super (PositionwiseFeedForward , self ).__init__ ()
117
+ self .linear1 = nn .Linear (hid_channels , hid_channels )
118
+ self .linear2 = nn .Linear (hid_channels , in_channels )
119
+ self .relu = nn .ReLU ()
120
+ self .dropout = nn .Dropout (p = drop_prob )
121
+
122
+ def forward (self , x ):
123
+ x = self .linear1 (x )
124
+ x = self .relu (x )
125
+ x = self .dropout (x )
126
+ x = self .linear2 (x )
127
+ return x
128
+
129
+
130
+ import math
131
+ from torch import nn
132
+
133
+ class ScaleDotProductAttention (nn .Module ):
134
+ def __init__ (self ):
135
+ super (ScaleDotProductAttention , self ).__init__ ()
136
+ self .softmax = nn .Softmax (dim = - 1 )
137
+
138
+ def forward (self , q , k , v , mask = None , e = 1e-12 ):
139
+ head , length , d_tensor = k .size ()
140
+ k_t = k .transpose (1 , 2 )
141
+ score = (q @ k_t ) / math .sqrt (d_tensor )
142
+ score = self .softmax (score )
143
+ v = score @ v
144
+
145
+ return v , score
0 commit comments