1
+ import torch
2
+ import torch .nn .functional as F
3
+ import dgl .function as fn
4
+ import dgl .nn .pytorch .conv as dglnn
5
+ from torch import nn
6
+ from transformer import Decoder
7
+ from dgl .nn .pytorch .conv import EdgeConv
1
8
9
+ class GCN (nn .Module ):
10
+ def __init__ (self , in_feats , h_feats = 64 , num_classes = 2 , num_layers = 2 , mlp_layers = 1 , dropout_rate = 0. ,
11
+ activation = 'ReLU' , ** kwargs ):
12
+ super ().__init__ ()
13
+ self .h_feats = h_feats
14
+ self .layers = nn .ModuleList ()
15
+ self .act = getattr (nn , activation )()
16
+ self .layers .append (dglnn .GraphConv (in_feats , h_feats , activation = self .act ))
17
+ for i in range (num_layers - 1 ):
18
+ self .layers .append (dglnn .GraphConv (h_feats , h_feats , activation = self .act ))
19
+ self .mlp = MLP (h_feats , h_feats , num_classes , mlp_layers , dropout_rate )
20
+ self .dropout = nn .Dropout (dropout_rate ) if dropout_rate > 0 else nn .Identity ()
21
+
22
+ def forward (self , graph ):
23
+ h = graph .ndata ['feature' ]
24
+ for i , layer in enumerate (self .layers ):
25
+ if i != 0 :
26
+ h = self .dropout (h )
27
+ h = layer (graph , h )
28
+ h = self .mlp (h , False )
29
+ return h
30
+
31
+ class MLP (nn .Module ):
32
+ def __init__ (self , in_feats , h_feats = 32 , num_classes = 2 , num_layers = 2 , dropout_rate = 0 , activation = 'ReLU' , ** kwargs ):
33
+ super (MLP , self ).__init__ ()
34
+ self .layers = nn .ModuleList ()
35
+ self .act = getattr (nn , activation )()
36
+ if num_layers == 0 :
37
+ return
38
+ if num_layers == 1 :
39
+ self .layers .append (nn .Linear (in_feats , num_classes ))
40
+ else :
41
+ self .layers .append (nn .Linear (in_feats , h_feats ))
42
+ for i in range (1 , num_layers - 1 ):
43
+ self .layers .append (nn .Linear (h_feats , h_feats ))
44
+ self .layers .append (nn .Linear (h_feats , num_classes ))
45
+ self .dropout = nn .Dropout (dropout_rate ) if dropout_rate > 0 else nn .Identity ()
46
+
47
+ def forward (self , h , is_graph = True ):
48
+ if is_graph :
49
+ h = h .ndata ['feature' ]
50
+ for i , layer in enumerate (self .layers ):
51
+ if i != 0 :
52
+ h = self .dropout (h )
53
+ h = layer (h )
54
+ if i != len (self .layers )- 1 :
55
+ h = self .act (h )
56
+ return h
57
+
58
+
59
+ class MGADN (nn .Module ):
60
+ def __init__ (self , in_feats , h_feats = 64 , n_head = 8 , n_layers = 4 , dropout_rate = 0. , ** kwargs ):
61
+ super ().__init__ ()
62
+ self .attn_fn = nn .Tanh ()
63
+ self .act_fn = nn .ReLU ()
64
+ self .decoder = Decoder (in_feats = in_feats , h_feats = h_feats , n_head = n_head , dropout_rate = 0. , n_layers = n_layers )
65
+ self .filters3 = GCN (in_feats , h_feats = h_feats , num_classes = h_feats , num_layers = 2 , mlp_layers = 2 , dropout_rate = 0. , activation = 'ReLU' )
66
+
67
+ self .DMGNN = EdgeConv (in_feats , out_feat = h_feats )
68
+
69
+ self .linear1 = nn .Linear (h_feats * 2 , h_feats )
70
+
71
+ self .linear = nn .Sequential (nn .Linear (h_feats , h_feats ),
72
+ self .attn_fn ,
73
+ nn .Linear (h_feats , 2 ))
74
+
75
+ self .gate_layer = nn .Linear (h_feats , h_feats )
76
+
77
+ def forward (self , graph ):
78
+ x = graph .ndata ['feature' ]
79
+ h_list = []
80
+ x = x .to (torch .float32 )
81
+
82
+ out1 = self .decoder (x , graph )
83
+ out2 = self .filters3 (graph )
84
+
85
+ F = self .DMGNN (graph ,x )
86
+
87
+ h_list .append (out1 )
88
+ h_list .append (out2 )
89
+
90
+ res = torch .cat ((h_list [0 ], h_list [1 ]), dim = 1 )
91
+ output = self .linear1 (res )
92
+
93
+ gate = torch .sigmoid (self .gate_layer (output ))
94
+
95
+ out = gate * output + (1 - gate ) * F
96
+ result = self .linear (out )
97
+ return result
0 commit comments