diff --git a/src/pathpyG/utils/dbgnn.py b/src/pathpyG/utils/dbgnn.py index ec902729..773d45c2 100644 --- a/src/pathpyG/utils/dbgnn.py +++ b/src/pathpyG/utils/dbgnn.py @@ -2,7 +2,14 @@ import torch +from torch_geometric.utils import coalesce +from torch_geometric.data import Data + +from pathpyG.algorithms.lift_order import aggregate_edge_index from pathpyG.core.graph import Graph +from pathpyG.core.index_map import IndexMap +from pathpyG.core.multi_order_model import MultiOrderModel +from pathpyG.core.temporal_graph import TemporalGraph def generate_bipartite_edge_index(g: Graph, g2: Graph, mapping: str = "last") -> torch.Tensor: @@ -22,3 +29,61 @@ def generate_bipartite_edge_index(g: Graph, g2: Graph, mapping: str = "last") -> ) return bipartide_edge_index + + +def generate_second_order_model(g: TemporalGraph, delta: float | int = 1, weight: str = "edge_weight") -> MultiOrderModel: + """ + Generate a multi-order model with first- and second-order layer from a temporal graph. + This method is optimized for the memory footprint of large graphs and may be slower than creating small models with the default apporach. + """ + data = g.data.sort_by_time() + edge_index1, timestamps1 = data.edge_index, data.time + + node_sequence1 = torch.arange(data.num_nodes, device=edge_index1.device).unsqueeze(1) + if weight in data: + edge_weight = data[weight] + else: + edge_weight = torch.ones(edge_index1.size(1), device=edge_index1.device) + + layer1 = aggregate_edge_index( + edge_index=edge_index1, node_sequence=node_sequence1, edge_weight=edge_weight + ) + layer1.mapping = g.mapping + + node_sequence2 = torch.cat([node_sequence1[edge_index1[0]], node_sequence1[edge_index1[1]][:, -1:]], dim=1) + node_sequence2, edge1_to_node2 = torch.unique(node_sequence2, dim=0, return_inverse=True) + + edge_index2 = [] + edge_weight2 = [] + for timestamp in timestamps1.unique(): + src_nodes2, src_nodes2_counts = edge1_to_node2[timestamps1 == timestamp].unique(return_counts=True) + dst_nodes2, dst_nodes2_counts = edge1_to_node2[(timestamps1 > timestamp) & (timestamps1 <= timestamp + delta)].unique(return_counts=True) + for src_node2, src_node2_count in zip(src_nodes2, src_nodes2_counts): + dst_node2 = dst_nodes2[node_sequence2[dst_nodes2, 0] == node_sequence2[src_node2, -1]] + dst_node2_count = dst_nodes2_counts[node_sequence2[dst_nodes2, 0] == node_sequence2[src_node2, -1]] + + edge_index2.append(torch.stack([src_node2.expand(dst_node2.size(0)), dst_node2], dim=0)) + edge_weight2.append(src_node2_count.expand(dst_node2.size(0)) * dst_node2_count) + + edge_index2 = torch.cat(edge_index2, dim=1) + edge_weight2 = torch.cat(edge_weight2, dim=0) + + edge_index2, edge_weight2 = coalesce(edge_index2, edge_attr=edge_weight2, num_nodes=node_sequence2.size(0), reduce="sum") + + data2 = Data( + edge_index=edge_index2, + num_nodes=node_sequence2.size(0), + node_sequence=node_sequence2, + edge_weight=edge_weight2, + inverse_idx=edge1_to_node2, + ) + layer2 = Graph(data2) + layer2.mapping = IndexMap( + [tuple(layer1.mapping.to_ids(v.cpu())) for v in node_sequence2] + ) + + + m = MultiOrderModel() + m.layers[1] = layer1 + m.layers[2] = layer2 + return m \ No newline at end of file diff --git a/tests/utils/test_generate_second_order_model.py b/tests/utils/test_generate_second_order_model.py new file mode 100644 index 00000000..1e8f418c --- /dev/null +++ b/tests/utils/test_generate_second_order_model.py @@ -0,0 +1,29 @@ +import torch + +from pathpyG.core.multi_order_model import MultiOrderModel +from pathpyG.core.temporal_graph import TemporalGraph +from pathpyG.utils.dbgnn import generate_second_order_model + +def test_generate_second_order_model(): + tedges = [('a', 'b', 1), ('c', 'b', 1), ('c', 'a', 1), ('c', 'a', 1), ('f', 'c', 1), + ('b', 'c', 5), ('a', 'd', 5), ('c', 'd', 9), ('a', 'd', 9), ('c', 'e', 9), + ('c', 'f', 11), ('f', 'a', 13), ('a', 'g', 18), ('b', 'f', 21), + ('a', 'g', 26), ('c', 'f', 27), ('h', 'f', 27), ('g', 'h', 28), + ('a', 'c', 30), ('a', 'b', 31), ('c', 'h', 32), ('f', 'h', 33), + ('b', 'i', 42), ('i', 'b', 42), ('c', 'i', 47), ('h', 'i', 50)] + + g = TemporalGraph.from_edge_list(tedges) + reference = MultiOrderModel.from_temporal_graph(g, max_order=2, delta=10).to_dbgnn_data() + + g = TemporalGraph.from_edge_list(tedges) + result = generate_second_order_model(g, delta=10).to_dbgnn_data() + + assert result.num_nodes == reference.num_nodes + assert result.num_ho_nodes == reference.num_ho_nodes + assert torch.equal(result.x, reference.x) + assert torch.equal(result.edge_index, reference.edge_index) + assert torch.equal(result.edge_weights, reference.edge_weights) + assert torch.equal(result.x_h, reference.x_h) + assert torch.equal(result.edge_index_higher_order, reference.edge_index_higher_order) + assert torch.equal(result.edge_weights_higher_order, reference.edge_weights_higher_order) + assert torch.equal(result.bipartite_edge_index, reference.bipartite_edge_index) \ No newline at end of file