1
- # Copyright (c) 2024, NVIDIA CORPORATION.
1
+ # Copyright (c) 2024-2025 , NVIDIA CORPORATION.
2
2
# Licensed under the Apache License, Version 2.0 (the "License");
3
3
# you may not use this file except in compliance with the License.
4
4
# You may obtain a copy of the License at
@@ -93,7 +93,11 @@ def train(epoch, model, optimizer, train_loader, edge_feature_store, num_steps=N
93
93
optimizer .zero_grad ()
94
94
95
95
for i , batch in enumerate (train_loader ):
96
- r = edge_feature_store [("n" , "e" , "n" ), "rel" ][batch .e_id ].flatten ().cuda ()
96
+ r = (
97
+ edge_feature_store [("n" , "e" , "n" ), "rel" , None ][batch .e_id ]
98
+ .flatten ()
99
+ .cuda ()
100
+ )
97
101
z = model .encode (batch .edge_index , r )
98
102
99
103
loss = model .recon_loss (z , batch .edge_index )
@@ -301,13 +305,18 @@ def load_partitioned_data(rank, edge_path, rel_path, pos_path, neg_path, meta_pa
301
305
feature_store = TensorDictFeatureStore ()
302
306
edge_feature_store = WholeFeatureStore ()
303
307
308
+ with open (meta_path , "r" ) as f :
309
+ meta = json .load (f )
310
+
311
+ print ("num nodes:" , meta ["num_nodes" ])
312
+
304
313
# Load edge index
305
- graph_store [( "n" , "e" , "n" ), "coo" ] = torch . load (
306
- os . path . join ( edge_path , f"rank= { rank } .pt" )
307
- )
314
+ graph_store [
315
+ ( "n" , "e" , "n" ), "coo" , False , ( meta [ "num_nodes" ], meta [ "num_nodes" ] )
316
+ ] = torch . load ( os . path . join ( edge_path , f"rank= { rank } .pt" ) )
308
317
309
318
# Load edge rel type
310
- edge_feature_store [("n" , "e" , "n" ), "rel" ] = torch .load (
319
+ edge_feature_store [("n" , "e" , "n" ), "rel" , None ] = torch .load (
311
320
os .path .join (rel_path , f"rank={ rank } .pt" )
312
321
)
313
322
@@ -333,9 +342,6 @@ def load_partitioned_data(rank, edge_path, rel_path, pos_path, neg_path, meta_pa
333
342
splits [stage ]["tail_neg" ] = tail_neg
334
343
splits [stage ]["relation" ] = relation
335
344
336
- with open (meta_path , "r" ) as f :
337
- meta = json .load (f )
338
-
339
345
return (feature_store , graph_store ), edge_feature_store , splits , meta
340
346
341
347
0 commit comments