Skip to content

Commit

Permalink
Add LoTTE eval script
Browse files Browse the repository at this point in the history
  • Loading branch information
santhnm2 committed Jul 31, 2023
1 parent 3df30ae commit bf36ec0
Showing 1 changed file with 85 additions and 0 deletions.
85 changes: 85 additions & 0 deletions utility/evaluate/evaluate_lotte_rankings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import argparse
from collections import defaultdict
import jsonlines
import os
import sys


def evaluate_dataset(query_type, dataset, split, k, data_rootdir, rankings_rootdir):
data_path = os.path.join(data_rootdir, dataset, split)
rankings_path = os.path.join(
rankings_rootdir, split, f"{dataset}.{query_type}.ranking.tsv"
)
if not os.path.exists(rankings_path):
print(f"[query_type={query_type}, dataset={dataset}] Success@{k}: ???")
return
rankings = defaultdict(list)
with open(rankings_path, "r") as f:
for line in f:
items = line.strip().split("\t")
qid, pid, rank = items[:3]
qid = int(qid)
pid = int(pid)
rank = int(rank)
rankings[qid].append(pid)
assert rank == len(rankings[qid])

success = 0
qas_path = os.path.join(data_path, f"qas.{query_type}.jsonl")

num_total_qids = 0
with jsonlines.open(qas_path, mode="r") as f:
for line in f:
qid = int(line["qid"])
num_total_qids += 1
if qid not in rankings:
print(f"WARNING: qid {qid} not found in {rankings_path}!", file=sys.stderr)
continue
answer_pids = set(line["answer_pids"])
if len(set(rankings[qid][:k]).intersection(answer_pids)) > 0:
success += 1
print(
f"[query_type={query_type}, dataset={dataset}] "
f"Success@{k}: {success / num_total_qids * 100:.1f}"
)


def main(args):
for query_type in ["search", "forum"]:
for dataset in [
"writing",
"recreation",
"science",
"technology",
"lifestyle",
"pooled",
]:
evaluate_dataset(
query_type,
dataset,
args.split,
args.k,
args.data_dir,
args.rankings_dir,
)
print()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="LoTTE evaluation script")
parser.add_argument("--k", type=int, default=5, help="Success@k")
parser.add_argument(
"-s", "--split", choices=["dev", "test"], required=True, help="Split"
)
parser.add_argument(
"-d", "--data_dir", type=str, required=True, help="Path to LoTTE data directory"
)
parser.add_argument(
"-r",
"--rankings_dir",
type=str,
required=True,
help="Path to LoTTE rankings directory",
)
args = parser.parse_args()
main(args)

0 comments on commit bf36ec0

Please sign in to comment.