1
+ # -*- coding:utf-8 -*-
2
+
3
+ """
4
+ @author: Songgx
5
+ @file: 0504_tf_full_connect_NN.py
6
+ @time: 11/30/16 3:44 PM
7
+ """
8
+
9
+ from __future__ import print_function
10
+
11
+ import numpy as np
12
+ import tensorflow as tf
13
+
14
+
15
+ def csv_file_line_number (fname ):
16
+ with open (fname , "r" ) as f :
17
+ num = 0
18
+ for line in f :
19
+ num += 1
20
+ return num
21
+
22
+
23
+ def count_column_num (fname , field_delim ):
24
+ with open (fname ) as f :
25
+ line = f .readline ().split (field_delim )
26
+ # the last column is the class number --> -1
27
+ return len (line )
28
+
29
+
30
+ def dense_to_one_hot (labels_dense , num_classes = 10 ):
31
+ """Convert class labels from scalars to one-hot vectors."""
32
+ num_labels = labels_dense .shape [0 ]
33
+ index_offset = np .arange (num_labels ) * num_classes
34
+ labels_one_hot = np .zeros ((num_labels , num_classes ))
35
+ labels_one_hot .flat [index_offset + labels_dense .ravel ()] = 1
36
+ return labels_one_hot
37
+
38
+
39
+ def read_and_decode (filename ):
40
+ filename_queue = tf .train .string_input_producer ([filename ])
41
+
42
+ reader = tf .TFRecordReader ()
43
+ _ , serialized_example = reader .read (filename_queue )
44
+ features = tf .parse_single_example (serialized_example ,
45
+ features = {
46
+ 'label' : tf .FixedLenFeature ([], tf .int64 ),
47
+ # We know the length of both fields. If not the
48
+ # tf.VarLenFeature could be used
49
+ 'features' : tf .FixedLenFeature ([n_input ], tf .float32 ),
50
+ })
51
+
52
+ X = tf .cast (features ['features' ], tf .float32 )
53
+ y = tf .cast (features ['label' ], tf .int32 )
54
+
55
+ return X , y
56
+
57
+
58
+ # Parameters
59
+ learning_rate = 0.001
60
+ training_epochs = 10000
61
+ display_step = 50
62
+ num_threads = 8
63
+
64
+ training_csv_file_path = "data/tvtsets/training_scat_data.txt"
65
+ training_file_path = "data/tvtsets/training_scat_data.tfrecords"
66
+ test_csv_file_path = "data/tvtsets/test_scat_data.txt"
67
+ test_file_path = "data/tvtsets/test_scat_data.tfrecords"
68
+ batch_size = 20
69
+ column_num = count_column_num (training_csv_file_path , " " )
70
+ # file_length = file_len(csv_file_path)
71
+
72
+ # Network Parameters
73
+ n_input = 28 # data input 8660 * 1 = 433 * 20
74
+ n_steps = 28 # timesteps
75
+ n_hidden = 128 # hidden layer num of features
76
+ n_classes = 10 # MNIST total classes (0-9 digits)
77
+
78
+
79
+ n_input = column_num - 1
80
+ n_classes = 10 # total classes (0-9 digits)
81
+
82
+ # tf Graph input
83
+
84
+ x = tf .placeholder (tf .float32 , [batch_size , n_steps , n_input ])
85
+ y = tf .placeholder (tf .int32 , [batch_size ,])
86
+
87
+ # Define weights
88
+ weights = {
89
+ 'out' : tf .Variable (tf .random_normal ([n_hidden , n_classes ]))
90
+ }
91
+ biases = {
92
+ 'out' : tf .Variable (tf .random_normal ([n_classes ]))
93
+ }
94
+
95
+
96
+ # Create model
97
+ def LSTM (x , weights , biases ):
98
+
99
+ # Prepare data shape to match `rnn` function requirements
100
+ # Current data input shape: (batch_size, n_steps, n_input)
101
+ # Required shape: 'n_steps' tensors list of shape (batch_size, n_input)
102
+
103
+ # Permuting batch_size and n_steps
104
+ x = tf .transpose (x , [1 , 0 , 2 ])
105
+ # Reshaping to (n_steps*batch_size, n_input)
106
+ x = tf .reshape (x , [- 1 , n_input ])
107
+ # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
108
+ x = tf .split (0 , n_steps , x )
109
+
110
+ # Define a lstm cell with tensorflow
111
+ lstm_cell = rnn_cell .BasicLSTMCell (n_hidden , forget_bias = 1.0 )
112
+
113
+ # Get lstm cell output
114
+ outputs , states = rnn .rnn (lstm_cell , x , dtype = tf .float32 )
115
+
116
+ # Linear activation, using rnn inner loop last output
117
+ return tf .matmul (outputs [- 1 ], weights ['out' ]) + biases ['out' ]
118
+
119
+
120
+
121
+ # Construct model
122
+ pred = LSTM (x , weights , biases )
123
+
124
+ # Define loss and optimizer & correct_prediction
125
+ cost = tf .reduce_mean (tf .nn .sparse_softmax_cross_entropy_with_logits (pred , y ))
126
+ optimizer = tf .train .AdamOptimizer (learning_rate = learning_rate ).minimize (cost )
127
+
128
+ # Launch the graph
129
+
130
+ audio , label = read_and_decode (training_file_path )
131
+ audio_test , label_test = read_and_decode (test_file_path )
132
+
133
+
134
+ #使用shuffle_batch可以随机打乱输入
135
+ audio_batch , label_batch = tf .train .shuffle_batch ([audio , label ],
136
+ batch_size = batch_size , capacity = 2000 ,
137
+ min_after_dequeue = 1000 )
138
+
139
+ # Initializing the variables
140
+ init = tf .global_variables_initializer ()
141
+
142
+ # Launch the graph
143
+ with tf .Session () as sess :
144
+ sess .run (init )
145
+
146
+ # Start input enqueue threads.
147
+ coord = tf .train .Coordinator ()
148
+ threads = tf .train .start_queue_runners (sess = sess , coord = coord )
149
+ for epoch in range (10000 ):
150
+ # pass it in through the feed_dict
151
+ audio_batch_vals , label_batch_vals = sess .run ([audio_batch , label_batch ])
152
+
153
+ _ , loss_val = sess .run ([optimizer , cost ], feed_dict = {x :audio_batch_vals , y :label_batch_vals })
154
+ if (epoch + 1 ) % display_step == 0 :
155
+ print ("Epoch:" , '%06d' % (epoch + 1 ), "cost=" , "{:.9f}" .format (loss_val ))
156
+
157
+ print ("Training finished." )
158
+ coord .request_stop ()
159
+ coord .join (threads )
160
+
161
+ # Test model
162
+ # Calculate accuracy
163
+ test_example_number = csv_file_line_number (test_csv_file_path )
164
+ correct_num = 0
165
+ for _ in range (test_example_number ):
166
+ audio_test_val , label_test_val = sess .run ([audio_test , label_test ])
167
+ audio_test_val_vector = np .array ([audio_test_val ])
168
+ test_pred = multilayer_perceptron (audio_test_val_vector , weights , biases )
169
+ '''
170
+ print (sess.run([test_pred]))
171
+
172
+ [array([[ 11.13519478, 12.56501865, 14.68154907, 2.02798128,
173
+ -5.89219952, -0.76298785, 0.46614531, 5.27717066,
174
+ 7.54774714, 7.12729597]], dtype=float32)]
175
+ '''
176
+ pred_class_index = sess .run (tf .argmax (test_pred , 1 ))
177
+ '''
178
+ print (sess.run(tf.argmax(test_pred, 1)))
179
+
180
+ [8]
181
+ 0
182
+ [5]
183
+ 0
184
+ [9]
185
+ 0
186
+ [8]
187
+ 0
188
+ ....
189
+
190
+ [9]
191
+ 9
192
+ [9]
193
+ 9
194
+ [4]
195
+ 9
196
+ [1]
197
+ 9
198
+ '''
199
+
200
+ if label_test_val == pred_class_index [0 ]:
201
+ correct_num += 1
202
+ print ("%i / %i is correct." % (correct_num , test_example_number ))
203
+ print ("Accuracy is %f ." % (float (correct_num ) / test_example_number ))
204
+ sess .close ()
0 commit comments