-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path1102_0_mtt_mini_batch_input.py
71 lines (53 loc) · 2.2 KB
/
1102_0_mtt_mini_batch_input.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# -*- coding:utf-8 -*-
"""
@author: Songgx
@file: 1102_0_mtt_mini_batch_input.py
@time: 2017/1/15 10:33
"""
from __future__ import print_function
import tensorflow as tf
import numpy as np
# https://indico.io/blog/tensorflow-data-inputs-part1-placeholders-protobufs-queues/
x_height = 96
x_width = 1366
# 总共的tag数
n_tags = 189
top_50_tags_index = np.loadtxt('data/top_50_tags.txt', delimiter=',', skiprows=0, dtype=int)
def read_and_decode(filename):
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
'features_mel': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([n_tags], tf.float32),
})
x = tf.decode_raw(features['features_mel'], tf.float32)
x = tf.reshape(x, [x_height, x_width])
y = tf.cast(features['label'], tf.float32)
return x, y
def get_top_50_tags(top_50_tags_index, tags_batch_val):
result=[]
for row in tags_batch_val:
result_row=[]
for index in range(len(row)): # 189
if index in top_50_tags_index:
result_row.append(row[index])
result.append(result_row)
return np.array(result)
mel_features, tags = read_and_decode("data/merge/mtt_mel_training.tfrecords")
# 使用shuffle_batch可以随机打乱输入
mel_features_batch, tags_batch = tf.train.shuffle_batch([mel_features, tags],
batch_size=20, capacity=2000,
min_after_dequeue=1000)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
tf.train.start_queue_runners(sess=sess)
for i in range(1):
# pass it in through the feed_dict
mel_features_batch_val, tags_batch_val = sess.run([mel_features_batch, tags_batch])
print(mel_features_batch_val)
print(tags_batch_val)
top_50_tags_batch_val= get_top_50_tags(top_50_tags_index, tags_batch_val)
print(top_50_tags_batch_val)