2828# Dependency imports
2929
3030import six
31+ from six import PY2
3132from six .moves import xrange # pylint: disable=redefined-builtin
3233from tensor2tensor .data_generators import tokenizer
3334
3435import tensorflow as tf
3536
37+
38+ # Conversion between Unicode and UTF-8, if required (on Python2)
39+ _native_to_unicode = (lambda s : s .decode ("utf-8" )) if PY2 else (lambda s : s )
40+
41+
42+ _unicode_to_native = (lambda s : s .encode ("utf-8" )) if PY2 else (lambda s : s )
43+
44+
3645# Reserved tokens for things like padding and EOS symbols.
3746PAD = "<pad>"
3847EOS = "<EOS>"
@@ -162,15 +171,36 @@ def _load_vocab_from_file(self, filename):
162171
163172
164173class SubwordTextEncoder (TextEncoder ):
165- """Class for breaking tokens into subtokens .
174+ """Class for invertibly encoding text using a limited vocabulary .
166175
167- Invertibly encodes a string as a sequence of subtokens from a limited
176+ Invertibly encodes a native string as a sequence of subtokens from a limited
168177 vocabulary.
169178
170179 A SubwordTextEncoder is built from a corpus (so it is tailored to the text in
171180 the corpus), and stored to a file. See text_encoder_build_subword.py.
172181
173182 It can then be loaded and used to encode/decode any text.
183+
184+ Encoding has four phases:
185+
186+ 1. Tokenize into a list of tokens. Each token is a unicode string of either
187+ all alphanumeric characters or all non-alphanumeric characters. We drop
188+ tokens consisting of a single space that are between two alphanumeric
189+ tokens.
190+
191+ 2. Escape each token. This escapes away special and out-of-vocabulary
192+ characters, and makes sure that each token ends with an underscore, and
193+ has no other underscores.
194+
195+ 3. Represent each escaped token as a the concatenation of a list of subtokens
196+ from the limited vocabulary. Subtoken selection is done greedily from
197+ beginning to end. That is, we construct the list in order, always picking
198+ the longest subtoken in our vocabulary that matches a prefix of the
199+ remaining portion of the encoded token.
200+
201+ 4. Concatenate these lists. This concatenation is invertible due to the
202+ fact that the trailing underscores indicate when one list is finished.
203+
174204 """
175205
176206 def __init__ (self , filename = None , num_reserved_ids = 2 ):
@@ -182,24 +212,26 @@ def __init__(self, filename=None, num_reserved_ids=2):
182212 super (SubwordTextEncoder , self ).__init__ (num_reserved_ids = num_reserved_ids )
183213
184214 def encode (self , raw_text ):
185- """Converts a string to a list of subtoken ids.
215+ """Converts a native string to a list of subtoken ids.
186216
187217 Args:
188- raw_text: a string.
218+ raw_text: a native string.
189219 Returns:
190220 a list of integers in the range [0, vocab_size)
191221 """
192- return self ._tokens_to_subtokens (self ._tokenizer .encode (raw_text ))
222+ return self ._tokens_to_subtokens (self ._tokenizer .encode (
223+ _native_to_unicode (raw_text )))
193224
194225 def decode (self , subtokens ):
195- """Converts a sequence of subtoken ids to a string.
226+ """Converts a sequence of subtoken ids to a native string.
196227
197228 Args:
198229 subtokens: a list of integers in the range [0, vocab_size)
199230 Returns:
200- a string
231+ a native string
201232 """
202- return self ._tokenizer .decode (self ._subtokens_to_tokens (subtokens ))
233+ return _unicode_to_native (self ._tokenizer .decode (
234+ self ._subtokens_to_tokens (subtokens )))
203235
204236 @property
205237 def vocab_size (self ):
@@ -239,8 +271,8 @@ def subtoken_to_subtoken_string(self, subtoken):
239271 if subtoken_string :
240272 return subtoken_string
241273 if 0 <= subtoken < self ._num_reserved_ids :
242- return "%s_" % RESERVED_TOKENS [subtoken ]
243- return "ID%d_" % subtoken
274+ return u "%s_" % RESERVED_TOKENS [subtoken ]
275+ return u "ID%d_" % subtoken
244276
245277 def _escaped_token_to_subtokens (self , escaped_token ):
246278 """Converts an escaped token string to a list of subtokens.
@@ -260,27 +292,11 @@ def _escaped_token_to_subtokens(self, escaped_token):
260292 if subtoken != - 1 :
261293 break
262294 end -= 1
263- if end > pos :
264- ret .append (subtoken )
265- pos = end
266- else :
267- # No subtoken in the vocabulary matches escaped_token[pos].
268- # This can happen if the token contains a Unicode character
269- # that did not occur in the vocabulary training set.
270- # The id self.vocab_size - 1 is decoded as Unicode uFFFD,
271- # REPLACEMENT_CHARACTER.
272- ret .append (self .vocab_size - 1 )
273- # Ensure that the outer loop continues
274- pos += 1
275- return ret
295+ assert end > pos
296+ ret .append (subtoken )
297+ pos = end
276298
277- @classmethod
278- def alphabet (cls , token_counts ):
279- """Return the set of Unicode characters that appear in the tokens."""
280- alphabet_set = set ()
281- for token in six .iterkeys (token_counts ):
282- alphabet_set |= set (token )
283- return alphabet_set
299+ return ret
284300
285301 @classmethod
286302 def build_to_target_size (cls ,
@@ -304,23 +320,21 @@ def build_to_target_size(cls,
304320 Returns:
305321 a SubwordTextEncoder instance.
306322 """
307- # Calculate the alphabet, i.e. the set of all Unicode characters
308- # that appear in the tokens.
309- alphabet_set = cls .alphabet (token_counts )
310- tf .logging .info ("Alphabet contains %d characters" % len (alphabet_set ))
311-
312323 def bisect (min_val , max_val ):
324+ """Bisection to find the right size."""
313325 present_count = (max_val + min_val ) // 2
314326 tf .logging .info ("Trying min_count %d" % present_count )
315327 subtokenizer = cls ()
316- subtokenizer .build_from_token_counts (token_counts , alphabet_set ,
328+ subtokenizer .build_from_token_counts (token_counts ,
317329 present_count , num_iterations )
318330 if min_val >= max_val or subtokenizer .vocab_size == target_size :
319331 return subtokenizer
332+
320333 if subtokenizer .vocab_size > target_size :
321334 other_subtokenizer = bisect (present_count + 1 , max_val )
322335 else :
323336 other_subtokenizer = bisect (min_val , present_count - 1 )
337+
324338 if (abs (other_subtokenizer .vocab_size - target_size ) <
325339 abs (subtokenizer .vocab_size - target_size )):
326340 return other_subtokenizer
@@ -330,17 +344,29 @@ def bisect(min_val, max_val):
330344
331345 def build_from_token_counts (self ,
332346 token_counts ,
333- alphabet_set ,
334347 min_count ,
335348 num_iterations = 4 ):
336349 """Train a SubwordTextEncoder based on a dictionary of word counts.
337350
338351 Args:
339352 token_counts: a dictionary of Unicode strings to int.
340- alphabet_set: the set of Unicode characters that appear in the tokens.
341353 min_count: an integer - discard subtokens with lower counts.
342354 num_iterations: an integer. how many iterations of refinement.
343355 """
356+ # first determine the alphabet to include all characters with count at
357+ # least min_count in the dataset.
358+ char_counts = defaultdict (int )
359+ for token , count in six .iteritems (token_counts ):
360+ for c in token :
361+ char_counts [c ] += count
362+ self ._alphabet = set ()
363+ for c , count in six .iteritems (char_counts ):
364+ if count >= min_count :
365+ self ._alphabet .add (c )
366+ # Make sure all characters needed for escaping are included
367+ for c in u"\\ _;0123456789" :
368+ self ._alphabet .add (c )
369+
344370 # We build iteratively. On each iteration, we segment all the words,
345371 # then count the resulting potential subtokens, keeping the ones
346372 # with high enough counts for our new vocabulary.
@@ -364,43 +390,36 @@ def build_from_token_counts(self,
364390 for end in xrange (start + 1 , len (escaped_token ) + 1 ):
365391 subtoken_string = escaped_token [start :end ]
366392 counts [subtoken_string ] += count
393+ # Make sure all characters needed for escaping are included
394+ for c in self ._alphabet :
395+ counts [c ] += min_count
367396 # Array of sets of candidate subtoken strings, by length
368397 len_to_subtoken_strings = []
369398 for subtoken_string , count in six .iteritems (counts ):
370399 lsub = len (subtoken_string )
371- # All subtoken strings of length 1 are automatically included
372- # later, so we don't need to consider them here
373- if count < min_count or lsub <= 1 :
374- continue
375- # Add this subtoken string to its length set
376- while len (len_to_subtoken_strings ) <= lsub :
377- len_to_subtoken_strings .append (set ())
378- len_to_subtoken_strings [lsub ].add (subtoken_string )
400+ if count >= min_count :
401+ # Add this subtoken string to its length set
402+ while len (len_to_subtoken_strings ) <= lsub :
403+ len_to_subtoken_strings .append (set ())
404+ len_to_subtoken_strings [lsub ].add (subtoken_string )
379405 new_subtoken_strings = []
380406 # consider the candidates longest to shortest, so that if we accept
381407 # a longer subtoken string, we can decrement the counts of its prefixes.
382- for subtoken_strings in reversed (len_to_subtoken_strings [2 :]):
408+ for lsub in reversed (range (1 , len (len_to_subtoken_strings ))):
409+ subtoken_strings = len_to_subtoken_strings [lsub ]
383410 for subtoken_string in subtoken_strings :
384411 count = counts [subtoken_string ]
385- if count < min_count :
386- continue
387- new_subtoken_strings .append ((count , subtoken_string ))
388- for l in xrange (1 , len (subtoken_string )):
389- counts [subtoken_string [:l ]] -= count
390- # Sort what we've got so far in decreasing order by count
412+ if count >= min_count :
413+ new_subtoken_strings .append ((count , subtoken_string ))
414+ for l in xrange (1 , lsub ):
415+ counts [subtoken_string [:l ]] -= count
416+ # Sort in decreasing order by count
391417 new_subtoken_strings .sort (reverse = True )
392- # Add the alphabet set at the end of the vocabulary list
393- for char in alphabet_set :
394- new_subtoken_strings .append ((0 , char ))
395- # Also include the Unicode REPLACEMENT CHARACTER to use
396- # when encountering previously unseen Unicode characters
397- # in the input (i.e. input external to the tokenizer training
398- # set, which may thus contain characters not in the alphabet_set).
399- # This must be the last entry in the subtoken vocabulary list.
400- new_subtoken_strings .append ((0 , u"\uFFFD " ))
401418 # Now we have a candidate vocabulary
419+ old_alphabet = self ._alphabet
402420 self ._init_from_list ([u"" ] * self ._num_reserved_ids +
403421 [p [1 ] for p in new_subtoken_strings ])
422+ assert old_alphabet == self ._alphabet
404423 tf .logging .info ("vocab_size = %d" % self .vocab_size )
405424
406425 original = "This sentence was encoded by the SubwordTextEncoder."
@@ -423,46 +442,77 @@ def _init_from_list(self, subtoken_strings):
423442 self ._all_subtoken_strings = subtoken_strings
424443 self ._subtoken_string_to_id = {
425444 s : i for i , s in enumerate (subtoken_strings ) if s }
445+ self ._alphabet = set ([c for c in subtoken_strings if len (c ) == 1 ])
426446
427447 def _load_from_file (self , filename ):
428448 """Load from a file."""
429449 subtoken_strings = []
430450 with tf .gfile .Open (filename ) as f :
431451 for line in f :
432- if six .PY2 :
433- subtoken_strings .append (line .strip ()[1 :- 1 ].decode ("utf-8" ))
434- else :
435- subtoken_strings .append (line .strip ()[1 :- 1 ])
452+ subtoken_strings .append (_native_to_unicode (line .strip ()[1 :- 1 ]))
436453 self ._init_from_list (subtoken_strings )
437454
438455 def store_to_file (self , filename ):
439456 with tf .gfile .Open (filename , "w" ) as f :
440457 for subtoken_string in self ._all_subtoken_strings :
441- if six .PY2 :
442- f .write ("'" + subtoken_string .encode ("utf-8" ) + "'\n " )
443- else :
444- f .write ("'" + subtoken_string + "'\n " )
458+ f .write ("'" + _unicode_to_native (subtoken_string ) + "'\n " )
445459
446460 def _escape_token (self , token ):
447- r"""Translate '\'->'\\' and '_'->'\u', then append '_'.
461+ r"""Escape away underscores and OOV characters and append '_'.
462+
463+ This allows the token to be experessed as the concatenation of a list
464+ of subtokens from the vocabulary. The underscore acts as a sentinel
465+ which allows us to invertibly concatenate multiple such lists.
448466
449467 Args:
450- token: a string
468+ token: a unicode string
451469 Returns:
452- escaped_token: a string
470+ escaped_token: a unicode string
453471 """
454- return token .replace ("\\ " , "\\ \\ " ).replace ("_" , "\\ u" ) + "_"
472+ token = token .replace ("\\ " , "\\ \\ " ).replace ("_" , "\\ u" ) + "_"
473+ ret = u""
474+ for c in token :
475+ if c in self ._alphabet :
476+ ret += c
477+ else :
478+ ret += u"\\ %d;" % ord (c )
479+ return ret
455480
456481 def _unescape_token (self , escaped_token ):
457- r"""Remove '_' from end, then translate '\\'->'\' and '\u'->'_' .
482+ r"""Inverse of _escape_token() .
458483
459484 Args:
460- escaped_token: a string
485+ escaped_token: a unicode string
461486 Returns:
462- token: a string
487+ token: a unicode string
463488 """
464- assert escaped_token [- 1 ] == "_"
465- return escaped_token [:- 1 ].replace ("\\ u" , "_" ).replace ("\\ \\ " , "\\ " )
489+ ret = u""
490+ escaped_token = escaped_token [:- 1 ]
491+ pos = 0
492+ while pos < len (escaped_token ):
493+ c = escaped_token [pos ]
494+ if c == "\\ " :
495+ pos += 1
496+ c = escaped_token [pos ]
497+ if c == u"u" :
498+ ret += u"_"
499+ pos += 1
500+ elif c == "\\ " :
501+ ret += u"_"
502+ pos += 1
503+ else :
504+ semicolon_pos = escaped_token .find (u";" , pos )
505+ if semicolon_pos == - 1 :
506+ continue
507+ try :
508+ ret += unichr (int (escaped_token [pos :semicolon_pos ]))
509+ pos = semicolon_pos + 1
510+ except (ValueError , OverflowError ) as _ :
511+ pass
512+ else :
513+ ret += c
514+ pos += 1
515+ return ret
466516
467517 @classmethod
468518 def get_token_counts (cls , text_filepattern , corpus_max_lines ):
@@ -474,7 +524,7 @@ def get_token_counts(cls, text_filepattern, corpus_max_lines):
474524 with tf .gfile .Open (text_filename ) as f :
475525 for line in f :
476526 # The tokenizer updates token_counts in encode()
477- tok .encode (line .strip ())
527+ tok .encode (_native_to_unicode ( line .strip () ))
478528 lines_read += 1
479529 if corpus_max_lines > 0 and lines_read > corpus_max_lines :
480530 return tok .token_counts
0 commit comments