You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to use one of the pre-packaged models (e.g. torch_geometric.nn.models.GraphSAGE) for a graph-level regression task.
Normally, if I am constructing my own model using existing layers, I would add a pooling layer at the end (e.g. global_max_pool) to get a graph-level output instead of a node-level output. However, for the pre-packaged models, I do not know how to achieve this. Setting out_channels parameter in GraphSAGE to 1 merely makes the dimension of the output tensor to be [some_big_number,1] which I presume to be [num_nodes_in_batch,1]. I tried using torch.nn.Sequential to chain the pooling layer to the output of GraphSAGE errors like "global_max_pool is not a Module subclass".
Apologies if this is trivial or not the right forum -- I am new to pytorch and couldn't find anything relevant after googling.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hello,
I am trying to use one of the pre-packaged models (e.g. torch_geometric.nn.models.GraphSAGE) for a graph-level regression task.
Normally, if I am constructing my own model using existing layers, I would add a pooling layer at the end (e.g. global_max_pool) to get a graph-level output instead of a node-level output. However, for the pre-packaged models, I do not know how to achieve this. Setting out_channels parameter in GraphSAGE to 1 merely makes the dimension of the output tensor to be [some_big_number,1] which I presume to be [num_nodes_in_batch,1]. I tried using torch.nn.Sequential to chain the pooling layer to the output of GraphSAGE errors like "global_max_pool is not a Module subclass".
Apologies if this is trivial or not the right forum -- I am new to pytorch and couldn't find anything relevant after googling.
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions