Skip to content

Commit

Permalink
fix support for multiple test set
Browse files Browse the repository at this point in the history
  • Loading branch information
BrightXiaoHan committed Oct 8, 2022
1 parent 2464590 commit efc376d
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions sacrebleu/sacrebleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def main():
# Read references
#################
full_refs = [[] for x in range(max(len(concat_ref_files[0]), args.num_refs))]
split_points = []
for ref_files in concat_ref_files:
for refno, ref_file in enumerate(ref_files):
for lineno, line in enumerate(smart_open(ref_file, encoding=args.encoding), 1):
Expand All @@ -364,6 +365,7 @@ def main():
sys.exit(17)
for refno, ref in enumerate(refs):
full_refs[refno].append(ref)
split_points.append(len(full_refs[0]))

# Decide on the number of final references, override the argument
args.num_refs = len(full_refs)
Expand Down Expand Up @@ -424,6 +426,32 @@ def main():
# Set final number of systems
num_sys = len(sys_names)

# Merge sentences from same doc for doc aligned datasets.
if args.test_set:
start_point = 0
doc_aligned_systems = [[] for _ in range(num_sys)]
doc_aligned_refs = [
[] for _ in range(args.num_refs)
] # args.num_refs is always "1" here, because args.test_set is not None
for test_set, end_point in zip(args.test_set.split(","), split_points):
dataset = DATASETS[test_set]
for i in range(num_sys):
doc_aligned_systems[i].extend(
dataset.doc_align(
args.langpair, full_systems[i][start_point:end_point], "src"
)
)
for i in range(args.num_refs):
doc_aligned_refs[i].extend(
dataset.doc_align(
args.langpair, full_refs[i][start_point:end_point], "ref"
)
)
start_point = end_point

full_systems = doc_aligned_systems
full_refs = doc_aligned_refs

# Add baseline prefix to the first system for clarity
if paired_test_mode:
if args.input is None:
Expand Down Expand Up @@ -462,11 +490,6 @@ def main():
# Unpack systems & references back
systems, refs = outputs[:num_sys], outputs[num_sys:]

# Merge sentences from same doc for doc aligned datasets.
if args.test_set:
refs = [DATASETS[args.test_set].doc_align(args.langpair, ref, "ref") for ref in refs]
systems = [DATASETS[args.test_set].doc_align(args.langpair, system, "ref") for system in systems]

# Perform some sanity checks
for system in systems:
if len(system) == 0:
Expand Down

0 comments on commit efc376d

Please sign in to comment.