From 92426329bdbe15b250a90f28229f4651891937ec Mon Sep 17 00:00:00 2001 From: Brandon Morris Date: Tue, 2 Oct 2018 11:48:31 -0700 Subject: [PATCH 1/4] Switch main.py to use argparse --- .gitignore | 3 +++ main.py | 71 ++++++++++++++++++++++-------------------------------- 2 files changed, 32 insertions(+), 42 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1f7c3da --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +save/ +**/__pycache__ +experiment-log-*.csv diff --git a/main.py b/main.py index 9721d05..2924e2c 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,4 @@ -import getopt +from argparse import ArgumentParser import sys from colorama import Fore @@ -12,17 +12,21 @@ from models.textGan_MMD.Textgan import TextganMmd +supported_gans = { + 'seqgan': Seqgan, + 'gsgan': Gsgan, + 'textgan': TextganMmd, + 'leakgan': Leakgan, + 'rankgan': Rankgan, + 'maligan': Maligan, + 'mle': Mle +} +supported_training = {'oracle', 'cfg', 'real'} + + def set_gan(gan_name): - gans = dict() - gans['seqgan'] = Seqgan - gans['gsgan'] = Gsgan - gans['textgan'] = TextganMmd - gans['leakgan'] = Leakgan - gans['rankgan'] = Rankgan - gans['maligan'] = Maligan - gans['mle'] = Mle try: - Gan = gans[gan_name.lower()] + Gan = supported_gans[gan_name.lower()] gan = Gan() gan.vocab_size = 5000 gan.generate_num = 10000 @@ -32,7 +36,6 @@ def set_gan(gan_name): sys.exit(-2) - def set_training(gan, training_method): try: if training_method == 'oracle': @@ -50,36 +53,20 @@ def set_training(gan, training_method): return gan_func -def parse_cmd(argv): - try: - opts, args = getopt.getopt(argv, "hg:t:d:") - - opt_arg = dict(opts) - if '-h' in opt_arg.keys(): - print('usage: python main.py -g ') - print(' python main.py -g -t ') - print(' python main.py -g -t realdata -d ') - sys.exit(0) - if not '-g' in opt_arg.keys(): - print('unspecified GAN type, use MLE training only...') - gan = set_gan('mle') - else: - gan = set_gan(opt_arg['-g']) - if not '-t' in opt_arg.keys(): - gan.train_oracle() - else: - gan_func = set_training(gan, opt_arg['-t']) - if opt_arg['-t'] == 'real' and '-d' in opt_arg.keys(): - gan_func(opt_arg['-d']) - else: - gan_func() - except getopt.GetoptError: - print('invalid arguments!') - print('`python main.py -h` for help') - sys.exit(-1) - pass - +def parse_cmd(): + parser = ArgumentParser() + parser.add_argument('-g', '--gan-type', help='The type of GAN to use', + choices=set(supported_gans.keys()), default='mle') + parser.add_argument('-t', '--train-type', help='Type of training to use', + choices=supported_training, default='oracle') + parser.add_argument('-d', '--data', default='data/image_coco.txt') + return parser.parse_known_args() if __name__ == '__main__': - gan = None - parse_cmd(sys.argv[1:]) + args, unused_args = parse_cmd() + gan = set_gan(args.gan_type) + train_f = set_training(gan, args.train_type) + if args.train_type == 'real': + train_f(args.data) + else: + train_f() From ed993ad900ffc20469cdd041553153dfa463969e Mon Sep 17 00:00:00 2001 From: Brandon Morris Date: Thu, 4 Oct 2018 13:33:33 -0700 Subject: [PATCH 2/4] Switch main.py to use tf.app.flags --- main.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/main.py b/main.py index 2924e2c..f61800a 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,7 @@ -from argparse import ArgumentParser import sys from colorama import Fore +import tensorflow as tf from models.gsgan.Gsgan import Gsgan from models.leakgan.Leakgan import Leakgan @@ -54,19 +54,20 @@ def set_training(gan, training_method): def parse_cmd(): - parser = ArgumentParser() - parser.add_argument('-g', '--gan-type', help='The type of GAN to use', - choices=set(supported_gans.keys()), default='mle') - parser.add_argument('-t', '--train-type', help='Type of training to use', - choices=supported_training, default='oracle') - parser.add_argument('-d', '--data', default='data/image_coco.txt') - return parser.parse_known_args() + flags = tf.app.flags + flags.DEFINE_enum('gan_type', 'mle', list(supported_gans.keys()), + 'Type of GAN to use') + flags.DEFINE_enum('train_type', 'oracle', supported_training, + 'Type of training to use') + flags.DEFINE_string('data', 'data/image_coco.txt', '') + return if __name__ == '__main__': - args, unused_args = parse_cmd() - gan = set_gan(args.gan_type) - train_f = set_training(gan, args.train_type) - if args.train_type == 'real': - train_f(args.data) + parse_cmd() + flags = tf.app.flags.FLAGS + gan = set_gan(flags.gan_type) + train_f = set_training(gan, flags.train_type) + if flags.train_type == 'real': + train_f(flags.data) else: train_f() From b04bbcfe834b5fd0f0f13da6affed9f51e0aa34a Mon Sep 17 00:00:00 2001 From: Brandon Morris Date: Thu, 4 Oct 2018 14:07:57 -0700 Subject: [PATCH 3/4] Allow output files to be specified as cli args --- models/Gan.py | 10 ++++++++++ models/gsgan/Gsgan.py | 4 ---- models/leakgan/Leakgan.py | 4 ---- models/maligan_basic/Maligan.py | 4 ---- models/mle/Mle.py | 4 ---- models/pg_bleu/Pgbleu.py | 3 --- models/rankgan/Rankgan.py | 4 ---- models/seqgan/Seqgan.py | 4 ---- models/textGan_MMD/Textgan.py | 4 ---- 9 files changed, 10 insertions(+), 31 deletions(-) diff --git a/models/Gan.py b/models/Gan.py index f727d43..43dc6e0 100644 --- a/models/Gan.py +++ b/models/Gan.py @@ -1,4 +1,5 @@ from abc import abstractmethod +from tensorflow.app import flags from utils.utils import init_sess @@ -19,6 +20,15 @@ def __init__(self): self.log = None self.reward = None + flags.DEFINE_string('oracle_file', 'save/oracle.txt', '') + flags.DEFINE_string('generator_file', 'save/generator.txt', '') + flags.DEFINE_string('test_file', 'save/test_file.txt', '') + FLAGS = flags.FLAGS + self.oracle_file = FLAGS.oracle_file + self.generator_file = FLAGS.generator_file + self.test_file = FLAGS.test_file + + def set_oracle(self, oracle): self.oracle = oracle diff --git a/models/gsgan/Gsgan.py b/models/gsgan/Gsgan.py index 7daa7a5..631be4a 100644 --- a/models/gsgan/Gsgan.py +++ b/models/gsgan/Gsgan.py @@ -31,10 +31,6 @@ def __init__(self, oracle=None): self.generate_num = 128 self.start_token = 0 - self.oracle_file = 'save/oracle.txt' - self.generator_file = 'save/generator.txt' - self.test_file = 'save/test_file.txt' - def init_oracle_trainng(self, oracle=None): if oracle is None: oracle = OracleLstm(num_vocabulary=self.vocab_size, batch_size=self.batch_size, emb_dim=self.emb_dim, diff --git a/models/leakgan/Leakgan.py b/models/leakgan/Leakgan.py index a6fec01..4f2f752 100644 --- a/models/leakgan/Leakgan.py +++ b/models/leakgan/Leakgan.py @@ -72,10 +72,6 @@ def __init__(self, oracle=None): self.dis_embedding_dim = 64 self.goal_size = 16 - self.oracle_file = 'save/oracle.txt' - self.generator_file = 'save/generator.txt' - self.test_file = 'save/test_file.txt' - def init_oracle_trainng(self, oracle=None): goal_out_size = sum(self.num_filters) diff --git a/models/maligan_basic/Maligan.py b/models/maligan_basic/Maligan.py index 32e1d0d..40ee4c9 100644 --- a/models/maligan_basic/Maligan.py +++ b/models/maligan_basic/Maligan.py @@ -28,10 +28,6 @@ def __init__(self, oracle=None): self.generate_num = 128 self.start_token = 0 - self.oracle_file = 'save/oracle.txt' - self.generator_file = 'save/generator.txt' - self.test_file = 'save/test_file.txt' - def init_oracle_trainng(self, oracle=None): if oracle is None: oracle = OracleLstm(num_vocabulary=self.vocab_size, batch_size=self.batch_size, emb_dim=self.emb_dim, diff --git a/models/mle/Mle.py b/models/mle/Mle.py index cbbe947..d1193ae 100644 --- a/models/mle/Mle.py +++ b/models/mle/Mle.py @@ -26,10 +26,6 @@ def __init__(self, oracle=None): self.generate_num = 128 self.start_token = 0 - self.oracle_file = 'save/oracle.txt' - self.generator_file = 'save/generator.txt' - self.test_file = 'save/test_file.txt' - def init_oracle_trainng(self, oracle=None): if oracle is None: oracle = OracleLstm(num_vocabulary=self.vocab_size, batch_size=self.batch_size, emb_dim=self.emb_dim, diff --git a/models/pg_bleu/Pgbleu.py b/models/pg_bleu/Pgbleu.py index 4d2aa86..83a2b17 100644 --- a/models/pg_bleu/Pgbleu.py +++ b/models/pg_bleu/Pgbleu.py @@ -27,9 +27,6 @@ def __init__(self, oracle=None): self.generate_num = 128 self.start_token = 0 - self.oracle_file = 'save/oracle.txt' - self.generator_file = 'save/generator.txt' - def init_oracle_trainng(self, oracle=None): if oracle is None: oracle = OracleLstm(num_vocabulary=self.vocab_size, batch_size=self.batch_size, emb_dim=self.emb_dim, diff --git a/models/rankgan/Rankgan.py b/models/rankgan/Rankgan.py index d6dfcf6..9ac660d 100644 --- a/models/rankgan/Rankgan.py +++ b/models/rankgan/Rankgan.py @@ -28,10 +28,6 @@ def __init__(self, oracle=None): self.generate_num = 128 self.start_token = 0 - self.oracle_file = 'save/oracle.txt' - self.generator_file = 'save/generator.txt' - self.test_file = 'save/test_file.txt' - def init_oracle_trainng(self, oracle=None): if oracle is None: oracle = OracleLstm(num_vocabulary=self.vocab_size, batch_size=self.batch_size, emb_dim=self.emb_dim, diff --git a/models/seqgan/Seqgan.py b/models/seqgan/Seqgan.py index 16b5320..1a69e5c 100644 --- a/models/seqgan/Seqgan.py +++ b/models/seqgan/Seqgan.py @@ -31,10 +31,6 @@ def __init__(self, oracle=None): self.generate_num = 128 self.start_token = 0 - self.oracle_file = 'save/oracle.txt' - self.generator_file = 'save/generator.txt' - self.test_file = 'save/test_file.txt' - def init_metric(self): nll = Nll(data_loader=self.oracle_data_loader, rnn=self.oracle, sess=self.sess) self.add_metric(nll) diff --git a/models/textGan_MMD/Textgan.py b/models/textGan_MMD/Textgan.py index dcd2e43..a560d21 100644 --- a/models/textGan_MMD/Textgan.py +++ b/models/textGan_MMD/Textgan.py @@ -50,10 +50,6 @@ def __init__(self, oracle=None): self.generate_num = 128 self.start_token = 0 - self.oracle_file = 'save/oracle.txt' - self.generator_file = 'save/generator.txt' - self.test_file = 'save/test_file.txt' - def init_oracle_trainng(self, oracle=None): if oracle is None: oracle = OracleLstm(num_vocabulary=self.vocab_size, batch_size=self.batch_size, emb_dim=self.emb_dim, From 7ed4ec2c69129bc01f115b826ec4a7e7d8c4605b Mon Sep 17 00:00:00 2001 From: Brandon Morris Date: Tue, 20 Nov 2018 15:36:09 -0700 Subject: [PATCH 4/4] Correct the cli args when running seqgan --- main.py | 3 +++ models/Gan.py | 3 --- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index f61800a..ed7d0c2 100644 --- a/main.py +++ b/main.py @@ -64,6 +64,9 @@ def parse_cmd(): if __name__ == '__main__': parse_cmd() + tf.app.flags.DEFINE_string('oracle_file', 'save/oracle.txt', '') + tf.app.flags.DEFINE_string('generator_file', 'save/generator.txt', '') + tf.app.flags.DEFINE_string('test_file', 'save/test_file.txt', '') flags = tf.app.flags.FLAGS gan = set_gan(flags.gan_type) train_f = set_training(gan, flags.train_type) diff --git a/models/Gan.py b/models/Gan.py index 43dc6e0..a06a2f7 100644 --- a/models/Gan.py +++ b/models/Gan.py @@ -20,9 +20,6 @@ def __init__(self): self.log = None self.reward = None - flags.DEFINE_string('oracle_file', 'save/oracle.txt', '') - flags.DEFINE_string('generator_file', 'save/generator.txt', '') - flags.DEFINE_string('test_file', 'save/test_file.txt', '') FLAGS = flags.FLAGS self.oracle_file = FLAGS.oracle_file self.generator_file = FLAGS.generator_file