-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] EquivariantConv
#2824
base: master
Are you sure you want to change the base?
[WIP] EquivariantConv
#2824
Conversation
merge upsream changes
Codecov Report
@@ Coverage Diff @@
## master #2824 +/- ##
==========================================
+ Coverage 84.23% 84.33% +0.09%
==========================================
Files 209 210 +1
Lines 9524 9579 +55
==========================================
+ Hits 8023 8078 +55
Misses 1501 1501
Continue to review full report at Codecov.
|
HI @wsad1 , thanks for your implementation on the EGNN-layer implementation. This implementation updates the positional features based on the local neighbourhood only, right? https://github.com/rusty1s/pytorch_geometric/blob/a3118b4406e12c817e20c3efc8210b228e886d76/torch_geometric/nn/conv/equivariant_conv.py#L111-L145 In the original paper, the positions |
@tuanle618 , Thanks for bringing this up. You are absolutely right, the positional embedding
The above approach would not support |
Hi @wsad1 , thanks a lot for your quick response and implementation suggestions. I would like to support you in this manner and have created some code for the steps (1) and (2). For supporting
Do step (3) from you, by providing |
import torch
from torch_scatter import scatter_add
from torch_geometric.data import Batch
# assumes data is of type `torch_geometric.data.batch.Batch`
x, batch, ptr = data.x, data.batch, data.ptr
batch_size = batch.max().item() + 1
edge_index, edge_attr = data.edge_index, data.edge_attr
def get_fully_connected_edges(n_nodes: int, add_self_loops: bool = False):
rows, cols = [], []
for i in range(n_nodes):
for j in range(n_nodes):
if i != j or (i == j and add_self_loops):
rows.append(i)
cols.append(j)
edges = [rows, cols]
edges = torch.tensor(edges, dtype=torch.long).contiguous()
return edges
batch_num_nodes = scatter_add(src=batch.new_ones(x.size(0)), index=batch, dim=0, dim_size=batch_size)
edge_index = data.edge_index
fc_edge_index = torch.cat([get_fully_connected_edges(n) + p for n, p in zip(batch_num_nodes, ptr)], dim=-1)
# a memory-inefficient and maybe slower version
# adjs = torch.block_diag(*[torch.ones(n, n).fill_diagonal_(0.0) for n in batch_num_nodes]).nonzero().t().contiguous()
# torch.allclose(fc_edge_index, adjs)
# find positions of true edge_index
source, target = source, target = edge_index[0].cpu().numpy().tolist(), edge_index[1].cpu().numpy().tolist()
source_target_to_edge_idx = {str([s, t]): i for s, t, i in zip(source, target, range(len(source)))}
edge_idx_to_source_target = {v: k for k, v in source_target_to_edge_idx.items()}
# positions of fake edge_index
source_fc, target_fc = fc_edge_index[0].cpu().numpy().tolist(), fc_edge_index[1].cpu().numpy().tolist()
source_target_to_fc_edge_idx = {str([s, t]): i for s, t, i in zip(source_fc, target_fc, range(len(source_fc)))}
fc_edge_idx_to_source_target = {v: k for k, v in source_target_to_fc_edge_idx.items()}
fake_edges = [s for s in source_target_to_fc_edge_idx.keys() if s not in source_target_to_edge_idx.keys()]
fake_edges_ids = [source_target_to_fc_edge_idx[k] for k in fake_edges]
E_fc = fc_edge_index.shape[1]
E = edge_index.shape[1]
assert len(fake_edges) == E_fc - E
fake_edge_index = fc_edge_index.t()[fake_edges_ids].t()
fake_edge_attr = torch.zeros(size=(fake_edge_index.size(1), edge_attr.size(-1)),
device=x.device)
all_edge_index = torch.cat([edge_index, fake_edge_index], dim=-1)
all_edge_attr = torch.cat([edge_attr, fake_edge_attr], dim=0) |
I've modified your implemented version of the Additionally, I removed the Find below the code, I've slightly modified from your version: The Conv: from typing import Optional, Callable, Tuple
from torch_geometric.typing import OptTensor, Adj
import torch
from torch import Tensor
from torch.nn import Linear
from torch_scatter import scatter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops
from torch_geometric.nn.inits import reset
class EquivariantConv(MessagePassing):
r"""The Equivariant graph neural network operator form the
`"E(n) Equivariant Graph Neural Networks"
<https://arxiv.org/pdf/2102.09844.pdf>`_ paper.
.. math::
\mathbf{m}_{ij}=h_{\mathbf{\Theta}}(\mathbf{x}_i,\mathbf{x}_j,\|
{\mathbf{pos}_i-\mathbf{pos}_j}\|^2_2,\mathbf{e}_{ij})
\mathbf{x}^{\prime}_i = \gamma_{\mathbf{\Theta}}(\mathbf{x}_i,
\sum_{j \in \mathcal{N}(i)} \mathbf{m}_{ij})
\mathbf{vel}^{\prime}_i = \phi_{\mathbf{\Theta}}(\mathbf{x}_i)\mathbf
{vel}_i + \frac{1}{|\mathcal{N}(i)|}\sum_{j \in\mathcal{N}(i)}
(\mathbf{pos}_i-\mathbf{pos}_j)
\rho_{\mathbf{\Theta}}(\mathbf{m}_{ij})
\mathbf{pos}^{\prime}_i = \mathbf{pos}_i + \mathbf{vel}_i
where :math:`\gamma_{\mathbf{\Theta}}`,
:math:`h_{\mathbf{\Theta}}`, :math:`\rho_{\mathbf{\Theta}}`
and :math:`\phi_{\mathbf{\Theta}}` denote neural
networks, *.i.e.* MLPs. :math:`\mathbf{P} \in \mathbb{R}^{N \times D}`
and :math:`\mathbf{V} \in \mathbb{R}^{N \times D}`
defines the position and velocity of each point respectively.
Args:
local_nn (torch.nn.Module, optional): A neural network
:math:`h_{\mathbf{\Theta}}` that maps node features :obj:`x`,
sqaured distance :math:`\|{\mathbf{pos}_i-\mathbf{pos}_j}\|^2_2`
and edge_features :obj:`edge_attr`
of shape :obj:`[-1, 2*in_channels + 1 +edge_dim]`
to shape :obj:`[-1, hidden_channels]`, *e.g.*, defined by
:class:`torch.nn.Sequential`. (default: :obj:`None`)
pos_nn (torch.nn.Module,optinal): A neural network
:math:`\rho_{\mathbf{\Theta}}` that
maps message :obj:`m` of shape
:obj:`[-1, hidden_channels]`,
to shape :obj:`[-1, 1]`, *e.g.*, defined by
:class:`torch.nn.Sequential`. (default: :obj:`None`)
vel_nn (torch.nn.Module,optional): A neural network
:math:`\phi_{\mathbf{\Theta}}` that
maps node featues :obj:`x` of shape :obj:`[-1, in_channels]`,
to shape :obj:`[-1, 1]`, *e.g.*, defined by
:class:`torch.nn.Sequential`. (default: :obj:`None`)
global_nn (torch.nn.Module, optional): A neural network
:math:`\gamma_{\mathbf{\Theta}}` that maps
message :obj:`m` after aggregation
and node features :obj:`x` of shape
:obj:`[-1, hidden_channels + in_channels]`
to shape :obj:`[-1, out_channels]`, *e.g.*, defined by
:class:`torch.nn.Sequential`. (default: :obj:`None`)
add_self_loops (bool, optional): If set to :obj:`False`, will not add
self-loops to the input graph. (default: :obj:`True`)
aggr (string, optional): The operator used to aggregate message
:obj:`m` (:obj:`"add"`, :obj:`"mean"`).
(default: :obj:`"mean"`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self, local_nn: Optional[Callable] = None,
pos_nn: Optional[Callable] = None,
vel_nn: Optional[Callable] = None,
global_nn: Optional[Callable] = None,
aggr="mean", **kwargs):
super(EquivariantConv, self).__init__(aggr=aggr, **kwargs)
self.local_nn = local_nn
self.pos_nn = pos_nn
self.vel_nn = vel_nn
self.global_nn = global_nn
self.add_self_loops = add_self_loops
self.reset_parameters()
def reset_parameters(self):
reset(self.local_nn)
reset(self.pos_nn)
reset(self.vel_nn)
reset(self.global_nn)
def forward(self, x: OptTensor,
pos: Tensor,
edge_index: Adj,
fc_edge_index: Adj,
vel: OptTensor = None,
edge_attr: OptTensor = None
) -> Tuple[Tensor, Tuple[Tensor, OptTensor]]:
""""""
self.__calculated_msgs = (None, None)
self.__E = edge_index.size(1)
# propagate_type: (x: OptTensor, pos: Tensor, edge_attr: OptTensor) -> Tuple[Tensor,Tensor] # noqa
_, out_pos = self.propagate(fc_edge_index, x=x, pos=pos,
edge_attr=edge_attr, size=None)
out_x, _ = self.propagate(edge_index, x=x, pos=pos,
edge_attr=edge_attr, size=None)
out_x = out_x if x is None else torch.cat([x, out_x], dim=1)
if self.global_nn is not None:
out_x = self.global_nn(out_x)
if vel is None:
out_pos += pos
out_vel = None
else:
out_vel = (vel if self.vel_nn is None or x is None else
self.vel_nn(x) * vel) + out_pos
out_pos = pos + out_vel
self.__calculated_msgs = (None, None)
return (out_x, (out_pos, out_vel))
def message(self, x_i: OptTensor, x_j: OptTensor, pos_i: Tensor,
pos_j: Tensor,
edge_attr: OptTensor = None) -> Tuple[Tensor, Tensor]:
# only do this calculation once
if self.__calculated_msgs[0] is None and self.__calculated_msgs[1] is None:
msg = torch.sum((pos_i - pos_j).square(), dim=1, keepdim=True)
msg = msg if x_j is None else torch.cat([x_j, msg], dim=1)
msg = msg if x_i is None else torch.cat([x_i, msg], dim=1)
msg = msg if edge_attr is None else torch.cat([msg, edge_attr], dim=1)
msg = msg if self.local_nn is None else self.local_nn(msg)
pos_msg = ((pos_i - pos_j) if self.pos_nn is None else
(pos_i - pos_j) * self.pos_nn(msg))
self.__calculated_msgs = (msg, pos_msg)
return (msg, pos_msg)
else:
return (self.__calculated_msgs[0][:self.__E], self.__calculated_msgs[1][:self.__E])
def aggregate(self, inputs: Tuple[Tensor, Tensor],
index: Tensor) -> Tuple[Tensor, Tensor]:
return (scatter(inputs[0], index, 0, reduce=self.aggr),
scatter(inputs[1], index, 0, reduce="mean"))
def update(self, inputs: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
return inputs
def __repr__(self):
return ("{}(local_nn={}, pos_nn={},"
" vel_nn={},"
" global_nn={})").format(self.__class__.__name__,
self.local_nn, self.pos_nn,
self.vel_nn, self.global_nn) A minimal example using networkx generated random graphs without the velocities. from torch_geometric.utils import from_networkx
import networkx as nx
from torch_geometric.data import DataLoader
from torch_scatter import scatter_add
def get_fully_connected_get_edges(n_nodes: int, add_self_loops: bool = False):
rows, cols = [], []
for i in range(n_nodes):
for j in range(n_nodes):
if i != j or (i == j and add_self_loops):
rows.append(i)
cols.append(j)
edges = [rows, cols]
edges = torch.tensor(edges, dtype=torch.long).contiguous()
return edges
def get_fc_edge_index(batch_num_nodes: list, ptr: torch.Tensor,
edge_index: torch.Tensor, edge_attr: torch.Tensor) -> (torch.Tensor, torch.Tensor):
fc_edge_index = torch.cat([get_fully_connected_get_edges(n) + p for n, p in zip(batch_num_nodes, ptr)], dim=-1)
source, target = edge_index[0].cpu().numpy().tolist(), edge_index[1].cpu().numpy().tolist()
source_target_to_edge_idx = {str([s, t]): i for s, t, i in zip(source, target, range(len(source)))}
# edge_idx_to_source_target = {v: k for k, v in source_target_to_edge_idx.items()}
# positions of fake edge_index
source_fc, target_fc = fc_edge_index[0].cpu().numpy().tolist(), fc_edge_index[1].cpu().numpy().tolist()
source_target_to_fc_edge_idx = {str([s, t]): i for s, t, i in zip(source_fc, target_fc, range(len(source_fc)))}
# fc_edge_idx_to_source_target = {v: k for k, v in source_target_to_fc_edge_idx.items()}
fake_edges = [s for s in source_target_to_fc_edge_idx.keys() if s not in source_target_to_edge_idx.keys()]
fake_edges_ids = [source_target_to_fc_edge_idx[k] for k in fake_edges]
E_fc = fc_edge_index.shape[1]
E = edge_index.shape[1]
assert len(fake_edges) == E_fc - E
fake_edge_index = fc_edge_index.t()[fake_edges_ids].t()
fake_edge_attr = torch.zeros(size=(fake_edge_index.size(1), edge_attr.size(-1)),
device=x.device)
all_edge_index = torch.cat([edge_index, fake_edge_index], dim=-1)
all_edge_attr = torch.cat([edge_attr, fake_edge_attr], dim=0)
return all_edge_index, all_edge_attr
def create_random_graph_pyg(n: int, seed: int):
G = nx.random_geometric_graph(n=n, radius=0.125, dim=3, seed=seed)
data = from_networkx(G)
data.x = torch.randn(data.num_nodes, 16)
data.edge_attr = torch.randn(data.edge_index.size(1), 8)
return data
batch_size = 16
max_num_nodes = 100
seed = 42
batch_num_nodes = torch.randint(low=20, high=max_num_nodes, size=(batch_size, ), dtype=torch.long)
datalist = [create_random_graph_pyg(n, seed + i) for i, n in enumerate(batch_num_nodes)]
loader = DataLoader(datalist, 16)
data = next(iter(loader))
x, pos, batch, ptr, edge_index, edge_attr = data.x, data.pos, data.batch, data.ptr, data.edge_index, data.edge_attr
batch_num_nodes = scatter_add(src=batch.new_ones(x.size(0)), index=batch, dim=0, dim_size=batch_size).cpu().numpy().tolist()
fc_edge_index, fc_edge_attr = get_fc_edge_index(batch_num_nodes=batch_num_nodes, ptr=ptr,
edge_index=edge_index, edge_attr=edge_attr)
node_in_channels = 16
edge_in_channels = 8
pos_in_channels = 3
local_nn = Linear(2 * node_in_channels + 1 + edge_in_channels, node_in_channels, bias=False)
pos_nn = Linear(node_in_channels, 1, bias=True)
global_nn = Linear(2 * node_in_channels, node_in_channels, bias=True)
conv = EquivariantConv(local_nn=local_nn, pos_nn=pos_nn, global_nn=global_nn)
x_out, (pos_out, _) = conv(x=x, pos=pos, edge_index=edge_index, fc_edge_index=fc_edge_index, edge_attr=fc_edge_attr)
# test without edge-attrs
node_in_channels = 16
edge_in_channels = 0
pos_in_channels = 3
local_nn = Linear(2 * node_in_channels + 1 + edge_in_channels, node_in_channels, bias=False)
pos_nn = Linear(node_in_channels, 1, bias=True)
global_nn = Linear(2 * node_in_channels, node_in_channels, bias=True)
conv = EquivariantConv(local_nn=local_nn, pos_nn=pos_nn, global_nn=global_nn)
x_out, (pos_out, _) = conv(x=x, pos=pos, edge_index=edge_index, fc_edge_index=fc_edge_index, edge_attr=None) |
@tuanle618 , firstly I would be happy to collaborate with you on this. Please feel free to send PRs to this branch, with your code edits. Second, appreciate your effort to fix 1.I think
|
Thanks for your suggestions @wsad1 . out_x, out_pos = self.propagate(edge_index = fc_edge_index, x=x, pos=pos,
edge_attr=edge_attr, orig_index = edge_index[1], size=None)
def aggregate(self, inputs: Tuple[Tensor, Tensor],
index: Tensor, orig_index: Tensor) -> Tuple[Tensor, Tensor]:
return (scatter(inputs[0], orig_index, 0, reduce=self.aggr), # aggregate on original edges
scatter(inputs[1], index, 0, reduce="mean")) # aggregate on fc edges for scatter(inputs[0][:len(orig_index)], orig_index, 0, reduce=self.aggr) The code currently runs without errors, but I'd like to add some more tests, to make sure the aggregations are done as intended. I'll ping you, once I made the PR. Best regards, |
Hi, node_in_channels = 16
edge_in_channels = 8
pos_in_channels = 3
local_nn = Linear(2 * node_in_channels + 1 + edge_in_channels, node_in_channels, bias=False)
pos_nn = Linear(node_in_channels, 1, bias=True)
global_nn = Linear(2 * node_in_channels, node_in_channels, bias=True)
conv = EquivariantConv(local_nn=local_nn, pos_nn=pos_nn, global_nn=global_nn)
x_out, (pos_out, _) = conv(x=x, pos=pos, edge_index=edge_index, fc_edge_index=fc_edge_index, edge_attr=fc_edge_attr) I want to be sure that the first Best regards |
@wsad1 What is the status of this PR? I'd be happy to help if there is still work to be done. |
Equivariant conv as proposed in this paper. Two things to note
message
,aggregate
andupdate
return a tuple of tensors.propogate
returns a Tuple of tensors and not a tensor. One potential workaround would be to haveupdate
return a tensor for node embedding, the positional update could be saved in a class variable. Let me know if that change needs to be made or if you think of a better fix.