-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
- Loading branch information
1 parent
a5e05f2
commit 9c91169
Showing
4 changed files
with
31 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,38 @@ | ||
import torch | ||
|
||
import torch_geometric.typing | ||
from torch_geometric.testing import withPackage | ||
from torch_geometric.nn import LPFormer | ||
from torch_geometric.testing import withPackage | ||
from torch_geometric.typing import SparseTensor | ||
from torch_geometric.utils import to_undirected | ||
|
||
|
||
@withPackage('numba') # For ppr calculation | ||
def test_lpformer(): | ||
model = LPFormer(16, 32, num_gnn_layers=2, | ||
num_transformer_layers=1) | ||
assert str(model) == 'LPFormer(16, 32, num_gnn_layers=2, num_transformer_layers=1)' | ||
model = LPFormer(16, 32, num_gnn_layers=2, num_transformer_layers=1) | ||
assert str( | ||
model | ||
) == 'LPFormer(16, 32, num_gnn_layers=2, num_transformer_layers=1)' | ||
|
||
num_nodes = 20 | ||
x = torch.randn(num_nodes, 16) | ||
edges = torch.randint(0, num_nodes-1, (2, 110)) | ||
edges = torch.randint(0, num_nodes - 1, (2, 110)) | ||
edge_index, test_edges = edges[:, :100], edges[:, 100:] | ||
edge_index = to_undirected(edge_index) | ||
|
||
ppr_matrix = model.calc_sparse_ppr(edge_index, num_nodes, eps=1e-4) | ||
|
||
assert ppr_matrix.is_sparse | ||
assert ppr_matrix.size() == (num_nodes, num_nodes) | ||
assert ppr_matrix.sum().item() > 0 | ||
|
||
# Test with dense edge_index | ||
out = model(test_edges, x, edge_index, ppr_matrix) | ||
assert out.size() == (10,) | ||
assert out.size() == (10, ) | ||
|
||
# Test with sparse edge_index | ||
if torch_geometric.typing.WITH_TORCH_SPARSE: | ||
adj = SparseTensor.from_edge_index(edge_index, | ||
adj = SparseTensor.from_edge_index(edge_index, | ||
sparse_sizes=(num_nodes, num_nodes)) | ||
out2 = model(test_edges, x, adj, ppr_matrix) | ||
assert out2.size() == (10,) | ||
assert out2.size() == (10, ) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters