Skip to content

LineGraph improvements #10212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Accelerated `LineGraph` and fixed issue with self-connected inputs ([#10212](https://github.com/pyg-team/pytorch_geometric/pull/10212))
- Updated cuGraph examples to use buffered sampling which keeps data in memory and is significantly faster than the deprecated buffered sampling ([#10079](https://github.com/pyg-team/pytorch_geometric/pull/10079))
- Updated Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794))
- Dropped Python 3.8 support ([#9696](https://github.com/pyg-team/pytorch_geometric/pull/9606))
Expand Down
71 changes: 37 additions & 34 deletions torch_geometric/transforms/line_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import coalesce, cumsum, remove_self_loops, scatter
from torch_geometric.utils import (
coalesce,
cumsum,
remove_self_loops,
scatter,
to_undirected,
)


@functional_transform('line_graph')
Expand All @@ -17,12 +23,11 @@ class LineGraph(BaseTransform):

\mathcal{V}^{\prime} &= \mathcal{E}

\mathcal{E}^{\prime} &= \{ (e_1, e_2) : e_1 \cap e_2 \neq \emptyset \}
\mathcal{E}^{\prime} &=
\{ (e_1, e_2) : e_1 \cap e_2 \neq \emptyset,\, e_1 \neq e_2 \}

Line-graph node indices are equal to indices in the original graph's
coalesced :obj:`edge_index`.
For undirected graphs, the maximum line-graph node index is
:obj:`(data.edge_index.size(1) // 2) - 1`.

New node features are given by old edge attributes.
For undirected graphs, edge attributes for reciprocal edges
Expand All @@ -43,57 +48,55 @@ def forward(self, data: Data) -> Data:
edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes=N)
row, col = edge_index

def vector_arange(start: Tensor, step: Tensor) -> Tensor:
step_csum = cumsum(step)[:-1]
offset = (step_csum - start).repeat_interleave(step)
rng = torch.arange(offset.size(0), dtype=torch.long,
device=offset.device)
return rng - offset

if self.force_directed or data.is_directed():
i = torch.arange(row.size(0), dtype=torch.long, device=row.device)

count = scatter(torch.ones_like(row), row, dim=0,
dim_size=data.num_nodes, reduce='sum')
ptr = cumsum(count)

cols = [i[ptr[col[j]]:ptr[col[j] + 1]] for j in range(col.size(0))]
rows = [row.new_full((c.numel(), ), j) for j, c in enumerate(cols)]

row, col = torch.cat(rows, dim=0), torch.cat(cols, dim=0)
col_count = count[col]
row = i.repeat_interleave(col_count)
col_index = vector_arange(ptr[col], col_count)
col = i[col_index]

data.edge_index = torch.stack([row, col], dim=0)
data.x = data.edge_attr
data.num_nodes = edge_index.size(1)

else:
# Compute node indices.
mask = row < col
row, col = row[mask], col[mask]
i = torch.arange(row.size(0), dtype=torch.long, device=row.device)

(row, col), i = coalesce(
torch.stack([
torch.cat([row, col], dim=0),
torch.cat([col, row], dim=0)
], dim=0),
torch.cat([i, i], dim=0),
N,
)
mask = row <= col
edge_index = edge_index[:, mask]
N = edge_index.size(1)
i = torch.arange(N, dtype=torch.long, device=row.device)
(row, col), i = to_undirected(edge_index, i, num_nodes=N,
reduce='min')

# Compute new edge indices according to `i`.
count = scatter(torch.ones_like(row), row, dim=0,
dim_size=data.num_nodes, reduce='sum')
joints = list(torch.split(i, count.tolist()))

def generate_grid(x: Tensor) -> Tensor:
row = x.view(-1, 1).repeat(1, x.numel()).view(-1)
col = x.repeat(x.numel())
return torch.stack([row, col], dim=0)
ptr = cumsum(count)
row = i.repeat_interleave(count[row])
count_rep = count.repeat_interleave(count)
col_index = vector_arange(ptr[:-1].repeat_interleave(count),
count_rep)
col = i[col_index]

joints = [generate_grid(joint) for joint in joints]
joint = torch.cat(joints, dim=1)
joint, _ = remove_self_loops(joint)
N = row.size(0) // 2
joint = coalesce(joint, num_nodes=N)
index = torch.stack([row, col], dim=0)
index, _ = remove_self_loops(index)
index = coalesce(index, num_nodes=N)

if edge_attr is not None:
data.x = scatter(edge_attr, i, dim=0, dim_size=N, reduce='sum')
data.edge_index = joint
data.num_nodes = edge_index.size(1) // 2
data.edge_index = index
data.num_nodes = N

data.edge_attr = None

Expand Down