diff --git a/torch_geometric/metrics/link_pred.py b/torch_geometric/metrics/link_pred.py index b9ad50a642a5..95e274dcd619 100644 --- a/torch_geometric/metrics/link_pred.py +++ b/torch_geometric/metrics/link_pred.py @@ -492,7 +492,7 @@ def _compute(self, data: LinkPredMetricData) -> Tensor: self.discount, self.discount.new_full((1, ), fill_value=float('inf')), ]) - discount = discount[pos.clamp(max=self.k + 1)] + discount = discount[pos.clamp(max=self.k)] idcg = scatter( # Apply discount and aggregate: data.edge_label_weight / discount,