Skip to content

Commit d148d14

Browse files
author
GuangxiaoSong
committed
BNLSTM
1 parent f128dac commit d148d14

2 files changed

+84
-13
lines changed
+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# -*- coding:utf-8 -*-
2+
3+
"""
4+
@author: Songgx
5+
@file: 1102_0_mtt_mini_batch_input.py
6+
@time: 2017/1/15 10:33
7+
"""
8+
9+
from __future__ import print_function
10+
import os
11+
12+
MTT_mel_path = 'D:/music_data/mtt_mel'
13+
14+
training_num = 0
15+
validation_num = 0
16+
test_num = 0
17+
18+
none_top50_file_path = 'none_top_50_tags_file.txt'
19+
nt5_file = open(none_top50_file_path, 'r')
20+
21+
for root, subdirs, files in os.walk(MTT_mel_path):
22+
subdirs.sort()
23+
for subdir in subdirs:
24+
path = root + "/" + subdir
25+
for root1, subdirs1, files1 in os.walk(path):
26+
for file1 in files1:
27+
cur_path = root1 + "/" + file1
28+
output_path_subfolder = cur_path[:(cur_path.rfind("/") + 1)]
29+
output_path_subfolder = output_path_subfolder[-2:]
30+
31+
if output_path_subfolder in ['d/', 'e/', 'f/']:
32+
test_num += 1
33+
elif output_path_subfolder in ['c/']:
34+
validation_num += 1
35+
else:
36+
training_num += 1
37+
38+
f = open(none_top50_file_path, 'r')
39+
for line in f:
40+
cur_folder = line[0]
41+
if cur_folder in ['d', 'e', 'f']:
42+
test_num -= 1
43+
elif output_path_subfolder in ['c']:
44+
validation_num -= 1
45+
else:
46+
training_num -= 1
47+
48+
print(training_num, validation_num, test_num)
49+
50+
51+
'''
52+
result:
53+
all: 18706 1825 5329
54+
top-50: 14951 1825 4332
55+
'''
56+
57+
58+

1309_1_mtt_multi-layer_BNLSTM_truncated_v2.py

+26-13
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
"""
44
@author: Songgx
5-
@file: 1304_mtt_multi-layer_LSTM_OOM.py
5+
@file: 1309_1_mtt_multi-layer_BNLSTM_truncated_v2.py
66
@time: 2017/2/27 21:10
77
88
https://medium.com/@erikhallstrm/using-the-dynamicrnn-api-in-tensorflow-7237aba7f7ea#.qduigey12
@@ -15,18 +15,22 @@
1515
import bnlstm
1616

1717

18+
training_num = 14951
19+
valdation_num = 1825
20+
test_num = 4332
21+
1822
x_height = 96
1923
x_width = 1366
2024
# 总共的tag数
2125
n_tags = 50
2226

2327
learning_rate = 1e-5
24-
training_iterations = 150 * 64 * 2000 # 150 * 64 * 2000 == 19,200,000 iterations, about 150 epochs
25-
display_step = 1000
28+
training_iterations = 200000# 00000 # 200 * 64 * 2000 == iterations, about 200 epochs
29+
display_step = 500
2630
state_size = 512
2731
batch_size = 64
2832
num_layers = 5
29-
dropout = 0.75
33+
dropout=0.7
3034

3135
# truncated_backprop_length
3236
truncated_backprop_length = 1366
@@ -87,6 +91,8 @@ def get_roc_auc_scores(tags, logits):
8791
# placeholders
8892
batchX_placeholder = tf.placeholder(tf.float32, [batch_size, truncated_num, truncated_backprop_length])
8993
batchY_placeholder = tf.placeholder(tf.float32, [batch_size, n_tags])
94+
output_keep_prob=tf.placeholder(tf.float32)
95+
phase_training = tf.placeholder(tf.bool)
9096

9197
init_state = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
9298

@@ -109,7 +115,7 @@ def get_roc_auc_scores(tags, logits):
109115
# initial state with zeros
110116
rnn_tuple_state_initial = tuple(np.zeros((num_layers, 2, batch_size, state_size), dtype=np.float32))
111117

112-
def RNN(X, weights, biases, rnn_tuple_state):
118+
def RNN(X, weights, biases, rnn_tuple_state, phase_training=np.array(True)):
113119
# Prepare data shape to match `rnn` function requirements
114120
# Current data input shape: (batch_size, n_steps, n_input)
115121
# Required shape: 'n_steps' tensors list of shape (batch_size, n_input)
@@ -120,8 +126,8 @@ def RNN(X, weights, biases, rnn_tuple_state):
120126
# Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
121127
X = tf.split(0, truncated_num, X)
122128

123-
cell = bnlstm.BNLSTMCell(state_size, tf.constant(True))
124-
cell = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=dropout)
129+
cell = bnlstm.BNLSTMCell(state_size, phase_training)
130+
cell = tf.nn.rnn_cell.DropoutWrapper(cell, 1.0, dropout)
125131
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)
126132

127133
# Forward passes
@@ -146,6 +152,7 @@ def RNN(X, weights, biases, rnn_tuple_state):
146152

147153
with tf.Session() as sess:
148154
sess.run(tf.global_variables_initializer())
155+
saver = tf.train.Saver()
149156
# Start input enqueue threads.
150157
coord = tf.train.Coordinator()
151158
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
@@ -157,35 +164,41 @@ def RNN(X, weights, biases, rnn_tuple_state):
157164
_, loss_val, pred_ = sess.run([train_step, mean_batch_loss, logits],
158165
feed_dict={batchX_placeholder: audio_batch_vals_training,
159166
batchY_placeholder: label_batch_vals_training,
167+
output_keep_prob: dropout,
168+
phase_training: np.array(True)
160169
})
161170
if (iteration_idx + 1) % display_step == 0:
162-
validation_iterations = 20
171+
validation_iterations = 28 # valdation_num / batch_size = 28.5
163172
cur_validation_acc = 0.
164173
for _ in range(validation_iterations):
165174
audio_batch_validation_vals, label_batch_validation_vals = sess.run([audio_batch_validation, label_batch_validation])
166175

167176
logits_validation, loss_val_validation = sess.run([logits, mean_batch_loss],
168177
feed_dict={batchX_placeholder: audio_batch_validation_vals,
169-
batchY_placeholder: label_batch_validation_vals,# keep_prob: 1.0
178+
batchY_placeholder: label_batch_validation_vals,
179+
output_keep_prob: 1.0,
180+
phase_training:np.array(False)
170181
})
171182
validation_accuracy = get_roc_auc_scores(label_batch_validation_vals, logits_validation)
172183
cur_validation_acc += validation_accuracy
173184

174185
cur_validation_acc /= validation_iterations
175186
print("iter %d, training loss: %f, validation accuracy: %f" % ((iteration_idx + 1), loss_val, cur_validation_acc))
176-
177-
print("######### Training finished. #########")
187+
save_path = saver.save(sess, "model/1309_1_mtt_multi-layer_BNLSTM_truncated_v2.ckpt")
188+
print("######### Training finished && model saved. #########")
178189

179190
# Test model
180191
# batch_test --> reduce_mean --> final_test_accuracy
181192

182-
test_iterations = 70
193+
test_iterations = 67 # 4332/64=67.68
183194
test_accuracy_final = 0.
184195
for _ in range(test_iterations):
185196
audio_test_vals, label_test_vals = sess.run([audio_batch_test, label_batch_test])
186197
logits_test, test_loss_val = sess.run([logits, mean_batch_loss],
187198
feed_dict={batchX_placeholder: audio_test_vals,
188-
batchY_placeholder: label_test_vals, #keep_prob: 1.0
199+
batchY_placeholder: label_test_vals,
200+
output_keep_prob: 1.0,
201+
phase_training:np.array(False)
189202
})
190203
test_accuracy = get_roc_auc_scores(label_test_vals, logits_test)
191204
test_accuracy_final += test_accuracy

0 commit comments

Comments
 (0)