Skip to content

Commit dce2fe6

Browse files
Zack Aristeizaristei
authored andcommitted
precommit
1 parent 089389b commit dce2fe6

File tree

3 files changed

+132
-51
lines changed

3 files changed

+132
-51
lines changed

test/sampler/test_sampler_base.py

Lines changed: 72 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
NumNeighbors,
77
SamplerOutput,
88
)
9-
from torch_geometric.sampler.utils import local_to_global_node_idx, global_to_local_node_idx
9+
from torch_geometric.sampler.utils import global_to_local_node_idx
1010
from torch_geometric.testing import get_random_edge_index
1111
from torch_geometric.utils import is_undirected
1212

@@ -121,6 +121,7 @@ def test_heterogeneous_to_bidirectional():
121121
assert is_undirected(
122122
torch.stack([obj.row['v1', 'to', 'v1'], obj.col['v1', 'to', 'v1']], 0))
123123

124+
124125
def test_homogeneous_sampler_output_global_fields():
125126
output = SamplerOutput(
126127
node=torch.tensor([0, 2, 3]),
@@ -152,7 +153,8 @@ def test_homogeneous_sampler_output_global_fields():
152153
global_values.append(seed_node)
153154

154155
output_bidirectional = output.to_bidirectional(keep_orig_edges=True)
155-
global_bidir_row, global_bidir_col = output_bidirectional.global_row, output_bidirectional.global_col
156+
global_bidir_row, global_bidir_col = \
157+
output_bidirectional.global_row, output_bidirectional.global_col
156158
assert torch.equal(global_bidir_row, torch.tensor([2, 0, 3, 2]))
157159
assert torch.equal(global_bidir_col, torch.tensor([0, 2, 2, 3]))
158160
local_values.append(output_bidirectional.row)
@@ -162,10 +164,12 @@ def test_homogeneous_sampler_output_global_fields():
162164

163165
assert torch.equal(output.global_row, output_bidirectional.global_orig_row)
164166
assert torch.equal(output.global_col, output_bidirectional.global_orig_col)
165-
167+
166168
# Make sure reverse mapping is correct
167169
for local_value, global_value in zip(local_values, global_values):
168-
assert torch.equal(global_to_local_node_idx(output.node, global_value), local_value)
170+
assert torch.equal(global_to_local_node_idx(output.node, global_value),
171+
local_value)
172+
169173

170174
def test_heterogeneous_sampler_output_global_fields():
171175
def _tensor_dict_equal(dict1, dict2):
@@ -177,46 +181,89 @@ def _tensor_dict_equal(dict1, dict2):
177181

178182
output = HeteroSamplerOutput(
179183
node={"person": torch.tensor([0, 2, 3])},
180-
row={("person", "works_with", "person"): torch.tensor([1]), ("person", "leads", "person"): torch.tensor([0])},
181-
col={("person", "works_with", "person"): torch.tensor([2]), ("person", "leads", "person"): torch.tensor([1])},
182-
edge={("person", "works_with", "person"): torch.tensor([1]), ("person", "leads", "person"): torch.tensor([0])},
184+
row={
185+
("person", "works_with", "person"): torch.tensor([1]),
186+
("person", "leads", "person"): torch.tensor([0])
187+
},
188+
col={
189+
("person", "works_with", "person"): torch.tensor([2]),
190+
("person", "leads", "person"): torch.tensor([1])
191+
},
192+
edge={
193+
("person", "works_with", "person"): torch.tensor([1]),
194+
("person", "leads", "person"): torch.tensor([0])
195+
},
183196
batch={"person": torch.tensor([0, 0, 0])},
184197
num_sampled_nodes={"person": torch.tensor([1, 1, 1])},
185-
num_sampled_edges={("person", "works_with", "person"): torch.tensor([1]), ("person", "leads", "person"): torch.tensor([1])},
198+
num_sampled_edges={
199+
("person", "works_with", "person"): torch.tensor([1]),
200+
("person", "leads", "person"): torch.tensor([1])
201+
},
186202
orig_row=None,
187203
orig_col=None,
188204
metadata=(None, None),
189205
)
190206

191-
local_values = []
192-
global_values = []
193-
194207
global_row, global_col = output.global_row, output.global_col
195-
assert _tensor_dict_equal(global_row, {("person", "works_with", "person"): torch.tensor([2]), ("person", "leads", "person"): torch.tensor([0])})
196-
assert _tensor_dict_equal(global_col, {("person", "works_with", "person"): torch.tensor([3]), ("person", "leads", "person"): torch.tensor([2])})
197-
198-
local_row_dict = {k: global_to_local_node_idx(output.node[k[0]], v) for k, v in global_row.items()}
208+
assert _tensor_dict_equal(
209+
global_row, {
210+
("person", "works_with", "person"): torch.tensor([2]),
211+
("person", "leads", "person"): torch.tensor([0])
212+
})
213+
assert _tensor_dict_equal(
214+
global_col, {
215+
("person", "works_with", "person"): torch.tensor([3]),
216+
("person", "leads", "person"): torch.tensor([2])
217+
})
218+
219+
local_row_dict = {
220+
k: global_to_local_node_idx(output.node[k[0]], v)
221+
for k, v in global_row.items()
222+
}
199223
assert _tensor_dict_equal(local_row_dict, output.row)
200224

201-
local_col_dict = {k: global_to_local_node_idx(output.node[k[2]], v) for k, v in global_col.items()}
225+
local_col_dict = {
226+
k: global_to_local_node_idx(output.node[k[2]], v)
227+
for k, v in global_col.items()
228+
}
202229
assert _tensor_dict_equal(local_col_dict, output.col)
203230

204231
seed_node = output.seed_node
205232
assert _tensor_dict_equal(seed_node, {"person": torch.tensor([0, 0, 0])})
206233

207-
local_batch_dict = {k: global_to_local_node_idx(output.node[k], v) for k, v in seed_node.items()}
234+
local_batch_dict = {
235+
k: global_to_local_node_idx(output.node[k], v)
236+
for k, v in seed_node.items()
237+
}
208238
assert _tensor_dict_equal(local_batch_dict, output.batch)
209239

210240
output_bidirectional = output.to_bidirectional(keep_orig_edges=True)
211-
global_bidir_row, global_bidir_col = output_bidirectional.global_row, output_bidirectional.global_col
212-
assert _tensor_dict_equal(global_bidir_row, {("person", "works_with", "person"): torch.tensor([3, 2]), ("person", "leads", "person"): torch.tensor([2, 0])})
213-
assert _tensor_dict_equal(global_bidir_col, {("person", "works_with", "person"): torch.tensor([2, 3]), ("person", "leads", "person"): torch.tensor([0, 2])})
214-
215-
local_bidir_row_dict = {k: global_to_local_node_idx(output_bidirectional.node[k[0]], v) for k, v in global_bidir_row.items()}
241+
global_bidir_row, global_bidir_col = \
242+
output_bidirectional.global_row, output_bidirectional.global_col
243+
assert _tensor_dict_equal(
244+
global_bidir_row, {
245+
("person", "works_with", "person"): torch.tensor([3, 2]),
246+
("person", "leads", "person"): torch.tensor([2, 0])
247+
})
248+
assert _tensor_dict_equal(
249+
global_bidir_col, {
250+
("person", "works_with", "person"): torch.tensor([2, 3]),
251+
("person", "leads", "person"): torch.tensor([0, 2])
252+
})
253+
254+
local_bidir_row_dict = {
255+
k: global_to_local_node_idx(output_bidirectional.node[k[0]], v)
256+
for k, v in global_bidir_row.items()
257+
}
216258
assert _tensor_dict_equal(local_bidir_row_dict, output_bidirectional.row)
217259

218-
local_bidir_col_dict = {k: global_to_local_node_idx(output_bidirectional.node[k[2]], v) for k, v in global_bidir_col.items()}
260+
local_bidir_col_dict = {
261+
k: global_to_local_node_idx(output_bidirectional.node[k[2]], v)
262+
for k, v in global_bidir_col.items()
263+
}
219264
assert _tensor_dict_equal(local_bidir_col_dict, output_bidirectional.col)
220265

221-
assert _tensor_dict_equal(output.global_row, output_bidirectional.global_orig_row)
222-
assert _tensor_dict_equal(output.global_col, output_bidirectional.global_orig_col)
266+
assert _tensor_dict_equal(output.global_row,
267+
output_bidirectional.global_orig_row)
268+
assert _tensor_dict_equal(output.global_col,
269+
output_bidirectional.global_orig_col)

torch_geometric/sampler/base.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
from torch import Tensor
1212

1313
from torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData
14-
from torch_geometric.sampler.utils import to_bidirectional, local_to_global_node_idx
14+
from torch_geometric.sampler.utils import (
15+
local_to_global_node_idx,
16+
to_bidirectional,
17+
)
1518
from torch_geometric.typing import EdgeType, EdgeTypeStr, NodeType, OptTensor
1619
from torch_geometric.utils.mixin import CastMixin
1720

@@ -214,19 +217,22 @@ def global_row(self) -> Tensor:
214217
@property
215218
def global_col(self) -> Tensor:
216219
return local_to_global_node_idx(self.node, self.col)
217-
220+
218221
@property
219222
def seed_node(self) -> Tensor:
220-
return local_to_global_node_idx(self.node, self.batch) if self.batch is not None else None
221-
223+
return local_to_global_node_idx(
224+
self.node, self.batch) if self.batch is not None else None
225+
222226
@property
223227
def global_orig_row(self) -> Tensor:
224-
return local_to_global_node_idx(self.node, self.orig_row) if self.orig_row is not None else None
225-
228+
return local_to_global_node_idx(
229+
self.node, self.orig_row) if self.orig_row is not None else None
230+
226231
@property
227232
def global_orig_col(self) -> Tensor:
228-
return local_to_global_node_idx(self.node, self.orig_col) if self.orig_col is not None else None
229-
233+
return local_to_global_node_idx(
234+
self.node, self.orig_col) if self.orig_col is not None else None
235+
230236
def to_bidirectional(
231237
self,
232238
keep_orig_edges: bool = False,
@@ -316,23 +322,40 @@ class HeteroSamplerOutput(CastMixin):
316322

317323
@property
318324
def global_row(self) -> Tensor:
319-
return {edge_type: local_to_global_node_idx(self.node[edge_type[0]], row) for edge_type, row in self.row.items()}
325+
return {
326+
edge_type: local_to_global_node_idx(self.node[edge_type[0]], row)
327+
for edge_type, row in self.row.items()
328+
}
320329

321330
@property
322331
def global_col(self) -> Tensor:
323-
return {edge_type: local_to_global_node_idx(self.node[edge_type[2]], col) for edge_type, col in self.col.items()}
324-
332+
return {
333+
edge_type: local_to_global_node_idx(self.node[edge_type[2]], col)
334+
for edge_type, col in self.col.items()
335+
}
336+
325337
@property
326338
def seed_node(self) -> Tensor:
327-
return {node_type: local_to_global_node_idx(self.node[node_type], batch) for node_type, batch in self.batch.items()}
328-
339+
return {
340+
node_type: local_to_global_node_idx(self.node[node_type], batch)
341+
for node_type, batch in self.batch.items()
342+
}
343+
329344
@property
330345
def global_orig_row(self) -> Tensor:
331-
return {edge_type: local_to_global_node_idx(self.node[edge_type[0]], orig_row) for edge_type, orig_row in self.orig_row.items()}
332-
346+
return {
347+
edge_type: local_to_global_node_idx(self.node[edge_type[0]],
348+
orig_row)
349+
for edge_type, orig_row in self.orig_row.items()
350+
}
351+
333352
@property
334353
def global_orig_col(self) -> Tensor:
335-
return {edge_type: local_to_global_node_idx(self.node[edge_type[2]], orig_col) for edge_type, orig_col in self.orig_col.items()}
354+
return {
355+
edge_type: local_to_global_node_idx(self.node[edge_type[2]],
356+
orig_col)
357+
for edge_type, orig_col in self.orig_col.items()
358+
}
336359

337360
def to_bidirectional(
338361
self,

torch_geometric/sampler/utils.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -162,35 +162,46 @@ def remap_keys(
162162
}
163163

164164

165-
def local_to_global_node_idx(node_values: Tensor, local_indices: Tensor) -> Tensor:
166-
"""Convert a tensor of indices referring to elements in the node_values tensor to their values.
165+
def local_to_global_node_idx(node_values: Tensor,
166+
local_indices: Tensor) -> Tensor:
167+
"""Convert a tensor of indices referring to elements in the node_values
168+
tensor to their values.
167169
168170
Args:
169171
node_values (Tensor): The node values. (num_nodes, feature_dim)
170172
local_indices (Tensor): The local indices. (num_indices)
171173
172174
Returns:
173-
Tensor: The values of the node_values tensor at the local indices. (num_indices, feature_dim)
175+
Tensor: The values of the node_values tensor at the local indices.
176+
(num_indices, feature_dim)
174177
"""
175178
return torch.index_select(node_values, dim=0, index=local_indices)
176179

177-
def global_to_local_node_idx(node_values: Tensor, local_values: Tensor) -> Tensor:
178-
"""Converts a tensor of values that are contained in the node_values tensor to their indices in that tensor.
180+
181+
def global_to_local_node_idx(node_values: Tensor,
182+
local_values: Tensor) -> Tensor:
183+
"""Converts a tensor of values that are contained in the node_values
184+
tensor to their indices in that tensor.
179185
180186
Args:
181187
node_values (Tensor): The node values. (num_nodes, feature_dim)
182188
local_values (Tensor): The local values. (num_indices, feature_dim)
183189
184190
Returns:
185-
Tensor: The indices of the local values in the node_values tensor. (num_indices)
191+
Tensor: The indices of the local values in the node_values tensor.
192+
(num_indices)
186193
"""
187194
if node_values.dim() == 1:
188195
node_values = node_values.unsqueeze(1)
189196
if local_values.dim() == 1:
190197
local_values = local_values.unsqueeze(1)
191-
node_values_expand = node_values.unsqueeze(-1).expand(*node_values.shape, local_values.shape[0]) # (num_nodes, feature_dim, num_indices)
192-
local_values_expand = local_values.transpose(0, 1).unsqueeze(0).expand(*node_values_expand.shape) # (num_nodes, feature_dim, num_indices)
193-
idx_match = torch.all(node_values_expand == local_values_expand, dim=1).nonzero() # (num_indices, 2)
198+
node_values_expand = node_values.unsqueeze(-1).expand(
199+
*node_values.shape,
200+
local_values.shape[0]) # (num_nodes, feature_dim, num_indices)
201+
local_values_expand = local_values.transpose(0, 1).unsqueeze(0).expand(
202+
*node_values_expand.shape) # (num_nodes, feature_dim, num_indices)
203+
idx_match = torch.all(node_values_expand == local_values_expand,
204+
dim=1).nonzero() # (num_indices, 2)
194205
sort_idx = torch.argsort(idx_match[:, 1])
195206

196-
return idx_match[:, 0][sort_idx]
207+
return idx_match[:, 0][sort_idx]

0 commit comments

Comments
 (0)