Skip to content

Commit 06e9a21

Browse files
author
Yan
committed
At least executable now. ORZ
1 parent 4de0d86 commit 06e9a21

File tree

6 files changed

+163
-40
lines changed

6 files changed

+163
-40
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# zhihu_cup
22
draft code for zhihu cup using MxNet</br>
3-
Code under construction
3+
Barely executable now.

concise_data.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import mxnet as mx
21
from read_embed import read_embed
32
char_raw = open('sorted_char_count.txt').readlines()
43
word_raw = open('sorted_word_count.txt').readlines()

iter.py

+28-10
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
import os
1+
import os,sys
2+
curr_path = os.path.abspath(os.path.dirname(__file__))
3+
sys.path.append(os.path.join(curr_path, "../mxnet/python"))
24
import numpy as np
35
import mxnet as mx
46
from read_embed import read_embed
@@ -49,14 +51,26 @@ def __init__(self, question_set_path,
4951
for v in self.buckets:
5052
f.write(v+'\n')
5153
self.reset()
54+
keys = self.max_bucket_key.split(',')
55+
self.provide_data = []
56+
5257
if embed_mode %2 == 0:
5358
self.char_dict, self.char_dict_size, self.char_dict_dim = read_embed(char_embed_path)
59+
self.provide_data +=[('tc_array',(self.batch_size,int(keys[0]),self.char_dict_dim)),\
60+
('cc_array',(self.batch_size,int(keys[1]),self.char_dict_dim))]
5461
if embed_mode > 0:
5562
self.word_dict, self.word_dict_size, self.word_dict_dim = read_embed(word_embed_path)
56-
self.provide_data = [('data', (self.batch_size, ))]
63+
if embed_mode %2 == 0:
64+
self.provide_data +=[('tw_array',(self.batch_size,int(keys[2]),self.word_dict_dim)),\
65+
('cw_array',(self.batch_size,int(keys[3]),self.word_dict_dim))]
66+
else:
67+
self.provide_data +=[('tw_array',(self.batch_size,int(keys[0]),self.word_dict_dim)),\
68+
('cw_array',(self.batch_size,int(keys[1]),self.word_dict_dim))]
69+
5770
self.provide_label = [('label', (self.batch_size, len(self.topic_info) + 1))]
5871

5972

73+
6074
def create_buckets(self, buckets=None):
6175
if buckets is None:
6276
self.buckets = self.default_buckets()
@@ -178,7 +192,6 @@ def __iter__(self):
178192
bucket_key = self.buckets[idx]
179193
inds= self.bucket_samples_inds[bucket_key] \
180194
[self.bucket_offset[idx]:self.bucket_offset[idx]+self.batch_size]
181-
182195
shapes= [(self.batch_size, int(v)) for v in bucket_key.split(',')]
183196
if len(shapes) == 4:
184197
tc_array = np.zeros(shapes[0]+(self.char_dict_dim,))
@@ -191,6 +204,8 @@ def __iter__(self):
191204
else:
192205
tw_array = np.zeros(shapes[0]+(self.word_dict_dim,))
193206
cw_array = np.zeros(shapes[1]+(self.word_dict_dim,))
207+
#print '*'*20
208+
#print shapes,bucket_key,tw_array.shape, cw_array.shape
194209

195210
label = np.zeros((self.batch_size, len(self.topic_encode)+1))
196211
for i,ind in enumerate(inds):
@@ -209,33 +224,36 @@ def __iter__(self):
209224
cc = tc
210225
cw = tw
211226

212-
data_name = []
213-
data = []
214227
if self.embed_mode %2 == 0:
215228
for j, v in enumerate(tc.split(',')):
216229
tc_array[i,j] = self.char_dict[v]
217230
for j, v in enumerate(cc.split(',')):
218231
cc_array[i,j] = self.char_dict[v]
219-
data_name += ['tc_array', 'cc_array']
220-
data += [mx.nd.array(tc_array), mx.nd.array(cc_array)]
221232
if self.embed_mode > 0:
222233
for j, v in enumerate(tw.split(',')):
223234
tw_array[i,j] = self.word_dict[v]
224235
for j, v in enumerate(cw.split(',')):
225236
cw_array[i,j] = self.word_dict[v]
226-
data_name += ['tw_array', 'cw_array']
227-
data += [mx.nd.array(tw_array), mx.nd.array(cw_array)]
228237

229238
top = self.question_topic[ind].split()[1].split(',')
230239
for t in top:
231240
label[i,self.topic_encode[t]] = 1
241+
data_name = []
242+
data = []
243+
if self.embed_mode %2 == 0:
244+
data_name += ['tc_array', 'cc_array']
245+
data += [mx.nd.array(tc_array), mx.nd.array(cc_array)]
246+
if self.embed_mode > 0:
247+
data_name += ['tw_array', 'cw_array']
248+
data += [mx.nd.array(tw_array), mx.nd.array(cw_array)]
232249
label = [mx.nd.array(label)]
233250
label_name = ['label']
251+
#print bucket_key, data
234252
yield SimpleBatch(data_name, data, label_name, label, bucket_key)
235253
raise StopIteration
236254

237255
if __name__ == '__main__':
238256
ziter = zhihu_iter('tidy_question_train_set.txt','tidy_question_topic_train_set.txt',embed_mode=1)
239257
#ziter.reset()
240258
for i in ziter:
241-
print i
259+
print i.provide_data

purge_data.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import mxnet as mx
21
from read_embed import read_embed
32
char_embed_path='./char_embedding.txt'
43
word_embed_path='./word_embedding.txt'

sym.py

+65-27
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
1+
import os,sys
2+
curr_path = os.path.abspath(os.path.dirname(__file__))
3+
sys.path.append(os.path.join(curr_path, "../mxnet/python"))
24
import mxnet as mx
35
import numpy as np
46

@@ -43,6 +45,9 @@ def fc_module(data, prefix, num_hidden=256):
4345
return relu_fc1
4446

4547
def sym_gen_char(bucket_key):
48+
num_layers = 1
49+
num_class = 2000
50+
num_hidden = 512
4651
key = bucket_key.split(',')
4752
tc_length = int(key[0])
4853
cc_length = int(key[1])
@@ -55,17 +60,55 @@ def sym_gen_char(bucket_key):
5560
cc_slices = list(mx.symbol.SliceChannel(data=cc_data, axis=1, num_outputs=cc_length, squeeze_axis=True, name='cc_slice'))
5661
tc_concat, _ = tc_cell.unroll(tc_length, inputs = tc_slices, merge_outputs=True, layout='TNC')
5762
cc_concat, _ = cc_cell.unroll(cc_length, inputs = cc_slices, merge_outputs=True, layout='TNC')
58-
tc_concat = mx.sym.swapaxes(tc_concat, 0, 1)
59-
cc_concat = mx.sym.swapaxes(cc_concat, 0, 1)
63+
tc_concat = mx.sym.transpose(tc_concat, (1, 2, 0))
64+
cc_concat = mx.sym.transpose(cc_concat, (1, 2, 0))
65+
tc_concat = mx.sym.Pooling(tc_concat, kernel=(1,), global_pool = True, pool_type='max')
66+
cc_concat = mx.sym.Pooling(cc_concat, kernel=(1,), global_pool = True, pool_type='max')
67+
feature = mx.sym.Concat(*[tc_concat, cc_concat], name= 'concat')
68+
feature = mx.sym.Dropout(feature, p=0.5)
69+
feature = fc_module(feature, 'feature', num_hidden=2000)
70+
loss = mx.sym.LogisticRegressionOutput(feature, label=label, name='regression')
71+
return loss
6072

6173

6274
def sym_gen_word(bucket_key):
75+
num_layers = 1
76+
num_class = 2000
77+
num_hidden = 512
78+
key = bucket_key.split(',')
79+
tw_length = int(key[0])
80+
cw_length = int(key[1])
81+
tw_data = mx.sym.Variable('tw_array')
82+
cw_data = mx.sym.Variable('cw_array')
83+
label = mx.sym.Variable('label')
84+
tw_cell = mx.rnn.FusedRNNCell(num_hidden, num_layers=num_layers, bidirectional=True, mode='lstm', prefix ='tw_')
85+
cw_cell = mx.rnn.FusedRNNCell(num_hidden, num_layers=num_layers, bidirectional=True, mode='lstm', prefix ='cw_')
86+
tw_slices = list(mx.symbol.SliceChannel(data=tw_data, axis=1, num_outputs=tw_length, squeeze_axis=True, name='tw_slice'))
87+
cw_slices = list(mx.symbol.SliceChannel(data=cw_data, axis=1, num_outputs=cw_length, squeeze_axis=True, name='cw_slice'))
88+
tw_concat, _ = tw_cell.unroll(tw_length, inputs = tw_slices, merge_outputs=True, layout='TNC')
89+
cw_concat, _ = cw_cell.unroll(cw_length, inputs = cw_slices, merge_outputs=True, layout='TNC')
90+
tw_concat = mx.sym.transpose(tw_concat, (1, 2, 0))
91+
cw_concat = mx.sym.transpose(cw_concat, (1, 2, 0))
92+
tw_concat = mx.sym.Pooling(tw_concat, kernel=(1,), global_pool = True, pool_type='max')
93+
cw_concat = mx.sym.Pooling(cw_concat, kernel=(1,), global_pool = True, pool_type='max')
94+
feature = mx.sym.Concat(*[tw_concat, cw_concat], name= 'concat')
95+
feature = mx.sym.Dropout(feature, p=0.5)
96+
feature = fc_module(feature, 'feature', num_hidden=2000)
97+
loss = mx.sym.LogisticRegressionOutput(feature, label=label, name='regression')
98+
data_name = ['tw_array', 'cw_array']
99+
label_name = ['label']
100+
return loss, data_name, label_name
63101

64102

65103
def sym_gen_both(bucket_key):
66-
67-
104+
num_layers = 1
105+
num_class = 2000
106+
num_hidden = 512
68107
key = bucket_key.split(',')
108+
tc_length = int(key[0])
109+
cc_length = int(key[1])
110+
tw_length = int(key[2])
111+
cw_length = int(key[3])
69112
tc_data = mx.sym.Variable('tc_array')
70113
cc_data = mx.sym.Variable('cc_array')
71114
tw_data = mx.sym.Variable('tw_array')
@@ -83,34 +126,29 @@ def sym_gen_both(bucket_key):
83126
cc_concat, _ = cc_cell.unroll(cc_length, inputs = cc_slices, merge_outputs=True, layout='TNC')
84127
tw_concat, _ = tw_cell.unroll(tw_length, inputs = tw_slices, merge_outputs=True, layout='TNC')
85128
cw_concat, _ = cw_cell.unroll(cw_length, inputs = cw_slices, merge_outputs=True, layout='TNC')
86-
tc_concat = mx.sym.swapaxes(tc_concat, 0, 1)
87-
cc_concat = mx.sym.swapaxes(cc_concat, 0, 1)
88-
tw_concat = mx.sym.swapaxes(tw_concat, 0, 1)
89-
cw_concat = mx.sym.swapaxes(cw_concat, 0, 1)
90-
#ch_outputs = mx.sym.Concat(*[tc_concat, cc_concat])
91-
#wd_outputs = mx.sym.Concat(*[tw_concat, cw_concat])
92-
#title_outputs= mx.sym.Concat(*[tc_concat, tw_concat])
93-
#content_outputs= mx.sym.Concat(*[cc_concat, cw_concat])
94-
#ch_outputs = fc_module(ch_outputs, 'ch_', num_hidden = 2000)
95-
#wd_outputs = fc_module(wd_outputs, 'wd_', num_hidden = 2000)
96-
#title_outputs = fc_module(title_outputs, 'title_', num_hidden = 2000)
97-
#content_outputs = fc_module(content_outputs, 'content_', num_hidden = 2000)
98-
#feature = mx.sym.Concat(*[ch_outputs, wd_outputs, title_outputs, content_outputs])
99-
feature = mx.sym.Concat(*[tc_concat, cc_concat, tw_concat, cw_concat])
100-
feature = fc_module(feature, 'feature', num_hidden=4000)
101-
feature = mx.sym.FullyConnected(data=feature, num_hidden=num_class, name='fc1')
102-
loss = mx.sym.LogisticRegressionOutput(feature, label, name='regression')
129+
tc_concat = mx.sym.transpose(tc_concat, (1, 2, 0))
130+
cc_concat = mx.sym.transpose(cc_concat, (1, 2, 0))
131+
tw_concat = mx.sym.transpose(tw_concat, (1, 2, 0))
132+
cw_concat = mx.sym.transpose(cw_concat, (1, 2, 0))
133+
tc_concat = mx.sym.Pooling(tc_concat, kernel=(1,), global_pool = True, pool_type='max')
134+
cc_concat = mx.sym.Pooling(cc_concat, kernel=(1,), global_pool = True, pool_type='max')
135+
tw_concat = mx.sym.Pooling(tw_concat, kernel=(1,), global_pool = True, pool_type='max')
136+
cw_concat = mx.sym.Pooling(cw_concat, kernel=(1,), global_pool = True, pool_type='max')
137+
feature = mx.sym.Concat(*[tc_concat, cc_concat, tw_concat, cw_concat], name= 'concat')
138+
feature = mx.sym.Dropout(feature, p=0.5)
139+
feature = fc_module(feature, 'feature', num_hidden=2000)
140+
loss = mx.sym.LogisticRegressionOutput(feature, label=label, name='regression')
103141
return loss
104142

105143
if __name__ == '__main__':
106-
sym = sym_gen(100,100, 100, 100)
144+
sym = sym_gen_both('100,33,11,21')
107145
batch_size = 32
108146
dim = 256
109147
length = 100
110-
shapes = sym.infer_shape_partial(tc_array=(batch_size,length,dim),
111-
cc_array=(batch_size,length,dim),
112-
tw_array=(batch_size,length,dim),
113-
cw_array=(batch_size,length,dim),
148+
shapes = sym.infer_shape_partial(tc_array=(batch_size,100,dim),
149+
cc_array=(batch_size,33,dim),
150+
tw_array=(batch_size,11,dim),
151+
cw_array=(batch_size,21,dim),
114152
label=(batch_size,2000))
115153
names = sym.list_arguments()
116154
for name, shape in zip(names, shapes[0]):

train.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import os,sys
2+
curr_path = os.path.abspath(os.path.dirname(__file__))
3+
sys.path.append(os.path.join(curr_path, "../mxnet/python"))
4+
import numpy as np
5+
import mxnet as mx
6+
from iter import zhihu_iter
7+
batch_size =4
8+
ziter = zhihu_iter('tiny_train.txt','tiny_topic.txt',batch_size=batch_size,embed_mode=1)
9+
10+
11+
num_layers = 1
12+
num_class = 2000
13+
num_hidden = 512
14+
tw_cell = mx.rnn.FusedRNNCell(num_hidden, num_layers=num_layers, bidirectional=True, mode='lstm', prefix ='tw_')
15+
cw_cell = mx.rnn.FusedRNNCell(num_hidden, num_layers=num_layers, bidirectional=True, mode='lstm', prefix ='cw_')
16+
17+
def fc_module(data, prefix, num_hidden=256):
18+
with mx.name.Prefix(prefix):
19+
fc1 = mx.sym.FullyConnected(data=data, num_hidden=num_hidden, name='fc1')
20+
relu_fc1 = mx.sym.Activation(data=fc1, act_type='relu', name='relu_fc1')
21+
return relu_fc1
22+
23+
data_name = [i[0] for i in ziter.provide_data]
24+
label_name = [i[0] for i in ziter.provide_label]
25+
def sym_gen_word(bucket_key):
26+
key = bucket_key.split(',')
27+
tw_length = int(key[0])
28+
cw_length = int(key[1])
29+
tw_data = mx.sym.Variable('tw_array')
30+
cw_data = mx.sym.Variable('cw_array')
31+
label = mx.sym.Variable('label')
32+
tw_slices = list(mx.symbol.SliceChannel(data=tw_data, axis=1, num_outputs=tw_length, squeeze_axis=True, name='tw_slice'))
33+
cw_slices = list(mx.symbol.SliceChannel(data=cw_data, axis=1, num_outputs=cw_length, squeeze_axis=True, name='cw_slice'))
34+
tw_concat, _ = tw_cell.unroll(tw_length, inputs = tw_slices, merge_outputs=True, layout='TNC')
35+
cw_concat, _ = cw_cell.unroll(cw_length, inputs = cw_slices, merge_outputs=True, layout='TNC')
36+
tw_concat = mx.sym.transpose(tw_concat, (1, 2, 0))
37+
cw_concat = mx.sym.transpose(cw_concat, (1, 2, 0))
38+
tw_concat = mx.sym.Pooling(tw_concat, kernel=(1,), global_pool = True, pool_type='max')
39+
cw_concat = mx.sym.Pooling(cw_concat, kernel=(1,), global_pool = True, pool_type='max')
40+
feature = mx.sym.Concat(*[tw_concat, cw_concat], name= 'concat')
41+
feature = mx.sym.Dropout(feature, p=0.5)
42+
feature = fc_module(feature, 'feature', num_hidden=2000)
43+
loss = mx.sym.LogisticRegressionOutput(feature, label=label, name='regression')
44+
return loss, data_name, label_name
45+
46+
#mod = mx.module.BucketingModule(sym_gen_word, default_bucket_key=ziter.max_bucket_key,context=mx.gpu(1),data_names=data_name, label_names=label_name)
47+
mod = mx.module.BucketingModule(sym_gen_word, default_bucket_key=ziter.max_bucket_key,context=mx.context.gpu(1))
48+
import logging
49+
head = '%(asctime)-15s %(message)s'
50+
logging.basicConfig(level=logging.DEBUG, format=head)
51+
prefix='model/textline'
52+
learning_rate = 0.01
53+
optimizer_params={'learning_rate': learning_rate,
54+
'clip_gradient': 10 }
55+
monitor=mx.mon.Monitor(200, pattern='.*')
56+
57+
58+
59+
num_epoch = 10
60+
print 'fit begin'
61+
mod.fit(train_data=ziter, eval_data=ziter,
62+
optimizer='adadelta',
63+
optimizer_params = optimizer_params,
64+
eval_metric = mx.metric.MSE(),
65+
num_epoch=num_epoch,
66+
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
67+
batch_end_callback=mx.callback.Speedometer(batch_size, 50),
68+
epoch_end_callback = mx.rnn.do_rnn_checkpoint([tw_cell, cw_cell], prefix, 1))
69+

0 commit comments

Comments
 (0)