@@ -59,6 +59,66 @@ def __init__(self, config):
59
59
def close_session (self ):
60
60
self .sess .close ()
61
61
62
+ def eval_with_loss (self ):
63
+ start_time = time .time ()
64
+
65
+ batch_num = 0
66
+ sum_loss = 0
67
+
68
+ true_positive , false_positive , false_negative = 0 , 0 , 0
69
+ total_prediction_batches = 0
70
+
71
+ if self .eval_queue is None :
72
+ self .eval_queue = reader .Reader (subtoken_to_index = self .subtoken_to_index ,
73
+ node_to_index = self .node_to_index ,
74
+ target_to_index = self .target_to_index ,
75
+ config = self .config , is_evaluating = True )
76
+ reader_output = self .eval_queue .get_output ()
77
+ _ , loss = self .build_training_graph (reader_output )
78
+ self .eval_predicted_indices_op , self .eval_topk_values , _ , _ = \
79
+ self .build_test_graph (reader_output )
80
+ self .eval_true_target_strings_op = reader_output [reader .TARGET_STRING_KEY ]
81
+
82
+ self .initialize_session_variables (self .sess )
83
+ print ('Initalized variables' )
84
+ if self .config .LOAD_PATH :
85
+ self .load_model (self .sess )
86
+
87
+ time .sleep (1 )
88
+ print ('Started reader...' )
89
+
90
+ multi_batch_start_time = time .time ()
91
+ for iteration in range (1 , (self .config .NUM_EPOCHS // self .config .SAVE_EVERY_EPOCHS ) + 1 ):
92
+ self .eval_queue .reset (self .sess )
93
+ try :
94
+ while True :
95
+ batch_num += 1
96
+
97
+ batch_loss = self .sess .run ([loss ])[0 ]
98
+ sum_loss += batch_loss
99
+
100
+ predicted_indices , true_target_strings , top_values = self .sess .run (
101
+ [self .eval_predicted_indices_op , self .eval_true_target_strings_op , self .eval_topk_values ],
102
+ )
103
+
104
+ print ('SINGLE BATCH LOSS' , batch_loss )
105
+
106
+ except tf .errors .OutOfRangeError :
107
+ self .epochs_trained += self .config .SAVE_EVERY_EPOCHS
108
+ print ('Finished %d epochs' % self .config .SAVE_EVERY_EPOCHS )
109
+ if self .config .BEAM_WIDTH == 0 :
110
+ print ('Accuracy after %d epochs: %.5f' % (self .epochs_trained , results ))
111
+ else :
112
+ print ('Accuracy after {} epochs: {}' .format (self .epochs_trained , results ))
113
+ print ('After %d epochs: Precision: %.5f, recall: %.5f, F1: %.5f' % (
114
+ self .epochs_trained , precision , recall , f1 ))
115
+ print ('Rouge: ' , rouge )
116
+
117
+ precision , recall , f1 = self .calculate_results (true_positive , false_positive , false_negative )
118
+
119
+ return num_correct_predictions / total_predictions , \
120
+ precision , recall , f1 , rouge , sum_loss / batch_num
121
+
62
122
def train (self ):
63
123
print ('Starting training' )
64
124
start_time = time .time ()
@@ -75,6 +135,7 @@ def train(self):
75
135
node_to_index = self .node_to_index ,
76
136
target_to_index = self .target_to_index ,
77
137
config = self .config )
138
+
78
139
optimizer , train_loss = self .build_training_graph (self .queue_thread .get_output ())
79
140
self .print_hyperparams ()
80
141
print ('Number of trainable params:' ,
@@ -339,7 +400,7 @@ def build_training_graph(self, input_tensors):
339
400
path_lengths = input_tensors [reader .PATH_LENGTHS_KEY ]
340
401
path_target_lengths = input_tensors [reader .PATH_TARGET_LENGTHS_KEY ]
341
402
342
- with tf .variable_scope ('model' ):
403
+ with tf .variable_scope ('model' , reuse = tf . AUTO_REUSE ):
343
404
subtoken_vocab = tf .get_variable ('SUBTOKENS_VOCAB' ,
344
405
shape = (self .subtoken_vocab_size , self .config .EMBEDDINGS_SIZE ),
345
406
dtype = tf .float32 ,
@@ -550,7 +611,7 @@ def build_test_graph(self, input_tensors):
550
611
path_lengths = input_tensors [reader .PATH_LENGTHS_KEY ]
551
612
path_target_lengths = input_tensors [reader .PATH_TARGET_LENGTHS_KEY ]
552
613
553
- with tf .variable_scope ('model' , reuse = self . get_should_reuse_variables () ):
614
+ with tf .variable_scope ('model' , reuse = tf . AUTO_REUSE ):
554
615
subtoken_vocab = tf .get_variable ('SUBTOKENS_VOCAB' ,
555
616
shape = (self .subtoken_vocab_size , self .config .EMBEDDINGS_SIZE ),
556
617
dtype = tf .float32 , trainable = False )
0 commit comments