Skip to content

Commit

Permalink
add generate_second_order_model method with unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan von Pichowski committed Jan 23, 2025
1 parent b1bbd91 commit b974e96
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 0 deletions.
65 changes: 65 additions & 0 deletions src/pathpyG/utils/dbgnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@

import torch

from torch_geometric.utils import coalesce
from torch_geometric.data import Data

from pathpyG.algorithms.lift_order import aggregate_edge_index
from pathpyG.core.graph import Graph
from pathpyG.core.index_map import IndexMap
from pathpyG.core.multi_order_model import MultiOrderModel
from pathpyG.core.temporal_graph import TemporalGraph


def generate_bipartite_edge_index(g: Graph, g2: Graph, mapping: str = "last") -> torch.Tensor:
Expand All @@ -22,3 +29,61 @@ def generate_bipartite_edge_index(g: Graph, g2: Graph, mapping: str = "last") ->
)

return bipartide_edge_index


def generate_second_order_model(g: TemporalGraph, delta: float | int = 1, weight: str = "edge_weight") -> MultiOrderModel:
"""
Generate a multi-order model with first- and second-order layer from a temporal graph.
This method is optimized for the memory footprint of large graphs and may be slower than creating small models with the default apporach.
"""
data = g.data.sort_by_time()
edge_index1, timestamps1 = data.edge_index, data.time

node_sequence1 = torch.arange(data.num_nodes, device=edge_index1.device).unsqueeze(1)
if weight in data:
edge_weight = data[weight]
else:
edge_weight = torch.ones(edge_index1.size(1), device=edge_index1.device)

layer1 = aggregate_edge_index(
edge_index=edge_index1, node_sequence=node_sequence1, edge_weight=edge_weight
)
layer1.mapping = g.mapping

node_sequence2 = torch.cat([node_sequence1[edge_index1[0]], node_sequence1[edge_index1[1]][:, -1:]], dim=1)
node_sequence2, edge1_to_node2 = torch.unique(node_sequence2, dim=0, return_inverse=True)

edge_index2 = []
edge_weight2 = []
for timestamp in timestamps1.unique():
src_nodes2, src_nodes2_counts = edge1_to_node2[timestamps1 == timestamp].unique(return_counts=True)
dst_nodes2, dst_nodes2_counts = edge1_to_node2[(timestamps1 > timestamp) & (timestamps1 <= timestamp + delta)].unique(return_counts=True)
for src_node2, src_node2_count in zip(src_nodes2, src_nodes2_counts):
dst_node2 = dst_nodes2[node_sequence2[dst_nodes2, 0] == node_sequence2[src_node2, -1]]
dst_node2_count = dst_nodes2_counts[node_sequence2[dst_nodes2, 0] == node_sequence2[src_node2, -1]]

edge_index2.append(torch.stack([src_node2.expand(dst_node2.size(0)), dst_node2], dim=0))
edge_weight2.append(src_node2_count.expand(dst_node2.size(0)) * dst_node2_count)

edge_index2 = torch.cat(edge_index2, dim=1)
edge_weight2 = torch.cat(edge_weight2, dim=0)

edge_index2, edge_weight2 = coalesce(edge_index2, edge_attr=edge_weight2, num_nodes=node_sequence2.size(0), reduce="sum")

data2 = Data(
edge_index=edge_index2,
num_nodes=node_sequence2.size(0),
node_sequence=node_sequence2,
edge_weight=edge_weight2,
inverse_idx=edge1_to_node2,
)
layer2 = Graph(data2)
layer2.mapping = IndexMap(
[tuple(layer1.mapping.to_ids(v.cpu())) for v in node_sequence2]
)


m = MultiOrderModel()
m.layers[1] = layer1
m.layers[2] = layer2
return m
29 changes: 29 additions & 0 deletions tests/utils/test_generate_second_order_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch

from pathpyG.core.multi_order_model import MultiOrderModel
from pathpyG.core.temporal_graph import TemporalGraph
from pathpyG.utils.dbgnn import generate_second_order_model

def test_generate_second_order_model():
tedges = [('a', 'b', 1), ('c', 'b', 1), ('c', 'a', 1), ('c', 'a', 1), ('f', 'c', 1),
('b', 'c', 5), ('a', 'd', 5), ('c', 'd', 9), ('a', 'd', 9), ('c', 'e', 9),
('c', 'f', 11), ('f', 'a', 13), ('a', 'g', 18), ('b', 'f', 21),
('a', 'g', 26), ('c', 'f', 27), ('h', 'f', 27), ('g', 'h', 28),
('a', 'c', 30), ('a', 'b', 31), ('c', 'h', 32), ('f', 'h', 33),
('b', 'i', 42), ('i', 'b', 42), ('c', 'i', 47), ('h', 'i', 50)]

g = TemporalGraph.from_edge_list(tedges)
reference = MultiOrderModel.from_temporal_graph(g, max_order=2, delta=10).to_dbgnn_data()

g = TemporalGraph.from_edge_list(tedges)
result = generate_second_order_model(g, delta=10).to_dbgnn_data()

assert result.num_nodes == reference.num_nodes
assert result.num_ho_nodes == reference.num_ho_nodes
assert torch.equal(result.x, reference.x)
assert torch.equal(result.edge_index, reference.edge_index)
assert torch.equal(result.edge_weights, reference.edge_weights)
assert torch.equal(result.x_h, reference.x_h)
assert torch.equal(result.edge_index_higher_order, reference.edge_index_higher_order)
assert torch.equal(result.edge_weights_higher_order, reference.edge_weights_higher_order)
assert torch.equal(result.bipartite_edge_index, reference.bipartite_edge_index)

0 comments on commit b974e96

Please sign in to comment.