Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] EquivariantConv #2824

Draft
wants to merge 16 commits into
base: master
Choose a base branch
from
Draft

[WIP] EquivariantConv #2824

wants to merge 16 commits into from

Conversation

wsad1
Copy link
Member

@wsad1 wsad1 commented Jul 6, 2021

Equivariant conv as proposed in this paper. Two things to note

  1. Equivariant conv updates a nodes embedding as well as position. So the functions message , aggregate and update return a tuple of tensors.
  2. jit seems to have some issues as propogate returns a Tuple of tensors and not a tensor. One potential workaround would be to have update return a tensor for node embedding, the positional update could be saved in a class variable. Let me know if that change needs to be made or if you think of a better fix.

@codecov-commenter
Copy link

codecov-commenter commented Jul 10, 2021

Codecov Report

Merging #2824 (a3118b4) into master (dc8f503) will increase coverage by 0.09%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #2824      +/-   ##
==========================================
+ Coverage   84.23%   84.33%   +0.09%     
==========================================
  Files         209      210       +1     
  Lines        9524     9579      +55     
==========================================
+ Hits         8023     8078      +55     
  Misses       1501     1501              
Impacted Files Coverage Δ
torch_geometric/nn/conv/__init__.py 100.00% <100.00%> (ø)
torch_geometric/nn/conv/equivariant_conv.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 85c8b78...a3118b4. Read the comment docs.

@wsad1
Copy link
Member Author

wsad1 commented Jul 31, 2021

@rusty1s thanks for updating jittable in message passing(673f947). With that equivariant conv is jittable.

@tuanle618
Copy link

HI @wsad1 , thanks for your implementation on the EGNN-layer implementation. This implementation updates the positional features based on the local neighbourhood only, right? https://github.com/rusty1s/pytorch_geometric/blob/a3118b4406e12c817e20c3efc8210b228e886d76/torch_geometric/nn/conv/equivariant_conv.py#L111-L145 In the original paper, the positions x_i^{l+1} are computed by iterating over all nodes in the graph. I guess this makes it hard to combine withing the MessagePassing framework, as the index provided is based on the given adjacency matrix.

@wsad1 wsad1 changed the title Added equivariant conv. [WIP] : Added equivariant conv. Sep 4, 2021
@wsad1 wsad1 marked this pull request as draft September 4, 2021 06:57
@wsad1
Copy link
Member Author

wsad1 commented Sep 4, 2021

@tuanle618 , Thanks for bringing this up. You are absolutely right, the positional embedding pos should be updated based on all nodes in the graph, the current implementation updates it based on node neighbours.
One way to implement this would be to

  1. create a fully_connected_edge_index for each graph. And pass this to propogate as edge_index, and pass the original edge_index as another argument original_index.
  2. in message the x_i, x_j, pos_i, pos_j would be created based on the fully connected edge index.
  3. in aggregate, msg needs to be aggregated based on original_index[,1] and pos based on index.

The above approach would not support edge_attr. And its not very straightforward or memory efficient, so i might have to think about it more.

@tuanle618
Copy link

Hi @wsad1 , thanks a lot for your quick response and implementation suggestions. I would like to support you in this manner and have created some code for the steps (1) and (2).
For step 3, however, I am right now not certain how we would provide the edge_index and fc_edge_index for the self.aggregate function, as the final goal is actually to just call self.propagate ones, right?

For supporting edge_attr, I was thinking to zero-pad the edge_attr tensor, along dim=0.
In general, my steps currently include:
Lets say the number of edges is E=edge_index.size(1), and the fully-connected edge_index is also bounded by the number of nodes in our batch, i.e. E_fc=fc_edge_index.size(1).

  • (a) create fully-connected edge-index from the batch, i.e., fc_edge_index
  • (b) get "fake" indices within fc_edge_index that are not present in the "true" edge_index, and update fc_edge_index according to. That means fc_edge_index.size(1) = E_fc - E
  • (c) now concatenate along dimension 0, so we "match" the ordering for the edge_attr, i.e. fc_edge_index=torch.cat([edge_index, fc_edge_index], dim=-1) and edge_attr=torch.cat([edge_attr, torch.zeros(size=(E_fc-E, edge_attr.size(-1))], device=x.device)

Do step (3) from you, by providing edge_index and fc_edge_index, where fc_edge_index is used to update the positional embeddings,
and edge_index just to gather all incoming messages to update the node embeddings.

@tuanle618
Copy link

tuanle618 commented Sep 6, 2021

import torch
from torch_scatter import scatter_add
from torch_geometric.data import Batch

# assumes data is of type `torch_geometric.data.batch.Batch`
x, batch, ptr = data.x, data.batch, data.ptr
batch_size = batch.max().item() + 1
edge_index, edge_attr = data.edge_index, data.edge_attr


def get_fully_connected_edges(n_nodes: int, add_self_loops: bool = False):
    rows, cols = [], []
    for i in range(n_nodes):
        for j in range(n_nodes):
            if i != j or (i == j and add_self_loops):
                rows.append(i)
                cols.append(j)

    edges = [rows, cols]
    edges = torch.tensor(edges, dtype=torch.long).contiguous()
    return edges

batch_num_nodes = scatter_add(src=batch.new_ones(x.size(0)), index=batch, dim=0, dim_size=batch_size)
edge_index = data.edge_index
fc_edge_index = torch.cat([get_fully_connected_edges(n) + p for n, p in zip(batch_num_nodes, ptr)], dim=-1)

# a memory-inefficient and maybe slower version
# adjs = torch.block_diag(*[torch.ones(n, n).fill_diagonal_(0.0) for n in batch_num_nodes]).nonzero().t().contiguous()
# torch.allclose(fc_edge_index, adjs)

# find positions of true edge_index
source, target = source, target = edge_index[0].cpu().numpy().tolist(), edge_index[1].cpu().numpy().tolist()

source_target_to_edge_idx = {str([s, t]): i for s, t, i in zip(source, target, range(len(source)))}
edge_idx_to_source_target = {v: k for k, v in source_target_to_edge_idx.items()}

# positions of fake edge_index
source_fc, target_fc = fc_edge_index[0].cpu().numpy().tolist(), fc_edge_index[1].cpu().numpy().tolist()
source_target_to_fc_edge_idx = {str([s, t]): i for s, t, i in zip(source_fc, target_fc, range(len(source_fc)))}
fc_edge_idx_to_source_target = {v: k for k, v in source_target_to_fc_edge_idx.items()}

fake_edges = [s for s in source_target_to_fc_edge_idx.keys() if s not in source_target_to_edge_idx.keys()]
fake_edges_ids = [source_target_to_fc_edge_idx[k] for k in fake_edges]

E_fc = fc_edge_index.shape[1]
E = edge_index.shape[1]

assert len(fake_edges) == E_fc - E
fake_edge_index = fc_edge_index.t()[fake_edges_ids].t()

fake_edge_attr = torch.zeros(size=(fake_edge_index.size(1), edge_attr.size(-1)),
                             device=x.device)


all_edge_index = torch.cat([edge_index, fake_edge_index], dim=-1)
all_edge_attr = torch.cat([edge_attr, fake_edge_attr], dim=0)

@tuanle618
Copy link

I've modified your implemented version of the EquivariantConv, @wsad1 .
Right now, self.propagate is called twice.
The first time, is when all fully-connected messages are constructed based on the fc_edge_index. After that call, intermediate messages are saved internally as tuples in self.__calculated_msgs.
The second time, is called when the true edge_index is input. Right now, two unnecessary aggregation steps are done for the x and pos for the first and second time, respectively.

Additionally, I removed the add_self_loops argument, as
edge-features for the self-loop could be included with zero-padded tensors, but the self-message does not make much sense in my opinion when constructing m_ij, as (a) the positional distance is 0, and hence the input for the local_nn is then just a concatenation of the same value, as well as zero-vectors for distance and edge_attr, respectively.

Find below the code, I've slightly modified from your version:

The Conv:

from typing import Optional, Callable, Tuple
from torch_geometric.typing import OptTensor, Adj

import torch
from torch import Tensor
from torch.nn import Linear
from torch_scatter import scatter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops

from torch_geometric.nn.inits import reset


class EquivariantConv(MessagePassing):
    r"""The Equivariant graph neural network operator form the
    `"E(n) Equivariant Graph Neural Networks"
    <https://arxiv.org/pdf/2102.09844.pdf>`_ paper.
    .. math::
        \mathbf{m}_{ij}=h_{\mathbf{\Theta}}(\mathbf{x}_i,\mathbf{x}_j,\|
        {\mathbf{pos}_i-\mathbf{pos}_j}\|^2_2,\mathbf{e}_{ij})
        \mathbf{x}^{\prime}_i = \gamma_{\mathbf{\Theta}}(\mathbf{x}_i,
        \sum_{j \in \mathcal{N}(i)} \mathbf{m}_{ij})
        \mathbf{vel}^{\prime}_i = \phi_{\mathbf{\Theta}}(\mathbf{x}_i)\mathbf
        {vel}_i + \frac{1}{|\mathcal{N}(i)|}\sum_{j \in\mathcal{N}(i)}
        (\mathbf{pos}_i-\mathbf{pos}_j)
        \rho_{\mathbf{\Theta}}(\mathbf{m}_{ij})
        \mathbf{pos}^{\prime}_i = \mathbf{pos}_i + \mathbf{vel}_i
    where :math:`\gamma_{\mathbf{\Theta}}`,
    :math:`h_{\mathbf{\Theta}}`, :math:`\rho_{\mathbf{\Theta}}`
    and :math:`\phi_{\mathbf{\Theta}}` denote neural
    networks, *.i.e.* MLPs. :math:`\mathbf{P} \in \mathbb{R}^{N \times D}`
    and :math:`\mathbf{V} \in \mathbb{R}^{N \times D}`
    defines the position and velocity of each point respectively.
    Args:
        local_nn (torch.nn.Module, optional): A neural network
            :math:`h_{\mathbf{\Theta}}` that maps node features :obj:`x`,
            sqaured distance :math:`\|{\mathbf{pos}_i-\mathbf{pos}_j}\|^2_2`
            and edge_features :obj:`edge_attr`
            of shape :obj:`[-1, 2*in_channels + 1 +edge_dim]`
            to shape :obj:`[-1, hidden_channels]`, *e.g.*, defined by
            :class:`torch.nn.Sequential`. (default: :obj:`None`)
        pos_nn (torch.nn.Module,optinal): A neural network
            :math:`\rho_{\mathbf{\Theta}}` that
            maps message :obj:`m` of shape
            :obj:`[-1, hidden_channels]`,
            to shape :obj:`[-1, 1]`, *e.g.*, defined by
            :class:`torch.nn.Sequential`. (default: :obj:`None`)
        vel_nn (torch.nn.Module,optional): A neural network
            :math:`\phi_{\mathbf{\Theta}}` that
            maps node featues :obj:`x` of shape :obj:`[-1, in_channels]`,
            to shape :obj:`[-1, 1]`, *e.g.*, defined by
            :class:`torch.nn.Sequential`. (default: :obj:`None`)
        global_nn (torch.nn.Module, optional): A neural network
            :math:`\gamma_{\mathbf{\Theta}}` that maps
            message :obj:`m` after aggregation
            and node features :obj:`x` of shape
            :obj:`[-1, hidden_channels + in_channels]`
            to shape :obj:`[-1, out_channels]`, *e.g.*, defined by
            :class:`torch.nn.Sequential`. (default: :obj:`None`)
        add_self_loops (bool, optional): If set to :obj:`False`, will not add
            self-loops to the input graph. (default: :obj:`True`)
        aggr (string, optional): The operator used to aggregate message
            :obj:`m` (:obj:`"add"`, :obj:`"mean"`).
            (default: :obj:`"mean"`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.
    """
    def __init__(self, local_nn: Optional[Callable] = None,
                 pos_nn: Optional[Callable] = None,
                 vel_nn: Optional[Callable] = None,
                 global_nn: Optional[Callable] = None,
                 aggr="mean", **kwargs):
        super(EquivariantConv, self).__init__(aggr=aggr, **kwargs)

        self.local_nn = local_nn
        self.pos_nn = pos_nn
        self.vel_nn = vel_nn
        self.global_nn = global_nn
        self.add_self_loops = add_self_loops

        self.reset_parameters()

    def reset_parameters(self):

        reset(self.local_nn)
        reset(self.pos_nn)
        reset(self.vel_nn)
        reset(self.global_nn)

    def forward(self, x: OptTensor,
                pos: Tensor,
                edge_index: Adj,
                fc_edge_index: Adj,
                vel: OptTensor = None,
                edge_attr: OptTensor = None
                ) -> Tuple[Tensor, Tuple[Tensor, OptTensor]]:
        """"""

        self.__calculated_msgs = (None, None)
        self.__E = edge_index.size(1)

        # propagate_type: (x: OptTensor, pos: Tensor, edge_attr: OptTensor) -> Tuple[Tensor,Tensor] # noqa
        _, out_pos = self.propagate(fc_edge_index, x=x, pos=pos,
                                    edge_attr=edge_attr, size=None)

        out_x, _ = self.propagate(edge_index, x=x, pos=pos,
                                  edge_attr=edge_attr, size=None)

        out_x = out_x if x is None else torch.cat([x, out_x], dim=1)
        if self.global_nn is not None:
            out_x = self.global_nn(out_x)

        if vel is None:
            out_pos += pos
            out_vel = None
        else:
            out_vel = (vel if self.vel_nn is None or x is None else
                       self.vel_nn(x) * vel) + out_pos
            out_pos = pos + out_vel
            
        self.__calculated_msgs = (None, None)

        return (out_x, (out_pos, out_vel))

    def message(self, x_i: OptTensor, x_j: OptTensor, pos_i: Tensor,
                pos_j: Tensor,
                edge_attr: OptTensor = None) -> Tuple[Tensor, Tensor]:

        # only do this calculation once
        if self.__calculated_msgs[0] is None and self.__calculated_msgs[1] is None:
            msg = torch.sum((pos_i - pos_j).square(), dim=1, keepdim=True)
            msg = msg if x_j is None else torch.cat([x_j, msg], dim=1)
            msg = msg if x_i is None else torch.cat([x_i, msg], dim=1)
            msg = msg if edge_attr is None else torch.cat([msg, edge_attr], dim=1)
            msg = msg if self.local_nn is None else self.local_nn(msg)

            pos_msg = ((pos_i - pos_j) if self.pos_nn is None else
                       (pos_i - pos_j) * self.pos_nn(msg))
            self.__calculated_msgs = (msg, pos_msg)
            return (msg, pos_msg)
        else:
            return (self.__calculated_msgs[0][:self.__E], self.__calculated_msgs[1][:self.__E])

    def aggregate(self, inputs: Tuple[Tensor, Tensor],
                  index: Tensor) -> Tuple[Tensor, Tensor]:
        return (scatter(inputs[0], index, 0, reduce=self.aggr),
                scatter(inputs[1], index, 0, reduce="mean"))

    def update(self, inputs: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
        return inputs

    def __repr__(self):
        return ("{}(local_nn={}, pos_nn={},"
                " vel_nn={},"
                " global_nn={})").format(self.__class__.__name__,
                                         self.local_nn, self.pos_nn,
                                         self.vel_nn, self.global_nn)

A minimal example using networkx generated random graphs without the velocities.

from torch_geometric.utils import from_networkx
import networkx as nx
from torch_geometric.data import DataLoader
from torch_scatter import scatter_add


def get_fully_connected_get_edges(n_nodes: int, add_self_loops: bool = False):
    rows, cols = [], []
    for i in range(n_nodes):
        for j in range(n_nodes):
            if i != j or (i == j and add_self_loops):
                rows.append(i)
                cols.append(j)

    edges = [rows, cols]
    edges = torch.tensor(edges, dtype=torch.long).contiguous()
    return edges


def get_fc_edge_index(batch_num_nodes: list, ptr: torch.Tensor,
                      edge_index: torch.Tensor, edge_attr: torch.Tensor) -> (torch.Tensor, torch.Tensor):
    fc_edge_index = torch.cat([get_fully_connected_get_edges(n) + p for n, p in zip(batch_num_nodes, ptr)], dim=-1)


    source, target = edge_index[0].cpu().numpy().tolist(), edge_index[1].cpu().numpy().tolist()
    source_target_to_edge_idx = {str([s, t]): i for s, t, i in zip(source, target, range(len(source)))}
    # edge_idx_to_source_target = {v: k for k, v in source_target_to_edge_idx.items()}
    # positions of fake edge_index
    source_fc, target_fc = fc_edge_index[0].cpu().numpy().tolist(), fc_edge_index[1].cpu().numpy().tolist()
    source_target_to_fc_edge_idx = {str([s, t]): i for s, t, i in zip(source_fc, target_fc, range(len(source_fc)))}
    # fc_edge_idx_to_source_target = {v: k for k, v in source_target_to_fc_edge_idx.items()}

    fake_edges = [s for s in source_target_to_fc_edge_idx.keys() if s not in source_target_to_edge_idx.keys()]
    fake_edges_ids = [source_target_to_fc_edge_idx[k] for k in fake_edges]
    E_fc = fc_edge_index.shape[1]
    E = edge_index.shape[1]
    assert len(fake_edges) == E_fc - E
    fake_edge_index = fc_edge_index.t()[fake_edges_ids].t()

    fake_edge_attr = torch.zeros(size=(fake_edge_index.size(1), edge_attr.size(-1)),
                                 device=x.device)

    all_edge_index = torch.cat([edge_index, fake_edge_index], dim=-1)
    all_edge_attr = torch.cat([edge_attr, fake_edge_attr], dim=0)

    return all_edge_index, all_edge_attr


def create_random_graph_pyg(n: int, seed: int):
    G = nx.random_geometric_graph(n=n, radius=0.125, dim=3, seed=seed)
    data = from_networkx(G)
    data.x = torch.randn(data.num_nodes, 16)
    data.edge_attr = torch.randn(data.edge_index.size(1), 8)
    return data

batch_size = 16
max_num_nodes = 100
seed = 42
batch_num_nodes = torch.randint(low=20, high=max_num_nodes, size=(batch_size, ), dtype=torch.long)
datalist = [create_random_graph_pyg(n, seed + i) for i, n in enumerate(batch_num_nodes)]

loader = DataLoader(datalist, 16)

data = next(iter(loader))

x, pos, batch, ptr, edge_index, edge_attr = data.x, data.pos, data.batch, data.ptr, data.edge_index, data.edge_attr

batch_num_nodes = scatter_add(src=batch.new_ones(x.size(0)), index=batch, dim=0, dim_size=batch_size).cpu().numpy().tolist()


fc_edge_index, fc_edge_attr = get_fc_edge_index(batch_num_nodes=batch_num_nodes, ptr=ptr,
                                                edge_index=edge_index, edge_attr=edge_attr)

node_in_channels = 16
edge_in_channels = 8
pos_in_channels = 3


local_nn = Linear(2 * node_in_channels + 1 + edge_in_channels, node_in_channels, bias=False)
pos_nn = Linear(node_in_channels, 1, bias=True)
global_nn = Linear(2 * node_in_channels, node_in_channels, bias=True)

conv = EquivariantConv(local_nn=local_nn, pos_nn=pos_nn, global_nn=global_nn)

x_out, (pos_out, _) = conv(x=x, pos=pos, edge_index=edge_index, fc_edge_index=fc_edge_index, edge_attr=fc_edge_attr)

# test without edge-attrs

node_in_channels = 16
edge_in_channels = 0
pos_in_channels = 3


local_nn = Linear(2 * node_in_channels + 1 + edge_in_channels, node_in_channels, bias=False)
pos_nn = Linear(node_in_channels, 1, bias=True)
global_nn = Linear(2 * node_in_channels, node_in_channels, bias=True)

conv = EquivariantConv(local_nn=local_nn, pos_nn=pos_nn, global_nn=global_nn)

x_out, (pos_out, _) = conv(x=x, pos=pos, edge_index=edge_index, fc_edge_index=fc_edge_index, edge_attr=None)

@wsad1
Copy link
Member Author

wsad1 commented Sep 7, 2021

@tuanle618 , firstly I would be happy to collaborate with you on this. Please feel free to send PRs to this branch, with your code edits.

Second, appreciate your effort to fix EquivariantConv. Some questions and thoughts.

1.I think propogate could be called just once, let me know if I am missing something here. So aggregate takes any argument passed to propogate. I believe something like this should work.

out_x, out_pos = self.propagate(edge_index = fc_edge_index, x=x, pos=pos,
                                  edge_attr=edge_attr, orig_index = edge_index[1], size=None)
       
def aggregate(self, inputs: Tuple[Tensor, Tensor],
                  index: Tensor, orig_index: Tensor) -> Tuple[Tensor, Tensor]:
        return (scatter(inputs[0], orig_index, 0, reduce=self.aggr), # aggregate on original edges
                    scatter(inputs[1], index, 0, reduce="mean")) # aggregate on fc edges
  1. get_fully_connected_get_edges could be simplified to
def get_fully_connected_get_edges(n_nodes: int, add_self_loops: bool = False):
    edge_index = torch.cartesian_prod(torch.arange(n_nodes),torch.arange(n_nodes)).T
    if not add_self_loops:
        edge_index = edge_index[edge_index[0]!=edge_index[1]]

@tuanle618
Copy link

Thanks for your suggestions @wsad1 .
I'm gonna make a PR soon on your forked repository to the enn branch. Will need to add some further tests, to make sure the aggregation on node-embeddings x (based on edge_index) and pos (based on fc_edge_index) are also correct. To your step (1) - I tried to manage to just use one function call to self.propagate using your

out_x, out_pos = self.propagate(edge_index = fc_edge_index, x=x, pos=pos,
                                  edge_attr=edge_attr, orig_index = edge_index[1], size=None)
       
def aggregate(self, inputs: Tuple[Tensor, Tensor],
                  index: Tensor, orig_index: Tensor) -> Tuple[Tensor, Tensor]:
        return (scatter(inputs[0], orig_index, 0, reduce=self.aggr), # aggregate on original edges
                    scatter(inputs[1], index, 0, reduce="mean")) # aggregate on fc edges

for aggregate however, I need to slice the tensor inputs[0] to make it match to orig_index, as the length of inputs[0] (as this was created based on the fc_edge_index) is much longer than orig_index, i.e.,

scatter(inputs[0][:len(orig_index)], orig_index, 0, reduce=self.aggr)

The code currently runs without errors, but I'd like to add some more tests, to make sure the aggregations are done as intended. I'll ping you, once I made the PR.

Best regards,
Tuan

@rusty1s rusty1s changed the title [WIP] : Added equivariant conv. [WIP] Added equivariant conv. Feb 4, 2022
@rusty1s rusty1s changed the title [WIP] Added equivariant conv. [WIP] Added equivariant conv Feb 4, 2022
@rusty1s rusty1s changed the title [WIP] Added equivariant conv [WIP] EquivariantConv Feb 7, 2022
@KevinCrp
Copy link

Hi,
Just a question about the example:

node_in_channels = 16
edge_in_channels = 8
pos_in_channels = 3


local_nn = Linear(2 * node_in_channels + 1 + edge_in_channels, node_in_channels, bias=False)
pos_nn = Linear(node_in_channels, 1, bias=True)
global_nn = Linear(2 * node_in_channels, node_in_channels, bias=True)

conv = EquivariantConv(local_nn=local_nn, pos_nn=pos_nn, global_nn=global_nn)

x_out, (pos_out, _) = conv(x=x, pos=pos, edge_index=edge_index, fc_edge_index=fc_edge_index, edge_attr=fc_edge_attr)

I want to be sure that the first pos_nn argument is well node_in_channels and not pos_in_channels.

Best regards
Kevin

@elilaird
Copy link

@wsad1 What is the status of this PR? I'd be happy to help if there is still work to be done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants