Skip to content
Permalink

Comparing changes

This is a direct comparison between two commits made in this repository or its related repositories. View the default comparison for this range or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: mjpost/sacrebleu
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: 2b6ff31ef51b815103fcccbc73ddcffa5f6c05a6
Choose a base ref
..
head repository: mjpost/sacrebleu
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: 5d0fd1746873ee27540c6ad192b0dc607f183095
Choose a head ref
Showing with 26 additions and 12 deletions.
  1. +3 −0 CHANGELOG.md
  2. +1 −1 sacrebleu/__init__.py
  3. +8 −5 sacrebleu/dataset/wmt_xml.py
  4. +14 −6 sacrebleu/utils.py
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Release Notes

- 2.2.1 (2022-09-13)
Bugfix: Standard usage was returning (and using) each reference twice.

- 2.2.0 (2022-07-25)
Features:
- Added WMT21 datasets (thanks to @BrighXiaoHan)
2 changes: 1 addition & 1 deletion sacrebleu/__init__.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '2.2.0'
__version__ = '2.2.1'
__description__ = 'Hassle-free computation of shareable, comparable, and reproducible BLEU, chrF, and TER scores'


13 changes: 8 additions & 5 deletions sacrebleu/dataset/wmt_xml.py
Original file line number Diff line number Diff line change
@@ -95,11 +95,6 @@ def get_sents(doc):
orig_langs.append(origlang)
src_sent_count += 1

# For backward compatibility, if "ref" is not in the fields,
# add reference sentences from the first translator as "ref" field
if "ref" not in refs:
refs["ref"] = refs[min(refs.keys())]

return {"src": src, **refs, "docid": docids, "origlang": orig_langs,}

def process_to_text(self, langpair=None):
@@ -130,6 +125,14 @@ def process_to_text(self, langpair=None):
for line in fields[fieldname]:
print(self._clean(line), file=fout)

def get_reference_files(self, langpair):
all_files = self.get_files(langpair)
all_fields = self.fieldnames(langpair)
ref_files = [
f for f, field in zip(all_files, all_fields) if field.startswith("ref")
]
return ref_files

def fieldnames(self, langpair):
"""
Return a list of all the field names. For most source, this is just
20 changes: 14 additions & 6 deletions sacrebleu/utils.py
Original file line number Diff line number Diff line change
@@ -274,12 +274,12 @@ def args_to_dict(args, prefix: str, strip_prefix: bool = False):
return d


def print_test_set(test_set, langpair, fields, origlang=None, subset=None):
def print_test_set(test_set, langpair, requested_fields, origlang=None, subset=None):
"""Prints to STDOUT the specified side of the specified test set.
:param test_set: the test set to print
:param langpair: the language pair
:param fields: the fields to print
:param requested_fields: the fields to print
:param origlang: print only sentences with a given original language (2-char ISO639-1 code), "non-" prefix means negation
:param subset: print only sentences whose document annotation matches a given regex
"""
@@ -289,17 +289,25 @@ def print_test_set(test_set, langpair, fields, origlang=None, subset=None):
fieldnames = DATASETS[test_set].fieldnames(langpair)
all_files = DATASETS[test_set].get_files(langpair)

if "all" in fields and len(fields) != 1:
if "all" in requested_fields and len(requested_fields) != 1:
sacrelogger.error("Cannot use --echo all with other fields")
sys.exit(1)
elif "all" in fields:
fields = fieldnames
elif "all" in requested_fields:
requested_fields = fieldnames

# backwards compatibility: allow "ref" even if not present (choose first)
if "ref" in requested_fields and "ref" not in fieldnames:
replacement_ref = min([f for f in fieldnames if f.startswith("ref")])
requested_fields = [f if f != "ref" else replacement_ref for f in requested_fields]

files = []
for field in fields:
for field in requested_fields:
if field not in fieldnames:
sacrelogger.error(f"No such field {field} in test set {test_set} for language pair {langpair}.")
sacrelogger.error(f"available fields for {test_set}/{langpair}: {', '.join(fieldnames)}")
if "ref" not in fieldnames:
subref = min([f for f in fieldnames if f.startswith("ref")])
sacrelogger.error(f"'ref' also allowed for backwards compatibility (will return {subref})")
sys.exit(1)
index = fieldnames.index(field)
files.append(all_files[index])