Skip to content

Commit ce59228

Browse files
Alexandre LissyAlexandre Lissy
Alexandre Lissy
authored and
Alexandre Lissy
committed
Localizeable validate_label
Fixes #2804
1 parent f9e05fe commit ce59228

15 files changed

+81
-18
lines changed

bin/import_cv.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from threading import RLock
1919
from multiprocessing.dummy import Pool
2020
from multiprocessing import cpu_count
21-
from util.text import validate_label
21+
from util.importers import validate_label_eng as validate_label
2222
from util.downloader import maybe_download, SIMPLE_BAR
2323

2424
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']

bin/import_cv2.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from multiprocessing import cpu_count
2727
from util.downloader import SIMPLE_BAR
2828
from util.text import Alphabet
29-
from util.importers import get_importers_parser, validate_label_eng as validate_label
29+
from util.importers import get_importers_parser, get_validate_label
3030
from util.helpers import secs_to_hours
3131

3232

@@ -144,6 +144,7 @@ def _maybe_convert_wav(mp3_filename, wav_filename):
144144
PARSER.add_argument('--space_after_every_character', action='store_true', help='To help transcript join by white space')
145145

146146
PARAMS = PARSER.parse_args()
147+
validate_label = get_validate_label(PARAMS)
147148

148149
AUDIO_DIR = PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, 'clips')
149150
ALPHABET = Alphabet(PARAMS.filter_alphabet) if PARAMS.filter_alphabet else None

bin/import_fisher.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import librosa
2020
import soundfile # <= Has an external dependency on libsndfile
2121

22-
from util.text import validate_label
22+
from util.importers import validate_label_eng as validate_label
2323

2424
def _download_and_preprocess_data(data_dir):
2525
# Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19

bin/import_gram_vaani.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import math
1111
import urllib
1212
import logging
13-
from util.importers import get_importers_parser
13+
from util.importers import get_importers_parser, get_validate_label
1414
import subprocess
1515
from os import path
1616
from pathlib import Path
@@ -19,8 +19,6 @@
1919
import pandas as pd
2020
from sox import Transformer
2121

22-
from util.text import validate_label
23-
2422

2523
__version__ = "0.1.0"
2624
_logger = logging.getLogger(__name__)
@@ -290,6 +288,7 @@ def main(args):
290288
args ([str]): command line parameter list
291289
"""
292290
args = parse_args(args)
291+
validate_label = get_validate_label(args)
293292
setup_logging(args.loglevel)
294293
_logger.info("Starting GramVaani importer...")
295294
_logger.info("Starting loading GramVaani csv...")

bin/import_lingua_libre.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
import sys
88
sys.path.insert(1, os.path.join(sys.path[0], '..'))
99

10-
from util.importers import get_importers_parser
10+
from util.importers import get_importers_parser, get_validate_label
1111

12+
import argparse
1213
import csv
1314
import re
1415
import sox
@@ -26,7 +27,7 @@
2627
from glob import glob
2728

2829
from util.downloader import maybe_download
29-
from util.text import Alphabet, validate_label
30+
from util.text import Alphabet
3031
from util.helpers import secs_to_hours
3132

3233
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
@@ -185,6 +186,7 @@ def handle_args():
185186
if __name__ == "__main__":
186187
CLI_ARGS = handle_args()
187188
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
189+
validate_label = get_validate_label(CLI_ARGS)
188190

189191
bogus_regexes = []
190192
if CLI_ARGS.bogus_records:

bin/import_m-ailabs.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
sys.path.insert(1, os.path.join(sys.path[0], '..'))
1111

12-
from util.importers import get_importers_parser
12+
from util.importers import get_importers_parser, get_validate_label
1313

1414
import csv
1515
import subprocess
@@ -26,7 +26,7 @@
2626
from glob import glob
2727

2828
from util.downloader import maybe_download
29-
from util.text import Alphabet, validate_label
29+
from util.text import Alphabet
3030
from util.helpers import secs_to_hours
3131

3232
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
@@ -182,6 +182,7 @@ def handle_args():
182182
CLI_ARGS = handle_args()
183183
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
184184
SKIP_LIST = filter(None, CLI_ARGS.skiplist.split(','))
185+
validate_label = get_validate_label(CLI_ARGS)
185186

186187
def label_filter(label):
187188
if CLI_ARGS.normalize:

bin/import_slr57.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import sys
88
sys.path.insert(1, os.path.join(sys.path[0], '..'))
99

10-
from util.importers import get_importers_parser
10+
from util.importers import get_importers_parser, get_validate_label
1111

1212
import csv
1313
import re
@@ -27,7 +27,7 @@
2727
from glob import glob
2828

2929
from util.downloader import maybe_download
30-
from util.text import Alphabet, validate_label
30+
from util.text import Alphabet
3131
from util.helpers import secs_to_hours
3232

3333
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
@@ -203,6 +203,7 @@ def handle_args():
203203
if __name__ == "__main__":
204204
CLI_ARGS = handle_args()
205205
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
206+
validate_label = get_validate_label(CLI_ARGS)
206207

207208
def label_filter(label):
208209
if CLI_ARGS.normalize:

bin/import_swb.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import codecs
2121
import tarfile
2222
import requests
23-
from util.text import validate_label
23+
from util.importers import validate_label_eng as validate_label
2424
import librosa
2525
import soundfile # <= Has an external dependency on libsndfile
2626

bin/import_swc.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from glob import glob
2828
from collections import Counter
2929
from multiprocessing.pool import ThreadPool
30-
from util.text import Alphabet, validate_label
30+
from util.text import Alphabet
31+
from util.importers import validate_label_eng as validate_label
3132
from util.downloader import maybe_download, SIMPLE_BAR
3233

3334
SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar"

bin/import_ts.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import sys
99
sys.path.insert(1, os.path.join(sys.path[0], '..'))
1010

11-
from util.importers import get_importers_parser
11+
from util.importers import get_importers_parser, get_validate_label
1212

1313
import csv
1414
import unidecode
@@ -25,7 +25,6 @@
2525
from os import path
2626

2727
from util.downloader import maybe_download
28-
from util.text import validate_label
2928
from util.helpers import secs_to_hours
3029

3130
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
@@ -193,4 +192,5 @@ def handle_args():
193192

194193
if __name__ == "__main__":
195194
cli_args = handle_args()
195+
validate_label = get_validate_label(cli_args)
196196
_download_and_preprocess_data(cli_args.target_dir, cli_args.english_compatible)

bin/import_tuda.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121

2222
from os import path
2323
from collections import Counter
24-
from util.text import Alphabet, validate_label
24+
from util.text import Alphabet
25+
from util.importers import validate_label_eng as validate_label
2526
from util.downloader import maybe_download, SIMPLE_BAR
2627

2728
TUDA_VERSION = 'v2'

requirements_tests.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
absl-py
2+
argparse

util/importers.py

+28
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,38 @@
11
import argparse
2+
import importlib
3+
import os
24
import re
5+
import sys
36

47
def get_importers_parser(description):
58
parser = argparse.ArgumentParser(description=description)
9+
parser.add_argument('--validate_label_locale', help='Path to a Python file defining a |validate_label| function for your locale. WARNING: THIS WILL ADD THIS FILE\'s DIRECTORY INTO PYTHONPATH.')
610
return parser
711

12+
def get_validate_label(args):
13+
"""
14+
Expects an argparse.Namespace argument to search for validate_label_locale parameter.
15+
If found, this will modify Python's library search path and add the directory of the
16+
file pointed by the validate_label_locale argument.
17+
18+
:param args: The importer's CLI argument object
19+
:type args: argparse.Namespace
20+
21+
:return: The user-supplied validate_label function
22+
:type: function
23+
"""
24+
if 'validate_label_locale' not in args or (args.validate_label_locale is None):
25+
print('WARNING: No --validate_label_locale specified, your might end with inconsistent dataset.')
26+
return validate_label_eng
27+
if not os.path.exists(os.path.abspath(args.validate_label_locale)):
28+
print('ERROR: Inexistent --validate_label_locale specified. Please check.')
29+
return None
30+
module_dir = os.path.abspath(os.path.dirname(args.validate_label_locale))
31+
sys.path.insert(1, module_dir)
32+
fname = os.path.basename(args.validate_label_locale).replace('.py', '')
33+
locale_module = importlib.import_module(fname, package=None)
34+
return locale_module.validate_label
35+
836
# Validate and normalize transcriptions. Returns a cleaned version of the label
937
# or None if it's invalid.
1038
def validate_label_eng(label):

util/test_data/validate_locale_fra.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
def validate_label(label):
2+
return label

util/test_importers.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,38 @@
11
import unittest
22

3-
from .importers import validate_label_eng
3+
from argparse import Namespace
4+
from .importers import validate_label_eng, get_validate_label
45

56
class TestValidateLabelEng(unittest.TestCase):
67

78
def test_numbers(self):
89
label = validate_label_eng("this is a 1 2 3 test")
910
self.assertEqual(label, None)
1011

12+
class TestGetValidateLabel(unittest.TestCase):
13+
14+
def test_no_validate_label_locale(self):
15+
f = get_validate_label(Namespace())
16+
self.assertEqual(f('toto'), 'toto')
17+
self.assertEqual(f('toto1234'), None)
18+
self.assertEqual(f('toto1234[{[{[]'), None)
19+
20+
def test_validate_label_locale_default(self):
21+
f = get_validate_label(Namespace(validate_label_locale=None))
22+
self.assertEqual(f('toto'), 'toto')
23+
self.assertEqual(f('toto1234'), None)
24+
self.assertEqual(f('toto1234[{[{[]'), None)
25+
26+
def test_get_validate_label_missing(self):
27+
args = Namespace(validate_label_locale='util/test_data/validate_locale_ger.py')
28+
f = get_validate_label(args)
29+
self.assertEqual(f, None)
30+
31+
def test_get_validate_label(self):
32+
args = Namespace(validate_label_locale='util/test_data/validate_locale_fra.py')
33+
f = get_validate_label(args)
34+
l = f('toto')
35+
self.assertEqual(l, 'toto')
36+
1137
if __name__ == '__main__':
1238
unittest.main()

0 commit comments

Comments
 (0)