-
Hello, totally new here, and have been dabbling with GNNs and pytorch geometric for the past few months. Firstly, wanted to say its an amazing toolkit, so thank you for making this available. I'm working on trying to train a GNN to learn to make graph level predictions, where I can pass the GNN a graph expressed as a HeteroData, and in return it will give me a single floating point value prediction. I have written something which works by individually doing a forward call for each graph training data point, collecting the output into a list, then doing a MSELoss against another list where I have the target (y) values that they should be trained to. It works, but its quite slow, and I'm trying to learn how to do it properly with batching, but I'm encountering some problems. so currently, I have something like this. train_data = []
for en in range(nrOfExamples):
idata = input.exampleTraining[en]
gdata = HeteroData()
gdata['node'].x = torch.tensor(idata.nodes.x)
gdata['node', 'dependson', 'node'].edge_index =....
....
etc
train_data.append(gdata)
loader = DataLoader(train_data, batch_size=100)
for epoch in range(nrTrainingEpochs):
batch = next(iter(loader))
... call the GNN forward and run training optimizer... In my GNN, I'm using some global pooling which I'd like to use to get the single value graph level prediction, however, I am not sure how I can get the batch_index array from the batch. I've seen some examples online where "batch.batch" will return an array indicating which node belongs to which graph returned by the DataLoader. However, when I try calling it, I get this error
I'm sure there is something I am not doing right that is simple, so hoping to get some help from the knowledgeable people here. Thank you in advance. Oh, and a second question, when next(iter(loader)) returns a batch, does it always include all nodes from a randomly chosen subset of graphs, or does it actually select a subset of the nodes from a subset of graph and return those? I'm looking for the DataLoader to always return me complete set of nodes from a randomly chosen subset of graphs. Is DataLoader the right thing to use? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Never mind. I just found the "batch" array. Wasn't looking hard enough, or thinking through it properly. :) For those who may have the same question in the future, in my example, it would be access via gdata['node'].batch |
Beta Was this translation helpful? Give feedback.
-
当需要输入全图不同节点的batch时,我直接尝试将data.batch改为data.batch_dict,是有效的 |
Beta Was this translation helpful? Give feedback.
Never mind.
I just found the "batch" array. Wasn't looking hard enough, or thinking through it properly. :)
For those who may have the same question in the future, in my example, it would be access via
gdata['node'].batch