Skip to content

Commit 06a84e2

Browse files
authored
Add files via upload
1 parent 9ff0fce commit 06a84e2

File tree

2 files changed

+191
-0
lines changed

2 files changed

+191
-0
lines changed

models/detector.py

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from models.gnn import *
2+
3+
from sklearn.metrics import roc_auc_score, average_precision_score
4+
import numpy as np
5+
6+
import time
7+
8+
class BaseDetector(object):
9+
def __init__(self, train_config, model_config, data):
10+
self.model_config = model_config
11+
self.train_config = train_config
12+
self.data = data
13+
model_config['in_feats'] = self.data.graph.ndata['feature'].shape[1]
14+
graph = self.data.graph.to(self.train_config['device'])
15+
self.labels = graph.ndata['label']
16+
self.train_mask = graph.ndata['train_mask'].bool()
17+
self.val_mask = graph.ndata['val_mask'].bool()
18+
self.test_mask = graph.ndata['test_mask'].bool()
19+
self.weight = (1 - self.labels[self.train_mask]).sum().item() / self.labels[self.train_mask].sum().item()
20+
self.source_graph = graph
21+
print(train_config['inductive'])
22+
if train_config['inductive'] == False:
23+
self.train_graph = graph
24+
self.val_graph = graph
25+
else:
26+
self.train_graph = graph.subgraph(self.train_mask)
27+
self.val_graph = graph.subgraph(self.train_mask+self.val_mask)
28+
self.best_score = -1
29+
# self.patience_knt = 0
30+
31+
def train(self):
32+
pass
33+
34+
def eval(self, labels, probs):
35+
score = {}
36+
with torch.no_grad():
37+
if torch.is_tensor(labels):
38+
labels = labels.cpu().numpy()
39+
if torch.is_tensor(probs):
40+
probs = probs.cpu().numpy()
41+
score['AUROC'] = roc_auc_score(labels, probs)
42+
score['AUPRC'] = average_precision_score(labels, probs)
43+
labels = np.array(labels)
44+
k = labels.sum()
45+
score['RecK'] = sum(labels[probs.argsort()[-k:]]) / sum(labels)
46+
return score
47+
48+
49+
class BaseGNNDetector(BaseDetector):
50+
def __init__(self, train_config, model_config, data):
51+
super().__init__(train_config, model_config, data)
52+
gnn = globals()[model_config['model']]
53+
model_config['in_feats'] = self.data.graph.ndata['feature'].shape[1]
54+
self.model = gnn(**model_config).to(train_config['device'])
55+
56+
57+
def train(self):
58+
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.model_config['lr'])
59+
train_labels, val_labels, test_labels = self.labels[self.train_mask], self.labels[self.val_mask], self.labels[self.test_mask]
60+
61+
for e in range(self.train_config['epochs']):
62+
self.model.train()
63+
logits = self.model(self.train_graph)
64+
loss = F.cross_entropy(logits[self.train_graph.ndata['train_mask']], train_labels,
65+
weight=torch.tensor([1., self.weight], device=self.labels.device))
66+
67+
optimizer.zero_grad()
68+
loss.backward()
69+
optimizer.step()
70+
71+
if self.model_config['drop_rate'] > 0 or self.train_config['inductive']:
72+
self.model.eval()
73+
logits = self.model(self.val_graph)
74+
probs = logits.softmax(1)[:, 1]
75+
val_score = self.eval(val_labels, probs[self.val_graph.ndata['val_mask']])
76+
77+
if val_score[self.train_config['metric']] > self.best_score:
78+
if self.train_config['inductive']:
79+
logits = self.model(self.source_graph)
80+
probs = logits.softmax(1)[:, 1]
81+
self.patience_knt = 0
82+
self.best_score = val_score[self.train_config['metric']]
83+
test_score = self.eval(test_labels, probs[self.test_mask])
84+
print('Epoch {}, Loss {:.4f}, Val AUC {:.4f}, PRC {:.4f}, RecK {:.4f}, test AUC {:.4f}, PRC {:.4f}, RecK {:.4f}'.format(
85+
e, loss, val_score['AUROC'], val_score['AUPRC'], val_score['RecK'],
86+
test_score['AUROC'], test_score['AUPRC'], test_score['RecK']))
87+
88+
else:
89+
self.patience_knt += 1
90+
if self.patience_knt > self.train_config['patience']:
91+
break
92+
93+
return test_score
94+
95+

models/gnn.py

+96
Original file line numberDiff line numberDiff line change
@@ -1 +1,97 @@
1+
import torch
2+
import torch.nn.functional as F
3+
import dgl.function as fn
4+
import dgl.nn.pytorch.conv as dglnn
5+
from torch import nn
6+
from transformer import Decoder
7+
from dgl.nn.pytorch.conv import EdgeConv
18

9+
class GCN(nn.Module):
10+
def __init__(self, in_feats, h_feats=64, num_classes=2, num_layers=2, mlp_layers=1, dropout_rate=0.,
11+
activation='ReLU', **kwargs):
12+
super().__init__()
13+
self.h_feats = h_feats
14+
self.layers = nn.ModuleList()
15+
self.act = getattr(nn, activation)()
16+
self.layers.append(dglnn.GraphConv(in_feats, h_feats, activation=self.act))
17+
for i in range(num_layers-1):
18+
self.layers.append(dglnn.GraphConv(h_feats, h_feats, activation=self.act))
19+
self.mlp = MLP(h_feats, h_feats, num_classes, mlp_layers, dropout_rate)
20+
self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity()
21+
22+
def forward(self, graph):
23+
h = graph.ndata['feature']
24+
for i, layer in enumerate(self.layers):
25+
if i != 0:
26+
h = self.dropout(h)
27+
h = layer(graph, h)
28+
h = self.mlp(h, False)
29+
return h
30+
31+
class MLP(nn.Module):
32+
def __init__(self, in_feats, h_feats=32, num_classes=2, num_layers=2, dropout_rate=0, activation='ReLU', **kwargs):
33+
super(MLP, self).__init__()
34+
self.layers = nn.ModuleList()
35+
self.act = getattr(nn, activation)()
36+
if num_layers == 0:
37+
return
38+
if num_layers == 1:
39+
self.layers.append(nn.Linear(in_feats, num_classes))
40+
else:
41+
self.layers.append(nn.Linear(in_feats, h_feats))
42+
for i in range(1, num_layers-1):
43+
self.layers.append(nn.Linear(h_feats, h_feats))
44+
self.layers.append(nn.Linear(h_feats, num_classes))
45+
self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity()
46+
47+
def forward(self, h, is_graph=True):
48+
if is_graph:
49+
h = h.ndata['feature']
50+
for i, layer in enumerate(self.layers):
51+
if i != 0:
52+
h = self.dropout(h)
53+
h = layer(h)
54+
if i != len(self.layers)-1:
55+
h = self.act(h)
56+
return h
57+
58+
59+
class MGADN(nn.Module):
60+
def __init__(self, in_feats, h_feats=64, n_head=8, n_layers=4, dropout_rate=0., **kwargs):
61+
super().__init__()
62+
self.attn_fn = nn.Tanh()
63+
self.act_fn = nn.ReLU()
64+
self.decoder = Decoder(in_feats=in_feats, h_feats=h_feats, n_head=n_head, dropout_rate=0., n_layers=n_layers)
65+
self.filters3 = GCN(in_feats, h_feats=h_feats, num_classes=h_feats, num_layers=2, mlp_layers=2, dropout_rate=0., activation='ReLU')
66+
67+
self.DMGNN = EdgeConv(in_feats, out_feat=h_feats)
68+
69+
self.linear1 = nn.Linear(h_feats*2, h_feats)
70+
71+
self.linear = nn.Sequential(nn.Linear(h_feats, h_feats),
72+
self.attn_fn,
73+
nn.Linear(h_feats, 2))
74+
75+
self.gate_layer = nn.Linear(h_feats, h_feats)
76+
77+
def forward(self, graph):
78+
x = graph.ndata['feature']
79+
h_list = []
80+
x = x.to(torch.float32)
81+
82+
out1 = self.decoder(x, graph)
83+
out2 = self.filters3(graph)
84+
85+
F = self.DMGNN(graph,x)
86+
87+
h_list.append(out1)
88+
h_list.append(out2)
89+
90+
res = torch.cat((h_list[0], h_list[1]), dim=1)
91+
output = self.linear1(res)
92+
93+
gate = torch.sigmoid(self.gate_layer(output))
94+
95+
out = gate * output + (1 - gate) * F
96+
result = self.linear(out)
97+
return result

0 commit comments

Comments
 (0)