Skip to content

Commit 0763aff

Browse files
author
Li-Wei Chen
committed
First commit
1 parent 6f4627d commit 0763aff

29 files changed

+502894
-1
lines changed

README.md

+117-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,117 @@
1-
# MQTTS
1+
# MQTTS
2+
- Official implementation for the paper [TODO]().
3+
- Audio samples (40 each system) can be accessed at [here](https://cmu.box.com/s/ktbk9pi04e2z1dlyepkkw69xcu9w91dj).
4+
- Quick demo can be accessed [TODO]().
5+
## Setup the environment
6+
1. Setup conda environment:
7+
```
8+
conda create --name mqtts python=3.9
9+
conda activate mqtts
10+
conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge
11+
pip install -r requirements.txt
12+
```
13+
(Update) You may need to create an access token to use the speaker embedding of pyannote as they updated their policy.
14+
If that's the case follow the [pyannote repo](https://github.com/pyannote/pyannote-audio) and change every `Inference("pyannote/embedding", window="whole")` accordingly.
15+
16+
2. Download the pretrained phonemizer checkpoint
17+
```
18+
wget https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_forward.pt
19+
```
20+
21+
## Preprocess the dataset
22+
1. Get the GigaSpeech dataset from the [official repo](https://github.com/SpeechColab/GigaSpeech)
23+
2. Install [FFmpeg](https://ffmpeg.org), then
24+
```
25+
conda install ffmpeg=4.3=hf484d3e_0
26+
conda update ffmpeg
27+
```
28+
3. Run python script
29+
```
30+
python preprocess.py --giga_speech_dir GIGASPEECH --outputdir datasets
31+
```
32+
33+
## Train the quantizer and inference
34+
1. Train
35+
```
36+
cd quantizer/
37+
python train.py --input_wavs_dir ../datasets/audios \
38+
--input_training_file ../datasets/training.txt \
39+
--input_validation_file ../datasets/validation.txt \
40+
--checkpoint_path ./checkpoints \
41+
--config config.json
42+
```
43+
44+
2. Inference to get codes for training the second stage
45+
```
46+
python get_labels.py --input_json ../datasets/train.json \
47+
--input_wav_dir ../datasets/audios \
48+
--output_json ../datasets/train_q.json \
49+
--checkpoint_file ./checkpoints/g_{training_steps}
50+
python get_labels.py --input_json ../datasets/dev.json \
51+
--input_wav_dir ../datasets/audios \
52+
--output_json ../datasets/dev_q.json \
53+
--checkpoint_file ./checkpoints/g_{training_steps}
54+
```
55+
56+
## Train the transformer (below an example for the 100M version)
57+
```
58+
cd ..
59+
mkdir ckpt
60+
python train.py \
61+
--distributed \
62+
--saving_path ckpt/ \
63+
--sampledir logs/ \
64+
--vocoder_config_path quantizer/checkpoints/config.json \
65+
--vocoder_ckpt_path quantizer/checkpoints/g_{training_steps} \
66+
--datadir datasets/audios \
67+
--metapath datasets/train_q.json \
68+
--val_metapath datasets/dev_q.json \
69+
--use_repetition_token \
70+
--ar_layer 4 \
71+
--ar_ffd_size 1024 \
72+
--ar_hidden_size 256 \
73+
--ar_nheads 4 \
74+
--speaker_embed_dropout 0.05 \
75+
--enc_nlayers 6 \
76+
--dec_nlayers 6 \
77+
--ffd_size 3072 \
78+
--hidden_size 768 \
79+
--nheads 12 \
80+
--batch_size 200 \
81+
--precision bf16 \
82+
--training_step 800000 \
83+
--layer_norm_eps 1e-05
84+
```
85+
You can view the progress using:
86+
```
87+
tensorboard --logdir logs/
88+
```
89+
90+
## Run batched inference (You'll have to change `speaker_to_text.json`, it's just an example.)
91+
```
92+
mkdir infer_samples
93+
CUDA_VISIBLE_DEVICES=0 python infer.py \
94+
--phonemizer_dict_path en_us_cmudict_forward.pt \
95+
--model_path ckpt/last.ckpt \
96+
--config_path ckpt/config.json \
97+
--input_path speaker_to_text.json \
98+
--outputdir infer_samples \
99+
--batch_size {batch_size} \
100+
--top_p 0.8 \
101+
--min_top_k 2 \
102+
--max_output_length {Maximum Output Frames to prevent infinite loop} \
103+
--phone_context_window 3 \
104+
--clean_speech_prior
105+
```
106+
107+
### Pretrained checkpoints
108+
109+
1. Quantizer (put it under `quantizer/checkpoints/`):
110+
```
111+
wget https://anonfiles.com/Tf52ua4dy8/g_00600000
112+
```
113+
114+
2. Transformer (100M version) (put it under `ckpt/`):
115+
```
116+
wget https://anonfiles.com/o6C1u747y6/last_ckpt
117+
```

data/QuantizeDataset.py

+138
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import os
2+
from torch.utils import data
3+
import torch
4+
import json
5+
import numpy as np
6+
import soundfile as sf
7+
import random
8+
from pathlib import Path
9+
from librosa.util import normalize
10+
from pyannote.audio import Inference
11+
12+
import torch.nn.functional as F
13+
14+
def random_crop(x, maxseqlen):
15+
if x.shape[0] >= maxseqlen:
16+
offset = random.randrange(x.shape[0] - maxseqlen + 1)
17+
x = x[offset: offset + maxseqlen]
18+
else:
19+
offset = 0
20+
return x, offset
21+
22+
def dynamic_range_compression(x, C=0.3, M=6.5, clip_val=1e-5):
23+
return (np.log(np.clip(x, a_min=clip_val, a_max=None)) + M) * C
24+
25+
def dynamic_range_decompression(x, C=0.3, M=6.5):
26+
return np.exp(x / C - M)
27+
28+
class QuantizeDataset(data.Dataset):
29+
def __init__(self, hp, metapath):
30+
self.hp = hp
31+
print (f'Loading metadata in {metapath}...')
32+
with open(metapath, 'r') as f:
33+
self.text = json.load(f) #{name: {text:, phoneme:, ..., duration: }}
34+
self.datasetbase = [x for x in self.text.keys()]
35+
self.dataset = [os.path.join(self.hp.datadir, x) for x in self.datasetbase]
36+
self.phoneset = ['<pad>', 'AA', 'AE', 'AH', 'AO', 'AW', 'AY', 'B', 'CH', 'D', 'DH', 'EH', 'ER', 'EY', 'F', 'G', 'HH', 'IH', 'IY', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OY', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UW', 'V', 'W', 'Y', 'Z', 'ZH', ',', '.']
37+
print (self.phoneset)
38+
if self.hp.speaker_embedding_dir is None:
39+
self.spkr_embedding = Inference("pyannote/embedding", window="whole")
40+
41+
#Print statistics:
42+
l = len(self.dataset)
43+
print (f'Total {l} examples')
44+
45+
self.lengths = [float(v['duration']) for v in self.text.values()]
46+
avglen = sum(self.lengths) / len(self.lengths)
47+
maxlen = max(self.lengths)
48+
minlen = min(self.lengths)
49+
print (f"Average duration of audio: {avglen} sec, Maximum duration: {maxlen} sec, Minimum duration: {minlen} sec")
50+
51+
def __len__(self):
52+
return len(self.dataset)
53+
54+
def __getitem__(self, i):
55+
dataname = self.dataset[i]
56+
_name = self.datasetbase[i]
57+
metadata = self.text[_name]
58+
#To synthesized phoneme sequence
59+
phonemes = [self.phoneset.index(ph) for ph in metadata['phoneme'].split() if ph in self.phoneset]
60+
61+
if self.hp.speaker_embedding_dir is None:
62+
audio, sampling_rate = sf.read(dataname)
63+
audio = normalize(audio) * 0.95
64+
speaker_embedding = self.spkr_embedding({'waveform': torch.FloatTensor(audio).unsqueeze(0), 'sample_rate': self.hp.sample_rate})
65+
else:
66+
speaker_embedding = os.path.join(self.hp.speaker_embedding_dir, os.path.splitext(_name)[0] + '.npy')
67+
speaker_embedding = np.load(speaker_embedding).astype(np.float32)
68+
69+
#Ground truth for TTS system
70+
quantization = np.array(metadata['quantization']).T # ..., 4
71+
#Add start token, end token
72+
start, end = np.full((1, self.hp.n_cluster_groups), self.hp.n_codes + 1, dtype=np.int16), np.full((1, self.hp.n_cluster_groups), self.hp.n_codes, dtype=np.int16)
73+
quantization_s = np.concatenate([start, quantization.copy()], 0)
74+
#Add repetition token if needed for ground truth "label"
75+
if self.hp.use_repetition_token:
76+
pad = np.full((1, self.hp.n_cluster_groups), -100, dtype=np.int16)
77+
np_mask = np.diff(quantization, axis=0, prepend=pad)
78+
quantization[np_mask == 0] = self.hp.n_codes + 2
79+
quantization_e = np.concatenate([quantization, end], 0)
80+
return speaker_embedding, quantization_s, quantization_e, phonemes, dataname
81+
82+
def seqCollate(self, batch):
83+
output = {
84+
'speaker': [],
85+
'phone': [],
86+
'phone_mask': [],
87+
'tts_quantize_input': [],
88+
'tts_quantize_output': [],
89+
'quantize_mask': [],
90+
}
91+
#Get the max length of everything
92+
max_len_q, max_phonelen = 0, 0
93+
for spkr, q_s, q_e, ph, _ in batch:
94+
if len(q_s) > max_len_q:
95+
max_len_q = len(q_s)
96+
if len(ph) > max_phonelen:
97+
max_phonelen = len(ph)
98+
output['speaker'].append(spkr)
99+
#Pad each element, create mask
100+
for _, qs, qe, phone, _ in batch:
101+
#Deal with phonemes
102+
phone_mask = np.array([False] * len(phone) + [True] * (max_phonelen - len(phone)))
103+
phone = np.pad(phone, [0, max_phonelen-len(phone)])
104+
#Deal with quantizations
105+
q_mask = np.array([False] * len(qs) + [True] * (max_len_q - len(qs)))
106+
qs = np.pad(qs, [[0, max_len_q-len(qs)], [0, 0]], constant_values=self.hp.n_codes)
107+
qe = np.pad(qe, [[0, max_len_q-len(qe)], [0, 0]], constant_values=self.hp.n_codes)
108+
#Aggregate
109+
output['phone'].append(phone)
110+
output['phone_mask'].append(phone_mask)
111+
output['tts_quantize_input'].append(qs)
112+
output['tts_quantize_output'].append(qe)
113+
output['quantize_mask'].append(q_mask)
114+
for k in output.keys():
115+
output[k] = np.array(output[k])
116+
if 'mask' in k:
117+
output[k] = torch.BoolTensor(output[k])
118+
elif k in ['phone', 'tts_quantize_input', 'tts_quantize_output']:
119+
output[k] = torch.LongTensor(output[k])
120+
else:
121+
output[k] = torch.FloatTensor(output[k])
122+
return output
123+
124+
class QuantizeDatasetVal(QuantizeDataset):
125+
def __len__(self):
126+
return len(self.dataset)
127+
128+
def __getitem__(self, i):
129+
speaker_embedding, quantization_s, quantization_e, phonemes, dataname = super().__getitem__(i)
130+
audio, sampling_rate = sf.read(dataname)
131+
audio = normalize(audio) * 0.95
132+
return (
133+
torch.FloatTensor(speaker_embedding),
134+
torch.LongTensor(quantization_s),
135+
torch.LongTensor(quantization_e),
136+
torch.LongTensor(phonemes),
137+
torch.FloatTensor(audio)
138+
)

data/sampler.py

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from torch.utils import data
2+
import torch
3+
import math
4+
import numpy as np
5+
import random
6+
7+
def StandardSampler(dataset, shuffle, distributed=False,
8+
world_size=None, rank=None):
9+
if distributed:
10+
return data.distributed.DistributedSampler(dataset, shuffle=shuffle,
11+
num_replicas=world_size, rank=rank)
12+
if shuffle:
13+
return data.RandomSampler(dataset)
14+
return data.SequentialSampler(dataset)
15+
16+
def RandomBucketSampler(nbuckets, length, batch_size, drop_last, distributed=False,
17+
world_size=None, rank=None):
18+
if distributed:
19+
return DistributedRandomBucketSampler(nbuckets, length, batch_size, drop_last, world_size, rank)
20+
return SingleRandomBucketSampler(nbuckets, length, batch_size, drop_last)
21+
22+
class SingleRandomBucketSampler(data.Sampler):
23+
def __init__(self, nbuckets, length, batch_size, drop_last):
24+
self.length = length
25+
self.batch_size = batch_size
26+
self.drop_last = drop_last
27+
indices = np.argsort([-x for x in length])
28+
split = len(indices) // nbuckets
29+
self.indices = []
30+
for i in range(nbuckets):
31+
self.indices.append(indices[i*split:(i+1)*split])
32+
if nbuckets * split < len(length):
33+
self.indices.append(indices[nbuckets*split:])
34+
35+
def __iter__(self):
36+
random.shuffle(self.indices)
37+
for x in self.indices:
38+
random.shuffle(x)
39+
idxs = [i for x in self.indices for i in x]
40+
batches, batch, sum_len, max_len = [], [], 0, 0
41+
for idx in idxs:
42+
batch.append(idx)
43+
sum_len += self.length[idx]
44+
max_len = max(self.length[idx], max_len)
45+
if max_len * len(batch) > self.batch_size:
46+
batches.append(batch[:-1])
47+
batch, sum_len, max_len = [batch[-1]], self.length[idx], self.length[idx]
48+
if len(batch) > 0 and not self.drop_last:
49+
batches.append(batch)
50+
random.shuffle(batches)
51+
return iter(batches)
52+
53+
class DistributedRandomBucketSampler(data.Sampler):
54+
def __init__(self, nbuckets, length, batch_size,
55+
drop_last, num_replicas, rank, seed=1234):
56+
if rank >= num_replicas or rank < 0:
57+
raise ValueError(
58+
"Invalid rank {}, rank should be in the interval"
59+
" [0, {}]".format(rank, num_replicas - 1))
60+
indices = np.argsort(length)
61+
split = len(indices) // nbuckets
62+
self.length = length
63+
self.batch_size = batch_size
64+
self.indices = []
65+
for i in range(nbuckets):
66+
self.indices.append(indices[i*split:(i+1)*split])
67+
if nbuckets * split < len(length):
68+
self.indices.append(indices[nbuckets*split:])
69+
self.num_replicas = num_replicas
70+
self.rank = rank
71+
self.epoch = 0
72+
self.seed = seed
73+
74+
def __iter__(self):
75+
#Deterministic shuffling
76+
random.Random(self.epoch + self.seed).shuffle(self.indices)
77+
for i, x in enumerate(self.indices):
78+
seed = self.epoch + self.seed + i * 5
79+
random.Random(seed).shuffle(x)
80+
indices = [i for x in self.indices for i in x]
81+
82+
#Batching
83+
batches, batch, sum_len, max_len = [], [], 0, 0
84+
for idx in indices:
85+
batch.append(idx)
86+
sum_len += self.length[idx]
87+
max_len = max(self.length[idx], max_len)
88+
if max_len * len(batch) > self.batch_size:
89+
batches.append(batch[:-1])
90+
batch, sum_len, max_len = [batch[-1]], self.length[idx], self.length[idx]
91+
# subsample
92+
num_samples = math.ceil((len(batches) - self.num_replicas) / self.num_replicas)
93+
total_size = num_samples * self.num_replicas
94+
batches = batches[:total_size]
95+
batches = batches[self.rank*num_samples: (self.rank+1)*num_samples]
96+
assert len(batches) == num_samples
97+
98+
#Stochastic suffling
99+
random.shuffle(batches)
100+
return iter(batches)
101+
102+
def set_epoch(self, epoch):
103+
self.epoch = epoch
104+

0 commit comments

Comments
 (0)