Skip to content

Commit f7578ef

Browse files
author
GuangxiaoSong
committed
wavelet & its tfrecord
1 parent 0d442df commit f7578ef

8 files changed

+473
-3
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ data/converted/
220220
data/merge/
221221
data/raw/
222222
data/tvtsets/
223+
data/wavelets/
223224
.idea/
224225

225226

0504_tf_full_connect_NN.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def read_and_decode(filename):
7171
# Network Parameters
7272
n_hidden_1 = 1024 # 1st layer number of features
7373
n_hidden_2 = 1024 # 2nd layer number of features
74+
n_hidden_3 = 1024 # 2nd layer number of features
7475

7576
# h1:512 h2:512, acc is about:0.4
7677

@@ -91,20 +92,25 @@ def multilayer_perceptron(x, weights, biases):
9192
# Hidden layer with sigmoid activation
9293
layer_2 = tf.add(tf.matmul(layer_1, weights['h2']), biases['b2'])
9394
layer_2 = tf.nn.sigmoid(layer_2)
95+
# Hidden layer with sigmoid activation
96+
layer_3 = tf.add(tf.matmul(layer_2, weights['h3']), biases['b3'])
97+
layer_3 = tf.nn.sigmoid(layer_3)
9498
# Output layer with linear activation
95-
out_layer = tf.matmul(layer_2, weights['out']) + biases['out']
99+
out_layer = tf.matmul(layer_3, weights['out']) + biases['out']
96100
return out_layer
97101

98102

99103
# Store layers weight & bias
100104
weights = {
101105
'h1': tf.Variable(tf.random_normal([n_input, n_hidden_1])),
102106
'h2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2])),
103-
'out': tf.Variable(tf.random_normal([n_hidden_2, n_classes]))
107+
'h3': tf.Variable(tf.random_normal([n_hidden_2, n_hidden_3])),
108+
'out': tf.Variable(tf.random_normal([n_hidden_3, n_classes]))
104109
}
105110
biases = {
106111
'b1': tf.Variable(tf.random_normal([n_hidden_1])),
107112
'b2': tf.Variable(tf.random_normal([n_hidden_2])),
113+
'b3': tf.Variable(tf.random_normal([n_hidden_3])),
108114
'out': tf.Variable(tf.random_normal([n_classes]))
109115
}
110116

@@ -136,7 +142,7 @@ def multilayer_perceptron(x, weights, biases):
136142
# Start input enqueue threads.
137143
coord = tf.train.Coordinator()
138144
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
139-
for epoch in range(100):
145+
for epoch in range(10000):
140146
# pass it in through the feed_dict
141147
audio_batch_vals, label_batch_vals = sess.run([audio_batch, label_batch])
142148

0505_LSTM.py

+204
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# -*- coding:utf-8 -*-
2+
3+
"""
4+
@author: Songgx
5+
@file: 0504_tf_full_connect_NN.py
6+
@time: 11/30/16 3:44 PM
7+
"""
8+
9+
from __future__ import print_function
10+
11+
import numpy as np
12+
import tensorflow as tf
13+
14+
15+
def csv_file_line_number(fname):
16+
with open(fname, "r") as f:
17+
num = 0
18+
for line in f:
19+
num += 1
20+
return num
21+
22+
23+
def count_column_num(fname, field_delim):
24+
with open(fname) as f:
25+
line = f.readline().split(field_delim)
26+
# the last column is the class number --> -1
27+
return len(line)
28+
29+
30+
def dense_to_one_hot(labels_dense, num_classes=10):
31+
"""Convert class labels from scalars to one-hot vectors."""
32+
num_labels = labels_dense.shape[0]
33+
index_offset = np.arange(num_labels) * num_classes
34+
labels_one_hot = np.zeros((num_labels, num_classes))
35+
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
36+
return labels_one_hot
37+
38+
39+
def read_and_decode(filename):
40+
filename_queue = tf.train.string_input_producer([filename])
41+
42+
reader = tf.TFRecordReader()
43+
_, serialized_example = reader.read(filename_queue)
44+
features = tf.parse_single_example(serialized_example,
45+
features={
46+
'label': tf.FixedLenFeature([], tf.int64),
47+
# We know the length of both fields. If not the
48+
# tf.VarLenFeature could be used
49+
'features': tf.FixedLenFeature([n_input], tf.float32),
50+
})
51+
52+
X = tf.cast(features['features'], tf.float32)
53+
y = tf.cast(features['label'], tf.int32)
54+
55+
return X, y
56+
57+
58+
# Parameters
59+
learning_rate = 0.001
60+
training_epochs = 10000
61+
display_step = 50
62+
num_threads = 8
63+
64+
training_csv_file_path = "data/tvtsets/training_scat_data.txt"
65+
training_file_path = "data/tvtsets/training_scat_data.tfrecords"
66+
test_csv_file_path = "data/tvtsets/test_scat_data.txt"
67+
test_file_path = "data/tvtsets/test_scat_data.tfrecords"
68+
batch_size = 20
69+
column_num = count_column_num(training_csv_file_path, " ")
70+
# file_length = file_len(csv_file_path)
71+
72+
# Network Parameters
73+
n_input = 28 # data input 8660 * 1 = 433 * 20
74+
n_steps = 28 # timesteps
75+
n_hidden = 128 # hidden layer num of features
76+
n_classes = 10 # MNIST total classes (0-9 digits)
77+
78+
79+
n_input = column_num - 1
80+
n_classes = 10 # total classes (0-9 digits)
81+
82+
# tf Graph input
83+
84+
x = tf.placeholder(tf.float32, [batch_size, n_steps, n_input])
85+
y = tf.placeholder(tf.int32, [batch_size,])
86+
87+
# Define weights
88+
weights = {
89+
'out': tf.Variable(tf.random_normal([n_hidden, n_classes]))
90+
}
91+
biases = {
92+
'out': tf.Variable(tf.random_normal([n_classes]))
93+
}
94+
95+
96+
# Create model
97+
def LSTM(x, weights, biases):
98+
99+
# Prepare data shape to match `rnn` function requirements
100+
# Current data input shape: (batch_size, n_steps, n_input)
101+
# Required shape: 'n_steps' tensors list of shape (batch_size, n_input)
102+
103+
# Permuting batch_size and n_steps
104+
x = tf.transpose(x, [1, 0, 2])
105+
# Reshaping to (n_steps*batch_size, n_input)
106+
x = tf.reshape(x, [-1, n_input])
107+
# Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
108+
x = tf.split(0, n_steps, x)
109+
110+
# Define a lstm cell with tensorflow
111+
lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
112+
113+
# Get lstm cell output
114+
outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)
115+
116+
# Linear activation, using rnn inner loop last output
117+
return tf.matmul(outputs[-1], weights['out']) + biases['out']
118+
119+
120+
121+
# Construct model
122+
pred = LSTM(x, weights, biases)
123+
124+
# Define loss and optimizer & correct_prediction
125+
cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(pred, y))
126+
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
127+
128+
# Launch the graph
129+
130+
audio, label = read_and_decode(training_file_path)
131+
audio_test, label_test = read_and_decode(test_file_path)
132+
133+
134+
#使用shuffle_batch可以随机打乱输入
135+
audio_batch, label_batch = tf.train.shuffle_batch([audio, label],
136+
batch_size=batch_size, capacity=2000,
137+
min_after_dequeue=1000)
138+
139+
# Initializing the variables
140+
init = tf.global_variables_initializer()
141+
142+
# Launch the graph
143+
with tf.Session() as sess:
144+
sess.run(init)
145+
146+
# Start input enqueue threads.
147+
coord = tf.train.Coordinator()
148+
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
149+
for epoch in range(10000):
150+
# pass it in through the feed_dict
151+
audio_batch_vals, label_batch_vals = sess.run([audio_batch, label_batch])
152+
153+
_, loss_val = sess.run([optimizer, cost], feed_dict={x:audio_batch_vals, y:label_batch_vals})
154+
if (epoch + 1) % display_step == 0:
155+
print("Epoch:", '%06d' % (epoch + 1), "cost=", "{:.9f}".format(loss_val))
156+
157+
print("Training finished.")
158+
coord.request_stop()
159+
coord.join(threads)
160+
161+
# Test model
162+
# Calculate accuracy
163+
test_example_number = csv_file_line_number(test_csv_file_path)
164+
correct_num = 0
165+
for _ in range(test_example_number):
166+
audio_test_val, label_test_val = sess.run([audio_test, label_test])
167+
audio_test_val_vector = np.array([audio_test_val])
168+
test_pred = multilayer_perceptron(audio_test_val_vector, weights, biases)
169+
'''
170+
print (sess.run([test_pred]))
171+
172+
[array([[ 11.13519478, 12.56501865, 14.68154907, 2.02798128,
173+
-5.89219952, -0.76298785, 0.46614531, 5.27717066,
174+
7.54774714, 7.12729597]], dtype=float32)]
175+
'''
176+
pred_class_index = sess.run(tf.argmax(test_pred, 1))
177+
'''
178+
print (sess.run(tf.argmax(test_pred, 1)))
179+
180+
[8]
181+
0
182+
[5]
183+
0
184+
[9]
185+
0
186+
[8]
187+
0
188+
....
189+
190+
[9]
191+
9
192+
[9]
193+
9
194+
[4]
195+
9
196+
[1]
197+
9
198+
'''
199+
200+
if label_test_val == pred_class_index[0]:
201+
correct_num += 1
202+
print("%i / %i is correct." % (correct_num, test_example_number))
203+
print("Accuracy is %f ." % (float(correct_num) / test_example_number))
204+
sess.close()

0600_wavelets_input.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# -*- coding:utf-8 -*-
2+
3+
"""
4+
@author: Songgx
5+
@file: 0600_wavelets_input.py
6+
@time: 2016/12/20 16:41
7+
"""
8+
9+
import numpy as np
10+
11+
cA = np.loadtxt("C:/Users/song/data/wavelets/classical.00000.wavelet_0_cA", dtype=np.float32, delimiter="\n")
12+
cD = np.loadtxt("C:/Users/song/data/wavelets/classical.00000.wavelet_0_cD", dtype=np.float32, delimiter="\n")
13+
14+
wavelets = np.vstack((cA, cD))
15+
16+
print (wavelets.shape)

data/0300_wavelet_hw.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# -*- coding:utf-8 -*-
2+
3+
"""
4+
@author: Songgx
5+
@file: 0300_wavelet_hw.py
6+
@time: 2016/12/20 14:23
7+
"""
8+
9+
import os
10+
import sunau
11+
import numpy as np
12+
import pywt
13+
14+
file = os.path.abspath('classical.00000.au')
15+
music = sunau.open(file, 'r')
16+
17+
# 读取格式信息
18+
# (nchannels, sampwidth, framerate, nframes, comptype, compname)
19+
nchannels = music.getnchannels() # nchannels = 1
20+
sampwidth = music.getsampwidth()
21+
framerate = music.getframerate()
22+
nframes = music.getnframes()
23+
24+
# 读取波形数据
25+
str_data = music.readframes(nframes)
26+
music.close()
27+
28+
#将波形数据转换为数组
29+
wave_data = np.fromstring(str_data, dtype=np.float32)
30+
print (wave_data.shape)
31+
# (330897,)
32+
33+
# wavelet transform, db6
34+
# 单尺度低频系数cA, 单尺度高频系数cD
35+
cA, cD = pywt.dwt(wave_data, 'db6')
36+
37+
cA = np.divide(cA, pow(10, 30))
38+
print(cA.shape)
39+
print(cD.shape)
40+
# (165454,)
41+
# (165454,)
42+
print(cA)
43+
44+
# 32768 大约6s的子采样
45+
# 32768 * 5 = 163840

data/0301_wavelet_transform.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# -*- coding:utf-8 -*-
2+
3+
"""
4+
@author: Songgx
5+
@file: 0301_wavelet_transform.py
6+
@time: 2016/12/20 15:15
7+
"""
8+
9+
# (330897,)
10+
# 32768 大约6s的子采样
11+
# 32768 * 5 = 163840
12+
13+
import os
14+
import sunau
15+
import numpy as np
16+
import pywt
17+
18+
19+
output_path = "wavelets/"
20+
21+
def wavelet_trans(au_file_path, size=32768):
22+
music = sunau.open(au_file_path, 'r')
23+
24+
if au_file_path.find("/") != -1:
25+
au_filename = au_file_path[(au_file_path.rfind("/")+1):]
26+
else:
27+
au_filename = au_file_path
28+
29+
nframes = music.getnframes()
30+
31+
# 读取波形数据
32+
str_data = music.readframes(nframes)
33+
music.close()
34+
35+
# 将波形数据转换为数组
36+
wave_data = np.fromstring(str_data, dtype=np.float32)
37+
38+
# wavelet transform, db6
39+
# 单尺度低频系数cA, 单尺度高频系数cD
40+
cA, cD = pywt.dwt(wave_data, 'db6')
41+
cA = np.divide(cA, pow(10, 30))
42+
cD = np.divide(cD, pow(10, 30))
43+
spilit_num = int((cA.shape[0]) / size)
44+
for i in range(spilit_num):
45+
cA_i = cA[size * i:size * (i+1)]
46+
cD_i = cD[size * i:size * (i+1)]
47+
spilit_sample_name = (au_filename + "_" + str(i)).replace(".au", ".wavelet")
48+
49+
cA_i_temp_path = spilit_sample_name+"_cA"
50+
cD_i_temp_path = spilit_sample_name + "_cD"
51+
np.savetxt(output_path + cA_i_temp_path, cA_i, fmt='%s', delimiter=' ')
52+
np.savetxt(output_path + cD_i_temp_path, cD_i, fmt='%s', delimiter=' ')
53+
54+
print("{} Done.".format(au_filename))
55+
56+
def transform_all_files(folder_path):
57+
for root, subdirs, files in os.walk(folder_path):
58+
subdirs.sort()
59+
# ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']
60+
# 分别对应[0-9]
61+
for subdir in subdirs:
62+
path = root + "/" + subdir
63+
for root1, subdirs1, files1 in os.walk(path):
64+
for file1 in files1:
65+
cur_path = root1 + "/" + file1
66+
wavelet_trans(cur_path)
67+
68+
69+
if __name__ == "__main__":
70+
# wavelet_trans('classical.00000.au')
71+
transform_all_files("D:/dh/DL/music/genres")

0 commit comments

Comments
 (0)