@@ -59,6 +59,66 @@ def __init__(self, config):
5959 def close_session (self ):
6060 self .sess .close ()
6161
62+ def eval_with_loss (self ):
63+ start_time = time .time ()
64+
65+ batch_num = 0
66+ sum_loss = 0
67+
68+ true_positive , false_positive , false_negative = 0 , 0 , 0
69+ total_prediction_batches = 0
70+
71+ if self .eval_queue is None :
72+ self .eval_queue = reader .Reader (subtoken_to_index = self .subtoken_to_index ,
73+ node_to_index = self .node_to_index ,
74+ target_to_index = self .target_to_index ,
75+ config = self .config , is_evaluating = True )
76+ reader_output = self .eval_queue .get_output ()
77+ _ , loss = self .build_training_graph (reader_output )
78+ self .eval_predicted_indices_op , self .eval_topk_values , _ , _ = \
79+ self .build_test_graph (reader_output )
80+ self .eval_true_target_strings_op = reader_output [reader .TARGET_STRING_KEY ]
81+
82+ self .initialize_session_variables (self .sess )
83+ print ('Initalized variables' )
84+ if self .config .LOAD_PATH :
85+ self .load_model (self .sess )
86+
87+ time .sleep (1 )
88+ print ('Started reader...' )
89+
90+ multi_batch_start_time = time .time ()
91+ for iteration in range (1 , (self .config .NUM_EPOCHS // self .config .SAVE_EVERY_EPOCHS ) + 1 ):
92+ self .eval_queue .reset (self .sess )
93+ try :
94+ while True :
95+ batch_num += 1
96+
97+ batch_loss = self .sess .run ([loss ])[0 ]
98+ sum_loss += batch_loss
99+
100+ predicted_indices , true_target_strings , top_values = self .sess .run (
101+ [self .eval_predicted_indices_op , self .eval_true_target_strings_op , self .eval_topk_values ],
102+ )
103+
104+ print ('SINGLE BATCH LOSS' , batch_loss )
105+
106+ except tf .errors .OutOfRangeError :
107+ self .epochs_trained += self .config .SAVE_EVERY_EPOCHS
108+ print ('Finished %d epochs' % self .config .SAVE_EVERY_EPOCHS )
109+ if self .config .BEAM_WIDTH == 0 :
110+ print ('Accuracy after %d epochs: %.5f' % (self .epochs_trained , results ))
111+ else :
112+ print ('Accuracy after {} epochs: {}' .format (self .epochs_trained , results ))
113+ print ('After %d epochs: Precision: %.5f, recall: %.5f, F1: %.5f' % (
114+ self .epochs_trained , precision , recall , f1 ))
115+ print ('Rouge: ' , rouge )
116+
117+ precision , recall , f1 = self .calculate_results (true_positive , false_positive , false_negative )
118+
119+ return num_correct_predictions / total_predictions , \
120+ precision , recall , f1 , rouge , sum_loss / batch_num
121+
62122 def train (self ):
63123 print ('Starting training' )
64124 start_time = time .time ()
@@ -75,6 +135,7 @@ def train(self):
75135 node_to_index = self .node_to_index ,
76136 target_to_index = self .target_to_index ,
77137 config = self .config )
138+
78139 optimizer , train_loss = self .build_training_graph (self .queue_thread .get_output ())
79140 self .print_hyperparams ()
80141 print ('Number of trainable params:' ,
@@ -339,7 +400,7 @@ def build_training_graph(self, input_tensors):
339400 path_lengths = input_tensors [reader .PATH_LENGTHS_KEY ]
340401 path_target_lengths = input_tensors [reader .PATH_TARGET_LENGTHS_KEY ]
341402
342- with tf .variable_scope ('model' ):
403+ with tf .variable_scope ('model' , reuse = tf . AUTO_REUSE ):
343404 subtoken_vocab = tf .get_variable ('SUBTOKENS_VOCAB' ,
344405 shape = (self .subtoken_vocab_size , self .config .EMBEDDINGS_SIZE ),
345406 dtype = tf .float32 ,
@@ -550,7 +611,7 @@ def build_test_graph(self, input_tensors):
550611 path_lengths = input_tensors [reader .PATH_LENGTHS_KEY ]
551612 path_target_lengths = input_tensors [reader .PATH_TARGET_LENGTHS_KEY ]
552613
553- with tf .variable_scope ('model' , reuse = self . get_should_reuse_variables () ):
614+ with tf .variable_scope ('model' , reuse = tf . AUTO_REUSE ):
554615 subtoken_vocab = tf .get_variable ('SUBTOKENS_VOCAB' ,
555616 shape = (self .subtoken_vocab_size , self .config .EMBEDDINGS_SIZE ),
556617 dtype = tf .float32 , trainable = False )
0 commit comments