Skip to content

Bi-directional Sampling for NeighborSampler #10126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 44 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
b4c097c
global indexing of homogeneous sampleoutput
Apr 15, 2025
6dd9538
global indexing of hetero sampleoutput
Apr 16, 2025
a7fc9f1
precommit
Apr 16, 2025
43dc8d1
changelog
zaristei Apr 16, 2025
b65d762
forgot to add null check for hetero
zaristei Apr 16, 2025
e6cbdf2
fix type annotations
Apr 22, 2025
8acbf58
bidirectional_sampler_wip
Mar 18, 2025
f4df732
fix csc transpose transform and merge outputs
Mar 18, 2025
fbc1c81
working bidirectional sampler!
Mar 19, 2025
afa499d
sanity test for neighborsampler
Mar 20, 2025
d504c87
add tests for reverse direction sampling
Mar 20, 2025
b81b878
tests for backwards sampling done
Mar 24, 2025
ef279bc
mark only neighborsampler
Mar 24, 2025
a86e3c0
add basic homo bidirectional test
Mar 25, 2025
7e4c724
bidirectional sampler bugfix and test adjustment
Apr 3, 2025
75d39dc
fix attribute error
Apr 3, 2025
ca0f173
fix some tests
Apr 3, 2025
0eab437
fix sampler merge to account for fact that all indices are local
Apr 11, 2025
b0271c7
precommit
Apr 11, 2025
e31cf14
lint
Apr 11, 2025
e5c1d6c
fix for disjoint sampling
Apr 12, 2025
caf4cdf
precommit
Apr 12, 2025
ea333e1
unittests mostly done for homogeneous sampleroutput, need to add coll…
Apr 15, 2025
a843e2c
Merge branch 'zaristei/sample_output_global_index' into zaristei/bidi…
zaristei Apr 22, 2025
89909c4
refactor to use new utility functions
Apr 16, 2025
916e595
precommit
Apr 16, 2025
ec535d3
final collate tests
Apr 16, 2025
9f2d797
precommit
Apr 16, 2025
a75d539
changelog
Apr 16, 2025
86f87e1
move graph drawing to fix precommit
Apr 16, 2025
f5333cf
fix labeling so that replace flag is accurate
Apr 22, 2025
d8425d7
replaced hetero test case with a slightly more complex example
Apr 22, 2025
3c96020
assert backwards time sampling not allowed
Apr 22, 2025
552cfa9
assert backwards weight sampling not allowed
Apr 22, 2025
29869f3
temporary undo switching edge types
Apr 22, 2025
2983a72
hetero sampler implemented but hitting issues with hetero edge cases
Apr 23, 2025
1714b9d
hetero finally working backwards!
Apr 25, 2025
f09ca30
precommit
Apr 25, 2025
6963cee
removed option for separate backwards neighbors to remove complexity
Apr 25, 2025
bc4f3e7
pseudocode implementation for homogeneous bidirectional
Apr 25, 2025
827d8e8
precommit
Apr 25, 2025
db36c2b
code works!
Apr 25, 2025
4e18c78
bidirectional disjoint working
Apr 26, 2025
e76a2ee
checked that weighted sampling is working
Apr 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `BidirectionalSampler`, which samples both forwards and backwards on graph edges ([#10126](https://github.com/pyg-team/pytorch_geometric/pull/10126))
- Enable Sampling both forwards and reverse edges on `NeighborSampler` ([#10126](https://github.com/pyg-team/pytorch_geometric/pull/10126))
- Added ability to merge together `SamplerOutput` objects ([#10126](https://github.com/pyg-team/pytorch_geometric/pull/10126))
- Added ability to get global row and col ids from `SamplerOutput` ([#10200](https://github.com/pyg-team/pytorch_geometric/pull/10200))
- Added PyTorch 2.6 support ([#10170](https://github.com/pyg-team/pytorch_geometric/pull/10170))
- Added support for heterogenous graphs in `PGExplainer` ([#10168](https://github.com/pyg-team/pytorch_geometric/pull/10168))
- Added support for heterogenous graphs in `GNNExplainer` ([#10158](https://github.com/pyg-team/pytorch_geometric/pull/10158))
Expand Down
310 changes: 310 additions & 0 deletions test/sampler/test_sampler_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
NumNeighbors,
SamplerOutput,
)
from torch_geometric.sampler.utils import global_to_local_node_idx
from torch_geometric.testing import get_random_edge_index
from torch_geometric.utils import is_undirected

Expand All @@ -30,6 +31,168 @@ def test_homogeneous_num_neighbors():
assert num_neighbors.num_hops == 2 # Test caching.


'''
Merge and collate tests use the following graph:

############# ###########
# Alice (0) # -> "works with" -> # Bob (1) #
############# ###########
|
v
"leads"
|
v
############# ############
# Carol (2) # -> "works with" -> # Dave (3) #
############# ############

'''


def _init_merge_sampler_outputs(hetero=False, disjoint=False):
if not hetero:
output1 = SamplerOutput(
node=torch.tensor([0, 1, 2]),
row=torch.tensor([0, 0]),
col=torch.tensor([1, 2]),
edge=torch.tensor([0, 1]),
batch=torch.tensor([0, 0, 0]) if disjoint else None,
num_sampled_nodes=list([1, 2]),
num_sampled_edges=list([2]),
orig_row=None,
orig_col=None,
metadata=(None, None),
)
output2 = SamplerOutput(
node=torch.tensor([0, 2, 3]),
row=torch.tensor([0, 1]),
col=torch.tensor([1, 2]),
edge=torch.tensor([1, 2]),
batch=torch.tensor([0, 0, 0]) if disjoint else None,
num_sampled_nodes=list([1, 1, 1]),
num_sampled_edges=list([1, 1]),
orig_row=None,
orig_col=None,
metadata=(None, None),
)

return output1, output2
else:
# TODO(zaristei)
raise NotImplementedError("Heterogeneous merge not implemented")


@pytest.mark.parametrize("disjoint", [True, False])
@pytest.mark.parametrize("bidirectional", [True, False])
def test_homogeneous_merge(disjoint, bidirectional):
"""Merge an output representing 1<-0->2 with one representing 0->2->3."""
output1, output2 = _init_merge_sampler_outputs(disjoint=disjoint)
if bidirectional:
output1 = output1.to_bidirectional(keep_orig_edges=True)
output2 = output2.to_bidirectional(keep_orig_edges=True)

expected_output = SamplerOutput(
node=torch.tensor([0, 1, 2, 3]),
row=torch.tensor([0, 0, 2]),
col=torch.tensor([1, 2, 3]),
edge=torch.tensor([0, 1, 2]),
batch=torch.tensor([0, 0, 0, 0]) if disjoint else None,
num_sampled_nodes=[1, 2, 0, 0, 1],
num_sampled_edges=[2, 0, 1],
orig_row=None,
orig_col=None,
metadata=[(None, None), (None, None)],
)
if bidirectional:
expected_output = expected_output.to_bidirectional(
keep_orig_edges=True)
merged_output = output1.merge_with(output2)

assert str(merged_output) == str(expected_output)


@pytest.mark.parametrize("disjoint", [True, False])
@pytest.mark.parametrize("bidirectional", [True, False])
def test_homogeneous_merge_no_replace(disjoint, bidirectional):
"""Merge an output representing 1<-0->2 with one representing 0->2->3.
replace=True makes it so that merged output is a simple concatenation
instead of removing already sampled nodes/edges.
"""
output1, output2 = _init_merge_sampler_outputs(disjoint=disjoint)
if bidirectional:
output1 = output1.to_bidirectional(keep_orig_edges=True)
output2 = output2.to_bidirectional(keep_orig_edges=True)

expected_output = SamplerOutput(
node=torch.tensor([0, 1, 2, 0, 2, 3]),
row=torch.tensor([0, 0, 3, 4]),
col=torch.tensor([1, 2, 4, 5]),
edge=torch.tensor([0, 1, 1, 2]),
batch=torch.tensor([0, 0, 0, 3, 3, 3]) if disjoint else None,
num_sampled_nodes=[1, 2, 1, 1, 1],
num_sampled_edges=[2, 1, 1],
orig_row=None,
orig_col=None,
metadata=[(None, None), (None, None)],
)
if bidirectional:
expected_output = expected_output.to_bidirectional(
keep_orig_edges=True)
merged_output = output1.merge_with(output2, replace=False)

assert str(merged_output) == str(expected_output)


def _init_collate_sampler_outputs(disjoint=False):
output1, output2 = _init_merge_sampler_outputs(disjoint=disjoint)
# new edge not present in graph above
output3 = SamplerOutput(
node=torch.tensor([3, 4]),
row=torch.tensor([0]),
col=torch.tensor([1]),
edge=torch.tensor([3]),
batch=torch.tensor([0, 0]) if disjoint else None,
num_sampled_nodes=list([1, 1]),
num_sampled_edges=list([1]),
orig_row=None,
orig_col=None,
metadata=(None, None),
)
return [output1, output2, output3]


@pytest.mark.parametrize("replace", [True, False])
@pytest.mark.parametrize("disjoint", [True, False])
def test_homogeneous_collate(disjoint, replace):
output1, output2, output3 = _init_collate_sampler_outputs(disjoint)
collated = SamplerOutput.collate([output1, output2, output3],
replace=replace)
assert str(collated) == str(
(output1.merge_with(output2, replace=replace)).merge_with(
output3, replace=replace))


def test_homogeneous_collate_empty():
with pytest.raises(ValueError,
match="Cannot collate an empty list of SamplerOutputs"):
SamplerOutput.collate([])


def test_homogeneous_collate_single():
output, _ = _init_merge_sampler_outputs()
collated = SamplerOutput.collate([output])
assert str(collated) == str(output)


def test_homogeneous_collate_missing_fields():
output1, output2, output3 = _init_collate_sampler_outputs()
output3.edge = None
with pytest.raises(
ValueError,
match="Output 3 has a different field than the first output"):
SamplerOutput.collate([output1, output2, output3])


def test_heterogeneous_num_neighbors_list():
num_neighbors = NumNeighbors([25, 10])

Expand Down Expand Up @@ -119,3 +282,150 @@ def test_heterogeneous_to_bidirectional():
)
assert is_undirected(
torch.stack([obj.row['v1', 'to', 'v1'], obj.col['v1', 'to', 'v1']], 0))


def test_homogeneous_sampler_output_global_fields():
output = SamplerOutput(
node=torch.tensor([0, 2, 3]),
row=torch.tensor([0, 1]),
col=torch.tensor([1, 2]),
edge=torch.tensor([1, 2]),
batch=torch.tensor([0, 0, 0]),
num_sampled_nodes=[1, 1, 1],
num_sampled_edges=[1, 1],
orig_row=None,
orig_col=None,
metadata=(None, None),
)

local_values = []
global_values = []

global_row, global_col = output.global_row, output.global_col
assert torch.equal(global_row, torch.tensor([0, 2]))
assert torch.equal(global_col, torch.tensor([2, 3]))
local_values.append(output.row)
local_values.append(output.col)
global_values.append(global_row)
global_values.append(global_col)

seed_node = output.seed_node
assert torch.equal(seed_node, torch.tensor([0, 0, 0]))
local_values.append(output.batch)
global_values.append(seed_node)

output_bidirectional = output.to_bidirectional(keep_orig_edges=True)
global_bidir_row, global_bidir_col = \
output_bidirectional.global_row, output_bidirectional.global_col
assert torch.equal(global_bidir_row, torch.tensor([2, 0, 3, 2]))
assert torch.equal(global_bidir_col, torch.tensor([0, 2, 2, 3]))
local_values.append(output_bidirectional.row)
local_values.append(output_bidirectional.col)
global_values.append(global_bidir_row)
global_values.append(global_bidir_col)

assert torch.equal(output.global_row, output_bidirectional.global_orig_row)
assert torch.equal(output.global_col, output_bidirectional.global_orig_col)

# Make sure reverse mapping is correct
for local_value, global_value in zip(local_values, global_values):
assert torch.equal(global_to_local_node_idx(output.node, global_value),
local_value)


def test_heterogeneous_sampler_output_global_fields():
def _tensor_dict_equal(dict1, dict2):
is_equal = True
is_equal &= dict1.keys() == dict2.keys()
for key in dict1.keys():
is_equal &= torch.equal(dict1[key], dict2[key])
return is_equal

output = HeteroSamplerOutput(
node={"person": torch.tensor([0, 2, 3])},
row={
("person", "works_with", "person"): torch.tensor([1]),
("person", "leads", "person"): torch.tensor([0])
},
col={
("person", "works_with", "person"): torch.tensor([2]),
("person", "leads", "person"): torch.tensor([1])
},
edge={
("person", "works_with", "person"): torch.tensor([1]),
("person", "leads", "person"): torch.tensor([0])
},
batch={"person": torch.tensor([0, 0, 0])},
num_sampled_nodes={"person": torch.tensor([1, 1, 1])},
num_sampled_edges={
("person", "works_with", "person"): torch.tensor([1]),
("person", "leads", "person"): torch.tensor([1])
},
orig_row=None,
orig_col=None,
metadata=(None, None),
)

global_row, global_col = output.global_row, output.global_col
assert _tensor_dict_equal(
global_row, {
("person", "works_with", "person"): torch.tensor([2]),
("person", "leads", "person"): torch.tensor([0])
})
assert _tensor_dict_equal(
global_col, {
("person", "works_with", "person"): torch.tensor([3]),
("person", "leads", "person"): torch.tensor([2])
})

local_row_dict = {
k: global_to_local_node_idx(output.node[k[0]], v)
for k, v in global_row.items()
}
assert _tensor_dict_equal(local_row_dict, output.row)

local_col_dict = {
k: global_to_local_node_idx(output.node[k[2]], v)
for k, v in global_col.items()
}
assert _tensor_dict_equal(local_col_dict, output.col)

seed_node = output.seed_node
assert _tensor_dict_equal(seed_node, {"person": torch.tensor([0, 0, 0])})

local_batch_dict = {
k: global_to_local_node_idx(output.node[k], v)
for k, v in seed_node.items()
}
assert _tensor_dict_equal(local_batch_dict, output.batch)

output_bidirectional = output.to_bidirectional(keep_orig_edges=True)
global_bidir_row, global_bidir_col = \
output_bidirectional.global_row, output_bidirectional.global_col
assert _tensor_dict_equal(
global_bidir_row, {
("person", "works_with", "person"): torch.tensor([3, 2]),
("person", "leads", "person"): torch.tensor([2, 0])
})
assert _tensor_dict_equal(
global_bidir_col, {
("person", "works_with", "person"): torch.tensor([2, 3]),
("person", "leads", "person"): torch.tensor([0, 2])
})

local_bidir_row_dict = {
k: global_to_local_node_idx(output_bidirectional.node[k[0]], v)
for k, v in global_bidir_row.items()
}
assert _tensor_dict_equal(local_bidir_row_dict, output_bidirectional.row)

local_bidir_col_dict = {
k: global_to_local_node_idx(output_bidirectional.node[k[2]], v)
for k, v in global_bidir_col.items()
}
assert _tensor_dict_equal(local_bidir_col_dict, output_bidirectional.col)

assert _tensor_dict_equal(output.global_row,
output_bidirectional.global_orig_row)
assert _tensor_dict_equal(output.global_col,
output_bidirectional.global_orig_col)
Loading