Skip to content
This repository was archived by the owner on Mar 19, 2021. It is now read-only.

Commit 3006ff8

Browse files
committed
Update translate code.
1 parent da36a6d commit 3006ff8

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

seq2seq_pt/s2s/Translator.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -200,14 +200,14 @@ def updateActive(t, rnnSize):
200200

201201
return allHyp, allScores, allIsCopy, allCopyPosition, allAttn, None
202202

203-
def translate(self, srcBatch, goldBatch):
203+
def translate(self, srcBatch, bio_batch, feats_batch, goldBatch):
204204
# (1) convert words to indexes
205-
dataset = self.buildData(srcBatch, goldBatch)
205+
dataset = self.buildData(srcBatch, bio_batch, feats_batch, goldBatch)
206206
# (wrap(srcBatch), lengths), (wrap(tgtBatch), ), indices
207-
src, tgt, indices = dataset[0]
207+
src, bio, feats, tgt, indices = dataset[0]
208208

209209
# (2) translate
210-
pred, predScore, predIsCopy, predCopyPosition, attn, _ = self.translateBatch(src, tgt)
210+
pred, predScore, predIsCopy, predCopyPosition, attn, _ = self.translateBatch(src, bio, feats, tgt)
211211
pred, predScore, predIsCopy, predCopyPosition, attn = list(zip(
212212
*sorted(zip(pred, predScore, predIsCopy, predCopyPosition, attn, indices),
213213
key=lambda x: x[-1])))[:-1]

seq2seq_pt/translate.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
help='Path to model .pt file')
2020
parser.add_argument('-src', required=True,
2121
help='Source sequence to decode (one line per sequence)')
22+
parser.add_argument('-bio')
23+
parser.add_argument('-feats', default=[], nargs='+', type=str)
2224
parser.add_argument('-tgt',
2325
help='True target sequence (optional)')
2426
parser.add_argument('-output', default='pred.txt',
@@ -66,7 +68,6 @@ def addPair(f1, f2):
6668

6769

6870
def main():
69-
raise Exception('Not implemented')
7071
opt = parser.parse_args()
7172
logger.info(opt)
7273
opt.cuda = opt.gpu > -1
@@ -80,15 +81,22 @@ def main():
8081
predScoreTotal, predWordsTotal, goldScoreTotal, goldWordsTotal = 0, 0, 0, 0
8182

8283
srcBatch, tgtBatch = [], []
84+
bio_batch, feats_batch = [], []
8385

8486
count = 0
8587

8688
tgtF = open(opt.tgt) if opt.tgt else None
89+
bioF = open(opt.bio, encoding='utf-8')
90+
featFs = [open(x, encoding='utf-8') for x in opt.feats]
8791
for line in addone(open(opt.src, encoding='utf-8')):
8892

8993
if (line is not None):
9094
srcTokens = line.strip().split(' ')
9195
srcBatch += [srcTokens]
96+
bio_tokens = bioF.readline().strip().split(' ')
97+
bio_batch += [bio_tokens]
98+
feats_tokens = [reader.readline().strip().split((' ')) for reader in featFs]
99+
feats_batch += [feats_tokens]
92100
if tgtF:
93101
tgtTokens = tgtF.readline().split(' ') if tgtF else None
94102
tgtBatch += [tgtTokens]
@@ -100,7 +108,7 @@ def main():
100108
if len(srcBatch) == 0:
101109
break
102110

103-
predBatch, predScore, goldScore = translator.translate(srcBatch, tgtBatch)
111+
predBatch, predScore, goldScore = translator.translate(srcBatch, bio_batch, feats_batch, tgtBatch)
104112

105113
predScoreTotal += sum(score[0] for score in predScore)
106114
predWordsTotal += sum(len(x[0]) for x in predBatch)
@@ -136,6 +144,7 @@ def main():
136144
logger.info('')
137145

138146
srcBatch, tgtBatch = [], []
147+
bio_batch, feats_batch = [], []
139148

140149
reportScore('PRED', predScoreTotal, predWordsTotal)
141150
# if tgtF:

0 commit comments

Comments
 (0)