Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 16, 2025
1 parent c2bc98f commit 3b8bd00
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 41 deletions.
15 changes: 7 additions & 8 deletions examples/lpformer.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
import random
import numpy as np
from tqdm import tqdm
from argparse import ArgumentParser
from collections import defaultdict

import numpy as np
import torch
from ogb.linkproppred import Evaluator, PygLinkPropPredDataset
from torch.utils.data import DataLoader
from torch_sparse import SparseTensor

from ogb.linkproppred import PygLinkPropPredDataset, Evaluator
from tqdm import tqdm

from torch_geometric.nn.models import LPFormer

parser = ArgumentParser()
parser.add_argument('--data_name', type=str, default='ogbl-ppa')
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--runs', help="# random seeds to run over",
type=int, default=5)
parser.add_argument('--runs', help="# random seeds to run over", type=int,
default=5)
parser.add_argument('--batch_size', type=int, default=32768)
parser.add_argument('--hidden_channels', type=int, default=64)
parser.add_argument('--gnn_layers', type=int, default=3)
Expand Down Expand Up @@ -70,6 +69,7 @@
ppr_matrix = model.calc_sparse_ppr(data.edge_index, data.num_nodes,
eps=args.eps)


def train_epoch():
model.train()
train_pos = split_data['train_pos'].to(device)
Expand Down Expand Up @@ -184,8 +184,7 @@ def set_seeds(seed):
best_valid_test = eval_test

print(
f"\nBest Performance:\n Valid={best_valid}\n Test={best_valid_test}"
)
f"\nBest Performance:\n Valid={best_valid}\n Test={best_valid_test}")
val_perf_runs.append(best_valid)
test_perf_runs.append(best_valid_test)

Expand Down
17 changes: 9 additions & 8 deletions test/nn/models/test_lpformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,30 @@


def test_lpformer():
model = LPFormer(16, 32, num_gnn_layers=2,
num_transformer_layers=1)
assert str(model) == 'LPFormer(16, 32, num_gnn_layers=2, num_transformer_layers=1)'
model = LPFormer(16, 32, num_gnn_layers=2, num_transformer_layers=1)
assert str(
model
) == 'LPFormer(16, 32, num_gnn_layers=2, num_transformer_layers=1)'

num_nodes = 20
x = torch.randn(num_nodes, 16)
edges = torch.randint(0, num_nodes-1, (2, 110))
edges = torch.randint(0, num_nodes - 1, (2, 110))
edge_index, test_edges = edges[:, :100], edges[:, 100:]
edge_index = to_undirected(edge_index)

ppr_matrix = model.calc_sparse_ppr(edge_index, num_nodes, eps=1e-4)

assert ppr_matrix.is_sparse
assert ppr_matrix.size() == (num_nodes, num_nodes)
assert ppr_matrix.sum().item() > 0

# Test with dense edge_index
out = model(test_edges, x, edge_index, ppr_matrix)
assert out.size() == (10,)
assert out.size() == (10, )

# Test with sparse edge_index
if torch_geometric.typing.WITH_TORCH_SPARSE:
adj = SparseTensor.from_edge_index(edge_index,
adj = SparseTensor.from_edge_index(edge_index,
sparse_sizes=(num_nodes, num_nodes))
out2 = model(test_edges, x, adj, ppr_matrix)
assert out2.size() == (10,)
assert out2.size() == (10, )
38 changes: 14 additions & 24 deletions torch_geometric/nn/models/lpformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.typing import SparseTensor

from ...utils import softmax, get_ppr, scatter
from ...typing import OptTensor, Tuple, Adj
from ...nn.conv import MessagePassing
from ...nn.dense.linear import Linear
from ...nn.inits import glorot, zeros
from ...typing import Adj, OptTensor, Tuple
from ...utils import get_ppr, scatter, softmax
from .basic_gnn import GCN


Expand Down Expand Up @@ -183,8 +184,7 @@ def forward(
return logits

def propagate(self, x: Tensor, adj: Adj) -> Tensor:
"""
Propagate via GNN
"""Propagate via GNN
Args:
x (Tensor): Node features
Expand Down Expand Up @@ -281,8 +281,7 @@ def get_pos_encodings(
def compute_node_mask(
self, batch: Tensor, adj: Tensor, ppr_matrix: Tensor
) -> Tuple[Tuple, Optional[Tuple], Optional[Tuple]]:
r"""
Get mask based on type of node
r"""Get mask based on type of node
When mask_type is not "cn", also return the ppr vals for both
the source and target.
Expand Down Expand Up @@ -354,15 +353,13 @@ def compute_node_mask(
else:
return (cn_ix, cn_src_ppr,
cn_tgt_ppr), (onehop_ix, onehop_src_ppr,
onehop_tgt_ppr), (non1hop_ix,
non1hop_sppr,
onehop_tgt_ppr), (non1hop_ix, non1hop_sppr,
non1hop_tppr)

def get_ppr_vals(
self, batch: Tensor, pair_diff_adj: Tensor,
ppr_matrix: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
r"""
Get the src and tgt ppr vals.
r"""Get the src and tgt ppr vals.
Returns the: link the node belongs to, type of node
(e.g., CN), PPR relative to src, PPR relative to tgt.
Expand Down Expand Up @@ -446,8 +443,7 @@ def get_structure_cnts(
onehop_info: Tuple[Tensor, Tensor],
non1hop_info: Optional[Tuple[Tensor, Tensor]],
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""
Counts for CNs, 1-Hop, and >1-Hop that satisfy PPR threshold
"""Counts for CNs, 1-Hop, and >1-Hop that satisfy PPR threshold
Also include total # of neighbors
Expand Down Expand Up @@ -483,8 +479,7 @@ def get_structure_cnts(
def get_num_ppr_thresh(self, batch: Tensor, node_mask: Tensor,
src_ppr: Tensor, tgt_ppr: Tensor,
thresh: float) -> Tensor:
"""
Get # of nodes `v` where `ppr(a, v) >= eta` & `ppr(b, v) >= eta`
"""Get # of nodes `v` where `ppr(a, v) >= eta` & `ppr(b, v) >= eta`
Args:
batch (Tensor): The batch vector.
Expand All @@ -509,8 +504,7 @@ def get_count(
node_mask: Tensor,
batch: Tensor,
) -> Tensor:
"""
# of nodes for each sample in batch
"""# of nodes for each sample in batch
They node have already filtered by PPR beforehand
Expand Down Expand Up @@ -597,8 +591,7 @@ def get_non_1hop_ppr(self, batch: Tensor, adj: Tensor,

def calc_sparse_ppr(self, edge_index: Tensor, num_nodes: int,
alpha: float = 0.15, eps: float = 5e-5) -> Tensor:
r"""
Calculate the PPR of the graph in sparse format
r"""Calculate the PPR of the graph in sparse format
Args:
edge_index: The edge indices
Expand All @@ -616,8 +609,7 @@ def calc_sparse_ppr(self, edge_index: Tensor, num_nodes: int,


class LPAttLayer(MessagePassing):
r"""
Attention Layer for pairwise interaction module.
r"""Attention Layer for pairwise interaction module.
Args:
in_channels (int): Size of input dimension
Expand Down Expand Up @@ -693,8 +685,7 @@ def forward(
node_feats: Tensor,
ppr_rpes: Tensor,
) -> Tensor:
"""
Runs the forward pass of the module.
"""Runs the forward pass of the module.
Args:
edge_index (Tensor): The edge indices.
Expand Down Expand Up @@ -747,8 +738,7 @@ def message(self, x_i: Tensor, x_j: Tensor, ppr_rpes: Tensor,


class MLP(nn.Module):
"""
L Layer MLP
"""L Layer MLP
"""
def __init__(self, in_channels: int, hid_channels: int, out_channels: int,
num_layers: int = 2, drop: int = 0, norm: str = "layer"):
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/nn/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,4 +248,4 @@ def forward(
return (x, emb) if isinstance(return_emb, bool) else x

def __repr__(self) -> str:
return f'{self.__class__.__name__}({str(self.channel_list)[1:-1]})'
return f'{self.__class__.__name__}({str(self.channel_list)[1:-1]})'

0 comments on commit 3b8bd00

Please sign in to comment.