Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 5a72e5c

Browse files
Adapted to upstream tokenizer change
1 parent 7bf4936 commit 5a72e5c

File tree

3 files changed

+10
-10
lines changed

3 files changed

+10
-10
lines changed

tensor2tensor/data_generators/generator_utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -300,21 +300,20 @@ def get_or_generate_tabbed_vocab(tmp_dir, source_filename, index, vocab_filename
300300
vocab = text_encoder.SubwordTextEncoder(vocab_filepath)
301301
return vocab
302302

303-
tokenizer = Tokenizer()
304-
305303
# Use Tokenizer to count the word occurrences.
304+
token_counts = defaultdict(int)
306305
filepath = os.path.join(tmp_dir, source_filename)
307306
with tf.gfile.GFile(filepath, mode="r") as source_file:
308307
for line in source_file:
309308
line = line.strip()
310309
if line and '\t' in line:
311310
parts = line.split('\t', maxsplit = 1)
312311
part = parts[index].strip()
313-
_ = tokenizer.encode(text_encoder.native_to_unicode(part))
312+
for tok in tokenizer.encode(text_encoder.native_to_unicode(part)):
313+
token_counts[tok] += 1
314314

315315
vocab = text_encoder.SubwordTextEncoder.build_to_target_size(
316-
vocab_size, tokenizer.token_counts, 1,
317-
min(1e3, vocab_size + text_encoder.NUM_RESERVED_TOKENS))
316+
vocab_size, token_counts, 1, 1e3)
318317
vocab.store_to_file(vocab_filepath)
319318
return vocab
320319

tensor2tensor/data_generators/text_encoder.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
PAD_TOKEN = RESERVED_TOKENS.index(PAD) # Normally 0
5555
EOS_TOKEN = RESERVED_TOKENS.index(EOS) # Normally 1
5656

57-
if six.PY2:
57+
if PY2:
5858
RESERVED_TOKENS_BYTES = RESERVED_TOKENS
5959
else:
6060
RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")]
@@ -110,7 +110,7 @@ class ByteTextEncoder(TextEncoder):
110110

111111
def encode(self, s):
112112
numres = self._num_reserved_ids
113-
if six.PY2:
113+
if PY2:
114114
return [ord(c) + numres for c in s]
115115
# Python3: explicitly convert to UTF-8
116116
return [c + numres for c in s.encode("utf-8")]
@@ -124,7 +124,7 @@ def decode(self, ids):
124124
decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)])
125125
else:
126126
decoded_ids.append(int2byte(id_ - numres))
127-
if six.PY2:
127+
if PY2:
128128
return "".join(decoded_ids)
129129
# Python3: join byte arrays and then decode string
130130
return b"".join(decoded_ids).decode("utf-8", "replace")
@@ -469,7 +469,7 @@ def store_to_file(self, filename):
469469
f.write("'" + unicode_to_native(subtoken_string) + "'\n")
470470

471471
def _escape_token(self, token):
472-
r"""Escape away underscores and OOV characters and append '_'.
472+
"""Escape away underscores and OOV characters and append '_'.
473473
474474
This allows the token to be experessed as the concatenation of a list
475475
of subtokens from the vocabulary. The underscore acts as a sentinel
@@ -491,7 +491,7 @@ def _escape_token(self, token):
491491
return ret
492492

493493
def _unescape_token(self, escaped_token):
494-
r"""Inverse of _escape_token().
494+
"""Inverse of _escape_token().
495495
496496
Args:
497497
escaped_token: a unicode string

tensor2tensor/data_generators/tokenizer.py

100644100755
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def read_corpus():
141141
if corpus_max_lines > 0 and lines_read > corpus_max_lines:
142142
return docs
143143
return docs
144+
144145
counts = defaultdict(int)
145146
for doc in read_corpus():
146147
for tok in encode(_native_to_unicode(doc)):

0 commit comments

Comments
 (0)