-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
some refactoring + a lot of fixes by Aaron
- Loading branch information
1 parent
445220c
commit ea15fee
Showing
4 changed files
with
445 additions
and
251 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,73 +1,193 @@ | ||
import os,sys | ||
import utils | ||
import copy | ||
import random | ||
|
||
#TODO: this whole file is now quite hacky - used to be mostly useful for | ||
#pseudoProj | ||
class OptionsManager(object): | ||
|
||
def __init__(self,options): | ||
""" | ||
input: parser options | ||
object to harmonise the way we deal with the parser | ||
""" | ||
|
||
print 'Using external embedding:', options.external_embedding | ||
|
||
if options.include and not options.datadir: | ||
raise Exception("You need to specify the data dir to include UD\ | ||
languages") | ||
#TODO: maybe add more sanity checks | ||
raise Exception("You need to specify the data dir to include UD languages") | ||
|
||
if not options.predictFlag: | ||
if not options.include and not options.trainfile: | ||
raise Exception("If not using the --include option, you must specify your training data with --trainfile") | ||
else: | ||
if not options.include and not options.testfile: | ||
raise Exception("If not using the --include option, you must specify your test data with --testfile") | ||
if not options.modeldir: | ||
options.modeldir = options.outdir # set model directory to output directory by default | ||
|
||
if not options.outdir: | ||
raise Exception("You must specify an output directory via the --outdir option") | ||
elif not os.path.exists(options.outdir): # create output directory if it doesn't exist | ||
print "Creating output directory " + options.outdir | ||
os.mkdir(options.outdir) | ||
|
||
if not options.predictFlag and not (options.rlFlag or options.rlMostFlag or options.headFlag): | ||
raise Exception("You must use either --userlmost or --userl or\ | ||
--usehead (you can use multiple)") | ||
#the diff between two is one is r/l/most child / the other is | ||
#element in the sentence | ||
#Eli's paper: | ||
#extended feature set | ||
# rightmost and leftmost modifiers of s0, s1 and s2 + leftmost | ||
# modifier of b0 | ||
|
||
if not options.include: | ||
raise Exception("Must include either head, rl or rlmost (For example, if you specified --disable-head and --disable-rlmost, you must specify --userl)") | ||
|
||
if options.rlFlag and options.rlMostFlag: | ||
print 'Warning: Switching off rlMostFlag to allow rlFlag to take precedence' | ||
options.rlMostFlag = False | ||
|
||
#TODO: maybe add more sanity checks | ||
|
||
#this is now useless | ||
options.drop_nproj = False | ||
|
||
options.multi_monoling = False # set default | ||
self.iterations = 1 # set default | ||
|
||
if not options.include: # must specifiy explicitly train | ||
treebank = utils.Treebank(options.trainfile, \ | ||
options.devfile, options.testfile) | ||
treebank.iso_id = None | ||
treebank.outdir = options.outdir | ||
treebank.modeldir = options.modeldir | ||
#just one model specified by train/dev and/or test | ||
if options.predictFlag: | ||
self.conllu = (os.path.splitext(options.conll_test.lower())[1] == '.conllu') | ||
if not options.testfile: | ||
raise Exception("--testfile must be specified") | ||
elif not os.path.exists(options.testfile): | ||
raise Exception("Test file " + options.testfile + " not found") | ||
else: | ||
self.conllu = (os.path.splitext(options.testfile.lower())[1] == '.conllu') # test if file in conllu format | ||
treebank.test_gold = options.testfile | ||
else: | ||
self.conllu = (os.path.splitext(options.conll_dev.lower())[1] == '.conllu') | ||
self.treebank = utils.Treebank(options.conll_train, \ | ||
options.conll_dev, options.conll_test) | ||
self.treebank.iso_id = None | ||
self.prepareDev(treebank,options) | ||
if options.devfile: | ||
self.conllu = (os.path.splitext(options.devfile.lower())[1] == '.conllu') | ||
elif options.create_dev: | ||
self.conllu = (os.path.splitext(options.trainfile.lower())[1] == '.conllu') | ||
|
||
if options.debug: | ||
self.createDebugData(treebank,options) | ||
|
||
self.languages = [treebank] # make it a list of one element just for the sake of consistency with the "include" case | ||
|
||
else: | ||
self.conllu = True | ||
language_list = utils.parse_list_arg(options.include) | ||
json_treebanks = utils.conll_dir_to_list(language_list,options.datadir,options.shared_task, | ||
self.conllu = True # file is in conllu format | ||
language_list = utils.parse_list_arg(options.include) # languages requested by the user via the include flag | ||
json_treebanks = utils.conll_dir_to_list(language_list,options.datadir,options.shared_task, # list of the available treebanks | ||
options.shared_task_datadir) | ||
self.languages = [lang for lang in json_treebanks if lang.iso_id in language_list] | ||
for language in self.languages: | ||
language.removeme = False | ||
language.outdir= "%s/%s"%(options.output,language.iso_id) | ||
language.modelDir= "%s/%s"%(options.modelDir,language.iso_id) | ||
model = "%s/%s"%(language.modelDir,options.model) | ||
if options.predictFlag and not os.path.exists(model): | ||
if not options.shared_task: | ||
# self.languages = [lang for lang in json_treebanks if lang.iso_id in language_list] | ||
treebank_dict = {lang.iso_id: lang for lang in json_treebanks} | ||
self.languages = [] | ||
for lang in language_list: | ||
if lang in treebank_dict: | ||
self.languages.append(treebank_dict[lang]) | ||
else: | ||
print "Warning: skipping invalid language code " + lang | ||
|
||
if options.multiling: | ||
if options.predictFlag: | ||
model = "%s/%s"%(options.modeldir,options.model) | ||
if not os.path.exists(model): # in multilingual case need model to be found in first language specified | ||
raise Exception("Model not found. Path tried: %s"%model) | ||
else: | ||
#find model for the language in question | ||
for otherl in json_treebanks: | ||
if otherl.lcode == language.lcode: | ||
if otherl.lcode == otherl.iso_id: | ||
language.modelDir = "%s/%s"%(options.modelDir,otherl.iso_id) | ||
|
||
if not os.path.exists(language.outdir): | ||
if options.model_selection: # can only do model selection for monolingual case | ||
print "Warning: model selection on dev data not available for multilingual case" | ||
options.model_selection = False | ||
else: | ||
options.multi_monoling = True | ||
self.iterations = len(self.languages) | ||
|
||
for lang_index in xrange(len(self.languages)): | ||
language = self.languages[lang_index] | ||
|
||
language.outdir= "%s/%s"%(options.outdir,language.iso_id) | ||
if not os.path.exists(language.outdir): # create language-specific output folder if it doesn't exist | ||
print "Creating language-specific output directory " + language.outdir | ||
os.mkdir(language.outdir) | ||
else: | ||
print ("Warning: language-specific subdirectory " + language.outdir | ||
+ " already exists, contents may be overwritten") | ||
|
||
for language in self.languages: | ||
if language.removeme: | ||
self.languages.remove(language) | ||
if not options.predictFlag: | ||
self.prepareDev(language,options) | ||
|
||
if options.include and not options.multiling: | ||
options.multi_monoling = True | ||
self.iterations = len(self.languages) | ||
else: | ||
options.multi_monoling = False | ||
self.iterations = 1 | ||
#this is now useless | ||
options.drop_proj = False | ||
if options.debug: # it is important that prepareDev be called before createDebugData | ||
self.createDebugData(language,options) | ||
|
||
if options.predictFlag and options.multi_monoling: | ||
language.modeldir= "%s/%s"%(options.modeldir,language.iso_id) | ||
model = "%s/%s"%(language.modeldir,options.model) | ||
if not os.path.exists(model): # in multilingual case need model to be found in first language specified | ||
if not options.shared_task: | ||
raise Exception("Model not found. Path tried: %s"%model) | ||
else: | ||
#find model for the language in question | ||
for otherl in json_treebanks: | ||
if otherl.lcode == language.lcode: | ||
if otherl.lcode == otherl.iso_id: | ||
language.modeldir = "%s/%s"%(options.modeldir,otherl.iso_id) | ||
|
||
# creates dev data by siphoning off a portion of the training data (when necessary) | ||
# sets up treebank for prediction and model selection on dev data | ||
def prepareDev(self,treebank,options): | ||
treebank.pred_dev = options.pred_dev # even if options.pred_dev is True, might change treebank.pred_dev to False later if no dev data available | ||
treebank.model_selection = False | ||
if not treebank.devfile or not os.path.exists(treebank.devfile): | ||
if options.create_dev: # create some dev data from the training data | ||
traindata = list(utils.read_conll(treebank.trainfile,treebank.iso_id)) | ||
tot_sen = len(traindata) | ||
if tot_sen > options.min_train_sents: # need to have at least min_train_sents to move forward | ||
dev_file = os.path.join(treebank.outdir,'dev-split' + '.conllu') # location for the new dev file | ||
train_file = os.path.join(treebank.outdir,'train-split' + '.conllu') # location for the new train file | ||
dev_len = int(0.01*options.dev_percent*tot_sen) | ||
print ("Taking " + str(dev_len) + " of " + str(tot_sen) | ||
+ " sentences from training data as new dev data for " + treebank.name) | ||
random.shuffle(traindata) | ||
devdata = traindata[:dev_len] | ||
utils.write_conll(dev_file,devdata) # write the new dev data to file | ||
traindata = traindata[dev_len:] # put the rest of the training data in a new file too | ||
utils.write_conll(train_file,traindata) | ||
# update some variables with the new file locations | ||
treebank.dev_gold = dev_file | ||
treebank.devfile = dev_file | ||
treebank.trainfile = train_file | ||
else: # not enough sentences | ||
print ("Warning: not enough sentences in training data to create dev set for " | ||
+ treebank.name + " (minimum required --min-train-size: " + str(options.min_train_sents) + ")") | ||
treebank.pred_dev = False | ||
else: # option --create-dev not set | ||
print ("Warning: No dev data for " + treebank.name | ||
+ ", consider adding option --create-dev to create dev data from training set") | ||
treebank.pred_dev = False | ||
if options.model_selection: | ||
treebank.dev_best = [options.epochs,0] # epoch (final by default), score of best dev epoch | ||
if treebank.pred_dev: | ||
treebank.model_selection = True | ||
else: | ||
print "Warning: can't do model selection for " + treebank.name + " as prediction on dev data is off" | ||
|
||
# if debug options is set, we read in the training, dev and test files as appropriate, cap the number of sentences and store | ||
# new files with these smaller data sets | ||
def createDebugData(self,treebank,options): | ||
print 'Creating smaller data sets for debugging' | ||
if not options.predictFlag: | ||
traindata = list(utils.read_conll(treebank.trainfile,treebank.iso_id,maxSize=options.debug_train_sents,hard_lim=True)) | ||
train_file = os.path.join(treebank.outdir,'train-debug' + '.conllu') # location for the new train file | ||
utils.write_conll(train_file,traindata) # write the new dev data to file | ||
treebank.trainfile = train_file | ||
if treebank.devfile and os.path.exists(treebank.devfile): | ||
devdata = list(utils.read_conll(treebank.devfile,treebank.iso_id,maxSize=options.debug_dev_sents,hard_lim=True)) | ||
dev_file = os.path.join(treebank.outdir,'dev-debug' + '.conllu') # location for the new dev file | ||
utils.write_conll(dev_file,devdata) # write the new dev data to file | ||
treebank.dev_gold = dev_file | ||
treebank.devfile = dev_file | ||
else: | ||
testdata = list(utils.read_conll(treebank.testfile,treebank.iso_id,maxSize=options.debug_test_sents,hard_lim=True)) | ||
test_file = os.path.join(treebank.outdir,'test-debug' + '.conllu') # location for the new dev file | ||
utils.write_conll(test_file,testdata) # write the new dev data to file | ||
treebank.test_gold = test_file | ||
treebank.testfile = test_file |
Oops, something went wrong.