2
2
3
3
"""
4
4
@author: Songgx
5
- @file: 1304_mtt_multi-layer_LSTM_OOM .py
5
+ @file: 1309_1_mtt_multi-layer_BNLSTM_truncated_v2 .py
6
6
@time: 2017/2/27 21:10
7
7
8
8
https://medium.com/@erikhallstrm/using-the-dynamicrnn-api-in-tensorflow-7237aba7f7ea#.qduigey12
15
15
import bnlstm
16
16
17
17
18
+ training_num = 14951
19
+ valdation_num = 1825
20
+ test_num = 4332
21
+
18
22
x_height = 96
19
23
x_width = 1366
20
24
# 总共的tag数
21
25
n_tags = 50
22
26
23
27
learning_rate = 1e-5
24
- training_iterations = 150 * 64 * 2000 # 150 * 64 * 2000 == 19,200,000 iterations, about 150 epochs
25
- display_step = 1000
28
+ training_iterations = 200000 # 00000 # 200 * 64 * 2000 == iterations, about 200 epochs
29
+ display_step = 500
26
30
state_size = 512
27
31
batch_size = 64
28
32
num_layers = 5
29
- dropout = 0.75
33
+ dropout = 0.7
30
34
31
35
# truncated_backprop_length
32
36
truncated_backprop_length = 1366
@@ -87,6 +91,8 @@ def get_roc_auc_scores(tags, logits):
87
91
# placeholders
88
92
batchX_placeholder = tf .placeholder (tf .float32 , [batch_size , truncated_num , truncated_backprop_length ])
89
93
batchY_placeholder = tf .placeholder (tf .float32 , [batch_size , n_tags ])
94
+ output_keep_prob = tf .placeholder (tf .float32 )
95
+ phase_training = tf .placeholder (tf .bool )
90
96
91
97
init_state = tf .placeholder (tf .float32 , [num_layers , 2 , batch_size , state_size ])
92
98
@@ -109,7 +115,7 @@ def get_roc_auc_scores(tags, logits):
109
115
# initial state with zeros
110
116
rnn_tuple_state_initial = tuple (np .zeros ((num_layers , 2 , batch_size , state_size ), dtype = np .float32 ))
111
117
112
- def RNN (X , weights , biases , rnn_tuple_state ):
118
+ def RNN (X , weights , biases , rnn_tuple_state , phase_training = np . array ( True ) ):
113
119
# Prepare data shape to match `rnn` function requirements
114
120
# Current data input shape: (batch_size, n_steps, n_input)
115
121
# Required shape: 'n_steps' tensors list of shape (batch_size, n_input)
@@ -120,8 +126,8 @@ def RNN(X, weights, biases, rnn_tuple_state):
120
126
# Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
121
127
X = tf .split (0 , truncated_num , X )
122
128
123
- cell = bnlstm .BNLSTMCell (state_size , tf . constant ( True ) )
124
- cell = tf .nn .rnn_cell .DropoutWrapper (cell , output_keep_prob = dropout )
129
+ cell = bnlstm .BNLSTMCell (state_size , phase_training )
130
+ cell = tf .nn .rnn_cell .DropoutWrapper (cell , 1.0 , dropout )
125
131
cell = tf .nn .rnn_cell .MultiRNNCell ([cell ] * num_layers , state_is_tuple = True )
126
132
127
133
# Forward passes
@@ -146,6 +152,7 @@ def RNN(X, weights, biases, rnn_tuple_state):
146
152
147
153
with tf .Session () as sess :
148
154
sess .run (tf .global_variables_initializer ())
155
+ saver = tf .train .Saver ()
149
156
# Start input enqueue threads.
150
157
coord = tf .train .Coordinator ()
151
158
threads = tf .train .start_queue_runners (sess = sess , coord = coord )
@@ -157,35 +164,41 @@ def RNN(X, weights, biases, rnn_tuple_state):
157
164
_ , loss_val , pred_ = sess .run ([train_step , mean_batch_loss , logits ],
158
165
feed_dict = {batchX_placeholder : audio_batch_vals_training ,
159
166
batchY_placeholder : label_batch_vals_training ,
167
+ output_keep_prob : dropout ,
168
+ phase_training : np .array (True )
160
169
})
161
170
if (iteration_idx + 1 ) % display_step == 0 :
162
- validation_iterations = 20
171
+ validation_iterations = 28 # valdation_num / batch_size = 28.5
163
172
cur_validation_acc = 0.
164
173
for _ in range (validation_iterations ):
165
174
audio_batch_validation_vals , label_batch_validation_vals = sess .run ([audio_batch_validation , label_batch_validation ])
166
175
167
176
logits_validation , loss_val_validation = sess .run ([logits , mean_batch_loss ],
168
177
feed_dict = {batchX_placeholder : audio_batch_validation_vals ,
169
- batchY_placeholder : label_batch_validation_vals ,# keep_prob: 1.0
178
+ batchY_placeholder : label_batch_validation_vals ,
179
+ output_keep_prob : 1.0 ,
180
+ phase_training :np .array (False )
170
181
})
171
182
validation_accuracy = get_roc_auc_scores (label_batch_validation_vals , logits_validation )
172
183
cur_validation_acc += validation_accuracy
173
184
174
185
cur_validation_acc /= validation_iterations
175
186
print ("iter %d, training loss: %f, validation accuracy: %f" % ((iteration_idx + 1 ), loss_val , cur_validation_acc ))
176
-
177
- print ("######### Training finished. #########" )
187
+ save_path = saver . save ( sess , "model/1309_1_mtt_multi-layer_BNLSTM_truncated_v2.ckpt" )
188
+ print ("######### Training finished && model saved . #########" )
178
189
179
190
# Test model
180
191
# batch_test --> reduce_mean --> final_test_accuracy
181
192
182
- test_iterations = 70
193
+ test_iterations = 67 # 4332/64=67.68
183
194
test_accuracy_final = 0.
184
195
for _ in range (test_iterations ):
185
196
audio_test_vals , label_test_vals = sess .run ([audio_batch_test , label_batch_test ])
186
197
logits_test , test_loss_val = sess .run ([logits , mean_batch_loss ],
187
198
feed_dict = {batchX_placeholder : audio_test_vals ,
188
- batchY_placeholder : label_test_vals , #keep_prob: 1.0
199
+ batchY_placeholder : label_test_vals ,
200
+ output_keep_prob : 1.0 ,
201
+ phase_training :np .array (False )
189
202
})
190
203
test_accuracy = get_roc_auc_scores (label_test_vals , logits_test )
191
204
test_accuracy_final += test_accuracy
0 commit comments