19
19
help = 'Path to model .pt file' )
20
20
parser .add_argument ('-src' , required = True ,
21
21
help = 'Source sequence to decode (one line per sequence)' )
22
+ parser .add_argument ('-bio' )
23
+ parser .add_argument ('-feats' , default = [], nargs = '+' , type = str )
22
24
parser .add_argument ('-tgt' ,
23
25
help = 'True target sequence (optional)' )
24
26
parser .add_argument ('-output' , default = 'pred.txt' ,
@@ -66,7 +68,6 @@ def addPair(f1, f2):
66
68
67
69
68
70
def main ():
69
- raise Exception ('Not implemented' )
70
71
opt = parser .parse_args ()
71
72
logger .info (opt )
72
73
opt .cuda = opt .gpu > - 1
@@ -80,15 +81,22 @@ def main():
80
81
predScoreTotal , predWordsTotal , goldScoreTotal , goldWordsTotal = 0 , 0 , 0 , 0
81
82
82
83
srcBatch , tgtBatch = [], []
84
+ bio_batch , feats_batch = [], []
83
85
84
86
count = 0
85
87
86
88
tgtF = open (opt .tgt ) if opt .tgt else None
89
+ bioF = open (opt .bio , encoding = 'utf-8' )
90
+ featFs = [open (x , encoding = 'utf-8' ) for x in opt .feats ]
87
91
for line in addone (open (opt .src , encoding = 'utf-8' )):
88
92
89
93
if (line is not None ):
90
94
srcTokens = line .strip ().split (' ' )
91
95
srcBatch += [srcTokens ]
96
+ bio_tokens = bioF .readline ().strip ().split (' ' )
97
+ bio_batch += [bio_tokens ]
98
+ feats_tokens = [reader .readline ().strip ().split ((' ' )) for reader in featFs ]
99
+ feats_batch += [feats_tokens ]
92
100
if tgtF :
93
101
tgtTokens = tgtF .readline ().split (' ' ) if tgtF else None
94
102
tgtBatch += [tgtTokens ]
@@ -100,7 +108,7 @@ def main():
100
108
if len (srcBatch ) == 0 :
101
109
break
102
110
103
- predBatch , predScore , goldScore = translator .translate (srcBatch , tgtBatch )
111
+ predBatch , predScore , goldScore = translator .translate (srcBatch , bio_batch , feats_batch , tgtBatch )
104
112
105
113
predScoreTotal += sum (score [0 ] for score in predScore )
106
114
predWordsTotal += sum (len (x [0 ]) for x in predBatch )
@@ -136,6 +144,7 @@ def main():
136
144
logger .info ('' )
137
145
138
146
srcBatch , tgtBatch = [], []
147
+ bio_batch , feats_batch = [], []
139
148
140
149
reportScore ('PRED' , predScoreTotal , predWordsTotal )
141
150
# if tgtF:
0 commit comments