21
21
import unicodedata
22
22
23
23
from os import path
24
- from threading import RLock
25
- from multiprocessing .dummy import Pool
26
- from multiprocessing import cpu_count
24
+ from multiprocessing import Pool
27
25
from util .downloader import SIMPLE_BAR
28
26
from util .text import Alphabet
29
- from util .importers import get_importers_parser , get_validate_label
30
- from util .helpers import secs_to_hours
27
+ from util .importers import get_importers_parser , get_validate_label , get_counter , get_imported_samples , print_import_report
31
28
32
29
33
30
FIELDNAMES = ['wav_filename' , 'wav_filesize' , 'transcript' ]
34
31
SAMPLE_RATE = 16000
35
32
MAX_SECS = 10
36
33
37
34
38
- def _preprocess_data (tsv_dir , audio_dir , label_filter , space_after_every_character = False ):
35
+ def _preprocess_data (tsv_dir , audio_dir , space_after_every_character = False ):
39
36
for dataset in ['train' , 'test' , 'dev' , 'validated' , 'other' ]:
40
37
input_tsv = path .join (path .abspath (tsv_dir ), dataset + ".tsv" )
41
38
if os .path .isfile (input_tsv ):
42
39
print ("Loading TSV file: " , input_tsv )
43
- _maybe_convert_set (input_tsv , audio_dir , label_filter , space_after_every_character )
44
-
45
-
46
- def _maybe_convert_set (input_tsv , audio_dir , label_filter , space_after_every_character = None ):
40
+ _maybe_convert_set (input_tsv , audio_dir , space_after_every_character )
41
+
42
+ def one_sample (sample ):
43
+ """ Take a audio file, and optionally convert it to 16kHz WAV """
44
+ mp3_filename = sample [0 ]
45
+ if not path .splitext (mp3_filename .lower ())[1 ] == '.mp3' :
46
+ mp3_filename += ".mp3"
47
+ # Storing wav files next to the mp3 ones - just with a different suffix
48
+ wav_filename = path .splitext (mp3_filename )[0 ] + ".wav"
49
+ _maybe_convert_wav (mp3_filename , wav_filename )
50
+ file_size = - 1
51
+ frames = 0
52
+ if path .exists (wav_filename ):
53
+ file_size = path .getsize (wav_filename )
54
+ frames = int (subprocess .check_output (['soxi' , '-s' , wav_filename ], stderr = subprocess .STDOUT ))
55
+ label = label_filter_fun (sample [1 ])
56
+ rows = []
57
+ counter = get_counter ()
58
+ if file_size == - 1 :
59
+ # Excluding samples that failed upon conversion
60
+ counter ['failed' ] += 1
61
+ elif label is None :
62
+ # Excluding samples that failed on label validation
63
+ counter ['invalid_label' ] += 1
64
+ elif int (frames / SAMPLE_RATE * 1000 / 10 / 2 ) < len (str (label )):
65
+ # Excluding samples that are too short to fit the transcript
66
+ counter ['too_short' ] += 1
67
+ elif frames / SAMPLE_RATE > MAX_SECS :
68
+ # Excluding very long samples to keep a reasonable batch-size
69
+ counter ['too_long' ] += 1
70
+ else :
71
+ # This one is good - keep it for the target CSV
72
+ rows .append ((os .path .split (wav_filename )[- 1 ], file_size , label ))
73
+ counter ['all' ] += 1
74
+ counter ['total_time' ] += frames
75
+
76
+ return (counter , rows )
77
+
78
+ def _maybe_convert_set (input_tsv , audio_dir , space_after_every_character = None ):
47
79
output_csv = path .join (audio_dir , os .path .split (input_tsv )[- 1 ].replace ('tsv' , 'csv' ))
48
80
print ("Saving new DeepSpeech-formatted CSV file to: " , output_csv )
49
81
@@ -52,51 +84,18 @@ def _maybe_convert_set(input_tsv, audio_dir, label_filter, space_after_every_cha
52
84
with open (input_tsv , encoding = 'utf-8' ) as input_tsv_file :
53
85
reader = csv .DictReader (input_tsv_file , delimiter = '\t ' )
54
86
for row in reader :
55
- samples .append ((row ['path' ], row ['sentence' ]))
87
+ samples .append ((path . join ( audio_dir , row ['path' ]) , row ['sentence' ]))
56
88
57
- # Keep track of how many samples are good vs. problematic
58
- counter = {'all' : 0 , 'failed' : 0 , 'invalid_label' : 0 , 'too_short' : 0 , 'too_long' : 0 , 'total_time' : 0 }
59
- lock = RLock ()
89
+ counter = get_counter ()
60
90
num_samples = len (samples )
61
91
rows = []
62
92
63
- def one_sample (sample ):
64
- """ Take a audio file, and optionally convert it to 16kHz WAV """
65
- mp3_filename = path .join (audio_dir , sample [0 ])
66
- if not path .splitext (mp3_filename .lower ())[1 ] == '.mp3' :
67
- mp3_filename += ".mp3"
68
- # Storing wav files next to the mp3 ones - just with a different suffix
69
- wav_filename = path .splitext (mp3_filename )[0 ] + ".wav"
70
- _maybe_convert_wav (mp3_filename , wav_filename )
71
- file_size = - 1
72
- frames = 0
73
- if path .exists (wav_filename ):
74
- file_size = path .getsize (wav_filename )
75
- frames = int (subprocess .check_output (['soxi' , '-s' , wav_filename ], stderr = subprocess .STDOUT ))
76
- label = label_filter (sample [1 ])
77
- with lock :
78
- if file_size == - 1 :
79
- # Excluding samples that failed upon conversion
80
- counter ['failed' ] += 1
81
- elif label is None :
82
- # Excluding samples that failed on label validation
83
- counter ['invalid_label' ] += 1
84
- elif int (frames / SAMPLE_RATE * 1000 / 10 / 2 ) < len (str (label )):
85
- # Excluding samples that are too short to fit the transcript
86
- counter ['too_short' ] += 1
87
- elif frames / SAMPLE_RATE > MAX_SECS :
88
- # Excluding very long samples to keep a reasonable batch-size
89
- counter ['too_long' ] += 1
90
- else :
91
- # This one is good - keep it for the target CSV
92
- rows .append ((os .path .split (wav_filename )[- 1 ], file_size , label ))
93
- counter ['all' ] += 1
94
- counter ['total_time' ] += frames
95
-
96
93
print ("Importing mp3 files..." )
97
- pool = Pool (cpu_count () )
94
+ pool = Pool ()
98
95
bar = progressbar .ProgressBar (max_value = num_samples , widgets = SIMPLE_BAR )
99
- for i , _ in enumerate (pool .imap_unordered (one_sample , samples ), start = 1 ):
96
+ for i , processed in enumerate (pool .imap_unordered (one_sample , samples ), start = 1 ):
97
+ counter += processed [0 ]
98
+ rows += processed [1 ]
100
99
bar .update (i )
101
100
bar .update (num_samples )
102
101
pool .close ()
@@ -113,16 +112,11 @@ def one_sample(sample):
113
112
else :
114
113
writer .writerow ({'wav_filename' : filename , 'wav_filesize' : file_size , 'transcript' : transcript })
115
114
116
- print ('Imported %d samples.' % (counter ['all' ] - counter ['failed' ] - counter ['too_short' ] - counter ['too_long' ]))
117
- if counter ['failed' ] > 0 :
118
- print ('Skipped %d samples that failed upon conversion.' % counter ['failed' ])
119
- if counter ['invalid_label' ] > 0 :
120
- print ('Skipped %d samples that failed on transcript validation.' % counter ['invalid_label' ])
121
- if counter ['too_short' ] > 0 :
122
- print ('Skipped %d samples that were too short to match the transcript.' % counter ['too_short' ])
123
- if counter ['too_long' ] > 0 :
124
- print ('Skipped %d samples that were longer than %d seconds.' % (counter ['too_long' ], MAX_SECS ))
125
- print ('Final amount of imported audio: %s.' % secs_to_hours (counter ['total_time' ] / SAMPLE_RATE ))
115
+ imported_samples = get_imported_samples (counter )
116
+ assert counter ['all' ] == num_samples
117
+ assert len (rows ) == imported_samples
118
+
119
+ print_import_report (counter , SAMPLE_RATE , MAX_SECS )
126
120
127
121
128
122
def _maybe_convert_wav (mp3_filename , wav_filename ):
@@ -162,4 +156,4 @@ def label_filter_fun(label):
162
156
label = None
163
157
return label
164
158
165
- _preprocess_data (PARAMS .tsv_dir , AUDIO_DIR , label_filter_fun , PARAMS .space_after_every_character )
159
+ _preprocess_data (PARAMS .tsv_dir , AUDIO_DIR , PARAMS .space_after_every_character )
0 commit comments