Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 24c1fd7

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Clean up transformer_vae and add refining.
PiperOrigin-RevId: 177505082
1 parent 01030eb commit 24c1fd7

File tree

1 file changed

+111
-86
lines changed

1 file changed

+111
-86
lines changed

tensor2tensor/models/transformer_vae.py

Lines changed: 111 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
import tensorflow as tf
3333

3434

35+
_DO_SUMMARIES = True
36+
37+
3538
def residual_conv(x, repeat, k, hparams, name, reuse=None):
3639
"""A stack of convolution blocks with residual connections."""
3740
with tf.variable_scope(name, reuse=reuse):
@@ -110,7 +113,8 @@ def dae(x, hparams, name):
110113
s = tf.nn.softmax((logsm + gumbel_samples) / temperature)
111114
m = tf.nn.softmax(m)
112115
kl = - tf.reduce_max(logsm, axis=-1)
113-
tf.summary.histogram("max-log", tf.reshape(kl, [-1]))
116+
if _DO_SUMMARIES:
117+
tf.summary.histogram("max-log", tf.reshape(kl, [-1]))
114118
# Calculate the argmax and construct hot vectors.
115119
maxvec = tf.reshape(tf.argmax(m, axis=-1), [-1])
116120
maxvhot = tf.stop_gradient(tf.one_hot(maxvec, hparams.v_size))
@@ -134,7 +138,9 @@ def vae(x, z_size, name):
134138
z = mu + tf.exp(log_sigma / 2) * epsilon
135139
kl = 0.5 * tf.reduce_mean(
136140
tf.exp(log_sigma) + tf.square(mu) - 1. - log_sigma, axis=-1)
137-
return z, tf.reduce_mean(kl), mu, log_sigma
141+
free_bits = z_size // 2
142+
kl_loss = tf.maximum(tf.reduce_mean(kl) - free_bits, 0.0)
143+
return z, kl_loss, mu, log_sigma
138144

139145

140146
def nearest(x, means, hparams):
@@ -187,35 +193,39 @@ def int_to_bit(x_int, nbits):
187193

188194
def bottleneck(x, hparams, filter_size, name):
189195
"""Bottleneck."""
190-
def embed1(x):
191-
if hparams.bottleneck_kind == "semhash":
192-
c = int_to_bit(x, c_size)
193-
h1a = tf.layers.dense(c, filter_size, name="vch1a")
194-
h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
195-
return h1a + h1b
196-
elif hparams.bottleneck_kind == "gumbel-softmax":
197-
hot = tf.one_hot(x, hparams.v_size)
198-
with tf.variable_scope(name, reuse=True):
199-
return tf.layers.dense(hot, hparams.hidden_size, name="dae_dense")
200-
201196
def embed(x):
197+
"""Embedding function; must be compatible with the code later."""
202198
with tf.variable_scope(name, reuse=True):
203-
h1 = embed1(x)
199+
if hparams.bottleneck_kind == "semhash":
200+
c = int_to_bit(x, z_size)
201+
h1a = tf.layers.dense(c, filter_size, name="vch1a")
202+
h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
203+
h1 = h1a + h1b
204+
elif hparams.bottleneck_kind == "gumbel-softmax":
205+
hot = tf.one_hot(x, hparams.v_size)
206+
h1 = tf.layers.dense(hot, hparams.hidden_size, name="dae_dense")
207+
elif hparams.bottleneck_kind == "vq-vae":
208+
means = tf.get_variable(name="means",
209+
shape=[hparams.v_size, hparams.hidden_size])
210+
h1 = tf.gather(means, x)
211+
204212
h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2")
205-
res = tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin")
206-
return res
213+
return tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin")
207214

208215
with tf.variable_scope(name):
209-
c_size = hparams.c_size
216+
z_size = hparams.z_size
210217
l = tf.constant(0.0)
211218
if hparams.bottleneck_kind == "dense":
212-
c = tf.layers.dense(x, c_size, name="vcc")
219+
c = tf.layers.dense(x, z_size, name="vcc")
220+
h1 = tf.layers.dense(c, filter_size, name="vch1")
221+
if hparams.bottleneck_kind == "vae":
222+
c, l, _, _ = vae(x, z_size, "vae")
213223
h1 = tf.layers.dense(c, filter_size, name="vch1")
214224
if hparams.bottleneck_kind == "semhash":
215-
c = tf.layers.dense(x, c_size, name="vcc")
225+
c = tf.layers.dense(x, z_size, name="vcc")
216226
y_clean = common_layers.saturating_sigmoid(c)
217-
tf.summary.histogram("y_clean", tf.reshape(y_clean, [-1]))
218-
# l = tf.reduce_mean(y_clean * (1.0 - y_clean))
227+
if _DO_SUMMARIES:
228+
tf.summary.histogram("y_clean", tf.reshape(y_clean, [-1]))
219229
if hparams.noise_dev > 0 and hparams.mode == tf.estimator.ModeKeys.TRAIN:
220230
dev = hparams.noise_dev
221231
noise = tf.truncated_normal(tf.shape(c), mean=0.0, stddev=dev)
@@ -233,7 +243,7 @@ def embed(x):
233243
h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
234244
h1 = h1a + h1b
235245
dx = tf.to_int32(tf.stop_gradient(d))
236-
c = bit_to_int(dx, c_size)
246+
c = bit_to_int(dx, z_size)
237247
if hparams.bottleneck_kind == "gumbel-softmax":
238248
_, hot, l = dae(x, hparams, name)
239249
c = tf.argmax(hot, axis=-1)
@@ -331,43 +341,54 @@ def next_bit(t_bit, i):
331341
def ae_transformer_internal(inputs, targets, target_space, hparams,
332342
beam_size, cache=None, predict_mask=1.0):
333343
"""AE Transformer, main step used for training."""
334-
hparams.z_size = hparams.hidden_size
335-
with tf.variable_scope("ae_transformer"):
336-
# Prepare inputs, targets, k.
337-
orig_targets = targets
338-
batch_size = tf.shape(orig_targets)[0]
339-
targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])
340-
k = hparams.num_compress_steps
341-
342-
# Encoder.
343-
if inputs is not None:
344-
inputs = common_layers.flatten4d3d(inputs)
345-
inputs, ed = encode(inputs, target_space, hparams, "input_enc")
346-
else:
347-
ed = None
348-
349-
# Autoencoding.
350-
losses = {"vc": tf.constant(0.0), "sm": tf.constant(0.0)}
351-
if hparams.do_ae:
352-
targets, _ = common_layers.pad_to_same_length(
353-
targets, targets, final_length_divisible_by=2**k)
354-
targets_c = compress(targets, False, hparams, "compress")
355-
if hparams.mode != tf.estimator.ModeKeys.PREDICT:
356-
# Compress and bottleneck.
357-
t_c, t_bit, vc_loss, _ = bottleneck(targets_c, hparams, 2*2048, "vc")
344+
# Summaries break with the do_refine cond, turn them off in that case.
345+
global _DO_SUMMARIES
346+
if hparams.do_refine:
347+
_DO_SUMMARIES = False
348+
349+
# Prepare.
350+
orig_targets = targets
351+
batch_size = tf.shape(orig_targets)[0]
352+
targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])
353+
354+
# Encoder.
355+
if inputs is not None:
356+
inputs = common_layers.flatten4d3d(inputs)
357+
inputs, ed = encode(inputs, target_space, hparams, "input_enc")
358+
else:
359+
ed = None
360+
361+
# Autoencoding.
362+
losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}
363+
if hparams.do_ae:
364+
max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1)
365+
targets, _ = common_layers.pad_to_same_length(
366+
targets, max_targets_len_from_inputs,
367+
final_length_divisible_by=2**hparams.num_compress_steps)
368+
targets_c = compress(targets, False, hparams, "compress")
369+
if hparams.mode != tf.estimator.ModeKeys.PREDICT:
370+
# Compress and bottleneck.
371+
t_c, t_bit, vc_loss, _ = bottleneck(targets_c, hparams, 2*2048, "vc")
372+
if _DO_SUMMARIES:
358373
tf.summary.histogram("bit0", tf.reshape(t_bit[:, 0, :], [-1]))
359-
pc = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.95
360-
pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
361-
cond = tf.less(tf.random_uniform([]), pc)
362-
t_c = tf.cond(cond, lambda: t_c, lambda: targets_c)
363-
losses["vc"] = vc_loss * tf.to_float(cond)
364-
# Extra loss predicting latent code from input.
374+
pc = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.95
375+
pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
376+
cond = tf.less(tf.random_uniform([]), pc)
377+
t_c = tf.cond(cond, lambda: t_c, lambda: targets_c)
378+
losses["extra"] = vc_loss * tf.to_float(cond)
379+
# Extra loss predicting latent code from input. Discrete only.
380+
if hparams.bottleneck_kind not in ["dense", "vae"]:
365381
t_pred = decode_transformer(
366382
inputs, ed, tf.stop_gradient(t_c), hparams, "extra")
367383
t_pred = tf.layers.dense(t_pred, 2**16, name="extra_logits")
368-
losses["sm"] = tf.nn.sparse_softmax_cross_entropy_with_logits(
384+
losses["latent_pred"] = tf.nn.sparse_softmax_cross_entropy_with_logits(
369385
labels=t_bit, logits=t_pred)
370-
losses["sm"] = tf.reduce_mean(losses["sm"]) * 0.5 * tf.to_float(cond)
386+
losses["latent_pred"] = tf.reduce_mean(
387+
losses["latent_pred"]) * 0.5 * tf.to_float(cond)
388+
else:
389+
if hparams.bottleneck_kind in ["dense", "vae"]:
390+
targets_rand = tf.random_uniform(tf.shape(targets_c))
391+
t_c, _, _, _ = bottleneck(targets_rand, hparams, 2*2048, "vc")
371392
else:
372393
latent_len = tf.shape(targets_c)[1]
373394
_, _, _, embed = bottleneck(targets_c, hparams, 2*2048, "vc")
@@ -378,33 +399,39 @@ def ae_transformer_internal(inputs, targets, target_space, hparams,
378399
cache = tf.reshape(cache, [1, latent_len, 1])
379400
cache = tf.tile(cache, [beam_size, 1, 1])
380401
t_c = embed(cache)
381-
# Postprocess.
382-
d = t_c
383-
pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size])
384-
pos = pos[:, :tf.shape(t_c)[1] + 1, :, :]
385-
t_c = tf.pad(t_c, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos
386-
387-
# Masking.
388-
if hparams.do_mask:
389-
masking = common_layers.inverse_lin_decay(100000)
390-
masking *= common_layers.inverse_exp_decay(25000) # Not much at start.
402+
# Postprocess.
403+
d = t_c
404+
pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size])
405+
pos = pos[:, :tf.shape(t_c)[1] + 1, :, :]
406+
t_c = tf.pad(t_c, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos
407+
408+
# Masking.
409+
if hparams.do_mask:
410+
masking = common_layers.inverse_lin_decay(100000)
411+
masking *= common_layers.inverse_exp_decay(25000) # Not much at start.
412+
if not hparams.do_refine:
391413
masking -= tf.random_uniform([]) * 0.3
392-
masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
393-
if hparams.mode == tf.estimator.ModeKeys.PREDICT:
394-
masking = predict_mask
395-
mask = tf.less(masking, tf.random_uniform(tf.shape(targets)[:-1]))
396-
mask = tf.expand_dims(tf.to_float(mask), 3)
397-
for i in xrange(hparams.num_compress_steps):
398-
j = hparams.num_compress_steps - i - 1
399-
d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
400-
d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)
401-
targets = mask * targets + (1.0 - mask) * d
402-
targets = tf.concat([tf.reverse(t_c, [1]), targets], axis=1)
403-
404-
res = decode_transformer(inputs, ed, targets, hparams, "decoder")
405-
if hparams.do_ae:
406-
res = res[:, tf.shape(t_c)[1]:, :, :]
407-
return res, losses, cache
414+
masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
415+
if hparams.mode == tf.estimator.ModeKeys.PREDICT:
416+
masking = predict_mask
417+
mask = tf.less(masking, tf.random_uniform(tf.shape(targets)[:-1]))
418+
mask = tf.expand_dims(tf.to_float(mask), 3)
419+
for i in xrange(hparams.num_compress_steps):
420+
j = hparams.num_compress_steps - i - 1
421+
d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
422+
d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)
423+
targets = mask * targets + (1.0 - mask) * d
424+
targets = tf.concat([tf.reverse(t_c, [1]), targets], axis=1)
425+
426+
res = decode_transformer(inputs, ed, targets, hparams, "decoder")
427+
if hparams.do_ae:
428+
res = res[:, tf.shape(t_c)[1]:, :, :]
429+
if hparams.do_mask and hparams.do_refine:
430+
def refine_res():
431+
return residual_conv(res, 1, (5, 1), hparams, "refine")
432+
all_masked = tf.less(tf.reduce_sum(mask), 0.1)
433+
res = tf.cond(all_masked, refine_res, lambda: res)
434+
return res, losses, cache
408435

409436

410437
@registry.register_model
@@ -466,7 +493,7 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
466493
else:
467494
batch_size = tf.shape(features["inputs"])[0]
468495
length = tf.shape(features["inputs"])[1]
469-
target_length = tf.to_int32(1.3 * tf.to_float(length))
496+
target_length = tf.to_int32(2.0 * tf.to_float(length))
470497
initial_output = tf.zeros((batch_size, target_length, 1, 1),
471498
dtype=tf.int64)
472499

@@ -489,15 +516,15 @@ def transformer_ae_small():
489516
hparams.hidden_size = 384
490517
hparams.filter_size = 2048
491518
hparams.label_smoothing = 0.0
492-
hparams.add_hparam("c_size", 16)
519+
hparams.add_hparam("z_size", 16)
493520
hparams.add_hparam("noise_dev", 1.0)
494521
hparams.add_hparam("d_mix", 0.5)
495-
# Bottleneck kinds supported: dense, semhash, gumbel-softmax.
522+
# Bottleneck kinds supported: dense, vae, semhash, gumbel-softmax, vq-vae.
496523
hparams.add_hparam("bottleneck_kind", "semhash")
497524
hparams.add_hparam("do_ae", True)
498525
hparams.add_hparam("do_mask", True)
526+
hparams.add_hparam("do_refine", True)
499527
hparams.add_hparam("drop_inputs", False)
500-
hparams.add_hparam("z_size", 128)
501528
hparams.add_hparam("v_size", 1024*64)
502529
hparams.add_hparam("max_context_length", 64)
503530
hparams.add_hparam("num_compress_steps", 3)
@@ -522,8 +549,6 @@ def transformer_ae_cifar():
522549
hparams = transformer_ae_small()
523550
hparams.hidden_size = 256
524551
hparams.filter_size = 512
525-
hparams.z_size = 256 # 64
526-
hparams.z_size2 = 0 # 16
527552
hparams.batch_size = 1024 * 4
528553
hparams.num_compress_steps = 2
529554
hparams.v_size = 1024 * 16

0 commit comments

Comments
 (0)