-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathltrdnn.py
218 lines (194 loc) · 9.01 KB
/
ltrdnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import numpy as np
import sys
import tensorflow as tf
class LTRDNN(object):
"""LTR-DNN model
"""
def __init__(self, vocab_size, emb_dim=256, repr_dim=256,
combiner='sum', lr=1e-4, eps=1.0):
"""Construct network.
"""
if combiner not in ['sum', 'mean']:
raise Exception('invalid combiner')
self.dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
self.global_step = tf.Variable(0, name='global_step', trainable=False)
self.pretrained_emb = tf.placeholder(tf.float32, [vocab_size, emb_dim])
self.eps = eps
# prepare placeholder for query, pos, neg
# https://www.tensorflow.org/api_docs/python/tf/sparse_placeholder
# input is a batch_size*seq_len sparse tensor
self.inp_qry = tf.sparse_placeholder(dtype=tf.int64, name='input_qry')
self.inp_pos = tf.sparse_placeholder(dtype=tf.int64, name='input_pos')
self.inp_neg = tf.sparse_placeholder(dtype=tf.int64, name='input_neg')
# use only when predicting sim-qq
self.inp_prd = tf.sparse_placeholder(dtype=tf.int64, name='input_prd')
# embedding from pretrained one or random one
self.embedding = tf.Variable(
tf.random_uniform([vocab_size, emb_dim], -0.02, 0.02),
name='emb_mat')
# defined an assign embedding value op
self.init_embedding = self.embedding.assign(self.pretrained_emb)
# #shape of emb_qry: batch_size * emb_dim
emb_qry = tf.nn.embedding_lookup_sparse(
self.embedding, self.inp_qry, sp_weights=None, combiner=combiner)
emb_pos = tf.nn.embedding_lookup_sparse(
self.embedding, self.inp_pos, sp_weights=None, combiner=combiner)
emb_neg = tf.nn.embedding_lookup_sparse(
self.embedding, self.inp_neg, sp_weights=None, combiner=combiner)
emb_prd = tf.nn.embedding_lookup_sparse(
self.embedding, self.inp_prd, sp_weights=None, combiner=combiner)
# construct fc layer to get repr of sentence
w = tf.get_variable(
'q-fc-W', shape=[emb_dim, repr_dim],
initializer=tf.contrib.layers.xavier_initializer())
b = tf.Variable(tf.constant(0.1, shape=[repr_dim]), name='b')
with tf.name_scope('tit_qry-vec'):
# #shape of repr_qry: batch_size * repr_dim
self.repr_qry = tf.nn.l2_normalize(
tf.nn.softsign(tf.nn.xw_plus_b(emb_qry, w, b)), dim=1)
with tf.name_scope('tit_prd-vec'):
self.repr_prd = tf.nn.l2_normalize(
tf.nn.softsign(tf.nn.xw_plus_b(emb_prd, w, b)), dim=1)
w = tf.get_variable(
't-fc-W', shape=[emb_dim, repr_dim],
initializer=tf.contrib.layers.xavier_initializer())
b = tf.Variable(tf.constant(0.1, shape=[repr_dim]), name='b')
with tf.name_scope('qry_pos-vec'):
self.repr_pos = tf.nn.l2_normalize(
tf.nn.softsign(tf.nn.xw_plus_b(emb_pos, w, b)), dim=1)
with tf.name_scope('qry_neg-vec'):
self.repr_neg = tf.nn.l2_normalize(
tf.nn.softsign(tf.nn.xw_plus_b(emb_neg, w, b)), dim=1)
# cosine similarity between q&p, q&n, q&q
with tf.name_scope('sim-qp'):
# #shape of sim_qp: batch_size * 1
self.sim_qp = tf.reduce_sum(
tf.multiply(self.repr_qry, self.repr_pos), axis=1)
with tf.name_scope('sim-qn'):
self.sim_qn = tf.reduce_sum(
tf.multiply(self.repr_qry, self.repr_neg), axis=1)
with tf.name_scope('sim-qq'):
self.sim_qq = tf.reduce_sum(
tf.multiply(self.repr_qry, self.repr_prd), axis=1)
with tf.name_scope('diff_qp-qn'):
self.sim_diff = tf.subtract(self.sim_qp, self.sim_qn)
with tf.name_scope('label'):
self.labels = tf.ones(shape=tf.shape(self.sim_diff))
with tf.name_scope('pairwise-loss'):
# calculate hinge loss
# modified hinge_loss = (1 / batch_size) * max(0, eps - sim_diff)
self.loss = tf.losses.hinge_loss(
labels=self.labels,
logits=self.sim_diff / self.eps,
reduction=tf.losses.Reduction.MEAN
) * self.eps
self.total_loss = self.loss # add reg-loss
# optimizer
# kindly notice the efficiecy problem of Adam with sparse op:
# https://github.com/tensorflow/tensorflow/issues/6460
self.opt = tf.contrib.opt.LazyAdamOptimizer(lr).minimize(
self.total_loss, global_step=self.global_step)
with tf.name_scope('pairwise-prediction'):
# prediction
# pred = 1 if sim_pos >= sim_neg else 0
self.preds = tf.sign(tf.sign(self.sim_diff) + 1.)
# @TODO: Add regularization like dropout, l2-reg, etc.
with tf.name_scope('accuracy'):
# accumulated accuracy
# re-initialize local variables to conduct a new evaluation
self.acc, self.update_acc = tf.contrib.metrics.streaming_accuracy(
labels=self.labels, predictions=self.preds)
# saver and loader
# drop local variables of optimizer
self.saver = tf.train.Saver(tf.trainable_variables())
def train_step(self, sess, inp_batch_q, inp_batch_p, inp_batch_n):
input_dict = {
self.inp_qry: inp_batch_q,
self.inp_pos: inp_batch_p,
self.inp_neg: inp_batch_n,
self.dropout_prob: 0.5}
sess.run(self.opt, feed_dict=input_dict)
def assign_embedding(self, sess, embedding=None):
if embedding is None:
raise Exception('embedding is None')
input_dict = {self.pretrained_emb: embedding}
sess.run(self.init_embedding, feed_dict=input_dict)
def eval_step(self, sess, dev_qry, dev_pos, dev_neg, metrics=None):
if not metrics:
metrics = ['loss']
eval_dict = {
self.inp_qry: dev_qry,
self.inp_pos: dev_pos,
self.inp_neg: dev_neg,
self.dropout_prob: 1.0}
eval_res = []
for metric in metrics:
if metric == 'loss':
eval_res.append(sess.run(self.loss, feed_dict=eval_dict))
return eval_res
def predict_sim(self, sess, query, title1, title2):
"""predict similarity between query&title1, query&title2, label
@return: [sim_qt1, sim_qt2, 1.0/0.0]
"""
eval_dict = {
self.inp_qry: query,
self.inp_pos: title1,
self.inp_neg: title2,
self.dropout_prob: 1.0}
return sess.run([self.sim_qp, self.sim_qn, self.preds],
feed_dict=eval_dict)
def predict_diff(self, sess, inp_qry, inp_pos, inp_neg):
"""predict which title is more similar to query.
@return: 1.0/0.0 if first/second title is more similar.
"""
pred_dict = {
self.inp_qry: inp_qry,
self.inp_pos: inp_pos,
self.inp_neg: inp_neg,
self.dropout_prob: 1.0}
return sess.run(self.preds, feed_dict=pred_dict)
def predict_sim_qt(self, sess, inp_query, inp_title):
"""predict similarity of query and title.
@return: cosine similarity. value range [-1, 1].
"""
pred_dict = {
self.inp_qry: inp_query,
self.inp_pos: inp_title,
self.dropout_prob: 1.0}
return sess.run(self.sim_qp, feed_dict=pred_dict)
def predict_sim_qq(self, sess, inp_query1, inp_query2):
"""predict similarity of two queries.
@return: cosine similarity. value range [-1, 1].
"""
pred_dict = {
self.inp_qry: inp_query1,
self.inp_prd: inp_query2,
self.dropout_prob: 1.0}
return sess.run(self.sim_qq, feed_dict=pred_dict)
def _accumulate_accuracy(self, sess, inp_q, inp_p, inp_n):
"""update accuracy by inputs and staged value.
@return: newly-updated accuracy.
"""
input_dict = {
self.inp_qry: inp_q,
self.inp_pos: inp_p,
self.inp_neg: inp_n,
self.dropout_prob: 1.0}
sess.run(self.update_acc, feed_dict=input_dict)
return sess.run(self.acc)
def pairwise_accuracy(self, sess, fiter, inp_fn, verb=None):
"""evaluate the correct pairwise order ratio.
@return: accuracy=(correct_pair/total_pair).
@fiter : an iterable yielding instance (qry & pos & neg of each query).
@inp_fn: a func extracting (qry, pos, neg) from instance, in which
qry, pos, neg are all batch-sentence that could be feed to
self.inp_X.
@verb : print progress hint every verb lines. None for no hint.
"""
accuracy = None
for nl, inst in enumerate(fiter):
if verb and nl % verb == 0: # print hint
sys.stderr.write(str(verb) + ' lines in pairwise_acc.\n')
qrys, poss, negs = inp_fn(inst)
accuracy = self._accumulate_accuracy(sess, qrys, poss, negs)
return accuracy