Skip to content

Commit

Permalink
debug infer
Browse files Browse the repository at this point in the history
  • Loading branch information
golsun committed Dec 4, 2019
1 parent 67cac03 commit fbbc4db
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@

*.pyc
models/
18 changes: 18 additions & 0 deletions src/infer_args.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
method,latent
beam_width,100
n_rand,100
r_rand,2
softmax_temperature,1
min_logP,-1000.
wt_center,1.
wt_rep,1.
wt_len,0.5
wt_base,1.
base10m_vs_holmes38k/ngram,1.5
base10m_vs_holmes38k/neural,1.5
base10m_vs_arxiv1m/ngram,1.5
base10m_vs_arxiv1m/neural,1.5
crit_logP,-2.5
crit_rep,-0.3
rep_allow,0.9
lm_wt,0.95
15 changes: 12 additions & 3 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def run_master(mode, args):
else:
allowed_words = None


model_class = args.model_class.lower()
if model_class.startswith('fuse'):
Master = StyleFusion
Expand All @@ -51,6 +52,7 @@ def run_master(mode, args):
pass
else:
raise ValueError


if model_class == 's2s+lm':
master = Seq2SeqLM(args, allowed_words)
Expand All @@ -70,7 +72,12 @@ def run_master(mode, args):
if mode in ['vis', 'load']:
return master

CLF_NAMES = []
if args.clf_name.lower() == 'holmes':
CLF_NAMES = ['classifier/Reddit_vs_Holmes/neural', 'classifier/Reddit_vs_Holmes/ngram']
elif args.clf_name.lower() == 'arxiv':
CLF_NAMES = ['classifier/Reddit_vs_arxiv/neural', 'classifier/Reddit_vs_arxiv/ngram']
else:
CLF_NAMES = [args.clf_name]
print('loading classifiers '+str(CLF_NAMES))
master.clf_names = CLF_NAMES
master.classifiers = []
Expand Down Expand Up @@ -120,6 +127,7 @@ def run_master(mode, args):
print(s_surrogate)
return

"""
if model_class != 's2s+lm':
with tf.variable_scope('base_rankder', reuse=tf.AUTO_REUSE):
fld_base_ranker = 'restore/%s/%s/pretrained/'%(args.model_class.replace('fuse1','fuse'), args.data_name)
Expand All @@ -131,10 +139,11 @@ def run_master(mode, args):
base_ranker.load_weights(path)
print('*'*10 + ' base_ranker loaded from: '+path)
else:
base_ranker = None
"""
base_ranker = None

def print_results(results):
ss = ['total', 'logP', 'rep', 'len', 's2s', 'clf_h', 'clf_v']
ss = ['total', 'logP', 'logP_c', 'logP_b', 'rep', 'len',] + ['clf%i'%i for i in range(len(CLF_NAMES))]
print('; '.join([' '*(6-len(s))+s for s in ss]))
for score, resp, terms in results:
print('%6.3f; '%score + '; '.join(['%6.3f'%x for x in terms]) + '; ' + resp)
Expand Down
2 changes: 1 addition & 1 deletion src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def train(self, batch_per_load=100):
def load_weights(self, path):
self.prev_wt_fuse = None
print('loading weights from %s'%path)
npz = np.load(path, encoding='latin1')
npz = np.load(path, encoding='latin1', allow_pickle=True)
print(npz.files)
weights = npz['layers'].item()
for k in weights:
Expand Down

0 comments on commit fbbc4db

Please sign in to comment.