-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path0800_mel_spectrogram_LSTM.py
151 lines (119 loc) · 5.17 KB
/
0800_mel_spectrogram_LSTM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# -*- coding:utf-8 -*-
"""
@author: Songgx
@file: 0700_raw_CNN_overfitting.py
@time: 2017/1/7 16:21
"""
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.ops import rnn, rnn_cell
def read_and_decode(filename):
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([n_classes], tf.float32),
'features_mel': tf.FixedLenFeature([], tf.string),
})
x = tf.decode_raw(features['features_mel'], tf.float32)
x = tf.reshape(x, [x_height,x_width])
y = tf.cast(features['label'], tf.float32)
return x, y
# Parameters
learning_rate = 0.001
training_iters = 160
training_size = 8000
test_size = 2000
batch_size = 50
display_step = 10
n_classes = 10 # total classes (0-9 digits)
x_height = 96
x_width = 1366
# Network Parameters
n_input = x_width # data input (shape: 96*1366)
n_steps = x_height # timesteps
n_hidden = 256# hidden layer num of features
# tf Graph input
x = tf.placeholder(tf.float32, [None, n_steps, n_input])
y = tf.placeholder(tf.float32, [None, n_classes])
# Define weights
weights = {
'out': tf.Variable(tf.random_normal([n_hidden, n_classes]))
}
biases = {
'out': tf.Variable(tf.random_normal([n_classes]))
}
def LSTM(x, weights, biases):
# Prepare data shape to match `rnn` function requirements
# Current data input shape: (batch_size, n_steps, n_input)
# Required shape: 'n_steps' tensors list of shape (batch_size, n_input)
# Permuting batch_size and n_steps
x = tf.transpose(x, [1, 0, 2])
# Reshaping to (n_steps*batch_size, n_input)
x = tf.reshape(x, [-1, n_input])
# Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
x = tf.split(0, n_steps, x)
# Define a lstm cell with tensorflow
lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
# Get lstm cell output
outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)
# Linear activation, using rnn inner loop last output
return tf.matmul(outputs[-1], weights['out']) + biases['out']
pred = LSTM(x, weights, biases)
# Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
# Evaluate model
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
# training & test data
features, label = read_and_decode("data/merge/mel_data_training.tfrecords")
features_test, label_test = read_and_decode("data/merge/mel_data_test.tfrecords")
#使用shuffle_batch可以随机打乱输入
audio_batch, label_batch = tf.train.shuffle_batch([features, label],
batch_size=batch_size, capacity=2000,
min_after_dequeue=1000)
audio_batch_test, label_batch_test = tf.train.shuffle_batch([features_test, label_test],
batch_size=batch_size, capacity=2000,
min_after_dequeue=1000)
init = tf.global_variables_initializer()
# Launch the graph
with tf.Session() as sess:
sess.run(init)
# for TensorBoard
summary_op = tf.merge_all_summaries()
summary_writer = tf.train.SummaryWriter('model/', sess.graph)
# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
step = 1
# Keep training until reach max iterations
while step * batch_size < training_iters:
audio_batch_vals, label_batch_vals = sess.run([audio_batch, label_batch])
# Run optimization op (backprop)
sess.run(optimizer, feed_dict={x: audio_batch_vals, y: label_batch_vals})
if step % display_step == 0:
# Calculate batch accuracy
acc = sess.run(accuracy, feed_dict={x: audio_batch_vals, y: label_batch_vals})
# Calculate batch loss
loss = sess.run(cost, feed_dict={x: audio_batch_vals, y: label_batch_vals})
print("Iter " + str(step*batch_size) + ", Minibatch Loss= " + \
"{:.6f}".format(loss) + ", Training Accuracy= " + \
"{:.5f}".format(acc))
step += 1
print("Optimization Finished!")
# Test model
# batch_test --> reduce_mean --> final_test_accuracy
test_epochs = int(test_size / batch_size)
test_accuracy_final = 0.
for _ in range(test_epochs):
audio_test_vals, label_test_vals = sess.run([audio_batch_test, label_batch_test])
test_accuracy = sess.run(accuracy, feed_dict={x: audio_test_vals, y: label_test_vals})
test_accuracy_final += test_accuracy
print("test epoch: %d, test accuracy: %f" % (_, test_accuracy))
test_accuracy_final /= test_epochs
print("test accuracy %f" % test_accuracy_final)
coord.request_stop()
coord.join(threads)
sess.close()