From 3ff0332d67df45ee6d3fa588262dd26e07188602 Mon Sep 17 00:00:00 2001 From: Donghyeon Kim <12129692+donghyeonk@users.noreply.github.com> Date: Mon, 24 Apr 2023 23:40:47 +0900 Subject: [PATCH] check if found or not --- model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model.py b/model.py index 17f1f56..5e0d281 100644 --- a/model.py +++ b/model.py @@ -865,7 +865,7 @@ def get_mrr_ndcg(calc_ndcg=False): for target_slot_idx, ota in zip(targets, outputs_topall_idxes): target_rank_idx = -1 for rank_idx, slot_idx in enumerate(ota): - if target_slot_idx.item() == slot_idx.item(): + if -1 == target_rank_idx and target_slot_idx.item() == slot_idx.item(): target_rank_idx = rank_idx if not calc_ndcg: break