Skip to content

Commit

Permalink
Added Link Predictor
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbertoFormaggio1 committed Nov 19, 2023
1 parent 88fc58e commit 713ee69
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 21 deletions.
50 changes: 48 additions & 2 deletions engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
from tqdm.auto import tqdm
from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import negative_sampling
from sklearn.metrics import roc_auc_score

def train(model, train_ds, val_ds, loss_fn: torch.nn.Module,
opt: torch.optim.Optimizer, epochs: int, batch_generation: bool = False):
Expand Down Expand Up @@ -45,7 +47,6 @@ def train(model, train_ds, val_ds, loss_fn: torch.nn.Module,

return results


def eval(model, loss_fn, ds, mask):
with torch.no_grad():
out = model(ds.x, ds.edge_index)
Expand All @@ -58,7 +59,8 @@ def eval(model, loss_fn, ds, mask):

def train_step(model: torch.nn.Module, ds, loss_fn: torch.nn.Module,
opt: torch.optim.Optimizer):
model.train() # Set the model in training phase
model.train() # Set the model in training mode
opt.zero_grad() # Reset the gradient
out = model(ds.x, ds.edge_index) # Compute the response of the model
loss = loss_fn(out[ds.train_mask], ds.y[ds.train_mask]) # Compute the loss based on training nodes
loss.backward() # Propagate the gradient
Expand All @@ -70,3 +72,47 @@ def train_step(model: torch.nn.Module, ds, loss_fn: torch.nn.Module,
train_acc = torch.sum(train_cls[ds.train_mask] == ds.y[ds.train_mask]) / torch.sum(ds.train_mask)

return train_loss, train_acc.item()


def train_link_prediction(model, train_ds, loss_fn: torch.nn.Module,
opt: torch.optim.Optimizer, epochs: int):

for epoch in range(epochs):
model.train()
opt.zero_grad()
# Computing first the embeddings with message passing on the edges that are already existing
# in the graph
z = model(train_ds.x, train_ds.edge_index)

# For every epoch perform a round of negative sampling.
# This array will return edges not already present in edge_index.
# The number of nodes is given by num_nodes
# The number of negative edges to generate is the same as the number of edges in the original graph, this way the predictor is unbiased
neg_edge_index = negative_sampling(
edge_index=train_ds.edge_index, num_nodes=train_ds.num_nodes,
num_neg_samples=train_ds.edge_label_index.size(1), method='sparse')

# The edge_label for the edges that are already in the graph will be 1
# The edge_label for the edges we just created instead will be 0

# concatenating on the last dimensions since we're adding more edges
edge_label_index = torch.cat([train_ds.edge_label_index, neg_edge_index], dim=1)
# Concatenating along the 1st (and only dimension) the label of the negative edges (thus, 0)
edge_label = torch.cat([train_ds.edge_label,
train_ds.edge_label.new_zeros(neg_edge_index.size(1))], dim=0)

#out = model.decode(z, edge_label_index).view(-1)
# Let's understand first what decode returns
out = model.decode(z, edge_label_index)
loss = loss_fn(out, edge_label)
loss.backward()
opt.step()

return loss.item()

@torch.no_grad()
def test(model, data):
model.eval()
z = model(data.x, data.edge_index)
out = model.decode(z, data.edge_label_index).view(-1).sigmoid()
return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())
1 change: 0 additions & 1 deletion load_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

def load_ds(dataset_name, transform):
dataset = Planetoid(root='data/Planetoid', name=dataset_name, transform=transform)

return dataset

def print_ds_info(ds : Planetoid):
Expand Down
68 changes: 53 additions & 15 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,66 @@
import torch
import os
from torch_geometric.transforms import NormalizeFeatures
import load_dataset
import engine
import model
import torch_geometric.transforms as T

datasets = {}
device = 'cuda' if torch.cuda.is_available() else 'cpu'

datasets['cora'] = load_dataset.load_ds('Cora', NormalizeFeatures())
datasets['citeseer'] = load_dataset.load_ds('CiteSeer', NormalizeFeatures())
datasets['pubmed'] = load_dataset.load_ds('PubMed', NormalizeFeatures())
classification = False

for ds in datasets.values():
load_dataset.print_ds_info(ds)
print('\n#################################\n')
if classification:

dataset = datasets['cora']
transform_classification = T.Compose([
T.NormalizeFeatures(),
T.ToDevice(device)
])

model = model.GAT(dataset.num_features, dataset.num_classes)
datasets = {}

criterion = torch.nn.CrossEntropyLoss() # Define loss criterion => CrossEntropyLoss in the case of classification
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
datasets['cora'] = load_dataset.load_ds('Cora', transform_classification)
datasets['citeseer'] = load_dataset.load_ds('CiteSeer', transform_classification)
datasets['pubmed'] = load_dataset.load_ds('PubMed', transform_classification)

results = engine.train(model, dataset.data, dataset.data, criterion, optimizer, 10, False)
for ds in datasets.values():
load_dataset.print_ds_info(ds)
print('\n#################################\n')

for k, r in results.items():
print(k, r)
dataset = datasets['cora']

model = model.GAT(dataset.num_features, dataset.num_classes)

criterion = torch.nn.CrossEntropyLoss() # Define loss criterion => CrossEntropyLoss in the case of classification
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

results = engine.train(model, dataset.data, dataset.data, criterion, optimizer, 10, False)

for k, r in results.items():
print(k, r)

else:
transform_prediction = T.Compose([
T.NormalizeFeatures(),
T.ToDevice(device),
T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,
add_negative_train_samples=False)
])

datasets = {}

datasets['cora'] = load_dataset.load_ds('Cora', transform_prediction)
datasets['citeseer'] = load_dataset.load_ds('CiteSeer', transform_prediction)
datasets['pubmed'] = load_dataset.load_ds('PubMed', transform_prediction)

dataset = datasets['cora']
train_ds, val_ds, test_ds = dataset[0]

model = model.GCN_Predictor(dataset.num_features, dataset.num_classes)

criterion = torch.nn.BCEWithLogitsLoss() # Define loss criterion => Binary Cross Entropy for link prediction
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

engine.train_link_prediction(model, train_ds, criterion, optimizer, 101)

acc = engine.test(model, val_ds)
print(acc)
43 changes: 40 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,20 @@
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import SAGEConv


class LinkPredictor(nn.Module):
def decode(self, embeddings, edge_label_index):
# Computing similarity between embeddings in the training set + negative examples returned by the function for sampling
simil = embeddings[edge_label_index[0]] * embeddings[edge_label_index[1]]
return simil.sum(dim=-1)

def decode_all(self, embedding):
# Compute the similarity as ZZ^T
prob_adj = embedding @ embedding.t()
# to investigate what this does ??
return (prob_adj > 0).nonzero(as_tuple=False).t()


########## QUESTION: SHOULD DROPOUT BE ADDED?
########## https://dl.acm.org/doi/pdf/10.1145/3487553.3524725
class MLP(nn.Module):
Expand All @@ -23,10 +37,11 @@ def __init__(self, input_size: int, num_classes: int, hidden_sizes: list[int], d
def forward(self, x):
return self.MLP(x)


# Check the parameters of GCN to find the best configuration.
# https://arxiv.org/abs/1609.02907
class GCN(nn.Module):
def __init__(self, input_size: int, hidden_channels: int, embedding_size: int, dropout: float = 0.5):
def __init__(self, input_size: int, embedding_size: int, hidden_channels: int = 16, dropout: float = 0.5):
super().__init__()
# Should parameter improved = True?
# Cached should be used for transductive learning, which is the case of our link prediction.
Expand All @@ -37,11 +52,31 @@ def __init__(self, input_size: int, hidden_channels: int, embedding_size: int, d

def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = nn.ReLU(x)
x = nn.ELU()(x)
x = self.dropout(x)
x = self.conv2(x, edge_index)
return x


# TO ERASE, only to test the link prediction
class GCN_Predictor(LinkPredictor):
def __init__(self, input_size: int, embedding_size: int, hidden_channels: int = 16, dropout: float = 0.5):
super().__init__()
# Should parameter improved = True?
# Cached should be used for transductive learning, which is the case of our link prediction.
# we need to see if it's possible to modify it when changing task or not
self.conv1 = GCNConv(input_size, hidden_channels, improved=True)
self.conv2 = GCNConv(hidden_channels, embedding_size, improved=True)
self.dropout = nn.Dropout(p=dropout)

def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = nn.ELU()(x)
x = self.dropout(x)
x = self.conv2(x, edge_index)
return x


# https://arxiv.org/abs/2105.14491
class GAT(nn.Module):
def __init__(self, input_size: int, embedding_size: int, hidden_channels: int = 16, heads:int = 8):
Expand All @@ -62,6 +97,7 @@ def forward(self, x, edge_index):
x = self.conv2(x, edge_index)
return x


# https://arxiv.org/pdf/1706.02216v4.pdf
class Graph_SAGE(nn.Module):
def __init__(self, input_size: int, embedding_dim: int, hidden_size: int = 512, dropout: float = 0.5):
Expand All @@ -74,14 +110,14 @@ def __init__(self, input_size: int, embedding_dim: int, hidden_size: int = 512,
self.sage2 = SAGEConv(hidden_size, embedding_dim, aggr='max')
self.dropout = nn.Dropout(p=dropout)


def forward(self, x, edge_index):
x = self.sage1(x, edge_index)
x = nn.ELU()(x)
x = self.dropout(x)
x = self.sage2(x, edge_index)
return x


class SAGE_MLP(nn.Module):
def __init__(self, sage, mlp):
super().__init__()
Expand All @@ -94,6 +130,7 @@ def forward(self, x):
x = self.mlp(x)
return x


class GAT_MLP(nn.Module):
"""
Please note that the class returns logits. They should be processed according to the graph task (e.g. softmax for
Expand Down

0 comments on commit 713ee69

Please sign in to comment.