Skip to content

Commit

Permalink
Clipping link prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbertoFormaggio1 committed Jan 25, 2024
1 parent 8cce7e0 commit fe07534
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 84 deletions.
25 changes: 14 additions & 11 deletions engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,16 @@ def train_step_classification(model: torch.nn.Module, ds, loss_fn: torch.nn.Modu
opt.zero_grad() # Reset the gradient
out = model(ds.x, ds.edge_index) # Compute the response of the model
loss = loss_fn(out[ds.train_mask], ds.y[ds.train_mask]) # Compute the loss based on training nodes
loss.backward() # Propagate the gradient
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
opt.step() # Update the weights

train_loss = loss.item() # Get the loss
# Compute the classification accuracy
train_cls = out.argmax(dim=-1)
train_acc = torch.sum(train_cls[ds.train_mask] == ds.y[ds.train_mask])

loss.backward() # Propagate the gradient
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
opt.step() # Update the weights

return train_loss, train_acc.item()


Expand Down Expand Up @@ -264,8 +265,7 @@ def train_link_prediction(model, train_ds, val_ds, loss_fn: torch.nn.Module,
writer.add_scalar(f'{writer_info["dataset_name"]}/{writer_info["model_name"]}/{k}', results[k][-1],
epoch + writer_info["starting_epoch"])

# dovrebbe diventare "return results"
return results #val_loss
return results


def train_step_link_pred_batch_gen(model: torch.nn.Module, batch, loss_fn: torch.nn.Module,
Expand Down Expand Up @@ -422,6 +422,8 @@ def eval_classifier(model: torch.nn.Module, loss_fn: torch.nn.Module, ds, is_val
model = model.to(device)
model.eval()

# In the tutorial for GraphSAGE, the authors of pyg don't generate batches
"""
if batch_generation:
if is_validation:
# [25, 10] is the neighbors to keep at each hop defined in the original paper
Expand All @@ -436,17 +438,18 @@ def eval_classifier(model: torch.nn.Module, loss_fn: torch.nn.Module, ds, is_val
mask = ds.val_mask
else:
mask = ds.test_mask
"""

validation_batches = [ds]
if is_validation:
mask = ds.val_mask
else:
mask = ds.test_mask

eval_loss, eval_acc = .0, .0
batch_num = 0
for batch in validation_batches:
batch = batch.to(device)
# Count the number of nodes in the current batch
if batch_generation:
if is_validation:
mask = batch.val_mask
else:
mask = batch.test_mask

# Compute the response of the model
out = model(batch.x, batch.edge_index)
Expand Down
109 changes: 80 additions & 29 deletions parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
parameters_grid_GCN = {
"x": [1,2,3,4,5],
"embedding_size": [32], #[32, 48, 64],
"hidden_channels": [16], #[16, 32],
"hidden_channels": [16], #[16, 32],
"dropout": [0.5], # [0.6, 0.7], #[0.3, 0.6], # 0 pag 6
"hidden_sizes_mlp_class1": [[10]], #[[10], [15]],
"hidden_sizes_mlp_link_pred": [[10]], #[[10], [15]],
Expand Down Expand Up @@ -46,7 +46,7 @@
"hidden_channels": [8],
"heads": [8],
"heads_out": [1],
"dropout": [0.6],
"dropout": [0.6],
"hidden_sizes_mlp_class1": [[16]], # X
"hidden_sizes_mlp_link_pred": [[16]], # X
"hidden_sizes_mlp_class2": [[16]], # X
Expand Down Expand Up @@ -106,45 +106,96 @@

# SAGE
parameters_grid_SAGE = {
"embedding_size": [32],
"hidden_channels": [64],
"dropout": [0.2],
"hidden_sizes_mlp_class1": [[20]],
"hidden_sizes_mlp_link_pred": [[5]],
"hidden_sizes_mlp_class2": [[5]],
"dropout_mlp_class1": [0],
"dropout_mlp_link_pred": [0],
"dropout_mlp_class2": [0],
"num_batch_neighbors": [[10, 4], [15, 6]], # in più rispetto a GCN e GAT
"link_pred_out_size_mlp" : [16],
"epochs_classification1": [50, 100],
"epochs_linkpred": [25, 50],
"net_freezed_linkpred": [0.4, 0.6],
"epochs_classification2": [25, 50],
"net_freezed_classification2": [0.4, 0.6],
"batch_size": [32], # in più rispetto a GCN e GAT
"hidden_channels": [32],
"dropout": [0.7],
"num_batch_neighbors": [[5,2]],
"epochs_linkpred": [8],
"epochs_classification2": [200],
"batch_size": [32],
}

parameters_SAGE = {
"embedding_size": 32,
"hidden_channels": 64,
"dropout": 0.2,
"hidden_sizes_mlp_class1": [20],
"hidden_sizes_mlp_link_pred": [5],
"hidden_sizes_mlp_class2": [5],
"dropout_mlp_class1": 0,
"dropout_mlp_link_pred": 0,
"dropout_mlp_class2": 0,
"num_batch_neighbors": [10, 4],
"dropout": 0.5,
"hidden_sizes_mlp_class1": [16],
"hidden_sizes_mlp_link_pred": [16],
"hidden_sizes_mlp_class2": [16],
"dropout_mlp_class1": 0.4,
"dropout_mlp_link_pred": 0.4,
"dropout_mlp_class2": 0.4,
"num_batch_neighbors": [5, 10],
"link_pred_out_size_mlp" : 16,
"epochs_classification1": 100,
"epochs_classification1": 50,
"epochs_linkpred": 50,
"net_freezed_linkpred": 0.6,
"epochs_classification2": 50,
"net_freezed_classification2": 0.6,
"batch_size": 32
"batch_size": 8
}

parameters_grid_SAGE_pubmed = {
"embedding_size": [32], #32
"hidden_channels": [256], #512
"dropout": [0.7], #0.2
"hidden_sizes_mlp_class1": [[32]],
"hidden_sizes_mlp_link_pred": [[32]],
"hidden_sizes_mlp_class2": [[32]],
"dropout_mlp_class1": [0.5],
"dropout_mlp_link_pred": [0.5],
"dropout_mlp_class2": [0.1],
"num_batch_neighbors": [[5,2]],
"link_pred_out_size_mlp": [256],
"epochs_classification1": [50],
"epochs_linkpred": [100],
"net_freezed_linkpred": [0.4],
"epochs_classification2": [90],
"net_freezed_classification2": [0.4],
"batch_size": [1024], # Done
}

parameters_grid_SAGE_cora = {
"embedding_size": [16], #32
"hidden_channels": [16], #512
"dropout": [0.5], #0.2
"hidden_sizes_mlp_class1": [[32]],
"hidden_sizes_mlp_link_pred": [[32]],
"hidden_sizes_mlp_class2": [[32]],
"dropout_mlp_class1": [0],
"dropout_mlp_link_pred": [0.2],
"dropout_mlp_class2": [0.1],
"num_batch_neighbors": [[5,2]],
"link_pred_out_size_mlp": [16],
"epochs_classification1": [50],
"epochs_linkpred": [50],
"net_freezed_linkpred": [0.4],
"epochs_classification2": [85],
"net_freezed_classification2": [0.45],
"batch_size": [1024], # Done
}

parameters_grid_SAGE_citeseer = {
"embedding_size": [32], #32
"hidden_channels": [32], #512
"dropout": [0.7], #0.2
"hidden_sizes_mlp_class1": [[32]],
"hidden_sizes_mlp_link_pred": [[32]],
"hidden_sizes_mlp_class2": [[32]],
"dropout_mlp_class1": [0.1],
"dropout_mlp_link_pred": [0.2],
"dropout_mlp_class2": [0],
"num_batch_neighbors": [[5,2]],
"link_pred_out_size_mlp": [256],
"epochs_classification1": [50],
"epochs_linkpred": [100],
"net_freezed_linkpred": [0.4],
"epochs_classification2": [150],
"net_freezed_classification2": [0.15],
"batch_size": [1024], # Done
}




# epochs_classification1 = [50, 100]
# epochs_linkpred = [25, 50]
Expand Down
20 changes: 12 additions & 8 deletions std_classification/main_std_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
#************************************** COMMANDS ************************************

use_grid_search = False #False
dataset_name = "pubmed" # cora - citeseer - pubmed
nets = ["GAT"] # GCN - GAT - SAGE
dataset_name = "citeseer" # cora - citeseer - pubmed
nets = ["SAGE"] # GCN - GAT - SAGE

# ************************************ PARAMETERS ************************************

Expand All @@ -39,8 +39,8 @@
# parameters_GAT = parameters.parameters_GAT

# SAGE
parameters_grid_SAGE = parameters.parameters_grid_SAGE
parameters_SAGE = parameters.parameters_SAGE
#parameters_grid_SAGE = parameters.parameters_grid_SAGE
#parameters_SAGE = parameters.parameters_SAGE

# Others
# lr = parameters.lr
Expand Down Expand Up @@ -103,10 +103,14 @@
elif dataset_name == "pubmed":
param_combinations = [parameters.parameters_GAT_pubmed]
else:
if use_grid_search:
param_combinations = utils.generate_combinations(parameters_grid_SAGE)
else:
param_combinations = [parameters_SAGE]
# For sage when running this results are different than the ones reported in the paper
# since I removed label smoothing and cosineAnnealing
if dataset_name == "cora":
param_combinations = [parameters.parameters_SAGE_cora]
elif dataset_name == "citeseer":
param_combinations = [parameters.parameters_SAGE_citeseer]
elif dataset_name == "pubmed":
param_combinations = [parameters.parameters_SAGE_pubmed]

i = 1
for params in param_combinations:
Expand Down
62 changes: 26 additions & 36 deletions std_classification/parameters_std_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,42 +71,32 @@
weight_decay = 0.0005

# SAGE
parameters_grid_SAGE = {
"embedding_size": [32],
"hidden_channels": [64],
"dropout": [0.2],
"hidden_sizes_mlp_class1": [[20]],
"hidden_sizes_mlp_link_pred": [[5]],
"hidden_sizes_mlp_class2": [[5]],
"dropout_mlp_class1": [0],
"dropout_mlp_link_pred": [0],
"dropout_mlp_class2": [0],
"num_batch_neighbors": [[10, 4], [15, 6]], # in più rispetto a GCN e GAT
"link_pred_out_size_mlp" : [16],
"epochs_classification1": [50, 100],
"epochs_linkpred": [25, 50],
"net_freezed_linkpred": [0.4, 0.6],
"epochs_classification2": [25, 50],
"net_freezed_classification2": [0.4, 0.6],
"batch_size": [32], # in più rispetto a GCN e GAT
parameters_SAGE_cora = {
"hidden_channels": 32, #512
"dropout": 0.7, #0.2
"num_batch_neighbors": [5,2],
"epochs": 150,
"batch_size": 4096, # Done
"lr":0.01,
"weight_decay" : 0.0005,
}

parameters_SAGE = {
"embedding_size": 32,
"hidden_channels": 64,
"dropout": 0.2,
"hidden_sizes_mlp_class1": [20],
"hidden_sizes_mlp_link_pred": [5],
"hidden_sizes_mlp_class2": [5],
"dropout_mlp_class1": 0,
"dropout_mlp_link_pred": 0,
"dropout_mlp_class2": 0,
"num_batch_neighbors": [10, 4],
"link_pred_out_size_mlp" : 16,
"epochs_classification1": 100,
"epochs_linkpred": 50,
"net_freezed_linkpred": 0.6,
"epochs_classification2": 50,
"net_freezed_classification2": 0.6,
"batch_size": 32
parameters_SAGE_pubmed = {
"hidden_channels": 16, #512
"dropout": 0.7, #0.2
"num_batch_neighbors": [5,2],
"epochs": 150,
"batch_size": 2048, # Done
"lr":0.01,
"weight_decay" : 0.0005,
}

parameters_SAGE_citeseer = {
"hidden_channels": 32,
"dropout": 0.7,
"num_batch_neighbors": [5,2],
"epochs": 100,
"batch_size": 32,
"lr":0.01,
"weight_decay" : 0.0005,
}

0 comments on commit fe07534

Please sign in to comment.