diff --git a/.github/workflows/check-build.yml b/.github/workflows/check-build.yml index 4f316bb0..f4056f1c 100644 --- a/.github/workflows/check-build.yml +++ b/.github/workflows/check-build.yml @@ -3,20 +3,40 @@ name: check-build on: pull_request +env: + PYTHONUTF8: "1" + jobs: check-build: - runs-on: ubuntu-20.04 + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: [3.6, 3.7, 3.8, 3.9] + exclude: + - os: windows-latest + python-version: '3.6' # test fails due to UTF8 stuff steps: - - name: update - run: sudo apt-get -y update - - name: install pytest-cov - run: pip install pytest-cov - - uses: actions/checkout@v1 - - name: install - run: sudo python3 setup.py install - - name: install-ja - run: sudo pip install .[ja] - - name: pytest - run: python3 -m pytest - - name: test - run: ./test.sh + # - name: update + # run: sudo apt-get -y update + - uses: actions/checkout@v2 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - if: matrix.os == 'macos-latest' + name: Install Mac OS requirements + run: brew install bash + - if: matrix.os == 'windows-latest' + name: Install Windows requirements + run: choco install wget unzip + - name: Install python dependencies + run: | + python -m pip install --upgrade pip + pip install pytest-cov + pip install .[ja] + - name: Python pytest test suite + run: python3 -m pytest + - name: CLI bash test suite + shell: bash + run: ./test.sh diff --git a/CHANGELOG.md b/CHANGELOG.md index 9fac42bd..719126ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,66 @@ -# VERSION HISTORY +# Release Notes + +- 2.0.0 (2021-07-XX) + - Build: Add Windows and OS X testing to Travis CI. + - Improve documentation and type annotations. + - Drop `Python < 3.6` support and migrate to f-strings. + - Relax `portalocker` version pinning, add `regex, tabulate, numpy` dependencies. + - Drop input type manipulation through `isinstance` checks. If the user does not obey + to the expected annotations, exceptions will be raised. Robustness attempts lead to + confusions and obfuscated score errors in the past (#121) + - Variable # references per segment is supported for all metrics by default. It is + still only available through the API. + - Use colored strings in tabular outputs (multi-system evaluation mode) through + the help of `colorama` package. + - tokenizers: Add caching to tokenizers which seem to speed up things a bit. + - `intl` tokenizer: Use `regex` module. Speed goes from ~4 seconds to ~0.6 seconds + for a particular test set evaluation. (#46) + - Signature: Formatting changed (mostly to remove '+' separator as it was + interfering with chrF++). The field separator is now '|' and key values + are separated with ':' rather than '.'. + - Signature: Boolean true / false values are shortened to yes / no. + - Signature: Number of references is `var` if variable number of references is used. + - Signature: Add effective order (yes/no) to BLEU and chrF signatures. + - Metrics: Scale all metrics into the [0, 100] range (#140) + - Metrics API: Use explicit argument names and defaults for the metrics instead of + passing obscure `argparse.Namespace` objects. + - Metrics API: A base abstract `Metric` class is introduced to guide further + metric development. This class defines the methods that should be implemented + in the derived classes and offers boilerplate methods for the common functionality. + A new metric implemented this way will automatically support significance testing. + - Metrics API: All metrics now receive an optional `references` argument at + initialization time to process and cache the references. Further evaluations + of different systems against the same references becomes faster this way + for example when using significance testing. + - BLEU: In case of no n-gram matches at all, skip smoothing and return 0.0 BLEU (#141). + - CHRF: Added multi-reference support, verified the scores against chrF++.py, added test case. + - CHRF: Added chrF+ support through `word_order` argument. Added test cases against chrF++.py. + Exposed it through the CLI (--chrf-word-order) (#124) + - CHRF: Add possibility to disable effective order smoothing (pass --chrf-eps-smoothing). + This way, the scores obtained are exactly the same as chrF++, Moses and NLTK implementations. + We keep the effective ordering as the default for compatibility, since this only + affects sentence-level scoring with very short sentences. (#144) + - CLI: `--input/-i` can now ingest multiple systems. For this reason, the positional + `references` should always preceed the `-i` flag. + - CLI: Allow modifying TER arguments through CLI. We still keep the TERCOM defaults. + - CLI: Prefix metric-specific arguments with --chrf and --ter. To maintain compatibility, + BLEU argument names are kept the same. + - CLI: Separate metric-specific arguments for clarity when `--help` is printed. + - CLI: Added `--format/-f` flag. The single-system output mode is now `json` by default. + If you want to keep the old text format persistently, you can export `SACREBLEU_FORMAT=text` into your + shell. + - CLI: For multi-system mode, `json` falls back to plain text. `latex` output can only + be generated for multi-system mode. + - CLI: sacreBLEU now supports evaluating multiple systems for a given test set + in an efficient way. Through the use of `tabulate` package, the results are + nicely rendered into a plain text table, LaTeX, HTML or RST (cf. --format/-f argument). + The systems can be either given as a list of plain text files to `-i/--input` or + as a tab-separated single stream redirected into `STDIN`. In the former case, + the basenames of the files will be automatically used as system names. + - Statistical tests: sacreBLEU now supports confidence interval estimation + through bootstrap resampling for single-system evaluation (`--confidence` flag) + as well as paired bootstrap resampling (`--paired-bs`) and paired approximate + randomization tests (`--paired-ar`) when evaluating multiple systems (#40 and #78). - 1.5.1 (2021-03-05) - Fix extraction error for WMT18 extra test sets (test-ts) (#142) diff --git a/DATASETS.md b/DATASETS.md new file mode 100644 index 00000000..29d9d2c8 --- /dev/null +++ b/DATASETS.md @@ -0,0 +1,58 @@ +| Dataset | Description | +| ------------------------------ | ------------------------------------------------------------------------------------------------------------------- | +| mtedx/valid | mTEDx evaluation data, valid: [URL](http://openslr.org/100) | +| mtedx/test | mTEDx evaluation data, test: [URL](http://openslr.org/100) | +| wmt20/robust/set1 | WMT20 robustness task, set 1 | +| wmt20/robust/set2 | WMT20 robustness task, set 2 | +| wmt20/robust/set3 | WMT20 robustness task, set 3 | +| wmt20/tworefs | WMT20 news test sets with two references | +| wmt20 | Official evaluation data for WMT20 | +| mtnt2019 | Test set for the WMT 19 robustness shared task | +| mtnt1.1/test | Test data for the Machine Translation of Noisy Text task: [URL](http://www.cs.cmu.edu/~pmichel1/mtnt/) | +| mtnt1.1/valid | Validation data for the Machine Translation of Noisy Text task: [URL](http://www.cs.cmu.edu/~pmichel1/mtnt/) | +| mtnt1.1/train | Training data for the Machine Translation of Noisy Text task: [URL](http://www.cs.cmu.edu/~pmichel1/mtnt/) | +| wmt20/dev | Development data for tasks new to 2020. | +| wmt19 | Official evaluation data. | +| wmt19/dev | Development data for tasks new to 2019. | +| wmt19/google/ar | Additional high-quality reference for WMT19/en-de. | +| wmt19/google/arp | Additional paraphrase of wmt19/google/ar. | +| wmt19/google/wmtp | Additional paraphrase of the official WMT19 reference. | +| wmt19/google/hqr | Best human selected-reference between wmt19 and wmt19/google/ar. | +| wmt19/google/hqp | Best human-selected reference between wmt19/google/arp and wmt19/google/wmtp. | +| wmt19/google/hqall | Best human-selected reference among original official reference and the Google reference and paraphrases. | +| wmt18 | Official evaluation data. | +| wmt18/test-ts | Official evaluation sources with extra test sets interleaved. | +| wmt18/dev | Development data (Estonian<>English). | +| wmt17 | Official evaluation data. | +| wmt17/B | Additional reference for EN-FI and FI-EN. | +| wmt17/tworefs | Systems with two references. | +| wmt17/improved | Improved zh-en and en-zh translations. | +| wmt17/dev | Development sets released for new languages in 2017. | +| wmt17/ms | Additional Chinese-English references from Microsoft Research. | +| wmt16 | Official evaluation data. | +| wmt16/B | Additional reference for EN-FI. | +| wmt16/tworefs | EN-FI with two references. | +| wmt16/dev | Development sets released for new languages in 2016. | +| wmt15 | Official evaluation data. | +| wmt14 | Official evaluation data. | +| wmt14/full | Evaluation data released after official evaluation for further research. | +| wmt13 | Official evaluation data. | +| wmt12 | Official evaluation data. | +| wmt11 | Official evaluation data. | +| wmt10 | Official evaluation data. | +| wmt09 | Official evaluation data. | +| wmt08 | Official evaluation data. | +| wmt08/nc | Official evaluation data (news commentary). | +| wmt08/europarl | Official evaluation data (Europarl). | +| iwslt17 | Official evaluation data for IWSLT. | +| iwslt17/tst2016 | Development data for IWSLT 2017. | +| iwslt17/tst2015 | Development data for IWSLT 2017. | +| iwslt17/tst2014 | Development data for IWSLT 2017. | +| iwslt17/tst2013 | Development data for IWSLT 2017. | +| iwslt17/tst2012 | Development data for IWSLT 2017. | +| iwslt17/tst2011 | Development data for IWSLT 2017. | +| iwslt17/tst2010 | Development data for IWSLT 2017. | +| iwslt17/dev2010 | Development data for IWSLT 2017. | +| multi30k/2016 | 2016 flickr test set of Multi30k dataset | +| multi30k/2017 | 2017 flickr test set of Multi30k dataset | +| multi30k/2018 | 2018 flickr test set of Multi30k dataset. See [URL](https://competitions.codalab.org/competitions/19917) for evaluation. | diff --git a/README.md b/README.md index 68cfc34f..e56bd86f 100644 --- a/README.md +++ b/README.md @@ -1,115 +1,598 @@ -[![PyPI version](https://badge.fury.io/py/sacrebleu.svg)](https://badge.fury.io/py/sacrebleu) -[![GitHub issues](https://img.shields.io/github/issues/mjpost/sacreBLEU.svg)](https://github.com/awslabs/sockeye/issues) +# sacreBLEU -SacreBLEU ([Post, 2018](http://aclweb.org/anthology/W18-6319)) provides hassle-free computation of shareable, comparable, and reproducible BLEU scores. +[![PyPI version](https://img.shields.io/pypi/v/sacrebleu)](https://img.shields.io/pypi/v/sacrebleu) +[![Python version](https://img.shields.io/pypi/pyversions/sacrebleu)](https://img.shields.io/pypi/pyversions/sacrebleu) +[![GitHub issues](https://img.shields.io/github/issues/mjpost/sacreBLEU.svg)](https://github.com/mjpost/sacrebleu/issues) + +SacreBLEU ([Post, 2018](http://aclweb.org/anthology/W18-6319)) provides hassle-free computation of shareable, comparable, and reproducible **BLEU** scores. Inspired by Rico Sennrich's `multi-bleu-detok.perl`, it produces the official WMT scores but works with plain text. It also knows all the standard test sets and handles downloading, processing, and tokenization for you. -Why use this version of BLEU? +The official version is hosted at . + +# Motivation + +Comparing BLEU scores is harder than it should be. Every decoder has its own implementation, often borrowed from Moses, but maybe with subtle changes. +Moses itself has a number of implementations as standalone scripts, with little indication of how they differ (note: they mostly don't, but `multi-bleu.pl` expects tokenized input). Different flags passed to each of these scripts can produce wide swings in the final score. All of these may handle tokenization in different ways. On top of this, downloading and managing test sets is a moderate annoyance. + +Sacre bleu! What a mess. + +**SacreBLEU** aims to solve these problems by wrapping the original reference implementation ([Papineni et al., 2002](https://www.aclweb.org/anthology/P02-1040.pdf)) together with other useful features. +The defaults are set the way that BLEU should be computed, and furthermore, the script outputs a short version string that allows others to know exactly what you did. +As an added bonus, it automatically downloads and manages test sets for you, so that you can simply tell it to score against `wmt14`, without having to hunt down a path on your local file system. +It is all designed to take BLEU a little more seriously. +After all, even with all its problems, BLEU is the default and---admit it---well-loved metric of our entire research community. +Sacre BLEU. + +# Features + - It automatically downloads common WMT test sets and processes them to plain text - It produces a short version string that facilitates cross-paper comparisons - It properly computes scores on detokenized outputs, using WMT ([Conference on Machine Translation](http://statmt.org/wmt17)) standard tokenization -- It produces the same values as official script (`mteval-v13a.pl`) used by WMT +- It produces the same values as the official script (`mteval-v13a.pl`) used by WMT - It outputs the BLEU score without the comma, so you don't have to remove it with `sed` (Looking at you, `multi-bleu.perl`) +- It supports different tokenizers for BLEU including support for Japanese and Chinese +- It supports **chrF, chrF++** and **Translation error rate (TER)** metrics +- It performs paired bootstrap resampling and paired approximate randomization tests for statistical significance reporting -The official version is hosted at . +# Breaking Changes + +## v2.0.0 + +As of v2.0.0, the default output format is changed to `json` for less painful parsing experience. This means that software that parse the output of sacreBLEU should be modified to either (i) parse the JSON using for example the `jq` utility or (ii) pass `-f text` to sacreBLEU to preserve the old textual output. The latter change can also be made **persistently** by exporting `SACREBLEU_FORMAT=text` in relevant shell configuration files. -# QUICK START +Here's an example of parsing the `score` key of the JSON output using `jq`: -Install the Python module (Python 3 only) +``` +$ sacrebleu -i output.detok.txt -t wmt17 -l en-de | jq -r .score +20.8 +``` + +# Installation - pip3 install sacrebleu +Install the official Python module from PyPI (**Python>=3.6 only**): + + pip install sacrebleu In order to install Japanese tokenizer support through `mecab-python3`, you need to run the following command instead, to perform a full installation with dependencies: - pip3 install sacrebleu[ja] + pip install sacrebleu[ja] + +# Command-line Usage -Alternately, you can install from the source: +You can get a list of available test sets with `sacrebleu --list`. Please see [DATASETS.md](DATASETS.md) +for an up-to-date list of supported datasets. - python3 setup.py install +## Basics -This installs a shell script, `sacrebleu`. -(You can also run `python3 -m sacrebleu`, so long as this root directory is in your `$PYTHONPATH`). +### Downloading test sets -Get a list of available test sets: +Download the **source** for one of the pre-defined test sets: - sacrebleu --list +``` +$ sacrebleu -t wmt17 -l en-de --echo src | head -n1 +28-Year-Old Chef Found Dead at San Francisco Mall +``` + +Download the **reference** for one of the pre-defined test sets: +``` +$ sacrebleu -t wmt17 -l en-de --echo ref | head -n1 +28-jähriger Koch in San Francisco Mall tot aufgefunden +``` + +### JSON output + +As of version `>=2.0.0`, sacreBLEU prints the computed scores in JSON format to make parsing less painful: + +``` +$ sacrebleu -i output.detok.txt -t wmt17 -l en-de +``` -Download the source for one of the pre-defined test sets: +```json +{ + "name": "BLEU", + "score": 20.8, + "signature": "nrefs:1|case:mixed|eff:no|tok:13a|smooth:exp|version:2.0.0", + "verbose_score": "54.4/26.6/14.9/8.7 (BP = 1.000 ratio = 1.026 hyp_len = 62880 ref_len = 61287)", + "nrefs": "1", + "case": "mixed", + "eff": "no", + "tok": "13a", + "smooth": "exp", + "version": "2.0.0" +} +``` - sacrebleu -t wmt14 -l de-en --echo src > wmt14-de-en.src +If you want to keep the old behavior, you can pass `-f text` or export `SACREBLEU_FORMAT=text`: -(you can also use long parameter names for readability): +``` +$ sacrebleu -i output.detok.txt -t wmt17 -l en-de -f text +BLEU|nrefs:1|case:mixed|eff:no|tok:13a|smooth:exp|version:2.0.0 = 20.8 54.4/26.6/14.9/8.7 (BP = 1.000 ratio = 1.026 hyp_len = 62880 ref_len = 61287) +``` - sacrebleu --test-set wmt14 --language-pair de-en --echo src > wmt14-de-en.src +### Scoring -After tokenizing, translating, and detokenizing it, you can score your decoder output easily: +(All examples below assume old-style text output for a compact representation that save space) - cat output.detok.txt | sacrebleu -t wmt14 -l de-en +Let's say that you just translated the `en-de` test set of WMT17 with your fancy MT system and the **detokenized** translations are in a file called `output.detok.txt`: -SacreBLEU knows about common WMT test sets, but you can also use it to score system outputs with arbitrary references. -It also works in backwards compatible model where you manually specify the reference(s), similar to the format of `multi-bleu.txt`: +``` +# Option 1: Redirect system output to STDIN +$ cat output.detok.txt | sacrebleu -t wmt17 -l en-de +BLEU|nrefs:1|case:mixed|eff:no|tok:13a|smooth:exp|version:2.0.0 = 20.8 54.4/26.6/14.9/8.7 (BP = 1.000 ratio = 1.026 hyp_len = 62880 ref_len = 61287) - cat output.detok.txt | sacrebleu REF1 [REF2 ...] +# Option 2: Use the --input/-i argument +$ sacrebleu -t wmt17 -l en-de -i output.detok.txt +BLEU|nrefs:1|case:mixed|eff:no|tok:13a|smooth:exp|version:2.0.0 = 20.8 54.4/26.6/14.9/8.7 (BP = 1.000 ratio = 1.026 hyp_len = 62880 ref_len = 61287) +``` -Note that the system output and references will all be tokenized internally. +You can obtain a short version of the signature with `--short/-sh`: -SacreBLEU generates version strings like the following. -Put them in a footnote in your paper! -Use `--short` for a shorter hash if you like. +``` +$ sacrebleu -t wmt17 -l en-de -i output.detok.txt -sh +BLEU|#:1|c:mixed|e:no|tok:13a|s:exp|v:2.0.0 = 20.8 54.4/26.6/14.9/8.7 (BP = 1.000 ratio = 1.026 hyp_len = 62880 ref_len = 61287) +``` - BLEU+case.mixed+lang.de-en+test.wmt17 = 32.97 66.1/40.2/26.6/18.1 (BP = 0.980 ratio = 0.980 hyp_len = 63134 ref_len = 64399) +If you only want the score to be printed, you can use the `--score-only/-b` flag: + +``` +$ sacrebleu -t wmt17 -l en-de -i output.detok.txt -b +20.8 +``` + +The precision of the scores can be configured via the `--width/-w` flag: + +``` +$ sacrebleu -t wmt17 -l en-de -i output.detok.txt -b -w 4 +20.7965 +``` + +### Using your own reference file + +SacreBLEU knows about common test sets (as detailed in the `--list` example above), but you can also use it to score system outputs with arbitrary references. In this case, do not forget to provide **detokenized** reference and hypotheses files: + +``` +# Let's save the reference to a text file +$ sacrebleu -t wmt17 -l en-de --echo ref > ref.detok.txt + +# Option 1: Pass the reference file as a positional argument to sacreBLEU +$ sacrebleu ref.detok.txt -i output.detok.txt -m bleu -b -w 4 +20.7965 + +# Option 2: Redirect the system into STDIN (Compatible with multi-bleu.perl way of doing things) +$ cat output.detok.txt | sacrebleu ref.detok.txt -m bleu -b -w 4 +20.7965 +``` + +### Using multiple metrics + +Let's first compute BLEU, chrF and TER with the default settings: + +``` +$ sacrebleu -t wmt17 -l en-de -i output.detok.txt -m bleu chrf ter + BLEU|nrefs:1|case:mixed|eff:no|tok:13a|smooth:exp|version:2.0.0 = 20.8 + chrF2|nrefs:1|case:mixed|eff:yes|nc:6|nw:0|space:no|version:2.0.0 = 52.0 +TER|nrefs:1|case:lc|tok:tercom|norm:no|punct:yes|asian:no|version:2.0.0 = 69.0 +``` + +Let's now enable `chrF++` which is a revised version of chrF that takes into account word n-grams. +Observe how the `nw:0` gets changed into `nw:2` in the signature: + +``` +$ sacrebleu -t wmt17 -l en-de -i output.detok.txt -m bleu chrf ter --chrf-word-order 2 + BLEU|nrefs:1|case:mixed|eff:no|tok:13a|smooth:exp|version:2.0.0 = 20.8 + chrF2++|nrefs:1|case:mixed|eff:yes|nc:6|nw:2|space:no|version:2.0.0 = 49.0 +TER|nrefs:1|case:lc|tok:tercom|norm:no|punct:yes|asian:no|version:2.0.0 = 69.0 +``` + +Metric-specific arguments are detailed in the output of `--help`: + +``` +BLEU related arguments: + --smooth-method {none,floor,add-k,exp}, -s {none,floor,add-k,exp} + Smoothing method: exponential decay, floor (increment zero counts), add-k (increment num/denom by k for n>1), or none. (Default: exp) + --smooth-value BLEU_SMOOTH_VALUE, -sv BLEU_SMOOTH_VALUE + The smoothing value. Only valid for floor and add-k. (Defaults: floor: 0.1, add-k: 1) + --tokenize {none,zh,13a,char,intl,ja-mecab}, -tok {none,zh,13a,char,intl,ja-mecab} + Tokenization method to use for BLEU. If not provided, defaults to `zh` for Chinese, `ja-mecab` for Japanese and `13a` (mteval) otherwise. + --lowercase, -lc If True, enables case-insensitivity. (Default: False) + --force Insist that your tokenized input is actually detokenized. + +chrF related arguments: + --chrf-char-order CHRF_CHAR_ORDER, -cc CHRF_CHAR_ORDER + Character n-gram order. (Default: 6) + --chrf-word-order CHRF_WORD_ORDER, -cw CHRF_WORD_ORDER + Word n-gram order (Default: 0). If equals to 2, the metric is referred to as chrF++. + --chrf-beta CHRF_BETA + Determine the importance of recall w.r.t precision. (Default: 2) + --chrf-whitespace Include whitespaces when extracting character n-grams. (Default: False) + --chrf-lowercase Enable case-insensitivity. (Default: False) + --chrf-eps-smoothing Enables epsilon smoothing similar to chrF++.py, NLTK and Moses; instead of effective order smoothing. (Default: False) + +TER related arguments (The defaults replicate TERCOM's behavior): + --ter-case-sensitive Enables case sensitivity (Default: False) + --ter-asian-support Enables special treatment of Asian characters (Default: False) + --ter-no-punct Removes punctuation. (Default: False) + --ter-normalized Applies basic normalization and tokenization. (Default: False) +``` + +### Version Signatures +As you may have noticed, sacreBLEU generates version strings such as `BLEU|nrefs:1|case:mixed|eff:no|tok:13a|smooth:exp|version:2.0.0` for reproducibility reasons. It's strongly recommended to share these signatures in your papers! + +## Translationese Support If you are interested in the translationese effect, you can evaluate BLEU on a subset of sentences -with a given original language (identified based on the origlang tag in the raw SGM files). +with a given original language (identified based on the `origlang` tag in the raw SGM files). E.g., to evaluate only against originally German sentences translated to English use: - sacrebleu -t wmt13 -l de-en --origlang=de < my-wmt13-output.txt + $ sacrebleu -t wmt13 -l de-en --origlang=de -i my-wmt13-output.txt -and to evaluate against the complement (in this case origlang en, fr, cs, ru, de) use: +and to evaluate against the complement (in this case `origlang` en, fr, cs, ru, de) use: - sacrebleu -t wmt13 -l de-en --origlang=non-de < my-wmt13-output.txt + $ sacrebleu -t wmt13 -l de-en --origlang=non-de -i my-wmt13-output.txt -*Please note* that the evaluator will return a BLEU score only on the requested subset, +**Please note** that the evaluator will return a BLEU score only on the requested subset, but it expects that you pass through the entire translated test set. -## Using SacreBLEU from Python +## Languages & Preprocessing + +### BLEU + +- You can compute case-insensitive BLEU by passing `--lowercase` to sacreBLEU +- The default tokenizer for BLEU is `13a` which mimics the `mteval-v13a` script from Moses. +- Other tokenizers are: + - `none` which will not apply any kind of tokenization at all + - `char` for language-agnostic character-level tokenization + - `intl` applies international tokenization and mimics the `mteval-v14` script from Moses + - `zh` separates out **Chinese** characters and tokenizes the non-Chinese parts using `13a` tokenizer + - `ja-mecab` tokenizes **Japanese** inputs using the [MeCab](https://pypi.org/project/mecab-python3) morphological analyzer +- You can switch tokenizers using the `--tokenize` flag of sacreBLEU. Alternatively, if you provide language-pair strings + using `--language-pair/-l`, `zh` and `ja-mecab` tokenizers will be used if the target language is `zh` or `ja`, respectively. +- **Note that** there's no automatic language detection from the hypotheses so you need to make sure that you are correctly + selecting the tokenizer for **Japanese** and **Chinese**. + -For evaluation, it may be useful to compute BLEU inside a script. This is how you can do it: +Default 13a tokenizer will produce poor results for Japanese: + +``` +$ sacrebleu kyoto-test.ref.ja -i kyoto-test.hyp.ja -b +2.1 +``` + +Let's use the `ja-mecab` tokenizer: +``` +$ sacrebleu kyoto-test.ref.ja -i kyoto-test.hyp.ja --tokenize ja-mecab -b +14.5 +``` + +If you provide the language-pair, sacreBLEU will use ja-mecab automatically: + +``` +$ sacrebleu kyoto-test.ref.ja -i kyoto-test.hyp.ja -l en-ja -b +14.5 +``` + +### chrF / chrF++ + +chrF applies minimum to none pre-processing as it deals with character n-grams: + +- If you pass `--chrf-whitespace`, whitespace characters will be preserved when computing character n-grams. +- If you pass `--chrf-lowercase`, sacreBLEU will compute case-insensitive chrF. +- If you enable non-zero `--chrf-word-order` (pass `2` for `chrF++`), a very simple punctuation tokenization will be internally applied. + + +### TER + +Translation Error Rate (TER) has its own special tokenizer that you can configure through the command line. +The defaults provided are **compatible with the upstream TER implementation (TERCOM)** but you can nevertheless modify the +behavior through the command-line: + +- TER is by default case-insensitive. Pass `--ter-case-sensitive` to enable case-sensitivity. +- Pass `--ter-normalize` to apply a general Western tokenization +- Pass `--ter-asian-support` to enable the tokenization of Asian characters. If provided with `--ter-normalize`, + both will be applied. +- Pass `--ter-no-punct` to strip punctuation. + +## Multi-reference Evaluation + +All three metrics support the use of multiple references during evaluation. Let's first pass all references as positional arguments: + +``` +$ sacrebleu ref1 ref2 -i system -m bleu chrf ter + BLEU|nrefs:2|case:mixed|eff:no|tok:13a|smooth:exp|version:2.0.0 = 61.8 + chrF2|nrefs:2|case:mixed|eff:yes|nc:6|nw:0|space:no|version:2.0.0 = 75.0 +TER|nrefs:2|case:lc|tok:tercom|norm:no|punct:yes|asian:no|version:2.0.0 = 31.2 +``` + +Alternatively (less recommended), we can concatenate references using tabs as delimiters as well. Don't forget to pass `--num-refs/-nr` in this case! + +``` +$ paste ref1 ref2 > refs.tsv + +$ sacrebleu refs.tsv --num-refs 2 -i system -m bleu +BLEU|nrefs:2|case:mixed|eff:no|tok:13a|smooth:exp|version:2.0.0 = 61.8 +``` + +## Multi-system Evaluation +As of version `>=2.0.0`, SacreBLEU supports evaluation of an arbitrary number of systems for a particular +test set and language-pair. This has the advantage of seeing all results in a +nicely formatted table. + +Let's pass all system output files that match the shell glob `newstest2017.online-*` to sacreBLEU for evaluation: + +``` +$ sacrebleu -t wmt17 -l en-de -i newstest2017.online-* -m bleu chrf +╒═══════════════════════════════╤════════╤═════════╕ +│ System │ BLEU │ chrF2 │ +╞═══════════════════════════════╪════════╪═════════╡ +│ newstest2017.online-A.0.en-de │ 20.8 │ 52.0 │ +├───────────────────────────────┼────────┼─────────┤ +│ newstest2017.online-B.0.en-de │ 26.7 │ 56.3 │ +├───────────────────────────────┼────────┼─────────┤ +│ newstest2017.online-F.0.en-de │ 15.5 │ 49.3 │ +├───────────────────────────────┼────────┼─────────┤ +│ newstest2017.online-G.0.en-de │ 18.2 │ 51.6 │ +╘═══════════════════════════════╧════════╧═════════╛ + +----------------- +Metric signatures +----------------- + - BLEU nrefs:1|case:mixed|eff:no|tok:13a|smooth:exp|version:2.0.0 + - chrF2 nrefs:1|case:mixed|eff:yes|nc:6|nw:0|space:no|version:2.0.0 +``` + +You can also change the output format to `latex`: + +``` +$ sacrebleu -t wmt17 -l en-de -i newstest2017.online-* -m bleu chrf -f latex +\begin{tabular}{rcc} +\toprule + System & BLEU & chrF2 \\ +\midrule + newstest2017.online-A.0.en-de & 20.8 & 52.0 \\ + newstest2017.online-B.0.en-de & 26.7 & 56.3 \\ + newstest2017.online-F.0.en-de & 15.5 & 49.3 \\ + newstest2017.online-G.0.en-de & 18.2 & 51.6 \\ +\bottomrule +\end{tabular} + +... +``` + +## Confidence Intervals for Single System Evaluation + +When enabled with the `--confidence` flag, SacreBLEU will print +(1) the actual system score, (2) the true mean estimated from bootstrap resampling and (3), +the 95% [confidence interval](https://en.wikipedia.org/wiki/Confidence_interval) around the mean. +By default, the number of bootstrap resamples is 1000 (`bs:1000` in the signature) +and can be changed with `--confidence-n`: + +``` +$ sacrebleu -t wmt17 -l en-de -i output.detok.txt -m bleu chrf --confidence -f text --short + BLEU|#:1|bs:1000|rs:12345|c:mixed|e:no|tok:13a|s:exp|v:2.0.0 = 22.675 (μ = 22.669 ± 0.598) ... +chrF2|#:1|bs:1000|rs:12345|c:mixed|e:yes|nc:6|nw:0|s:no|v:2.0.0 = 51.953 (μ = 51.953 ± 0.462) +``` + +**NOTE:** Although provided as a functionality, having access to confidence intervals for just one system +may not reveal much information about the underlying model. It often makes more sense to perform +**paired statistical tests** across multiple systems. + +**NOTE:** When resampling, the seed of the `numpy`'s random number generator (RNG) +is fixed to `12345`. If you want to relax this and set your own seed, you can +export the environment variable `SACREBLEU_SEED` to an integer. Alternatively, you can export +`SACREBLEU_SEED=None` to skip initializing the RNG's seed and allow for non-deterministic +behavior. + +## Paired Significance Tests for Multi System Evaluation +Ideally, one would have access to many systems in cases such as (1) investigating +whether a newly added feature yields significantly different scores than the baseline or +(2) evaluating submissions for a particular shared task. SacreBLEU offers two different paired significance tests that are widely used in MT research. + +### Paired bootstrap resampling (--paired-bs) + +This is an efficient implementation of the paper [Statistical Significance Tests for Machine Translation Evaluation](https://www.aclweb.org/anthology/W04-3250.pdf) and is result-compliant with the [reference Moses implementation](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/analysis/bootstrap-hypothesis-difference-significance.pl). The number of bootstrap resamples can be changed with the `--paired-bs-n` flag and its default is 1000. + +When launched, paired bootstrap resampling will perform: + - Bootstrap resampling to estimate 95% CI for all systems and the baseline + - A significance test between the **baseline** and each **system** to compute a [p-value](https://en.wikipedia.org/wiki/P-value). + +### Paired approximate randomization (--paired-ar) + +Paired approximate randomization (AR) is another type of paired significance test that is claimed to be more accurate than paired bootstrap resampling when it comes to Type-I errors ([Riezler and Maxwell III, 2005](https://www.aclweb.org/anthology/W05-0908.pdf)). Type-I errors indicate failures to reject the null hypothesis when it is true. In other words, AR should in theory be more robust to subtle changes across systems. + +Our implementation is verified to be result-compliant with the [Multeval toolkit](https://github.com/jhclark/multeval) that also uses paired AR test for pairwise comparison. The number of approximate randomization trials is set to 10,000 by default. This can be changed with the `--paired-ar-n` flag. + +### Running the tests + +- The **first system** provided to `--input/-i` will be automatically taken as the **baseline system** against which you want to compare **other systems.** +- When `--input/-i` is used, the system output files will be automatically named according to the file paths. For the sake of simplicity, SacreBLEU will automatically discard the **baseline system** if it also appears amongst **other systems**. This is useful if you would like to run the tool by passing `-i systems/baseline.txt systems/*.txt`. Here, the `baseline.txt` file will not be also considered as a candidate system. +- Alternatively, you can also use a tab-separated input file redirected to SacreBLEU. In this case, the first column hypotheses will be taken as the **baseline system**. However, this method is **not recommended** as it won't allow naming your systems in a human-readable way. It will instead enumerate the systems from 1 to N following the column order in the tab-separated input. +- On Linux and Mac OS X, you can launch the tests on multiple CPU's by passing the flag `--paired-jobs N`. If `N == 0`, SacreBLEU will launch one worker for each pairwise comparison. If `N > 0`, `N` worker processes will be spawned. This feature will substantially speed up the runtime especially if you want the **TER** metric to be computed. + +#### Example: Paired bootstrap resampling +In the example below, we select `newstest2017.LIUM-NMT.4900.en-de` as the baseline and compare it to 4 other WMT17 submissions using paired bootstrap resampling. According to the results, the null hypothesis (i.e. the two systems being essentially the same) could not be rejected (at the significance level of 0.05) for the following comparisons: + +- 0.1 BLEU difference between the baseline and the online-B system (p = 0.3077) + +``` +$ sacrebleu -t wmt17 -l en-de -i newstest2017.LIUM-NMT.4900.en-de newstest2017.online-* -m bleu chrf --paired-bs +╒════════════════════════════════════════════╤═════════════════════╤══════════════════════╕ +│ System │ BLEU (μ ± 95% CI) │ chrF2 (μ ± 95% CI) │ +╞════════════════════════════════════════════╪═════════════════════╪══════════════════════╡ +│ Baseline: newstest2017.LIUM-NMT.4900.en-de │ 26.6 (26.6 ± 0.6) │ 55.9 (55.9 ± 0.5) │ +├────────────────────────────────────────────┼─────────────────────┼──────────────────────┤ +│ newstest2017.online-A.0.en-de │ 20.8 (20.8 ± 0.6) │ 52.0 (52.0 ± 0.4) │ +│ │ (p = 0.0010)* │ (p = 0.0010)* │ +├────────────────────────────────────────────┼─────────────────────┼──────────────────────┤ +│ newstest2017.online-B.0.en-de │ 26.7 (26.6 ± 0.7) │ 56.3 (56.3 ± 0.5) │ +│ │ (p = 0.3077) │ (p = 0.0240)* │ +├────────────────────────────────────────────┼─────────────────────┼──────────────────────┤ +│ newstest2017.online-F.0.en-de │ 15.5 (15.4 ± 0.5) │ 49.3 (49.3 ± 0.4) │ +│ │ (p = 0.0010)* │ (p = 0.0010)* │ +├────────────────────────────────────────────┼─────────────────────┼──────────────────────┤ +│ newstest2017.online-G.0.en-de │ 18.2 (18.2 ± 0.5) │ 51.6 (51.6 ± 0.4) │ +│ │ (p = 0.0010)* │ (p = 0.0010)* │ +╘════════════════════════════════════════════╧═════════════════════╧══════════════════════╛ + +------------------------------------------------------------ +Paired bootstrap resampling test with 1000 resampling trials +------------------------------------------------------------ + - Each system is pairwise compared to Baseline: newstest2017.LIUM-NMT.4900.en-de. + Actual system score / bootstrap estimated true mean / 95% CI are provided for each metric. + + - Null hypothesis: the system and the baseline translations are essentially + generated by the same underlying process. For a given system and the baseline, + the p-value is roughly the probability of the absolute score difference (delta) + or higher occurring due to chance, under the assumption that the null hypothesis is correct. + + - Assuming a significance threshold of 0.05, the null hypothesis can be rejected + for p-values < 0.05 (marked with "*"). This means that the delta is unlikely to be attributed + to chance, hence the system is significantly "different" than the baseline. + Otherwise, the p-values are highlighted in red. + + - NOTE: Significance does not tell whether a system is "better" than the baseline but rather + emphasizes the "difference" of the systems in terms of the replicability of the delta. + +----------------- +Metric signatures +----------------- + - BLEU nrefs:1|bs:1000|seed:12345|case:mixed|eff:no|tok:13a|smooth:exp|version:2.0.0 + - chrF2 nrefs:1|bs:1000|seed:12345|case:mixed|eff:yes|nc:6|nw:0|space:no|version:2.0.0 +``` + +#### Example: Paired approximate randomization + +Let's now run the paired approximate randomization test for the same comparison. According to the results, the findings are compatible with the paired bootstrap resampling test. However, the p-value for the `baseline vs. online-B` comparison is much higher (`0.8066`) than the paired bootstrap resampling test. + +(**Note that** the AR test does not provide confidence intervals around the true mean as it does not perform bootstrap resampling.) + +``` +$ sacrebleu -t wmt17 -l en-de -i newstest2017.LIUM-NMT.4900.en-de newstest2017.online-* -m bleu chrf --paired-ar +╒════════════════════════════════════════════╤═══════════════╤═══════════════╕ +│ System │ BLEU │ chrF2 │ +╞════════════════════════════════════════════╪═══════════════╪═══════════════╡ +│ Baseline: newstest2017.LIUM-NMT.4900.en-de │ 26.6 │ 55.9 │ +├────────────────────────────────────────────┼───────────────┼───────────────┤ +│ newstest2017.online-A.0.en-de │ 20.8 │ 52.0 │ +│ │ (p = 0.0001)* │ (p = 0.0001)* │ +├────────────────────────────────────────────┼───────────────┼───────────────┤ +│ newstest2017.online-B.0.en-de │ 26.7 │ 56.3 │ +│ │ (p = 0.8066) │ (p = 0.0385)* │ +├────────────────────────────────────────────┼───────────────┼───────────────┤ +│ newstest2017.online-F.0.en-de │ 15.5 │ 49.3 │ +│ │ (p = 0.0001)* │ (p = 0.0001)* │ +├────────────────────────────────────────────┼───────────────┼───────────────┤ +│ newstest2017.online-G.0.en-de │ 18.2 │ 51.6 │ +│ │ (p = 0.0001)* │ (p = 0.0001)* │ +╘════════════════════════════════════════════╧═══════════════╧═══════════════╛ + +------------------------------------------------------- +Paired approximate randomization test with 10000 trials +------------------------------------------------------- + - Each system is pairwise compared to Baseline: newstest2017.LIUM-NMT.4900.en-de. + Actual system score is provided for each metric. + + - Null hypothesis: the system and the baseline translations are essentially + generated by the same underlying process. For a given system and the baseline, + the p-value is roughly the probability of the absolute score difference (delta) + or higher occurring due to chance, under the assumption that the null hypothesis is correct. + + - Assuming a significance threshold of 0.05, the null hypothesis can be rejected + for p-values < 0.05 (marked with "*"). This means that the delta is unlikely to be attributed + to chance, hence the system is significantly "different" than the baseline. + Otherwise, the p-values are highlighted in red. + + - NOTE: Significance does not tell whether a system is "better" than the baseline but rather + emphasizes the "difference" of the systems in terms of the replicability of the delta. + +----------------- +Metric signatures +----------------- + - BLEU nrefs:1|ar:10000|seed:12345|case:mixed|eff:no|tok:13a|smooth:exp|version:2.0.0 + - chrF2 nrefs:1|ar:10000|seed:12345|case:mixed|eff:yes|nc:6|nw:0|space:no|version:2.0.0 +``` + +# Using SacreBLEU from Python + +For evaluation, it may be useful to compute BLEU, chrF or TER from a Python script. The recommended +way of doing this is to use the object-oriented API, by creating an instance of the `metrics.BLEU` class +for example: ```python -import sacrebleu -refs = [['The dog bit the man.', 'It was not unexpected.', 'The man bit him first.'], - ['The dog had bit the man.', 'No one was surprised.', 'The man had bitten the dog.']] -sys = ['The dog bit the man.', "It wasn't surprising.", 'The man had just bitten him.'] -bleu = sacrebleu.corpus_bleu(sys, refs) -print(bleu.score) -``` - -# MOTIVATION - -Comparing BLEU scores is harder than it should be. -Every decoder has its own implementation, often borrowed from Moses, but maybe with subtle changes. -Moses itself has a number of implementations as standalone scripts, with little indication of how they differ (note: they mostly don't, but `multi-bleu.pl` expects tokenized input). -Different flags passed to each of these scripts can produce wide swings in the final score. -All of these may handle tokenization in different ways. -On top of this, downloading and managing test sets is a moderate annoyance. -Sacre bleu! -What a mess. - -SacreBLEU aims to solve these problems by wrapping the original Papineni reference implementation together with other useful features. -The defaults are set the way that BLEU should be computed, and furthermore, the script outputs a short version string that allows others to know exactly what you did. -As an added bonus, it automatically downloads and manages test sets for you, so that you can simply tell it to score against 'wmt14', without having to hunt down a path on your local file system. -It is all designed to take BLEU a little more seriously. -After all, even with all its problems, BLEU is the default and---admit it---well-loved metric of our entire research community. -Sacre BLEU. +In [1]: from sacrebleu.metrics import BLEU, CHRF, TER + ...: + ...: refs = [ # First set of references + ...: ['The dog bit the man.', 'It was not unexpected.', 'The man bit him first.'], + ...: # Second set of references + ...: ['The dog had bit the man.', 'No one was surprised.', 'The man had bitten the dog.'], + ...: ] + ...: sys = ['The dog bit the man.', "It wasn't surprising.", 'The man had just bitten him.'] + +In [2]: bleu = BLEU() + +In [3]: bleu.corpus_score(sys, refs) +Out[3]: BLEU = 48.53 82.4/50.0/45.5/37.5 (BP = 0.943 ratio = 0.944 hyp_len = 17 ref_len = 18) + +In [4]: bleu.get_signature() +Out[4]: nrefs:2|case:mixed|eff:no|tok:13a|smooth:exp|version:2.0.0 + +In [5]: chrf = CHRF() -# LICENSE +In [6]: chrf.corpus_score(sys, refs) +Out[6]: chrF2 = 59.73 +``` + +### Variable Number of References + +Let's now remove the first reference sentence for the first system sentence `The dog bit the man.` by replacing it with either `None` or the empty string `''`. +This allows using a variable number of reference segments per hypothesis. Observe how the signature changes from `nrefs:2` to `nrefs:var`: -SacreBLEU is licensed under the Apache 2.0 License. +```python +In [1]: from sacrebleu.metrics import BLEU, CHRF, TER + ...: + ...: refs = [ # First set of references + # 1st sentence does not have a ref here + ...: ['', 'It was not unexpected.', 'The man bit him first.'], + ...: # Second set of references + ...: ['The dog had bit the man.', 'No one was surprised.', 'The man had bitten the dog.'], + ...: ] + ...: sys = ['The dog bit the man.', "It wasn't surprising.", 'The man had just bitten him.'] + +In [2]: bleu = BLEU() + +In [3]: bleu.corpus_score(sys, refs) +Out[3]: BLEU = 29.44 82.4/42.9/27.3/12.5 (BP = 0.889 ratio = 0.895 hyp_len = 17 ref_len = 19) + +In [4]: bleu.get_signature() +Out[4]: nrefs:var|case:mixed|eff:no|tok:13a|smooth:exp|version:2.0.0 +``` -# CREDITS +## Compatibility API + +You can also use the compatibility API that provides wrapper functions around the object-oriented API to +compute sentence-level and corpus-level BLEU, chrF and TER: (It should be noted that this API can be +removed in future releases) + +```python +In [1]: import sacrebleu + ...: + ...: refs = [ # First set of references + ...: ['The dog bit the man.', 'It was not unexpected.', 'The man bit him first.'], + ...: # Second set of references + ...: ['The dog had bit the man.', 'No one was surprised.', 'The man had bitten the dog.'], + ...: ] + ...: sys = ['The dog bit the man.', "It wasn't surprising.", 'The man had just bitten him.'] + +In [2]: sacrebleu.corpus_bleu(sys, refs) +Out[2]: BLEU = 48.53 82.4/50.0/45.5/37.5 (BP = 0.943 ratio = 0.944 hyp_len = 17 ref_len = 18) +``` + +# License + +SacreBLEU is licensed under the [Apache 2.0 License](LICENSE.txt). + +# Credits This was all Rico Sennrich's idea. Originally written by Matt Post. @@ -130,3 +613,7 @@ If you use SacreBLEU, please cite the following: pages = "186--191", } ``` + +# Release Notes + +Please see [CHANGELOG.md](CHANGELOG.md) for release notes. diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..7207d687 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,20 @@ +[mypy] +python_version = 3.6 + +[mypy-portalocker.*] +ignore_missing_imports = True + +[mypy-colorama.*] +ignore_missing_imports = True + +[mypy-numpy.*] +ignore_missing_imports = True + +[mypy-regex.*] +ignore_missing_imports = True + +[mypy-ipadic.*] +ignore_missing_imports = True + +[mypy-MeCab.*] +ignore_missing_imports = True diff --git a/sacrebleu/__init__.py b/sacrebleu/__init__.py index 5632afb9..638715d2 100644 --- a/sacrebleu/__init__.py +++ b/sacrebleu/__init__.py @@ -14,22 +14,18 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '1.5.1' +__version__ = '2.0.0' __description__ = 'Hassle-free computation of shareable, comparable, and reproducible BLEU, chrF, and TER scores' -from .utils import smart_open, SACREBLEU_DIR, download_test_set -from .utils import get_source_file, get_reference_files -from .utils import get_available_testsets, get_langpairs_for_testset -from .dataset import DATASETS -from .tokenizers import TOKENIZERS, DEFAULT_TOKENIZER -from .metrics import BLEU, CHRF +from .utils import smart_open, SACREBLEU_DIR, download_test_set # noqa: F401 +from .utils import get_source_file, get_reference_files # noqa: F401 +from .utils import get_available_testsets, get_langpairs_for_testset # noqa: F401 +from .metrics.helpers import extract_word_ngrams, extract_char_ngrams # noqa: F401 +from .dataset import DATASETS # noqa: F401 +from .metrics import BLEU, CHRF, TER # noqa: F401 # Backward compatibility functions for old style API access (<= 1.4.10) -from .compat import * - -# Other shorthands for backward-compatibility with <= 1.4.10 -extract_ngrams = BLEU.extract_ngrams -extract_char_ngrams = CHRF.extract_char_ngrams -ref_stats = BLEU.reference_stats -compute_bleu = BLEU.compute_bleu +from .compat import corpus_bleu, raw_corpus_bleu, sentence_bleu # noqa: F401 +from .compat import corpus_chrf, sentence_chrf # noqa: F401 +from .compat import corpus_ter, sentence_ter # noqa: F401 diff --git a/sacrebleu/__main__.py b/sacrebleu/__main__.py index ee3c0d82..3833741e 100644 --- a/sacrebleu/__main__.py +++ b/sacrebleu/__main__.py @@ -21,7 +21,7 @@ See the [README.md] file for more information. """ -from .sacrebleu import main +from .sacrebleu import main if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/sacrebleu/compat.py b/sacrebleu/compat.py index 5be36f84..acb47c20 100644 --- a/sacrebleu/compat.py +++ b/sacrebleu/compat.py @@ -1,171 +1,198 @@ -from typing import Union, Iterable, List -from argparse import Namespace +from typing import Sequence, Optional -from .tokenizers import DEFAULT_TOKENIZER from .metrics import BLEU, CHRF, TER, BLEUScore, CHRFScore, TERScore ###################################################################### # Backward compatibility functions for old style API access (< 1.4.11) ###################################################################### -def corpus_bleu(sys_stream: Union[str, Iterable[str]], - ref_streams: Union[str, List[Iterable[str]]], +def corpus_bleu(hypotheses: Sequence[str], + references: Sequence[Sequence[str]], smooth_method='exp', smooth_value=None, force=False, lowercase=False, - tokenize=DEFAULT_TOKENIZER, + tokenize=BLEU.TOKENIZER_DEFAULT, use_effective_order=False) -> BLEUScore: - """Produces BLEU scores along with its sufficient statistics from a source against one or more references. + """Computes BLEU for a corpus against a single (or multiple) reference(s). - :param sys_stream: The system stream (a sequence of segments) - :param ref_streams: A list of one or more reference streams (each a sequence of segments) + :param hypotheses: A sequence of hypothesis strings. + :param references: A sequence of reference documents with document being + defined as a sequence of reference strings. :param smooth_method: The smoothing method to use ('floor', 'add-k', 'exp' or 'none') :param smooth_value: The smoothing value for `floor` and `add-k` methods. `None` falls back to default value. :param force: Ignore data that looks already tokenized :param lowercase: Lowercase the data :param tokenize: The tokenizer to use + :param use_effective_order: Don't take into account n-gram orders without any match. :return: a `BLEUScore` object """ - args = Namespace( - smooth_method=smooth_method, smooth_value=smooth_value, force=force, - short=False, lc=lowercase, tokenize=tokenize) + metric = BLEU( + lowercase=lowercase, force=force, tokenize=tokenize, + smooth_method=smooth_method, smooth_value=smooth_value, + effective_order=use_effective_order) - metric = BLEU(args) - return metric.corpus_score( - sys_stream, ref_streams, use_effective_order=use_effective_order) + return metric.corpus_score(hypotheses, references) -def raw_corpus_bleu(sys_stream, - ref_streams, - smooth_value=BLEU.SMOOTH_DEFAULTS['floor']) -> BLEUScore: - """Convenience function that wraps corpus_bleu(). - This is convenient if you're using sacrebleu as a library, say for scoring on dev. - It uses no tokenization and 'floor' smoothing, with the floor default to 0.1. +def raw_corpus_bleu(hypotheses: Sequence[str], + references: Sequence[Sequence[str]], + smooth_value: Optional[float] = BLEU.SMOOTH_DEFAULTS['floor']) -> BLEUScore: + """Computes BLEU for a corpus against a single (or multiple) reference(s). + This convenience function assumes a particular set of arguments i.e. + it disables tokenization and applies a `floor` smoothing with value `0.1`. - :param sys_stream: the system stream (a sequence of segments) - :param ref_streams: a list of one or more reference streams (each a sequence of segments) + :param hypotheses: A sequence of hypothesis strings. + :param references: A sequence of reference documents with document being + defined as a sequence of reference strings. :param smooth_value: The smoothing value for `floor`. If not given, the default of 0.1 is used. :return: Returns a `BLEUScore` object. """ return corpus_bleu( - sys_stream, ref_streams, smooth_method='floor', + hypotheses, references, smooth_method='floor', smooth_value=smooth_value, force=True, tokenize='none', use_effective_order=True) def sentence_bleu(hypothesis: str, - references: List[str], + references: Sequence[str], smooth_method: str = 'exp', smooth_value: float = None, + lowercase: bool = False, + tokenize=BLEU.TOKENIZER_DEFAULT, use_effective_order: bool = True) -> BLEUScore: """ - Computes BLEU on a single sentence pair. + Computes BLEU for a single sentence against a single (or multiple) reference(s). - Disclaimer: computing BLEU on the sentence level is not its intended use, + Disclaimer: Computing BLEU at the sentence level is not its intended use as BLEU is a corpus-level metric. - :param hypothesis: Hypothesis string. - :param references: List of reference strings. + :param hypothesis: A single hypothesis string. + :param references: A sequence of reference strings. :param smooth_method: The smoothing method to use ('floor', 'add-k', 'exp' or 'none') :param smooth_value: The smoothing value for `floor` and `add-k` methods. `None` falls back to default value. - :param use_effective_order: Account for references that are shorter than the largest n-gram. + :param lowercase: Lowercase the data + :param tokenize: The tokenizer to use + :param use_effective_order: Don't take into account n-gram orders without any match. :return: Returns a `BLEUScore` object. """ - args = Namespace( - smooth_method=smooth_method, smooth_value=smooth_value, force=False, - short=False, lc=False, tokenize=DEFAULT_TOKENIZER) + metric = BLEU( + lowercase=lowercase, tokenize=tokenize, force=False, + smooth_method=smooth_method, smooth_value=smooth_value, + effective_order=use_effective_order) - metric = BLEU(args) - return metric.sentence_score( - hypothesis, references, use_effective_order=use_effective_order) + return metric.sentence_score(hypothesis, references) -def corpus_chrf(hypotheses: Iterable[str], - references: List[Iterable[str]], - order: int = CHRF.ORDER, - beta: float = CHRF.BETA, - remove_whitespace: bool = True) -> CHRFScore: +def corpus_chrf(hypotheses: Sequence[str], + references: Sequence[Sequence[str]], + char_order: int = CHRF.CHAR_ORDER, + word_order: int = CHRF.WORD_ORDER, + beta: int = CHRF.BETA, + remove_whitespace: bool = True, + eps_smoothing: bool = False) -> CHRFScore: """ - Computes ChrF on a corpus. - - :param hypotheses: Stream of hypotheses. - :param references: Stream of references. - :param order: Maximum n-gram order. - :param beta: Defines importance of recall w.r.t precision. If beta=1, same importance. - :param remove_whitespace: Whether to delete all whitespace from hypothesis and reference strings. + Computes chrF for a corpus against a single (or multiple) reference(s). + If `word_order` equals to 2, the metric is referred to as chrF++. + + :param hypotheses: A sequence of hypothesis strings. + :param references: A sequence of reference documents with document being + defined as a sequence of reference strings. + :param char_order: Character n-gram order. + :param word_order: Word n-gram order. If equals to 2, the metric is referred to as chrF++. + :param beta: Determine the importance of recall w.r.t precision. + :param eps_smoothing: If `True`, applies epsilon smoothing similar + to reference chrF++.py, NLTK and Moses implementations. Otherwise, + it takes into account effective match order similar to sacreBLEU < 2.0.0. + :param remove_whitespace: If `True`, removes whitespaces prior to character n-gram extraction. :return: A `CHRFScore` object. """ - args = Namespace( - chrf_order=order, chrf_beta=beta, chrf_whitespace=not remove_whitespace, short=False) - metric = CHRF(args) + metric = CHRF( + char_order=char_order, + word_order=word_order, + beta=beta, + whitespace=not remove_whitespace, + eps_smoothing=eps_smoothing) return metric.corpus_score(hypotheses, references) def sentence_chrf(hypothesis: str, - references: List[str], - order: int = CHRF.ORDER, - beta: float = CHRF.BETA, - remove_whitespace: bool = True) -> CHRFScore: + references: Sequence[str], + char_order: int = CHRF.CHAR_ORDER, + word_order: int = CHRF.WORD_ORDER, + beta: int = CHRF.BETA, + remove_whitespace: bool = True, + eps_smoothing: bool = False) -> CHRFScore: """ - Computes ChrF on a single sentence pair. - - :param hypothesis: Hypothesis string. - :param references: Reference string(s). - :param order: Maximum n-gram order. - :param beta: Defines importance of recall w.r.t precision. If beta=1, same importance. - :param remove_whitespace: Whether to delete whitespaces from hypothesis and reference strings. + Computes chrF for a single sentence against a single (or multiple) reference(s). + If `word_order` equals to 2, the metric is referred to as chrF++. + + :param hypothesis: A single hypothesis string. + :param references: A sequence of reference strings. + :param char_order: Character n-gram order. + :param word_order: Word n-gram order. If equals to 2, the metric is referred to as chrF++. + :param beta: Determine the importance of recall w.r.t precision. + :param eps_smoothing: If `True`, applies epsilon smoothing similar + to reference chrF++.py, NLTK and Moses implementations. Otherwise, + it takes into account effective match order similar to sacreBLEU < 2.0.0. + :param remove_whitespace: If `True`, removes whitespaces prior to character n-gram extraction. :return: A `CHRFScore` object. """ - args = Namespace( - chrf_order=order, chrf_beta=beta, chrf_whitespace=not remove_whitespace, short=False) - metric = CHRF(args) + metric = CHRF( + char_order=char_order, + word_order=word_order, + beta=beta, + whitespace=not remove_whitespace, + eps_smoothing=eps_smoothing) return metric.sentence_score(hypothesis, references) -def corpus_ter(hypotheses: Iterable[str], - references: List[Iterable[str]], +def corpus_ter(hypotheses: Sequence[str], + references: Sequence[Sequence[str]], normalized: bool = False, no_punct: bool = False, asian_support: bool = False, case_sensitive: bool = False) -> TERScore: """ - Computes TER on a corpus. + Computes TER for a corpus against a single (or multiple) reference(s). - :param hypotheses: Stream of hypotheses. - :param references: Stream of references. + :param hypotheses: A sequence of hypothesis strings. + :param references: A sequence of reference documents with document being + defined as a sequence of reference strings. :param normalized: Enable character normalization. :param no_punct: Remove punctuation. :param asian_support: Enable special treatment of Asian characters. - :param case_sensitive: Enable case sensitivity. + :param case_sensitive: Enables case-sensitivity. :return: A `TERScore` object. """ - args = Namespace( - normalized=normalized, no_punct=no_punct, - asian_support=asian_support, case_sensitive=case_sensitive) - metric = TER(args) + metric = TER( + normalized=normalized, + no_punct=no_punct, + asian_support=asian_support, + case_sensitive=case_sensitive) return metric.corpus_score(hypotheses, references) def sentence_ter(hypothesis: str, - references: List[str], + references: Sequence[str], normalized: bool = False, no_punct: bool = False, asian_support: bool = False, case_sensitive: bool = False) -> TERScore: """ - Computes TER on a single sentence pair. + Computes TER for a single hypothesis against a single (or multiple) reference(s). - :param hypothesis: Hypothesis string. - :param references: Reference string(s). + :param hypothesis: A single hypothesis string. + :param references: A sequence of reference strings. :param normalized: Enable character normalization. :param no_punct: Remove punctuation. :param asian_support: Enable special treatment of Asian characters. - :param case_sensitive: Enable case sensitivity. + :param case_sensitive: Enable case-sensitivity. :return: A `TERScore` object. """ - args = Namespace( - normalized=normalized, no_punct=no_punct, - asian_support=asian_support, case_sensitive=case_sensitive) - metric = TER(args) + metric = TER( + normalized=normalized, + no_punct=no_punct, + asian_support=asian_support, + case_sensitive=case_sensitive) return metric.sentence_score(hypothesis, references) diff --git a/sacrebleu/dataset.py b/sacrebleu/dataset.py index 0a9e6535..5b4d57ae 100644 --- a/sacrebleu/dataset.py +++ b/sacrebleu/dataset.py @@ -27,38 +27,38 @@ "description": 'mTEDx evaluation data, valid: http://openslr.org/100', "citation": "@misc{salesky2021multilingual,\n title={The Multilingual TEDx Corpus for Speech Recognition and Translation}, \n author={Elizabeth Salesky and Matthew Wiesner and Jacob Bremerman and Roldano Cattoni and Matteo Negri and Marco Turchi and Douglas W. Oard and Matt Post},\n year={2021},\n eprint={2102.01757},\n archivePrefix={arXiv},\n primaryClass={cs.CL}\n}", "md5": ['40618171614c50e6cbb5e5bbceee0635'], - "el-en": ['valid/mtedx-valid-elen.el','valid/mtedx-valid-elen.en'], - "es-en": ['valid/mtedx-valid-esen.es','valid/mtedx-valid-esen.en'], - "es-fr": ['valid/mtedx-valid-esfr.es','valid/mtedx-valid-esfr.fr'], - "es-it": ['valid/mtedx-valid-esit.es','valid/mtedx-valid-esit.it'], - "es-pt": ['valid/mtedx-valid-espt.es','valid/mtedx-valid-espt.pt'], - "fr-en": ['valid/mtedx-valid-fren.fr','valid/mtedx-valid-fren.en'], - "fr-es": ['valid/mtedx-valid-fres.fr','valid/mtedx-valid-fres.es'], - "fr-pt": ['valid/mtedx-valid-frpt.fr','valid/mtedx-valid-frpt.pt'], - "it-en": ['valid/mtedx-valid-iten.it','valid/mtedx-valid-iten.en'], - "it-es": ['valid/mtedx-valid-ites.it','valid/mtedx-valid-ites.es'], - "pt-en": ['valid/mtedx-valid-pten.pt','valid/mtedx-valid-pten.en'], - "pt-es": ['valid/mtedx-valid-ptes.pt','valid/mtedx-valid-ptes.es'], - "ru-en": ['valid/mtedx-valid-ruen.ru','valid/mtedx-valid-ruen.en'] + "el-en": ['valid/mtedx-valid-elen.el', 'valid/mtedx-valid-elen.en'], + "es-en": ['valid/mtedx-valid-esen.es', 'valid/mtedx-valid-esen.en'], + "es-fr": ['valid/mtedx-valid-esfr.es', 'valid/mtedx-valid-esfr.fr'], + "es-it": ['valid/mtedx-valid-esit.es', 'valid/mtedx-valid-esit.it'], + "es-pt": ['valid/mtedx-valid-espt.es', 'valid/mtedx-valid-espt.pt'], + "fr-en": ['valid/mtedx-valid-fren.fr', 'valid/mtedx-valid-fren.en'], + "fr-es": ['valid/mtedx-valid-fres.fr', 'valid/mtedx-valid-fres.es'], + "fr-pt": ['valid/mtedx-valid-frpt.fr', 'valid/mtedx-valid-frpt.pt'], + "it-en": ['valid/mtedx-valid-iten.it', 'valid/mtedx-valid-iten.en'], + "it-es": ['valid/mtedx-valid-ites.it', 'valid/mtedx-valid-ites.es'], + "pt-en": ['valid/mtedx-valid-pten.pt', 'valid/mtedx-valid-pten.en'], + "pt-es": ['valid/mtedx-valid-ptes.pt', 'valid/mtedx-valid-ptes.es'], + "ru-en": ['valid/mtedx-valid-ruen.ru', 'valid/mtedx-valid-ruen.en'] }, "mtedx/test": { "data": ['https://raw.githubusercontent.com/esalesky/mtedx-eval/main/test.tar.gz'], "description": 'mTEDx evaluation data, test: http://openslr.org/100', "citation": "@misc{salesky2021multilingual,\n title={The Multilingual TEDx Corpus for Speech Recognition and Translation}, \n author={Elizabeth Salesky and Matthew Wiesner and Jacob Bremerman and Roldano Cattoni and Matteo Negri and Marco Turchi and Douglas W. Oard and Matt Post},\n year={2021},\n eprint={2102.01757},\n archivePrefix={arXiv},\n primaryClass={cs.CL}\n}", "md5": ['fa4cb1548c210ec424d7d6bc9a3675a7'], - "el-en": ['test/mtedx-test-elen.el','test/mtedx-test-elen.en'], - "es-en": ['test/mtedx-test-esen.es','test/mtedx-test-esen.en'], - "es-fr": ['test/mtedx-test-esfr.es','test/mtedx-test-esfr.fr'], - "es-it": ['test/mtedx-test-esit.es','test/mtedx-test-esit.it'], - "es-pt": ['test/mtedx-test-espt.es','test/mtedx-test-espt.pt'], - "fr-en": ['test/mtedx-test-fren.fr','test/mtedx-test-fren.en'], - "fr-es": ['test/mtedx-test-fres.fr','test/mtedx-test-fres.es'], - "fr-pt": ['test/mtedx-test-frpt.fr','test/mtedx-test-frpt.pt'], - "it-en": ['test/mtedx-test-iten.it','test/mtedx-test-iten.en'], - "it-es": ['test/mtedx-test-ites.it','test/mtedx-test-ites.es'], - "pt-en": ['test/mtedx-test-pten.pt','test/mtedx-test-pten.en'], - "pt-es": ['test/mtedx-test-ptes.pt','test/mtedx-test-ptes.es'], - "ru-en": ['test/mtedx-test-ruen.ru','test/mtedx-test-ruen.en'] + "el-en": ['test/mtedx-test-elen.el', 'test/mtedx-test-elen.en'], + "es-en": ['test/mtedx-test-esen.es', 'test/mtedx-test-esen.en'], + "es-fr": ['test/mtedx-test-esfr.es', 'test/mtedx-test-esfr.fr'], + "es-it": ['test/mtedx-test-esit.es', 'test/mtedx-test-esit.it'], + "es-pt": ['test/mtedx-test-espt.es', 'test/mtedx-test-espt.pt'], + "fr-en": ['test/mtedx-test-fren.fr', 'test/mtedx-test-fren.en'], + "fr-es": ['test/mtedx-test-fres.fr', 'test/mtedx-test-fres.es'], + "fr-pt": ['test/mtedx-test-frpt.fr', 'test/mtedx-test-frpt.pt'], + "it-en": ['test/mtedx-test-iten.it', 'test/mtedx-test-iten.en'], + "it-es": ['test/mtedx-test-ites.it', 'test/mtedx-test-ites.es'], + "pt-en": ['test/mtedx-test-pten.pt', 'test/mtedx-test-pten.en'], + "pt-es": ['test/mtedx-test-ptes.pt', 'test/mtedx-test-ptes.es'], + "ru-en": ['test/mtedx-test-ruen.ru', 'test/mtedx-test-ruen.en'] }, "wmt20/robust/set1": { "data": ["http://data.statmt.org/wmt20/robustness-task/robustness20-3-sets.zip"], @@ -611,7 +611,7 @@ 'ja-en': ['ja-en/IWSLT17.TED.tst2017.ja-en.ja.xml', 'en-ja/IWSLT17.TED.tst2017.en-ja.en.xml'], 'en-ko': ['en-ko/IWSLT17.TED.tst2017.en-ko.en.xml', 'ko-en/IWSLT17.TED.tst2017.ko-en.ko.xml'], 'ko-en': ['ko-en/IWSLT17.TED.tst2017.ko-en.ko.xml', 'en-ko/IWSLT17.TED.tst2017.en-ko.en.xml'], - }, + }, 'iwslt17/tst2016': { 'data': ['https://raw.githubusercontent.com/hlt-mt/WIT3/master/archive/2017-01-ted-test/texts/en/fr/en-fr.tgz', 'https://raw.githubusercontent.com/hlt-mt/WIT3/master/archive/2017-01-ted-test/texts/fr/en/fr-en.tgz', @@ -812,64 +812,81 @@ # when applied on large data (TODO: annotate all documents from recent WMT years, at least for origlang=en, consider renaming "world" to "other"). _SUBSETS = { 'wmt18': 'rt.com.68098=US-crime guardian.181611=US-politics bbc.310963=GB-sport washpost.116881=US-politics scotsman.104228=GB-sport timemagazine.75207=OTHER-world-ID ' - 'euronews-en.117981=OTHER-crime-AE smh.com.au.242810=US-crime msnbc.53726=US-politics euronews-en.117983=US-politics msnbc.53894=US-crime theglobeandmail.com.62700=US-business ' - 'bbc.310870=OTHER-world-AF reuters.196698=US-politics latimes.231739=US-sport thelocal.51929=OTHER-world-SE cbsnews.198694=US-politics reuters.196718=OTHER-sport-RU ' - 'abcnews.255599=EU-sport nytimes.127256=US-entertainment scotsman.104225=GB-politics dailymail.co.uk.233026=GB-scitech independent.181088=GB-entertainment ' - 'brisbanetimes.com.au.181614=OTHER-business-AU washpost.116837=US-politics dailymail.co.uk.232928=GB-world thelocal.51916=OTHER-politics-IT bbc.310871=US-crime ' - 'nytimes.127392=EU-business-DE euronews-en.118001=EU-scitech-FR washpost.116866=OTHER-crime-MX dailymail.co.uk.233025=OTHER-scitech-CA latimes.231829=US-crime ' - 'guardian.181662=US-entertainment msnbc.53731=US-crime rt.com.68127=OTHER-sport-RU latimes.231782=US-business latimes.231840=US-sport reuters.196711=OTHER-scitech ' - 'guardian.181666=GB-entertainment novinite.com.24019=US-politics smh.com.au.242750=OTHER-scitech guardian.181610=US-politics telegraph.364393=OTHER-crime-ZA ' - 'novinite.com.23995=EU-world dailymail.co.uk.233028=GB-scitech independent.181071=GB-sport telegraph.364538=GB-scitech timemagazine.75193=US-politics ' - 'independent.181096=US-entertainment upi.140602=OTHER-world-AF bbc.310946=GB-business independent.181052=EU-sport ', + 'euronews-en.117981=OTHER-crime-AE smh.com.au.242810=US-crime msnbc.53726=US-politics euronews-en.117983=US-politics msnbc.53894=US-crime theglobeandmail.com.62700=US-business ' + 'bbc.310870=OTHER-world-AF reuters.196698=US-politics latimes.231739=US-sport thelocal.51929=OTHER-world-SE cbsnews.198694=US-politics reuters.196718=OTHER-sport-RU ' + 'abcnews.255599=EU-sport nytimes.127256=US-entertainment scotsman.104225=GB-politics dailymail.co.uk.233026=GB-scitech independent.181088=GB-entertainment ' + 'brisbanetimes.com.au.181614=OTHER-business-AU washpost.116837=US-politics dailymail.co.uk.232928=GB-world thelocal.51916=OTHER-politics-IT bbc.310871=US-crime ' + 'nytimes.127392=EU-business-DE euronews-en.118001=EU-scitech-FR washpost.116866=OTHER-crime-MX dailymail.co.uk.233025=OTHER-scitech-CA latimes.231829=US-crime ' + 'guardian.181662=US-entertainment msnbc.53731=US-crime rt.com.68127=OTHER-sport-RU latimes.231782=US-business latimes.231840=US-sport reuters.196711=OTHER-scitech ' + 'guardian.181666=GB-entertainment novinite.com.24019=US-politics smh.com.au.242750=OTHER-scitech guardian.181610=US-politics telegraph.364393=OTHER-crime-ZA ' + 'novinite.com.23995=EU-world dailymail.co.uk.233028=GB-scitech independent.181071=GB-sport telegraph.364538=GB-scitech timemagazine.75193=US-politics ' + 'independent.181096=US-entertainment upi.140602=OTHER-world-AF bbc.310946=GB-business independent.181052=EU-sport ', 'wmt19': 'bbc.381790=GB-politics rt.com.91337=OTHER-politics-MK nytimes.184853=US-world upi.176266=US-crime guardian.221754=GB-business dailymail.co.uk.298595=GB-business ' - 'cnbc.com.6790=US-politics nytimes.184837=OTHER-world-ID upi.176249=GB-sport euronews-en.153835=OTHER-world-ID dailymail.co.uk.298732=GB-crime telegraph.405401=GB-politics ' - 'newsweek.51331=OTHER-crime-CN abcnews.306815=US-world cbsnews.248384=US-politics reuters.218882=GB-politics cbsnews.248387=US-crime abcnews.306764=OTHER-world-MX ' - 'reuters.218888=EU-politics bbc.381780=GB-crime bbc.381746=GB-sport euronews-en.153800=EU-politics bbc.381679=GB-crime bbc.381735=GB-crime newsweek.51338=US-world ' - 'bbc.381765=GB-crime cnn.304489=US-politics reuters.218863=OTHER-world-ID nytimes.184860=OTHER-world-ID cnn.304404=US-crime bbc.381647=US-entertainment ' - 'abcnews.306758=OTHER-politics-MX cnbc.com.6772=US-business reuters.218932=OTHER-politics-MK upi.176251=GB-sport reuters.218921=US-sport cnn.304447=US-politics ' - 'guardian.221679=GB-politics scotsman.133765=GB-sport scotsman.133804=GB-entertainment guardian.221762=OTHER-politics-BO cnbc.com.6769=US-politics ' - 'dailymail.co.uk.298692=EU-entertainment scotsman.133744=GB-world reuters.218911=US-sport newsweek.51310=US-politics independent.226301=US-sport reuters.218923=EU-sport ' - 'reuters.218861=US-politics dailymail.co.uk.298759=US-world scotsman.133791=GB-sport cbsnews.248484=EU-scitech dailymail.co.uk.298630=US-scitech ' - 'newsweek.51329=US-entertainment bbc.381701=GB-crime dailymail.co.uk.298738=GB-entertainment bbc.381669=OTHER-world-CN foxnews.94512=US-politics ' - 'guardian.221718=GB-entertainment dailymail.co.uk.298686=GB-politics cbsnews.248471=US-politics newsweek.51318=US-entertainment rt.com.91335=US-politics ' - 'newsweek.51300=US-politics cnn.304478=US-politics upi.176275=US-politics telegraph.405422=OTHER-world-ID reuters.218933=US-politics newsweek.51328=US-politics ' - 'newsweek.51307=US-business bbc.381692=GB-world independent.226346=GB-entertainment bbc.381646=GB-sport reuters.218914=US-sport scotsman.133758=EU-sport ' - 'rt.com.91350=EU-world scotsman.133773=GB-scitech rt.com.91334=EU-crime bbc.381680=GB-politics guardian.221756=US-politics scotsman.133783=GB-politics cnn.304521=US-sport ' - 'dailymail.co.uk.298622=GB-politics bbc.381789=GB-sport dailymail.co.uk.298644=GB-business dailymail.co.uk.298602=GB-world scotsman.133753=GB-sport ' - 'independent.226317=GB-entertainment nytimes.184862=US-politics thelocal.65969=OTHER-world-SY nytimes.184825=US-politics cnbc.com.6784=US-politics nytimes.184804=US-politics ' - 'nytimes.184830=US-politics scotsman.133801=GB-sport cnbc.com.6770=US-business bbc.381760=GB-crime reuters.218865=OTHER-world-ID newsweek.51339=US-crime ' - 'euronews-en.153797=OTHER-world-ID abcnews.306774=US-crime dailymail.co.uk.298696=GB-politics abcnews.306755=US-politics reuters.218909=US-crime ' - 'independent.226349=OTHER-sport-RU newsweek.51330=US-politics bbc.381705=GB-sport newsweek.51340=OTHER-world-ID cbsnews.248411=OTHER-world-FM abcnews.306776=US-crime ' - 'bbc.381694=GB-entertainment rt.com.91356=US-world telegraph.405430=GB-entertainment telegraph.405404=EU-world bbc.381749=GB-world telegraph.405413=US-politics ' - 'bbc.381736=OTHER-politics-KP cbsnews.248394=US-politics nytimes.184822=US-world telegraph.405408=US-politics euronews-en.153799=OTHER-politics-SY ' - 'euronews-en.153826=EU-sport cnn.304400=US-world' + 'cnbc.com.6790=US-politics nytimes.184837=OTHER-world-ID upi.176249=GB-sport euronews-en.153835=OTHER-world-ID dailymail.co.uk.298732=GB-crime telegraph.405401=GB-politics ' + 'newsweek.51331=OTHER-crime-CN abcnews.306815=US-world cbsnews.248384=US-politics reuters.218882=GB-politics cbsnews.248387=US-crime abcnews.306764=OTHER-world-MX ' + 'reuters.218888=EU-politics bbc.381780=GB-crime bbc.381746=GB-sport euronews-en.153800=EU-politics bbc.381679=GB-crime bbc.381735=GB-crime newsweek.51338=US-world ' + 'bbc.381765=GB-crime cnn.304489=US-politics reuters.218863=OTHER-world-ID nytimes.184860=OTHER-world-ID cnn.304404=US-crime bbc.381647=US-entertainment ' + 'abcnews.306758=OTHER-politics-MX cnbc.com.6772=US-business reuters.218932=OTHER-politics-MK upi.176251=GB-sport reuters.218921=US-sport cnn.304447=US-politics ' + 'guardian.221679=GB-politics scotsman.133765=GB-sport scotsman.133804=GB-entertainment guardian.221762=OTHER-politics-BO cnbc.com.6769=US-politics ' + 'dailymail.co.uk.298692=EU-entertainment scotsman.133744=GB-world reuters.218911=US-sport newsweek.51310=US-politics independent.226301=US-sport reuters.218923=EU-sport ' + 'reuters.218861=US-politics dailymail.co.uk.298759=US-world scotsman.133791=GB-sport cbsnews.248484=EU-scitech dailymail.co.uk.298630=US-scitech ' + 'newsweek.51329=US-entertainment bbc.381701=GB-crime dailymail.co.uk.298738=GB-entertainment bbc.381669=OTHER-world-CN foxnews.94512=US-politics ' + 'guardian.221718=GB-entertainment dailymail.co.uk.298686=GB-politics cbsnews.248471=US-politics newsweek.51318=US-entertainment rt.com.91335=US-politics ' + 'newsweek.51300=US-politics cnn.304478=US-politics upi.176275=US-politics telegraph.405422=OTHER-world-ID reuters.218933=US-politics newsweek.51328=US-politics ' + 'newsweek.51307=US-business bbc.381692=GB-world independent.226346=GB-entertainment bbc.381646=GB-sport reuters.218914=US-sport scotsman.133758=EU-sport ' + 'rt.com.91350=EU-world scotsman.133773=GB-scitech rt.com.91334=EU-crime bbc.381680=GB-politics guardian.221756=US-politics scotsman.133783=GB-politics cnn.304521=US-sport ' + 'dailymail.co.uk.298622=GB-politics bbc.381789=GB-sport dailymail.co.uk.298644=GB-business dailymail.co.uk.298602=GB-world scotsman.133753=GB-sport ' + 'independent.226317=GB-entertainment nytimes.184862=US-politics thelocal.65969=OTHER-world-SY nytimes.184825=US-politics cnbc.com.6784=US-politics nytimes.184804=US-politics ' + 'nytimes.184830=US-politics scotsman.133801=GB-sport cnbc.com.6770=US-business bbc.381760=GB-crime reuters.218865=OTHER-world-ID newsweek.51339=US-crime ' + 'euronews-en.153797=OTHER-world-ID abcnews.306774=US-crime dailymail.co.uk.298696=GB-politics abcnews.306755=US-politics reuters.218909=US-crime ' + 'independent.226349=OTHER-sport-RU newsweek.51330=US-politics bbc.381705=GB-sport newsweek.51340=OTHER-world-ID cbsnews.248411=OTHER-world-FM abcnews.306776=US-crime ' + 'bbc.381694=GB-entertainment rt.com.91356=US-world telegraph.405430=GB-entertainment telegraph.405404=EU-world bbc.381749=GB-world telegraph.405413=US-politics ' + 'bbc.381736=OTHER-politics-KP cbsnews.248394=US-politics nytimes.184822=US-world telegraph.405408=US-politics euronews-en.153799=OTHER-politics-SY ' + 'euronews-en.153826=EU-sport cnn.304400=US-world', } + SUBSETS = {k: {d.split('=')[0]: d.split('=')[1] for d in v.split()} for (k, v) in _SUBSETS.items()} COUNTRIES = sorted(list({v.split('-')[0] for v in SUBSETS['wmt19'].values()})) DOMAINS = sorted(list({v.split('-')[1] for v in SUBSETS['wmt19'].values()})) if __name__ == '__main__': - # check downloading of files and MD5 hashsums - import urllib.request - import hashlib - url_md5 = {} + import sys + try: + cmd = sys.argv[1] + except IndexError: + print(f'Usage: {sys.argv[0]} --check | --dump') + sys.exit(1) + + if cmd == '--check': + import urllib.request + import hashlib + url_md5 = {} - for key, value in DATASETS.items(): - md5_hashes = value.get('md5', None) - if md5_hashes is not None: - assert len(value['data']) == len(md5_hashes) - pairs = zip(value['data'], md5_hashes) - for url, md5_hash in pairs: - url_md5[url] = md5_hash + for key, value in DATASETS.items(): + md5_hashes = value.get('md5', None) + if md5_hashes is not None: + assert len(value['data']) == len(md5_hashes) + pairs = zip(value['data'], md5_hashes) + for url, md5_hash in pairs: + url_md5[url] = md5_hash - for url, md5_hash in url_md5.items(): - try: - print('Downloading ', url) - with urllib.request.urlopen(url) as f: - data = f.read() - except Exception as exc: - raise(exc) + for url, md5_hash in url_md5.items(): + try: + print('Downloading ', url) + with urllib.request.urlopen(url) as f: + data = f.read() + except Exception as exc: + raise(exc) - if hashlib.md5(data).hexdigest() != md5_hash: - print('MD5 check failed for', url) + if hashlib.md5(data).hexdigest() != md5_hash: + print('MD5 check failed for', url) + elif cmd == '--dump': + import re + # Dumps a table in markdown format + print(f'| {"Dataset":<30} | {"Description":<115} |') + header = '| ' + '-' * 30 + ' | ' + '-' * 115 + ' |' + print(header) + for name, dset in DATASETS.items(): + desc = re.sub(r'(http[s]?:\/\/\S+)', r'[URL](\1)', str(dset['description'])) + print(f'| {name:<30} | {desc:<115} |') diff --git a/sacrebleu/metrics/__init__.py b/sacrebleu/metrics/__init__.py index 4f48aa4a..a18c2277 100644 --- a/sacrebleu/metrics/__init__.py +++ b/sacrebleu/metrics/__init__.py @@ -1,11 +1,11 @@ -# -*- coding: utf-8 -*- +"""The implementation of various metrics.""" -from .bleu import BLEU, BLEUScore -from .chrf import CHRF, CHRFScore -from .ter import TER, TERScore +from .bleu import BLEU, BLEUScore # noqa: F401 +from .chrf import CHRF, CHRFScore # noqa: F401 +from .ter import TER, TERScore # noqa: F401 METRICS = { - 'bleu': BLEU, - 'chrf': CHRF, - 'ter': TER, + 'BLEU': BLEU, + 'CHRF': CHRF, + 'TER': TER, } diff --git a/sacrebleu/metrics/base.py b/sacrebleu/metrics/base.py index 08b43011..40597139 100644 --- a/sacrebleu/metrics/base.py +++ b/sacrebleu/metrics/base.py @@ -1,58 +1,442 @@ +"""The base `Score`, `Metric` and `Signature` classes to derive from. + +`Metric` is an abstract class that enforces the implementation of a set +of abstract methods. This way, a correctly implemented metric will work +seamlessly with the rest of the codebase. +""" + +import json +import logging +import statistics +from typing import List, Sequence, Any, Optional, Dict +from abc import ABCMeta, abstractmethod from .. import __version__ +sacrelogger = logging.getLogger('sacrebleu') + -class BaseScore: - """A base score class to derive from.""" - def __init__(self, score): +class Score: + """A base score class to derive from. + + :param name: The name of the underlying metric. + :param score: A floating point number for the final metric. + """ + def __init__(self, name: str, score: float): + """`Score` initializer.""" + self.name = name self.score = score - def format(self, width=2, score_only=False, signature=''): - raise NotImplementedError() + # Statistical test related fields + self._mean = -1.0 + self._ci = -1.0 + + # More info can be added right after the score + self._verbose = '' + + def format(self, width: int = 2, score_only: bool = False, + signature: str = '', is_json: bool = False) -> str: + """Returns a pretty representation of the score. + :param width: Floating point decimal precision width. + :param score_only: If `True`, and the format is not `json`, + returns a single score string. + :param signature: A string representation of the given `Signature` + instance. + :param is_json: If `True`, will output the score in JSON string. + :return: A plain or JSON-formatted string representation. + """ + d = { + 'name': self.name, + 'score': float(f'{self.score:.{width}f}'), + 'signature': signature, + } + + sc = f'{self.score:.{width}f}' + + if self._mean > 0: + confidence_mean = f'{self._mean:.{width}f}' + confidence_var = f'{self._ci:.{width}f}' + confidence_str = f'μ = {confidence_mean} ± {confidence_var}' + + sc += f' ({confidence_str})' + if is_json: + d['confidence_mean'] = float(confidence_mean) + d['confidence_var'] = float(confidence_var) + d['confidence'] = confidence_str + + # Construct full score line + full_score = f"{self.name}|{signature}" if signature else self.name + full_score = f"{full_score} = {sc}" + if self._verbose: + full_score += f' {self._verbose}' + d['verbose_score'] = self._verbose + + if score_only: + return sc + + if is_json: + for param in signature.split('|'): + key, value = param.split(':') + d[key] = value + return json.dumps(d, indent=1, ensure_ascii=False) + + return full_score + + def estimate_ci(self, scores: List['Score']): + """Takes a list of scores and stores mean, stdev and 95% confidence + interval around the mean. + + :param scores: A list of `Score` objects obtained from bootstrap + resampling for example. + """ + # Sort the scores + raw_scores = sorted([x.score for x in scores]) + n = len(raw_scores) + + # Get CI bounds (95%, i.e. 1/40 from left) + lower_idx = n // 40 + upper_idx = n - lower_idx - 1 + lower, upper = raw_scores[lower_idx], raw_scores[upper_idx] + self._ci = 0.5 * (upper - lower) + self._mean = statistics.mean(raw_scores) def __repr__(self): + """Returns a human readable score string.""" return self.format() class Signature: """A convenience class to represent sacreBLEU reproducibility signatures. - :param args: The resulting `Namespace` returned from `parse_args()`. - Argument-value pairs from command-line would then be directly added - to the signature. + :param args: key-value dictionary passed from the actual metric instance. """ - def __init__(self, args): - # Copy the dictionary - self.args = dict(args.__dict__) - self.short = self.args.get('short', False) - + def __init__(self, args: dict): + """`Signature` initializer.""" + # Global items that are shared across all metrics self._abbr = { 'version': 'v', + 'nrefs': '#', 'test': 't', 'lang': 'l', 'subset': 'S', 'origlang': 'o', + 'bs': 'bs', # Bootstrap resampling trials + 'ar': 'ar', # Approximate randomization trials + 'seed': 'rs', # RNG's seed } + if 'num_refs' not in args: + raise RuntimeError( + 'Number of references unknown, please evaluate the metric first.') + + num_refs = args['num_refs'] + if num_refs == -1: + # Detect variable number of refs + num_refs = 'var' + + # Global items that are shared across all metrics + # None's will be ignored self.info = { - # None's will be ignored 'version': __version__, - 'test': self.args.get('test_set', None), - 'lang': self.args.get('langpair', None), - 'origlang': self.args.get('origlang', None), - 'subset': self.args.get('subset', None), + 'nrefs': num_refs, + 'bs': args.get('n_bootstrap', None), + 'ar': None, + 'seed': args.get('seed', None), + 'test': args.get('test_set', None), + 'lang': args.get('langpair', None), + 'origlang': args.get('origlang', None), + 'subset': args.get('subset', None), } - def __str__(self): - """Returns a formatted signature string.""" + def format(self, short: bool = False) -> str: + """Returns a string representation of the signature. + + :param short: If True, shortened signature is produced. + :return: A string representation of the signature. + """ pairs = [] - for name in sorted(self.info.keys()): + keys = list(self.info.keys()) + # keep version always at end + keys.remove('version') + for name in keys + ['version']: value = self.info[name] if value is not None: - final_name = self._abbr[name] if self.short else name - pairs.append('{}.{}'.format(final_name, value)) + if isinstance(value, bool): + # Replace True/False with yes/no + value = 'yes' if value else 'no' + final_name = self._abbr[name] if short else name + pairs.append(f'{final_name}:{value}') + + return '|'.join(pairs) - return '+'.join(pairs) + def update(self, key: str, value: Any): + """Add a new item or update an existing one. + + :param key: The key to use in the dictionary. + :param value: The associated value for the `key`. + """ + self.info[key] = value + + def __str__(self): + """Returns a human-readable signature string.""" + return self.format() def __repr__(self): - return self.__str__() + """Returns a human-readable signature string.""" + return self.format() + + +class Metric(metaclass=ABCMeta): + """A base class for all metrics that ensures the implementation of some + methods. Much of the common functionality is moved to this base class + from other metrics.""" + + # Each metric should define its Signature class' name here + _SIGNATURE_TYPE = Signature + + def __init__(self): + """`Metric` initializer.""" + # The pre-computed reference cache + self._ref_cache = None + + # only useful for BLEU tokenized warnings. Set to True so that + # warnings are not issued for other metrics. + self._force = True + + # Will be used by the signature when bootstrap resampling + self.n_bootstrap = None + self.seed = None + + def _check_sentence_score_args(self, hyp: str, refs: Sequence[str]): + """Performs sanity checks on `sentence_score` method's arguments. + + :param hyp: A single hypothesis string. + :param refs: A sequence of reference strings. + """ + prefix = self.__class__.__name__ + err_msg = None + + if not isinstance(hyp, str): + err_msg = 'The argument `hyp` should be a string.' + elif isinstance(refs, str) or not isinstance(refs, Sequence): + err_msg = 'The argument `refs` should be a sequence of strings.' + elif not isinstance(refs[0], str): + err_msg = 'Each element of `refs` should be a string.' + + if err_msg: + raise RuntimeError(f'{prefix}: {err_msg}') + + def _check_corpus_score_args(self, hyps: Sequence[str], + refs: Optional[Sequence[Sequence[str]]]): + """Performs sanity checks on `corpus_score` method's arguments. + + :param hypses: A sequence of hypothesis strings. + :param refs: A sequence of reference documents with document being + defined as a sequence of reference strings. If `None`, cached references + will be used. + """ + + prefix = self.__class__.__name__ + err_msg = None + + if not isinstance(hyps, Sequence): + err_msg = "`hyps` should be a sequence of strings." + elif not isinstance(hyps[0], str): + err_msg = 'Each element of `hyps` should be a string.' + elif any(line is None for line in hyps): + err_msg = "Undefined line in hypotheses stream!" + + if refs is not None: + if not isinstance(refs, Sequence): + err_msg = "`refs` should be a sequence of sequence of strings." + elif not isinstance(refs[0], Sequence): + err_msg = "Each element of `refs` should be a sequence of strings." + elif not isinstance(refs[0][0], str): + err_msg = "`refs` should be a sequence of sequence of strings." + + if err_msg: + raise RuntimeError(f'{prefix}: {err_msg}') + + @abstractmethod + def _aggregate_and_compute(self, stats: List[List[Any]]) -> Any: + """Computes the final score given the pre-computed match statistics. + + :param stats: A list of segment-level statistics. + :return: A `Score` instance. + """ + pass + + @abstractmethod + def _compute_score_from_stats(self, stats: List[Any]) -> Any: + """Computes the final score from already aggregated statistics. + + :param stats: A list or numpy array of segment-level statistics. + :return: A `Score` object. + """ + pass + + @abstractmethod + def _preprocess_segment(self, sent: str) -> str: + """A wrapper around the metric's tokenization and pre-processing logic. + This should be implemented for reference caching to work correctly. + + :param sent: The input sentence. + :return: The pre-processed output sentence. + """ + pass + + @abstractmethod + def _extract_reference_info(self, refs: Sequence[str]) -> Dict[str, Any]: + """Given a list of reference segments, extract the required + information (such as n-grams for BLEU and chrF). This should be implemented + for the generic `_cache_references()` to work across all metrics. + + :param refs: A sequence of strings. + """ + pass + + @abstractmethod + def _compute_segment_statistics(self, hypothesis: str, ref_kwargs: Dict) -> List[Any]: + """Given a (pre-processed) hypothesis sentence and already computed + reference info, returns the best match statistics across the + references. The return type is usually a List of ints or floats. + + :param hypothesis: A pre-processed hypothesis sentence. + :param ref_kwargs: A dictionary with reference-related information + within. This is formulated as a dictionary as different metrics may + require different information regarding a reference segment. + """ + pass + + def _cache_references(self, references: Sequence[Sequence[str]]) -> List[Any]: + """Given the full set of document references, extract segment n-grams + (or other necessary information) for caching purposes. + + :param references: A sequence of reference documents with document being + defined as a sequence of reference strings. A particular reference + segment can be '' or `None` to allow the use of variable number + of references per segment. + :return: A list where each element is a tuple of segment n-grams and + reference lengths, as returned by `_extract_reference_info()`. + """ + ref_cache = [] + + # Decide on final number of refs here as well + num_refs = set() + + for refs in zip(*references): + # remove undefined / empty references + # i.e. we have fewer references for this particular sentence + lines = [x for x in refs if x is not None and x != ""] + + if len(lines) == 0: + raise RuntimeError("Empty or `None` reference sentence found.") + + # Keep track of reference counts to allow variable reference + # info in the signature + num_refs.add(len(lines)) + + lines = [self._preprocess_segment(x) for x in lines] + + # Get n-grams + ref_cache.append(self._extract_reference_info(lines)) + + if len(num_refs) == 1: + self.num_refs = list(num_refs)[0] + else: + # A variable number of refs exist + self.num_refs = -1 + + return ref_cache + + def _extract_corpus_statistics(self, hypotheses: Sequence[str], + references: Optional[Sequence[Sequence[str]]]) -> Any: + """Reads the corpus and returns sentence-level match statistics for + faster re-computations esp. during statistical tests. + + :param hypotheses: A sequence of hypothesis strings. + :param references: A sequence of reference documents with document being + defined as a sequence of reference strings. If `None`, cached references + will be used. + :return: A list where each sublist corresponds to segment statistics. + """ + # Pre-compute references + # Don't store the cache as the user is explicitly passing refs + if references: + ref_cache = self._cache_references(references) + elif self._ref_cache: + ref_cache = self._ref_cache + else: + raise RuntimeError('No references provided and the cache is empty.') + + stats = [] + tok_count = 0 + + for hyp, ref_kwargs in zip(hypotheses, ref_cache): + # Check for already-tokenized input problem (only for BLEU) + if not self._force and hyp.endswith(' .'): + tok_count += 1 + + hyp = self._preprocess_segment(hyp) + + # Collect stats + stats.append(self._compute_segment_statistics(hyp, ref_kwargs)) + + if tok_count >= 100: + sacrelogger.warning("That's 100 lines that end in a tokenized period ('.')") + sacrelogger.warning("It looks like you forgot to detokenize your test data, which may hurt your score.") + sacrelogger.warning("If you insist your data is detokenized, or don't care, you can suppress this message with the `force` parameter.") + + return stats + + def sentence_score(self, hypothesis: str, references: Sequence[str]) -> Any: + """Compute the metric for a single sentence against a single (or multiple) reference(s). + + :param hypothesis: A single hypothesis string. + :param references: A sequence of reference strings. + :return: A `Score` object. + """ + self._check_sentence_score_args(hypothesis, references) + + stats = self._extract_corpus_statistics( + [hypothesis], [[refs] for refs in references]) + return self._aggregate_and_compute(stats) + + def corpus_score(self, hypotheses: Sequence[str], + references: Optional[Sequence[Sequence[str]]], + n_bootstrap: int = 1) -> Any: + """Compute the metric for a corpus against a single (or multiple) reference(s). + + :param hypotheses: A sequence of hypothesis strings. + :param references: A sequence of reference documents with document being + defined as a sequence of reference strings. If `None`, cached references + will be used. + :param n_bootstrap: If > 1, provides 95% confidence interval around true mean + using bootstrap resampling with `n_bootstrap` samples. + :return: A `Score` object. + """ + self._check_corpus_score_args(hypotheses, references) + + # Collect corpus stats + stats = self._extract_corpus_statistics(hypotheses, references) + + # Compute the actual system score + actual_score = self._aggregate_and_compute(stats) + + if n_bootstrap > 1: + # Compute bootstrap estimate as well + # Delayed import is to escape from numpy import if bootstrap + # is not requested. + from ..significance import _bootstrap_resample + + self.n_bootstrap = n_bootstrap + self.seed, bs_scores = _bootstrap_resample(stats, self, n_bootstrap) + actual_score.estimate_ci(bs_scores) + + return actual_score + + def get_signature(self) -> Signature: + """Creates and returns the signature for the metric. The creation + of signatures is delayed as the number of references is resolved + only at the point of reference caching.""" + return self._SIGNATURE_TYPE(self.__dict__) diff --git a/sacrebleu/metrics/bleu.py b/sacrebleu/metrics/bleu.py index 631d61e3..45b2a266 100644 --- a/sacrebleu/metrics/bleu.py +++ b/sacrebleu/metrics/bleu.py @@ -1,84 +1,127 @@ +"""The implementation of the BLEU metric (Papineni et al., 2002).""" + import math import logging -from collections import Counter -from typing import List, Iterable, Union +from importlib import import_module +from typing import List, Sequence, Optional, Dict, Any + +from ..utils import my_log, sum_of_lists -from ..tokenizers import TOKENIZERS -from ..utils import my_log -from .base import BaseScore, Signature +from .base import Score, Signature, Metric +from .helpers import extract_all_word_ngrams sacrelogger = logging.getLogger('sacrebleu') +# The default for the maximum n-gram order when computing precisions +MAX_NGRAM_ORDER = 4 + +_TOKENIZERS = { + 'none': 'tokenizer_base.BaseTokenizer', + 'zh': 'tokenizer_zh.TokenizerZh', + '13a': 'tokenizer_13a.Tokenizer13a', + 'intl': 'tokenizer_intl.TokenizerV14International', + 'char': 'tokenizer_char.TokenizerChar', + 'ja-mecab': 'tokenizer_ja_mecab.TokenizerJaMecab', +} + + +def _get_tokenizer(name: str): + """Dynamically import tokenizer as importing all is slow.""" + module_name, class_name = _TOKENIZERS[name].rsplit('.', 1) + return getattr( + import_module(f'.tokenizers.{module_name}', 'sacrebleu'), + class_name) + class BLEUSignature(Signature): - def __init__(self, args): + """A convenience class to represent the reproducibility signature for BLEU. + + :param args: key-value dictionary passed from the actual metric instance. + """ + def __init__(self, args: dict): + """`BLEUSignature` initializer.""" super().__init__(args) self._abbr.update({ - 'smooth': 's', 'case': 'c', + 'eff': 'e', 'tok': 'tok', - 'numrefs': '#', + 'smooth': 's', }) # Construct a combined string for smoothing method and value - smooth_str = self.args['smooth_method'] + smooth_str = args['smooth_method'] smooth_def = BLEU.SMOOTH_DEFAULTS[smooth_str] # If the method requires a parameter, add it within brackets if smooth_def is not None: # the following can be None if the user wants to use the default - smooth_val = self.args['smooth_value'] + smooth_val = args['smooth_value'] if smooth_val is None: smooth_val = smooth_def - smooth_str += '[{:.2f}]'.format(smooth_val) + smooth_str += f'[{smooth_val:.2f}]' self.info.update({ + 'case': 'lc' if args['lowercase'] else 'mixed', + 'eff': 'yes' if args['effective_order'] else 'no', + 'tok': args['tokenizer_signature'], 'smooth': smooth_str, - 'case': 'lc' if self.args['lc'] else 'mixed', - 'tok': TOKENIZERS[self.args['tokenize']]().signature(), - 'numrefs': self.args.get('num_refs', '?'), }) -class BLEUScore(BaseScore): - """A convenience class to represent BLEU scores (without signature).""" - def __init__(self, score, counts, totals, precisions, bp, sys_len, ref_len): - super().__init__(score) - - self.prefix = 'BLEU' +class BLEUScore(Score): + """A convenience class to represent BLEU scores. + + :param score: The BLEU score. + :param counts: List of counts of correct ngrams, 1 <= n <= max_ngram_order + :param totals: List of counts of total ngrams, 1 <= n <= max_ngram_order + :param precisions: List of precisions, 1 <= n <= max_ngram_order + :param bp: The brevity penalty. + :param sys_len: The cumulative system length. + :param ref_len: The cumulative reference length. + """ + def __init__(self, score: float, counts: List[int], totals: List[int], + precisions: List[float], bp: float, + sys_len: int, ref_len: int): + """`BLEUScore` initializer.""" + super().__init__('BLEU', score) self.bp = bp self.counts = counts self.totals = totals self.sys_len = sys_len self.ref_len = ref_len self.precisions = precisions - self.prec_str = "/".join(["{:.1f}".format(p) for p in self.precisions]) - - def format(self, width=2, score_only=False, signature=''): - if score_only: - return '{0:.{1}f}'.format(self.score, width) - - prefix = "{}+{}".format(self.prefix, signature) if signature else self.prefix - - s = '{pr} = {sc:.{w}f} {prec} (BP = {bp:.3f} ratio = {r:.3f} hyp_len = {sl:d} ref_len = {rl:d})'.format( - pr=prefix, - sc=self.score, - w=width, - prec=self.prec_str, - bp=self.bp, - r=self.sys_len / self.ref_len, - sl=self.sys_len, - rl=self.ref_len) - return s - -class BLEU: - NGRAM_ORDER = 4 - - SMOOTH_DEFAULTS = { + self.prec_str = "/".join([f"{p:.1f}" for p in self.precisions]) + self.ratio = self.sys_len / self.ref_len if self.ref_len else 0 + + # The verbose part of BLEU + self._verbose = f"{self.prec_str} (BP = {self.bp:.3f} " + self._verbose += f"ratio = {self.ratio:.3f} hyp_len = {self.sys_len:d} " + self._verbose += f"ref_len = {self.ref_len:d})" + + +class BLEU(Metric): + """Computes the BLEU metric given hypotheses and references. + + :param lowercase: If True, lowercased BLEU is computed. + :param force: Ignore data that looks already tokenized. + :param tokenize: The tokenizer to use. If None, defaults to language-specific tokenizers with '13a' as the fallback default. + :param smooth_method: The smoothing method to use ('floor', 'add-k', 'exp' or 'none'). + :param smooth_value: The smoothing value for `floor` and `add-k` methods. `None` falls back to default value. + :param max_ngram_order: If given, it overrides the maximum n-gram order (default: 4) when computing precisions. + :param effective_order: If `True`, stop including n-gram orders for which precision is 0. This should be + `True`, if sentence-level BLEU will be computed. + :param trg_lang: An optional language code to raise potential tokenizer warnings. + :param references: A sequence of reference documents with document being + defined as a sequence of reference strings. If given, the reference n-grams + and lengths will be pre-computed and cached for faster BLEU computation + across many systems. + """ + + SMOOTH_DEFAULTS: Dict[str, Optional[float]] = { # The defaults for `floor` and `add-k` are obtained from the following paper # A Systematic Comparison of Smoothing Techniques for Sentence-Level BLEU # Boxing Chen and Colin Cherry @@ -89,67 +132,69 @@ class BLEU: 'exp': None, # No value is required } - def __init__(self, args): - self.name = 'bleu' - self.force = args.force - self.lc = args.lc - self.smooth_value = args.smooth_value - self.smooth_method = args.smooth_method - self.tokenizer = TOKENIZERS[args.tokenize]() - self.signature = BLEUSignature(args) - - # Sanity check - assert self.smooth_method in self.SMOOTH_DEFAULTS.keys(), \ - "Unknown smooth_method '{}'".format(self.smooth_method) + TOKENIZERS = ['none', 'zh', '13a', 'char', 'intl', 'ja-mecab'] - @staticmethod - def extract_ngrams(line, min_order=1, max_order=NGRAM_ORDER) -> Counter: - """Extracts all the ngrams (min_order <= n <= max_order) from a sequence of tokens. - - :param line: A segment containing a sequence of words. - :param min_order: Minimum n-gram length (default: 1). - :param max_order: Maximum n-gram length (default: NGRAM_ORDER). - :return: a dictionary containing ngrams and counts - """ + # mteval-v13a.pl tokenizer unless Chinese or Japanese is provided + TOKENIZER_DEFAULT = '13a' - ngrams = Counter() # type: Counter - tokens = line.split() - for n in range(min_order, max_order + 1): - for i in range(0, len(tokens) - n + 1): - ngram = ' '.join(tokens[i: i + n]) - ngrams[ngram] += 1 + # Some language specific mappings to use if `trg_lang` is given + # and the tokenizer is not explicitly specified + _TOKENIZER_MAP = { + 'zh': 'zh', + 'ja': 'ja-mecab', + } - return ngrams + _SIGNATURE_TYPE = BLEUSignature + + def __init__(self, lowercase: bool = False, + force: bool = False, + tokenize: Optional[str] = '13a', + smooth_method: str = 'exp', + smooth_value: Optional[float] = None, + max_ngram_order: int = MAX_NGRAM_ORDER, + effective_order: bool = False, + trg_lang: str = '', + references: Optional[Sequence[Sequence[str]]] = None): + """`BLEU` initializer.""" + super().__init__() + + self._force = force + self.trg_lang = trg_lang + self.lowercase = lowercase + self.smooth_value = smooth_value + self.smooth_method = smooth_method + self.max_ngram_order = max_ngram_order + self.effective_order = effective_order - @staticmethod - def reference_stats(refs, output_len): - """Extracts reference statistics for a given segment. + # Sanity check + assert self.smooth_method in self.SMOOTH_DEFAULTS.keys(), \ + "Unknown smooth_method {self.smooth_method!r}" - :param refs: A list of segment tokens. - :param output_len: Hypothesis length for this segment. - :return: a tuple of (ngrams, closest_diff, closest_len) - """ + # Default tokenizer logic + if tokenize is None: + best_tokenizer = self.TOKENIZER_DEFAULT - ngrams = Counter() - closest_diff = None - closest_len = None + # Set `zh` or `ja-mecab` if target language is provided + if self.trg_lang in self._TOKENIZER_MAP: + best_tokenizer = self._TOKENIZER_MAP[self.trg_lang] + else: + best_tokenizer = tokenize + if self.trg_lang == 'zh' and best_tokenizer != 'zh': + sacrelogger.warning( + "You should use the 'zh' tokenizer for Chinese.") + if self.trg_lang == 'ja' and best_tokenizer != 'ja-mecab': + sacrelogger.warning( + "You should use the 'ja-mecab' tokenizer for Japanese.") - for ref in refs: - tokens = ref.split() - reflen = len(tokens) - diff = abs(output_len - reflen) - if closest_diff is None or diff < closest_diff: - closest_diff = diff - closest_len = reflen - elif diff == closest_diff: - if reflen < closest_len: - closest_len = reflen + # Create the tokenizer + self.tokenizer = _get_tokenizer(best_tokenizer)() - ngrams_ref = BLEU.extract_ngrams(ref) - for ngram in ngrams_ref.keys(): - ngrams[ngram] = max(ngrams[ngram], ngrams_ref[ngram]) + # Build the signature + self.tokenizer_signature = self.tokenizer.signature() - return ngrams, closest_diff, closest_len + if references is not None: + # Pre-compute reference ngrams and lengths + self._ref_cache = self._cache_references(references) @staticmethod def compute_bleu(correct: List[int], @@ -158,8 +203,9 @@ def compute_bleu(correct: List[int], ref_len: int, smooth_method: str = 'none', smooth_value=None, - use_effective_order=False) -> BLEUScore: - """Computes BLEU score from its sufficient statistics. Adds smoothing. + effective_order: bool = False, + max_ngram_order: int = MAX_NGRAM_ORDER) -> BLEUScore: + """Computes BLEU score from its sufficient statistics with smoothing. Smoothing methods (citing "A Systematic Comparison of Smoothing Techniques for Sentence-Level BLEU", Boxing Chen and Colin Cherry, WMT 2014: http://aclweb.org/anthology/W14-3346) @@ -169,151 +215,193 @@ def compute_bleu(correct: List[int], - add-k: Method 2 (Generalizing Lin and Och, 2004) - exp: Method 3 (NIST smoothing method i.e. in use with mteval-v13a.pl) - :param correct: List of counts of correct ngrams, 1 <= n <= NGRAM_ORDER - :param total: List of counts of total ngrams, 1 <= n <= NGRAM_ORDER + :param correct: List of counts of correct ngrams, 1 <= n <= max_ngram_order + :param total: List of counts of total ngrams, 1 <= n <= max_ngram_order :param sys_len: The cumulative system length :param ref_len: The cumulative reference length :param smooth_method: The smoothing method to use ('floor', 'add-k', 'exp' or 'none') :param smooth_value: The smoothing value for `floor` and `add-k` methods. `None` falls back to default value. - :param use_effective_order: If true, use the length of `correct` for the n-gram order instead of NGRAM_ORDER. - :return: A BLEU object with the score (100-based) and other statistics. + :param effective_order: If `True`, stop including n-gram orders for which precision is 0. This should be + `True`, if sentence-level BLEU will be computed. + :param max_ngram_order: If given, it overrides the maximum n-gram order (default: 4) when computing precisions. + :return: A `BLEUScore` instance. """ assert smooth_method in BLEU.SMOOTH_DEFAULTS.keys(), \ - "Unknown smooth_method '{}'".format(smooth_method) + "Unknown smooth_method {smooth_method!r}" # Fetch the default value for floor and add-k if smooth_value is None: smooth_value = BLEU.SMOOTH_DEFAULTS[smooth_method] - precisions = [0.0 for x in range(BLEU.NGRAM_ORDER)] + # Compute brevity penalty + if sys_len < ref_len: + bp = math.exp(1 - ref_len / sys_len) if sys_len > 0 else 0.0 + else: + bp = 1.0 + + # n-gram precisions + precisions = [0.0 for x in range(max_ngram_order)] + + # Early stop if there are no matches (#141) + if not any(correct): + return BLEUScore(0.0, correct, total, precisions, bp, sys_len, ref_len) smooth_mteval = 1. - effective_order = BLEU.NGRAM_ORDER - for n in range(1, BLEU.NGRAM_ORDER + 1): + eff_order = max_ngram_order + for n in range(1, len(precisions) + 1): if smooth_method == 'add-k' and n > 1: - correct[n-1] += smooth_value - total[n-1] += smooth_value - if total[n-1] == 0: + correct[n - 1] += smooth_value + total[n - 1] += smooth_value + + if total[n - 1] == 0: break - if use_effective_order: - effective_order = n + # If the system guesses no i-grams, 1 <= i <= max_ngram_order, + # the BLEU score is 0 (technically undefined). This is a problem for sentence + # level BLEU or a corpus of short sentences, where systems will get + # no credit if sentence lengths fall under the max_ngram_order threshold. + # This fix scales max_ngram_order to the observed maximum order. + # It is only available through the API and off by default + if effective_order: + eff_order = n - if correct[n-1] == 0: + if correct[n - 1] == 0: if smooth_method == 'exp': smooth_mteval *= 2 - precisions[n-1] = 100. / (smooth_mteval * total[n-1]) + precisions[n - 1] = 100. / (smooth_mteval * total[n - 1]) elif smooth_method == 'floor': - precisions[n-1] = 100. * smooth_value / total[n-1] + precisions[n - 1] = 100. * smooth_value / total[n - 1] else: - precisions[n-1] = 100. * correct[n-1] / total[n-1] - - # If the system guesses no i-grams, 1 <= i <= NGRAM_ORDER, the BLEU - # score is 0 (technically undefined). This is a problem for sentence - # level BLEU or a corpus of short sentences, where systems will get - # no credit if sentence lengths fall under the NGRAM_ORDER threshold. - # This fix scales NGRAM_ORDER to the observed maximum order. - # It is only available through the API and off by default - - if sys_len < ref_len: - bp = math.exp(1 - ref_len / sys_len) if sys_len > 0 else 0.0 - else: - bp = 1.0 + precisions[n - 1] = 100. * correct[n - 1] / total[n - 1] + # Compute BLEU score score = bp * math.exp( - sum(map(my_log, precisions[:effective_order])) / effective_order) + sum([my_log(p) for p in precisions[:eff_order]]) / eff_order) - return BLEUScore( - score, correct, total, precisions, bp, sys_len, ref_len) + return BLEUScore(score, correct, total, precisions, bp, sys_len, ref_len) - def sentence_score(self, hypothesis: str, - references: List[str], - use_effective_order: bool = True) -> BLEUScore: + def _preprocess_segment(self, sent: str) -> str: + """Given a sentence, lowercases (optionally) and tokenizes it + :param sent: The input sentence string. + :return: The pre-processed output string. """ - Computes BLEU on a single sentence pair. + if self.lowercase: + sent = sent.lower() + return self.tokenizer(sent.rstrip()) - Disclaimer: computing BLEU on the sentence level is not its intended use, - BLEU is a corpus-level metric. + def _compute_score_from_stats(self, stats: List[int]) -> BLEUScore: + """Computes the final score from already aggregated statistics. - :param hypothesis: Hypothesis string. - :param references: List of reference strings. - :param use_effective_order: Account for references that are shorter than the largest n-gram. - :return: a `BLEUScore` object containing everything you'd want - """ - assert not isinstance(references, str), \ - "sentence_score needs a list of references, not a single string" - return self.corpus_score(hypothesis, [[ref] for ref in references], - use_effective_order=use_effective_order) - - def corpus_score(self, sys_stream: Union[str, Iterable[str]], - ref_streams: Union[str, List[Iterable[str]]], - use_effective_order: bool = False) -> BLEUScore: - """Produces BLEU scores along with its sufficient statistics from a source against one or more references. - - :param sys_stream: The system stream (a sequence of segments) - :param ref_streams: A list of one or more reference streams (each a sequence of segments) - :param use_effective_order: Account for references that are shorter than the largest n-gram. - :return: a `BLEUScore` object containing everything you'd want + :param stats: A list or numpy array of segment-level statistics. + :return: A `BLEUScore` object. """ + return self.compute_bleu( + correct=stats[2: 2 + self.max_ngram_order], + total=stats[2 + self.max_ngram_order:], + sys_len=int(stats[0]), ref_len=int(stats[1]), + smooth_method=self.smooth_method, smooth_value=self.smooth_value, + effective_order=self.effective_order) - # Add some robustness to the input arguments - if isinstance(sys_stream, str): - sys_stream = [sys_stream] + def _aggregate_and_compute(self, stats: List[List[int]]) -> BLEUScore: + """Computes the final BLEU score given the pre-computed corpus statistics. - if isinstance(ref_streams, str): - ref_streams = [[ref_streams]] + :param stats: A list of segment-level statistics + :return: A `BLEUScore` instance. + """ + return self._compute_score_from_stats(sum_of_lists(stats)) - sys_len = 0 - ref_len = 0 + def _get_closest_ref_len(self, hyp_len: int, ref_lens: List[int]) -> int: + """Given a hypothesis length and a list of reference lengths, returns + the closest reference length to be used by BLEU. - correct = [0 for n in range(self.NGRAM_ORDER)] - total = [0 for n in range(self.NGRAM_ORDER)] + :param hyp_len: The hypothesis length. + :param ref_lens: A list of reference lengths. + :return: The closest reference length. + """ + closest_diff, closest_len = -1, -1 - # look for already-tokenized sentences - tokenized_count = 0 + for ref_len in ref_lens: + diff = abs(hyp_len - ref_len) + if closest_diff == -1 or diff < closest_diff: + closest_diff = diff + closest_len = ref_len + elif diff == closest_diff and ref_len < closest_len: + closest_len = ref_len - # sanity checks - if any(len(ref_stream) != len(sys_stream) for ref_stream in ref_streams): - raise EOFError("System and reference streams have different lengths!") - if any(line is None for line in sys_stream): - raise EOFError("Undefined line in system stream!") + return closest_len - for output, *refs in zip(sys_stream, *ref_streams): - # remove undefined/empty references (i.e. we have fewer references for this particular sentence) - # but keep empty hypothesis (it's always defined thanks to the sanity check above) - lines = [output] + [x for x in refs if x is not None and x != ""] - if len(lines) < 2: # we need at least hypothesis + 1 defined & non-empty reference - raise EOFError("No valid references for a sentence!") + def _extract_reference_info(self, refs: Sequence[str]) -> Dict[str, Any]: + """Given a list of reference segments, extract the n-grams and reference lengths. + The latter will be useful when comparing hypothesis and reference lengths for BLEU. - if self.lc: - lines = [x.lower() for x in lines] + :param refs: A sequence of strings. + :return: A dictionary that will be passed to `_compute_segment_statistics()` + through keyword arguments. + """ + ngrams = None + ref_lens = [] - if not (self.force or self.tokenizer.signature() == 'none') and lines[0].rstrip().endswith(' .'): - tokenized_count += 1 + for ref in refs: + # extract n-grams for this ref + this_ngrams, ref_len = extract_all_word_ngrams(ref, 1, self.max_ngram_order) + ref_lens.append(ref_len) + + if ngrams is None: + # Set it directly for first set of refs + ngrams = this_ngrams + else: + # Merge counts across multiple references + # The below loop is faster than `ngrams |= this_ngrams` + for ngram, count in this_ngrams.items(): + ngrams[ngram] = max(ngrams[ngram], count) + + return {'ref_ngrams': ngrams, 'ref_lens': ref_lens} + + def _compute_segment_statistics(self, hypothesis: str, + ref_kwargs: Dict) -> List[int]: + """Given a (pre-processed) hypothesis sentence and already computed + reference n-grams & lengths, returns the best match statistics across the + references. + + :param hypothesis: Hypothesis sentence. + :param ref_kwargs: A dictionary with `refs_ngrams`and `ref_lens` keys + that denote the counter containing all n-gram counts and reference lengths, + respectively. + :return: A list of integers with match statistics. + """ - if tokenized_count == 100: - sacrelogger.warning('That\'s 100 lines that end in a tokenized period (\'.\')') - sacrelogger.warning('It looks like you forgot to detokenize your test data, which may hurt your score.') - sacrelogger.warning('If you insist your data is detokenized, or don\'t care, you can suppress this message with \'--force\'.') + ref_ngrams, ref_lens = ref_kwargs['ref_ngrams'], ref_kwargs['ref_lens'] - output, *refs = [self.tokenizer(x.rstrip()) for x in lines] + # Extract n-grams for the hypothesis + hyp_ngrams, hyp_len = extract_all_word_ngrams( + hypothesis, 1, self.max_ngram_order) - output_len = len(output.split()) - ref_ngrams, closest_diff, closest_len = BLEU.reference_stats(refs, output_len) + ref_len = self._get_closest_ref_len(hyp_len, ref_lens) - sys_len += output_len - ref_len += closest_len + # Count the stats + # Although counter has its internal & and | operators, this is faster + correct = [0 for i in range(self.max_ngram_order)] + total = correct[:] + for hyp_ngram, hyp_count in hyp_ngrams.items(): + # n-gram order + n = len(hyp_ngram) - 1 + # count hypothesis n-grams + total[n] += hyp_count + # count matched n-grams + if hyp_ngram in ref_ngrams: + correct[n] += min(hyp_count, ref_ngrams[hyp_ngram]) - sys_ngrams = BLEU.extract_ngrams(output) - for ngram in sys_ngrams.keys(): - n = len(ngram.split()) - correct[n-1] += min(sys_ngrams[ngram], ref_ngrams.get(ngram, 0)) - total[n-1] += sys_ngrams[ngram] + # Return a flattened list for efficient computation + return [hyp_len, ref_len] + correct + total - # Get BLEUScore object - score = self.compute_bleu( - correct, total, sys_len, ref_len, - smooth_method=self.smooth_method, smooth_value=self.smooth_value, - use_effective_order=use_effective_order) + def sentence_score(self, hypothesis: str, references: Sequence[str]) -> BLEUScore: + """Compute the metric for a single sentence against a single (or multiple) reference(s). - return score + :param hypothesis: A single hypothesis string. + :param references: A sequence of reference strings. + :return: a `BLEUScore` object. + """ + if not self.effective_order: + sacrelogger.warning( + 'It is recommended to enable `effective_order` for sentence-level BLEU.') + return super().sentence_score(hypothesis, references) diff --git a/sacrebleu/metrics/chrf.py b/sacrebleu/metrics/chrf.py index 6f032d97..f7d4f685 100644 --- a/sacrebleu/metrics/chrf.py +++ b/sacrebleu/metrics/chrf.py @@ -1,164 +1,284 @@ -import re +"""The implementation of chrF (Popović 2015) and chrF++ (Popović 2017) metrics.""" + +from typing import List, Sequence, Optional, Dict from collections import Counter -from itertools import zip_longest -from typing import List, Iterable, Union -from .base import BaseScore, Signature +from ..utils import sum_of_lists +from .base import Score, Signature, Metric +from .helpers import extract_all_char_ngrams, extract_word_ngrams class CHRFSignature(Signature): - def __init__(self, args): - super().__init__(args) + """A convenience class to represent the reproducibility signature for chrF. + :param args: key-value dictionary passed from the actual metric instance. + """ + def __init__(self, args: dict): + """`CHRFSignature` initializer.""" + super().__init__(args) self._abbr.update({ - 'numchars': 'n', + 'case': 'c', + 'eff': 'e', + 'nc': 'nc', + 'nw': 'nw', 'space': 's', }) self.info.update({ - 'space': str(self.args['chrf_whitespace']).lower(), - 'numchars': self.args['chrf_order'], + 'case': 'lc' if args['lowercase'] else 'mixed', + 'eff': 'yes' if not args['eps_smoothing'] else 'no', + 'nc': args['char_order'], + 'nw': args['word_order'], + 'space': 'yes' if args['whitespace'] else 'no', }) -class CHRFScore(BaseScore): - def __init__(self, score, beta, order): - super().__init__(score) +class CHRFScore(Score): + """A convenience class to represent chrF scores. + :param score: The chrF (chrF++) score. + :param char_order: The character n-gram order. + :param word_order: The word n-gram order. If equals to 2, the metric is referred to as chrF++. + :param beta: Determine the importance of recall w.r.t precision. + """ + def __init__(self, score: float, char_order: int, word_order: int, beta: int): + """`CHRFScore` initializer.""" self.beta = beta - self.order = order - self.prefix = 'chrF{0:d}'.format(self.beta) + self.char_order = char_order + self.word_order = word_order + + # Add + signs to denote chrF+ variant + name = f'chrF{self.beta}' + '+' * self.word_order - def format(self, width=2, score_only=False, signature=''): - # NOTE: Being 0-1 scaled, a default width of 1 is too small for chrF - width += 1 - if score_only: - return '{0:.{1}f}'.format(self.score, width) + super().__init__(name, score) - prefix = "{}+{}".format(self.prefix, signature) if signature else self.prefix - return '{pr} = {sc:.{w}f}'.format(pr=prefix, sc=self.score, w=width) +class CHRF(Metric): + """Computes the chrF(++) metric given hypotheses and references. -class CHRF: - # Default values for CHRF - ORDER = 6 + :param char_order: Character n-gram order. + :param word_order: Word n-gram order. If equals to 2, the metric is referred to as chrF++. + :param beta: Determine the importance of recall w.r.t precision. + :param lowercase: Enable case-insensitivity. + :param whitespace: If `True`, include whitespaces when extracting character n-grams. + :param eps_smoothing: If `True`, applies epsilon smoothing similar + to reference chrF++.py, NLTK and Moses implementations. Otherwise, + it takes into account effective match order similar to sacreBLEU < 2.0.0. + :param references: A sequence of reference documents with document being + defined as a sequence of reference strings. If given, the reference n-grams + will be pre-computed and cached for faster re-computation across many systems. + """ - # default to 2 (per http://www.aclweb.org/anthology/W16-2341) + # Maximum character n-gram order to take into account + CHAR_ORDER = 6 + + # chrF+ additionally takes into account some of the word n-grams + WORD_ORDER = 0 + + # Defaults to 2 (per http://www.aclweb.org/anthology/W16-2341) BETA = 2 - def __init__(self, args): - self.name = 'chrf' - self.include_whitespace = args.chrf_whitespace - self.order = args.chrf_order - self.beta = args.chrf_beta - self.signature = CHRFSignature(args) + # Cache string.punctuation for chrF+' punctuation stripper + _PUNCTS = set('!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~') - if self.include_whitespace: - self._preprocess = lambda x: x - else: - self._preprocess = lambda x: re.sub(r'\s+', '', x).strip() + _SIGNATURE_TYPE = CHRFSignature + + def __init__(self, char_order: int = CHAR_ORDER, + word_order: int = WORD_ORDER, + beta: int = BETA, + lowercase: bool = False, + whitespace: bool = False, + eps_smoothing: bool = False, + references: Optional[Sequence[Sequence[str]]] = None): + """`CHRF` initializer.""" + super().__init__() + + self.beta = beta + self.char_order = char_order + self.word_order = word_order + self.order = self.char_order + self.word_order + self.lowercase = lowercase + self.whitespace = whitespace + self.eps_smoothing = eps_smoothing + + if references is not None: + # Pre-compute reference ngrams + self._ref_cache = self._cache_references(references) @staticmethod - def extract_char_ngrams(s: str, n: int) -> Counter: + def _get_match_statistics(hyp_ngrams: Counter, ref_ngrams: Counter) -> List[int]: + """Computes the match statistics between hypothesis and reference n-grams. + + :param hyp_ngrams: A `Counter` holding hypothesis n-grams. + :param ref_ngrams: A `Counter` holding reference n-grams. + :return: A list of three numbers denoting hypothesis n-gram count, + reference n-gram count and the intersection count. + """ + # Counter's internal intersection is not that fast, count manually + match_count, hyp_count = 0, 0 + for ng, count in hyp_ngrams.items(): + hyp_count += count + if ng in ref_ngrams: + match_count += min(count, ref_ngrams[ng]) + + return [ + # Don't count hits if no reference exists for that n-gram + hyp_count if ref_ngrams else 0, + sum(ref_ngrams.values()), + match_count, + ] + + def _remove_punctuation(self, sent: str) -> List[str]: + """Separates out punctuations from beginning and end of words for chrF. + Adapted from https://github.com/m-popovic/chrF + + :param sent: A string. + :return: A list of words. """ - Yields counts of character n-grams from string s of order n. + tokenized = [] + for w in sent.split(): + if len(w) == 1: + tokenized.append(w) + else: + # NOTE: This splits '(hi)' to '(hi' and ')' (issue #124) + if w[-1] in self._PUNCTS: + tokenized += [w[:-1], w[-1]] + elif w[0] in self._PUNCTS: + tokenized += [w[0], w[1:]] + else: + tokenized.append(w) + return tokenized + + def _preprocess_segment(self, sent: str) -> str: + """Given a sentence, apply optional lowercasing. + + :param sent: The input sentence string. + :return: The pre-processed output string. """ - return Counter([s[i:i + n] for i in range(len(s) - n + 1)]) + return sent.lower() if self.lowercase else sent - @staticmethod - def compute_chrf(statistics: List[int], - order: int, - beta: float) -> CHRFScore: + def _compute_f_score(self, statistics: List[int]) -> float: + """Compute the chrF score given the n-gram match statistics. + :param statistics: A flattened list of 3 * (`char_order` + `word_order`) + elements giving the [hyp, ref, match] counts for each order. + :return: The final f_beta score between [0, 100]. + """ + eps = 1e-16 score = 0.0 - avg_recall = 0.0 - avg_precision = 0.0 effective_order = 0 + factor = self.beta ** 2 + avg_prec, avg_rec = 0.0, 0.0 + + for i in range(self.order): + n_hyp, n_ref, n_match = statistics[3 * i: 3 * i + 3] + + # chrF++.py style EPS smoothing (also used by Moses and NLTK) + prec = n_match / n_hyp if n_hyp > 0 else eps + rec = n_match / n_ref if n_ref > 0 else eps + + denom = factor * prec + rec + score += ((1 + factor) * prec * rec / denom) if denom > 0 else eps - for i in range(order): - hypotheses_ngrams = statistics[3 * i + 0] - references_ngrams = statistics[3 * i + 1] - common_ngrams = statistics[3 * i + 2] - if hypotheses_ngrams > 0 and references_ngrams > 0: - avg_precision += common_ngrams / hypotheses_ngrams - avg_recall += common_ngrams / references_ngrams + # sacreBLEU <2.0.0 style effective order smoothing + if n_hyp > 0 and n_ref > 0: + avg_prec += prec + avg_rec += rec effective_order += 1 + if self.eps_smoothing: + return 100 * score / self.order + if effective_order == 0: - avg_precision, avg_recall = 0.0, 0.0 + avg_prec = avg_rec = 0.0 else: - avg_precision /= effective_order - avg_recall /= effective_order + avg_prec /= effective_order + avg_rec /= effective_order - if avg_precision + avg_recall == 0: - score = 0.0 + if avg_prec + avg_rec: + score = (1 + factor) * avg_prec * avg_rec + score /= ((factor * avg_prec) + avg_rec) + return 100 * score else: - beta_square = beta ** 2 - score = (1 + beta_square) * (avg_precision * avg_recall) - score /= ((beta_square * avg_precision) + avg_recall) - - return CHRFScore(score, beta, order) + return 0.0 - def get_sentence_statistics(self, hypothesis: str, - references: List[str]) -> List[int]: - # NOTE: multi-reference not supported yet - reference = references[0] + def _compute_score_from_stats(self, stats: List[int]) -> CHRFScore: + """Computes the final score from already aggregated statistics. - hypothesis = self._preprocess(hypothesis) - reference = self._preprocess(reference) - statistics = [0] * (self.order * 3) - for i in range(self.order): - n = i + 1 - hypothesis_ngrams = self.extract_char_ngrams(hypothesis, n) - reference_ngrams = self.extract_char_ngrams(reference, n) - common_ngrams = hypothesis_ngrams & reference_ngrams - statistics[3 * i + 0] = sum(hypothesis_ngrams.values()) - statistics[3 * i + 1] = sum(reference_ngrams.values()) - statistics[3 * i + 2] = sum(common_ngrams.values()) - return statistics - - def sentence_score(self, hypothesis: str, references: List[str]) -> CHRFScore: + :param stats: A list or numpy array of segment-level statistics. + :return: A `CHRFScore` object. """ - Computes ChrF on a single sentence pair. + return CHRFScore( + self._compute_f_score(stats), self.char_order, + self.word_order, self.beta) - :param hypothesis: Hypothesis string. - :param references: Reference string(s). - :return: Chrf score. - """ - assert not isinstance(references, str), \ - "sentence_score needs a list of references, not a single string" - stats = self.get_sentence_statistics(hypothesis, references) - return self.compute_chrf(stats, self.order, self.beta) + def _aggregate_and_compute(self, stats: List[List[int]]) -> CHRFScore: + """Computes the final score given the pre-computed corpus statistics. - def corpus_score(self, sys_stream: Union[str, Iterable[str]], - ref_streams: Union[str, List[Iterable[str]]]) -> CHRFScore: + :param stats: A list of segment-level statistics + :return: A `CHRFScore` object. """ - Computes Chrf on a corpus. + return self._compute_score_from_stats(sum_of_lists(stats)) + + def _extract_reference_info(self, refs: Sequence[str]) -> Dict[str, List[List[Counter]]]: + """Given a list of reference segments, extract the character and word n-grams. - :param hypotheses: Stream of hypotheses. - :param references: Stream of references. - :return: Chrf score. + :param refs: A sequence of reference segments. + :return: A list where each element contains n-grams per reference segment. """ + ngrams = [] - # Add some robustness to the input arguments - if isinstance(sys_stream, str): - sys_stream = [sys_stream] + for ref in refs: + # extract character n-grams + stats = extract_all_char_ngrams(ref, self.char_order, self.whitespace) - if isinstance(ref_streams, str): - ref_streams = [[ref_streams]] + # Check chrF+ mode + if self.word_order > 0: + ref_words = self._remove_punctuation(ref) - corpus_statistics = [0] * (self.order * 3) + for n in range(self.word_order): + stats.append(extract_word_ngrams(ref_words, n + 1)) - fhs = [sys_stream] + ref_streams - for lines in zip_longest(*fhs): - if None in lines: - raise EOFError("Source and reference streams have different lengths!") + ngrams.append(stats) - # Unpack - hypothesis, *refs = lines + return {'ref_ngrams': ngrams} - statistics = self.get_sentence_statistics(hypothesis, refs) - for i in range(len(statistics)): - corpus_statistics[i] += statistics[i] + def _compute_segment_statistics( + self, hypothesis: str, ref_kwargs: Dict) -> List[int]: + """Given a (pre-processed) hypothesis sentence and already computed + reference n-grams, returns the best match statistics across the + references. - return self.compute_chrf(corpus_statistics, self.order, self.beta) + :param hypothesis: Hypothesis sentence. + :param ref_kwargs: A dictionary with key `ref_ngrams` which is a list + where each sublist contains n-gram counters for a particular reference sentence. + :return: A list of integers where each triplet denotes [hyp, ref, match] + statistics. + """ + best_stats = [] + best_f_score = -1.0 + + # extract character n-grams + all_hyp_ngrams = extract_all_char_ngrams( + hypothesis, self.char_order, self.whitespace) + + # Check chrF+ mode to see if we'll add word n-grams as well + if self.word_order > 0: + # Primitive tokenization: separate out punctuations + hwords = self._remove_punctuation(hypothesis) + _range = range(1, self.word_order + 1) + all_hyp_ngrams.extend([extract_word_ngrams(hwords, n) for n in _range]) + + # Iterate over multiple references, pick the one with best F score + for _ref_ngrams in ref_kwargs['ref_ngrams']: + stats = [] + # Traverse all orders + for h, r in zip(all_hyp_ngrams, _ref_ngrams): + stats.extend(self._get_match_statistics(h, r)) + f_score = self._compute_f_score(stats) + + if f_score > best_f_score: + best_f_score = f_score + best_stats = stats + + return best_stats diff --git a/sacrebleu/metrics/helpers.py b/sacrebleu/metrics/helpers.py new file mode 100644 index 00000000..72ec1446 --- /dev/null +++ b/sacrebleu/metrics/helpers.py @@ -0,0 +1,69 @@ +"""Various utility functions for word and character n-gram extraction.""" + +from collections import Counter +from typing import List, Tuple + + +def extract_all_word_ngrams(line: str, min_order: int, max_order: int) -> Tuple[Counter, int]: + """Extracts all ngrams (min_order <= n <= max_order) from a sentence. + + :param line: A string sentence. + :param min_order: Minimum n-gram order. + :param max_order: Maximum n-gram order. + :return: a Counter object with n-grams counts and the sequence length. + """ + + ngrams = [] + tokens = line.split() + + for n in range(min_order, max_order + 1): + for i in range(0, len(tokens) - n + 1): + ngrams.append(tuple(tokens[i: i + n])) + + return Counter(ngrams), len(tokens) + + +def extract_word_ngrams(tokens: List[str], n: int) -> Counter: + """Extracts n-grams with order `n` from a list of tokens. + + :param tokens: A list of tokens. + :param n: The order of n-grams. + :return: a Counter object with n-grams counts. + """ + return Counter([' '.join(tokens[i:i + n]) for i in range(len(tokens) - n + 1)]) + + +def extract_char_ngrams(line: str, n: int, include_whitespace: bool = False) -> Counter: + """Yields counts of character n-grams from a sentence. + + :param line: A segment containing a sequence of words. + :param n: The order of the n-grams. + :param include_whitespace: If given, will not strip whitespaces from the line. + :return: a dictionary containing ngrams and counts + """ + if not include_whitespace: + line = ''.join(line.split()) + + return Counter([line[i:i + n] for i in range(len(line) - n + 1)]) + + +def extract_all_char_ngrams( + line: str, max_order: int, include_whitespace: bool = False) -> List[Counter]: + """Extracts all character n-grams at once for convenience. + + :param line: A segment containing a sequence of words. + :param max_order: The maximum order of the n-grams. + :param include_whitespace: If given, will not strip whitespaces from the line. + :return: a list of Counter objects containing ngrams and counts. + """ + + counters = [] + + if not include_whitespace: + line = ''.join(line.split()) + + for n in range(1, max_order + 1): + ngrams = Counter([line[i:i + n] for i in range(len(line) - n + 1)]) + counters.append(ngrams) + + return counters diff --git a/sacrebleu/metrics/lib_ter.py b/sacrebleu/metrics/lib_ter.py new file mode 100644 index 00000000..2d2de494 --- /dev/null +++ b/sacrebleu/metrics/lib_ter.py @@ -0,0 +1,478 @@ +"""This module implements various utility functions for the TER metric.""" + +# Copyright 2020 Memsource +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import List, Tuple, Dict + + +_COST_INS = 1 +_COST_DEL = 1 +_COST_SUB = 1 + +# Tercom-inspired limits +_MAX_SHIFT_SIZE = 10 +_MAX_SHIFT_DIST = 50 +_BEAM_WIDTH = 25 + +# Our own limits +_MAX_CACHE_SIZE = 10000 +_MAX_SHIFT_CANDIDATES = 1000 +_INT_INFINITY = int(1e16) + +_OP_INS = 'i' +_OP_DEL = 'd' +_OP_NOP = ' ' +_OP_SUB = 's' +_OP_UNDEF = 'x' + +_FLIP_OPS = str.maketrans(_OP_INS + _OP_DEL, _OP_DEL + _OP_INS) + + +def translation_edit_rate(words_hyp: List[str], words_ref: List[str]) -> Tuple[int, int]: + """Calculate the translation edit rate. + + :param words_hyp: Tokenized translation hypothesis. + :param words_ref: Tokenized reference translation. + :return: tuple (number of edits, length) + """ + n_words_ref = len(words_ref) + n_words_hyp = len(words_hyp) + if n_words_ref == 0: + # FIXME: This trace here is not used? + trace = _OP_DEL * n_words_hyp + # special treatment of empty refs + return n_words_hyp, 0 + + cached_ed = BeamEditDistance(words_ref) + shifts = 0 + + input_words = words_hyp + checked_candidates = 0 + while True: + # do shifts until they stop reducing the edit distance + delta, new_input_words, checked_candidates = _shift( + input_words, words_ref, cached_ed, checked_candidates) + + if checked_candidates >= _MAX_SHIFT_CANDIDATES: + break + + if delta <= 0: + break + shifts += 1 + input_words = new_input_words + + edit_distance, trace = cached_ed(input_words) + total_edits = shifts + edit_distance + + return total_edits, n_words_ref + + +def _shift(words_h: List[str], words_r: List[str], cached_ed, + checked_candidates: int) -> Tuple[int, List[str], int]: + """Attempt to shift words in hypothesis to match reference. + + Returns the shift that reduces the edit distance the most. + + Note that the filtering of possible shifts and shift selection are heavily + based on somewhat arbitrary heuristics. The code here follows as closely + as possible the logic in Tercom, not always justifying the particular design + choices. + + :param words_h: Hypothesis. + :param words_r: Reference. + :param cached_ed: Cached edit distance. + :param checked_candidates: Number of shift candidates that were already + evaluated. + :return: (score, shifted_words, checked_candidates). Best shift and updated + number of evaluated shift candidates. + """ + pre_score, inv_trace = cached_ed(words_h) + + # to get alignment, we pretend we are rewriting reference into hypothesis, + # so we need to flip the trace of edit operations + trace = _flip_trace(inv_trace) + align, ref_err, hyp_err = trace_to_alignment(trace) + + best = None + + for start_h, start_r, length in _find_shifted_pairs(words_h, words_r): + # don't do the shift unless both the hypothesis was wrong and the + # reference doesn't match hypothesis at the target position + if sum(hyp_err[start_h: start_h + length]) == 0: + continue + + if sum(ref_err[start_r: start_r + length]) == 0: + continue + + # don't try to shift within the subsequence + if start_h <= align[start_r] < start_h + length: + continue + + prev_idx = -1 + for offset in range(-1, length): + if start_r + offset == -1: + idx = 0 # insert before the beginning + elif start_r + offset in align: + # Unlike Tercom which inserts *after* the index, we insert + # *before* the index. + idx = align[start_r + offset] + 1 + else: + break # offset is out of bounds => aims past reference + + if idx == prev_idx: + continue # skip idx if already tried + + prev_idx = idx + + shifted_words = _perform_shift(words_h, start_h, length, idx) + assert(len(shifted_words) == len(words_h)) + + # Elements of the tuple are designed to replicate Tercom ranking + # of shifts: + candidate = ( + pre_score - cached_ed(shifted_words)[0], # highest score first + length, # then, longest match first + -start_h, # then, earliest match first + -idx, # then, earliest target position first + shifted_words, + ) + + checked_candidates += 1 + + if not best or candidate > best: + best = candidate + + if checked_candidates >= _MAX_SHIFT_CANDIDATES: + break + + if not best: + return 0, words_h, checked_candidates + else: + best_score, _, _, _, shifted_words = best + return best_score, shifted_words, checked_candidates + + +def _perform_shift(words: List[str], start: int, length: int, target: int) -> List[str]: + """Perform a shift in `words` from `start` to `target`. + + :param words: Words to shift. + :param start: Where from. + :param length: How many words. + :param target: Where to. + :return: Shifted words. + """ + if target < start: + # shift before previous position + return words[:target] + words[start: start + length] \ + + words[target: start] + words[start + length:] + elif target > start + length: + # shift after previous position + return words[:start] + words[start + length: target] \ + + words[start: start + length] + words[target:] + else: + # shift within the shifted string + return words[:start] + words[start + length: length + target] \ + + words[start: start + length] + words[length + target:] + + +def _find_shifted_pairs(words_h: List[str], words_r: List[str]): + """Find matching word sub-sequences in two lists of words. + + Ignores sub-sequences starting at the same position. + + :param words_h: First word list. + :param words_r: Second word list. + :return: Yields tuples of (h_start, r_start, length) such that: + words_h[h_start:h_start+length] = words_r[r_start:r_start+length] + """ + n_words_h = len(words_h) + n_words_r = len(words_r) + for start_h in range(n_words_h): + for start_r in range(n_words_r): + # this is slightly different from what tercom does but this should + # really only kick in in degenerate cases + if abs(start_r - start_h) > _MAX_SHIFT_DIST: + continue + + length = 0 + while words_h[start_h + length] == words_r[start_r + length] and length < _MAX_SHIFT_SIZE: + length += 1 + + yield start_h, start_r, length + + # If one sequence is consumed, stop processing + if n_words_h == start_h + length or n_words_r == start_r + length: + break + + +def _flip_trace(trace): + """Flip the trace of edit operations. + + Instead of rewriting a->b, get a recipe for rewriting b->a. + + Simply flips insertions and deletions. + """ + return trace.translate(_FLIP_OPS) + + +def trace_to_alignment(trace: str) -> Tuple[Dict, List, List]: + """Transform trace of edit operations into an alignment of the sequences. + + :param trace: Trace of edit operations (' '=no change or 's'/'i'/'d'). + :return: Alignment, error positions in reference, error positions in hypothesis. + """ + pos_hyp = -1 + pos_ref = -1 + hyp_err = [] + ref_err = [] + align = {} + + # we are rewriting a into b + for op in trace: + if op == _OP_NOP: + pos_hyp += 1 + pos_ref += 1 + align[pos_ref] = pos_hyp + hyp_err.append(0) + ref_err.append(0) + elif op == _OP_SUB: + pos_hyp += 1 + pos_ref += 1 + align[pos_ref] = pos_hyp + hyp_err.append(1) + ref_err.append(1) + elif op == _OP_INS: + pos_hyp += 1 + hyp_err.append(1) + elif op == _OP_DEL: + pos_ref += 1 + align[pos_ref] = pos_hyp + ref_err.append(1) + else: + raise Exception(f"unknown operation {op!r}") + + return align, ref_err, hyp_err + + +class BeamEditDistance: + """Edit distance with several features required for TER calculation. + + * internal cache + * "beam" search + * tracking of edit operations + + The internal self._cache works like this: + + Keys are words of the hypothesis. Values are tuples (next_node, row) where: + + * next_node is the cache for the next word in the sequence + * row is the stored row of the edit distance matrix + + Effectively, caching allows to skip several rows in the edit distance + matrix calculation and instead, to initialize the computation with the last + matching matrix row. + + Beam search, as implemented here, only explores a fixed-size sub-row of + candidates around the matrix diagonal (more precisely, it's a + "pseudo"-diagonal since we take the ratio of sequence lengths into account). + + Tracking allows to reconstruct the optimal sequence of edit operations. + + :param words_ref: A list of reference tokens. + """ + def __init__(self, words_ref: List[str]): + """`BeamEditDistance` initializer.""" + self._words_ref = words_ref + self._n_words_ref = len(self._words_ref) + + # first row corresponds to insertion operations of the reference, + # so we do 1 edit operation per reference word + self._initial_row = [(i * _COST_INS, _OP_INS) + for i in range(self._n_words_ref + 1)] + + self._cache = {} # type: Dict[str, Tuple] + self._cache_size = 0 + + # Precomputed empty matrix row. Contains infinities so that beam search + # avoids using the uninitialized cells. + self._empty_row = [(_INT_INFINITY, _OP_UNDEF)] * (self._n_words_ref + 1) + + def __call__(self, words_hyp: List[str]) -> Tuple[int, str]: + """Calculate edit distance between self._words_ref and the hypothesis. + + Uses cache to skip some of the computation. + + :param words_hyp: Words in translation hypothesis. + :return: Edit distance score. + """ + + # skip initial words in the hypothesis for which we already know the + # edit distance + start_position, dist = self._find_cache(words_hyp) + + # calculate the rest of the edit distance matrix + edit_distance, newly_created_matrix, trace = self._edit_distance( + words_hyp, start_position, dist) + + # update our cache with the newly calculated rows + self._add_cache(words_hyp, newly_created_matrix) + + return edit_distance, trace + + def _edit_distance(self, words_h: List[str], start_h: int, + cache: List[List[Tuple[int, str]]]) -> Tuple[int, List, str]: + """Actual edit distance calculation. + + Can be initialized with the last cached row and a start position in + the hypothesis that it corresponds to. + + :param words_h: Words in translation hypothesis. + :param start_h: Position from which to start the calculation. + (This is zero if no cache match was found.) + :param cache: Precomputed rows corresponding to edit distance matrix + before `start_h`. + :return: Edit distance value, newly computed rows to update the + cache, trace. + """ + + n_words_h = len(words_h) + + # initialize the rest of the matrix with infinite edit distances + rest_empty = [list(self._empty_row) + for _ in range(n_words_h - start_h)] + + dist = cache + rest_empty + + assert len(dist) == n_words_h + 1 + + length_ratio = self._n_words_ref / n_words_h if words_h else 1 + + # in some crazy sentences, the difference in length is so large that + # we may end up with zero overlap with previous row + if _BEAM_WIDTH < length_ratio / 2: + beam_width = math.ceil(length_ratio / 2 + _BEAM_WIDTH) + else: + beam_width = _BEAM_WIDTH + + # calculate the Levenshtein distance + for i in range(start_h + 1, n_words_h + 1): + pseudo_diag = math.floor(i * length_ratio) + min_j = max(0, pseudo_diag - beam_width) + max_j = min(self._n_words_ref + 1, pseudo_diag + beam_width) + + if i == n_words_h: + max_j = self._n_words_ref + 1 + + for j in range(min_j, max_j): + if j == 0: + dist[i][j] = (dist[i - 1][j][0] + _COST_DEL, _OP_DEL) + else: + if words_h[i - 1] == self._words_ref[j - 1]: + cost_sub = 0 + op_sub = _OP_NOP + else: + cost_sub = _COST_SUB + op_sub = _OP_SUB + + # Tercom prefers no-op/sub, then insertion, then deletion. + # But since we flip the trace and compute the alignment from + # the inverse, we need to swap order of insertion and + # deletion in the preference. + ops = ( + (dist[i - 1][j - 1][0] + cost_sub, op_sub), + (dist[i - 1][j][0] + _COST_DEL, _OP_DEL), + (dist[i][j - 1][0] + _COST_INS, _OP_INS), + ) + + for op_cost, op_name in ops: + if dist[i][j][0] > op_cost: + dist[i][j] = op_cost, op_name + + # get the trace + trace = "" + i = n_words_h + j = self._n_words_ref + + while i > 0 or j > 0: + op = dist[i][j][1] + trace = op + trace + if op in (_OP_SUB, _OP_NOP): + i -= 1 + j -= 1 + elif op == _OP_INS: + j -= 1 + elif op == _OP_DEL: + i -= 1 + else: + raise Exception(f"unknown operation {op!r}") + + return dist[-1][-1][0], dist[len(cache):], trace + + def _add_cache(self, words_hyp: List[str], mat: List[List[Tuple]]): + """Add newly computed rows to cache. + + Since edit distance is only calculated on the hypothesis suffix that + was not in cache, the number of rows in `mat` may be shorter than + hypothesis length. In that case, we skip over these initial words. + + :param words_hyp: Hypothesis words. + :param mat: Edit distance matrix rows for each position. + """ + if self._cache_size >= _MAX_CACHE_SIZE: + return + + node = self._cache + + n_mat = len(mat) + + # how many initial words to skip + skip_num = len(words_hyp) - n_mat + + # jump through the cache to the current position + for i in range(skip_num): + node = node[words_hyp[i]][0] + + assert len(words_hyp[skip_num:]) == n_mat + + # update cache with newly computed rows + for word, row in zip(words_hyp[skip_num:], mat): + if word not in node: + node[word] = ({}, tuple(row)) + self._cache_size += 1 + value = node[word] + node = value[0] + + def _find_cache(self, words_hyp: List[str]) -> Tuple[int, List[List]]: + """Find the already computed rows of the edit distance matrix in cache. + + Returns a partially computed edit distance matrix. + + :param words_hyp: Translation hypothesis. + :return: Tuple (start position, dist). + """ + node = self._cache + start_position = 0 + dist = [self._initial_row] + for word in words_hyp: + if word in node: + start_position += 1 + node, row = node[word] + dist.append(row) + else: + break + + return start_position, dist diff --git a/sacrebleu/metrics/ter.py b/sacrebleu/metrics/ter.py index 5fe3ff10..3078656c 100644 --- a/sacrebleu/metrics/ter.py +++ b/sacrebleu/metrics/ter.py @@ -1,3 +1,5 @@ +"""The implementation of the TER metric (Snover et al., 2006).""" + # Copyright 2020 Memsource # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,578 +15,167 @@ # limitations under the License. -import math -from typing import List, Tuple, Dict, Union, Iterable -from itertools import zip_longest +from typing import List, Dict, Sequence, Optional, Any from ..tokenizers.tokenizer_ter import TercomTokenizer -from .base import BaseScore, Signature - -# Translation edit rate (TER). -# -# A near-exact reimplementation of the Tercom algorithm, produces identical -# results on all "sane" outputs. -# -# The beam edit distance algorithm uses a slightly different approach (we stay -# around the diagonal which is faster, at least in Python) so in some -# (extreme) corner cases, the output could differ. -# -# Tercom original implementation: -# -# https://github.com/jhclark/tercom -# -# Caching in the edit distance is based partly on the PyTer package by Hiroyuki -# Tanaka (MIT license). -# -# https://github.com/aflc/pyter - -_COST_INS = 1 -_COST_DEL = 1 -_COST_SUB = 1 - -# Tercom-inspired limits -_MAX_SHIFT_SIZE = 10 -_MAX_SHIFT_DIST = 50 -_BEAM_WIDTH = 25 - -# Our own limits -_MAX_CACHE_SIZE = 10000 -_MAX_SHIFT_CANDIDATES = 1000 -_INT_INFINITY = int(1e16) - -_OP_INS = 'i' -_OP_DEL = 'd' -_OP_NOP = ' ' -_OP_SUB = 's' -_OP_UNDEF = 'x' - - -class TERScore(BaseScore): - def __init__(self, num_edits, ref_length): - score = num_edits / ref_length if ref_length > 0 else 1 - super().__init__(score) - - self.num_edits = num_edits - self.ref_length = ref_length - self.prefix = 'TER' - - def format(self, width=2, score_only=False, signature=''): - # the default width of 1 is too small for TER, it's reported with 3 - # decimal places in matrix.statmt.org - if width == 1: - width = 3 - - if score_only: - return '{0:.{1}f}'.format(self.score, width) - - prefix = "{}+{}".format(self.prefix, signature) if signature else self.prefix - return '{pr} = {sc:.{w}f}'.format(pr=prefix, sc=self.score, w=width) +from ..utils import sum_of_lists +from .base import Score, Signature, Metric +from .lib_ter import translation_edit_rate class TERSignature(Signature): - def __init__(self, args): + """A convenience class to represent the reproducibility signature for TER. + + :param args: key-value dictionary passed from the actual metric instance. + """ + def __init__(self, args: dict): + """`TERSignature` initializer.""" super().__init__(args) self._abbr.update({ + 'case': 'c', 'tok': 't', + 'norm': 'nr', + 'punct': 'pn', + 'asian': 'as', }) self.info.update({ - 'tok': TER.create_tokenizer(args).signature(), + 'case': 'mixed' if args['case_sensitive'] else 'lc', + 'tok': args['tokenizer_signature'], + 'norm': args['normalized'], + 'punct': not args['no_punct'], + 'asian': args['asian_support'], }) -class TER: - TOKENIZER_DEFAULTS = { - "normalized": False, - "no_punct": False, - "asian_support": False, - "case_sensitive": False, - } - - @staticmethod - def create_tokenizer(args): - # hackish workaround for specifying tokenizer config - config = dict(TER.TOKENIZER_DEFAULTS) - args_vars = vars(args) - for k in config: - if k in args_vars: - config[k] = args_vars[k] - return TercomTokenizer(**config) - - def __init__(self, args): - self.tokenizer = self.create_tokenizer(args) - self.signature = TERSignature(args) - - def corpus_score(self, sys_stream: Union[str, Iterable[str]], - ref_streams: Union[str, List[Iterable[str]]]) -> TERScore: - # Add some robustness to the input arguments - if isinstance(sys_stream, str): - sys_stream = [sys_stream] - - if isinstance(ref_streams, str): - ref_streams = [[ref_streams]] - - fhs = [sys_stream] + ref_streams - - total_edits = 0 - sum_ref_lengths = 0.0 - - for lines in zip_longest(*fhs): - if None in lines: - raise EOFError("Source and reference streams have different lengths!") - hypo, *refs = lines - - words_hyp = self.tokenizer(hypo).split() - - best_num_edits = _INT_INFINITY - ref_lengths = 0 - - for ref in refs: - words_ref = self.tokenizer(ref).split() - num_edits, ref_len = translation_edit_rate(words_hyp, words_ref) - ref_lengths += ref_len - if num_edits < best_num_edits: - best_num_edits = num_edits - - total_edits += best_num_edits - sum_ref_lengths += (ref_lengths / len(refs)) - - return TERScore(total_edits, sum_ref_lengths) - - def sentence_score(self, hypothesis: str, references: List[str]) -> TERScore: - return self.corpus_score(hypothesis, [[ref] for ref in references]) - +class TERScore(Score): + """A convenience class to represent TER scores. -def translation_edit_rate(words_hyp: List[str], words_ref: List[str]) -> Tuple[int, int]: - """Calculate the translation edit rate. - - :param words_hyp: Tokenized translation hypothesis. - :param words_ref: Tokenized reference translation. - :return: tuple (number of edits, length) + :param score: The TER score. + :param num_edits: The cumulative number of edits. + :param ref_length: The cumulative average reference length. """ - if len(words_ref) == 0: - trace = _OP_DEL * len(words_hyp) - # special treatment of empty refs - return len(words_hyp), 0 - - cached_ed = BeamEditDistance(words_ref) - shifts = 0 - - input_words = words_hyp - checked_candidates = 0 - while True: - # do shifts until they stop reducing the edit distance - delta, new_input_words, checked_candidates = _shift( - input_words, words_ref, cached_ed, checked_candidates) - - if checked_candidates >= _MAX_SHIFT_CANDIDATES: - break - - if delta <= 0: - break - shifts += 1 - input_words = new_input_words - - edit_distance, trace = cached_ed(input_words) - total_edits = shifts + edit_distance - - return total_edits, len(words_ref) - - -def _shift(words_h: List[str], words_r: List[str], cached_ed, - checked_candidates: int) -> Tuple[int, List[str], int]: - """Attempt to shift words in hypothesis to match reference. - - Returns the shift that reduces the edit distance the most. - - Note that the filtering of possible shifts and shift selection are heavily - based on somewhat arbitrary heuristics. The code here follows as closely - as possible the logic in Tercom, not always justifying the particular design - choices. - - :param words_h: Hypothesis. - :param words_r: Reference. - :param cached_ed: Cached edit distance. - :param checked_candidates: Number of shift candidates that were already - evaluated. - :return: (score, shifted_words, checked_candidates). Best shift and updated - number of evaluated shift candidates. - """ - pre_score, inv_trace = cached_ed(words_h) - - # to get alignment, we pretend we are rewriting reference into hypothesis, - # so we need to flip the trace of edit operations - trace = _flip_trace(inv_trace) - align, ref_err, hyp_err = trace_to_alignment(trace) - - best = None - - for start_h, start_r, length in _find_shifted_pairs(words_h, words_r): - # don't do the shift unless both the hypothesis was wrong and the - # reference doesn't match hypothesis at the target position - if sum(hyp_err[start_h:start_h+length]) == 0: - continue - - if sum(ref_err[start_r:start_r+length]) == 0: - continue - - # don't try to shift within the subsequence - if start_h <= align[start_r] < start_h + length: - continue - - prev_idx = -1 - for offset in range(-1, length): - if start_r + offset == -1: - idx = 0 # insert before the beginning - elif start_r + offset in align: - # Unlike Tercom which inserts *after* the index, we insert - # *before* the index. - idx = align[start_r + offset] + 1 - else: - break # offset is out of bounds => aims past reference - - if idx == prev_idx: - continue # skip idx if already tried - - prev_idx = idx - - shifted_words = _perform_shift(words_h, start_h, length, idx) - assert(len(shifted_words) == len(words_h)) - - # Elements of the tuple are designed to replicate Tercom ranking - # of shifts: - candidate = ( - pre_score - cached_ed(shifted_words)[0], # highest score first - length, # then, longest match first - -start_h, # then, earliest match first - -idx, # then, earliest target position first - shifted_words, - ) - - checked_candidates += 1 - - if not best or candidate > best: - best = candidate - - if checked_candidates >= _MAX_SHIFT_CANDIDATES: - break - - if not best: - return 0, words_h, checked_candidates - else: - best_score, _, _, _, shifted_words = best - return best_score, shifted_words, checked_candidates - - -def _perform_shift(words: List[str], start: int, length: int, target: int) -> List[str]: - """Perform a shift in `words` from `start` to `target`. - - :param words: Words to shift. - :param start: Where from. - :param length: How many words. - :param target: Where to. - :return: Shifted words. - """ - if target < start: - # shift before previous position - return (words[:target] + words[start:start+length] - + words[target:start] + words[start+length:]) - elif target > start + length: - # shift after previous position - return (words[:start] + words[start+length:target] - + words[start:start+length] + words[target:]) - else: - # shift within the shifted string - return (words[:start] + words[start+length:length+target] - + words[start:start+length] + words[length+target:]) - - -def _find_shifted_pairs(words_h: List[str], words_r: List[str]): - """Find matching word sub-sequences in two lists of words. - - Ignores sub-sequences starting at the same position. - - :param words_h: First word list. - :param words_r: Second word list. - :return: Yields tuples of (h_start, r_start, length) such that: - - words_h[h_start:h_start+length] = words_r[r_start:r_start+length] - """ - for start_h in range(len(words_h)): - for start_r in range(len(words_r)): - # this is slightly different from what tercom does but this should - # really only kick in in degenerate cases - if abs(start_r - start_h) > _MAX_SHIFT_DIST: - continue - - length = 0 - while (words_h[start_h + length] == words_r[start_r + length] - and length < _MAX_SHIFT_SIZE): - length += 1 - - if length != 0: - yield start_h, start_r, length - - if ((len(words_h) == start_h + length) - or (len(words_r) == start_r + length)): - break + def __init__(self, score: float, num_edits: float, ref_length: float): + """`TERScore` initializer.""" + super().__init__('TER', score) + self.num_edits = int(num_edits) + self.ref_length = ref_length -def _flip_trace(trace): - """Flip the trace of edit operations. +class TER(Metric): + """Translation edit rate (TER). A near-exact reimplementation of the Tercom + algorithm, produces identical results on all "sane" outputs. - Instead of rewriting a->b, get a recipe for rewriting b->a. + Tercom original implementation: https://github.com/jhclark/tercom - Simply flips insertions and deletions. - """ - ret = list(trace) - for i in range(len(ret)): - if ret[i] == _OP_INS: - ret[i] = _OP_DEL - elif ret[i] == _OP_DEL: - ret[i] = _OP_INS - return ''.join(ret) + The beam edit distance algorithm uses a slightly different approach (we stay + around the diagonal which is faster, at least in Python) so in some + (extreme) corner cases, the output could differ. + Caching in the edit distance is based partly on the PyTer package by Hiroyuki + Tanaka (MIT license). (https://github.com/aflc/pyter) -def trace_to_alignment(trace: str) -> Tuple[Dict, List, List]: - """Transform trace of edit operations into an alignment of the sequences. - - :param trace: Trace of edit operations (' '=no change or 's'/'i'/'d'). - :return: Alignment, error positions in reference, error positions in hypothesis. - """ - pos_hyp = -1 - pos_ref = -1 - hyp_err = [] - ref_err = [] - align = {} - - # we are rewriting a into b - for op in trace: - if op == _OP_NOP: - pos_hyp += 1 - pos_ref += 1 - align[pos_ref] = pos_hyp - hyp_err.append(0) - ref_err.append(0) - elif op == _OP_SUB: - pos_hyp += 1 - pos_ref += 1 - align[pos_ref] = pos_hyp - hyp_err.append(1) - ref_err.append(1) - elif op == _OP_INS: - pos_hyp += 1 - hyp_err.append(1) - elif op == _OP_DEL: - pos_ref += 1 - align[pos_ref] = pos_hyp - ref_err.append(1) - else: - raise Exception("unknown operation '{}'".format(op)) - - return align, ref_err, hyp_err - - -class BeamEditDistance: - """Edit distance with several features required for TER calculation. - - * internal cache - * "beam" search - * tracking of edit operations - - The internal self._cache works like this: - - Keys are words of the hypothesis. Values are tuples (next_node, row) where: - - * next_node is the cache for the next word in the sequence - * row is the stored row of the edit distance matrix - - Effectively, caching allows to skip several rows in the edit distance - matrix calculation and instead, to initialize the computation with the last - matching matrix row. - - Beam search, as implemented here, only explores a fixed-size sub-row of - candidates around the matrix diagonal (more precisely, it's a - "pseudo"-diagonal since we take the ratio of sequence lengths into account). - - Tracking allows to reconstruct the optimal sequence of edit operations. + :param normalized: If `True`, applies basic tokenization to sentences. + :param no_punct: If `True`, removes punctuations from sentences. + :param asian_support: If `True`, adds support for Asian character processing. + :param case_sensitive: If `True`, does not lowercase sentences. + :param references: A sequence of reference documents with document being + defined as a sequence of reference strings. If given, the reference info + will be pre-computed and cached for faster re-computation across many systems. """ - def __init__(self, words_ref: List[str]): - self._words_ref = words_ref - - # first row corresponds to insertion operations of the reference, - # so we do 1 edit operation per reference word - self._initial_row = [(i * _COST_INS, _OP_INS) - for i in range(len(self._words_ref) + 1)] - - self._cache = {} # type: Dict[str, Tuple] - self._cache_size = 0 - - # Precomputed empty matrix row. Contains infinities so that beam search - # avoids using the uninitialized cells. - self._empty_row = [(_INT_INFINITY, _OP_UNDEF)] * (len(self._words_ref) + 1) - - def __call__(self, words_hyp: List[str]) -> Tuple[int, str]: - """Calculate edit distance between self._words_ref and the hypothesis. - Uses cache to skip some of the computation. - - :param words_hyp: Words in translation hypothesis. - :return: Edit distance score. + _SIGNATURE_TYPE = TERSignature + + def __init__(self, normalized: bool = False, + no_punct: bool = False, + asian_support: bool = False, + case_sensitive: bool = False, + references: Optional[Sequence[Sequence[str]]] = None): + """`TER` initializer.""" + super().__init__() + + self.no_punct = no_punct + self.normalized = normalized + self.asian_support = asian_support + self.case_sensitive = case_sensitive + + self.tokenizer = TercomTokenizer( + normalized=self.normalized, + no_punct=self.no_punct, + asian_support=self.asian_support, + case_sensitive=self.case_sensitive, + ) + self.tokenizer_signature = self.tokenizer.signature() + + if references is not None: + self._ref_cache = self._cache_references(references) + + def _preprocess_segment(self, sent: str) -> str: + """Given a sentence, apply tokenization if enabled. + + :param sent: The input sentence string. + :return: The pre-processed output string. """ + return self.tokenizer(sent.rstrip()) - # skip initial words in the hypothesis for which we already know the - # edit distance - start_position, dist = self._find_cache(words_hyp) - - # calculate the rest of the edit distance matrix - edit_distance, newly_created_matrix, trace = self._edit_distance( - words_hyp, start_position, dist) - - # update our cache with the newly calculated rows - self._add_cache(words_hyp, newly_created_matrix) - - return edit_distance, trace + def _compute_score_from_stats(self, stats: List[float]) -> TERScore: + """Computes the final score from already aggregated statistics. - def _edit_distance(self, words_h: List[str], start_h: int, - cache: List[List[Tuple[int, str]]]) -> Tuple[int, List, str]: - """Actual edit distance calculation. + :param stats: A list or numpy array of segment-level statistics. + :return: A `TERScore` object. + """ + total_edits, sum_ref_lengths = stats[0], stats[1] + score = total_edits / sum_ref_lengths if sum_ref_lengths > 0 else 1 + return TERScore(100 * score, total_edits, sum_ref_lengths) - Can be initialized with the last cached row and a start position in - the hypothesis that it corresponds to. + def _aggregate_and_compute(self, stats: List[List[float]]) -> TERScore: + """Computes the final TER score given the pre-computed corpus statistics. - :param words_h: Words in translation hypothesis. - :param start_h: Position from which to start the calculation. - (This is zero if no cache match was found.) - :param cache: Precomputed rows corresponding to edit distance matrix - before `start_h`. - :return: Edit distance value, newly computed rows to update the - cache, trace. + :param stats: A list of segment-level statistics + :return: A `TERScore` instance. """ - - # initialize the rest of the matrix with infinite edit distances - rest_empty = [list(self._empty_row) - for _ in range(len(words_h) - start_h)] - - dist = cache + rest_empty - - assert len(dist) == len(words_h) + 1 - - if words_h: - length_ratio = len(self._words_ref) / len(words_h) - else: - length_ratio = 1 - - # in some crazy sentences, the difference in length is so large that - # we may end up with zero overlap with previous row - if _BEAM_WIDTH < length_ratio / 2: - beam_width = math.ceil(length_ratio / 2 + _BEAM_WIDTH) - else: - beam_width = _BEAM_WIDTH - - # calculate the Levenshtein distance - for i in range(start_h + 1, len(words_h) + 1): - pseudo_diag = math.floor(i * length_ratio) - min_j = max(0, pseudo_diag - beam_width) - max_j = min(len(self._words_ref) + 1, pseudo_diag + beam_width) - - if i == len(words_h): - max_j = len(self._words_ref) + 1 - - for j in range(min_j, max_j): - if j == 0: - dist[i][j] = (dist[i - 1][j][0] + _COST_DEL, _OP_DEL) - else: - if words_h[i - 1] == self._words_ref[j - 1]: - cost_sub = 0 - op_sub = _OP_NOP - else: - cost_sub = _COST_SUB - op_sub = _OP_SUB - - # Tercom prefers no-op/sub, then insertion, then deletion. - # But since we flip the trace and compute the alignment from - # the inverse, we need to swap order of insertion and - # deletion in the preference. - ops = ( - (dist[i - 1][j - 1][0] + cost_sub, op_sub), - (dist[i - 1][j][0] + _COST_DEL, _OP_DEL), - (dist[i][j - 1][0] + _COST_INS, _OP_INS), - ) - - for op_cost, op_name in ops: - if dist[i][j][0] > op_cost: - dist[i][j] = op_cost, op_name - - # get the trace - trace = "" - i = len(words_h) - j = len(self._words_ref) - - while i > 0 or j > 0: - op = dist[i][j][1] - trace = op + trace - if op in (_OP_SUB, _OP_NOP): - i -= 1 - j -= 1 - elif op == _OP_INS: - j -= 1 - elif op == _OP_DEL: - i -= 1 - else: - raise Exception("unknown operation '{}'".format(op)) - - return dist[-1][-1][0], dist[len(cache):], trace - - def _add_cache(self, words_hyp: List[str], mat: List[List[Tuple]]): - """Add newly computed rows to cache. - - Since edit distance is only calculated on the hypothesis suffix that - was not in cache, the number of rows in `mat` may be shorter than - hypothesis length. In that case, we skip over these initial words. - - :param words_hyp: Hypothesis words. - :param mat: Edit distance matrix rows for each position. + return self._compute_score_from_stats(sum_of_lists(stats)) + + def _compute_segment_statistics( + self, hypothesis: str, ref_kwargs: Dict) -> List[float]: + """Given a (pre-processed) hypothesis sentence and already computed + reference words, returns the segment statistics required to compute + the full TER score. + + :param hypothesis: Hypothesis sentence. + :param ref_kwargs: A dictionary with `ref_words` key which is a list + where each sublist contains reference words. + :return: A two-element list that contains the 'minimum number of edits' + and 'the average reference length'. """ - if self._cache_size >= _MAX_CACHE_SIZE: - return - node = self._cache + ref_lengths = 0 + best_num_edits = int(1e16) - # how many initial words to skip - skip_num = len(words_hyp) - len(mat) + words_hyp = hypothesis.split() - # jump through the cache to the current position - for i in range(skip_num): - node = node[words_hyp[i]][0] + # Iterate the references + ref_words = ref_kwargs['ref_words'] + for words_ref in ref_words: + num_edits, ref_len = translation_edit_rate(words_hyp, words_ref) + ref_lengths += ref_len + if num_edits < best_num_edits: + best_num_edits = num_edits - assert len(words_hyp[skip_num:]) == len(mat) + avg_ref_len = ref_lengths / len(ref_words) + return [best_num_edits, avg_ref_len] - # update cache with newly computed rows - for word, row in zip(words_hyp[skip_num:], mat): - if word not in node: - node[word] = ({}, tuple(row)) - self._cache_size += 1 - value = node[word] - node = value[0] + def _extract_reference_info(self, refs: Sequence[str]) -> Dict[str, Any]: + """Given a list of reference segments, applies pre-processing & tokenization + and returns list of tokens for each reference. - def _find_cache(self, words_hyp: List[str]) -> Tuple[int, List[List]]: - """Find the already computed rows of the edit distance matrix in cache. + :param refs: A sequence of strings. + :return: A dictionary that will be passed to `_compute_segment_statistics()` + through keyword arguments. + """ + ref_words = [] - Returns a partially computed edit distance matrix. + for ref in refs: + ref_words.append(self._preprocess_segment(ref).split()) - :param words_hyp: Translation hypothesis. - :return: Tuple (start position, dist). - """ - node = self._cache - start_position = 0 - dist = [self._initial_row] - for word in words_hyp: - if word in node: - start_position += 1 - node, row = node[word] - dist.append(row) - else: - break - - return start_position, dist + return {'ref_words': ref_words} diff --git a/sacrebleu/sacrebleu.py b/sacrebleu/sacrebleu.py index 465d387a..f52911e4 100755 --- a/sacrebleu/sacrebleu.py +++ b/sacrebleu/sacrebleu.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- # Copyright 2017--2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # @@ -23,10 +22,13 @@ """ import io +import os import sys import logging import pathlib import argparse +from collections import defaultdict + # Allows calling the script as a standalone utility # See: https://github.com/mjpost/sacrebleu/issues/86 @@ -35,13 +37,13 @@ sys.path.insert(0, str(parent)) __package__ = 'sacrebleu' -from .tokenizers import TOKENIZERS, DEFAULT_TOKENIZER -from .dataset import DATASETS, DOMAINS, COUNTRIES, SUBSETS +from .dataset import DATASETS from .metrics import METRICS +from .utils import smart_open, filter_subset, get_langpairs_for_testset, get_available_testsets +from .utils import print_test_set, print_subset_results, get_reference_files, download_test_set +from .utils import args_to_dict, sanity_check_lengths, print_results_table, print_single_results +from .utils import Color -from .utils import smart_open, filter_subset, get_available_origlangs, SACREBLEU_DIR -from .utils import get_langpairs_for_testset, get_available_testsets -from .utils import print_test_set, get_reference_files, download_test_set from . import __version__ as VERSION sacrelogger = logging.getLogger('sacrebleu') @@ -53,9 +55,8 @@ # If SIGPIPE is available, change behaviour to default instead of ignore. from signal import signal, SIG_DFL signal(SIGPIPE, SIG_DFL) - except ImportError: - sacrelogger.warning('Could not import signal.SIGPIPE (this is expected on Windows machines)') + pass def parse_args(): @@ -66,82 +67,168 @@ def parse_args(): formatter_class=argparse.RawDescriptionHelpFormatter) arg_parser.add_argument('--citation', '--cite', default=False, action='store_true', - help='dump the bibtex citation and quit.') + help='Dump the bibtex citation and quit.') arg_parser.add_argument('--list', default=False, action='store_true', - help='print a list of all available test sets.') + help='Print a list of all available test sets.') arg_parser.add_argument('--test-set', '-t', type=str, default=None, - help='the test set to use (see also --list) or a comma-separated list of test sets to be concatenated') + help='The test set to use (see also --list) or a comma-separated list of test sets to be concatenated.') arg_parser.add_argument('--language-pair', '-l', dest='langpair', default=None, - help='source-target language pair (2-char ISO639-1 codes)') + help='Source-target language pair (2-char ISO639-1 codes).') arg_parser.add_argument('--origlang', '-ol', dest='origlang', default=None, - help='use a subset of sentences with a given original language (2-char ISO639-1 codes), "non-" prefix means negation') + help='Use a subset of sentences with a given original language (2-char ISO639-1 codes), "non-" prefix means negation.') arg_parser.add_argument('--subset', dest='subset', default=None, - help='use a subset of sentences whose document annotation matches a give regex (see SUBSETS in the source code)') + help='Use a subset of sentences whose document annotation matches a given regex (see SUBSETS in the source code).') arg_parser.add_argument('--download', type=str, default=None, - help='download a test set and quit') + help='Download a test set and quit.') arg_parser.add_argument('--echo', choices=['src', 'ref', 'both'], type=str, default=None, - help='output the source (src), reference (ref), or both (both, pasted) to STDOUT and quit') + help='Output the source (src), reference (ref), or both (both, pasted) to STDOUT and quit.') # I/O related arguments - arg_parser.add_argument('--input', '-i', type=str, default='-', - help='Read input from a file instead of STDIN') + # Multiple input files can be provided for significance testing for example + arg_parser.add_argument('--input', '-i', type=str, nargs='*', default=None, + help='Read input from file(s) instead of STDIN.') arg_parser.add_argument('refs', nargs='*', default=[], - help='optional list of references (for backwards-compatibility with older scripts)') + help='Optional list of references. If given, it should preceed the -i/--input argument.') arg_parser.add_argument('--num-refs', '-nr', type=int, default=1, - help='Split the reference stream on tabs, and expect this many references. Default: %(default)s.') + help='Split the reference stream on tabs, and expect this many references. (Default: %(default)s)') arg_parser.add_argument('--encoding', '-e', type=str, default='utf-8', - help='open text files with specified encoding (default: %(default)s)') + help='Open text files with specified encoding (Default: %(default)s)') # Metric selection - arg_parser.add_argument('--metrics', '-m', choices=METRICS.keys(), nargs='+', default=['bleu'], - help='metrics to compute (default: bleu)') - arg_parser.add_argument('--sentence-level', '-sl', action='store_true', help='Output metric on each sentence.') + avail_metrics = [m.lower() for m in METRICS] + arg_parser.add_argument('--metrics', '-m', choices=avail_metrics, nargs='+', default=['bleu'], + help='Space-delimited list of metrics to compute (Default: bleu)') + arg_parser.add_argument('--sentence-level', '-sl', action='store_true', help='Compute metric for each sentence.') # BLEU-related arguments - arg_parser.add_argument('-lc', action='store_true', default=False, help='Use case-insensitive BLEU (default: False)') - arg_parser.add_argument('--smooth-method', '-s', choices=METRICS['bleu'].SMOOTH_DEFAULTS.keys(), default='exp', - help='smoothing method: exponential decay (default), floor (increment zero counts), add-k (increment num/denom by k for n>1), or none') - arg_parser.add_argument('--smooth-value', '-sv', type=float, default=None, - help='The value to pass to the smoothing technique, only used for floor and add-k. Default floor: {}, add-k: {}.'.format( - METRICS['bleu'].SMOOTH_DEFAULTS['floor'], METRICS['bleu'].SMOOTH_DEFAULTS['add-k'])) - arg_parser.add_argument('--tokenize', '-tok', choices=TOKENIZERS.keys(), default=None, - help='Tokenization method to use for BLEU. If not provided, defaults to `zh` for Chinese, `mecab` for Japanese and `mteval-v13a` otherwise.') - arg_parser.add_argument('--force', default=False, action='store_true', - help='insist that your tokenized input is actually detokenized') + # since sacreBLEU had only support for BLEU initially, the argument names + # are not prefixed with 'bleu' as in chrF arguments for example. + # Let's do that manually here through dest= options, as otherwise + # things will get quite hard to maintain when other metrics are added. + bleu_args = arg_parser.add_argument_group('BLEU related arguments') + + bleu_args.add_argument('--smooth-method', '-s', choices=METRICS['BLEU'].SMOOTH_DEFAULTS.keys(), default='exp', + dest='bleu_smooth_method', + help='Smoothing method: exponential decay, floor (increment zero counts), add-k (increment num/denom by k for n>1), or none. (Default: %(default)s)') + bleu_args.add_argument('--smooth-value', '-sv', type=float, default=None, + dest='bleu_smooth_value', + help='The smoothing value. Only valid for floor and add-k. ' + f"(Defaults: floor: {METRICS['BLEU'].SMOOTH_DEFAULTS['floor']}, " + f"add-k: {METRICS['BLEU'].SMOOTH_DEFAULTS['add-k']})") + bleu_args.add_argument('--tokenize', '-tok', choices=METRICS['BLEU'].TOKENIZERS, default=None, + dest='bleu_tokenize', + help='Tokenization method to use for BLEU. If not provided, defaults to `zh` for Chinese, `ja-mecab` for Japanese and `13a` (mteval) otherwise.') + bleu_args.add_argument('--lowercase', '-lc', dest='bleu_lowercase', action='store_true', default=False, + help='If True, enables case-insensitivity. (Default: %(default)s)') + bleu_args.add_argument('--force', default=False, action='store_true', + dest='bleu_force', help='Insist that your tokenized input is actually detokenized.') # ChrF-related arguments - arg_parser.add_argument('--chrf-order', type=int, default=METRICS['chrf'].ORDER, - help='chrf character order (default: %(default)s)') - arg_parser.add_argument('--chrf-beta', type=int, default=METRICS['chrf'].BETA, - help='chrf BETA parameter (default: %(default)s)') - arg_parser.add_argument('--chrf-whitespace', action='store_true', default=False, - help='include whitespace in chrF calculation (default: %(default)s)') + chrf_args = arg_parser.add_argument_group('chrF related arguments') + chrf_args.add_argument('--chrf-char-order', '-cc', type=int, default=METRICS['CHRF'].CHAR_ORDER, + help='Character n-gram order. (Default: %(default)s)') + chrf_args.add_argument('--chrf-word-order', '-cw', type=int, default=METRICS['CHRF'].WORD_ORDER, + help='Word n-gram order (Default: %(default)s). If equals to 2, the metric is referred to as chrF++.') + chrf_args.add_argument('--chrf-beta', type=int, default=METRICS['CHRF'].BETA, + help='Determine the importance of recall w.r.t precision. (Default: %(default)s)') + chrf_args.add_argument('--chrf-whitespace', action='store_true', default=False, + help='Include whitespaces when extracting character n-grams. (Default: %(default)s)') + chrf_args.add_argument('--chrf-lowercase', action='store_true', default=False, + help='Enable case-insensitivity. (Default: %(default)s)') + chrf_args.add_argument('--chrf-eps-smoothing', action='store_true', default=False, + help='Enables epsilon smoothing similar to chrF++.py, NLTK and Moses; instead of effective order smoothing. (Default: %(default)s)') + + # TER related arguments + ter_args = arg_parser.add_argument_group("TER related arguments (The defaults replicate TERCOM's behavior)") + ter_args.add_argument('--ter-case-sensitive', action='store_true', + help='Enables case sensitivity. (Default: %(default)s)') + ter_args.add_argument('--ter-asian-support', action='store_true', + help='Enables special treatment of Asian characters. (Default: %(default)s)') + ter_args.add_argument('--ter-no-punct', action='store_true', + help='Removes punctuation. (Default: %(default)s)') + ter_args.add_argument('--ter-normalized', action='store_true', + help='Applies basic normalization and tokenization. (Default: %(default)s)') + + # Bootstrap resampling for confidence intervals + sign_args = arg_parser.add_argument_group('Confidence interval (CI) estimation for single-system evaluation') + sign_args.add_argument('--confidence', '-ci', action='store_true', + help='Report confidence interval using bootstrap resampling.') + sign_args.add_argument('--confidence-n', '-cin', type=int, default=1000, + help='Set the number of bootstrap resamples for CI estimation (Default: %(default)s).') + + # Paired significance testing + pair_args = arg_parser.add_argument_group('Paired significance testing for multi-system evaluation') + pair_args_choice = pair_args.add_mutually_exclusive_group() + + pair_args_choice.add_argument('--paired-ar', '-par', action='store_true', + help='Perform paired test using approximate randomization (AR). This option is ' + 'mutually exclusive with --paired-bs (Default: %(default)s).') + pair_args_choice.add_argument('--paired-bs', '-pbs', action='store_true', + help='Perform paired test using bootstrap resampling. This option is ' + 'mutually exclusive with --paired-ar (Default: %(default)s).') + + pair_args.add_argument('--paired-ar-n', '-parn', type=int, default=10000, + help='Number of trials for approximate randomization test (Default: %(default)s).') + + pair_args.add_argument('--paired-bs-n', '-pbsn', type=int, default=1000, + help='Number of bootstrap resamples for paired bootstrap resampling test (Default: %(default)s).') + + pair_args.add_argument('--paired-jobs', '-j', type=int, default=1, + help='If 0, launches as many workers as the number of systems. If > 0, sets the number of workers manually. ' + 'This feature is currently not supported on Windows.') # Reporting related arguments - arg_parser.add_argument('--quiet', '-q', default=False, action='store_true', - help='suppress informative output') - arg_parser.add_argument('--short', default=False, action='store_true', - help='produce a shorter (less human readable) signature') - arg_parser.add_argument('--score-only', '-b', default=False, action='store_true', - help='output only the BLEU score') - arg_parser.add_argument('--width', '-w', type=int, default=1, - help='floating point width (default: %(default)s)') - arg_parser.add_argument('--detail', '-d', default=False, action='store_true', - help='print extra information (split test sets based on origlang)') - - arg_parser.add_argument('-V', '--version', action='version', - version='%(prog)s {}'.format(VERSION)) + report_args = arg_parser.add_argument_group('Reporting related arguments') + report_args.add_argument('--quiet', '-q', default=False, action='store_true', + help='Suppress verbose messages.') + report_args.add_argument('--short', '-sh', default=False, action='store_true', + help='Produce a shorter (less human readable) signature.') + report_args.add_argument('--score-only', '-b', default=False, action='store_true', + help='Print only the computed score.') + report_args.add_argument('--width', '-w', type=int, default=1, + help='Floating point width (Default: %(default)s).') + report_args.add_argument('--detail', '-d', default=False, action='store_true', + help='Print detailed information (split test sets based on origlang).') + report_args.add_argument('--no-color', '-nc', action='store_true', + help='Disable the occasional use of terminal colors.') + + output_formats = ['json', 'text', 'latex'] + report_args.add_argument('--format', '-f', default='json', choices=output_formats, + help='Set the output format. `latex` is only valid for multi-system mode whereas ' + '`json` and `text` apply to single-system mode only. This flag is overridden if the ' + 'SACREBLEU_FORMAT environment variable is set to one of the valid choices (Default: %(default)s).') + + arg_parser.add_argument('--version', '-V', action='version', version='%(prog)s {}'.format(VERSION)) + args = arg_parser.parse_args() + + # Override the format from the environment, if any + if 'SACREBLEU_FORMAT' in os.environ: + _new_value = os.environ['SACREBLEU_FORMAT'].lower() + if _new_value in output_formats: + args.format = _new_value + return args def main(): args = parse_args() + # Is paired test requested? + paired_test_mode = args.paired_bs or args.paired_ar + # Explicitly set the encoding sys.stdin = open(sys.stdin.fileno(), mode='r', encoding='utf-8', buffering=True, newline="\n") sys.stdout = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=True) + if os.environ.get('NO_COLOR', False) or args.no_color: + Color.ENABLE_COLORS = False + else: + # These should come after all stdout manipulations otherwise cause + # issues esp. on Windows + import colorama + colorama.init() + if not args.quiet: logging.basicConfig(level=logging.INFO, format='sacreBLEU: %(message)s') @@ -154,12 +241,13 @@ def main(): print(' '.join(get_langpairs_for_testset(args.test_set))) else: print('The available test sets are:') - for testset in get_available_testsets(): - print('%30s: %s' % (testset, DATASETS[testset].get('description', '').strip())) + for testset in sorted(get_available_testsets()): + desc = DATASETS[testset].get('description', '').strip() + print(f'{testset:<30}: {desc}') sys.exit(0) if args.sentence_level and len(args.metrics) > 1: - sacrelogger.error('Only one metric can be used with Sentence-level reporting.') + sacrelogger.error('Only one metric can be used in sentence-level mode.') sys.exit(1) if args.citation: @@ -168,27 +256,29 @@ def main(): sys.exit(1) for test_set in args.test_set.split(','): if 'citation' not in DATASETS[test_set]: - sacrelogger.error('No citation found for %s', test_set) + sacrelogger.error(f'No citation found for {test_set}') else: print(DATASETS[test_set]['citation']) sys.exit(0) if args.num_refs != 1 and (args.test_set is not None or len(args.refs) > 1): sacrelogger.error('The --num-refs argument allows you to provide any number of tab-delimited references in a single file.') - sacrelogger.error('You can only use it with externaly-provided references, however (i.e., not with `-t`),') + sacrelogger.error('You can only use it with externally provided references, however (i.e., not with `-t`),') sacrelogger.error('and you cannot then provide multiple reference files.') sys.exit(1) if args.test_set is not None: for test_set in args.test_set.split(','): if test_set not in DATASETS: - sacrelogger.error('Unknown test set "%s"', test_set) + sacrelogger.error(f'Unknown test set {test_set!r}') sacrelogger.error('Please run with --list to see the available test sets.') sys.exit(1) if args.test_set is None: if len(args.refs) == 0: - sacrelogger.error('I need either a predefined test set (-t) or a list of references') + sacrelogger.error('If manual references given, make sure to provide them ' + 'before the -i/--input argument to avoid confusion.') + sacrelogger.error('Otherwise, I need a predefined test set (-t) from the following list:') sacrelogger.error(get_available_testsets()) sys.exit(1) elif len(args.refs) > 0: @@ -201,9 +291,10 @@ def main(): for test_set in args.test_set.split(','): langpairs = get_langpairs_for_testset(test_set) if args.langpair not in langpairs: - sacrelogger.error('No such language pair "%s"', args.langpair) - sacrelogger.error('Available language pairs for test set "%s": %s', test_set, - ', '.join(langpairs)) + sacrelogger.error(f'No such language pair {args.langpair!r}') + sacrelogger.error(f'Available language pairs for {test_set!r} are:') + for lp in langpairs: + sacrelogger.error(f' > {lp}') sys.exit(1) if args.echo: @@ -214,142 +305,260 @@ def main(): print_test_set(test_set, args.langpair, args.echo, args.origlang, args.subset) sys.exit(0) - if args.test_set is not None and args.tokenize == 'none': - sacrelogger.warning("You are turning off sacrebleu's internal tokenization ('--tokenize none'), presumably to supply\n" - "your own reference tokenization. Published numbers will not be comparable with other papers.\n") - - if 'ter' in args.metrics and args.tokenize is not None: - logging.warning("Your setting of --tokenize will be ignored when " - "computing TER") - - # Internal tokenizer settings - if args.tokenize is None: - # set default - if args.langpair is not None and args.langpair.split('-')[1] == 'zh': - args.tokenize = 'zh' - elif args.langpair is not None and args.langpair.split('-')[1] == 'ja': - args.tokenize = 'ja-mecab' - else: - args.tokenize = DEFAULT_TOKENIZER - - if args.langpair is not None and 'bleu' in args.metrics: - if args.langpair.split('-')[1] == 'zh' and args.tokenize != 'zh': - sacrelogger.warning('You should also pass "--tok zh" when scoring Chinese...') - if args.langpair.split('-')[1] == 'ja' and not args.tokenize.startswith('ja-'): - sacrelogger.warning('You should also pass "--tok ja-mecab" when scoring Japanese...') - - # concat_ref_files is a list of list of reference filenames, for example: - # concat_ref_files = [[testset1_refA, testset1_refB], [testset2_refA, testset2_refB]] + # Hack: inject target language info for BLEU, so that it can + # select the tokenizer based on it + if args.langpair: + args.bleu_trg_lang = args.langpair.split('-')[1] + + if args.test_set is not None and args.bleu_tokenize == 'none': + sacrelogger.warning( + "You are turning off BLEU's internal tokenizer " + "presumably to supply your own tokenized files.") + sacrelogger.warning( + "Published numbers will not be comparable to other papers.") + + # concat_ref_files is a list of list of reference filenames + # (concatenation happens if multiple test sets are given through -t) + # Example: [[testset1_refA, testset1_refB], [testset2_refA, testset2_refB]] + concat_ref_files = [] if args.test_set is None: - concat_ref_files = [args.refs] + concat_ref_files.append(args.refs) else: - concat_ref_files = [] + # Multiple test sets can be given for test_set in args.test_set.split(','): ref_files = get_reference_files(test_set, args.langpair) if len(ref_files) == 0: - sacrelogger.warning('No references found for test set {}/{}.'.format(test_set, args.langpair)) + sacrelogger.warning( + f'No references found for test set {test_set}/{args.langpair}.') concat_ref_files.append(ref_files) + ################# # Read references + ################# full_refs = [[] for x in range(max(len(concat_ref_files[0]), args.num_refs))] for ref_files in concat_ref_files: for refno, ref_file in enumerate(ref_files): for lineno, line in enumerate(smart_open(ref_file, encoding=args.encoding), 1): - if args.num_refs != 1: - splits = line.rstrip().split(sep='\t', maxsplit=args.num_refs-1) - if len(splits) != args.num_refs: - sacrelogger.error('FATAL: line {}: expected {} fields, but found {}.'.format(lineno, args.num_refs, len(splits))) - sys.exit(17) - for refno, split in enumerate(splits): - full_refs[refno].append(split) - else: + line = line.rstrip() + if args.num_refs == 1: full_refs[refno].append(line) + else: + refs = line.split(sep='\t', maxsplit=args.num_refs - 1) + # We are strict in fixed number of references through CLI + # But the API supports having variable refs per each segment + # by simply having '' or None's as dummy placeholders + if len(refs) != args.num_refs: + sacrelogger.error(f'FATAL: line {lineno}: expected {args.num_refs} fields, but found {len(refs)}.') + sys.exit(17) + for refno, ref in enumerate(refs): + full_refs[refno].append(ref) # Decide on the number of final references, override the argument args.num_refs = len(full_refs) - # Read hypotheses stream - if args.input == '-': + # Read hypotheses + # Can't tokenize yet as each metric has its own way of tokenizing things + full_systems, sys_names = [], [] + + if args.input is None: + # Read from STDIN inputfh = io.TextIOWrapper(sys.stdin.buffer, encoding=args.encoding) + + # guess the number of systems by looking at the first line + fields = inputfh.readline().rstrip().split('\t') + + # Set number of systems + num_sys = len(fields) + + # place the first lines already + full_systems = [[s] for s in fields] + + # Enumerate the systems + sys_names = [f'System {i + 1}' for i in range(num_sys)] + + # Read the rest + for line in inputfh: + fields = line.rstrip().split('\t') + if len(fields) != num_sys: + sacrelogger.error('FATAL: the number of tab-delimited fields in the input stream differ across lines.') + sys.exit(17) + # Place systems into the list + for sys_idx, sent in enumerate(fields): + full_systems[sys_idx].append(sent.rstrip()) else: - inputfh = smart_open(args.input, encoding=args.encoding) - full_system = inputfh.readlines() - - # Filter sentences according to a given origlang - system, *refs = filter_subset( - [full_system, *full_refs], args.test_set, args.langpair, args.origlang, args.subset) - - if len(system) == 0: - message = 'Test set %s contains no sentence' % args.test_set - if args.origlang is not None or args.subset is not None: - message += ' with' - message += '' if args.origlang is None else ' origlang=' + args.origlang - message += '' if args.subset is None else ' subset=' + args.subset - sacrelogger.error(message) + # Separate files are given for each system output + # Ex: --input smt.txt nmt.txt + for fname in args.input: + sys_name = fname + + if sys_name in sys_names: + if paired_test_mode and sys_name == sys_names[0]: + # We skip loading a system, if it was already the baseline + sacrelogger.info(f'Ignoring {sys_name!r} as it was also given as the baseline.') + continue + else: + # To avoid ambiguities, we fail if two systems have same names + sacrelogger.error(f"{sys_name!r} already used to name a system.") + sacrelogger.error("Make sure to have a different basename for each system.") + sys.exit(1) + + # Read the system + lines = [] + for line in smart_open(fname, encoding=args.encoding): + lines.append(line.rstrip()) + full_systems.append(lines) + sys_names.append(sys_name) + + # Set final number of systems + num_sys = len(sys_names) + + # Add baseline prefix to the first system for clarity + if paired_test_mode: + if args.input is None: + # STDIN mode, no explicit system names + sys_names = ['Baseline'] + [f'System {i + 1}' for i in range(num_sys - 1)] + else: + # --input mode, we have names for the systems, just change the 1st one + sys_names[0] = f'Baseline: {sys_names[0]}' + + if args.sentence_level: + if num_sys > 1: + sacrelogger.error('Only one system can be evaluated in sentence-level mode.') + sys.exit(1) + if args.confidence or paired_test_mode: + sacrelogger.error('Statistical tests are unavailable in sentence-level mode.') + sys.exit(1) + + # >=2.0.0: effective_order is now part of BLEU class. For sentence-BLEU + # we now need to explicitly enable it without user's intervention + # for backward compatibility. + args.bleu_effective_order = True + + if paired_test_mode and num_sys == 1: + sacrelogger.error('Paired tests require multiple input systems given to --input (-i).') + sys.exit(1) + + if num_sys > 1 and args.confidence: + sacrelogger.error('Use paired tests (--paired) for multiple systems.') sys.exit(1) - # Create metric inventory, let each metric consume relevant args from argparse - metrics = [METRICS[met](args) for met in args.metrics] + # Filter subsets if requested + outputs = filter_subset( + [*full_systems, *full_refs], args.test_set, args.langpair, + args.origlang, args.subset) + + # Unpack systems & references back + systems, refs = outputs[:num_sys], outputs[num_sys:] + + # Perform some sanity checks + for system in systems: + if len(system) == 0: + message = f'Test set {args.test_set!r} contains no sentence' + if args.origlang is not None or args.subset is not None: + message += ' with' + if args.origlang: + message += f' origlang={args.origlang}' + if args.subset: + message += f' subset={args.subset}' + args.subset + sacrelogger.error(message) + sys.exit(1) + + # Check lengths + sanity_check_lengths(system, refs, test_set=args.test_set) + + # Create the metrics + metrics = {} + for name in args.metrics: + # Each metric's specific arguments are prefixed with `metricname_` + # for grouping. Filter accordingly and strip the prefixes prior to + # metric object construction. + metric_args = args_to_dict(args, name.lower(), strip_prefix=True) + + # This will cache reference stats for faster re-computation if required + metric_args['references'] = refs + + # Make it uppercase for the rest of the code + name = name.upper() + metrics[name] = METRICS[name](**metric_args) # Handle sentence level and quit if args.sentence_level: - # one metric in use for sentence-level - metric = metrics[0] - for output, *references in zip(system, *refs): - score = metric.sentence_score(output, references) - print(score.format(args.width, args.score_only, metric.signature)) + # one metric and one system in use for sentence-level + metric, system = list(metrics.values())[0], systems[0] + + for hypothesis, *references in zip(system, *refs): + score = metric.sentence_score(hypothesis, references) + sig = metric.get_signature().format(args.short) + print(score.format(args.width, args.score_only, sig)) sys.exit(0) - # Else, handle system level - for metric in metrics: - try: - score = metric.corpus_score(system, refs) - except EOFError: - sacrelogger.error('The input and reference stream(s) were of different lengths.') - if args.test_set is not None: - sacrelogger.error('\nThis could be a problem with your system output or with sacreBLEU\'s reference database.\n' - 'If the latter, you can clean out the references cache by typing:\n' - '\n' - ' rm -r %s/%s\n' - '\n' - 'They will be downloaded automatically again the next time you run sacreBLEU.', SACREBLEU_DIR, - args.test_set) - sys.exit(1) + if args.detail and args.format == 'json': + # The translationese info will interfere with JSON output, disable + args.format = 'text' + + ############################## + # Corpus level evaluation mode + ############################## + if num_sys == 1: + # Single system evaluation mode + results = [] + for name in sorted(metrics): + # compute the score + score = metrics[name].corpus_score( + system, references=None, + n_bootstrap=args.confidence_n if args.confidence else 1) + # get the signature + sig = metrics[name].get_signature().format( + args.short if args.format != 'json' else False) + results.append( + score.format(args.width, args.score_only, sig, args.format == 'json')) + + print_single_results(results, args) + + # Prints detailed information for translationese effect experiments + if args.detail: + print_subset_results(metrics, full_systems[0], full_refs, args) + else: + # Multi-system evaluation mode + named_systems = [(sys_names[i], systems[i]) for i in range(num_sys)] + sacrelogger.info(f'Found {num_sys} systems.') + + if not paired_test_mode: + # Bootstrap resampling or the usual single score computation mode + sigs = {} + scores = defaultdict(list) + scores['System'] = sys_names + + for sys_name, system in named_systems: + for name in sorted(metrics): + score = metrics[name].corpus_score(system, references=None) + sigs[score.name] = metrics[name].get_signature().format(args.short) + scores[score.name].append(score.format(args.width, True)) + else: - print(score.format(args.width, args.score_only, metric.signature)) - - if args.detail: - width = args.width - sents_digits = len(str(len(full_system))) - origlangs = args.origlang if args.origlang else get_available_origlangs(args.test_set, args.langpair) - for origlang in origlangs: - subsets = [None] - if args.subset is not None: - subsets += [args.subset] - elif all(t in SUBSETS for t in args.test_set.split(',')): - subsets += COUNTRIES + DOMAINS - for subset in subsets: - system, *refs = filter_subset([full_system, *full_refs], args.test_set, args.langpair, origlang, subset) - if len(system) == 0: - continue - if subset in COUNTRIES: - subset_str = '%20s' % ('country=' + subset) - elif subset in DOMAINS: - subset_str = '%20s' % ('domain=' + subset) - else: - subset_str = '%20s' % '' - for metric in metrics: - # FIXME: handle this in metrics - if metric.name == 'bleu': - _refs = refs - elif metric.name == 'chrf': - _refs = refs[0] - - score = metric.corpus_score(system, _refs) - print('origlang={} {}: sentences={:{}} {}={:{}.{}f}'.format( - origlang, subset_str, len(system), sents_digits, - score.prefix, score.score, width+4, width)) + # Paired significance testing mode + from .significance import PairedTest + + # Set params + test_type = 'bs' if args.paired_bs else 'ar' + n_samples = args.paired_bs_n if args.paired_bs else args.paired_ar_n + + ps = PairedTest(named_systems, metrics, references=None, + test_type=test_type, n_samples=n_samples, + n_jobs=args.paired_jobs) + + # Set back the number of trials + args.paired_n = ps.n_samples + + # Run the test + sigs, scores = ps() + + # Get signature strings + sigs = {k: v.format(args.short) for k, v in sigs.items()} + + # Dump the results + print_results_table(scores, sigs, args) if __name__ == '__main__': diff --git a/sacrebleu/significance.py b/sacrebleu/significance.py new file mode 100644 index 00000000..b39e0a59 --- /dev/null +++ b/sacrebleu/significance.py @@ -0,0 +1,434 @@ +import os +import logging +import multiprocessing as mp +from typing import Sequence, Dict, Optional, Tuple, List, Union, Any + +import numpy as np + +from .metrics.base import Metric, Score, Signature + +IS_WINDOWS = os.name == 'nt' + + +sacrelogger = logging.getLogger('sacrebleu') + + +class Result: + """A container to represent results from a particular statistical + significance test. + :param score: The floating point score for the system at hand. + :param p_value: If exists, represents the p-value when the system at + hand is compared to a baseline using a paired test. + :param mean: When paired bootstrap test is applied, this represents + the true mean score estimated from bootstrap resamples of the system. + :param ci: When paired bootstrap test is applied, this represents + the 95% confidence interval around the true mean score `sys_mean`. + """ + def __init__(self, score: float, p_value: Optional[float] = None, + mean: Optional[float] = None, ci: Optional[float] = None): + self.score = score + self.p_value = p_value + self.mean = mean + self.ci = ci + + def __repr__(self): + return ','.join([f'{k}={str(v)}' for k, v in self.__dict__.items()]) + + +def estimate_ci(scores: np.ndarray) -> Tuple[float, float]: + """Takes a list of scores and returns mean and 95% confidence + interval around the mean. + + :param scores: A list of floating point scores. + :return: A tuple of mean and the 95% CI. + """ + # Sort the scores + scores = np.sort(scores) + n = len(scores) + + # Get CI bounds (95%, i.e. 1/40 from left) + lower_idx = n // 40 + upper_idx = n - lower_idx - 1 + lower, upper = scores[lower_idx], scores[upper_idx] + ci = 0.5 * (upper - lower) + return (scores.mean(), ci) + + +def _bootstrap_resample(stats: List[List[Union[int, float]]], + metric: Metric, n_samples: int = 1000) -> Tuple[str, List[Score]]: + """Performs bootstrap resampling for a single system to estimate + a confidence interval around the true mean. + :param stats: A list of statistics extracted from the system's hypotheses. + :param metric: The `Metric` instance to be used for score computation. + :n_samples: Number of bootstrap resamples to use. + + :return: A tuple of the seed choice as string and the list of `Score` + instances for all bootstrap resamples. + """ + + # Set numpy RNG's seed + # If given -> Fix to the given value + # If given but =='[Nn]one', don't fix the seed i.e. pull entropy from OS + seed = os.environ.get('SACREBLEU_SEED', '12345') + _seed = None if seed.lower() == 'none' else int(seed) + rng = np.random.default_rng(_seed) + + # The indices that'll produce all bootstrap resamples at once + idxs = rng.choice(len(stats), size=(n_samples, len(stats)), replace=True) + + # convert to numpy array. float32 is more efficient + stats = np.array(stats, dtype='float32') + + # recompute scores for all resamples + scores = [ + metric._compute_score_from_stats(_s.sum(0)) for _s in stats[idxs]] + + return str(seed).lower(), scores + + +def _compute_p_value(stats: np.ndarray, real_difference: float) -> float: + """Computes the p-value given the sample statistics and the real statistic. + :param stats: A numpy array with the sample statistics. + :real_difference: The real statistic. + :return: The p-value. + """ + # Taken from: significance/StratifiedApproximateRandomizationTest.java + # https://github.com/jhclark/multeval.git + + # "the != is important. if we want to score the same system against itself + # having a zero difference should not be attributed to chance." + + c = np.sum(stats > real_difference) + + # "+1 applies here, though it only matters for small numbers of shufflings, + # which we typically never do. it's necessary to ensure the probability of + # falsely rejecting the null hypothesis is no greater than the rejection + # level of the test (see william and morgan on significance tests) + p = (c + 1) / (len(stats) + 1) + + return p + + +def _paired_ar_test(baseline_info: Dict[str, Tuple[np.ndarray, Result]], + sys_name: str, + hypotheses: Sequence[str], + references: Optional[Sequence[Sequence[str]]], + metrics: Dict[str, Metric], + n_samples: int = 10000, + n_ar_confidence: int = -1, + seed: Optional[int] = None) -> Tuple[str, Dict[str, Result]]: + """Paired two-sided approximate randomization (AR) test for MT evaluation. + + :param baseline_info: A dictionary with `Metric` instances as the keys, + that contains sufficient statistics and a `Result` instance for the baseline system. + :param sys_name: The name of the system to be evaluated. + :param hypotheses: A sequence of string hypotheses for the system. + :param references: A sequence of reference documents with document being + defined as a sequence of reference strings. If `None`, references + will be used through each metric's internal cache. + :param metrics: A dictionary of `Metric` instances that will be computed + for each system. + :param n_samples: The number of AR trials. + :param n_ar_confidence: The number of bootstrap resamples to use for + confidence estimation. A value of -1 disables confidence estimation. + :param seed: The seed value for the RNG. If `None`, the RNG will not be + fixed to a particular seed. + + :return: A tuple with first element being the system name and the second + being a `Result` namedtuple. + """ + # Seed the RNG + rng = np.random.default_rng(seed) + + # Generate indices that'll select stats + pos_sel = rng.integers(2, size=(n_samples, len(hypotheses)), dtype=bool) + + # Flip mask to obtain selectors for system hypotheses + neg_sel = ~pos_sel + + if n_ar_confidence > 0: + # Perform confidence estimation as well + bs_idxs = rng.choice( + len(hypotheses), size=(n_ar_confidence, len(hypotheses)), replace=True) + + results = {} + + for name, metric in metrics.items(): + # Use pre-computed match stats for the baseline + bl_stats, bl_result = baseline_info[name] + + # Compute system's stats and score + sacrelogger.info(f'Computing {name} for {sys_name!r} and extracting sufficient statistics') + sys_stats = metric._extract_corpus_statistics(hypotheses, references) + sys_score = metric._aggregate_and_compute(sys_stats) + + # original test statistic: absolute difference between baseline and the system + diff = abs(bl_result.score - sys_score.score) + + sacrelogger.info(f' > Performing approximate randomization test (# trials: {n_samples})') + # get shuffled pseudo systems + shuf_a = pos_sel @ bl_stats + neg_sel @ sys_stats + shuf_b = neg_sel @ bl_stats + pos_sel @ sys_stats + + # Aggregate trial stats and compute scores for each + scores_a = np.array( + [metric._aggregate_and_compute(x).score for x in shuf_a[:, None]]) + scores_b = np.array( + [metric._aggregate_and_compute(x).score for x in shuf_b[:, None]]) + + # Count the statistical difference and compute the p-value + p = _compute_p_value( + np.abs(np.array(scores_a) - np.array(scores_b)), diff) + + res = Result(sys_score.score, p) + + if n_ar_confidence > 0: + sacrelogger.info(f' > Performing bootstrap resampling for confidence interval (# resamples: {n_ar_confidence})') + sys_stats = np.array(sys_stats, dtype='float32') + # recompute scores for all resamples + sys_scores = [ + metric._compute_score_from_stats(_s.sum(0)).score for _s in sys_stats[bs_idxs]] + res.mean, res.ci = estimate_ci(sys_scores) + + # Store the result + results[name] = res + + return sys_name, results + + +def _paired_bs_test(baseline_info: Dict[str, Tuple[np.ndarray, Result]], + sys_name: str, + hypotheses: Sequence[str], + references: Optional[Sequence[Sequence[str]]], + metrics: Dict[str, Metric], + n_samples: int = 1000, + n_ar_confidence: int = -1, + seed: Optional[int] = None) -> Tuple[str, Dict[str, Result]]: + """Paired bootstrap resampling test for MT evaluation. This function + replicates the behavior of the Moses script called + `bootstrap-hypothesis-difference-significance.pl`. + + :param baseline_info: A dictionary with `Metric` instances as the keys, + that contains sufficient statistics and a `Result` instance for the baseline system. + :param sys_name: The name of the system to be evaluated. + :param hypotheses: A sequence of string hypotheses for the system. + :param references: A sequence of reference documents with document being + defined as a sequence of reference strings. If `None`, references + will be used through each metric's internal cache. + :param metrics: A dictionary of `Metric` instances that will be computed + for each system. + :param n_samples: The number of bootstrap resamples. + :param n_ar_confidence: This parameter is not used for this function but + is there for signature compatibility in the API. + :param seed: The seed value for the RNG. If `None`, the RNG will not be + fixed to a particular seed. + + :return: A tuple with first element being the system name and the second + being a `Result` namedtuple. + """ + # Seed the RNG + rng = np.random.default_rng(seed) + + results = {} + + # It takes ~10ms to generated the indices + idxs = rng.choice( + len(hypotheses), size=(n_samples, len(hypotheses)), replace=True) + + for name, metric in metrics.items(): + # Use pre-computed match stats for the baseline + bl_stats, bl_result = baseline_info[name] + + # Compute system's stats and score + sacrelogger.info(f'Computing {name} for {sys_name!r} and extracting sufficient statistics') + sys_stats = metric._extract_corpus_statistics(hypotheses, references) + sys_score = metric._aggregate_and_compute(sys_stats) + + # Convert to numpy arrays for efficient indexing + sys_stats = np.array(sys_stats, dtype='float32') + bl_stats = np.array(bl_stats, dtype='float32') + + # original test statistic: absolute difference between baseline and the system + diff = abs(bl_result.score - sys_score.score) + + sacrelogger.info(f' > Performing paired bootstrap resampling test (# resamples: {n_samples})') + scores_bl = np.array( + [metric._compute_score_from_stats(_s.sum(0)).score for _s in bl_stats[idxs]]) + scores_sys = np.array( + [metric._compute_score_from_stats(_s.sum(0)).score for _s in sys_stats[idxs]]) + + # Compute CI as well + sys_mean, sys_ci = estimate_ci(scores_sys) + + # Compute the statistics + sample_diffs = np.abs(scores_sys - scores_bl) + stats = sample_diffs - sample_diffs.mean() + + # Count the statistical difference and compute the p-value + p = _compute_p_value(stats, diff) + + results[name] = Result(sys_score.score, p, sys_mean, sys_ci) + + return sys_name, results + + +class PairedTest: + """This is the manager class that will call the actual standalone implementation + for approximate randomization or paired bootstrap resampling, based on the + `test_type` argument. + + :param named_systems: A lisf of (system_name, system_hypotheses) tuples on + which the test will be applied. + :param metrics: A dictionary of `Metric` instances that will be computed + for each system. + :param references: A sequence of reference documents with document being + defined as a sequence of reference strings. If `None`, already cached references + will be used through each metric's internal cache. + :param test_type: `ar` for approximate randomization, `bs` for paired bootstrap. + :param n_samples: The number of AR trials (for `ar`) or bootstrap resamples (for `bs`). + The defaults (10000 or 1000 respectively) will be used if 0 is passed. + :param n_ar_confidence: If `approximate randomization` is selected, the number + of bootstrap resamples to use for confidence estimation. A value of -1 disables + confidence estimation. 0 will use the default of 1000. + :param n_jobs: If 0, a worker process will be spawned for each system variant. + If > 0, the number of workers will be set accordingly. The default of 1 + does not use multi-processing. + """ + _DEFAULT_SAMPLES = { + 'ar': 10000, + 'bs': 1000, + } + + def __init__(self, named_systems: List[Tuple[str, Sequence[str]]], + metrics: Dict[str, Metric], + references: Optional[Sequence[Sequence[str]]], + test_type: str = 'ar', + n_samples: int = 0, + n_ar_confidence: int = -1, + n_jobs: int = 1): + assert test_type in ('ar', 'bs'), f"Unknown test type {test_type!r}" + self.test_type = test_type + + # Set method + if self.test_type == 'ar': + self._fn = _paired_ar_test + elif self.test_type == 'bs': + self._fn = _paired_bs_test + + # Set numpy RNG's seed + # If given -> Fix to the given value + # If given but =='[Nn]one', don't fix the seed i.e. pull entropy from OS + seed = os.environ.get('SACREBLEU_SEED', '12345') + self._seed = None if seed.lower() == 'none' else int(seed) + self.n_jobs = n_jobs + self.references = references + self.named_systems = named_systems + + # Set the defaults if requested + self.n_ar_confidence = n_ar_confidence if n_ar_confidence != 0 else \ + self._DEFAULT_SAMPLES['bs'] + + self.n_samples = n_samples if n_samples > 0 else \ + self._DEFAULT_SAMPLES[self.test_type] + + # Number of systems (excluding the baseline) + self.n_systems = len(named_systems) - 1 + + # Decide on number of workers + if IS_WINDOWS: + sacrelogger.warning('Parallel tests are not supported on Windows.') + self.n_jobs = 1 + elif self.n_jobs == 0: + # Decide automatically + # Divide by two to ignore hyper-threading + n_max_jobs = mp.cpu_count() // 2 + if n_max_jobs == 0: + self.n_jobs = 1 + else: + # Don't use more workers than the number of CPUs + self.n_jobs = min(n_max_jobs, self.n_systems) + + self._signatures: Dict[str, Signature] = {} + self._baseline_info: Dict[str, Tuple[Any, Result]] = {} + + ################################################## + # Pre-compute and cache baseline system statistics + ################################################## + self.metrics = {} + + bl_name, bl_hyps = self.named_systems[0] + + for name, metric in metrics.items(): + sacrelogger.info(f'Pre-computing {name} statistics for {bl_name!r}') + bl_stats = metric._extract_corpus_statistics(bl_hyps, self.references) + bl_score = metric._aggregate_and_compute(bl_stats) + + # Compute CI for the baseline here once + confidence_n = self.n_samples if self.test_type == 'bs' \ + else self.n_ar_confidence + + bl_mean, bl_ci = None, None + if confidence_n > 0: + _, bl_scores = _bootstrap_resample(bl_stats, metric, confidence_n) + bl_mean, bl_ci = estimate_ci(np.array([x.score for x in bl_scores])) + + result = Result(bl_score.score, mean=bl_mean, ci=bl_ci) + # Use updated name for the metric + self._baseline_info[bl_score.name] = (bl_stats, result) + self.metrics[bl_score.name] = metric + + # Update metric signature as well + sig = metric.get_signature() + sig.update('seed', str(self._seed).lower()) + + # Num samples for bs, num trials for AR + sig.update(self.test_type, self.n_samples) + if self.n_ar_confidence > 0: + # Bootstrap is used for AR CI as well + sig.update('bs', self.n_ar_confidence) + self._signatures[bl_score.name] = sig + + def __call__(self) -> Tuple[Dict[str, Signature], Dict[str, List[Union[str, Result]]]]: + """Runs the paired test either on single or multiple worker processes.""" + tasks = [] + scores: Dict[str, List[Union[str, Result]]] = {} + + # Add the name column + scores['System'] = [ns[0] for ns in self.named_systems] + + # Store baseline results as the first position + for metric, (_, result) in self._baseline_info.items(): + scores[metric] = [result] + + # Prepare list of arguments for each comparison + # Skip the baseline (pos: 0) + for idx, (name, hyps) in enumerate(self.named_systems[1:]): + seed = self._seed if self._seed else None + + tasks.append( + (self._baseline_info, name, hyps, self.references, + self.metrics, self.n_samples, self.n_ar_confidence, seed)) + + # Run the test(s) + if self.n_jobs == 1: + results = [self._fn(*args) for args in tasks] + else: + # NOTE: The overhead of worker creation is not negligible + # but if you have many systems and TER enabled, this significantly + # speeds up the test. + # NOTE: This only works on Linux/Mac OS X but not Windows. Windows only + # supports `spawn` backend which requires things to be called + # from within __main__. + sacrelogger.info(f'Launching {self.n_jobs} parallel workers.') + with mp.get_context('fork').Pool(self.n_jobs) as pool: + jobs = [pool.apply_async(self._fn, args) for args in tasks] + + # wait for completion + results = [j.get() for j in jobs] + + # Keep the order deterministic + for sys_name, sys_results in results: + for metric, _result in sys_results.items(): + scores[metric].append(_result) + + return self._signatures, scores diff --git a/sacrebleu/tokenizers/__init__.py b/sacrebleu/tokenizers/__init__.py index fc9f739e..d658a1ba 100644 --- a/sacrebleu/tokenizers/__init__.py +++ b/sacrebleu/tokenizers/__init__.py @@ -1,21 +1,2 @@ -# -*- coding: utf-8 -*- - -from .tokenizer_none import NoneTokenizer -from .tokenizer_13a import Tokenizer13a -from .tokenizer_intl import TokenizerV14International -from .tokenizer_zh import TokenizerZh -from .tokenizer_ja_mecab import TokenizerJaMecab -from .tokenizer_char import TokenizerChar - - -DEFAULT_TOKENIZER = '13a' - - -TOKENIZERS = { - 'none': NoneTokenizer, - '13a': Tokenizer13a, - 'intl': TokenizerV14International, - 'zh': TokenizerZh, - 'ja-mecab': TokenizerJaMecab, - 'char': TokenizerChar, -} +# Base tokenizer to derive from +from .tokenizer_base import BaseTokenizer # noqa: F401 diff --git a/sacrebleu/tokenizers/tokenizer_13a.py b/sacrebleu/tokenizers/tokenizer_13a.py index 01e64e88..331d171c 100644 --- a/sacrebleu/tokenizers/tokenizer_13a.py +++ b/sacrebleu/tokenizers/tokenizer_13a.py @@ -1,10 +1,9 @@ -# -*- coding: utf-8 -*- - -from .tokenizer_none import NoneTokenizer +from functools import lru_cache +from .tokenizer_base import BaseTokenizer from .tokenizer_re import TokenizerRegexp -class Tokenizer13a(NoneTokenizer): +class Tokenizer13a(BaseTokenizer): def signature(self): return '13a' @@ -12,6 +11,7 @@ def signature(self): def __init__(self): self._post_tokenizer = TokenizerRegexp() + @lru_cache(maxsize=None) def __call__(self, line): """Tokenizes an input line using a relatively minimal tokenization that is however equivalent to mteval-v13a, used by WMT. @@ -24,10 +24,11 @@ def __call__(self, line): line = line.replace('', '') line = line.replace('-\n', '') line = line.replace('\n', ' ') - line = line.replace('"', '"') - line = line.replace('&', '&') - line = line.replace('<', '<') - line = line.replace('>', '>') - line = " {} ".format(line) - return self._post_tokenizer(line) + if '&' in line: + line = line.replace('"', '"') + line = line.replace('&', '&') + line = line.replace('<', '<') + line = line.replace('>', '>') + + return self._post_tokenizer(f' {line} ') diff --git a/sacrebleu/tokenizers/tokenizer_none.py b/sacrebleu/tokenizers/tokenizer_base.py similarity index 89% rename from sacrebleu/tokenizers/tokenizer_none.py rename to sacrebleu/tokenizers/tokenizer_base.py index 7b45484c..e2a52e8b 100644 --- a/sacrebleu/tokenizers/tokenizer_none.py +++ b/sacrebleu/tokenizers/tokenizer_base.py @@ -1,6 +1,4 @@ -# -*- coding: utf-8 -*- - -class NoneTokenizer: +class BaseTokenizer: """A base dummy tokenizer to derive from.""" def signature(self): diff --git a/sacrebleu/tokenizers/tokenizer_char.py b/sacrebleu/tokenizers/tokenizer_char.py index 59f5afaf..aab7ac26 100644 --- a/sacrebleu/tokenizers/tokenizer_char.py +++ b/sacrebleu/tokenizers/tokenizer_char.py @@ -1,15 +1,15 @@ -# -*- coding: utf-8 -*- +from functools import lru_cache +from .tokenizer_base import BaseTokenizer -from .tokenizer_none import NoneTokenizer - -class TokenizerChar(NoneTokenizer): +class TokenizerChar(BaseTokenizer): def signature(self): return 'char' def __init__(self): pass + @lru_cache(maxsize=None) def __call__(self, line): """Tokenizes all the characters in the input line. diff --git a/sacrebleu/tokenizers/tokenizer_intl.py b/sacrebleu/tokenizers/tokenizer_intl.py index a314ae61..9af05493 100644 --- a/sacrebleu/tokenizers/tokenizer_intl.py +++ b/sacrebleu/tokenizers/tokenizer_intl.py @@ -1,76 +1,50 @@ -# -*- coding: utf-8 -*- +from functools import lru_cache -import re -import sys -import functools -import unicodedata +import regex -from .tokenizer_none import NoneTokenizer +from .tokenizer_base import BaseTokenizer -class UnicodeRegex: - """Ad-hoc hack to recognize all punctuation and symbols - without depending on https://pypi.python.org/pypi/regex/.""" +class TokenizerV14International(BaseTokenizer): + """Tokenizes a string following the official BLEU implementation. - @staticmethod - def _property_chars(prefix): - return ''.join(chr(x) for x in range(sys.maxunicode) - if unicodedata.category(chr(x)).startswith(prefix)) + See github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983 - @staticmethod - @functools.lru_cache(maxsize=1) - def punctuation(): - return UnicodeRegex._property_chars('P') + In our case, the input string is expected to be just one line. + We just tokenize on punctuation and symbols, + except when a punctuation is preceded and followed by a digit + (e.g. a comma/dot as a thousand/decimal separator). + We do not recover escaped forms of punctuations such as ' or > + as these should never appear in MT system outputs (see issue #138) - @staticmethod - @functools.lru_cache(maxsize=1) - def nondigit_punct_re(): - return re.compile(r'([^\d])([' + UnicodeRegex.punctuation() + r'])') + Note that a number (e.g., a year) followed by a dot at the end of + sentence is NOT tokenized, i.e. the dot stays with the number because + `s/(\\p{P})(\\P{N})/ $1 $2/g` does not match this case (unless we add a + space after each sentence). However, this error is already in the + original mteval-v14.pl and we want to be consistent with it. + The error is not present in the non-international version, + which uses `$norm_text = " $norm_text "`. - @staticmethod - @functools.lru_cache(maxsize=1) - def punct_nondigit_re(): - return re.compile(r'([' + UnicodeRegex.punctuation() + r'])([^\d])') - - @staticmethod - @functools.lru_cache(maxsize=1) - def symbol_re(): - return re.compile('([' + UnicodeRegex._property_chars('S') + '])') - - -class TokenizerV14International(NoneTokenizer): + :param line: the input string to tokenize. + :return: The tokenized string. + """ def signature(self): return 'intl' def __init__(self): - self.nondigit_punct_re = UnicodeRegex.nondigit_punct_re() - self.punct_nondigit_re = UnicodeRegex.punct_nondigit_re() - self.symbol_re = UnicodeRegex.symbol_re() - - def __call__(self, line): - r"""Tokenize a string following the official BLEU implementation. - - See https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983 - In our case, the input string is expected to be just one line - and no HTML entities de-escaping is needed. - So we just tokenize on punctuation and symbols, - except when a punctuation is preceded and followed by a digit - (e.g. a comma/dot as a thousand/decimal separator). - - Note that a number (e.g., a year) followed by a dot at the end of - sentence is NOT tokenized, i.e. the dot stays with the number because - `s/(\p{P})(\P{N})/ $1 $2/g` does not match this case (unless we add a - space after each sentence). However, this error is already in the - original mteval-v14.pl and we want to be consistent with it. - The error is not present in the non-international version, - which uses - `$norm_text = " $norm_text "` (or `norm = " {} ".format(norm)` in Python). - - :param line: the input string - :return: a list of tokens - """ - line = self.nondigit_punct_re.sub(r'\1 \2 ', line) - line = self.punct_nondigit_re.sub(r' \1 \2', line) - line = self.symbol_re.sub(r' \1 ', line) - return line.strip() + self._re = [ + # Separate out punctuations preceeded by a non-digit + (regex.compile(r'(\P{N})(\p{P})'), r'\1 \2 '), + # Separate out punctuations followed by a non-digit + (regex.compile(r'(\p{P})(\P{N})'), r' \1 \2'), + # Separate out symbols + (regex.compile(r'(\p{S})'), r' \1 '), + ] + + @lru_cache(maxsize=None) + def __call__(self, line: str) -> str: + for (_re, repl) in self._re: + line = _re.sub(repl, line) + + return ' '.join(line.split()) diff --git a/sacrebleu/tokenizers/tokenizer_ja_mecab.py b/sacrebleu/tokenizers/tokenizer_ja_mecab.py index 9f9e44fc..bc6aca9c 100644 --- a/sacrebleu/tokenizers/tokenizer_ja_mecab.py +++ b/sacrebleu/tokenizers/tokenizer_ja_mecab.py @@ -1,4 +1,5 @@ -# -*- coding: utf-8 -*- +from functools import lru_cache + try: import MeCab import ipadic @@ -6,7 +7,7 @@ # Don't fail until the tokenizer is actually used MeCab = None -from .tokenizer_none import NoneTokenizer +from .tokenizer_base import BaseTokenizer FAIL_MESSAGE = """ Japanese tokenization requires extra dependencies, but you do not have them installed. @@ -15,7 +16,8 @@ pip install sacrebleu[ja] """ -class TokenizerJaMecab(NoneTokenizer): + +class TokenizerJaMecab(BaseTokenizer): def __init__(self): if MeCab is None: raise RuntimeError(FAIL_MESSAGE) @@ -28,6 +30,7 @@ def __init__(self): # This asserts that no user dictionary has been loaded assert d.next is None + @lru_cache(maxsize=None) def __call__(self, line): """ Tokenizes an Japanese input line using MeCab morphological analyzer. diff --git a/sacrebleu/tokenizers/tokenizer_re.py b/sacrebleu/tokenizers/tokenizer_re.py index 1bbffda4..4328886f 100644 --- a/sacrebleu/tokenizers/tokenizer_re.py +++ b/sacrebleu/tokenizers/tokenizer_re.py @@ -1,9 +1,10 @@ +from functools import lru_cache import re -from .tokenizer_none import NoneTokenizer +from .tokenizer_base import BaseTokenizer -class TokenizerRegexp(NoneTokenizer): +class TokenizerRegexp(BaseTokenizer): def signature(self): return 're' @@ -19,9 +20,11 @@ def __init__(self): # tokenize dash when preceded by a digit (re.compile(r'([0-9])(-)'), r'\1 \2 '), # one space only between words - (re.compile(r'\s+'), r' '), + # NOTE: Doing this in Python (below) is faster + # (re.compile(r'\s+'), r' '), ] + @lru_cache(maxsize=None) def __call__(self, line): """Common post-processing tokenizer for `13a` and `zh` tokenizers. @@ -31,5 +34,5 @@ def __call__(self, line): for (_re, repl) in self._re: line = _re.sub(repl, line) - # no leading or trailing spaces - return line.strip() + # no leading or trailing spaces, single space within words + return ' '.join(line.split()) diff --git a/sacrebleu/tokenizers/tokenizer_ter.py b/sacrebleu/tokenizers/tokenizer_ter.py index a58f7e38..dda5735d 100644 --- a/sacrebleu/tokenizers/tokenizer_ter.py +++ b/sacrebleu/tokenizers/tokenizer_ter.py @@ -14,8 +14,9 @@ import re +from functools import lru_cache -from .tokenizer_none import NoneTokenizer +from .tokenizer_base import BaseTokenizer def _normalize_general_and_western(sent: str) -> str: @@ -34,7 +35,7 @@ def _normalize_general_and_western(sent: str) -> str: sent = re.sub(r">", ">", sent) # language-dependent (Western) part - sent = " {} ".format(sent) + sent = f" {sent} " # tokenize punctuation sent = re.sub(r"([{-~[-` -&(-+:-@/])", r" \1 ", sent) @@ -105,8 +106,8 @@ def _remove_asian_punct(sent: str) -> str: return sent -class TercomTokenizer(NoneTokenizer): - """Re-implementation Tercom Tokenizer in Python 3. +class TercomTokenizer(BaseTokenizer): + """Re-implementation of Tercom Tokenizer in Python 3. See src/ter/core/Normalizer.java in https://github.com/jhclark/tercom @@ -135,6 +136,10 @@ def __init__(self, self._asian_support = asian_support self._case_sensitive = case_sensitive + @lru_cache(maxsize=None) + # Although the cache is shared across different instances, same sentence + # queries do not return invalid returns across different instances since + # `self` becomes part of the query as well. def __call__(self, sent: str) -> str: if not sent: return "" @@ -147,22 +152,13 @@ def __call__(self, sent: str) -> str: if self._asian_support: sent = _normalize_asian(sent) - sent = re.sub(r"\s+", " ", sent) # one space only between words - sent = re.sub(r"^\s+", "", sent) # no leading space - sent = re.sub(r"\s+$", "", sent) # no trailing space - if self._no_punct: sent = _remove_punct(sent) if self._asian_support: sent = _remove_asian_punct(sent) - return sent + # Strip extra whitespaces + return ' '.join(sent.split()) def signature(self): - return("-".join([ - 'tercom', - 'norm' if self._normalized else 'nonorm', - 'nopunct' if self._no_punct else 'punct', - 'asian' if self._asian_support else 'noasian', - 'cased' if self._case_sensitive else 'uncased', - ])) + return 'tercom' diff --git a/sacrebleu/tokenizers/tokenizer_zh.py b/sacrebleu/tokenizers/tokenizer_zh.py index e7a93d58..c1f65fc1 100644 --- a/sacrebleu/tokenizers/tokenizer_zh.py +++ b/sacrebleu/tokenizers/tokenizer_zh.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - # Copyright 2017--2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You may not @@ -39,7 +37,9 @@ # Author: Shujian Huang huangsj@nju.edu.cn -from .tokenizer_none import NoneTokenizer +from functools import lru_cache + +from .tokenizer_base import BaseTokenizer from .tokenizer_re import TokenizerRegexp _UCODE_RANGES = [ @@ -69,7 +69,7 @@ ] -class TokenizerZh(NoneTokenizer): +class TokenizerZh(BaseTokenizer): def signature(self): return 'zh' @@ -78,6 +78,7 @@ def __init__(self): self._post_tokenizer = TokenizerRegexp() @staticmethod + @lru_cache(maxsize=None) def _is_chinese_char(uchar): """ :param uchar: input char in unicode @@ -88,6 +89,7 @@ def _is_chinese_char(uchar): return True return False + @lru_cache(maxsize=None) def __call__(self, line): """The tokenization of Chinese text in this script contains two steps: separate each Chinese characters (by utf-8 encoding); tokenize @@ -100,11 +102,12 @@ def __call__(self, line): """ line = line.strip() + line_in_chars = "" # TODO: the below code could probably be replaced with the following: + # @ozan: Gives slightly different scores, need to investigate # import regex # line = regex.sub(r'(\p{Han})', r' \1 ', line) - line_in_chars = "" for char in line: if self._is_chinese_char(char): line_in_chars += " " @@ -112,6 +115,5 @@ def __call__(self, line): line_in_chars += " " else: line_in_chars += char - line = line_in_chars - return self._post_tokenizer(line) + return self._post_tokenizer(line_in_chars) diff --git a/sacrebleu/utils.py b/sacrebleu/utils.py index e692e982..5a390202 100644 --- a/sacrebleu/utils.py +++ b/sacrebleu/utils.py @@ -1,20 +1,19 @@ -# -*- coding: utf-8 -*- - -import gzip -import hashlib -import logging -import math import os -import portalocker import re import sys -import ssl -import urllib.request +import gzip +import math +import hashlib +import logging +from collections import defaultdict +from typing import List, Optional, Sequence, Dict +from argparse import Namespace -from itertools import filterfalse -from typing import List -from .dataset import DATASETS, SUBSETS +import portalocker +from tabulate import tabulate +import colorama +from .dataset import DATASETS, SUBSETS, DOMAINS, COUNTRIES # Where to store downloaded test sets. # Define the environment variable $SACREBLEU, or use the default of ~/.sacrebleu. @@ -27,6 +26,203 @@ sacrelogger = logging.getLogger('sacrebleu') + +class Color: + ENABLE_COLORS = True + + @staticmethod + def format(msg: str, color: str) -> str: + """Returns a colored version of the given message string. + + :param msg: The string to Color.format. + :param color: The color specifier i.e. 'red', 'blue', 'green', etc. + :return: A colored version of the string if the output is a terminal. + """ + if not Color.ENABLE_COLORS: + return msg + _ansi_str = getattr(colorama.Fore, color.upper(), None) + if _ansi_str: + return f'{_ansi_str}{msg}{colorama.Style.RESET_ALL}' + + return msg + + +def _format_score_lines(scores: dict, + width: int = 2, + multiline: bool = True) -> Dict[str, List[str]]: + """Formats the scores prior to tabulating them.""" + new_scores = {'System': scores.pop('System')} + p_val_break_char = '\n' if multiline else ' ' + is_bootstrap = False + + def _color_p_value(p: float): + msg = f'(p = {p:.4f})' + if p > 0.05: + return Color.format(msg, 'red') + return msg + '*' + + for metric, vals in scores.items(): + new_vals = [] + + for result in vals: + if not isinstance(result, str): + # Format result instances + _str = f'{result.score:.{width}f}' + if result.mean is not None: + is_bootstrap = True + _str += f' ({result.mean:.{width}f} ± {result.ci:.{width}f})' + if result.p_value is not None: + _str += p_val_break_char + _color_p_value(result.p_value) + else: + # Already formatted in non paired-test mode + _str = result + + new_vals.append(_str) + + if is_bootstrap: + # Change titles + metric += ' (μ ± 95% CI)' + + new_scores[metric] = new_vals + + return new_scores + + +def print_results_table(results: dict, signatures: dict, args: Namespace): + """Prints out a nicely formatted table for multi-system evaluation mode.""" + + tablefmt = args.format + if tablefmt in ('text', 'json'): + # Fallback to simple table if json is given + tablefmt = 'fancy_grid' + elif tablefmt == 'latex': + # Use booktabs + tablefmt = 'latex_booktabs' + + # If paired testing has been given, this'll format the score lines + results = _format_score_lines( + results, args.width, multiline=tablefmt == 'fancy_grid') + + new_dict = {} + + # Color the column names and the baseline system name and scores + has_baseline = False + baseline_name = '' + for name in results.keys(): + val = results[name] + if val[0].startswith('Baseline:') or has_baseline: + if val[0].startswith('Baseline:'): + baseline_name = val[0] + has_baseline = True + val[0] = Color.format(val[0], 'yellow') + new_dict[Color.format(name, 'cyan')] = results[name] + + # Finally tabulate + table = tabulate( + new_dict, headers='keys', tablefmt=tablefmt, + colalign=('right', ), + stralign='center', + numalign='center', + floatfmt=f'.{args.width}f') + + print(table) + print() + + is_paired = args.paired_bs or args.paired_ar + + if is_paired: + test_type = 'bootstrap resampling' if args.paired_bs else 'approximate randomization' + n_samples_or_trials = args.paired_bs_n if args.paired_bs else args.paired_ar_n + test_sample_type = 'resampling trials' if args.paired_bs else 'trials' + msg = f'Paired {test_type} test with {n_samples_or_trials} {test_sample_type}' + + bline = Color.format('baseline', 'yellow') + bline_name = Color.format(baseline_name, 'yellow') + null_hyp = Color.format('Null hypothesis', 'green') + pval_color = Color.format('highlighted in red', 'red') + + # Print fancy header + print('-' * len(msg) + '\n' + msg + '\n' + '-' * len(msg)) + print(f' - Each system is pairwise compared to {bline_name}.') + if args.paired_bs: + print(' Actual system score / bootstrap estimated true mean / 95% CI are provided for each metric.') + else: + print(' Actual system score is provided for each metric.') + print() + print(f' - {null_hyp}: the system and the {bline} translations are essentially') + print(f' generated by the same underlying process. For a given system and the {bline},') + print(' the p-value is roughly the probability of the absolute score difference (delta)') + print(f' or higher occurring due to chance, under the assumption that the {null_hyp.lower()} is correct.') + print() + print(f' - Assuming a significance threshold of 0.05, the {null_hyp.lower()} can be rejected') + print(' for p-values < 0.05 (marked with "*"). This means that the delta is unlikely to be attributed') + print(f' to chance, hence the system is significantly "different" than the {bline}.') + print(f' Otherwise, the p-values are {pval_color}.') + print() + print(f' - NOTE: Significance does not tell whether a system is "better" than the {bline} but rather') + print(' emphasizes the "difference" of the systems in terms of the replicability of the delta.') + print() + + print('-----------------') + print('Metric signatures') + print('-----------------') + for name, sig in signatures.items(): + print(f' - {name:<10} {sig}') + + +def print_single_results(results: List[str], args: Namespace): + """Re-process metric strings to align them nicely.""" + if args.format == 'json': + if len(results) > 1: + proper_json = '[\n' + ',\n'.join(results) + '\n]' + print(proper_json) + else: + print(results[0]) + return + + # Color confidence strings for emphasis + if 'μ' in results[0]: + color_re = re.compile(r'(\(μ = [0-9\.]+ ± [0-9\.]+\))') + for idx in range(len(results)): + results[idx] = color_re.sub( + lambda m: Color.format(m.group(), 'cyan'), results[idx]) + + if len(results) == 1: + # Just one system, nothing to align. + print(results[0]) + return + + # Align by '=' character + lens = [] + for line in results: + # If not score_only, split lines from '=' for re-alignment + try: + lens.append(line.index('=') - 1) + except ValueError: + print(line) + + if len(lens) > 0: + w = max(lens) + for (_len, line) in zip(lens, results): + left, right = line[:_len], line[_len:] + print(f'{left:>{w}}{right}') + + +def sanity_check_lengths(system: Sequence[str], + refs: Sequence[Sequence[str]], + test_set: Optional[str] = None): + n_hyps = len(system) + if any(len(ref_stream) != n_hyps for ref_stream in refs): + sacrelogger.error("System and reference streams have different lengths.") + if test_set: + sacrelogger.error("This could be an issue with your system output " + "or with sacreBLEU's reference database if -t is given.") + sacrelogger.error("For the latter, try cleaning out the cache by typing:\n") + sacrelogger.error(f" rm -r {SACREBLEU_DIR}/{test_set}\n") + sacrelogger.error("The test sets will be re-downloaded the next time you run sacreBLEU.") + sys.exit(1) + + def smart_open(file, mode='rt', encoding='utf-8'): """Convenience function for reading compressed or plain text files. :param file: The file to read. @@ -38,7 +234,7 @@ def smart_open(file, mode='rt', encoding='utf-8'): return open(file, mode=mode, encoding=encoding, newline="\n") -def my_log(num): +def my_log(num: float) -> float: """ Floors the log function @@ -51,6 +247,33 @@ def my_log(num): return math.log(num) +def sum_of_lists(lists): + """Aggregates list of numeric lists by summing.""" + if len(lists) == 1: + return lists[0] + + # Preserve datatype + size = len(lists[0]) + init_val = type(lists[0][0])(0.0) + total = [init_val] * size + for ll in lists: + for i in range(size): + total[i] += ll[i] + return total + + +def args_to_dict(args, prefix: str, strip_prefix: bool = False): + """Filters argparse's `Namespace` into dictionary with arguments + beginning with the given prefix.""" + prefix += '_' + d = {} + for k, v in args.__dict__.items(): + if k.startswith(prefix): + k = k.replace(prefix, '') if strip_prefix else k + d[k] = v + return d + + def process_to_text(rawfile, txtfile, field: int = None): """Processes raw files to plain text files. Can handle SGML, XML, TSV files, and plain text. Called after downloading datasets. @@ -69,7 +292,7 @@ def _clean(s): return re.sub(r'\s+', ' ', s.strip()) if not os.path.exists(txtfile) or os.path.getsize(txtfile) == 0: - sacrelogger.info("Processing %s to %s", rawfile, txtfile) + sacrelogger.info(f"Processing {rawfile} to {txtfile}") if rawfile.endswith('.sgm') or rawfile.endswith('.sgml'): with smart_open(rawfile) as fin, smart_open(txtfile, 'wt') as fout: for line in fin: @@ -115,7 +338,7 @@ def print_test_set(test_set, langpair, side, origlang=None, subset=None): print('\t'.join(map(lambda x: x.rstrip(), lines))) -def get_source_file(test_set, langpair): +def get_source_file(test_set: str, langpair: str) -> str: """ Returns the source file for a given testset/langpair. Downloads it first if it is not already local. @@ -127,7 +350,7 @@ def get_source_file(test_set, langpair): return get_files(test_set, langpair)[0] -def get_reference_files(test_set, langpair): +def get_reference_files(test_set: str, langpair: str) -> List[str]: """ Returns a list of one or more reference file paths for the given testset/langpair. Downloads the references first if they are not already local. @@ -139,7 +362,7 @@ def get_reference_files(test_set, langpair): return get_files(test_set, langpair)[1:] -def get_files(test_set, langpair): +def get_files(test_set, langpair) -> List[str]: """ Returns the path of the source file and all reference files for the provided test set / language pair. @@ -151,25 +374,29 @@ def get_files(test_set, langpair): """ if test_set not in DATASETS: - raise Exception("No such test set {}".format(test_set)) + raise Exception(f"No such test set {test_set}") if langpair not in DATASETS[test_set]: - raise Exception("No such language pair {}/{}".format(test_set, langpair)) + raise Exception(f"No such language pair {test_set}/{langpair}") cachedir = os.path.join(SACREBLEU_DIR, test_set) source, target = langpair.split("-") - source_path = os.path.join(cachedir, "{}.{}".format(langpair, source)) + source_path = os.path.join(cachedir, f"{langpair}.{source}") num_refs = len(DATASETS[test_set][langpair]) - 1 if num_refs == 1: - reference_paths = [os.path.join(cachedir, "{}.{}".format(langpair, target))] + reference_paths = [os.path.join(cachedir, f"{langpair}.{target}")] else: - reference_paths = [os.path.join(cachedir, "{}.{}.{}".format(langpair, target, num)) for num in range(num_refs)] + reference_paths = [os.path.join(cachedir, f"{langpair}.{target}.{num}") for num in range(num_refs)] - if any(filterfalse(os.path.exists, [source_path] + reference_paths)): - download_test_set(test_set, langpair) + all_files = [source_path] + reference_paths - return [source_path] + reference_paths + for fname in all_files: + if not os.path.exists(fname): + download_test_set(test_set, langpair) + break + + return all_files def download_test_set(test_set, langpair=None): @@ -181,7 +408,10 @@ def download_test_set(test_set, langpair=None): """ if test_set not in DATASETS: - raise Exception("No such test set {}".format(test_set)) + raise Exception(f"No such test set {test_set}") + + import urllib.request + import ssl outdir = os.path.join(SACREBLEU_DIR, test_set) os.makedirs(outdir, exist_ok=True) @@ -191,17 +421,17 @@ def download_test_set(test_set, langpair=None): tarball = os.path.join(outdir, os.path.basename(dataset)) rawdir = os.path.join(outdir, 'raw') - lockfile = '{}.lock'.format(tarball) + lockfile = f'{tarball}.lock' with portalocker.Lock(lockfile, 'w', timeout=60): if not os.path.exists(tarball) or os.path.getsize(tarball) == 0: - sacrelogger.info("Downloading %s to %s", dataset, tarball) + sacrelogger.info(f"Downloading {dataset} to {tarball}") try: with urllib.request.urlopen(dataset) as f, open(tarball, 'wb') as out: out.write(f.read()) except ssl.SSLError: sacrelogger.warning('An SSL error was encountered in downloading the files. If you\'re on a Mac, ' - 'you may need to run the "Install Certificates.command" file located in the ' - '"Python 3" folder, often found under /Applications') + 'you may need to run the "Install Certificates.command" file located in the ' + '"Python 3" folder, often found under /Applications') sys.exit(1) # Check md5sum @@ -210,16 +440,17 @@ def download_test_set(test_set, langpair=None): with open(tarball, 'rb') as infile: for line in infile: md5.update(line) - if md5.hexdigest() != expected_md5: - sacrelogger.error('Fatal: MD5 sum of downloaded file was incorrect (got {}, expected {}).'.format(md5.hexdigest(), expected_md5)) - sacrelogger.error('Please manually delete "{}" and rerun the command.'.format(tarball)) + cur_md5 = md5.hexdigest() + if cur_md5 != expected_md5: + sacrelogger.error(f'Fatal: MD5 sum of downloaded file was incorrect (got {cur_md5}, expected {expected_md5}).') + sacrelogger.error(f'Please manually delete {tarball!r} and rerun the command.') sacrelogger.error('If the problem persists, the tarball may have changed, in which case, please contact the SacreBLEU maintainer.') sys.exit(1) else: - sacrelogger.info('Checksum passed: {}'.format(md5.hexdigest())) + sacrelogger.info(f'Checksum passed: {cur_md5}') # Extract the tarball - sacrelogger.info('Extracting %s', tarball) + sacrelogger.info(f'Extracting {tarball}') if tarball.endswith('.tar.gz') or tarball.endswith('.tgz'): import tarfile with tarfile.open(tarball) as tar: @@ -241,7 +472,7 @@ def download_test_set(test_set, langpair=None): field, rawfile = rawfile.split(':', maxsplit=1) field = int(field) rawpath = os.path.join(rawdir, rawfile) - outpath = os.path.join(outdir, '{}.{}'.format(pair, src)) + outpath = os.path.join(outdir, f'{pair}.{src}') process_to_text(rawpath, outpath, field=field) file_paths.append(outpath) @@ -253,26 +484,26 @@ def download_test_set(test_set, langpair=None): field = int(field) rawpath = os.path.join(rawdir, ref) if len(refs) >= 2: - outpath = os.path.join(outdir, '{}.{}.{}'.format(pair, tgt, i)) + outpath = os.path.join(outdir, f'{pair}.{tgt}.{i}') else: - outpath = os.path.join(outdir, '{}.{}'.format(pair, tgt)) + outpath = os.path.join(outdir, f'{pair}.{tgt}') process_to_text(rawpath, outpath, field=field) file_paths.append(outpath) return file_paths -def get_langpairs_for_testset(testset: str) -> List: +def get_langpairs_for_testset(testset: str) -> List[str]: """Return a list of language pairs for a given test set.""" return list(filter(lambda x: re.match(r'\w\w\-\w\w', x), DATASETS.get(testset, {}).keys())) -def get_available_testsets() -> List: +def get_available_testsets() -> List[str]: """Return a list of available test sets.""" return sorted(DATASETS.keys(), reverse=True) -def get_available_origlangs(test_sets, langpair): +def get_available_origlangs(test_sets, langpair) -> List[str]: """Return a list of origlang values in according to the raw SGM files.""" if test_sets is None: return [] @@ -296,11 +527,15 @@ def filter_subset(systems, test_sets, langpair, origlang, subset=None): if test_sets is None or langpair is None: raise ValueError('Filtering for --origlang or --subset needs a test (-t) and a language pair (-l).') + re_origlang = re.compile(r'.* origlang="([^"]+)".*\n') + re_id = re.compile(r'.* docid="([^"]+)".*\n') + indices_to_keep = [] + for test_set in test_sets.split(','): rawfile = os.path.join(SACREBLEU_DIR, test_set, 'raw', DATASETS[test_set][langpair][0]) if not rawfile.endswith('.sgm'): - raise Exception('--origlang and --subset supports only *.sgm files, not %s', rawfile) + raise Exception(f'--origlang and --subset supports only *.sgm files, not {rawfile!r}') if subset is not None: if test_set not in SUBSETS: raise Exception('No subset annotation available for test set ' + test_set) @@ -313,16 +548,59 @@ def filter_subset(systems, test_sets, langpair, origlang, subset=None): if origlang is None: include_doc = True else: - doc_origlang = re.sub(r'.* origlang="([^"]+)".*\n', '\\1', line) + doc_origlang = re_origlang.sub(r'\1', line) if origlang.startswith('non-'): include_doc = doc_origlang != origlang[4:] else: include_doc = doc_origlang == origlang + if subset is not None: - doc_id = re.sub(r'.* docid="([^"]+)".*\n', '\\1', line) + doc_id = re_id.sub(r'\1', line) if not re.search(subset, doc_to_tags.get(doc_id, '')): include_doc = False if line.startswith(' [no-cache] ', end='') + measure(klass, kwargs, systems, refs, cache=False) + + print(' > [cached] ', end='') + measure(klass, kwargs, systems, refs, cache=True) diff --git a/setup.py b/setup.py index a54b820a..85ff1938 100755 --- a/setup.py +++ b/setup.py @@ -73,31 +73,39 @@ def get_description(): return DESCRIPTION_RE.search(init).group(1) -setup( - name = 'sacrebleu', +def get_long_description(): + with open('README.md') as f: + long_description = f.read() - # Versions should comply with PEP440. For a discussion on single-sourcing - # the version across setup.py and the project code, see - # https://packaging.python.org/en/latest/single_source_version.html - version = get_version(), + with open('CHANGELOG.md') as f: + release_notes = f.read() - description = get_description(), + # Plug release notes into the long description + long_description = long_description.replace( + '# Release Notes\n\nPlease see [CHANGELOG.md](CHANGELOG.md) for release notes.', + release_notes) - long_description = 'SacreBLEU is a standard BLEU implementation that downloads and manages WMT datasets, produces scores on detokenized outputs, and reports a string encapsulating BLEU parameters, facilitating the production of shareable, comparable BLEU scores.', + return long_description - # The project's main homepage. - url = 'https://github.com/mjpost/sacrebleu', - author = 'Matt Post', +setup( + name='sacrebleu', + # Versions should comply with PEP440. For a discussion on single-sourcing + # the version across setup.py and the project code, see + # https://packaging.python.org/en/latest/single_source_version.html + version=get_version(), + description=get_description(), + long_description_content_type='text/markdown', + long_description=get_long_description(), + url='https://github.com/mjpost/sacrebleu', + author='Matt Post', author_email='post@cs.jhu.edu', maintainer_email='post@cs.jhu.edu', - - license = 'Apache License 2.0', - - python_requires = '>=3', - + license='Apache License 2.0', + # We don't support Python < 3.6 anymore + python_requires='>=3.6', # See https://pypi.python.org/pypi?%3Aaction=list_classifiers - classifiers = [ + classifiers=[ # How mature is this project? Common values are # 3 - Alpha # 4 - Beta @@ -114,16 +122,21 @@ def get_description(): # Pick your license as you wish (should match "license" above) 'License :: OSI Approved :: Apache Software License', + # List operating systems + 'Operating System :: POSIX', + 'Operating System :: MacOS :: MacOS X', + 'Operating System :: Microsoft :: Windows', + # Specify the Python versions you support here. In particular, ensure # that you indicate whether you support Python 2, Python 3 or both. 'Programming Language :: Python :: 3 :: Only', ], # What does your project relate to? - keywords = ['machine translation, evaluation, NLP, natural language processing, computational linguistics'], + keywords=['machine translation, evaluation, NLP, natural language processing, computational linguistics'], # Which packages to deploy (currently sacrebleu, sacrebleu.matrics and sacrebleu.tokenizers)? - packages = find_packages(), + packages=find_packages(), # Mark sacrebleu (and recursively all its sub-packages) as supporting mypy type hints (see PEP 561). package_data={"sacrebleu": ["py.typed"]}, @@ -132,16 +145,13 @@ def get_description(): # your project is installed. For an analysis of "install_requires" vs pip's # requirements files see: # https://packaging.python.org/en/latest/requirements.html - install_requires = [ - 'typing;python_version<"3.5"', - 'portalocker==2.0.0', - ], + install_requires=['portalocker', 'regex', 'tabulate>=0.8.9', 'numpy>=1.17', 'colorama'], # List additional groups of dependencies here (e.g. development # dependencies). You can install these using the following syntax, # for example: # $ pip install -e .[dev,test] - extras_require = {'ja': ['mecab-python3==1.0.3', 'ipadic>=1.0,<2.0'] }, + extras_require={'ja': ['mecab-python3==1.0.3', 'ipadic>=1.0,<2.0']}, # To provide executable scripts, use entry points in preference to the # "scripts" keyword. Entry points provide cross-platform support and allow diff --git a/test.sh b/test.sh index 69b4912e..81d2b494 100755 --- a/test.sh +++ b/test.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # @@ -23,20 +23,42 @@ set -u if [[ $(echo $BASH_VERSION | cut -d. -f1) -lt 4 ]]; then - echo "This script requires BASH version 4 or above (since it uses hashes)." - exit 1 + echo "This script requires BASH version 4 or above (since it uses hashes)." + exit 1 +fi + +# Switch from JSON output to text +export SACREBLEU_FORMAT="text" + +# Cleanup temporary files +trap "rm -f .tmp* data/.tmp*" EXIT INT TERM + +# For Travis CI to work on Windows/Mac OS X +if [[ "$OSTYPE" == "linux-gnu"* ]]; then + CMD="python -m sacrebleu" +elif [[ "$OSTYPE" == "msys" ]]; then + CMD="python -m sacrebleu" +elif [[ "$OSTYPE" == "darwin"* ]]; then + # OS X ships python -> python2 by default, be explicit + CMD="python3 -m sacrebleu" fi export SACREBLEU=$(pwd)/.sacrebleu export PYTHONPATH="${PWD}" # assuming PYTHONPATH=. as the default -CMD="python3 -m sacrebleu" +export NO_COLOR=1 # Only run this test limit_test=${1:-} +SKIP_INITIAL=${SKIP_INITIAL:-} SKIP_CHRF=${SKIP_CHRF:-} SKIP_TER=${SKIP_TER:-} SKIP_MECAB=${SKIP_MECAB:-} +SKIP_MTEVAL13=${SKIP_MTEVAL13:-} +SKIP_MTEVAL14=${SKIP_MTEVAL14:-} + +# test case counter +declare -i i=0 # TEST 1: download and process WMT17 data [[ -d $SACREBLEU/wmt17 ]] && rm -f $SACREBLEU/wmt17/{en-*,*-en*} @@ -48,44 +70,22 @@ declare -A EXPECTED EXPECTED["${CMD} -t wmt16,wmt17 -l en-fi --echo ref | ${CMD} -b -w 4 -t wmt16/B,wmt17/B -l en-fi"]=53.7432 EXPECTED["${CMD} -t wmt16,wmt17 -l en-fi --echo ref | ${CMD} -b -w 4 -t wmt16/B,wmt17/B -l en-fi --origlang=en"]=18.9054 EXPECTED["${CMD} -t wmt17 -l en-fi --echo ref | ${CMD} -b -t wmt17/B -l en-fi --detail"]="55.6 -origlang=en : sentences=1502 BLEU= 21.4 -origlang=fi : sentences=1500 BLEU=100.0" +origlang=en : sentences=1502 BLEU = 21.4 +origlang=fi : sentences=1500 BLEU = 100.0" EXPECTED["${CMD} -t wmt18,wmt19 -l en-de --echo=src | ${CMD} -t wmt18,wmt19 -l en-de -b --detail"]="3.6 -origlang=de : sentences=1498 BLEU= 3.6 -origlang=en : sentences=3497 BLEU= 3.5 -origlang=en country=EU: sentences= 265 BLEU= 2.5 -origlang=en country=GB: sentences= 913 BLEU= 3.1 -origlang=en country=OTHER: sentences= 801 BLEU= 2.5 -origlang=en country=US: sentences=1518 BLEU= 4.2 -origlang=en domain=business: sentences= 241 BLEU= 3.4 -origlang=en domain=crime: sentences= 570 BLEU= 3.6 -origlang=en domain=entertainment: sentences= 322 BLEU= 5.1 -origlang=en domain=politics: sentences= 959 BLEU= 3.0 -origlang=en domain=scitech: sentences= 211 BLEU= 3.1 -origlang=en domain=sport: sentences= 534 BLEU= 3.6 -origlang=en domain=world: sentences= 660 BLEU= 3.1" - -for command in "${!EXPECTED[@]}"; do - echo Testing $command - obtained=`eval $command` - expected=${EXPECTED[$command]} - if [[ $obtained != $expected ]]; then - echo -e "\nFAILED:\n expected = $expected\n obtained = $obtained" - exit 1 - fi - echo PASS -done - -# Test loading via file instead of STDIN -echo "Testing loading via file instead of STDIN" -${CMD} -t wmt17 -l en-de --echo ref > .wmt17.en-de.de.tmp -score=$(${CMD} -t wmt17 -l en-de -i .wmt17.en-de.de.tmp -b) -rm .wmt17.en-de.de.tmp -if [[ $score != '100.0' ]]; then - echo "File test failed." - exit 1 -fi -echo PASS +origlang=de : sentences=1498 BLEU = 3.6 +origlang=en : sentences=3497 BLEU = 3.5 +origlang=en country=EU : sentences=265 BLEU = 2.5 +origlang=en country=GB : sentences=913 BLEU = 3.1 +origlang=en country=OTHER : sentences=801 BLEU = 2.5 +origlang=en country=US : sentences=1518 BLEU = 4.2 +origlang=en domain=business : sentences=241 BLEU = 3.4 +origlang=en domain=crime : sentences=570 BLEU = 3.6 +origlang=en domain=entertainment : sentences=322 BLEU = 5.1 +origlang=en domain=politics : sentences=959 BLEU = 3.0 +origlang=en domain=scitech : sentences=211 BLEU = 3.1 +origlang=en domain=sport : sentences=534 BLEU = 3.6 +origlang=en domain=world : sentences=660 BLEU = 3.1" [[ ! -d data ]] && mkdir data cd data @@ -102,79 +102,111 @@ if [[ ! -d en-ja-translation-example-master ]]; then unzip master.zip fi -# Test echoing of source, reference, and both -${CMD} -t wmt17/ms -l zh-en --echo src > .tmp.echo -diff .tmp.echo $SACREBLEU/wmt17/ms/zh-en.zh -if [[ $? -ne 0 ]]; then - echo "Source echo failed." - exit 1 -fi -${CMD} -t wmt17/ms -l zh-en --echo ref | cut -f3 > .tmp.echo -diff .tmp.echo $SACREBLEU/wmt17/ms/zh-en.en.2 -if [[ $? -ne 0 ]]; then - echo "Source echo failed." - exit 1 -fi - -export LC_ALL=C - -declare -i i=0 - -echo '-----------------------' -echo 'Control character tests' -echo '-----------------------' -score1=$( echo "Hello! How are you doing today?" | ${CMD} -w 2 -b <(printf "Hello! How are you \r doing today?") ) -score2=$( echo "Hello! How are you doing today?" | ${CMD} -w 2 -b <(echo "Hello! How are you doing today?") ) -if [[ $score1 != $score2 ]]; then - echo "Control character in reference test failed" - exit 1 -fi -let i++ -echo "Passed control character in reference test" - -##################################################################### -# Tests for single-ref BLEU, multi-ref BLEU, signature and tokenizers -##################################################################### -path="wmt17-submitted-data/txt/system-outputs/newstest2017/cs-en" -ref1="${path}/newstest2017.online-A.0.cs-en" -ref2="${path}/newstest2017.online-B.0.cs-en" -sys="${path}/newstest2017.PJATK.4760.cs-en" - -echo '---------------------' -echo 'BLEU regression tests' -echo '---------------------' -unset EXPECTED -declare -A EXPECTED +if [ -z $SKIP_INITIAL ]; then + for command in "${!EXPECTED[@]}"; do + echo Testing $command + # Convert line endings to UNIX for Windows tests + obtained=`eval $command | tr -d '\015'` + expected=${EXPECTED[$command]} + if [[ $obtained != $expected ]]; then + echo -e "\nFAILED:\n expected = $expected\n obtained = $obtained" + exit 1 + fi + echo PASS + done -# Single ref, tokenizer variants, lowercase -EXPECTED["${CMD} -w 4 -b -l cs-en -i $sys $ref1"]=36.8799 -EXPECTED["${CMD} -lc -w 4 -b -l cs-en -i $sys $ref1"]=38.1492 -EXPECTED["${CMD} --tokenize 13a -w 4 -b -l cs-en -i $sys $ref1"]=36.8799 -EXPECTED["${CMD} --tokenize none -w 4 -b -l cs-en -i $sys $ref1"]=34.0638 -EXPECTED["${CMD} --tokenize intl -w 4 -b -l cs-en -i $sys $ref1"]=37.3859 -# multiple REF files -EXPECTED["${CMD} -w 4 -b -l cs-en -i $sys $ref1 $ref2"]=44.6732 -# multiple REFs with tab-delimited stream -EXPECTED["${CMD} -w 4 -b -l cs-en -i $sys --num-refs 2 <(paste $ref1 $ref2)"]=44.6732 -# Check signature correctness for multi-reference -# separate files -EXPECTED["${CMD} -l cs-en -i $sys $ref1 $ref2 | perl -pe 's/.*numrefs\.([0-9]).*/\1/'"]=2 -# tab delimited stream -EXPECTED["${CMD} -l cs-en -i $sys --num-refs 2 <(paste $ref1 $ref2) | perl -pe 's/.*numrefs\.([0-9]).*/\1/'"]=2 - - -# Run the tests -for command in "${!EXPECTED[@]}"; do - echo Testing $command - obtained=`eval $command` - expected=${EXPECTED[$command]} - if [[ $obtained != $expected ]]; then - echo -e "\nFAILED:\n expected = $expected\n obtained = $obtained" + # Test loading via file instead of STDIN + echo "Testing loading via file instead of STDIN" + ${CMD} -t wmt17 -l en-de --echo ref > .tmp.wmt17 + score=$(${CMD} -t wmt17 -l en-de -i .tmp.wmt17 -b) + if [[ $score != '100.0' ]]; then + echo "File test failed." exit 1 fi echo PASS + + # Test echoing of source, reference, and both + # Replace \r\n with \n for Windows compatibility + ${CMD} -t wmt17/ms -l zh-en --echo src | tr -d '\015' > .tmp.echo + diff .tmp.echo $SACREBLEU/wmt17/ms/zh-en.zh + if [[ $? -ne 0 ]]; then + echo "Source echo failed." + exit 1 + fi + ${CMD} -t wmt17/ms -l zh-en --echo ref | tr -d '\015' | cut -f3 > .tmp.echo + diff .tmp.echo $SACREBLEU/wmt17/ms/zh-en.en.2 + if [[ $? -ne 0 ]]; then + echo "Source echo failed." + exit 1 + fi + + export LC_ALL=C + + echo '-----------------------' + echo 'Control character tests' + echo '-----------------------' + printf "Hello! How are you \r doing today?" > .tmp_ref_buggy + echo "Hello! How are you doing today?" > .tmp_ref_okay + score1=$( echo "Hello! How are you doing today?" | ${CMD} -w 2 -b .tmp_ref_buggy ) + score2=$( echo "Hello! How are you doing today?" | ${CMD} -w 2 -b .tmp_ref_okay ) + if [[ $score1 != $score2 ]]; then + echo "Control character in reference test failed" + exit 1 + fi let i++ -done + echo "Passed control character in reference test" + + ##################################################################### + # Tests for single-ref BLEU, multi-ref BLEU, signature and tokenizers + ##################################################################### + path="wmt17-submitted-data/txt/system-outputs/newstest2017/cs-en" + ref1="${path}/newstest2017.online-A.0.cs-en" + ref2="${path}/newstest2017.online-B.0.cs-en" + sys="${path}/newstest2017.PJATK.4760.cs-en" + + echo '--------------------------' + echo 'BLEU/CHRF regression tests' + echo '--------------------------' + unset EXPECTED + declare -A EXPECTED + + paste $ref1 $ref2 > .tmp_refs + + # Single ref, tokenizer variants, lowercase + EXPECTED["${CMD} -w 4 -b -l cs-en $ref1 -i $sys"]=36.8799 + EXPECTED["${CMD} -lc -w 4 -b -l cs-en $ref1 -i $sys"]=38.1492 + EXPECTED["${CMD} --tokenize 13a -w 4 -b -l cs-en $ref1 -i $sys"]=36.8799 + EXPECTED["${CMD} --tokenize none -w 4 -b -l cs-en $ref1 -i $sys"]=34.0638 + # multiple REF files + EXPECTED["${CMD} -w 4 -b -l cs-en $ref1 $ref2 -i $sys"]=44.6732 + # multiple REF CHRF (epsilon smoothing) + EXPECTED["${CMD} -m chrf --chrf-eps-smoothing -w 4 -b -l cs-en $ref1 $ref2 -i $sys"]=67.8596 + # multiple REF CHRF (effective order smoothing) + EXPECTED["${CMD} -m chrf -w 4 -b -l cs-en $ref1 $ref2 -i $sys"]=67.8603 + # multiple REF CHRF++ + EXPECTED["${CMD} -m chrf -cw 2 -w 4 -b -l cs-en $ref1 $ref2 -i $sys"]=66.1016 + # multiple REFs with tab-delimited stream + EXPECTED["${CMD} -w 4 -b -l cs-en --num-refs 2 .tmp_refs -i $sys"]=44.6732 + # Check signature correctness for multi-reference + # separate files + EXPECTED["${CMD} -l cs-en $ref1 $ref2 -i $sys | perl -pe 's/.*nrefs:([0-9]).*/\1/'"]=2 + # tab delimited stream + EXPECTED["${CMD} -l cs-en --num-refs 2 .tmp_refs -i $sys | perl -pe 's/.*nrefs:([0-9]).*/\1/'"]=2 + + + # Run the tests + for command in "${!EXPECTED[@]}"; do + echo Testing $command + obtained=`eval $command` + expected=${EXPECTED[$command]} + if [[ $obtained != $expected ]]; then + echo -e "\nFAILED:\n expected = $expected\n obtained = $obtained" + exit 1 + fi + echo PASS + let i++ + done +fi ####################################################### # Pre-computed chrF scores from official implementation @@ -182,163 +214,48 @@ done ####################################################### declare -A CHRF=( ["newstest2017.PJATK.4760.cs-en.sgm"]=52.5947 ["newstest2017.online-A.0.cs-en.sgm"]=53.3856 - ["newstest2017.online-B.0.cs-en.sgm"]=54.4608 - ["newstest2017.uedin-nmt.4955.cs-en.sgm"]=56.8490 - ["newstest2017.C-3MA.4958.de-en.sgm"]=54.9500 - ["newstest2017.KIT.4951.de-en.sgm"]=59.5876 - ["newstest2017.LIUM-NMT.4733.de-en.sgm"]=56.1531 - ["newstest2017.RWTH-nmt-ensemble.4920.de-en.sgm"]=58.8482 - ["newstest2017.SYSTRAN.4846.de-en.sgm"]=58.6623 - ["newstest2017.TALP-UPC.4830.de-en.sgm"]=55.6962 - ["newstest2017.online-A.0.de-en.sgm"]=59.1026 - ["newstest2017.online-B.0.de-en.sgm"]=59.0564 - ["newstest2017.online-F.0.de-en.sgm"]=50.2126 - ["newstest2017.online-G.0.de-en.sgm"]=55.6530 - ["newstest2017.uedin-nmt.4723.de-en.sgm"]=60.1464 ["newstest2017.CU-Chimera.4886.en-cs.sgm"]=48.3370 ["newstest2017.LIUM-FNMT.4852.en-cs.sgm"]=48.4708 - ["newstest2017.LIUM-NMT.4947.en-cs.sgm"]=48.4079 - ["newstest2017.PJATK.4761.en-cs.sgm"]=43.0152 - ["newstest2017.limsi-factored-norm.4957.en-cs.sgm"]=48.7015 - ["newstest2017.online-A.0.en-cs.sgm"]=45.9326 - ["newstest2017.online-B.0.en-cs.sgm"]=48.4691 - ["newstest2017.tuning-task-afrl_4gb.sgm.0.en-cs.sgm"]=40.8498 - ["newstest2017.tuning-task-afrl_8gb.sgm.0.en-cs.sgm"]=41.3727 - ["newstest2017.tuning-task-baseline_4gb.sgm.0.en-cs.sgm"]=40.4781 - ["newstest2017.tuning-task-baseline_8gb.sgm.0.en-cs.sgm"]=40.5823 - ["newstest2017.tuning-task-denisov_4gb.sgm.0.en-cs.sgm"]=39.9792 - ["newstest2017.tuning-task-ufal_4gb.sgm.0.en-cs.sgm"]=39.4850 - ["newstest2017.tuning-task-ufal_8gb.sgm.0.en-cs.sgm"]=42.4445 - ["newstest2017.uedin-nmt.4956.en-cs.sgm"]=50.5857 ["newstest2017.C-3MA.4959.en-de.sgm"]=51.9533 ["newstest2017.FBK.4870.en-de.sgm"]=54.7152 - ["newstest2017.KIT.4950.en-de.sgm"]=55.7629 - ["newstest2017.LIUM-NMT.4900.en-de.sgm"]=55.9284 - ["newstest2017.LMU-nmt-reranked.4934.en-de.sgm"]=56.3908 - ["newstest2017.LMU-nmt-single.4893.en-de.sgm"]=55.9216 - ["newstest2017.PROMT-Rule-based.4735.en-de.sgm"]=50.3511 - ["newstest2017.RWTH-nmt-ensemble.4921.en-de.sgm"]=55.6116 - ["newstest2017.SYSTRAN.4847.en-de.sgm"]=55.5758 - ["newstest2017.TALP-UPC.4834.en-de.sgm"]=51.6860 - ["newstest2017.online-A.0.en-de.sgm"]=52.0023 - ["newstest2017.online-B.0.en-de.sgm"]=56.2633 - ["newstest2017.online-F.0.en-de.sgm"]=49.2588 - ["newstest2017.online-G.0.en-de.sgm"]=51.5871 - ["newstest2017.uedin-nmt.4722.en-de.sgm"]=57.7227 - ["newstest2017.xmu.4910.en-de.sgm"]=55.9642 ["newstest2017.AaltoHnmtFlatcat.4798.en-fi.sgm"]=50.5981 - ["newstest2017.AaltoHnmtMultitask.4873.en-fi.sgm"]=52.4618 - ["newstest2017.HY-AH.4797.en-fi.sgm"]=46.9995 - ["newstest2017.HY-HNMT.4961.en-fi.sgm"]=54.9460 - ["newstest2017.HY-SMT.4882.en-fi.sgm"]=51.2609 - ["newstest2017.TALP-UPC.4939.en-fi.sgm"]=44.8177 - ["newstest2017.apertium-unconstrained.4769.en-fi.sgm"]=21.7725 - ["newstest2017.jhu-nmt-lattice-rescore.4903.en-fi.sgm"]=51.3314 - ["newstest2017.jhu-pbmt.4968.en-fi.sgm"]=49.7043 ["newstest2017.online-A.0.en-fi.sgm"]=49.5458 - ["newstest2017.online-B.0.en-fi.sgm"]=56.1894 - ["newstest2017.online-G.0.en-fi.sgm"]=51.8957 - ["newstest2017.C-3MA.5069.en-lv.sgm"]=43.8029 - ["newstest2017.HY-HNMT.5066.en-lv.sgm"]=46.3223 - ["newstest2017.KIT.5062.en-lv.sgm"]=51.1055 - ["newstest2017.LIUM-FNMT.5043.en-lv.sgm"]=47.9871 - ["newstest2017.LIUM-NMT.5042.en-lv.sgm"]=48.1380 - ["newstest2017.PJATK.4744.en-lv.sgm"]=35.9152 - ["newstest2017.QT21-System-Combination.5063.en-lv.sgm"]=50.6553 - ["newstest2017.jhu-pbmt.4969.en-lv.sgm"]=46.9511 ["newstest2017.limsi-factored-norm.5041.en-lv.sgm"]=49.3634 ["newstest2017.online-A.0.en-lv.sgm"]=45.2101 - ["newstest2017.online-B.0.en-lv.sgm"]=50.1384 - ["newstest2017.tilde-c-nmt-smt-hybrid.5049.en-lv.sgm"]=51.6770 - ["newstest2017.tilde-nc-nmt-smt-hybrid.5047.en-lv.sgm"]=52.7970 - ["newstest2017.tilde-nc-smt.5044.en-lv.sgm"]=51.5999 - ["newstest2017.uedin-nmt.5016.en-lv.sgm"]=49.2607 - ["newstest2017.usfd-consensus-kit.5078.en-lv.sgm"]=50.7400 - ["newstest2017.usfd-consensus-qt21.5077.en-lv.sgm"]=51.0538 - ["newstest2017.PROMT-Rule-based.4736.en-ru.sgm"]=53.2902 - ["newstest2017.afrl-mitll-backtrans.4907.en-ru.sgm"]=52.2807 ["newstest2017.jhu-pbmt.4986.en-ru.sgm"]=54.9569 - ["newstest2017.online-A.0.en-ru.sgm"]=53.4180 ["newstest2017.online-B.0.en-ru.sgm"]=60.4059 - ["newstest2017.online-F.0.en-ru.sgm"]=42.0595 - ["newstest2017.online-G.0.en-ru.sgm"]=56.5493 - ["newstest2017.online-H.0.en-ru.sgm"]=56.8716 - ["newstest2017.uedin-nmt.4756.en-ru.sgm"]=56.6076 - ["newstest2017.JAIST.4858.en-tr.sgm"]=42.1117 ["newstest2017.LIUM-NMT.4953.en-tr.sgm"]=47.5881 - ["newstest2017.jhu-nmt-lattice-rescore.4904.en-tr.sgm"]=42.5309 - ["newstest2017.jhu-pbmt.4970.en-tr.sgm"]=42.1480 - ["newstest2017.online-A.0.en-tr.sgm"]=47.4192 - ["newstest2017.online-B.0.en-tr.sgm"]=54.1855 ["newstest2017.online-G.0.en-tr.sgm"]=48.7404 - ["newstest2017.uedin-nmt.4932.en-tr.sgm"]=50.3093 - ["newstest2017.CASICT-DCU-NMT.5157.en-zh.sgm"]=27.0468 - ["newstest2017.Oregon-State-University-S.5174.en-zh.sgm"]=24.5325 - ["newstest2017.SogouKnowing-nmt.5131.en-zh.sgm"]=31.3259 - ["newstest2017.UU-HNMT.5134.en-zh.sgm"]=22.6901 - ["newstest2017.jhu-nmt.5153.en-zh.sgm"]=27.9123 - ["newstest2017.online-A.0.en-zh.sgm"]=25.6325 - ["newstest2017.online-B.0.en-zh.sgm"]=29.2984 + ["newstest2017.UU-HNMT.5134.en-zh.sgm"]=22.6844 ["newstest2017.online-F.0.en-zh.sgm"]=18.7403 - ["newstest2017.online-G.0.en-zh.sgm"]=20.6007 - ["newstest2017.uedin-nmt.5111.en-zh.sgm"]=31.8748 - ["newstest2017.xmunmt.5165.en-zh.sgm"]=31.7770 - ["newstest2017.Hunter-MT.4925.fi-en.sgm"]=47.9929 - ["newstest2017.TALP-UPC.4937.fi-en.sgm"]=45.7795 - ["newstest2017.apertium-unconstrained.4793.fi-en.sgm"]=38.6486 - ["newstest2017.online-A.0.fi-en.sgm"]=51.9119 - ["newstest2017.online-B.0.fi-en.sgm"]=55.7417 - ["newstest2017.online-G.0.fi-en.sgm"]=53.8541 - ["newstest2017.C-3MA.5067.lv-en.sgm"]=43.3150 - ["newstest2017.Hunter-MT.5092.lv-en.sgm"]=46.1868 - ["newstest2017.PJATK.4740.lv-en.sgm"]=39.3033 - ["newstest2017.jhu-pbmt.4980.lv-en.sgm"]=46.7783 - ["newstest2017.online-A.0.lv-en.sgm"]=47.1552 - ["newstest2017.online-B.0.lv-en.sgm"]=51.4714 - ["newstest2017.tilde-c-nmt-smt-hybrid.5051.lv-en.sgm"]=49.1392 - ["newstest2017.tilde-nc-nmt-smt-hybrid.5050.lv-en.sgm"]=51.5697 - ["newstest2017.uedin-nmt.5017.lv-en.sgm"]=48.0781 - ["newstest2017.NRC.4855.ru-en.sgm"]=60.1860 - ["newstest2017.afrl-mitll-opennmt.4896.ru-en.sgm"]=59.4356 - ["newstest2017.afrl-mitll-syscomb.4905.ru-en.sgm"]=59.7636 - ["newstest2017.jhu-pbmt.4978.ru-en.sgm"]=58.1248 - ["newstest2017.online-A.0.ru-en.sgm"]=57.9992 - ["newstest2017.online-B.0.ru-en.sgm"]=63.0622 - ["newstest2017.online-F.0.ru-en.sgm"]=49.5420 - ["newstest2017.online-G.0.ru-en.sgm"]=61.8913 - ["newstest2017.uedin-nmt.4890.ru-en.sgm"]=57.4335 - ["newstest2017.JAIST.4859.tr-en.sgm"]=43.1983 - ["newstest2017.LIUM-NMT.4888.tr-en.sgm"]=45.3857 - ["newstest2017.PROMT-SMT.4737.tr-en.sgm"]=46.1464 - ["newstest2017.afrl-mitll-m2w-nr1.4901.tr-en.sgm"]=45.7267 - ["newstest2017.afrl-mitll-syscomb.4902.tr-en.sgm"]=46.1653 - ["newstest2017.jhu-pbmt.4972.tr-en.sgm"]=43.2728 - ["newstest2017.online-A.0.tr-en.sgm"]=52.1165 - ["newstest2017.online-B.0.tr-en.sgm"]=54.1508 - ["newstest2017.online-G.0.tr-en.sgm"]=49.4456 - ["newstest2017.uedin-nmt.4931.tr-en.sgm"]=47.8457 - ["newstest2017.CASICT-DCU-NMT.5144.zh-en.sgm"]=49.7426 - ["newstest2017.NMT-Model-Average-Multi-Cards.5099.zh-en.sgm"]=47.3694 - ["newstest2017.NRC.5172.zh-en.sgm"]=53.6810 - ["newstest2017.Oregon-State-University-S.5173.zh-en.sgm"]=47.6272 - ["newstest2017.PROMT-SMT.5125.zh-en.sgm"]=48.3674 - ["newstest2017.ROCMT.5183.zh-en.sgm"]=50.0904 - ["newstest2017.SogouKnowing-nmt.5171.zh-en.sgm"]=55.0223 - ["newstest2017.UU-HNMT.5162.zh-en.sgm"]=45.2487 - ["newstest2017.afrl-mitll-opennmt.5109.zh-en.sgm"]=50.3686 - ["newstest2017.jhu-nmt.5151.zh-en.sgm"]=49.3613 - ["newstest2017.online-A.0.zh-en.sgm"]=53.8268 - ["newstest2017.online-B.0.zh-en.sgm"]=59.2377 - ["newstest2017.online-F.0.zh-en.sgm"]=45.6546 - ["newstest2017.online-G.0.zh-en.sgm"]=49.9084 - ["newstest2017.uedin-nmt.5112.zh-en.sgm"]=53.5398 - ["newstest2017.xmunmt.5160.zh-en.sgm"]=54.3314 + ) + +######################################################### +# Pre-computed chrF++ scores from official implementation +# Cmd: chrF++.py -H hyp -R ref +######################################################### +declare -A CHRFPP=( ["newstest2017.PJATK.4760.cs-en.sgm"]=50.2947 + ["newstest2017.online-A.0.cs-en.sgm"]=51.1037 + ["newstest2017.CU-Chimera.4886.en-cs.sgm"]=45.6732 + ["newstest2017.LIUM-FNMT.4852.en-cs.sgm"]=45.7210 + ["newstest2017.C-3MA.4959.en-de.sgm"]=49.1683 + ["newstest2017.FBK.4870.en-de.sgm"]=52.2330 + ["newstest2017.AaltoHnmtFlatcat.4798.en-fi.sgm"]=46.6295 + ["newstest2017.online-A.0.en-fi.sgm"]=44.9147 + ["newstest2017.limsi-factored-norm.5041.en-lv.sgm"]=45.7117 + ["newstest2017.online-A.0.en-lv.sgm"]=40.8963 + ["newstest2017.jhu-pbmt.4986.en-ru.sgm"]=52.0154 + ["newstest2017.online-B.0.en-ru.sgm"]=57.6404 + ["newstest2017.LIUM-NMT.4953.en-tr.sgm"]=43.8311 + ["newstest2017.online-G.0.en-tr.sgm"]=44.6236 + ["newstest2017.UU-HNMT.5134.en-zh.sgm"]=17.0181 + ["newstest2017.online-F.0.en-zh.sgm"]=14.1572 ) if [ -z $SKIP_CHRF ]; then - echo "-------------------" - echo "Starting chrF tests" - echo "-------------------" + echo "------------------------------" + echo "Starting chrF and chrF++ tests" + echo "------------------------------" # Test only for different target languages as there is no tokenization # issue involved in chrF for pair in cs-en en-cs en-de en-fi en-lv en-ru en-tr en-zh; do @@ -347,25 +264,35 @@ if [ -z $SKIP_CHRF ]; then for sgm in wmt17-submitted-data/sgm/system-outputs/newstest2017/$pair/*.sgm; do name=$(basename $sgm) + if [[ ! -v CHRF[$name] ]]; then continue; fi if [[ ! -z $limit_test && $limit_test != $name ]]; then continue; fi sys=$(basename $sgm .sgm | perl -pe 's/newstest2017\.//') txt=$(dirname $sgm | perl -pe 's/sgm/txt/')/$(basename $sgm .sgm) - src=wmt17-submitted-data/sgm/sources/newstest2017-$source$target-src.$source.sgm - ref=wmt17-submitted-data/sgm/references/newstest2017-$source$target-ref.$target.sgm - score=$(cat $txt | ${CMD} -w 4 -t wmt17 -l $source-$target -b --metrics chrf) + # Test chrF + score=$(cat $txt | ${CMD} -w 4 -t wmt17 -l $source-$target -b --metrics chrf --chrf-eps-smoothing) + expected_score=${CHRF[$name]} + echo "import sys; sys.exit(1 if abs(${score}-${expected_score}) > 1e-6 else 0)" | python + + if [[ $? -eq 1 ]]; then + echo "FAILED chrF test $pair/$sys (wanted $expected_score got $score)" + exit 1 + fi + echo "Passed $source-$target $sys chrF++.py: $expected_score sacreCHRF: $score" + let i++ - # rescale to 0-1 - expected_score=`echo "print('{:.4f}'.format(${CHRF[$name]} / 100.0))" | python` + # Test chrF++ + score=$(cat $txt | ${CMD} -w 4 -t wmt17 -l $source-$target -b --metrics chrf --chrf-word-order 2 --chrf-eps-smoothing) + expected_score=${CHRFPP[$name]} - echo "import sys; sys.exit(1 if abs(${score}-${expected_score}) > 0.01 else 0)" | python + echo "import sys; sys.exit(1 if abs(${score}-${expected_score}) > 1e-6 else 0)" | python if [[ $? -eq 1 ]]; then - echo "FAILED test $pair/$sys (wanted $expected_score got $score)" + echo "FAILED chrF++ test $pair/$sys (wanted $expected_score got $score)" exit 1 fi - echo "Passed $source-$target $sys chrF++.py: $expected_score sacreCHRF: $score" + echo "Passed $source-$target $sys chrF++.py: $expected_score sacreCHRF++: $score" let i++ done @@ -375,9 +302,9 @@ fi ################################################################ # Pre-computed results from Moses' mteval-v13a.pl for BLEU tests ################################################################ -echo "-------------------" -echo "Starting BLEU tests" -echo "-------------------" +echo "------------------------------------" +echo "Starting BLEU tests (mteval-v13a.pl)" +echo "------------------------------------" declare -A MTEVAL=( ["newstest2017.PJATK.4760.cs-en.sgm"]=23.15 ["newstest2017.online-A.0.cs-en.sgm"]=25.12 ["newstest2017.online-B.0.cs-en.sgm"]=27.45 @@ -534,34 +461,93 @@ declare -A MTEVAL=( ["newstest2017.PJATK.4760.cs-en.sgm"]=23.15 ["kyoto-test"]=14.48 ) -for pair in cs-en de-en en-cs en-de en-fi en-lv en-ru en-tr en-zh fi-en lv-en ru-en tr-en zh-en; do - source=$(echo $pair | cut -d- -f1) - target=$(echo $pair | cut -d- -f2) - for sgm in wmt17-submitted-data/sgm/system-outputs/newstest2017/$pair/*.sgm; do - name=$(basename $sgm) +if [ -z $SKIP_MTEVAL13 ]; then + for pair in cs-en de-en en-cs en-de en-fi en-lv en-ru en-tr en-zh fi-en lv-en ru-en tr-en zh-en; do + source=$(echo $pair | cut -d- -f1) + target=$(echo $pair | cut -d- -f2) + for sgm in wmt17-submitted-data/sgm/system-outputs/newstest2017/$pair/*.sgm; do + name=$(basename $sgm) - if [[ ! -z $limit_test && $limit_test != $name ]]; then continue; fi + if [[ ! -v MTEVAL[$name] ]]; then continue; fi + if [[ ! -z $limit_test && $limit_test != $name ]]; then continue; fi - sys=$(basename $sgm .sgm | perl -pe 's/newstest2017\.//') - txt=$(dirname $sgm | perl -pe 's/sgm/txt/')/$(basename $sgm .sgm) - src=wmt17-submitted-data/sgm/sources/newstest2017-$source$target-src.$source.sgm - ref=wmt17-submitted-data/sgm/references/newstest2017-$source$target-ref.$target.sgm + sys=$(basename $sgm .sgm | perl -pe 's/newstest2017\.//') + txt=$(dirname $sgm | perl -pe 's/sgm/txt/')/$(basename $sgm .sgm) + src=wmt17-submitted-data/sgm/sources/newstest2017-$source$target-src.$source.sgm + ref=wmt17-submitted-data/sgm/references/newstest2017-$source$target-ref.$target.sgm - # mteval=$($MOSES/scripts/generic/mteval-v13a.pl -c -s $src -r $ref -t $sgm 2> /dev/null | grep "BLEU score" | cut -d' ' -f9) - # mteval=$(echo "print($bleu1 * 100)" | python) - score=$(cat $txt | ${CMD} -w 2 -t wmt17 -l $source-$target -b) + # mteval=$($MOSES/scripts/generic/mteval-v13a.pl -c -s $src -r $ref -t $sgm 2> /dev/null | grep "BLEU score" | cut -d' ' -f9) + # mteval=$(echo "print($bleu1 * 100)" | python) + score=$(cat $txt | ${CMD} -w 2 -t wmt17 -l $source-$target -b) - echo "import sys; sys.exit(1 if abs($score-${MTEVAL[$name]}) > 0.01 else 0)" | python + echo "import sys; sys.exit(1 if abs($score-${MTEVAL[$name]}) > 0.01 else 0)" | python - if [[ $? -eq 1 ]]; then - echo "FAILED test $pair/$sys (wanted ${MTEVAL[$name]} got $score)" - exit 1 - fi - echo "Passed $source-$target $sys mteval-v13a.pl: ${MTEVAL[$name]} sacreBLEU: $score" + if [[ $? -eq 1 ]]; then + echo "FAILED test $pair/$sys (wanted ${MTEVAL[$name]} got $score)" + exit 1 + fi + echo "Passed $source-$target $sys mteval-v13a.pl: ${MTEVAL[$name]} sacreBLEU: $score" - let i++ - done -done + let i++ + done + done +fi + +############################################################################ +# Pre-computed results from Moses' mteval-v14.pl for BLEU tests +# mteval-v14a.pl -c -s $src -r $ref -t $sgm -b --international-tokenization +############################################################################ +echo "-----------------------------------------------------------------" +echo "Starting BLEU tests (mteval-v14.pl, --international-tokenization)" +echo "-----------------------------------------------------------------" + +declare -A MTEVAL14=( ["newstest2017.online-A.0.en-ru.sgm"]=23.99 + ["newstest2017.online-A.0.en-cs.sgm"]=16.65 + ["newstest2017.uedin-nmt.4890.ru-en.sgm"]=30.91 + ["newstest2017.xmunmt.5160.zh-en.sgm"]=26.8 + ["newstest2017.LIUM-NMT.4900.en-de.sgm"]=26.89 + ["newstest2017.tilde-nc-nmt-smt-hybrid.5047.en-lv.sgm"]=20.88 + ["newstest2017.TALP-UPC.4937.fi-en.sgm"]=16.18 + ["newstest2017.LIUM-NMT.4888.tr-en.sgm"]=17.94 + ["newstest2017.SogouKnowing-nmt.5131.en-zh.sgm"]=7.24 + ["newstest2017.online-B.0.en-tr.sgm"]=22.88 + ["newstest2017.apertium-unconstrained.4769.en-fi.sgm"]=1.08 + ["newstest2017.tilde-c-nmt-smt-hybrid.5051.lv-en.sgm"]=20.44 + ["newstest2017.RWTH-nmt-ensemble.4920.de-en.sgm"]=33.73 + # Below two not exactly compatible with mteval due to 's (#138) + ["newstest2017.PJATK.4760.cs-en.sgm"]=23.47 + ["newstest2017.PJATK.4761.en-cs.sgm"]=15.91 + ) + +if [ -z $SKIP_MTEVAL14 ]; then + for pair in cs-en de-en en-cs en-de en-fi en-lv en-ru en-tr en-zh fi-en lv-en ru-en tr-en zh-en; do + source=$(echo $pair | cut -d- -f1) + target=$(echo $pair | cut -d- -f2) + for sgm in wmt17-submitted-data/sgm/system-outputs/newstest2017/$pair/*.sgm; do + name=$(basename $sgm) + + if [[ ! -v MTEVAL14[$name] ]]; then continue; fi + if [[ ! -z $limit_test && $limit_test != $name ]]; then continue; fi + + sys=$(basename $sgm .sgm | perl -pe 's/newstest2017\.//') + txt=$(dirname $sgm | perl -pe 's/sgm/txt/')/$(basename $sgm .sgm) + src=wmt17-submitted-data/sgm/sources/newstest2017-$source$target-src.$source.sgm + ref=wmt17-submitted-data/sgm/references/newstest2017-$source$target-ref.$target.sgm + + score=$(cat $txt | ${CMD} -w 2 -t wmt17 -l $source-$target -b --tokenize intl) + + echo "import sys; sys.exit(1 if abs($score-${MTEVAL14[$name]}) > 0.01 else 0)" | python + + if [[ $? -eq 1 ]]; then + echo "FAILED test $pair/$sys (wanted ${MTEVAL14[$name]} got $score)" + exit 1 + fi + echo "Passed $source-$target $sys mteval-v14.pl: ${MTEVAL14[$name]} sacreBLEU: $score" + + let i++ + done + done +fi ####################################################### # Pre-computed TER scores from official implementation @@ -673,7 +659,7 @@ if [ -z $SKIP_TER ]; then expected_score="${TER[$name]}" - echo "import sys; sys.exit(1 if abs(${score}-${expected_score}) > 0.01 else 0)" | python + echo "import sys; sys.exit(1 if abs(0.01 * ${score}-${expected_score}) > 0.01 else 0)" | python if [[ $? -eq 1 ]]; then echo "FAILED test $pair/$sys (wanted $expected_score got $score)" diff --git a/test/test.py b/test/test.py deleted file mode 100644 index 56cb00db..00000000 --- a/test/test.py +++ /dev/null @@ -1,47 +0,0 @@ -import sacrebleu - -segment = "Consistency is the last refuge of the unimaginative" -score = sacrebleu.corpus_chrf([segment], [segment], 6, 3.0) -assert(score == 1.0) - -ref = "AAAAAA" -sys = "BBBB" -score = sacrebleu.corpus_chrf([sys], [ref], 3, 3.0) -assert(score == 0.0) - -ref = "" -sys = "" -score = sacrebleu.corpus_chrf([sys], [ref], 6, 3) -print(score) -#assert(score == 1.0) - -ref = "A" -sys = "" -score = sacrebleu.corpus_chrf([sys], [ref], 6, 3) -assert(score == 0.0) - -ref = "" -sys = "A" -score = sacrebleu.corpus_chrf([sys], [ref], 6, 3) -assert(score == 0.0) - -ref = "AB" -sys = "AA" -score = sacrebleu.corpus_chrf([sys], [ref], 6, 3) -assert(score == 0.25) - -# segment_a = self.tokenize("A") -# segment_b = self.tokenize("A") -ref = "A" -sys = "A" -score = sacrebleu.corpus_chrf([sys], [ref], 6, 3) -assert(score == 1.0) -# scorer = CharacterFScorer('n=6,beta=3') -# scorer.set_reference(segment_a) -# self.assertEqual(scorer.score(segment_b), 1.0) - -ref = "risk assessment has to be undertaken by those who are qualified and expert in that area - that is the scientists ." -sys = " risk assessment must be made of those who are qualified and expertise in the sector - these are the scientists ." -score = sacrebleu.corpus_chrf([sys], [ref], 6, 3) -print(score) -assert('{0:.5f}'.format(score) == '0.63362') diff --git a/test/test_api.py b/test/test_api.py index bfe4ba3a..facc7d90 100644 --- a/test/test_api.py +++ b/test/test_api.py @@ -13,7 +13,9 @@ import pytest -import sacrebleu +from sacrebleu.utils import get_available_testsets, get_langpairs_for_testset +from sacrebleu.utils import get_source_file, get_reference_files +from sacrebleu.dataset import DATASETS test_api_get_data = [ ("wmt19", "de-en", 1, "Schöne Münchnerin 2018: Schöne Münchnerin 2018 in Hvar: Neun Dates", "The Beauty of Munich 2018: the Beauty of Munich 2018 in Hvar: Nine dates"), @@ -21,42 +23,46 @@ ("wmt19/google/ar", "en-de", 1, "Welsh AMs worried about 'looking like muppets'", "Walisische Abgeordnete befürchten als ,Idioten’ dazustehen."), ] + @pytest.mark.parametrize("testset, langpair, sentno, source, reference", test_api_get_data) def test_api_get_source(testset, langpair, sentno, source, reference): - with open(sacrebleu.get_source_file(testset, langpair)) as fh: + with open(get_source_file(testset, langpair)) as fh: line = fh.readlines()[sentno - 1].strip() assert line == source + @pytest.mark.parametrize("testset, langpair, sentno, source, reference", test_api_get_data) def test_api_get_reference(testset, langpair, sentno, source, reference): - with open(sacrebleu.get_reference_files(testset, langpair)[0]) as fh: + with open(get_reference_files(testset, langpair)[0]) as fh: line = fh.readlines()[sentno - 1].strip() assert line == reference + def test_api_get_available_testsets(): """ Loop over the datasets directly, and ensure the API function returns the test sets found. """ - available = sacrebleu.get_available_testsets() + available = get_available_testsets() assert type(available) is list assert "wmt19" in available assert "wmt05" not in available - for testset in sacrebleu.DATASETS.keys(): + for testset in DATASETS.keys(): assert testset in available assert "slashdot_" + testset not in available + def test_api_get_langpairs_for_testset(): """ Loop over the datasets directly, and ensure the API function returns each language pair in each test set. """ - for testset in sacrebleu.DATASETS.keys(): - available = sacrebleu.get_langpairs_for_testset(testset) + for testset in DATASETS.keys(): + available = get_langpairs_for_testset(testset) assert type(available) is list - for langpair in sacrebleu.DATASETS[testset].keys(): + for langpair in DATASETS[testset].keys(): # skip non-language keys if "-" not in langpair: assert langpair not in available diff --git a/test/test_bleu.py b/test/test_bleu.py index eaceab20..be245ecb 100644 --- a/test/test_bleu.py +++ b/test/test_bleu.py @@ -11,17 +11,23 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. +from collections import namedtuple import pytest + import sacrebleu -from collections import namedtuple + +from sacrebleu.metrics import BLEU + EPSILON = 1e-8 Statistics = namedtuple('Statistics', ['common', 'total']) -test_raw_bleu_cases = [(["this is a test", "another test"], ["ref1", "ref2"], 0.003799178428257963), - (["this is a test"], ["this is a test"], 1.0), - (["this is a fest"], ["this is a test"], 0.223606797749979)] +test_raw_bleu_cases = [ + # This now returns 0.0 score (#141) + (["this is a test", "another test"], [["ref1", "ref2"]], 0.0), + (["this is a test"], [["this is a test"]], 1.0), + (["this is a fest"], [["this is a test"]], 0.223606797749979)] # test for README example with empty hypothesis strings check _refs = [ @@ -44,7 +50,7 @@ (_hyps, _refs, {'smooth_method': 'none'}, 48.530827), ] -test_case_offset = [("am I am a character sequence", "I am a symbol string sequence a a", 0.1555722182, 0)] +test_case_offset = [(["am I am a character sequence"], [["I am a symbol string sequence a a"]], 0.1555722182, 0)] # statistic structure: # - common counts @@ -52,14 +58,14 @@ # - hyp_count # - ref_count -test_case_statistics = [("am I am a character sequence", "I am a symbol string sequence a a", +test_case_statistics = [(["am I am a character sequence"], [["I am a symbol string sequence a a"]], Statistics([4, 2, 1, 0], [6, 5, 4, 3]))] test_case_scoring = [((Statistics([9, 7, 5, 3], [10, 8, 6, 4]), 11, 11), 0.8375922397)] -test_case_effective_order = [(["test"], ["a test"], 0.3678794411714425), - (["a test"], ["a test"], 1.0), - (["a little test"], ["a test"], 0.03218297948685433)] +test_case_effective_order = [(["test"], [["a test"]], 0.3678794411714425), + (["a test"], [["a test"]], 1.0), + (["a little test"], [["a test"]], 0.03218297948685433)] # testing that right score is returned for null statistics and different offsets @@ -74,7 +80,7 @@ @pytest.mark.parametrize("hypotheses, references, expected_bleu", test_raw_bleu_cases) def test_raw_bleu(hypotheses, references, expected_bleu): - bleu = sacrebleu.raw_corpus_bleu(hypotheses, [references], .01).score / 100 + bleu = sacrebleu.raw_corpus_bleu(hypotheses, references, .01).score / 100 assert abs(bleu - expected_bleu) < EPSILON @@ -86,7 +92,7 @@ def test_corpus_bleu(hypotheses, references, kwargs, expected_bleu): @pytest.mark.parametrize("hypotheses, references, expected_bleu", test_case_effective_order) def test_effective_order(hypotheses, references, expected_bleu): - bleu = sacrebleu.raw_corpus_bleu(hypotheses, [references], .01).score / 100 + bleu = sacrebleu.raw_corpus_bleu(hypotheses, references, .01).score / 100 assert abs(bleu - expected_bleu) < EPSILON @@ -99,7 +105,7 @@ def test_statistics(hypothesis, reference, expected_stat): @pytest.mark.parametrize("statistics, expected_score", test_case_scoring) def test_scoring(statistics, expected_score): - score = sacrebleu.compute_bleu(statistics[0].common, statistics[0].total, statistics[1], statistics[2]).score / 100 + score = BLEU.compute_bleu(statistics[0].common, statistics[0].total, statistics[1], statistics[2]).score / 100 assert abs(score - expected_score) < EPSILON @@ -120,5 +126,10 @@ def test_offset(hypothesis, reference, expected_with_offset, expected_without_of @pytest.mark.parametrize("statistics, offset, expected_score", test_case_degenerate_stats) def test_degenerate_statistics(statistics, offset, expected_score): - score = sacrebleu.compute_bleu(statistics[0].common, statistics[0].total, statistics[1], statistics[2], smooth_method='floor', smooth_value=offset).score / 100 + score = BLEU.compute_bleu( + statistics[0].common, + statistics[0].total, + statistics[1], + statistics[2], + smooth_method='floor', smooth_value=offset).score / 100 assert score == expected_score diff --git a/test/test_chrf.py b/test/test_chrf.py index 257d9b9e..24df6d52 100644 --- a/test/test_chrf.py +++ b/test/test_chrf.py @@ -11,37 +11,101 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -# -*- coding: utf-8 -*- - import pytest import sacrebleu -EPSILON = 1e-8 - -test_cases = [(["Niemand hat die Absicht, eine Mauer zu errichten"], ["Niemand hat die Absicht, eine Mauer zu errichten"], 1.0), - (["abcdefg"], ["hijklmnop"], 0.0), - (["a"], ["a"], 1.0), - ([""], [""], 0.0), - ([""], ["reference"], 0.0), - (["a b c"], ["a b c"], 1.0), - (["a b c"], ["abc"], 1.0), - ([""], ["c"], 0.0), - (["a", "b"], ["a", "c"], 0.5), - (["source"], [""], 0.0), - (["aa"], ["ab"], 0.25), - ([" Die Beziehung zwischen Obama und Netanjahu ist nicht gerade freundlich. "], ["Das Verhältnis zwischen Obama und Netanyahu ist nicht gerade freundschaftlich."], 0.64130269831561459), - ([" risk assessment must be made of those who are qualified and expertise in the sector - these are the scientists ."], ["risk assessment has to be undertaken by those who are qualified and expert in that area - that is the scientists ."], 0.63361730303214769)] +EPSILON = 1e-4 + +test_sentence_level_chrf = [ + ( + 'Co nás nejvíc trápí, protože lékaři si vybírají, kdo bude žít a kdo zemře.', + ['Nejvíce smutní jsme z toho, že musíme rozhodovat o tom, kdo bude žít a kdo zemře.'], + 39.14078509, + ), + ( + 'Nebo prostě nemají vybavení, které by jim pomohlo, uvedli lékaři.', + ['A někdy nemáme ani potřebný materiál, abychom jim pomohli, popsali lékaři.'], + 31.22557079, + ), + ( + 'Lapali po dechu, jejich životy skončily dřív, než skutečně začaly.', + ['Lapali po dechu a pak jejich život skončil - dřív, než skutečně mohl začít, připomněli.'], + 57.15704367, + ), +] + + +# hypothesis, reference, expected score +# >= 2.0.0: some orders are not fulfilled in epsilon smoothing (chrF++.py and NLTK) +test_cases = [ + (["abcdefg"], ["hijklmnop"], 0.0), + (["a"], ["b"], 0.0), + ([""], ["b"], 0.0), + ([""], ["ref"], 0.0), + ([""], ["reference"], 0.0), + (["aa"], ["ab"], 8.3333), + (["a", "b"], ["a", "c"], 8.3333), + (["a"], ["a"], 16.6667), + (["a b c"], ["a b c"], 50.0), + (["a b c"], ["abc"], 50.0), + ([" risk assessment must be made of those who are qualified and expertise in the sector - these are the scientists ."], + ["risk assessment has to be undertaken by those who are qualified and expert in that area - that is the scientists ."], 63.361730), + ([" Die Beziehung zwischen Obama und Netanjahu ist nicht gerade freundlich. "], + ["Das Verhältnis zwischen Obama und Netanyahu ist nicht gerade freundschaftlich."], 64.1302698), + (["Niemand hat die Absicht, eine Mauer zu errichten"], ["Niemand hat die Absicht, eine Mauer zu errichten"], 100.0), +] + +# sacreBLEU < 2.0.0 mode +# hypothesis, reference, expected score +test_cases_effective_order = [ + (["a"], ["a"], 100.0), + ([""], ["reference"], 0.0), + (["a b c"], ["a b c"], 100.0), + (["a b c"], ["abc"], 100.0), + ([""], ["c"], 0.0), + (["a", "b"], ["a", "c"], 50.0), + (["aa"], ["ab"], 25.0), +] test_cases_keep_whitespace = [ - (["Die Beziehung zwischen Obama und Netanjahu ist nicht gerade freundlich."], ["Das Verhältnis zwischen Obama und Netanyahu ist nicht gerade freundschaftlich."], 0.67348160629772402), - (["risk assessment must be made of those who are qualified and expertise in the sector - these are the scientists ."], ["risk assessment has to be undertaken by those who are qualified and expert in that area - that is the scientists ."], 0.652414427449)] + ( + ["Die Beziehung zwischen Obama und Netanjahu ist nicht gerade freundlich."], + ["Das Verhältnis zwischen Obama und Netanyahu ist nicht gerade freundschaftlich."], + 67.3481606, + ), + ( + ["risk assessment must be made of those who are qualified and expertise in the sector - these are the scientists ."], + ["risk assessment has to be undertaken by those who are qualified and expert in that area - that is the scientists ."], + 65.2414427, + ), +] + @pytest.mark.parametrize("hypotheses, references, expected_score", test_cases) def test_chrf(hypotheses, references, expected_score): - score = sacrebleu.corpus_chrf(hypotheses, [references], 6, 3).score + score = sacrebleu.corpus_chrf( + hypotheses, [references], char_order=6, word_order=0, beta=3, + eps_smoothing=True).score assert abs(score - expected_score) < EPSILON + +@pytest.mark.parametrize("hypotheses, references, expected_score", test_cases_effective_order) +def test_chrf_eff_order(hypotheses, references, expected_score): + score = sacrebleu.corpus_chrf( + hypotheses, [references], char_order=6, word_order=0, beta=3, + eps_smoothing=False).score + assert abs(score - expected_score) < EPSILON + + @pytest.mark.parametrize("hypotheses, references, expected_score", test_cases_keep_whitespace) def test_chrf_keep_whitespace(hypotheses, references, expected_score): - score = sacrebleu.corpus_chrf(hypotheses, [references], 6, 3, remove_whitespace=False).score + score = sacrebleu.corpus_chrf( + hypotheses, [references], char_order=6, word_order=0, beta=3, + remove_whitespace=False).score + assert abs(score - expected_score) < EPSILON + + +@pytest.mark.parametrize("hypothesis, references, expected_score", test_sentence_level_chrf) +def test_chrf_sentence_level(hypothesis, references, expected_score): + score = sacrebleu.sentence_chrf(hypothesis, references, eps_smoothing=True).score assert abs(score - expected_score) < EPSILON diff --git a/test/test_sentence_bleu.py b/test/test_sentence_bleu.py new file mode 100644 index 00000000..88afa2eb --- /dev/null +++ b/test/test_sentence_bleu.py @@ -0,0 +1,72 @@ +import pytest +import sacrebleu + +EPSILON = 1e-3 + + +# Example taken from #98 +REF = "producţia de zahăr brut se exprimă în zahăr alb;" +SYS = "Producția de zahăr primă va fi exprimată în ceea ce privește zahărul alb;" + +test_cases = [ + # change smoothing + ('exp', None, False, '13a', 8.493), + ('none', None, False, '13a', 0.0), + ('floor', None, False, '13a', 4.51688), # defaults to 0.1 + ('floor', 0.1, False, '13a', 4.51688), + ('floor', 0.5, False, '13a', 10.10), + ('add-k', None, False, '13a', 14.882), # defaults to 1 + ('add-k', 1, False, '13a', 14.882), + ('add-k', 2, False, '13a', 21.389), + # change tok + ('exp', None, False, 'none', 7.347), + ('exp', None, False, 'intl', 8.493), + ('exp', None, False, 'char', 40.8759), + # change case + ('exp', None, True, 'char', 42.0267), +] + + +# Example taken from #141 +REF_0 = "okay thanks" +SYS_0 = "this is a cat" + +test_cases_zero_bleu = [ + ('exp', None, False, '13a', 0.0), + ('none', None, False, '13a', 0.0), + ('floor', None, False, '13a', 0.0), # defaults to 0.1 + ('floor', 0.1, False, '13a', 0.0), + ('add-k', None, False, '13a', 0.0), # defaults to 1 + ('add-k', 1, False, '13a', 0.0), +] + + +@pytest.mark.parametrize("smooth_method, smooth_value, lowercase, tok, expected_score", test_cases) +def test_compat_sentence_bleu(smooth_method, smooth_value, lowercase, tok, expected_score): + score = sacrebleu.compat.sentence_bleu( + SYS, [REF], smooth_method=smooth_method, smooth_value=smooth_value, + tokenize=tok, + lowercase=lowercase, + use_effective_order=True) + assert abs(score.score - expected_score) < EPSILON + + +@pytest.mark.parametrize("smooth_method, smooth_value, lowercase, tok, expected_score", test_cases) +def test_api_sentence_bleu(smooth_method, smooth_value, lowercase, tok, expected_score): + metric = sacrebleu.metrics.BLEU( + lowercase=lowercase, force=False, tokenize=tok, + smooth_method=smooth_method, smooth_value=smooth_value, + effective_order=True) + score = metric.sentence_score(SYS, [REF]) + + assert abs(score.score - expected_score) < EPSILON + + +@pytest.mark.parametrize("smooth_method, smooth_value, lowercase, tok, expected_score", test_cases_zero_bleu) +def test_api_sentence_bleu_zero(smooth_method, smooth_value, lowercase, tok, expected_score): + metric = sacrebleu.metrics.BLEU( + lowercase=lowercase, force=False, tokenize=tok, + smooth_method=smooth_method, smooth_value=smooth_value, + effective_order=True) + score = metric.sentence_score(SYS_0, [REF_0]) + assert abs(score.score - expected_score) < EPSILON diff --git a/test/test_significance.py b/test/test_significance.py new file mode 100644 index 00000000..46679ac4 --- /dev/null +++ b/test/test_significance.py @@ -0,0 +1,104 @@ +import os + +from collections import defaultdict + +from sacrebleu.metrics import BLEU +from sacrebleu.significance import PairedTest + +import pytest + + +def _read_pickle_file(): + import bz2 + import pickle as pkl + with bz2.BZ2File('./test/wmt17_en_de_systems.pkl.bz2', 'rb') as f: + data = pkl.load(f) + return data + + +# P-values obtained from Moses' significance script (mean of 3 runs) +# Script: scripts/moses-sigdiff.pl (modified to bootstrap samples = 2000) +MOSES_P_VALS = { + "newstest2017.C-3MA.4959.en-de": 0.00000, + "newstest2017.FBK.4870.en-de": 0.01267, + "newstest2017.KIT.4950.en-de": 0.02233, + "newstest2017.LMU-nmt-reranked.4934.en-de": 0.04383, + "newstest2017.LMU-nmt-single.4893.en-de": 0.20783, + "newstest2017.online-A.0.en-de": 0.00000, + "newstest2017.online-B.0.en-de": 0.38100, + "newstest2017.online-F.0.en-de": 0.00000, + "newstest2017.online-G.0.en-de": 0.00000, + "newstest2017.PROMT-Rule-based.4735.en-de": 0.00000, + "newstest2017.RWTH-nmt-ensemble.4921.en-de": 0.01167, + "newstest2017.SYSTRAN.4847.en-de": 0.20983, + "newstest2017.TALP-UPC.4834.en-de": 0.00000, + "newstest2017.uedin-nmt.4722.en-de": 0.00000, + "newstest2017.xmu.4910.en-de": 0.25483, +} + +# Obtained from the multeval toolkit, 10,000 AR trials, (BLEU and TER) +# Code: github.com/mjclark/multeval.git +MULTEVAL_P_VALS = { + "newstest2017.C-3MA.4959.en-de": (0.0001, 0.0001), + "newstest2017.FBK.4870.en-de": (0.0218, 0.09569), + "newstest2017.KIT.4950.en-de": (0.0410, 0.0002), + "newstest2017.LMU-nmt-reranked.4934.en-de": (0.09029, 0.0001), + "newstest2017.LMU-nmt-single.4893.en-de": (0.58494, 0.0054), + "newstest2017.online-A.0.en-de": (0.0001, 0.0001), + "newstest2017.online-B.0.en-de": (0.94111, 0.82242), + "newstest2017.online-F.0.en-de": (0.0001, 0.0001), + "newstest2017.online-G.0.en-de": (0.0001, 0.0001), + "newstest2017.PROMT-Rule-based.4735.en-de": (0.0001, 0.0001), + "newstest2017.RWTH-nmt-ensemble.4921.en-de": (0.0207, 0.07539), + "newstest2017.SYSTRAN.4847.en-de": (0.59914, 0.0001), + "newstest2017.TALP-UPC.4834.en-de": (0.0001, 0.0001), + "newstest2017.uedin-nmt.4722.en-de": (0.0001, 0.0001), + "newstest2017.xmu.4910.en-de": (0.71073, 0.0001), +} + + +SACREBLEU_BS_P_VALS = defaultdict(float) +SACREBLEU_AR_P_VALS = defaultdict(float) + +# Load data from pickled file to not bother with WMT17 downloading +named_systems = _read_pickle_file() +_, refs = named_systems.pop() +metrics = {'BLEU': BLEU(references=refs, tokenize='none')} + + +######### +# BS test +######### +os.environ['SACREBLEU_SEED'] = str(12345) +bs_scores = PairedTest( + named_systems, metrics, references=None, + test_type='bs', n_samples=2000)()[1] + +for name, result in zip(bs_scores['System'], bs_scores['BLEU']): + if result.p_value is not None: + SACREBLEU_BS_P_VALS[name] += result.p_value + + +############################################### +# AR test (1 run) +# Test only BLEU as TER will take too much time +############################################### +ar_scores = PairedTest(named_systems, metrics, references=None, + test_type='ar', n_samples=10000)()[1] + +for name, result in zip(ar_scores['System'], ar_scores['BLEU']): + if result.p_value is not None: + SACREBLEU_AR_P_VALS[name] += result.p_value + + +@pytest.mark.parametrize("name, expected_p_val", MOSES_P_VALS.items()) +def test_paired_bootstrap(name, expected_p_val): + p_val = SACREBLEU_BS_P_VALS[name] + assert abs(p_val - expected_p_val) < 1e-2 + + +@pytest.mark.parametrize("name, expected_p_vals", MULTEVAL_P_VALS.items()) +def test_paired_approximate_randomization(name, expected_p_vals): + expected_bleu_p_val = expected_p_vals[0] + p_val = SACREBLEU_AR_P_VALS[name] + assert abs(p_val - expected_bleu_p_val) < 1e-2 diff --git a/test/test_ter.py b/test/test_ter.py index 904eb0ed..8457a5cd 100644 --- a/test/test_ter.py +++ b/test/test_ter.py @@ -1,15 +1,13 @@ import pytest import sacrebleu -from argparse import Namespace - EPSILON = 1e-3 test_cases = [ (['aaaa bbbb cccc dddd'], ['aaaa bbbb cccc dddd'], 0), # perfect match (['dddd eeee ffff'], ['aaaa bbbb cccc'], 1), # no overlap - ([''], [''], 1), # corner case, empty strings - (['d e f g h a b c'], ['a b c d e f g h'], 1/8), # a single shift fixes MT + ([''], ['a'], 1), # corner case, empty hypothesis + (['d e f g h a b c'], ['a b c d e f g h'], 1 / 8), # a single shift fixes MT ( [ 'wählen Sie " Bild neu berechnen , " um beim Ändern der Bildgröße Pixel hinzuzufügen oder zu entfernen , damit das Bild ungefähr dieselbe Größe aufweist wie die andere Größe .', @@ -32,7 +30,6 @@ @pytest.mark.parametrize("hypotheses, references, expected_score", test_cases) def test_ter(hypotheses, references, expected_score): - args = Namespace(tokenize=sacrebleu.DEFAULT_TOKENIZER) - metric = sacrebleu.metrics.TER(args) + metric = sacrebleu.metrics.TER() score = metric.corpus_score(hypotheses, [references]).score - assert abs(score - expected_score) < EPSILON + assert abs(score - 100 * expected_score) < EPSILON diff --git a/test/test_tokenizer_ter.py b/test/test_tokenizer_ter.py index 0cd808a3..e76a03bf 100644 --- a/test/test_tokenizer_ter.py +++ b/test/test_tokenizer_ter.py @@ -16,8 +16,8 @@ ] test_cases_norm = [ - ("a b (c) d.", "a b ( c ) d"), - ("Jim's car.", "Jim 's car ."), + ("a b (c) d.", "a b ( c ) d ."), + ("Jim's car.", "jim 's car ."), ("4.2", "4.2"), ] @@ -34,18 +34,18 @@ def test_ter_tokenizer_default(input, expected): @pytest.mark.parametrize("input, expected", test_cases_no_punct) -def test_ter_tokenizer_default(input, expected): +def test_ter_tokenizer_nopunct(input, expected): tokenizer = TercomTokenizer(no_punct=True) assert tokenizer(input) == expected @pytest.mark.parametrize("input, expected", test_cases_norm) -def test_ter_tokenizer_default(input, expected): +def test_ter_tokenizer_norm(input, expected): tokenizer = TercomTokenizer(normalized=True) assert tokenizer(input) == expected @pytest.mark.parametrize("input, expected", test_cases_asian) -def test_ter_tokenizer_default(input, expected): +def test_ter_tokenizer_asian(input, expected): tokenizer = TercomTokenizer(normalized=True, asian_support=True) assert tokenizer(input) == expected diff --git a/test/wmt17_en_de_systems.pkl.bz2 b/test/wmt17_en_de_systems.pkl.bz2 new file mode 100644 index 00000000..227b630e Binary files /dev/null and b/test/wmt17_en_de_systems.pkl.bz2 differ diff --git a/tox.ini b/tox.ini new file mode 100644 index 00000000..31159a2d --- /dev/null +++ b/tox.ini @@ -0,0 +1,2 @@ +[flake8] +ignore = E501,E265