3232import tensorflow as tf
3333
3434
35+ _DO_SUMMARIES = True
36+
37+
3538def residual_conv (x , repeat , k , hparams , name , reuse = None ):
3639 """A stack of convolution blocks with residual connections."""
3740 with tf .variable_scope (name , reuse = reuse ):
@@ -110,7 +113,8 @@ def dae(x, hparams, name):
110113 s = tf .nn .softmax ((logsm + gumbel_samples ) / temperature )
111114 m = tf .nn .softmax (m )
112115 kl = - tf .reduce_max (logsm , axis = - 1 )
113- tf .summary .histogram ("max-log" , tf .reshape (kl , [- 1 ]))
116+ if _DO_SUMMARIES :
117+ tf .summary .histogram ("max-log" , tf .reshape (kl , [- 1 ]))
114118 # Calculate the argmax and construct hot vectors.
115119 maxvec = tf .reshape (tf .argmax (m , axis = - 1 ), [- 1 ])
116120 maxvhot = tf .stop_gradient (tf .one_hot (maxvec , hparams .v_size ))
@@ -134,7 +138,9 @@ def vae(x, z_size, name):
134138 z = mu + tf .exp (log_sigma / 2 ) * epsilon
135139 kl = 0.5 * tf .reduce_mean (
136140 tf .exp (log_sigma ) + tf .square (mu ) - 1. - log_sigma , axis = - 1 )
137- return z , tf .reduce_mean (kl ), mu , log_sigma
141+ free_bits = z_size // 2
142+ kl_loss = tf .maximum (tf .reduce_mean (kl ) - free_bits , 0.0 )
143+ return z , kl_loss , mu , log_sigma
138144
139145
140146def nearest (x , means , hparams ):
@@ -187,35 +193,39 @@ def int_to_bit(x_int, nbits):
187193
188194def bottleneck (x , hparams , filter_size , name ):
189195 """Bottleneck."""
190- def embed1 (x ):
191- if hparams .bottleneck_kind == "semhash" :
192- c = int_to_bit (x , c_size )
193- h1a = tf .layers .dense (c , filter_size , name = "vch1a" )
194- h1b = tf .layers .dense (1.0 - c , filter_size , name = "vch1b" )
195- return h1a + h1b
196- elif hparams .bottleneck_kind == "gumbel-softmax" :
197- hot = tf .one_hot (x , hparams .v_size )
198- with tf .variable_scope (name , reuse = True ):
199- return tf .layers .dense (hot , hparams .hidden_size , name = "dae_dense" )
200-
201196 def embed (x ):
197+ """Embedding function; must be compatible with the code later."""
202198 with tf .variable_scope (name , reuse = True ):
203- h1 = embed1 (x )
199+ if hparams .bottleneck_kind == "semhash" :
200+ c = int_to_bit (x , z_size )
201+ h1a = tf .layers .dense (c , filter_size , name = "vch1a" )
202+ h1b = tf .layers .dense (1.0 - c , filter_size , name = "vch1b" )
203+ h1 = h1a + h1b
204+ elif hparams .bottleneck_kind == "gumbel-softmax" :
205+ hot = tf .one_hot (x , hparams .v_size )
206+ h1 = tf .layers .dense (hot , hparams .hidden_size , name = "dae_dense" )
207+ elif hparams .bottleneck_kind == "vq-vae" :
208+ means = tf .get_variable (name = "means" ,
209+ shape = [hparams .v_size , hparams .hidden_size ])
210+ h1 = tf .gather (means , x )
211+
204212 h2 = tf .layers .dense (tf .nn .relu (h1 ), filter_size , name = "vch2" )
205- res = tf .layers .dense (tf .nn .relu (h2 ), hparams .hidden_size , name = "vcfin" )
206- return res
213+ return tf .layers .dense (tf .nn .relu (h2 ), hparams .hidden_size , name = "vcfin" )
207214
208215 with tf .variable_scope (name ):
209- c_size = hparams .c_size
216+ z_size = hparams .z_size
210217 l = tf .constant (0.0 )
211218 if hparams .bottleneck_kind == "dense" :
212- c = tf .layers .dense (x , c_size , name = "vcc" )
219+ c = tf .layers .dense (x , z_size , name = "vcc" )
220+ h1 = tf .layers .dense (c , filter_size , name = "vch1" )
221+ if hparams .bottleneck_kind == "vae" :
222+ c , l , _ , _ = vae (x , z_size , "vae" )
213223 h1 = tf .layers .dense (c , filter_size , name = "vch1" )
214224 if hparams .bottleneck_kind == "semhash" :
215- c = tf .layers .dense (x , c_size , name = "vcc" )
225+ c = tf .layers .dense (x , z_size , name = "vcc" )
216226 y_clean = common_layers .saturating_sigmoid (c )
217- tf . summary . histogram ( "y_clean" , tf . reshape ( y_clean , [ - 1 ]))
218- # l = tf.reduce_mean( y_clean * (1.0 - y_clean ))
227+ if _DO_SUMMARIES :
228+ tf .summary . histogram ( " y_clean" , tf . reshape ( y_clean , [ - 1 ] ))
219229 if hparams .noise_dev > 0 and hparams .mode == tf .estimator .ModeKeys .TRAIN :
220230 dev = hparams .noise_dev
221231 noise = tf .truncated_normal (tf .shape (c ), mean = 0.0 , stddev = dev )
@@ -233,7 +243,7 @@ def embed(x):
233243 h1b = tf .layers .dense (1.0 - c , filter_size , name = "vch1b" )
234244 h1 = h1a + h1b
235245 dx = tf .to_int32 (tf .stop_gradient (d ))
236- c = bit_to_int (dx , c_size )
246+ c = bit_to_int (dx , z_size )
237247 if hparams .bottleneck_kind == "gumbel-softmax" :
238248 _ , hot , l = dae (x , hparams , name )
239249 c = tf .argmax (hot , axis = - 1 )
@@ -331,43 +341,54 @@ def next_bit(t_bit, i):
331341def ae_transformer_internal (inputs , targets , target_space , hparams ,
332342 beam_size , cache = None , predict_mask = 1.0 ):
333343 """AE Transformer, main step used for training."""
334- hparams .z_size = hparams .hidden_size
335- with tf .variable_scope ("ae_transformer" ):
336- # Prepare inputs, targets, k.
337- orig_targets = targets
338- batch_size = tf .shape (orig_targets )[0 ]
339- targets = tf .reshape (targets , [batch_size , - 1 , 1 , hparams .hidden_size ])
340- k = hparams .num_compress_steps
341-
342- # Encoder.
343- if inputs is not None :
344- inputs = common_layers .flatten4d3d (inputs )
345- inputs , ed = encode (inputs , target_space , hparams , "input_enc" )
346- else :
347- ed = None
348-
349- # Autoencoding.
350- losses = {"vc" : tf .constant (0.0 ), "sm" : tf .constant (0.0 )}
351- if hparams .do_ae :
352- targets , _ = common_layers .pad_to_same_length (
353- targets , targets , final_length_divisible_by = 2 ** k )
354- targets_c = compress (targets , False , hparams , "compress" )
355- if hparams .mode != tf .estimator .ModeKeys .PREDICT :
356- # Compress and bottleneck.
357- t_c , t_bit , vc_loss , _ = bottleneck (targets_c , hparams , 2 * 2048 , "vc" )
344+ # Summaries break with the do_refine cond, turn them off in that case.
345+ global _DO_SUMMARIES
346+ if hparams .do_refine :
347+ _DO_SUMMARIES = False
348+
349+ # Prepare.
350+ orig_targets = targets
351+ batch_size = tf .shape (orig_targets )[0 ]
352+ targets = tf .reshape (targets , [batch_size , - 1 , 1 , hparams .hidden_size ])
353+
354+ # Encoder.
355+ if inputs is not None :
356+ inputs = common_layers .flatten4d3d (inputs )
357+ inputs , ed = encode (inputs , target_space , hparams , "input_enc" )
358+ else :
359+ ed = None
360+
361+ # Autoencoding.
362+ losses = {"extra" : tf .constant (0.0 ), "latent_pred" : tf .constant (0.0 )}
363+ if hparams .do_ae :
364+ max_targets_len_from_inputs = tf .concat ([inputs , inputs ], axis = 1 )
365+ targets , _ = common_layers .pad_to_same_length (
366+ targets , max_targets_len_from_inputs ,
367+ final_length_divisible_by = 2 ** hparams .num_compress_steps )
368+ targets_c = compress (targets , False , hparams , "compress" )
369+ if hparams .mode != tf .estimator .ModeKeys .PREDICT :
370+ # Compress and bottleneck.
371+ t_c , t_bit , vc_loss , _ = bottleneck (targets_c , hparams , 2 * 2048 , "vc" )
372+ if _DO_SUMMARIES :
358373 tf .summary .histogram ("bit0" , tf .reshape (t_bit [:, 0 , :], [- 1 ]))
359- pc = common_layers .inverse_exp_decay (hparams .startup_steps ) * 0.95
360- pc = pc if hparams .mode == tf .estimator .ModeKeys .TRAIN else 1.0
361- cond = tf .less (tf .random_uniform ([]), pc )
362- t_c = tf .cond (cond , lambda : t_c , lambda : targets_c )
363- losses ["vc" ] = vc_loss * tf .to_float (cond )
364- # Extra loss predicting latent code from input.
374+ pc = common_layers .inverse_exp_decay (hparams .startup_steps ) * 0.95
375+ pc = pc if hparams .mode == tf .estimator .ModeKeys .TRAIN else 1.0
376+ cond = tf .less (tf .random_uniform ([]), pc )
377+ t_c = tf .cond (cond , lambda : t_c , lambda : targets_c )
378+ losses ["extra" ] = vc_loss * tf .to_float (cond )
379+ # Extra loss predicting latent code from input. Discrete only.
380+ if hparams .bottleneck_kind not in ["dense" , "vae" ]:
365381 t_pred = decode_transformer (
366382 inputs , ed , tf .stop_gradient (t_c ), hparams , "extra" )
367383 t_pred = tf .layers .dense (t_pred , 2 ** 16 , name = "extra_logits" )
368- losses ["sm " ] = tf .nn .sparse_softmax_cross_entropy_with_logits (
384+ losses ["latent_pred " ] = tf .nn .sparse_softmax_cross_entropy_with_logits (
369385 labels = t_bit , logits = t_pred )
370- losses ["sm" ] = tf .reduce_mean (losses ["sm" ]) * 0.5 * tf .to_float (cond )
386+ losses ["latent_pred" ] = tf .reduce_mean (
387+ losses ["latent_pred" ]) * 0.5 * tf .to_float (cond )
388+ else :
389+ if hparams .bottleneck_kind in ["dense" , "vae" ]:
390+ targets_rand = tf .random_uniform (tf .shape (targets_c ))
391+ t_c , _ , _ , _ = bottleneck (targets_rand , hparams , 2 * 2048 , "vc" )
371392 else :
372393 latent_len = tf .shape (targets_c )[1 ]
373394 _ , _ , _ , embed = bottleneck (targets_c , hparams , 2 * 2048 , "vc" )
@@ -378,33 +399,39 @@ def ae_transformer_internal(inputs, targets, target_space, hparams,
378399 cache = tf .reshape (cache , [1 , latent_len , 1 ])
379400 cache = tf .tile (cache , [beam_size , 1 , 1 ])
380401 t_c = embed (cache )
381- # Postprocess.
382- d = t_c
383- pos = tf .get_variable ("pos" , [1 , 1000 , 1 , hparams .hidden_size ])
384- pos = pos [:, :tf .shape (t_c )[1 ] + 1 , :, :]
385- t_c = tf .pad (t_c , [[0 , 0 ], [1 , 0 ], [0 , 0 ], [0 , 0 ]]) + pos
386-
387- # Masking.
388- if hparams .do_mask :
389- masking = common_layers .inverse_lin_decay (100000 )
390- masking *= common_layers .inverse_exp_decay (25000 ) # Not much at start.
402+ # Postprocess.
403+ d = t_c
404+ pos = tf .get_variable ("pos" , [1 , 1000 , 1 , hparams .hidden_size ])
405+ pos = pos [:, :tf .shape (t_c )[1 ] + 1 , :, :]
406+ t_c = tf .pad (t_c , [[0 , 0 ], [1 , 0 ], [0 , 0 ], [0 , 0 ]]) + pos
407+
408+ # Masking.
409+ if hparams .do_mask :
410+ masking = common_layers .inverse_lin_decay (100000 )
411+ masking *= common_layers .inverse_exp_decay (25000 ) # Not much at start.
412+ if not hparams .do_refine :
391413 masking -= tf .random_uniform ([]) * 0.3
392- masking = tf .minimum (tf .maximum (masking , 0.0 ), 1.0 )
393- if hparams .mode == tf .estimator .ModeKeys .PREDICT :
394- masking = predict_mask
395- mask = tf .less (masking , tf .random_uniform (tf .shape (targets )[:- 1 ]))
396- mask = tf .expand_dims (tf .to_float (mask ), 3 )
397- for i in xrange (hparams .num_compress_steps ):
398- j = hparams .num_compress_steps - i - 1
399- d = residual_conv (d , 1 , (3 , 1 ), hparams , "decompress_rc_%d" % j )
400- d = decompress_step (d , hparams , i > 0 , False , "decompress_%d" % j )
401- targets = mask * targets + (1.0 - mask ) * d
402- targets = tf .concat ([tf .reverse (t_c , [1 ]), targets ], axis = 1 )
403-
404- res = decode_transformer (inputs , ed , targets , hparams , "decoder" )
405- if hparams .do_ae :
406- res = res [:, tf .shape (t_c )[1 ]:, :, :]
407- return res , losses , cache
414+ masking = tf .minimum (tf .maximum (masking , 0.0 ), 1.0 )
415+ if hparams .mode == tf .estimator .ModeKeys .PREDICT :
416+ masking = predict_mask
417+ mask = tf .less (masking , tf .random_uniform (tf .shape (targets )[:- 1 ]))
418+ mask = tf .expand_dims (tf .to_float (mask ), 3 )
419+ for i in xrange (hparams .num_compress_steps ):
420+ j = hparams .num_compress_steps - i - 1
421+ d = residual_conv (d , 1 , (3 , 1 ), hparams , "decompress_rc_%d" % j )
422+ d = decompress_step (d , hparams , i > 0 , False , "decompress_%d" % j )
423+ targets = mask * targets + (1.0 - mask ) * d
424+ targets = tf .concat ([tf .reverse (t_c , [1 ]), targets ], axis = 1 )
425+
426+ res = decode_transformer (inputs , ed , targets , hparams , "decoder" )
427+ if hparams .do_ae :
428+ res = res [:, tf .shape (t_c )[1 ]:, :, :]
429+ if hparams .do_mask and hparams .do_refine :
430+ def refine_res ():
431+ return residual_conv (res , 1 , (5 , 1 ), hparams , "refine" )
432+ all_masked = tf .less (tf .reduce_sum (mask ), 0.1 )
433+ res = tf .cond (all_masked , refine_res , lambda : res )
434+ return res , losses , cache
408435
409436
410437@registry .register_model
@@ -466,7 +493,7 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
466493 else :
467494 batch_size = tf .shape (features ["inputs" ])[0 ]
468495 length = tf .shape (features ["inputs" ])[1 ]
469- target_length = tf .to_int32 (1.3 * tf .to_float (length ))
496+ target_length = tf .to_int32 (2.0 * tf .to_float (length ))
470497 initial_output = tf .zeros ((batch_size , target_length , 1 , 1 ),
471498 dtype = tf .int64 )
472499
@@ -489,15 +516,15 @@ def transformer_ae_small():
489516 hparams .hidden_size = 384
490517 hparams .filter_size = 2048
491518 hparams .label_smoothing = 0.0
492- hparams .add_hparam ("c_size " , 16 )
519+ hparams .add_hparam ("z_size " , 16 )
493520 hparams .add_hparam ("noise_dev" , 1.0 )
494521 hparams .add_hparam ("d_mix" , 0.5 )
495- # Bottleneck kinds supported: dense, semhash, gumbel-softmax.
522+ # Bottleneck kinds supported: dense, vae, semhash, gumbel-softmax, vq-vae .
496523 hparams .add_hparam ("bottleneck_kind" , "semhash" )
497524 hparams .add_hparam ("do_ae" , True )
498525 hparams .add_hparam ("do_mask" , True )
526+ hparams .add_hparam ("do_refine" , True )
499527 hparams .add_hparam ("drop_inputs" , False )
500- hparams .add_hparam ("z_size" , 128 )
501528 hparams .add_hparam ("v_size" , 1024 * 64 )
502529 hparams .add_hparam ("max_context_length" , 64 )
503530 hparams .add_hparam ("num_compress_steps" , 3 )
@@ -522,8 +549,6 @@ def transformer_ae_cifar():
522549 hparams = transformer_ae_small ()
523550 hparams .hidden_size = 256
524551 hparams .filter_size = 512
525- hparams .z_size = 256 # 64
526- hparams .z_size2 = 0 # 16
527552 hparams .batch_size = 1024 * 4
528553 hparams .num_compress_steps = 2
529554 hparams .v_size = 1024 * 16
0 commit comments