Skip to content

Commit 63fcdac

Browse files
committed
Print hypotheses when using beam search
1 parent b8b4f74 commit 63fcdac

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

model.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -233,17 +233,20 @@ def update_correct_predictions(self, num_correct_predictions, output_file, resul
233233
if self.config.BEAM_WIDTH > 0:
234234
predicted_first = predicted[0]
235235
filtered_predicted_first_parts = Common.filter_impossible_names(predicted_first) # list
236-
output_file.write('Original: ' + Common.internal_delimiter.join(original_name_parts) +
237-
' , predicted 1st: ' + Common.internal_delimiter.join(filtered_predicted_first_parts) + '\n')
238236

239237
if self.config.BEAM_WIDTH == 0:
238+
output_file.write('Original: ' + Common.internal_delimiter.join(original_name_parts) +
239+
' , predicted 1st: ' + Common.internal_delimiter.join(filtered_predicted_first_parts) + '\n')
240240
if filtered_original == filtered_predicted_first_parts or Common.unique(filtered_original) == Common.unique(
241241
filtered_predicted_first_parts) or ''.join(filtered_original) == ''.join(filtered_predicted_first_parts):
242242
num_correct_predictions += 1
243243
else:
244244
filtered_predicted = [Common.internal_delimiter.join(Common.filter_impossible_names(p)) for p in predicted]
245245

246246
true_ref = original_name
247+
output_file.write('Original: ' + ' '.join(original_name_parts) + '\n')
248+
for i, p in enumerate(filtered_predicted):
249+
output_file.write('\t@{}: {}'.format(i + 1, ' '.join(p.split(Common.internal_delimiter)))+ '\n')
247250
if true_ref in filtered_predicted:
248251
index_of_correct = filtered_predicted.index(true_ref)
249252
update = np.concatenate(

0 commit comments

Comments
 (0)