4545class Transformer (t2t_model .T2TModel ):
4646 """Attention net. See file docstring."""
4747
48+ def __init__ (self , * args , ** kwargs ):
49+ super (Transformer , self ).__init__ (* args , ** kwargs )
50+ self .attention_weights = dict () # For vizualizing attention heads.
51+
4852 def encode (self , inputs , target_space , hparams , features = None ):
4953 """Encode transformer inputs.
5054
@@ -73,7 +77,8 @@ def encode(self, inputs, target_space, hparams, features=None):
7377
7478 encoder_output = transformer_encoder (
7579 encoder_input , self_attention_bias ,
76- hparams , nonpadding = _features_to_nonpadding (features , "inputs" ))
80+ hparams , nonpadding = _features_to_nonpadding (features , "inputs" ),
81+ save_weights_to = self .attention_weights )
7782
7883 return encoder_output , encoder_decoder_attention_bias
7984
@@ -114,7 +119,8 @@ def decode(self,
114119 encoder_decoder_attention_bias ,
115120 hparams ,
116121 cache = cache ,
117- nonpadding = nonpadding )
122+ nonpadding = nonpadding ,
123+ save_weights_to = self .attention_weights )
118124
119125 if hparams .use_tpu and hparams .mode == tf .estimator .ModeKeys .TRAIN :
120126 # TPU does not react kindly to extra dimensions.
@@ -507,7 +513,8 @@ def transformer_encoder(encoder_input,
507513 encoder_self_attention_bias ,
508514 hparams ,
509515 name = "encoder" ,
510- nonpadding = None ):
516+ nonpadding = None ,
517+ save_weights_to = None ):
511518 """A stack of transformer layers.
512519
513520 Args:
@@ -522,6 +529,9 @@ def transformer_encoder(encoder_input,
522529 encoder_self_attention_bias. The knowledge about padding is used
523530 for pad_remover(efficiency) and to mask out padding in convoltutional
524531 layers.
532+ save_weights_to: an optional dictionary to capture attention weights
533+ for vizualization; the weights tensor will be appended there under
534+ a string key created from the variable scope (including name).
525535
526536 Returns:
527537 y: a Tensors
@@ -551,6 +561,7 @@ def transformer_encoder(encoder_input,
551561 hparams .num_heads ,
552562 hparams .attention_dropout ,
553563 attention_type = hparams .self_attention_type ,
564+ save_weights_to = save_weights_to ,
554565 max_relative_position = hparams .max_relative_position )
555566 x = common_layers .layer_postprocess (x , y , hparams )
556567 with tf .variable_scope ("ffn" ):
@@ -571,7 +582,8 @@ def transformer_decoder(decoder_input,
571582 hparams ,
572583 cache = None ,
573584 name = "decoder" ,
574- nonpadding = None ):
585+ nonpadding = None ,
586+ save_weights_to = None ):
575587 """A stack of transformer layers.
576588
577589 Args:
@@ -590,6 +602,9 @@ def transformer_decoder(decoder_input,
590602 to mask out padding in convoltutional layers. We generally only
591603 need this mask for "packed" datasets, because for ordinary datasets,
592604 no padding is ever followed by nonpadding.
605+ save_weights_to: an optional dictionary to capture attention weights
606+ for vizualization; the weights tensor will be appended there under
607+ a string key created from the variable scope (including name).
593608
594609 Returns:
595610 y: a Tensors
@@ -612,6 +627,7 @@ def transformer_decoder(decoder_input,
612627 hparams .num_heads ,
613628 hparams .attention_dropout ,
614629 attention_type = hparams .self_attention_type ,
630+ save_weights_to = save_weights_to ,
615631 max_relative_position = hparams .max_relative_position ,
616632 cache = layer_cache )
617633 x = common_layers .layer_postprocess (x , y , hparams )
@@ -624,7 +640,8 @@ def transformer_decoder(decoder_input,
624640 hparams .attention_key_channels or hparams .hidden_size ,
625641 hparams .attention_value_channels or hparams .hidden_size ,
626642 hparams .hidden_size , hparams .num_heads ,
627- hparams .attention_dropout )
643+ hparams .attention_dropout ,
644+ save_weights_to = save_weights_to )
628645 x = common_layers .layer_postprocess (x , y , hparams )
629646 with tf .variable_scope ("ffn" ):
630647 y = transformer_ffn_layer (
0 commit comments