diff --git a/sacrebleu/sacrebleu.py b/sacrebleu/sacrebleu.py index e34a7a8..295186f 100755 --- a/sacrebleu/sacrebleu.py +++ b/sacrebleu/sacrebleu.py @@ -343,6 +343,7 @@ def main(): # Read references ################# full_refs = [[] for x in range(max(len(concat_ref_files[0]), args.num_refs))] + split_points = [] for ref_files in concat_ref_files: for refno, ref_file in enumerate(ref_files): for lineno, line in enumerate(smart_open(ref_file, encoding=args.encoding), 1): @@ -359,6 +360,7 @@ def main(): sys.exit(17) for refno, ref in enumerate(refs): full_refs[refno].append(ref) + split_points.append(len(full_refs[0])) # Decide on the number of final references, override the argument args.num_refs = len(full_refs) @@ -419,6 +421,32 @@ def main(): # Set final number of systems num_sys = len(sys_names) + # Merge sentences from same doc for doc aligned datasets. + if args.test_set: + start_point = 0 + doc_aligned_systems = [[] for _ in range(num_sys)] + doc_aligned_refs = [ + [] for _ in range(args.num_refs) + ] # args.num_refs is always "1" here, because args.test_set is not None + for test_set, end_point in zip(args.test_set.split(","), split_points): + dataset = DATASETS[test_set] + for i in range(num_sys): + doc_aligned_systems[i].extend( + dataset.doc_align( + args.langpair, full_systems[i][start_point:end_point], "src" + ) + ) + for i in range(args.num_refs): + doc_aligned_refs[i].extend( + dataset.doc_align( + args.langpair, full_refs[i][start_point:end_point], "ref" + ) + ) + start_point = end_point + + full_systems = doc_aligned_systems + full_refs = doc_aligned_refs + # Add baseline prefix to the first system for clarity if paired_test_mode: if args.input is None: @@ -457,11 +485,6 @@ def main(): # Unpack systems & references back systems, refs = outputs[:num_sys], outputs[num_sys:] - # Merge sentences from same doc for doc aligned datasets. - if args.test_set: - refs = [DATASETS[args.test_set].doc_align(args.langpair, ref, "ref") for ref in refs] - systems = [DATASETS[args.test_set].doc_align(args.langpair, system, "ref") for system in systems] - # Perform some sanity checks for system in systems: if len(system) == 0: