Skip to content

Commit

Permalink
Minor changes to GAT architecture and results output
Browse files Browse the repository at this point in the history
  • Loading branch information
Manuel383 committed Jan 21, 2024
1 parent d889598 commit 57ea3e7
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 31 deletions.
2 changes: 1 addition & 1 deletion engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def train_link_prediction(model, train_ds, val_ds, loss_fn: torch.nn.Module,
epoch + writer_info["starting_epoch"])

# dovrebbe diventare "return results"
return val_loss
return results #val_loss


def train_step_link_pred_batch_gen(model: torch.nn.Module, batch, loss_fn: torch.nn.Module,
Expand Down
77 changes: 50 additions & 27 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@

#************************************** COMMANDS ************************************

use_grid_search = True #False
dataset_name = "pubmed" # cora - citeseer - pubmed
nets = ["SAGE"] # GCN - GAT - SAGE
use_grid_search = False #False
dataset_name = "cora" # cora - citeseer - pubmed
nets = ["GAT"] # GCN - GAT - SAGE

# ************************************ PARAMETERS ************************************

Expand Down Expand Up @@ -155,6 +155,8 @@
batch_size = None

# ************************************ CLASSIFICATION 1 ************************************

print("************************* TRAINING CLASSIFICATION 1 *************************")

input_size = classification_dataset.num_features
hidden_channels = params["hidden_channels"]
Expand All @@ -165,7 +167,8 @@
network = model.GCN(input_size=input_size, embedding_size=output_size, hidden_channels=hidden_channels, dropout=dropout)
elif net == "GAT":
heads = params["heads"]
network = model.GAT(input_size=input_size, embedding_size=output_size, hidden_channels=hidden_channels, heads=heads, dropout=dropout)
heads_out = params["heads_out"]
network = model.GAT(input_size=input_size, embedding_size=output_size, hidden_channels=hidden_channels, heads=heads, heads_out=heads_out, dropout=dropout)
else:
network = model.Graph_SAGE(input_size=input_size, embedding_size=output_size, hidden_channels=hidden_channels, dropout=dropout)

Expand Down Expand Up @@ -200,17 +203,22 @@
optimizer, epochs, writer, writer_info, device, batch_generation,
num_batch_neighbors, batch_size, lr_schedule)

print()
print("CLASSIFICATION 1 RESULTS")
for k, v in results_class1.items():
print(k + ":" + str(v[-1]))
print("****************************************************** \n")
# print()
# print("CLASSIFICATION 1 RESULTS")
# for k, v in results_class1.items():
# print(k + ":" + str(v[-1]))
# print("****************************************************** \n")

# _, acc1 = engine.eval_classifier(model_classification1, criterion, classification_dataset.data,False,batch_generation,device,num_batch_neighbors,batch_size)
# print(acc1)

_, acc1 = engine.eval_classifier(model_classification1, criterion, classification_dataset.data,False,batch_generation,device,num_batch_neighbors,batch_size)
print(acc1)
print()
print("*****************************************************************************\n")

# ************************************ LINK PREDICTION ************************************

print("************************* TRAINING LINK PREDICTION *************************")

input_size_mlp = params["embedding_size"]
output_size_mlp = params["link_pred_out_size_mlp"] # Non è legato al numero di classi ## e allora che mettiamo ?
hidden_sizes_mlp = params["hidden_sizes_mlp_link_pred"]
Expand Down Expand Up @@ -248,16 +256,17 @@
'second_tr_e': epochs, 'starting_epoch': epochs_cls + epochs}
optimizer = torch.optim.Adam(model_linkpred.parameters(), lr=lr_schedule.get_lr()[0], weight_decay=weight_decay)
epochs = epochs_linkpred - epochs
engine.train_link_prediction(model_linkpred, train_ds, val_ds, criterion, optimizer, epochs, writer,
results_linkpred = engine.train_link_prediction(model_linkpred, train_ds, val_ds, criterion, optimizer, epochs, writer,
writer_info,
device, batch_generation, num_batch_neighbors, batch_size, lr_schedule)

print()
print("LINK PREDICTION TRAINING DONE")
print("****************************************************** \n")
print("*****************************************************************************\n")

# ************************************ CLASSIFICATION 2 ************************************

print("************************* TRAINING CLASSIFICATION 2 *************************")

input_size_mlp = params["embedding_size"]
output_size_mlp = classification_dataset.num_classes
hidden_sizes_mlp = params["hidden_sizes_mlp_class2"]
Expand Down Expand Up @@ -290,11 +299,11 @@
optimizer, epochs, writer, writer_info, device, batch_generation,
num_batch_neighbors, batch_size, lr_schedule)

print()
print("CLASSIFICATION 2a RESULTS")
for k, v in results_class2a.items():
print(k + ":" + str(v[-1]))
print("****************************************************** \n")
# print()
# print("CLASSIFICATION 2a RESULTS")
# for k, v in results_class2a.items():
# print(k + ":" + str(v[-1]))
# print("****************************************************** \n")

results_class2b = {}
if net_freezed_classification2 < 1.0:
Expand All @@ -306,14 +315,18 @@
results_class2b = engine.train_classification(model_classification2, classification_dataset.data, classification_dataset.data, criterion,
optimizer, epochs, writer, writer_info, device, batch_generation,
num_batch_neighbors, batch_size, lr_schedule)
print()
print("\nCLASSIFICATION 2b RESULTS")
for k,v in results_class2b.items():
print(k + ":" + str(v[-1]))
print("****************************************************** \n")

_, acc2 = engine.eval_classifier(model_classification2, criterion, classification_dataset.data,False,batch_generation,device,num_batch_neighbors,batch_size)
print("test acc with LinkPrediction:", acc2)
# print()
# print("\nCLASSIFICATION 2b RESULTS")
# for k,v in results_class2b.items():
# print(k + ":" + str(v[-1]))
# print("****************************************************** \n")


print()
print("*****************************************************************************")

# _, acc2 = engine.eval_classifier(model_classification2, criterion, classification_dataset.data,False,batch_generation,device,num_batch_neighbors,batch_size)
# print("test acc with LinkPrediction:", acc2)
# ************************************ SAVING RESULTS ************************************

# params_string = "" # part of the key that explicit the parameters used
Expand Down Expand Up @@ -364,6 +377,16 @@
with open(results_file, "w") as f:
json.dump(results_dict, f, indent = 4)

print("\nClassification 1 val accuracy: ", results_class1["val_acc"][-1])
print("Link prediction val accuracy: ", results_linkpred["val_acc"][-1])
print("Classification 2a val accuracy: ", results_class2b["val_acc"][-1])
print("Classification 2b val accuracy: ", results_class2b["val_acc"][-1])

_, test_acc = engine.eval_classifier(model_classification2, criterion, classification_dataset.data,False,batch_generation,device,num_batch_neighbors,batch_size)
print("\n Test accuracy: ", test_acc)

print()
print("*****************************************************************************")

if use_grid_search:
num_best_runs = 20
Expand Down
3 changes: 2 additions & 1 deletion main_no_MLP.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,9 @@
dropout=dropout)
elif net == "GAT":
heads = params["heads"]
heads_out = params["heads_out"]
network = model.GAT(input_size=input_size, embedding_size=output_size, hidden_channels=hidden_channels,
heads=heads, dropout=dropout)
heads=heads, heads_out=heads_out, dropout=dropout)
else:
network = model.Graph_SAGE(input_size=input_size, embedding_size=output_size,
hidden_channels=hidden_channels, dropout=dropout)
Expand Down
4 changes: 2 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ def forward(self, x, edge_index):

# https://arxiv.org/abs/2105.14491
class GAT(LinkPredictor):
def __init__(self, input_size: int, embedding_size: int, hidden_channels: int = 16, heads: int = 8, dropout: float = 0.6):
def __init__(self, input_size: int, embedding_size: int, hidden_channels: int = 16, heads: int = 8, heads_out: int = 1, dropout: float = 0.6):
super().__init__()
# 256 channels seemed the best in the paper (but it depends on the complexity of the dataset)
# LR = 0.001/0.01
self.conv1 = GATv2Conv(input_size, hidden_channels, heads=heads)
# Maybe concat should be set to False for the last layer so that the outputs will be averaged.
self.conv2 = GATv2Conv(hidden_channels * heads, embedding_size, heads=1)
self.conv2 = GATv2Conv(hidden_channels * heads, embedding_size, heads=heads_out)
self.dropout = dropout

def forward(self, x, edge_index):
Expand Down

0 comments on commit 57ea3e7

Please sign in to comment.