Description
I use the CopyNetWapper to wrap a decoder, this is my code:
train_decoder = tf.contrib.seq2seq.AttentionWrapper(decoder, attention_mechanism,
attention_layer_size=self.config.PHVM_decoder_dim)
train_encoder_state = train_decoder.zero_state(self.batch_size, dtype=tf.float32).clone(
cell_state=sent_dec_state)
copynet_decoder = CopyNetWrapper(train_decoder, sent_input, sent_lens, sent_lens, self.tgt_vocab_size)
copy_train_encoder_state = copynet_decoder.zero_state(self.batch_size, dtype=tf.float32).clone(
cell_state=train_encoder_state)
However, during train I got an error:
Traceback (most recent call last):
File "/home/work/mnt/project/.local/lib/python3.6/site-packages/tensorflow/python/util/nest.py", line 297, in assert_same_structure
expand_composites)
TypeError: The two structures don't have the same nested structure.First structure: type=CopyNetWrapperState str=CopyNetWrapperState(cell_state=AttentionWrapperState(cell_state=(<tf.Tensor 'sentence_level/train/while/sent_deocde/CopyNetWrapperZeroState/AttentionWrapperZeroState/checked_cell_state:0' shape=(?, 300) dtype=float32>, <tf.Tensor 'sentence_level/train/while/sent_deocde/CopyNetWrapperZeroState/AttentionWrapperZeroState/checked_cell_state_1:0' shape=(?, 300) dtype=float32>), attention=<tf.Tensor 'sentence_level/train/while/sent_deocde/CopyNetWrapperZeroState/AttentionWrapperZeroState/zeros_2:0' shape=(?, 300) dtype=float32>, time=<tf.Tensor 'sentence_level/train/while/sent_deocde/CopyNetWrapperZeroState/AttentionWrapperZeroState/zeros_1:0' shape=() dtype=int32>, alignments=<tf.Tensor 'sentence_level/train/while/sent_deocde/CopyNetWrapperZeroState/AttentionWrapperZeroState/zeros:0' shape=(?, ?) dtype=float32>, alignment_history=(), attention_state=<tf.Tensor 'sentence_level/train/while/sent_deocde/CopyNetWrapperZeroState/AttentionWrapperZeroState/zeros_3:0' shape=(?, ?) dtype=float32>), last_ids=<tf.Tensor 'sentence_level/train/while/sent_deocde/CopyNetWrapperZeroState/sub:0' shape=(?,) dtype=int32>, prob_c=<tf.Tensor 'sentence_level/train/while/sent_deocde/CopyNetWrapperZeroState/zeros_1:0' shape=(?, ?) dtype=float32>)
Second structure: type=CopyNetWrapperState str=CopyNetWrapperState(cell_state=(<tf.Tensor 'sentence_level/train/while/sent_deocde/sent_dec_state/dense/BiasAdd:0' shape=(?, 300) dtype=float32>, <tf.Tensor 'sentence_level/train/while/sent_deocde/sent_dec_state/dense_1/BiasAdd:0' shape=(?, 300) dtype=float32>), last_ids=<tf.Tensor 'sentence_level/train/while/sent_deocde/CopyNetWrapperZeroState/sub:0' shape=(?,) dtype=int32>, prob_c=<tf.Tensor 'sentence_level/train/while/sent_deocde/CopyNetWrapperZeroState/zeros_1:0' shape=(?, ?) dtype=float32>)
More specifically: The two namedtuples don't have the same sequence type. First structure type=AttentionWrapperState str=AttentionWrapperState(cell_state=(<tf.Tensor 'sentence_level/train/while/sent_deocde/CopyNetWrapperZeroState/AttentionWrapperZeroState/checked_cell_state:0' shape=(?, 300) dtype=float32>, <tf.Tensor 'sentence_level/train/while/sent_deocde/CopyNetWrapperZeroState/AttentionWrapperZeroState/checked_cell_state_1:0' shape=(?, 300) dtype=float32>), attention=<tf.Tensor 'sentence_level/train/while/sent_deocde/CopyNetWrapperZeroState/AttentionWrapperZeroState/zeros_2:0' shape=(?, 300) dtype=float32>, time=<tf.Tensor 'sentence_level/train/while/sent_deocde/CopyNetWrapperZeroState/AttentionWrapperZeroState/zeros_1:0' shape=() dtype=int32>, alignments=<tf.Tensor 'sentence_level/train/while/sent_deocde/CopyNetWrapperZeroState/AttentionWrapperZeroState/zeros:0' shape=(?, ?) dtype=float32>, alignment_history=(), attention_state=<tf.Tensor 'sentence_level/train/while/sent_deocde/CopyNetWrapperZeroState/AttentionWrapperZeroState/zeros_3:0' shape=(?, ?) dtype=float32>) has type AttentionWrapperState, while second structure type=tuple str=(<tf.Tensor 'sentence_level/train/while/sent_deocde/sent_dec_state/dense/BiasAdd:0' shape=(?, 300) dtype=float32>, <tf.Tensor 'sentence_level/train/while/sent_deocde/sent_dec_state/dense_1/BiasAdd:0' shape=(?, 300) dtype=float32>) has type tuple
I have no idea how to solve the problem. Have someone met the same error before?
It will be very nice if you can offer me some advice.
Plus,
copynet_cell = CopyNetWrapper(cell, encoder_outputs, encoder_input_ids,
vocab_size, gen_vocab_size)
Would you please explain the parameters in detail ?
Looking forward to any reply.