From 9c73f9908bfce4d8c47d23fa00b55d83ad4f1399 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Mon, 20 Jan 2025 14:11:44 +0100 Subject: [PATCH] Fix `LinkPredMetric` for empty ground-truths (#9962) --- test/metrics/test_link_pred_metric.py | 17 +++++++++++++++++ torch_geometric/metrics/link_pred.py | 8 ++++++++ 2 files changed, 25 insertions(+) diff --git a/test/metrics/test_link_pred_metric.py b/test/metrics/test_link_pred_metric.py index 13e338cb1002..67a57683d8a7 100644 --- a/test/metrics/test_link_pred_metric.py +++ b/test/metrics/test_link_pred_metric.py @@ -221,3 +221,20 @@ def test_link_pred_metric_collection(num_src_nodes, num_dst_nodes, num_edges): metric_collection.update(pred_index_mat, edge_label_index) assert metric_collection.compute() == expected metric_collection.reset() + + +def test_empty_ground_truth(): + pred = torch.rand(10, 5) + pred_index_mat = pred.argsort(dim=1) + edge_label_index = torch.empty(2, 0, dtype=torch.long) + edge_label_weight = torch.empty(0) + + metric = LinkPredMAP(k=5) + metric.update(pred_index_mat, edge_label_index) + assert metric.compute() == 0 + metric.reset() + + metric = LinkPredNDCG(k=5, weighted=True) + metric.update(pred_index_mat, edge_label_index, edge_label_weight) + assert metric.compute() == 0 + metric.reset() diff --git a/torch_geometric/metrics/link_pred.py b/torch_geometric/metrics/link_pred.py index b80ecc706054..b9ad50a642a5 100644 --- a/torch_geometric/metrics/link_pred.py +++ b/torch_geometric/metrics/link_pred.py @@ -30,6 +30,14 @@ def pred_rel_mat(self) -> Tensor: if hasattr(self, '_pred_rel_mat'): return self._pred_rel_mat # type: ignore + if self.edge_label_index[1].numel() == 0: + self._pred_rel_mat = torch.zeros_like( + self.pred_index_mat, + dtype=torch.bool if self.edge_label_weight is None else + torch.get_default_dtype(), + ) + return self._pred_rel_mat + # Flatten both prediction and ground-truth indices, and determine # overlaps afterwards via `torch.searchsorted`. max_index = max( # type: ignore