|
| 1 | +import os |
| 2 | +from torch.utils import data |
| 3 | +import torch |
| 4 | +import json |
| 5 | +import numpy as np |
| 6 | +import soundfile as sf |
| 7 | +import random |
| 8 | +from pathlib import Path |
| 9 | +from librosa.util import normalize |
| 10 | +from pyannote.audio import Inference |
| 11 | + |
| 12 | +import torch.nn.functional as F |
| 13 | + |
| 14 | +def random_crop(x, maxseqlen): |
| 15 | + if x.shape[0] >= maxseqlen: |
| 16 | + offset = random.randrange(x.shape[0] - maxseqlen + 1) |
| 17 | + x = x[offset: offset + maxseqlen] |
| 18 | + else: |
| 19 | + offset = 0 |
| 20 | + return x, offset |
| 21 | + |
| 22 | +def dynamic_range_compression(x, C=0.3, M=6.5, clip_val=1e-5): |
| 23 | + return (np.log(np.clip(x, a_min=clip_val, a_max=None)) + M) * C |
| 24 | + |
| 25 | +def dynamic_range_decompression(x, C=0.3, M=6.5): |
| 26 | + return np.exp(x / C - M) |
| 27 | + |
| 28 | +class QuantizeDataset(data.Dataset): |
| 29 | + def __init__(self, hp, metapath): |
| 30 | + self.hp = hp |
| 31 | + print (f'Loading metadata in {metapath}...') |
| 32 | + with open(metapath, 'r') as f: |
| 33 | + self.text = json.load(f) #{name: {text:, phoneme:, ..., duration: }} |
| 34 | + self.datasetbase = [x for x in self.text.keys()] |
| 35 | + self.dataset = [os.path.join(self.hp.datadir, x) for x in self.datasetbase] |
| 36 | + self.phoneset = ['<pad>', 'AA', 'AE', 'AH', 'AO', 'AW', 'AY', 'B', 'CH', 'D', 'DH', 'EH', 'ER', 'EY', 'F', 'G', 'HH', 'IH', 'IY', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OY', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UW', 'V', 'W', 'Y', 'Z', 'ZH', ',', '.'] |
| 37 | + print (self.phoneset) |
| 38 | + if self.hp.speaker_embedding_dir is None: |
| 39 | + self.spkr_embedding = Inference("pyannote/embedding", window="whole") |
| 40 | + |
| 41 | + #Print statistics: |
| 42 | + l = len(self.dataset) |
| 43 | + print (f'Total {l} examples') |
| 44 | + |
| 45 | + self.lengths = [float(v['duration']) for v in self.text.values()] |
| 46 | + avglen = sum(self.lengths) / len(self.lengths) |
| 47 | + maxlen = max(self.lengths) |
| 48 | + minlen = min(self.lengths) |
| 49 | + print (f"Average duration of audio: {avglen} sec, Maximum duration: {maxlen} sec, Minimum duration: {minlen} sec") |
| 50 | + |
| 51 | + def __len__(self): |
| 52 | + return len(self.dataset) |
| 53 | + |
| 54 | + def __getitem__(self, i): |
| 55 | + dataname = self.dataset[i] |
| 56 | + _name = self.datasetbase[i] |
| 57 | + metadata = self.text[_name] |
| 58 | + #To synthesized phoneme sequence |
| 59 | + phonemes = [self.phoneset.index(ph) for ph in metadata['phoneme'].split() if ph in self.phoneset] |
| 60 | + |
| 61 | + if self.hp.speaker_embedding_dir is None: |
| 62 | + audio, sampling_rate = sf.read(dataname) |
| 63 | + audio = normalize(audio) * 0.95 |
| 64 | + speaker_embedding = self.spkr_embedding({'waveform': torch.FloatTensor(audio).unsqueeze(0), 'sample_rate': self.hp.sample_rate}) |
| 65 | + else: |
| 66 | + speaker_embedding = os.path.join(self.hp.speaker_embedding_dir, os.path.splitext(_name)[0] + '.npy') |
| 67 | + speaker_embedding = np.load(speaker_embedding).astype(np.float32) |
| 68 | + |
| 69 | + #Ground truth for TTS system |
| 70 | + quantization = np.array(metadata['quantization']).T # ..., 4 |
| 71 | + #Add start token, end token |
| 72 | + start, end = np.full((1, self.hp.n_cluster_groups), self.hp.n_codes + 1, dtype=np.int16), np.full((1, self.hp.n_cluster_groups), self.hp.n_codes, dtype=np.int16) |
| 73 | + quantization_s = np.concatenate([start, quantization.copy()], 0) |
| 74 | + #Add repetition token if needed for ground truth "label" |
| 75 | + if self.hp.use_repetition_token: |
| 76 | + pad = np.full((1, self.hp.n_cluster_groups), -100, dtype=np.int16) |
| 77 | + np_mask = np.diff(quantization, axis=0, prepend=pad) |
| 78 | + quantization[np_mask == 0] = self.hp.n_codes + 2 |
| 79 | + quantization_e = np.concatenate([quantization, end], 0) |
| 80 | + return speaker_embedding, quantization_s, quantization_e, phonemes, dataname |
| 81 | + |
| 82 | + def seqCollate(self, batch): |
| 83 | + output = { |
| 84 | + 'speaker': [], |
| 85 | + 'phone': [], |
| 86 | + 'phone_mask': [], |
| 87 | + 'tts_quantize_input': [], |
| 88 | + 'tts_quantize_output': [], |
| 89 | + 'quantize_mask': [], |
| 90 | + } |
| 91 | + #Get the max length of everything |
| 92 | + max_len_q, max_phonelen = 0, 0 |
| 93 | + for spkr, q_s, q_e, ph, _ in batch: |
| 94 | + if len(q_s) > max_len_q: |
| 95 | + max_len_q = len(q_s) |
| 96 | + if len(ph) > max_phonelen: |
| 97 | + max_phonelen = len(ph) |
| 98 | + output['speaker'].append(spkr) |
| 99 | + #Pad each element, create mask |
| 100 | + for _, qs, qe, phone, _ in batch: |
| 101 | + #Deal with phonemes |
| 102 | + phone_mask = np.array([False] * len(phone) + [True] * (max_phonelen - len(phone))) |
| 103 | + phone = np.pad(phone, [0, max_phonelen-len(phone)]) |
| 104 | + #Deal with quantizations |
| 105 | + q_mask = np.array([False] * len(qs) + [True] * (max_len_q - len(qs))) |
| 106 | + qs = np.pad(qs, [[0, max_len_q-len(qs)], [0, 0]], constant_values=self.hp.n_codes) |
| 107 | + qe = np.pad(qe, [[0, max_len_q-len(qe)], [0, 0]], constant_values=self.hp.n_codes) |
| 108 | + #Aggregate |
| 109 | + output['phone'].append(phone) |
| 110 | + output['phone_mask'].append(phone_mask) |
| 111 | + output['tts_quantize_input'].append(qs) |
| 112 | + output['tts_quantize_output'].append(qe) |
| 113 | + output['quantize_mask'].append(q_mask) |
| 114 | + for k in output.keys(): |
| 115 | + output[k] = np.array(output[k]) |
| 116 | + if 'mask' in k: |
| 117 | + output[k] = torch.BoolTensor(output[k]) |
| 118 | + elif k in ['phone', 'tts_quantize_input', 'tts_quantize_output']: |
| 119 | + output[k] = torch.LongTensor(output[k]) |
| 120 | + else: |
| 121 | + output[k] = torch.FloatTensor(output[k]) |
| 122 | + return output |
| 123 | + |
| 124 | +class QuantizeDatasetVal(QuantizeDataset): |
| 125 | + def __len__(self): |
| 126 | + return len(self.dataset) |
| 127 | + |
| 128 | + def __getitem__(self, i): |
| 129 | + speaker_embedding, quantization_s, quantization_e, phonemes, dataname = super().__getitem__(i) |
| 130 | + audio, sampling_rate = sf.read(dataname) |
| 131 | + audio = normalize(audio) * 0.95 |
| 132 | + return ( |
| 133 | + torch.FloatTensor(speaker_embedding), |
| 134 | + torch.LongTensor(quantization_s), |
| 135 | + torch.LongTensor(quantization_e), |
| 136 | + torch.LongTensor(phonemes), |
| 137 | + torch.FloatTensor(audio) |
| 138 | + ) |
0 commit comments