Skip to content

Commit

Permalink
fix wt_clf bug
Browse files Browse the repository at this point in the history
  • Loading branch information
golsun committed Dec 5, 2019
1 parent a8f0c51 commit ef3fc9c
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 16 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

*.pyc
models/
temp.py
4 changes: 2 additions & 2 deletions src/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,9 @@ def clf_interact(fld):
print('%.4f'%score)


def clf_eval(path):
def clf_eval(clf_fld, path):
# path is a tsv, last col is hyp
clf = load_classifier(fld)
clf = load_classifier(clf_fld)
sum_score = 0
n = 0
for line in open(path, encoding='utf-8'):
Expand Down
6 changes: 2 additions & 4 deletions src/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,7 @@ def rank_nbest(hyps, logP, logP_center, master, inp, infer_args=dict(), base_ran
hyps_no_ie.append((' '+hyp+' ').replace(' i . e . ,',' ').replace(' i . e. ',' ').strip())
hyps = hyps_no_ie[:]

wts_classifier = []
for clf_name in master.clf_names:
wts_classifier.append(infer_args.get(clf_name, 0))
wt_clf = infer_args.get('wt_clf', 0) / len(master.classifiers)
wt_rep = infer_args.get('wt_rep', 0)
wt_len = infer_args.get('wt_len', 0)
wt_center = infer_args.get('wt_center', 0)
Expand Down Expand Up @@ -224,7 +222,7 @@ def rank_nbest(hyps, logP, logP_center, master, inp, infer_args=dict(), base_ran
clf_score_ = []
for k in range(len(master.classifiers)):
s = clf_score[k][i]
score += wts_classifier[k] * s
score += wt_clf * s
clf_score_.append(s)
pq.put((-score, hyp, (logP[i], logP_center[i], logP_base[i], rep, l) + tuple(clf_score_)))

Expand Down
17 changes: 7 additions & 10 deletions src/infer_args.csv
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
method,latent
beam_width,100
n_rand,100
r_rand,2
n_rand,10
r_rand,3
softmax_temperature,1
min_logP,-1000.
wt_center,1.
wt_rep,1.
wt_len,0.5
wt_base,1.
base10m_vs_holmes38k/ngram,1.5
base10m_vs_holmes38k/neural,1.5
base10m_vs_arxiv1m/ngram,1.5
base10m_vs_arxiv1m/neural,1.5
wt_center,0.
wt_rep,0.
wt_len,0.
wt_base,0.
wt_clf,1.
crit_logP,-2.5
crit_rep,-0.3
rep_allow,0.9
Expand Down

0 comments on commit ef3fc9c

Please sign in to comment.