diff --git a/model.py b/model.py index 17f1f56..5e0d281 100644 --- a/model.py +++ b/model.py @@ -865,7 +865,7 @@ def get_mrr_ndcg(calc_ndcg=False): for target_slot_idx, ota in zip(targets, outputs_topall_idxes): target_rank_idx = -1 for rank_idx, slot_idx in enumerate(ota): - if target_slot_idx.item() == slot_idx.item(): + if -1 == target_rank_idx and target_slot_idx.item() == slot_idx.item(): target_rank_idx = rank_idx if not calc_ndcg: break