Skip to content

Commit

Permalink
Merge pull request divelab#239 from hongyiling/dig-stable
Browse files Browse the repository at this point in the history
Fix Graphair bug
  • Loading branch information
hongyiling authored Feb 4, 2024
2 parents 5a72709 + 244843c commit 21476b0
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 33 deletions.
2 changes: 1 addition & 1 deletion dig/fairgraph/dataset/fairgraph_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def __init__(self,
self.label_number = 100
self.sens_number = 500
self.seed = 20
self.test_idx=True
self.test_idx=False
self.data_path = data_path
self.process()

Expand Down
45 changes: 19 additions & 26 deletions dig/fairgraph/method/Graphair/graphair.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class graphair(nn.Module):
:type num_proj_hidden: int,optional
'''
def __init__(self, aug_model, f_encoder, sens_model, classifier_model, lr = 1e-4, weight_decay = 1e-5, alpha = 10.0, beta = 0.1, gamma = 0.5, lam = 0.5, dataset = 'POKEC', num_hidden = 64, num_proj_hidden = 64):
def __init__(self, aug_model, f_encoder, sens_model, classifier_model, lr = 1e-4, weight_decay = 1e-5, alpha = 0.1, beta = 1.0, gamma = 10.0, lam = 1.0, dataset = 'POKEC', num_hidden = 64, num_proj_hidden = 64):
super(graphair, self).__init__()
self.aug_model = aug_model
self.f_encoder = f_encoder
Expand All @@ -69,13 +69,13 @@ def __init__(self, aug_model, f_encoder, sens_model, classifier_model, lr = 1e-4
self.criterion_cont= nn.CrossEntropyLoss()
self.criterion_recons = nn.MSELoss()

self.optimizer_s = torch.optim.Adam(self.sens_model.parameters(), lr = 1e-3, weight_decay = 1e-5)
self.optimizer_s = torch.optim.Adam(self.sens_model.parameters(), lr = 1e-4, weight_decay = 1e-5)

FG_params = [{'params': self.aug_model.parameters(), 'lr': 1e-4} , {'params':self.f_encoder.parameters()}]
self.optimizer = torch.optim.Adam(FG_params, lr = 1e-3, weight_decay = weight_decay)
self.optimizer = torch.optim.Adam(FG_params, lr = 1e-4, weight_decay = weight_decay)

self.optimizer_aug = torch.optim.Adam(self.aug_model.parameters(), lr = 1e-3, weight_decay = weight_decay)
self.optimizer_enc = torch.optim.Adam(self.f_encoder.parameters(), lr = 1e-3, weight_decay = weight_decay)
self.optimizer_aug = torch.optim.Adam(self.aug_model.parameters(), lr = 1e-4, weight_decay = weight_decay)
self.optimizer_enc = torch.optim.Adam(self.f_encoder.parameters(), lr = 1e-4, weight_decay = weight_decay)

self.fc1 = torch.nn.Linear(num_hidden, num_proj_hidden)
self.fc2 = torch.nn.Linear(num_proj_hidden, num_hidden)
Expand Down Expand Up @@ -150,7 +150,7 @@ def fit_whole(self, epochs, adj, x,sens,idx_sens,warmup=None, adv_epoches=1):
edge_loss = norm_w * F.binary_cross_entropy_with_logits(adj_logits, adj_orig.cuda())

feat_loss = self.criterion_recons(x_aug, x)
recons_loss = edge_loss + self.beta * feat_loss
recons_loss = edge_loss + self.lam * feat_loss

self.optimizer_aug.zero_grad()
with torch.autograd.set_detect_anomaly(True):
Expand Down Expand Up @@ -213,13 +213,9 @@ def fit_whole(self, epochs, adj, x,sens,idx_sens,warmup=None, adv_epoches=1):

## Due to the license issue, we re-implement the batch training code using pyg.loader.GraphSAINTRandomWalkSampler instead of directly using GraphSAINT code in the original code.
def fit_batch(self, epochs, adj, x,sens,idx_sens,warmup=None, adv_epoches=1):

assert sp.issparse(adj)
if not isinstance(adj, sp.coo_matrix):
adj = sp.coo_matrix(adj)
adj.setdiag(1)
adj_orig = sp.csr_matrix(adj)
norm_w = adj_orig.shape[0]**2 / float((adj_orig.shape[0]**2 - adj_orig.sum()) * 2)
norm_w = adj.shape[0] ** 2 / float((adj.shape[0]**2 - adj.sum()) * 2)

idx_sens = idx_sens.cpu().numpy()
sens_mask = np.zeros((x.shape[0],1))
Expand All @@ -228,21 +224,21 @@ def fit_batch(self, epochs, adj, x,sens,idx_sens,warmup=None, adv_epoches=1):

edge_index, _ = from_scipy_sparse_matrix(adj)

miniBatchLoader = GraphSAINTRandomWalkSampler(Data(x=x, edge_index=edge_index, sens = sens, sens_mask = sens_mask),
miniBatchLoader = GraphSAINTRandomWalkSampler(Data(x=x, edge_index=edge_index, sens = sens, sens_mask = sens_mask, deg = torch.tensor(np.array(adj.sum(1)).flatten())),
batch_size = 1000,
walk_length = 3,
sample_coverage = 500,
num_workers = 0,
save_dir = "./checkpoint/{}".format(self.dataset))

def normalize_adjacency(adj):
# Calculate the degrees
def normalize_adjacency(adj, deg):
# Calculate the degrees
row, col = adj.indices()
edge_weight = adj.values() if adj.values() is not None else torch.ones(row.size(0))
degree = torch_scatter.scatter_add(edge_weight, row, dim=0, dim_size=adj.size(0))
# degree = torch_scatter.scatter_add(edge_weight, row, dim=0, dim_size=adj.size(0))

# Inverse square root of degree matrix
degree_inv_sqrt = degree.pow(-0.5)
degree_inv_sqrt = deg.pow(-0.5)
degree_inv_sqrt[degree_inv_sqrt == float('inf')] = 0

# Normalize
Expand All @@ -258,9 +254,8 @@ def normalize_adjacency(adj):
for _ in range(warmup):
for data in miniBatchLoader:
data = data.cuda()
edge_index,_ = add_remaining_self_loops(to_undirected(data.edge_index))
sub_adj = normalize_adjacency(to_torch_sparse_tensor(edge_index)).cuda()
sub_adj_dense = to_dense_adj(edge_index = edge_index, max_num_nodes = data.x.shape[0])[0].float()
sub_adj = normalize_adjacency(to_torch_sparse_tensor(data.edge_index, data.edge_norm), data.deg.float()).cuda()
sub_adj_dense = to_dense_adj(edge_index = data.edge_index, max_num_nodes = data.x.shape[0])[0].float()
adj_aug, x_aug, adj_logits = self.aug_model(sub_adj, data.x, adj_orig = sub_adj_dense)

edge_loss = norm_w * F.binary_cross_entropy_with_logits(adj_logits, sub_adj_dense)
Expand All @@ -284,10 +279,8 @@ def normalize_adjacency(adj):
data = data.cuda()

### generate fair view
edge_index,_ = add_remaining_self_loops(to_undirected(data.edge_index))
sub_adj = normalize_adjacency(to_torch_sparse_tensor(edge_index)).cuda()

sub_adj_dense = to_dense_adj(edge_index = edge_index, max_num_nodes = data.x.shape[0])[0].float()
sub_adj = normalize_adjacency(to_torch_sparse_tensor(data.edge_index, data.edge_norm), data.deg.float()).cuda()
sub_adj_dense = to_dense_adj(edge_index = data.edge_index, max_num_nodes = data.x.shape[0])[0].float()
adj_aug, x_aug, adj_logits = self.aug_model(sub_adj, data.x, adj_orig = sub_adj_dense)


Expand Down Expand Up @@ -316,11 +309,11 @@ def normalize_adjacency(adj):

s_pred , _ = self.sens_model(adj_aug, x_aug)
senloss = torch.nn.BCEWithLogitsLoss(weight=data.node_norm, reduction='sum')(s_pred[mask].squeeze(), data.sens[mask].float())

## update aug model
logits, labels = self.info_nce_loss_2views(torch.cat((h, h_prime), dim = 0))
contrastive_loss = (nn.CrossEntropyLoss(reduction='none')(logits, labels) * data.node_norm.repeat(2)).sum()

## update encoder
edge_loss = norm_w * F.binary_cross_entropy_with_logits(adj_logits, sub_adj_dense)

Expand Down
2 changes: 1 addition & 1 deletion dig/fairgraph/method/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def run(self,device,dataset,model='Graphair',epochs=10_000,test_epochs=1_000,
aug_model = aug_module(features, n_hidden=64, temperature=1).to(device)
f_encoder = GCN_Body(in_feats = features.shape[1], n_hidden = 64, out_feats = 64, dropout = 0.1, nlayer = 2).to(device)
sens_model = GCN(in_feats = features.shape[1], n_hidden = 64, out_feats = 64, nclass = 1).to(device)
classifier_model = Classifier(input_dim=64,hidden_dim=64)
classifier_model = Classifier(input_dim=64,hidden_dim=128)
model = graphair(aug_model=aug_model,f_encoder=f_encoder,sens_model=sens_model,classifier_model=classifier_model, lr=lr,weight_decay=weight_decay,dataset=dataset_name).to(device)
else:
raise Exception('At this moment, only Graphair is supported!')
Expand Down
4 changes: 2 additions & 2 deletions examples/fairgraph/Graphair/run_graphair_nba.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
# Train and evaluate
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
run_fairgraph = run()
run_fairgraph.run(device,dataset=nba,model='Graphair',epochs=2000,test_epochs=500,
lr=1e-4,weight_decay=1e-5)
run_fairgraph.run(device,dataset=nba,model='Graphair',epochs=10000,test_epochs=500,
lr=1e-3,weight_decay=1e-5)
6 changes: 3 additions & 3 deletions examples/fairgraph/Graphair/run_graphair_pokec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import torch

# Load the dataset and split
pokec = POKEC(dataset_sample='pokec_z') # you may also choose 'pokec_n'
# pokec = POKEC(dataset_sample='pokec_n')
# pokec = POKEC(dataset_sample='pokec_z') # you may also choose 'pokec_n'
pokec = POKEC(dataset_sample='pokec_n')

# Train and evaluate
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
run_fairgraph = run()
run_fairgraph.run(device,dataset=pokec,model='Graphair',epochs=10_000,test_epochs=1000,
run_fairgraph.run(device,dataset=pokec,model='Graphair',epochs=10_000,test_epochs=500,
lr=1e-3,weight_decay=1e-5)

0 comments on commit 21476b0

Please sign in to comment.