diff --git a/CHANGELOG.md b/CHANGELOG.md index 6aaf1d96f9de..1355e7102df7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Accelerated `LineGraph` and fixed issue with self-connected inputs ([#10212](https://github.com/pyg-team/pytorch_geometric/pull/10212)) - Updated cuGraph examples to use buffered sampling which keeps data in memory and is significantly faster than the deprecated buffered sampling ([#10079](https://github.com/pyg-team/pytorch_geometric/pull/10079)) - Updated Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794)) - Dropped Python 3.8 support ([#9696](https://github.com/pyg-team/pytorch_geometric/pull/9606)) diff --git a/torch_geometric/transforms/line_graph.py b/torch_geometric/transforms/line_graph.py index 0696f22ae486..be97401d1a8c 100644 --- a/torch_geometric/transforms/line_graph.py +++ b/torch_geometric/transforms/line_graph.py @@ -4,7 +4,13 @@ from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform -from torch_geometric.utils import coalesce, cumsum, remove_self_loops, scatter +from torch_geometric.utils import ( + coalesce, + cumsum, + remove_self_loops, + scatter, + to_undirected, +) @functional_transform('line_graph') @@ -17,12 +23,11 @@ class LineGraph(BaseTransform): \mathcal{V}^{\prime} &= \mathcal{E} - \mathcal{E}^{\prime} &= \{ (e_1, e_2) : e_1 \cap e_2 \neq \emptyset \} + \mathcal{E}^{\prime} &= + \{ (e_1, e_2) : e_1 \cap e_2 \neq \emptyset,\, e_1 \neq e_2 \} Line-graph node indices are equal to indices in the original graph's coalesced :obj:`edge_index`. - For undirected graphs, the maximum line-graph node index is - :obj:`(data.edge_index.size(1) // 2) - 1`. New node features are given by old edge attributes. For undirected graphs, edge attributes for reciprocal edges @@ -43,17 +48,23 @@ def forward(self, data: Data) -> Data: edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes=N) row, col = edge_index + def vector_arange(start: Tensor, step: Tensor) -> Tensor: + step_csum = cumsum(step)[:-1] + offset = (step_csum - start).repeat_interleave(step) + rng = torch.arange(offset.size(0), dtype=torch.long, + device=offset.device) + return rng - offset + if self.force_directed or data.is_directed(): i = torch.arange(row.size(0), dtype=torch.long, device=row.device) count = scatter(torch.ones_like(row), row, dim=0, dim_size=data.num_nodes, reduce='sum') ptr = cumsum(count) - - cols = [i[ptr[col[j]]:ptr[col[j] + 1]] for j in range(col.size(0))] - rows = [row.new_full((c.numel(), ), j) for j, c in enumerate(cols)] - - row, col = torch.cat(rows, dim=0), torch.cat(cols, dim=0) + col_count = count[col] + row = i.repeat_interleave(col_count) + col_index = vector_arange(ptr[col], col_count) + col = i[col_index] data.edge_index = torch.stack([row, col], dim=0) data.x = data.edge_attr @@ -61,39 +72,31 @@ def forward(self, data: Data) -> Data: else: # Compute node indices. - mask = row < col - row, col = row[mask], col[mask] - i = torch.arange(row.size(0), dtype=torch.long, device=row.device) - - (row, col), i = coalesce( - torch.stack([ - torch.cat([row, col], dim=0), - torch.cat([col, row], dim=0) - ], dim=0), - torch.cat([i, i], dim=0), - N, - ) + mask = row <= col + edge_index = edge_index[:, mask] + N = edge_index.size(1) + i = torch.arange(N, dtype=torch.long, device=row.device) + (row, col), i = to_undirected(edge_index, i, num_nodes=N, + reduce='min') # Compute new edge indices according to `i`. count = scatter(torch.ones_like(row), row, dim=0, dim_size=data.num_nodes, reduce='sum') - joints = list(torch.split(i, count.tolist())) - - def generate_grid(x: Tensor) -> Tensor: - row = x.view(-1, 1).repeat(1, x.numel()).view(-1) - col = x.repeat(x.numel()) - return torch.stack([row, col], dim=0) + ptr = cumsum(count) + row = i.repeat_interleave(count[row]) + count_rep = count.repeat_interleave(count) + col_index = vector_arange(ptr[:-1].repeat_interleave(count), + count_rep) + col = i[col_index] - joints = [generate_grid(joint) for joint in joints] - joint = torch.cat(joints, dim=1) - joint, _ = remove_self_loops(joint) - N = row.size(0) // 2 - joint = coalesce(joint, num_nodes=N) + index = torch.stack([row, col], dim=0) + index, _ = remove_self_loops(index) + index = coalesce(index, num_nodes=N) if edge_attr is not None: data.x = scatter(edge_attr, i, dim=0, dim_size=N, reduce='sum') - data.edge_index = joint - data.num_nodes = edge_index.size(1) // 2 + data.edge_index = index + data.num_nodes = N data.edge_attr = None