Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit c93a188

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
New BLEU cleanup and small correction to VAE.
PiperOrigin-RevId: 177547599
1 parent b1abcf4 commit c93a188

File tree

3 files changed

+71
-4
lines changed

3 files changed

+71
-4
lines changed

tensor2tensor/models/transformer_vae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def vae(x, z_size, name):
139139
kl = 0.5 * tf.reduce_mean(
140140
tf.exp(log_sigma) + tf.square(mu) - 1. - log_sigma, axis=-1)
141141
free_bits = z_size // 2
142-
kl_loss = tf.maximum(tf.reduce_mean(kl) - free_bits, 0.0)
142+
kl_loss = tf.reduce_mean(tf.maximum(kl - free_bits, 0.0))
143143
return z, kl_loss, mu, log_sigma
144144

145145

tensor2tensor/utils/bleu_hook.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,14 @@
2020

2121
import collections
2222
import math
23+
import re
24+
import sys
25+
import unicodedata
2326

2427
# Dependency imports
2528

2629
import numpy as np
30+
import six
2731
# pylint: disable=redefined-builtin
2832
from six.moves import xrange
2933
from six.moves import zip
@@ -93,9 +97,15 @@ def compute_bleu(reference_corpus,
9397
for ngram in translation_ngram_counts:
9498
possible_matches_by_order[len(ngram)-1] += translation_ngram_counts[ngram]
9599
precisions = [0] * max_order
100+
smooth = 1.0
96101
for i in xrange(0, max_order):
97102
if possible_matches_by_order[i] > 0:
98103
precisions[i] = matches_by_order[i] / possible_matches_by_order[i]
104+
if matches_by_order[i] > 0:
105+
precisions[i] = matches_by_order[i] / possible_matches_by_order[i]
106+
else:
107+
smooth *= 2
108+
precisions[i] = 1.0 / (smooth * possible_matches_by_order[i])
99109
else:
100110
precisions[i] = 0.0
101111

@@ -131,3 +141,59 @@ def bleu_score(predictions, labels, **unused_kwargs):
131141

132142
bleu = tf.py_func(compute_bleu, (labels, outputs), tf.float32)
133143
return bleu, tf.constant(1.0)
144+
145+
146+
class UnicodeRegex(object):
147+
"""Ad-hoc hack to recognize all punctuation and symbols."""
148+
149+
def __init__(self):
150+
def _property_chars(prefix):
151+
return ''.join(six.unichr(x) for x in range(sys.maxunicode)
152+
if unicodedata.category(six.unichr(x)).startswith(prefix))
153+
punctuation = self._property_chars('P')
154+
self.nondigit_punct_re = re.compile(r'([^\d])([' + punctuation + r'])')
155+
self.punct_nondigit_re = re.compile(r'([' + punctuation + r'])([^\d])')
156+
self.symbol_re = re.compile('([' + _property_chars('S') + '])')
157+
158+
159+
def bleu_tokenize(string):
160+
r"""Tokenize a string following the official BLEU implementation.
161+
162+
See https://github.com/moses-smt/mosesdecoder/"
163+
"blob/master/scripts/generic/mteval-v14.pl#L954-L983
164+
In our case, the input string is expected to be just one line
165+
and no HTML entities de-escaping is needed.
166+
So we just tokenize on punctuation and symbols,
167+
except when a punctuation is preceded and followed by a digit
168+
(e.g. a comma/dot as a thousand/decimal separator).
169+
170+
Note that a numer (e.g. a year) followed by a dot at the end of sentence
171+
is NOT tokenized,
172+
i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g`
173+
does not match this case (unless we add a space after each sentence).
174+
However, this error is already in the original mteval-v14.pl
175+
and we want to be consistent with it.
176+
177+
Args:
178+
string: the input string
179+
180+
Returns:
181+
a list of tokens
182+
"""
183+
string = UnicodeRegex.nondigit_punct_re.sub(r'\1 \2 ', string)
184+
string = UnicodeRegex.punct_nondigit_re.sub(r' \1 \2', string)
185+
string = UnicodeRegex.symbol_re.sub(r' \1 ', string)
186+
return string.split()
187+
188+
189+
def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False):
190+
"""Compute BLEU for two files (reference and hypothesis translation)."""
191+
ref_lines = open(ref_filename).read().splitlines()
192+
hyp_lines = open(hyp_filename).read().splitlines()
193+
assert len(ref_lines) == len(hyp_lines)
194+
if not case_sensitive:
195+
ref_lines = [x.lower() for x in ref_lines]
196+
hyp_lines = [x.lower() for x in hyp_lines]
197+
ref_tokens = [bleu_tokenize(x) for x in ref_lines]
198+
hyp_tokens = [bleu_tokenize(x) for x in hyp_lines]
199+
return compute_bleu(ref_tokens, hyp_tokens)

tensor2tensor/utils/bleu_hook_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@ def testComputeNotEqual(self):
3939
translation_corpus = [[1, 2, 3, 4]]
4040
reference_corpus = [[5, 6, 7, 8]]
4141
bleu = bleu_hook.compute_bleu(reference_corpus, translation_corpus)
42-
actual_bleu = 0.0
43-
self.assertEqual(bleu, actual_bleu)
42+
# The smoothing prevents 0 for small corpora
43+
actual_bleu = 0.0798679
44+
self.assertAllClose(bleu, actual_bleu, atol=1e-03)
4445

4546
def testComputeMultipleBatch(self):
4647
translation_corpus = [[1, 2, 3, 4], [5, 6, 7, 0]]
@@ -53,7 +54,7 @@ def testComputeMultipleNgrams(self):
5354
reference_corpus = [[1, 2, 1, 13], [12, 6, 7, 4, 8, 9, 10]]
5455
translation_corpus = [[1, 2, 1, 3], [5, 6, 7, 4]]
5556
bleu = bleu_hook.compute_bleu(reference_corpus, translation_corpus)
56-
actual_bleu = 0.486
57+
actual_bleu = 0.3436
5758
self.assertAllClose(bleu, actual_bleu, atol=1e-03)
5859

5960
if __name__ == '__main__':

0 commit comments

Comments
 (0)