Skip to content

Commit 8846ebb

Browse files
Single point eval loss
1 parent 18b2e8e commit 8846ebb

File tree

2 files changed

+65
-3
lines changed

2 files changed

+65
-3
lines changed

code2seq.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,11 @@
3838
if config.TRAIN_PATH:
3939
model.train()
4040
if config.TEST_PATH and not args.data_path:
41-
results, precision, recall, f1, rouge = model.evaluate()
41+
results, precision, recall, f1, rouge, loss = model.eval_with_loss()
4242
print('Accuracy: ' + str(results))
4343
print('Precision: ' + str(precision) + ', recall: ' + str(recall) + ', F1: ' + str(f1))
4444
print('Rouge: ', rouge)
45+
print("Avg loss: ", loss)
4546
if args.predict:
4647
predictor = InteractivePredictor(config, model)
4748
predictor.predict()

model.py

+63-2
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,66 @@ def __init__(self, config):
5959
def close_session(self):
6060
self.sess.close()
6161

62+
def eval_with_loss(self):
63+
start_time = time.time()
64+
65+
batch_num = 0
66+
sum_loss = 0
67+
68+
true_positive, false_positive, false_negative = 0, 0, 0
69+
total_prediction_batches = 0
70+
71+
if self.eval_queue is None:
72+
self.eval_queue = reader.Reader(subtoken_to_index=self.subtoken_to_index,
73+
node_to_index=self.node_to_index,
74+
target_to_index=self.target_to_index,
75+
config=self.config, is_evaluating=True)
76+
reader_output = self.eval_queue.get_output()
77+
_, loss = self.build_training_graph(reader_output)
78+
self.eval_predicted_indices_op, self.eval_topk_values, _, _ = \
79+
self.build_test_graph(reader_output)
80+
self.eval_true_target_strings_op = reader_output[reader.TARGET_STRING_KEY]
81+
82+
self.initialize_session_variables(self.sess)
83+
print('Initalized variables')
84+
if self.config.LOAD_PATH:
85+
self.load_model(self.sess)
86+
87+
time.sleep(1)
88+
print('Started reader...')
89+
90+
multi_batch_start_time = time.time()
91+
for iteration in range(1, (self.config.NUM_EPOCHS // self.config.SAVE_EVERY_EPOCHS) + 1):
92+
self.eval_queue.reset(self.sess)
93+
try:
94+
while True:
95+
batch_num += 1
96+
97+
batch_loss = self.sess.run([loss])[0]
98+
sum_loss += batch_loss
99+
100+
predicted_indices, true_target_strings, top_values = self.sess.run(
101+
[self.eval_predicted_indices_op, self.eval_true_target_strings_op, self.eval_topk_values],
102+
)
103+
104+
print('SINGLE BATCH LOSS', batch_loss)
105+
106+
except tf.errors.OutOfRangeError:
107+
self.epochs_trained += self.config.SAVE_EVERY_EPOCHS
108+
print('Finished %d epochs' % self.config.SAVE_EVERY_EPOCHS)
109+
if self.config.BEAM_WIDTH == 0:
110+
print('Accuracy after %d epochs: %.5f' % (self.epochs_trained, results))
111+
else:
112+
print('Accuracy after {} epochs: {}'.format(self.epochs_trained, results))
113+
print('After %d epochs: Precision: %.5f, recall: %.5f, F1: %.5f' % (
114+
self.epochs_trained, precision, recall, f1))
115+
print('Rouge: ', rouge)
116+
117+
precision, recall, f1 = self.calculate_results(true_positive, false_positive, false_negative)
118+
119+
return num_correct_predictions / total_predictions, \
120+
precision, recall, f1, rouge, sum_loss / batch_num
121+
62122
def train(self):
63123
print('Starting training')
64124
start_time = time.time()
@@ -75,6 +135,7 @@ def train(self):
75135
node_to_index=self.node_to_index,
76136
target_to_index=self.target_to_index,
77137
config=self.config)
138+
78139
optimizer, train_loss = self.build_training_graph(self.queue_thread.get_output())
79140
self.print_hyperparams()
80141
print('Number of trainable params:',
@@ -339,7 +400,7 @@ def build_training_graph(self, input_tensors):
339400
path_lengths = input_tensors[reader.PATH_LENGTHS_KEY]
340401
path_target_lengths = input_tensors[reader.PATH_TARGET_LENGTHS_KEY]
341402

342-
with tf.variable_scope('model'):
403+
with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
343404
subtoken_vocab = tf.get_variable('SUBTOKENS_VOCAB',
344405
shape=(self.subtoken_vocab_size, self.config.EMBEDDINGS_SIZE),
345406
dtype=tf.float32,
@@ -550,7 +611,7 @@ def build_test_graph(self, input_tensors):
550611
path_lengths = input_tensors[reader.PATH_LENGTHS_KEY]
551612
path_target_lengths = input_tensors[reader.PATH_TARGET_LENGTHS_KEY]
552613

553-
with tf.variable_scope('model', reuse=self.get_should_reuse_variables()):
614+
with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
554615
subtoken_vocab = tf.get_variable('SUBTOKENS_VOCAB',
555616
shape=(self.subtoken_vocab_size, self.config.EMBEDDINGS_SIZE),
556617
dtype=tf.float32, trainable=False)

0 commit comments

Comments
 (0)