diff --git a/src/model/seq2seq.py b/src/model/seq2seq.py index bbff9c05..80cb0141 100644 --- a/src/model/seq2seq.py +++ b/src/model/seq2seq.py @@ -70,7 +70,8 @@ from tensorflow.python.ops import nn_ops from tensorflow.contrib.rnn.python.ops import rnn, rnn_cell from tensorflow.python.ops import variable_scope -linear = rnn_cell._linear # pylint: disable=protected-access +from tensorflow.contrib.rnn.python.ops import core_rnn_cell +linear = core_rnn_cell._linear # pylint: disable=protected-access def _extract_argmax_and_embed(embedding, output_projection=None, update_embedding=True): diff --git a/src/model/seq2seq_model.py b/src/model/seq2seq_model.py index 98b2bea0..cf9ebc06 100644 --- a/src/model/seq2seq_model.py +++ b/src/model/seq2seq_model.py @@ -84,22 +84,25 @@ def __init__(self, encoder_masks, encoder_inputs_tensor, self.encoder_masks = encoder_masks # Create the internal multi-layer cell for our RNN. - single_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(attn_num_hidden, forget_bias=0.0, state_is_tuple=False) + if use_gru: print("using GRU CELL in decoder") - single_cell = tf.contrib.rnn.core_rnn_cell.GRUCell(attn_num_hidden) - cell = single_cell + single_cell = tf.contrib.rnn.GRUCell(attn_num_hidden) + else: + single_cell = tf.contrib.rnn.BasicLSTMCell(attn_num_hidden, forget_bias=0.0, state_is_tuple=False) if attn_num_layers > 1: - cell = tf.contrib.rnn.core_rnn_cell.MultiRNNCell([single_cell] * attn_num_layers, state_is_tuple=False) + cell = tf.contrib.rnn.MultiRNNCell([single_cell] * attn_num_layers, state_is_tuple=False) + else: + cell = single_cell # The seq2seq function: we use embedding for the input and attention. def seq2seq_f(lstm_inputs, decoder_inputs, seq_length, do_decode): num_hidden = attn_num_layers * attn_num_hidden - lstm_fw_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(num_hidden, forget_bias=0.0, state_is_tuple=False) + lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(num_hidden, forget_bias=0.0, state_is_tuple=False) # Backward direction cell - lstm_bw_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(num_hidden, forget_bias=0.0, state_is_tuple=False) + lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(num_hidden, forget_bias=0.0, state_is_tuple=False) pre_encoder_inputs, output_state_fw, output_state_bw = tf.contrib.rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, lstm_inputs, initial_state_fw=None, initial_state_bw=None,