Skip to content

Commit ed61c13

Browse files
author
psp3dcg
committed
init project
1 parent 8615c67 commit ed61c13

File tree

7 files changed

+335
-0
lines changed

7 files changed

+335
-0
lines changed

README.md

+27
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,29 @@
11
# GSAPool
22
Pytorch Implementation of GSAPool (WWW 2020)
3+
====
4+
5+
PyTorch implementation of [Structure-Feature based Graph Self-adaptive Pooling](https://arxiv.org/pdf/2002.00848)
6+
7+
8+
9+
10+
## Requirements
11+
* pytorch
12+
* torch_geometric
13+
14+
## Usage
15+
16+
```python control_shell.py```
17+
18+
19+
## Cite
20+
```
21+
@InProceedings{GSAPool2020,
22+
title = {Structure-Feature based Graph Self-adaptive Pooling},
23+
author = {Liang Zhang and Xudong Wang and Hongsheng Li and Guangming Zhu and Peiyi Shen and Ping Li and Xiaoyuan Lu and Syed Afaq Ali Shah and Mohammed Bennamoun},
24+
booktitle = {Proceedings of the Web Conference 2020},
25+
year = {2020},
26+
month = {20-25 April}
27+
}
28+
```
29+

control_shell.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#coding = "utf-8"
2+
import os
3+
import util
4+
5+
parser = util.parser
6+
args = parser.parse_args()
7+
os.chdir(args.save_path)
8+
for i in range(args.training_times):
9+
print('------------------------------')
10+
print("GSAPool Training Control Shell")
11+
print('------------------------------')
12+
print('Training Dataset: ', args.dataset)
13+
print('Pooling Layer Type: ',args.pooling_layer_type)
14+
print('Feature Fusion Type:',args.feature_fusion_type)
15+
print('------------------------------')
16+
os.system("python main.py")
17+
with open(os.path.join(args.save_path, 'result.txt'), 'a') as f:
18+
f.write('\r\n')
19+
20+
21+
22+

latest.pth

540 KB
Binary file not shown.

layers.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
from torch.nn import Parameter
6+
from torch_geometric.nn.pool.topk_pool import topk,filter_adj
7+
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, ChebConv, GraphConv
8+
9+
class GSAPool(torch.nn.Module):
10+
11+
def __init__(self, in_channels, pooling_ratio=0.5, alpha=0.6, pooling_conv="GCNConv", fusion_conv="false",
12+
min_score=None, multiplier=1, non_linearity=torch.tanh):
13+
super(GSAPool,self).__init__()
14+
self.in_channels = in_channels
15+
16+
self.ratio = pooling_ratio
17+
self.alpha = alpha
18+
19+
self.sbtl_layer = self.conv_selection(pooling_conv, in_channels)
20+
self.fbtl_layer = nn.Linear(in_channels, 1)
21+
self.fusion = self.conv_selection(fusion_conv, in_channels, conv_type=1)
22+
23+
self.min_score = min_score
24+
self.multiplier = multiplier
25+
self.fusion_flag = 0
26+
if(fusion_conv!="false"):
27+
self.fusion_flag = 1
28+
self.non_linearity = non_linearity
29+
30+
def conv_selection(self, conv, in_channels, conv_type=0):
31+
if(conv_type == 0):
32+
out_channels = 1
33+
elif(conv_type == 1):
34+
out_channels = in_channels
35+
if(conv == "GCNConv"):
36+
return GCNConv(in_channels,out_channels)
37+
elif(conv == "ChebConv"):
38+
return ChebConv(in_channels,out_channels,1)
39+
elif(conv == "SAGEConv"):
40+
return SAGEConv(in_channels,out_channels)
41+
elif(conv == "GATConv"):
42+
return GATConv(in_channels,out_channels, heads=1, concat=True)
43+
elif(conv == "GraphConv"):
44+
return GraphConv(in_channels,out_channels)
45+
else:
46+
raise ValueError
47+
48+
def forward(self, x, edge_index, edge_attr=None, batch=None):
49+
if batch is None:
50+
batch = edge_index.new_zeros(x.size(0))
51+
x = x.unsqueeze(-1) if x.dim() == 1 else x
52+
53+
#SBTL
54+
score_s = self.sbtl_layer(x,edge_index).squeeze()
55+
#FBTL
56+
score_f = self.fbtl_layer(x).squeeze()
57+
#hyperparametr alpha
58+
score = score_s*self.alpha + score_f*(1-self.alpha)
59+
60+
score = score.unsqueeze(-1) if score.dim()==0 else score
61+
62+
if self.min_score is None:
63+
score = self.non_linearity(score)
64+
else:
65+
score = softmax(score, batch)
66+
perm = topk(score, self.ratio, batch)
67+
68+
#fusion
69+
if(self.fusion_flag == 1):
70+
x = self.fusion(x, edge_index)
71+
72+
x = x[perm] * score[perm].view(-1, 1)
73+
x = self.multiplier * x if self.multiplier != 1 else x
74+
75+
batch = batch[perm]
76+
edge_index, edge_attr = filter_adj(
77+
edge_index, edge_attr, perm, num_nodes=score.size(0))
78+
79+
return x, edge_index, edge_attr, batch, perm

main.py

+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import os
2+
import torch
3+
import torch.nn.functional as F
4+
from torch.utils.data import random_split
5+
from torch_geometric.datasets import TUDataset
6+
from torch_geometric.data import DataLoader
7+
8+
9+
import util
10+
from networks import Net
11+
12+
13+
14+
#parameter initialization
15+
parser = util.parser
16+
args = parser.parse_args()
17+
torch.manual_seed(args.seed)
18+
19+
#device selection
20+
if torch.cuda.is_available():
21+
torch.cuda.manual_seed(args.seed)
22+
args.device = 'cuda:0'
23+
else:
24+
args.device = 'cpu'
25+
26+
#dataset split
27+
def data_builder(args):
28+
dataset = TUDataset(os.path.join('data',args.dataset),name=args.dataset)
29+
args.num_classes = dataset.num_classes
30+
args.num_features = dataset.num_features
31+
32+
num_training = int(len(dataset)*0.8)
33+
num_val = int(len(dataset)*0.1)
34+
num_test = len(dataset) - (num_training+num_val)
35+
training_set,validation_set,test_set = random_split(dataset,[num_training,num_val,num_test])
36+
37+
train_loader = DataLoader(training_set, batch_size=args.batch_size, shuffle=True)
38+
val_loader = DataLoader(validation_set,batch_size=args.batch_size,shuffle=False)
39+
test_loader = DataLoader(test_set,batch_size=1,shuffle=False)
40+
41+
return train_loader, val_loader, test_loader
42+
43+
#test function
44+
def test(model,loader):
45+
model.eval()
46+
correct = 0.
47+
loss = 0.
48+
for data in loader:
49+
data = data.to(args.device)
50+
out = model(data)
51+
pred = out.max(dim=1)[1]
52+
correct += pred.eq(data.y).sum().item()
53+
loss += F.nll_loss(out,data.y,reduction='sum').item()
54+
return correct / len(loader.dataset),loss / len(loader.dataset)
55+
56+
#save result in txt
57+
def save_result(test_acc, save_path):
58+
with open(os.path.join(save_path, 'result.txt'), 'a') as f:
59+
test_acc *= 100
60+
f.write(args.dataset+";")
61+
f.write("pooling_layer_type:"+args.pooling_layer_type+";")
62+
f.write("feature_fusion_type:"+args.feature_fusion_type+";")
63+
f.write(str(test_acc))
64+
f.write('\r\n')
65+
66+
#training configuration
67+
train_loader, val_loader, test_loader = data_builder(args)
68+
model = Net(args).to(args.device)
69+
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
70+
71+
#training steps
72+
patience = 0
73+
for epoch in range(args.epochs):
74+
model.train()
75+
for i, data in enumerate(train_loader):
76+
data = data.to(args.device)
77+
out = model(data)
78+
loss = F.nll_loss(out, data.y)
79+
print("Training loss:{}".format(loss.item()))
80+
loss.backward()
81+
optimizer.step()
82+
optimizer.zero_grad()
83+
val_acc,val_loss = test(model,val_loader)
84+
print("Validation loss:{}\taccuracy:{}".format(val_loss,val_acc))
85+
print("Epoch{}".format(epoch))
86+
if val_loss < args.min_loss:
87+
torch.save(model.state_dict(),'latest.pth')
88+
print("Model saved at epoch{}".format(epoch))
89+
min_loss = val_loss
90+
patience = 0
91+
else:
92+
patience += 1
93+
if patience > args.patience:
94+
break
95+
96+
#test step
97+
model = Net(args).to(args.device)
98+
model.load_state_dict(torch.load('latest.pth'))
99+
test_acc,test_loss = test(model,test_loader)
100+
print("Test accuarcy:{}".format(test_acc))
101+
save_result(test_acc, args.save_path)

networks.py

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import torch
2+
import numpy as np
3+
import torch.nn.functional as F
4+
from torch_geometric.nn import GCNConv
5+
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
6+
7+
from layers import GSAPool
8+
9+
class Net(torch.nn.Module):
10+
def __init__(self,args):
11+
super(Net, self).__init__()
12+
13+
self.args = args
14+
self.nhid = args.nhid
15+
16+
self.num_features = args.num_features
17+
self.num_classes = args.num_classes
18+
19+
self.alpha = args.alpha
20+
self.pooling_ratio = args.pooling_ratio
21+
self.dropout_ratio = args.dropout_ratio
22+
23+
self.pooling_layer_type = args.pooling_layer_type
24+
self.feature_fusion_type = args.feature_fusion_type
25+
26+
self.conv1 = GCNConv(self.num_features, self.nhid)
27+
self.pool1 = GSAPool(self.nhid, pooling_ratio=self.pooling_ratio, alpha = self.alpha,
28+
pooling_conv=self.pooling_layer_type, fusion_conv=self.feature_fusion_type)
29+
self.conv2 = GCNConv(self.nhid, self.nhid)
30+
self.pool2 = GSAPool(self.nhid, pooling_ratio=self.pooling_ratio, alpha = self.alpha,
31+
pooling_conv=self.pooling_layer_type, fusion_conv=self.feature_fusion_type)
32+
self.conv3 = GCNConv(self.nhid, self.nhid)
33+
self.pool3 = GSAPool(self.nhid, pooling_ratio=self.pooling_ratio, alpha = self.alpha,
34+
pooling_conv=self.pooling_layer_type, fusion_conv=self.feature_fusion_type)
35+
36+
self.lin1 = torch.nn.Linear(self.nhid*2, self.nhid)
37+
self.lin2 = torch.nn.Linear(self.nhid, self.nhid//2)
38+
self.lin3 = torch.nn.Linear(self.nhid//2, self. num_classes)
39+
40+
41+
def forward(self, data):
42+
x, edge_index, batch = data.x, data.edge_index, data.batch
43+
44+
45+
x = F.relu(self.conv1(x, edge_index))
46+
x, edge_index, _, batch, _ = self.pool1(x, edge_index, None, batch)
47+
x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
48+
49+
x = F.relu(self.conv2(x, edge_index))
50+
x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch)
51+
x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
52+
53+
x = F.relu(self.conv3(x, edge_index))
54+
x, edge_index, _, batch, _ = self.pool3(x, edge_index, None, batch)
55+
x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
56+
57+
x = x1 + x2 + x3
58+
59+
x = F.relu(self.lin1(x))
60+
x = F.dropout(x, p=self.dropout_ratio, training=self.training)
61+
x = F.relu(self.lin2(x))
62+
x = F.log_softmax(self.lin3(x), dim=-1)
63+
64+
return x
65+
66+

util.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import argparse
2+
#Parameter Configuration
3+
4+
parser = argparse.ArgumentParser()
5+
6+
parser.add_argument('--seed', type=int, default=777,
7+
help='seed')
8+
parser.add_argument('--batch_size', type=int, default=128,
9+
help='batch size')
10+
parser.add_argument('--lr', type=float, default=0.0005,
11+
help='learning rate')
12+
parser.add_argument('--weight_decay', type=float, default=0.0001,
13+
help='weight decay')
14+
parser.add_argument('--min_loss', type=float, default=1e10,
15+
help='min loss value')
16+
parser.add_argument('--nhid', type=int, default=128,
17+
help='hidden size')
18+
parser.add_argument('--pooling_ratio', type=float, default=0.5,
19+
help='pooling ratio')
20+
parser.add_argument('--alpha', type=float, default=0.6,
21+
help='combination_ratio')
22+
parser.add_argument('--dropout_ratio', type=float, default=0.5,
23+
help='dropout ratio')
24+
parser.add_argument('--dataset', type=str, default='DD',
25+
help='DD/NCI1/NCI109/Mutagenicity')
26+
parser.add_argument('--epochs', type=int, default=100000,#default = 100000
27+
help='maximum number of epochs')
28+
parser.add_argument('--patience', type=int, default=50,
29+
help='patience for earlystopping')
30+
parser.add_argument('--pooling_layer_type', type=str, default='GCNConv',
31+
help='GCNConv')
32+
parser.add_argument('--feature_fusion_type', type=str, default='GATConv',
33+
help='GCNConv/SAGEConv/ChebConv/GATConv/GraphConv')
34+
parser.add_argument('--save_path', type=str, default='/home/baoke/workspace_wxd/GSAPool',
35+
help='path to save result')
36+
parser.add_argument('--training_times', type=int, default=20,
37+
help='')
38+
39+
40+

0 commit comments

Comments
 (0)