Skip to content

Commit 2b8703a

Browse files
author
GuangxiaoSong
committed
mini batch
1 parent 620a42b commit 2b8703a

5 files changed

+178
-9
lines changed

0503_1_tf_TFrecords_input.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ def read_and_decode(filename):
2929

3030
return X, y
3131

32-
img, label = read_and_decode("data/merge/scat_data_test.tfrecords")
32+
img, label = read_and_decode("data/tvtsets/test_scat_data.tfrecords")
3333

3434
#使用shuffle_batch可以随机打乱输入
3535
img_batch, label_batch = tf.train.shuffle_batch([img, label],
36-
batch_size=2, capacity=2000,
36+
batch_size=20, capacity=2000,
3737
min_after_dequeue=1000)
3838
init = tf.global_variables_initializer()
3939

0503_2_tf_TFrecords_single_input.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# -*- coding:utf-8 -*-
2+
3+
"""
4+
@author: Songgx
5+
@file: 0503_1_tf_TFrecords_input.py
6+
@time: 12/1/16 7:33 PM
7+
"""
8+
9+
from __future__ import print_function
10+
import tensorflow as tf
11+
12+
# https://indico.io/blog/tensorflow-data-inputs-part1-placeholders-protobufs-queues/
13+
14+
def read_and_decode_single_example(filename):
15+
# first construct a queue containing a list of filenames.
16+
# this lets a user split up there dataset in multiple files to keep
17+
# size down
18+
filename_queue = tf.train.string_input_producer([filename], num_epochs=None)
19+
# Unlike the TFRecordWriter, the TFRecordReader is symbolic
20+
reader = tf.TFRecordReader()
21+
# One can read a single serialized example from a filename
22+
# serialized_example is a Tensor of type string.
23+
_, serialized_example = reader.read(filename_queue)
24+
# The serialized example is converted back to actual values.
25+
# One needs to describe the format of the objects to be returned
26+
features = tf.parse_single_example(
27+
serialized_example,
28+
features={
29+
# We know the length of both fields. If not the
30+
# tf.VarLenFeature could be used
31+
'label': tf.FixedLenFeature([], tf.int64),
32+
'feature': tf.VarLenFeature(tf.float32)
33+
})
34+
# now return the converted data
35+
label = features['label']
36+
audio = features['feature']
37+
return label, audio
38+
39+
# returns symbolic label and audio
40+
label, audio = read_and_decode_single_example("data/tvtsets/test_scat_data.tfrecords")
41+
42+
sess = tf.Session()
43+
44+
# Required. See below for explanation
45+
init = tf.global_variables_initializer()
46+
sess.run(init)
47+
tf.train.start_queue_runners(sess=sess)
48+
49+
# grab examples back.
50+
# first example from file
51+
label_val_1, audio_val_1 = sess.run([label, audio])
52+
# second example from file
53+
label_val_2, audio_val_2 = sess.run([label, audio])
54+
55+
'''
56+
The fact that this works requires a fair bit of effort behind the scenes.
57+
First, it is important to remember that TensorFlow’s graphs contain state.
58+
It is this state that allows the TFRecordReader to remember the location of the tfrecord
59+
it’s reading and always return the next one. This is why for almost all TensorFlow work
60+
we need to initialize the graph. We can use the helper function tf.initialize_all_variables(),
61+
which constructs an op that initializes the state on the graph when you run it.
62+
63+
'''
+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# -*- coding:utf-8 -*-
2+
3+
"""
4+
@author: Songgx
5+
@file: 0503_1_tf_TFrecords_input.py
6+
@time: 12/1/16 7:33 PM
7+
"""
8+
9+
from __future__ import print_function
10+
import tensorflow as tf
11+
12+
# https://indico.io/blog/tensorflow-data-inputs-part1-placeholders-protobufs-queues/
13+
14+
def read_and_decode(filename):
15+
filename_queue = tf.train.string_input_producer([filename])
16+
17+
reader = tf.TFRecordReader()
18+
_, serialized_example = reader.read(filename_queue)
19+
features = tf.parse_single_example(serialized_example,
20+
features={
21+
'label': tf.FixedLenFeature([], tf.int64),
22+
# We know the length of both fields. If not the
23+
# tf.VarLenFeature could be used
24+
'features': tf.FixedLenFeature([8660], tf.float32),
25+
})
26+
27+
X = tf.cast(features['features'], tf.float32)
28+
y = tf.cast(features['label'], tf.int32)
29+
30+
return X, y
31+
32+
img, label = read_and_decode("data/tvtsets/test_scat_data.tfrecords")
33+
34+
#使用shuffle_batch可以随机打乱输入
35+
img_batch, label_batch = tf.train.shuffle_batch([img, label],
36+
batch_size=20, capacity=2000,
37+
min_after_dequeue=1000)
38+
init = tf.global_variables_initializer()
39+
40+
# simple model
41+
w = tf.get_variable("w1", [8660, 10])
42+
y_pred = tf.matmul(img_batch, w)
43+
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(y_pred, label_batch)
44+
45+
# for monitoring
46+
loss_mean = tf.reduce_mean(loss)
47+
train_op = tf.train.AdamOptimizer().minimize(loss)
48+
49+
sess = tf.Session()
50+
init = tf.global_variables_initializer()
51+
sess.run(init)
52+
tf.train.start_queue_runners(sess=sess)
53+
54+
for i in range(200):
55+
# pass it in through the feed_dict
56+
_, loss_val = sess.run([train_op, loss_mean])
57+
print (loss_val)
58+
59+
60+
'''
61+
with tf.Session() as sess:
62+
sess.run(init)
63+
coord = tf.train.Coordinator()
64+
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
65+
try:
66+
for i in range(10):
67+
val, l= sess.run([img_batch, label_batch])
68+
print(val[-10:], l)
69+
except tf.errors.OutOfRangeError:
70+
print ('Done reading')
71+
finally:
72+
coord.request_stop()
73+
74+
coord.join(threads)
75+
sess.close()
76+
'''

0504_tf_full_connect_NN.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,42 @@ def count_column_num(fname, field_delim):
2424
# the last column is the class number --> -1
2525
return len(line)
2626

27+
2728
def dense_to_one_hot(labels_dense, num_classes=10):
2829
"""Convert class labels from scalars to one-hot vectors."""
2930
num_labels = labels_dense.shape[0]
3031
index_offset = np.arange(num_labels) * num_classes
3132
labels_one_hot = np.zeros((num_labels, num_classes))
3233
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
3334
return labels_one_hot
35+
36+
37+
def read_and_decode(filename):
38+
filename_queue = tf.train.string_input_producer([filename])
39+
40+
reader = tf.TFRecordReader()
41+
_, serialized_example = reader.read(filename_queue)
42+
features = tf.parse_single_example(serialized_example,
43+
features={
44+
'label': tf.FixedLenFeature([], tf.int64),
45+
# We know the length of both fields. If not the
46+
# tf.VarLenFeature could be used
47+
'features': tf.FixedLenFeature([8660], tf.float32),
48+
})
49+
50+
X = tf.cast(features['features'], tf.float32)
51+
y = tf.cast(features['label'], tf.int32)
52+
53+
return X, y
54+
55+
3456
# Parameters
3557
learning_rate = 0.001
3658
training_epochs = 10000
3759
display_step = 1
3860
num_threads = 4
39-
csv_file_path = "data/merge/scat_data.txt"
40-
training_file_path = "data/merge/scat_data.tfrecords"
61+
csv_file_path = "data/tvtsets/training_scat_data.txt"
62+
training_file_path = "data/tvtsets/training_scat_data.tfrecords"
4163
column_num = count_column_num(csv_file_path, " ")
4264
# file_length = file_len(csv_file_path)
4365
# Network Parameters
@@ -106,8 +128,6 @@ def multilayer_perceptron(x, weights, biases):
106128
features_array = np.reshape(features_array, (1, n_input))
107129
label_array = dense_to_one_hot(np.array([label]), num_classes = n_classes)
108130

109-
with open("0504_log.txt", "w") as f:
110-
f.write("features: {}, label: {}".format(features_array, label_array))
111131
_, c = sess.run([optimizer, cost], feed_dict={x: features_array, y: label_array})
112132
# Display logs per epoch step
113133
if epoch % display_step == 0:

data/0203_convert_to_TFrecords.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
# 2,5,3,3,6,7,7,5,1,1
1818

1919

20-
def convert_tfrecords(input_filename, output_filename):
21-
current_path = os.getcwd() + "/merge/"
20+
def convert_tfrecords(input_filename, output_filename, data_folder):
21+
current_path = os.getcwd() + data_folder
2222
input_file = os.path.join(current_path, input_filename)
2323
output_file = os.path.join(current_path, output_filename)
2424
print("Start to convert {} to {}".format(input_file, output_file))
@@ -43,4 +43,14 @@ def convert_tfrecords(input_filename, output_filename):
4343
print("Successfully convert {} to {}".format(input_file, output_file))
4444

4545

46-
convert_tfrecords("scat_data_test.txt", "scat_data_test.tfrecords")
46+
# convert_tfrecords("scat_data_test.txt", "scat_data_test.tfrecords", "/merge/")
47+
48+
if __name__ == "__main__":
49+
50+
# 转换所有tvtsets目录下的txt文件为tfrecords文件
51+
for root, dirs, file in os.walk("tvtsets"):
52+
for fn in file:
53+
if fn.endswith(".txt"):
54+
tfrecords_name = fn.replace(".txt", ".tfrecords")
55+
# print (tfrecords_name)
56+
convert_tfrecords(fn, tfrecords_name, "/tvtsets/")

0 commit comments

Comments
 (0)