Skip to content

Commit

Permalink
Merge pull request #58 from SherylHYX/link_split_fixes
Browse files Browse the repository at this point in the history
link split fixes
  • Loading branch information
SherylHYX authored Feb 6, 2024
2 parents 4ac0b4c + 5221fa2 commit e91e095
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
os: [windows-latest, ubuntu-latest, macos-latest]
os: [windows-latest, ubuntu-latest]
torch-version: [2.0.0]
include:
- torch-version: 2.0.0
Expand Down
29 changes: 29 additions & 0 deletions test/signed_data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,35 @@ def test_load_signed_real_data():
assert signed_dataset.is_signed


def test_connectivity():
seed = 0
dataset_name = 'bitcoin_otc'

# Load data using torch geometric signed directed data loader
data = load_signed_real_data(dataset=dataset_name)

# Create several train, val, test splits

signed_datasets = data.link_split(prob_val=0.1,
prob_test=0.1,
task='sign',
maintain_connect=True,
seed=seed,
splits=1)

# check that all nodes in validation and test have at least an edge in the training set
for split_id in signed_datasets:
val_nodes = torch.unique(torch.flatten(signed_datasets[split_id]['val']['edges']))
for node_id in val_nodes:
node_id_mask = torch.logical_or(signed_datasets[split_id]['train']['edges'][:, 0] == node_id,
signed_datasets[split_id]['train']['edges'][:, 1] == node_id)
assert node_id_mask.sum().item() > 0, f'[VAL] node id: {node_id} has no incident edges in training set'
test_nodes = torch.unique(torch.flatten(signed_datasets[split_id]['test']['edges']))
for node_id in test_nodes:
node_id_mask = torch.logical_or(signed_datasets[split_id]['train']['edges'][:, 0] == node_id,
signed_datasets[split_id]['train']['edges'][:, 1] == node_id)

assert node_id_mask.sum().item() > 0, f'[TEST] node id: {node_id} has no incident edges in training set'

def test_SSBM():
num_nodes = 1000
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def fast_appr_power(A, alpha=0.1, max_iter=100,
personalize = personalize.reshape(n, 1)
s = 1/(1+alpha)/n * personalize
z_T = ((alpha*(1+alpha)) * (r != 0) + ((1-alpha)/(1+alpha)+alpha*(1+alpha))
* (r == 0))[scipy.newaxis, :]
* (r == 0))[np.newaxis, :]
W = (1-alpha) * A.T @ D_1
x = s
oldx = np.zeros((n, 1))
Expand Down
5 changes: 3 additions & 2 deletions torch_geometric_signed_directed/utils/general/link_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ def link_class_split(data: torch_geometric.data.Data, size: int = None, splits:
G, algorithm="kruskal", data=False))
all_edges = list(map(tuple, all_edge_index))
mst_r = [t[::-1] for t in mst]
nmst = list(set(all_edges) - set(mst) - set(mst_r))
mst += mst_r
nmst = list(set(all_edges) - set(mst))
if len(nmst) < (len_val+len_test):
raise ValueError(
"There are no enough edges to be removed for validation/testing. Please use a smaller prob_test or prob_val.")
Expand Down Expand Up @@ -418,4 +419,4 @@ def link_class_split(data: torch_geometric.data.Data, size: int = None, splits:
datasets[ind]['test']['label'] = torch.from_numpy(
labels_test).long().to(device)
#datasets[ind]['test']['weight'] = torch.from_numpy(label_test_w).float().to(device)
return datasets
return datasets

0 comments on commit e91e095

Please sign in to comment.