forked from ParallelMeaningBank/elephant
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathelephant-train
More file actions
executable file
·115 lines (107 loc) · 4.34 KB
/
elephant-train
File metadata and controls
executable file
·115 lines (107 loc) · 4.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#!/usr/bin/env python
import errno
import imp
import itertools
import argparse
import os
import shutil
import subprocess
import sys
import unicodedata as ud
from tempfile import mkstemp
def makedirs(path):
try:
os.makedirs(path)
except OSError, error:
if error.errno != errno.EEXIST:
raise error
def fst((a, b)):
return a
def snd((a, b)):
return b
def read_iob_file(iob_file):
sequence = []
with open(iob_file) as f:
for line in itertools.chain(f, ['']):
if line.rstrip():
code, label = line.split()
char = unichr(int(code))
sequence.append((char, label))
else:
if sequence:
text = ''.join(map(fst, sequence))
labels = map(snd, sequence)
yield text, labels
sequence = []
def create_tmp_file(iob_file):
# read IOB file
sequences = read_iob_file(iob_file)
replace = elephant.create_replacement_function() # TODO support vocabularies?
# write the training text on a temporary file for wapiti
fh, tmp_file = mkstemp()
#tmp_file = 'tmp'
fd = open(tmp_file, 'w')
# TODO We run a new elman process for each sequence. This is quite slow. Is
# there a better way?
for i, (text, labels) in enumerate(sequences):
print >> sys.stderr, 'featurizing sequence {}'.format(i + 1)
if elman_model: # Add rnn features
for (ch, features, label) in zip(text, elephant.augment(text, elman_model, replace), labels):
fd.write('{0} {1} {2} {3}\n'.format(ord(ch), ud.category(ch), features, label))
else: # just character codes and categories
for (ch, label) in zip(text, labels):
fd.write('{0} {1} {2}\n'.format(ord(ch), ud.category(ch), label))
fd.write('\n')
fd.close()
return tmp_file
if __name__ == '__main__':
elephant = imp.load_source('elephant', os.path.join(os.path.dirname(
__file__), 'elephant'))
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model', dest='model_dir', required=True,
help="""Target directory for the model that will be trained.""")
parser.add_argument('-e', '--elman', dest='elman_model',
help="""Parameter file describing the trained Elman RNN
character-level language model, generated by the elman binary
from a large corpus. Each model distributed with Elephant
contains such a paramter file, called "elman". They can be
reused for training Elephant on different data in the same
language, although creating a new RNN on comparable data can
improve performance.""")
parser.add_argument('-w', '--wapiti', dest='wapiti_pattern_file',
required=True,
help="""Wapiti pattern file specifying the feature set to use for
training the CRF model. A number of pattern files are
distributed with Elephant.""")
parser.add_argument('-i', '--input', dest='input_iob_file', required=True,
help="""Training corpus in IOB format.""")
parser.add_argument('-d', '--devel', dest='devel_iob_file',
help="""Development corpus in IOB format.""")
args = parser.parse_args()
model_dir = args.model_dir
elman_model = args.elman_model
wapiti_pattern_file = args.wapiti_pattern_file
input_iob_file = args.input_iob_file
devel_iob_file = args.devel_iob_file
wapiti_model = model_dir + '/wapiti'
# prepare output directory
makedirs(model_dir)
if elman_model:
shutil.copy(elman_model, os.path.join(model_dir, 'elman'))
tmp_input_file = create_tmp_file(input_iob_file)
if devel_iob_file:
tmp_devel_file = create_tmp_file(devel_iob_file)
if not elephant.check_exec('wapiti'):
raise RuntimeError("Unable to run wapiti")
# TODO make number of threads configurable
command = ['wapiti', 'train', '-t', '4', '-p',
wapiti_pattern_file, tmp_input_file, wapiti_model]
if devel_iob_file:
command.append('--devel')
command.append(tmp_devel_file)
command.append('--stopwin')
command.append('20')
subprocess.call(command)
os.unlink(tmp_input_file)
if devel_iob_file:
os.unlink(tmp_devel_file)