We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents c8c49e5 + 8703513 commit ec0ae30Copy full SHA for ec0ae30
model.py
@@ -599,7 +599,7 @@ def predict(self, predict_data_lines):
599
predicted_strings = [[self.index_to_target[sugg] for sugg in timestep]
600
for timestep in predicted_indices] # (target_length, top-k)
601
predicted_strings = list(map(list, zip(*predicted_strings))) # (top-k, target_length)
602
- top_scores = [np.exp(np.sum(s, 0)) for s in top_scores]
+ top_scores = [np.exp(np.sum(s)) for s in zip(*top_scores)]
603
else:
604
predicted_strings = [self.index_to_target[idx]
605
for idx in predicted_indices] # (batch, target_length)
0 commit comments