6
6
NumNeighbors ,
7
7
SamplerOutput ,
8
8
)
9
- from torch_geometric .sampler .utils import local_to_global_node_idx , global_to_local_node_idx
9
+ from torch_geometric .sampler .utils import global_to_local_node_idx
10
10
from torch_geometric .testing import get_random_edge_index
11
11
from torch_geometric .utils import is_undirected
12
12
@@ -121,6 +121,7 @@ def test_heterogeneous_to_bidirectional():
121
121
assert is_undirected (
122
122
torch .stack ([obj .row ['v1' , 'to' , 'v1' ], obj .col ['v1' , 'to' , 'v1' ]], 0 ))
123
123
124
+
124
125
def test_homogeneous_sampler_output_global_fields ():
125
126
output = SamplerOutput (
126
127
node = torch .tensor ([0 , 2 , 3 ]),
@@ -152,7 +153,8 @@ def test_homogeneous_sampler_output_global_fields():
152
153
global_values .append (seed_node )
153
154
154
155
output_bidirectional = output .to_bidirectional (keep_orig_edges = True )
155
- global_bidir_row , global_bidir_col = output_bidirectional .global_row , output_bidirectional .global_col
156
+ global_bidir_row , global_bidir_col = \
157
+ output_bidirectional .global_row , output_bidirectional .global_col
156
158
assert torch .equal (global_bidir_row , torch .tensor ([2 , 0 , 3 , 2 ]))
157
159
assert torch .equal (global_bidir_col , torch .tensor ([0 , 2 , 2 , 3 ]))
158
160
local_values .append (output_bidirectional .row )
@@ -162,10 +164,12 @@ def test_homogeneous_sampler_output_global_fields():
162
164
163
165
assert torch .equal (output .global_row , output_bidirectional .global_orig_row )
164
166
assert torch .equal (output .global_col , output_bidirectional .global_orig_col )
165
-
167
+
166
168
# Make sure reverse mapping is correct
167
169
for local_value , global_value in zip (local_values , global_values ):
168
- assert torch .equal (global_to_local_node_idx (output .node , global_value ), local_value )
170
+ assert torch .equal (global_to_local_node_idx (output .node , global_value ),
171
+ local_value )
172
+
169
173
170
174
def test_heterogeneous_sampler_output_global_fields ():
171
175
def _tensor_dict_equal (dict1 , dict2 ):
@@ -177,46 +181,89 @@ def _tensor_dict_equal(dict1, dict2):
177
181
178
182
output = HeteroSamplerOutput (
179
183
node = {"person" : torch .tensor ([0 , 2 , 3 ])},
180
- row = {("person" , "works_with" , "person" ): torch .tensor ([1 ]), ("person" , "leads" , "person" ): torch .tensor ([0 ])},
181
- col = {("person" , "works_with" , "person" ): torch .tensor ([2 ]), ("person" , "leads" , "person" ): torch .tensor ([1 ])},
182
- edge = {("person" , "works_with" , "person" ): torch .tensor ([1 ]), ("person" , "leads" , "person" ): torch .tensor ([0 ])},
184
+ row = {
185
+ ("person" , "works_with" , "person" ): torch .tensor ([1 ]),
186
+ ("person" , "leads" , "person" ): torch .tensor ([0 ])
187
+ },
188
+ col = {
189
+ ("person" , "works_with" , "person" ): torch .tensor ([2 ]),
190
+ ("person" , "leads" , "person" ): torch .tensor ([1 ])
191
+ },
192
+ edge = {
193
+ ("person" , "works_with" , "person" ): torch .tensor ([1 ]),
194
+ ("person" , "leads" , "person" ): torch .tensor ([0 ])
195
+ },
183
196
batch = {"person" : torch .tensor ([0 , 0 , 0 ])},
184
197
num_sampled_nodes = {"person" : torch .tensor ([1 , 1 , 1 ])},
185
- num_sampled_edges = {("person" , "works_with" , "person" ): torch .tensor ([1 ]), ("person" , "leads" , "person" ): torch .tensor ([1 ])},
198
+ num_sampled_edges = {
199
+ ("person" , "works_with" , "person" ): torch .tensor ([1 ]),
200
+ ("person" , "leads" , "person" ): torch .tensor ([1 ])
201
+ },
186
202
orig_row = None ,
187
203
orig_col = None ,
188
204
metadata = (None , None ),
189
205
)
190
206
191
- local_values = []
192
- global_values = []
193
-
194
207
global_row , global_col = output .global_row , output .global_col
195
- assert _tensor_dict_equal (global_row , {("person" , "works_with" , "person" ): torch .tensor ([2 ]), ("person" , "leads" , "person" ): torch .tensor ([0 ])})
196
- assert _tensor_dict_equal (global_col , {("person" , "works_with" , "person" ): torch .tensor ([3 ]), ("person" , "leads" , "person" ): torch .tensor ([2 ])})
197
-
198
- local_row_dict = {k : global_to_local_node_idx (output .node [k [0 ]], v ) for k , v in global_row .items ()}
208
+ assert _tensor_dict_equal (
209
+ global_row , {
210
+ ("person" , "works_with" , "person" ): torch .tensor ([2 ]),
211
+ ("person" , "leads" , "person" ): torch .tensor ([0 ])
212
+ })
213
+ assert _tensor_dict_equal (
214
+ global_col , {
215
+ ("person" , "works_with" , "person" ): torch .tensor ([3 ]),
216
+ ("person" , "leads" , "person" ): torch .tensor ([2 ])
217
+ })
218
+
219
+ local_row_dict = {
220
+ k : global_to_local_node_idx (output .node [k [0 ]], v )
221
+ for k , v in global_row .items ()
222
+ }
199
223
assert _tensor_dict_equal (local_row_dict , output .row )
200
224
201
- local_col_dict = {k : global_to_local_node_idx (output .node [k [2 ]], v ) for k , v in global_col .items ()}
225
+ local_col_dict = {
226
+ k : global_to_local_node_idx (output .node [k [2 ]], v )
227
+ for k , v in global_col .items ()
228
+ }
202
229
assert _tensor_dict_equal (local_col_dict , output .col )
203
230
204
231
seed_node = output .seed_node
205
232
assert _tensor_dict_equal (seed_node , {"person" : torch .tensor ([0 , 0 , 0 ])})
206
233
207
- local_batch_dict = {k : global_to_local_node_idx (output .node [k ], v ) for k , v in seed_node .items ()}
234
+ local_batch_dict = {
235
+ k : global_to_local_node_idx (output .node [k ], v )
236
+ for k , v in seed_node .items ()
237
+ }
208
238
assert _tensor_dict_equal (local_batch_dict , output .batch )
209
239
210
240
output_bidirectional = output .to_bidirectional (keep_orig_edges = True )
211
- global_bidir_row , global_bidir_col = output_bidirectional .global_row , output_bidirectional .global_col
212
- assert _tensor_dict_equal (global_bidir_row , {("person" , "works_with" , "person" ): torch .tensor ([3 , 2 ]), ("person" , "leads" , "person" ): torch .tensor ([2 , 0 ])})
213
- assert _tensor_dict_equal (global_bidir_col , {("person" , "works_with" , "person" ): torch .tensor ([2 , 3 ]), ("person" , "leads" , "person" ): torch .tensor ([0 , 2 ])})
214
-
215
- local_bidir_row_dict = {k : global_to_local_node_idx (output_bidirectional .node [k [0 ]], v ) for k , v in global_bidir_row .items ()}
241
+ global_bidir_row , global_bidir_col = \
242
+ output_bidirectional .global_row , output_bidirectional .global_col
243
+ assert _tensor_dict_equal (
244
+ global_bidir_row , {
245
+ ("person" , "works_with" , "person" ): torch .tensor ([3 , 2 ]),
246
+ ("person" , "leads" , "person" ): torch .tensor ([2 , 0 ])
247
+ })
248
+ assert _tensor_dict_equal (
249
+ global_bidir_col , {
250
+ ("person" , "works_with" , "person" ): torch .tensor ([2 , 3 ]),
251
+ ("person" , "leads" , "person" ): torch .tensor ([0 , 2 ])
252
+ })
253
+
254
+ local_bidir_row_dict = {
255
+ k : global_to_local_node_idx (output_bidirectional .node [k [0 ]], v )
256
+ for k , v in global_bidir_row .items ()
257
+ }
216
258
assert _tensor_dict_equal (local_bidir_row_dict , output_bidirectional .row )
217
259
218
- local_bidir_col_dict = {k : global_to_local_node_idx (output_bidirectional .node [k [2 ]], v ) for k , v in global_bidir_col .items ()}
260
+ local_bidir_col_dict = {
261
+ k : global_to_local_node_idx (output_bidirectional .node [k [2 ]], v )
262
+ for k , v in global_bidir_col .items ()
263
+ }
219
264
assert _tensor_dict_equal (local_bidir_col_dict , output_bidirectional .col )
220
265
221
- assert _tensor_dict_equal (output .global_row , output_bidirectional .global_orig_row )
222
- assert _tensor_dict_equal (output .global_col , output_bidirectional .global_orig_col )
266
+ assert _tensor_dict_equal (output .global_row ,
267
+ output_bidirectional .global_orig_row )
268
+ assert _tensor_dict_equal (output .global_col ,
269
+ output_bidirectional .global_orig_col )
0 commit comments