Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 3b663b6

Browse files
authored
Merge pull request #771 from lukaszkaiser/push
1.6.2
2 parents 4ebb860 + 485e5d9 commit 3b663b6

30 files changed

+1095
-115
lines changed

.travis.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@ matrix:
2121
- python: "3.6"
2222
env: TF_VERSION="1.7.*"
2323
before_install:
24-
- echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list
25-
- curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -
24+
# Disabled TensorFlow Serving install until bug fixed. See "Export and query"
25+
# section below.
26+
# - echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list
27+
# - curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -
2628
- sudo apt-get update -qq
2729
- sudo apt-get install -qq libhdf5-dev
28-
- sudo apt-get install -qq tensorflow-model-server
30+
# - sudo apt-get install -qq tensorflow-model-server
2931
install:
3032
- pip install -q "tensorflow==$TF_VERSION"
3133
- pip install -q .[tests]

docs/cloud_tpu.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ See the official tutorial for [running Transformer
1818
on Cloud TPUs](https://cloud.google.com/tpu/docs/tutorials/transformer)
1919
for some examples and try out your own problems.
2020

21+
You can train an Automatic Speech Recognition (ASR) model with Transformer
22+
on TPU by using `transformer` as `model` with `transformer_librispeech_tpu` as
23+
`hparams_set` and `librispeech` as `problem`. See this [tutorial](tutorials/ast_with_transformer.md) for more details on training it and this
24+
[notebook](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/asr_transformer.ipynb) to see how the resulting model transcribes your speech to text.
25+
2126
Image Transformer:
2227
* `imagetransformer` with `imagetransformer_base_tpu` (or
2328
`imagetransformer_tiny_tpu`)

docs/index.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,14 @@ accessible and [accelerate ML
1616
research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html).
1717

1818

19-
## Basics
19+
## Introduction
2020

2121
* [Walkthrough](walkthrough.md): Install and run.
2222
* [IPython notebook](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb): Get a hands-on experience.
23+
* [Automatic Speech Recognition notebook](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/asr_transformer.ipynb): Transcribe speech to text with a T2T model.
24+
25+
## Basics
26+
2327
* [Overview](overview.md): How all parts of T2T code are connected.
2428
* [New Problem](new_problem.md): Train T2T models on your data.
2529
* [New Model](new_model.md): Create your own T2T model.
@@ -29,6 +33,7 @@ research](https://research.googleblog.com/2017/06/accelerating-deep-learning-res
2933
* [Training on Google Cloud ML](cloud_mlengine.md)
3034
* [Training on Google Cloud TPUs](cloud_tpu.md)
3135
* [Distributed Training](distributed_training.md)
36+
# [Automatic Speech Recognition (ASR) with Transformer](tutorials/asr_with_transformer.md)
3237

3338
## Solving your task
3439

docs/tutorials/asr_with_transformer.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
# Automatic Speech Recognition (ASR) with Transformer
22

3+
Check out the [Automatic Speech Recognition notebook](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/asr_transformer.ipynb) to see how the resulting model transcribes your speech to text.
4+
35
## Data set
46

57
This tutorial uses the publicly available
68
[Librispeech](http://www.openslr.org/12/) ASR corpus.
79

10+
811
## Generate the dataset
912

1013
To generate the dataset use `t2t-datagen`. You need to create environment

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.6.1',
8+
version='1.6.2',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='[email protected]',

tensor2tensor/data_generators/all_problems.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
"tensor2tensor.data_generators.ice_parsing",
3737
"tensor2tensor.data_generators.imagenet",
3838
"tensor2tensor.data_generators.imdb",
39+
"tensor2tensor.data_generators.lambada",
3940
"tensor2tensor.data_generators.librispeech",
4041
"tensor2tensor.data_generators.lm1b",
4142
"tensor2tensor.data_generators.mnist",

tensor2tensor/data_generators/gym.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from tensor2tensor.data_generators import problem
2929
from tensor2tensor.data_generators import video_utils
3030

31+
from tensor2tensor.models.research import autoencoders
3132
from tensor2tensor.models.research import rl
3233
from tensor2tensor.rl import collect
3334
from tensor2tensor.rl.envs import tf_atari_wrappers as atari
@@ -42,7 +43,9 @@
4243
flags = tf.flags
4344
FLAGS = flags.FLAGS
4445

45-
flags.DEFINE_string("agent_policy_path", "", "File with model for agent")
46+
47+
flags.DEFINE_string("agent_policy_path", "", "File with model for agent.")
48+
flags.DEFINE_string("autoencoder_path", "", "File with model for autoencoder.")
4649

4750

4851
class GymDiscreteProblem(video_utils.VideoProblem):
@@ -179,6 +182,7 @@ class GymPongRandom50k(GymPongRandom5k):
179182
def num_steps(self):
180183
return 50000
181184

185+
182186
@registry.register_problem
183187
class GymFreewayRandom5k(GymDiscreteProblem):
184188
"""Freeway game, random actions."""
@@ -209,7 +213,6 @@ def num_steps(self):
209213
return 50000
210214

211215

212-
@registry.register_problem
213216
class GymDiscreteProblemWithAgent(GymDiscreteProblem):
214217
"""Gym environment with discrete actions and rewards and an agent."""
215218

@@ -239,7 +242,7 @@ def _setup(self):
239242
generator_batch_env = batch_env_factory(
240243
self.environment_spec, env_hparams, num_agents=1, xvfb=False)
241244

242-
with tf.variable_scope("", reuse=tf.AUTO_REUSE):
245+
with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
243246
if FLAGS.agent_policy_path:
244247
policy_lambda = self.collect_hparams.network
245248
else:
@@ -252,7 +255,7 @@ def _setup(self):
252255
create_scope_now_=True,
253256
unique_name_="network")
254257

255-
with tf.variable_scope("", reuse=tf.AUTO_REUSE):
258+
with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
256259
self.collect_hparams.epoch_length = 10
257260
_, self.collect_trigger_op = collect.define_collect(
258261
policy_factory, generator_batch_env, self.collect_hparams,
@@ -267,6 +270,22 @@ def restore_networks(self, sess):
267270
tf.global_variables(".*network_parameters.*"))
268271
model_saver.restore(sess, FLAGS.agent_policy_path)
269272

273+
def autoencode(self, image, sess):
274+
with tf.Graph().as_default():
275+
hparams = autoencoders.autoencoder_discrete_pong()
276+
hparams.data_dir = "unused"
277+
hparams.problem_hparams = self.get_hparams(hparams)
278+
hparams.problem = self
279+
model = autoencoders.AutoencoderOrderedDiscrete(
280+
hparams, tf.estimator.ModeKeys.EVAL)
281+
img = tf.constant(image)
282+
img = tf.to_int32(tf.reshape(
283+
img, [1, 1, self.frame_height, self.frame_width, self.num_channels]))
284+
encoded = model.encode(img)
285+
model_saver = tf.train.Saver(tf.global_variables())
286+
model_saver.restore(sess, FLAGS.autoencoder_path)
287+
return sess.run(encoded)
288+
270289
def generate_encoded_samples(self, data_dir, tmp_dir, unused_dataset_split):
271290
self._setup()
272291
self.debug_dump_frames_path = os.path.join(
@@ -275,17 +294,14 @@ def generate_encoded_samples(self, data_dir, tmp_dir, unused_dataset_split):
275294
with tf.Session() as sess:
276295
sess.run(tf.global_variables_initializer())
277296
self.restore_networks(sess)
278-
# Actions are shifted by 1 by MemoryWrapper, compensate here.
279-
avilable_data_size = sess.run(self.avilable_data_size_op)
280-
if avilable_data_size < 1:
281-
sess.run(self.collect_trigger_op)
282297
pieces_generated = 0
283-
observ, reward, _, _ = sess.run(self.data_get_op)
284298
while pieces_generated < self.num_steps + self.warm_up:
285299
avilable_data_size = sess.run(self.avilable_data_size_op)
286300
if avilable_data_size < 1:
287301
sess.run(self.collect_trigger_op)
288-
next_observ, next_reward, action, _ = sess.run(self.data_get_op)
302+
observ, reward, action, _, img = sess.run(self.data_get_op)
303+
if FLAGS.autoencoder_path:
304+
observ = self.autoencode(img, sess)
289305
yield {"image/encoded": [observ],
290306
"image/format": ["png"],
291307
"image/height": [self.frame_height],
@@ -294,7 +310,6 @@ def generate_encoded_samples(self, data_dir, tmp_dir, unused_dataset_split):
294310
"done": [int(False)],
295311
"reward": [int(reward) - self.min_reward]}
296312
pieces_generated += 1
297-
observ, reward = next_observ, next_reward
298313

299314

300315
@registry.register_problem
@@ -318,20 +333,24 @@ def restore_networks(self, sess):
318333

319334

320335
@registry.register_problem
321-
class GymSimulatedDiscreteProblemWithAgentOnPong(GymSimulatedDiscreteProblemWithAgent, GymPongRandom5k):
336+
class GymSimulatedDiscreteProblemWithAgentOnPong(
337+
GymSimulatedDiscreteProblemWithAgent, GymPongRandom5k):
322338
pass
323339

324340

325341
@registry.register_problem
326-
class GymDiscreteProblemWithAgentOnPong(GymDiscreteProblemWithAgent, GymPongRandom5k):
342+
class GymDiscreteProblemWithAgentOnPong(
343+
GymDiscreteProblemWithAgent, GymPongRandom5k):
327344
pass
328345

329346

330347
@registry.register_problem
331-
class GymSimulatedDiscreteProblemWithAgentOnFreeway(GymSimulatedDiscreteProblemWithAgent, GymFreewayRandom5k):
348+
class GymSimulatedDiscreteProblemWithAgentOnFreeway(
349+
GymSimulatedDiscreteProblemWithAgent, GymFreewayRandom5k):
332350
pass
333351

334352

335353
@registry.register_problem
336-
class GymDiscreteProblemWithAgentOnFreeway(GymDiscreteProblemWithAgent, GymFreewayRandom5k):
354+
class GymDiscreteProblemWithAgentOnFreeway(
355+
GymDiscreteProblemWithAgent, GymFreewayRandom5k):
337356
pass

0 commit comments

Comments
 (0)