mrfakename commited on
Commit
1646c30
·
verified ·
0 Parent(s):

Super-squash branch 'main' using huggingface_hub

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Yushen CHEN
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: E2/F5 TTS
3
+ emoji: 🗣️
4
+ colorFrom: green
5
+ colorTo: green
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: true
9
+ short_description: 'E2-TTS & F5-TTS: Zero-Shot Voice Cloning (Unofficial Demo)'
10
+ sdk_version: 5.0.1
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import torch
4
+ import torchaudio
5
+ import gradio as gr
6
+ import numpy as np
7
+ import tempfile
8
+ from einops import rearrange
9
+ from ema_pytorch import EMA
10
+ from vocos import Vocos
11
+ from pydub import AudioSegment
12
+ from model import CFM, UNetT, DiT, MMDiT
13
+ from cached_path import cached_path
14
+ from model.utils import (
15
+ get_tokenizer,
16
+ convert_char_to_pinyin,
17
+ save_spectrogram,
18
+ )
19
+ from transformers import pipeline
20
+ import spaces
21
+ import librosa
22
+
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ pipe = pipeline(
26
+ "automatic-speech-recognition",
27
+ model="openai/whisper-large-v3-turbo",
28
+ torch_dtype=torch.float16,
29
+ device=device,
30
+ )
31
+
32
+ # --------------------- Settings -------------------- #
33
+
34
+ target_sample_rate = 24000
35
+ n_mel_channels = 100
36
+ hop_length = 256
37
+ target_rms = 0.1
38
+ nfe_step = 32 # 16, 32
39
+ cfg_strength = 2.0
40
+ ode_method = 'euler'
41
+ sway_sampling_coef = -1.0
42
+ speed = 1.0
43
+ # fix_duration = 27 # None or float (duration in seconds)
44
+ fix_duration = None
45
+
46
+ def load_model(exp_name, model_cls, model_cfg, ckpt_step):
47
+ checkpoint = torch.load(str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
48
+ vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
49
+ model = CFM(
50
+ transformer=model_cls(
51
+ **model_cfg,
52
+ text_num_embeds=vocab_size,
53
+ mel_dim=n_mel_channels
54
+ ),
55
+ mel_spec_kwargs=dict(
56
+ target_sample_rate=target_sample_rate,
57
+ n_mel_channels=n_mel_channels,
58
+ hop_length=hop_length,
59
+ ),
60
+ odeint_kwargs=dict(
61
+ method=ode_method,
62
+ ),
63
+ vocab_char_map=vocab_char_map,
64
+ ).to(device)
65
+
66
+ ema_model = EMA(model, include_online_model=False).to(device)
67
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
68
+ ema_model.copy_params_from_ema_to_model()
69
+
70
+ return ema_model, model
71
+
72
+ # load models
73
+ F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
74
+ E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
75
+
76
+ F5TTS_ema_model, F5TTS_base_model = load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
77
+ E2TTS_ema_model, E2TTS_base_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
78
+
79
+ @spaces.GPU
80
+ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence):
81
+ print(gen_text)
82
+ if len(gen_text) > 200:
83
+ raise gr.Error("Please keep your text under 200 chars.")
84
+ gr.Info("Converting audio...")
85
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
86
+ aseg = AudioSegment.from_file(ref_audio_orig)
87
+ audio_duration = len(aseg)
88
+ if audio_duration > 15000:
89
+ gr.Warning("Audio is over 15s, clipping to only first 15s.")
90
+ aseg = aseg[:15000]
91
+ aseg.export(f.name, format="wav")
92
+ ref_audio = f.name
93
+ if exp_name == "F5-TTS":
94
+ ema_model = F5TTS_ema_model
95
+ base_model = F5TTS_base_model
96
+ elif exp_name == "E2-TTS":
97
+ ema_model = E2TTS_ema_model
98
+ base_model = E2TTS_base_model
99
+
100
+ if not ref_text.strip():
101
+ gr.Info("No reference text provided, transcribing reference audio...")
102
+ ref_text = outputs = pipe(
103
+ ref_audio,
104
+ chunk_length_s=30,
105
+ batch_size=128,
106
+ generate_kwargs={"task": "transcribe"},
107
+ return_timestamps=False,
108
+ )['text'].strip()
109
+ gr.Info("Finished transcription")
110
+ else:
111
+ gr.Info("Using custom reference text...")
112
+ audio, sr = torchaudio.load(ref_audio)
113
+
114
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
115
+ if rms < target_rms:
116
+ audio = audio * target_rms / rms
117
+ if sr != target_sample_rate:
118
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
119
+ audio = resampler(audio)
120
+ audio = audio.to(device)
121
+
122
+ # Prepare the text
123
+ text_list = [ref_text + gen_text]
124
+ final_text_list = convert_char_to_pinyin(text_list)
125
+
126
+ # Calculate duration
127
+ ref_audio_len = audio.shape[-1] // hop_length
128
+ # if fix_duration is not None:
129
+ # duration = int(fix_duration * target_sample_rate / hop_length)
130
+ # else:
131
+ zh_pause_punc = r"。,、;:?!"
132
+ ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
133
+ gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
134
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
135
+
136
+ # inference
137
+ gr.Info(f"Generating audio using {exp_name}")
138
+ with torch.inference_mode():
139
+ generated, _ = base_model.sample(
140
+ cond=audio,
141
+ text=final_text_list,
142
+ duration=duration,
143
+ steps=nfe_step,
144
+ cfg_strength=cfg_strength,
145
+ sway_sampling_coef=sway_sampling_coef,
146
+ )
147
+
148
+ generated = generated[:, ref_audio_len:, :]
149
+ generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
150
+ gr.Info("Running vocoder")
151
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
152
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
153
+ if rms < target_rms:
154
+ generated_wave = generated_wave * rms / target_rms
155
+
156
+ # wav -> numpy
157
+ generated_wave = generated_wave.squeeze().cpu().numpy()
158
+
159
+ if remove_silence:
160
+ gr.Info("Removing audio silences")
161
+ non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
162
+ non_silent_wave = np.array([])
163
+ for interval in non_silent_intervals:
164
+ start, end = interval
165
+ non_silent_wave = np.concatenate([non_silent_wave, generated_wave[start:end]])
166
+ generated_wave = non_silent_wave
167
+
168
+
169
+ # spectogram
170
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
171
+ spectrogram_path = tmp_spectrogram.name
172
+ save_spectrogram(generated_mel_spec[0].cpu().numpy(), spectrogram_path)
173
+
174
+ return (target_sample_rate, generated_wave), spectrogram_path
175
+
176
+ with gr.Blocks() as app:
177
+ gr.Markdown("""
178
+ # E2/F5 TTS
179
+
180
+ This is an unofficial E2/F5 TTS demo. This demo supports the following TTS models:
181
+
182
+ * [E2-TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
183
+ * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
184
+
185
+ This demo is based on the [F5-TTS](https://github.com/SWivid/F5-TTS) codebase, which is based on an [unofficial E2-TTS implementation](https://github.com/lucidrains/e2-tts-pytorch).
186
+
187
+ The checkpoints support English and Chinese.
188
+
189
+ **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
190
+ """)
191
+
192
+ ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
193
+ gen_text_input = gr.Textbox(label="Text to Generate (max 200 chars.)", lines=4)
194
+ model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
195
+ generate_btn = gr.Button("Synthesize", variant="primary")
196
+ with gr.Accordion("Advanced Settings", open=False):
197
+ ref_text_input = gr.Textbox(label="Reference Text", info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.", lines=2)
198
+ remove_silence = gr.Checkbox(label="[EXPERIMENTAL] Remove Silences", info="The model tends to leave silences, we can manually remove silences if needed. This may produce strange results and is not guarenteed to work.")
199
+
200
+ audio_output = gr.Audio(label="Synthesized Audio")
201
+ spectrogram_output = gr.Image(label="Spectrogram")
202
+
203
+ generate_btn.click(infer, inputs=[ref_audio_input, ref_text_input, gen_text_input, model_choice, remove_silence], outputs=[audio_output, spectrogram_output])
204
+ gr.Markdown("Unofficial demo by [mrfakename](https://x.com/realmrfakename)")
205
+
206
+
207
+ app.queue().launch()
ckpts/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Pretrained model ckpts. https://huggingface.co/SWivid/F5-TTS
3
+
4
+ ```
5
+ ckpts/
6
+ E2TTS_Base/
7
+ model_1200000.pt
8
+ F5TTS_Base/
9
+ model_1200000.pt
10
+ ```
data/Emilia_ZH_EN_pinyin/vocab.txt ADDED
@@ -0,0 +1,2545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ !
3
+ "
4
+ #
5
+ $
6
+ %
7
+ &
8
+ '
9
+ (
10
+ )
11
+ *
12
+ +
13
+ ,
14
+ -
15
+ .
16
+ /
17
+ 0
18
+ 1
19
+ 2
20
+ 3
21
+ 4
22
+ 5
23
+ 6
24
+ 7
25
+ 8
26
+ 9
27
+ :
28
+ ;
29
+ =
30
+ >
31
+ ?
32
+ @
33
+ A
34
+ B
35
+ C
36
+ D
37
+ E
38
+ F
39
+ G
40
+ H
41
+ I
42
+ J
43
+ K
44
+ L
45
+ M
46
+ N
47
+ O
48
+ P
49
+ Q
50
+ R
51
+ S
52
+ T
53
+ U
54
+ V
55
+ W
56
+ X
57
+ Y
58
+ Z
59
+ [
60
+ \
61
+ ]
62
+ _
63
+ a
64
+ a1
65
+ ai1
66
+ ai2
67
+ ai3
68
+ ai4
69
+ an1
70
+ an3
71
+ an4
72
+ ang1
73
+ ang2
74
+ ang4
75
+ ao1
76
+ ao2
77
+ ao3
78
+ ao4
79
+ b
80
+ ba
81
+ ba1
82
+ ba2
83
+ ba3
84
+ ba4
85
+ bai1
86
+ bai2
87
+ bai3
88
+ bai4
89
+ ban1
90
+ ban2
91
+ ban3
92
+ ban4
93
+ bang1
94
+ bang2
95
+ bang3
96
+ bang4
97
+ bao1
98
+ bao2
99
+ bao3
100
+ bao4
101
+ bei
102
+ bei1
103
+ bei2
104
+ bei3
105
+ bei4
106
+ ben1
107
+ ben2
108
+ ben3
109
+ ben4
110
+ beng
111
+ beng1
112
+ beng2
113
+ beng3
114
+ beng4
115
+ bi1
116
+ bi2
117
+ bi3
118
+ bi4
119
+ bian1
120
+ bian2
121
+ bian3
122
+ bian4
123
+ biao1
124
+ biao2
125
+ biao3
126
+ bie1
127
+ bie2
128
+ bie3
129
+ bie4
130
+ bin1
131
+ bin4
132
+ bing1
133
+ bing2
134
+ bing3
135
+ bing4
136
+ bo
137
+ bo1
138
+ bo2
139
+ bo3
140
+ bo4
141
+ bu2
142
+ bu3
143
+ bu4
144
+ c
145
+ ca1
146
+ cai1
147
+ cai2
148
+ cai3
149
+ cai4
150
+ can1
151
+ can2
152
+ can3
153
+ can4
154
+ cang1
155
+ cang2
156
+ cao1
157
+ cao2
158
+ cao3
159
+ ce4
160
+ cen1
161
+ cen2
162
+ ceng1
163
+ ceng2
164
+ ceng4
165
+ cha1
166
+ cha2
167
+ cha3
168
+ cha4
169
+ chai1
170
+ chai2
171
+ chan1
172
+ chan2
173
+ chan3
174
+ chan4
175
+ chang1
176
+ chang2
177
+ chang3
178
+ chang4
179
+ chao1
180
+ chao2
181
+ chao3
182
+ che1
183
+ che2
184
+ che3
185
+ che4
186
+ chen1
187
+ chen2
188
+ chen3
189
+ chen4
190
+ cheng1
191
+ cheng2
192
+ cheng3
193
+ cheng4
194
+ chi1
195
+ chi2
196
+ chi3
197
+ chi4
198
+ chong1
199
+ chong2
200
+ chong3
201
+ chong4
202
+ chou1
203
+ chou2
204
+ chou3
205
+ chou4
206
+ chu1
207
+ chu2
208
+ chu3
209
+ chu4
210
+ chua1
211
+ chuai1
212
+ chuai2
213
+ chuai3
214
+ chuai4
215
+ chuan1
216
+ chuan2
217
+ chuan3
218
+ chuan4
219
+ chuang1
220
+ chuang2
221
+ chuang3
222
+ chuang4
223
+ chui1
224
+ chui2
225
+ chun1
226
+ chun2
227
+ chun3
228
+ chuo1
229
+ chuo4
230
+ ci1
231
+ ci2
232
+ ci3
233
+ ci4
234
+ cong1
235
+ cong2
236
+ cou4
237
+ cu1
238
+ cu4
239
+ cuan1
240
+ cuan2
241
+ cuan4
242
+ cui1
243
+ cui3
244
+ cui4
245
+ cun1
246
+ cun2
247
+ cun4
248
+ cuo1
249
+ cuo2
250
+ cuo4
251
+ d
252
+ da
253
+ da1
254
+ da2
255
+ da3
256
+ da4
257
+ dai1
258
+ dai2
259
+ dai3
260
+ dai4
261
+ dan1
262
+ dan2
263
+ dan3
264
+ dan4
265
+ dang1
266
+ dang2
267
+ dang3
268
+ dang4
269
+ dao1
270
+ dao2
271
+ dao3
272
+ dao4
273
+ de
274
+ de1
275
+ de2
276
+ dei3
277
+ den4
278
+ deng1
279
+ deng2
280
+ deng3
281
+ deng4
282
+ di1
283
+ di2
284
+ di3
285
+ di4
286
+ dia3
287
+ dian1
288
+ dian2
289
+ dian3
290
+ dian4
291
+ diao1
292
+ diao3
293
+ diao4
294
+ die1
295
+ die2
296
+ die4
297
+ ding1
298
+ ding2
299
+ ding3
300
+ ding4
301
+ diu1
302
+ dong1
303
+ dong3
304
+ dong4
305
+ dou1
306
+ dou2
307
+ dou3
308
+ dou4
309
+ du1
310
+ du2
311
+ du3
312
+ du4
313
+ duan1
314
+ duan2
315
+ duan3
316
+ duan4
317
+ dui1
318
+ dui4
319
+ dun1
320
+ dun3
321
+ dun4
322
+ duo1
323
+ duo2
324
+ duo3
325
+ duo4
326
+ e
327
+ e1
328
+ e2
329
+ e3
330
+ e4
331
+ ei2
332
+ en1
333
+ en4
334
+ er
335
+ er2
336
+ er3
337
+ er4
338
+ f
339
+ fa1
340
+ fa2
341
+ fa3
342
+ fa4
343
+ fan1
344
+ fan2
345
+ fan3
346
+ fan4
347
+ fang1
348
+ fang2
349
+ fang3
350
+ fang4
351
+ fei1
352
+ fei2
353
+ fei3
354
+ fei4
355
+ fen1
356
+ fen2
357
+ fen3
358
+ fen4
359
+ feng1
360
+ feng2
361
+ feng3
362
+ feng4
363
+ fo2
364
+ fou2
365
+ fou3
366
+ fu1
367
+ fu2
368
+ fu3
369
+ fu4
370
+ g
371
+ ga1
372
+ ga2
373
+ ga3
374
+ ga4
375
+ gai1
376
+ gai2
377
+ gai3
378
+ gai4
379
+ gan1
380
+ gan2
381
+ gan3
382
+ gan4
383
+ gang1
384
+ gang2
385
+ gang3
386
+ gang4
387
+ gao1
388
+ gao2
389
+ gao3
390
+ gao4
391
+ ge1
392
+ ge2
393
+ ge3
394
+ ge4
395
+ gei2
396
+ gei3
397
+ gen1
398
+ gen2
399
+ gen3
400
+ gen4
401
+ geng1
402
+ geng3
403
+ geng4
404
+ gong1
405
+ gong3
406
+ gong4
407
+ gou1
408
+ gou2
409
+ gou3
410
+ gou4
411
+ gu
412
+ gu1
413
+ gu2
414
+ gu3
415
+ gu4
416
+ gua1
417
+ gua2
418
+ gua3
419
+ gua4
420
+ guai1
421
+ guai2
422
+ guai3
423
+ guai4
424
+ guan1
425
+ guan2
426
+ guan3
427
+ guan4
428
+ guang1
429
+ guang2
430
+ guang3
431
+ guang4
432
+ gui1
433
+ gui2
434
+ gui3
435
+ gui4
436
+ gun3
437
+ gun4
438
+ guo1
439
+ guo2
440
+ guo3
441
+ guo4
442
+ h
443
+ ha1
444
+ ha2
445
+ ha3
446
+ hai1
447
+ hai2
448
+ hai3
449
+ hai4
450
+ han1
451
+ han2
452
+ han3
453
+ han4
454
+ hang1
455
+ hang2
456
+ hang4
457
+ hao1
458
+ hao2
459
+ hao3
460
+ hao4
461
+ he1
462
+ he2
463
+ he4
464
+ hei1
465
+ hen2
466
+ hen3
467
+ hen4
468
+ heng1
469
+ heng2
470
+ heng4
471
+ hong1
472
+ hong2
473
+ hong3
474
+ hong4
475
+ hou1
476
+ hou2
477
+ hou3
478
+ hou4
479
+ hu1
480
+ hu2
481
+ hu3
482
+ hu4
483
+ hua1
484
+ hua2
485
+ hua4
486
+ huai2
487
+ huai4
488
+ huan1
489
+ huan2
490
+ huan3
491
+ huan4
492
+ huang1
493
+ huang2
494
+ huang3
495
+ huang4
496
+ hui1
497
+ hui2
498
+ hui3
499
+ hui4
500
+ hun1
501
+ hun2
502
+ hun4
503
+ huo
504
+ huo1
505
+ huo2
506
+ huo3
507
+ huo4
508
+ i
509
+ j
510
+ ji1
511
+ ji2
512
+ ji3
513
+ ji4
514
+ jia
515
+ jia1
516
+ jia2
517
+ jia3
518
+ jia4
519
+ jian1
520
+ jian2
521
+ jian3
522
+ jian4
523
+ jiang1
524
+ jiang2
525
+ jiang3
526
+ jiang4
527
+ jiao1
528
+ jiao2
529
+ jiao3
530
+ jiao4
531
+ jie1
532
+ jie2
533
+ jie3
534
+ jie4
535
+ jin1
536
+ jin2
537
+ jin3
538
+ jin4
539
+ jing1
540
+ jing2
541
+ jing3
542
+ jing4
543
+ jiong3
544
+ jiu1
545
+ jiu2
546
+ jiu3
547
+ jiu4
548
+ ju1
549
+ ju2
550
+ ju3
551
+ ju4
552
+ juan1
553
+ juan2
554
+ juan3
555
+ juan4
556
+ jue1
557
+ jue2
558
+ jue4
559
+ jun1
560
+ jun4
561
+ k
562
+ ka1
563
+ ka2
564
+ ka3
565
+ kai1
566
+ kai2
567
+ kai3
568
+ kai4
569
+ kan1
570
+ kan2
571
+ kan3
572
+ kan4
573
+ kang1
574
+ kang2
575
+ kang4
576
+ kao1
577
+ kao2
578
+ kao3
579
+ kao4
580
+ ke1
581
+ ke2
582
+ ke3
583
+ ke4
584
+ ken3
585
+ keng1
586
+ kong1
587
+ kong3
588
+ kong4
589
+ kou1
590
+ kou2
591
+ kou3
592
+ kou4
593
+ ku1
594
+ ku2
595
+ ku3
596
+ ku4
597
+ kua1
598
+ kua3
599
+ kua4
600
+ kuai3
601
+ kuai4
602
+ kuan1
603
+ kuan2
604
+ kuan3
605
+ kuang1
606
+ kuang2
607
+ kuang4
608
+ kui1
609
+ kui2
610
+ kui3
611
+ kui4
612
+ kun1
613
+ kun3
614
+ kun4
615
+ kuo4
616
+ l
617
+ la
618
+ la1
619
+ la2
620
+ la3
621
+ la4
622
+ lai2
623
+ lai4
624
+ lan2
625
+ lan3
626
+ lan4
627
+ lang1
628
+ lang2
629
+ lang3
630
+ lang4
631
+ lao1
632
+ lao2
633
+ lao3
634
+ lao4
635
+ le
636
+ le1
637
+ le4
638
+ lei
639
+ lei1
640
+ lei2
641
+ lei3
642
+ lei4
643
+ leng1
644
+ leng2
645
+ leng3
646
+ leng4
647
+ li
648
+ li1
649
+ li2
650
+ li3
651
+ li4
652
+ lia3
653
+ lian2
654
+ lian3
655
+ lian4
656
+ liang2
657
+ liang3
658
+ liang4
659
+ liao1
660
+ liao2
661
+ liao3
662
+ liao4
663
+ lie1
664
+ lie2
665
+ lie3
666
+ lie4
667
+ lin1
668
+ lin2
669
+ lin3
670
+ lin4
671
+ ling2
672
+ ling3
673
+ ling4
674
+ liu1
675
+ liu2
676
+ liu3
677
+ liu4
678
+ long1
679
+ long2
680
+ long3
681
+ long4
682
+ lou1
683
+ lou2
684
+ lou3
685
+ lou4
686
+ lu1
687
+ lu2
688
+ lu3
689
+ lu4
690
+ luan2
691
+ luan3
692
+ luan4
693
+ lun1
694
+ lun2
695
+ lun4
696
+ luo1
697
+ luo2
698
+ luo3
699
+ luo4
700
+ lv2
701
+ lv3
702
+ lv4
703
+ lve3
704
+ lve4
705
+ m
706
+ ma
707
+ ma1
708
+ ma2
709
+ ma3
710
+ ma4
711
+ mai2
712
+ mai3
713
+ mai4
714
+ man1
715
+ man2
716
+ man3
717
+ man4
718
+ mang2
719
+ mang3
720
+ mao1
721
+ mao2
722
+ mao3
723
+ mao4
724
+ me
725
+ mei2
726
+ mei3
727
+ mei4
728
+ men
729
+ men1
730
+ men2
731
+ men4
732
+ meng
733
+ meng1
734
+ meng2
735
+ meng3
736
+ meng4
737
+ mi1
738
+ mi2
739
+ mi3
740
+ mi4
741
+ mian2
742
+ mian3
743
+ mian4
744
+ miao1
745
+ miao2
746
+ miao3
747
+ miao4
748
+ mie1
749
+ mie4
750
+ min2
751
+ min3
752
+ ming2
753
+ ming3
754
+ ming4
755
+ miu4
756
+ mo1
757
+ mo2
758
+ mo3
759
+ mo4
760
+ mou1
761
+ mou2
762
+ mou3
763
+ mu2
764
+ mu3
765
+ mu4
766
+ n
767
+ n2
768
+ na1
769
+ na2
770
+ na3
771
+ na4
772
+ nai2
773
+ nai3
774
+ nai4
775
+ nan1
776
+ nan2
777
+ nan3
778
+ nan4
779
+ nang1
780
+ nang2
781
+ nang3
782
+ nao1
783
+ nao2
784
+ nao3
785
+ nao4
786
+ ne
787
+ ne2
788
+ ne4
789
+ nei3
790
+ nei4
791
+ nen4
792
+ neng2
793
+ ni1
794
+ ni2
795
+ ni3
796
+ ni4
797
+ nian1
798
+ nian2
799
+ nian3
800
+ nian4
801
+ niang2
802
+ niang4
803
+ niao2
804
+ niao3
805
+ niao4
806
+ nie1
807
+ nie4
808
+ nin2
809
+ ning2
810
+ ning3
811
+ ning4
812
+ niu1
813
+ niu2
814
+ niu3
815
+ niu4
816
+ nong2
817
+ nong4
818
+ nou4
819
+ nu2
820
+ nu3
821
+ nu4
822
+ nuan3
823
+ nuo2
824
+ nuo4
825
+ nv2
826
+ nv3
827
+ nve4
828
+ o
829
+ o1
830
+ o2
831
+ ou1
832
+ ou2
833
+ ou3
834
+ ou4
835
+ p
836
+ pa1
837
+ pa2
838
+ pa4
839
+ pai1
840
+ pai2
841
+ pai3
842
+ pai4
843
+ pan1
844
+ pan2
845
+ pan4
846
+ pang1
847
+ pang2
848
+ pang4
849
+ pao1
850
+ pao2
851
+ pao3
852
+ pao4
853
+ pei1
854
+ pei2
855
+ pei4
856
+ pen1
857
+ pen2
858
+ pen4
859
+ peng1
860
+ peng2
861
+ peng3
862
+ peng4
863
+ pi1
864
+ pi2
865
+ pi3
866
+ pi4
867
+ pian1
868
+ pian2
869
+ pian4
870
+ piao1
871
+ piao2
872
+ piao3
873
+ piao4
874
+ pie1
875
+ pie2
876
+ pie3
877
+ pin1
878
+ pin2
879
+ pin3
880
+ pin4
881
+ ping1
882
+ ping2
883
+ po1
884
+ po2
885
+ po3
886
+ po4
887
+ pou1
888
+ pu1
889
+ pu2
890
+ pu3
891
+ pu4
892
+ q
893
+ qi1
894
+ qi2
895
+ qi3
896
+ qi4
897
+ qia1
898
+ qia3
899
+ qia4
900
+ qian1
901
+ qian2
902
+ qian3
903
+ qian4
904
+ qiang1
905
+ qiang2
906
+ qiang3
907
+ qiang4
908
+ qiao1
909
+ qiao2
910
+ qiao3
911
+ qiao4
912
+ qie1
913
+ qie2
914
+ qie3
915
+ qie4
916
+ qin1
917
+ qin2
918
+ qin3
919
+ qin4
920
+ qing1
921
+ qing2
922
+ qing3
923
+ qing4
924
+ qiong1
925
+ qiong2
926
+ qiu1
927
+ qiu2
928
+ qiu3
929
+ qu1
930
+ qu2
931
+ qu3
932
+ qu4
933
+ quan1
934
+ quan2
935
+ quan3
936
+ quan4
937
+ que1
938
+ que2
939
+ que4
940
+ qun2
941
+ r
942
+ ran2
943
+ ran3
944
+ rang1
945
+ rang2
946
+ rang3
947
+ rang4
948
+ rao2
949
+ rao3
950
+ rao4
951
+ re2
952
+ re3
953
+ re4
954
+ ren2
955
+ ren3
956
+ ren4
957
+ reng1
958
+ reng2
959
+ ri4
960
+ rong1
961
+ rong2
962
+ rong3
963
+ rou2
964
+ rou4
965
+ ru2
966
+ ru3
967
+ ru4
968
+ ruan2
969
+ ruan3
970
+ rui3
971
+ rui4
972
+ run4
973
+ ruo4
974
+ s
975
+ sa1
976
+ sa2
977
+ sa3
978
+ sa4
979
+ sai1
980
+ sai4
981
+ san1
982
+ san2
983
+ san3
984
+ san4
985
+ sang1
986
+ sang3
987
+ sang4
988
+ sao1
989
+ sao2
990
+ sao3
991
+ sao4
992
+ se4
993
+ sen1
994
+ seng1
995
+ sha1
996
+ sha2
997
+ sha3
998
+ sha4
999
+ shai1
1000
+ shai2
1001
+ shai3
1002
+ shai4
1003
+ shan1
1004
+ shan3
1005
+ shan4
1006
+ shang
1007
+ shang1
1008
+ shang3
1009
+ shang4
1010
+ shao1
1011
+ shao2
1012
+ shao3
1013
+ shao4
1014
+ she1
1015
+ she2
1016
+ she3
1017
+ she4
1018
+ shei2
1019
+ shen1
1020
+ shen2
1021
+ shen3
1022
+ shen4
1023
+ sheng1
1024
+ sheng2
1025
+ sheng3
1026
+ sheng4
1027
+ shi
1028
+ shi1
1029
+ shi2
1030
+ shi3
1031
+ shi4
1032
+ shou1
1033
+ shou2
1034
+ shou3
1035
+ shou4
1036
+ shu1
1037
+ shu2
1038
+ shu3
1039
+ shu4
1040
+ shua1
1041
+ shua2
1042
+ shua3
1043
+ shua4
1044
+ shuai1
1045
+ shuai3
1046
+ shuai4
1047
+ shuan1
1048
+ shuan4
1049
+ shuang1
1050
+ shuang3
1051
+ shui2
1052
+ shui3
1053
+ shui4
1054
+ shun3
1055
+ shun4
1056
+ shuo1
1057
+ shuo4
1058
+ si1
1059
+ si2
1060
+ si3
1061
+ si4
1062
+ song1
1063
+ song3
1064
+ song4
1065
+ sou1
1066
+ sou3
1067
+ sou4
1068
+ su1
1069
+ su2
1070
+ su4
1071
+ suan1
1072
+ suan4
1073
+ sui1
1074
+ sui2
1075
+ sui3
1076
+ sui4
1077
+ sun1
1078
+ sun3
1079
+ suo
1080
+ suo1
1081
+ suo2
1082
+ suo3
1083
+ t
1084
+ ta1
1085
+ ta2
1086
+ ta3
1087
+ ta4
1088
+ tai1
1089
+ tai2
1090
+ tai4
1091
+ tan1
1092
+ tan2
1093
+ tan3
1094
+ tan4
1095
+ tang1
1096
+ tang2
1097
+ tang3
1098
+ tang4
1099
+ tao1
1100
+ tao2
1101
+ tao3
1102
+ tao4
1103
+ te4
1104
+ teng2
1105
+ ti1
1106
+ ti2
1107
+ ti3
1108
+ ti4
1109
+ tian1
1110
+ tian2
1111
+ tian3
1112
+ tiao1
1113
+ tiao2
1114
+ tiao3
1115
+ tiao4
1116
+ tie1
1117
+ tie2
1118
+ tie3
1119
+ tie4
1120
+ ting1
1121
+ ting2
1122
+ ting3
1123
+ tong1
1124
+ tong2
1125
+ tong3
1126
+ tong4
1127
+ tou
1128
+ tou1
1129
+ tou2
1130
+ tou4
1131
+ tu1
1132
+ tu2
1133
+ tu3
1134
+ tu4
1135
+ tuan1
1136
+ tuan2
1137
+ tui1
1138
+ tui2
1139
+ tui3
1140
+ tui4
1141
+ tun1
1142
+ tun2
1143
+ tun4
1144
+ tuo1
1145
+ tuo2
1146
+ tuo3
1147
+ tuo4
1148
+ u
1149
+ v
1150
+ w
1151
+ wa
1152
+ wa1
1153
+ wa2
1154
+ wa3
1155
+ wa4
1156
+ wai1
1157
+ wai3
1158
+ wai4
1159
+ wan1
1160
+ wan2
1161
+ wan3
1162
+ wan4
1163
+ wang1
1164
+ wang2
1165
+ wang3
1166
+ wang4
1167
+ wei1
1168
+ wei2
1169
+ wei3
1170
+ wei4
1171
+ wen1
1172
+ wen2
1173
+ wen3
1174
+ wen4
1175
+ weng1
1176
+ weng4
1177
+ wo1
1178
+ wo2
1179
+ wo3
1180
+ wo4
1181
+ wu1
1182
+ wu2
1183
+ wu3
1184
+ wu4
1185
+ x
1186
+ xi1
1187
+ xi2
1188
+ xi3
1189
+ xi4
1190
+ xia1
1191
+ xia2
1192
+ xia4
1193
+ xian1
1194
+ xian2
1195
+ xian3
1196
+ xian4
1197
+ xiang1
1198
+ xiang2
1199
+ xiang3
1200
+ xiang4
1201
+ xiao1
1202
+ xiao2
1203
+ xiao3
1204
+ xiao4
1205
+ xie1
1206
+ xie2
1207
+ xie3
1208
+ xie4
1209
+ xin1
1210
+ xin2
1211
+ xin4
1212
+ xing1
1213
+ xing2
1214
+ xing3
1215
+ xing4
1216
+ xiong1
1217
+ xiong2
1218
+ xiu1
1219
+ xiu3
1220
+ xiu4
1221
+ xu
1222
+ xu1
1223
+ xu2
1224
+ xu3
1225
+ xu4
1226
+ xuan1
1227
+ xuan2
1228
+ xuan3
1229
+ xuan4
1230
+ xue1
1231
+ xue2
1232
+ xue3
1233
+ xue4
1234
+ xun1
1235
+ xun2
1236
+ xun4
1237
+ y
1238
+ ya
1239
+ ya1
1240
+ ya2
1241
+ ya3
1242
+ ya4
1243
+ yan1
1244
+ yan2
1245
+ yan3
1246
+ yan4
1247
+ yang1
1248
+ yang2
1249
+ yang3
1250
+ yang4
1251
+ yao1
1252
+ yao2
1253
+ yao3
1254
+ yao4
1255
+ ye1
1256
+ ye2
1257
+ ye3
1258
+ ye4
1259
+ yi
1260
+ yi1
1261
+ yi2
1262
+ yi3
1263
+ yi4
1264
+ yin1
1265
+ yin2
1266
+ yin3
1267
+ yin4
1268
+ ying1
1269
+ ying2
1270
+ ying3
1271
+ ying4
1272
+ yo1
1273
+ yong1
1274
+ yong2
1275
+ yong3
1276
+ yong4
1277
+ you1
1278
+ you2
1279
+ you3
1280
+ you4
1281
+ yu1
1282
+ yu2
1283
+ yu3
1284
+ yu4
1285
+ yuan1
1286
+ yuan2
1287
+ yuan3
1288
+ yuan4
1289
+ yue1
1290
+ yue4
1291
+ yun1
1292
+ yun2
1293
+ yun3
1294
+ yun4
1295
+ z
1296
+ za1
1297
+ za2
1298
+ za3
1299
+ zai1
1300
+ zai3
1301
+ zai4
1302
+ zan1
1303
+ zan2
1304
+ zan3
1305
+ zan4
1306
+ zang1
1307
+ zang4
1308
+ zao1
1309
+ zao2
1310
+ zao3
1311
+ zao4
1312
+ ze2
1313
+ ze4
1314
+ zei2
1315
+ zen3
1316
+ zeng1
1317
+ zeng4
1318
+ zha1
1319
+ zha2
1320
+ zha3
1321
+ zha4
1322
+ zhai1
1323
+ zhai2
1324
+ zhai3
1325
+ zhai4
1326
+ zhan1
1327
+ zhan2
1328
+ zhan3
1329
+ zhan4
1330
+ zhang1
1331
+ zhang2
1332
+ zhang3
1333
+ zhang4
1334
+ zhao1
1335
+ zhao2
1336
+ zhao3
1337
+ zhao4
1338
+ zhe
1339
+ zhe1
1340
+ zhe2
1341
+ zhe3
1342
+ zhe4
1343
+ zhen1
1344
+ zhen2
1345
+ zhen3
1346
+ zhen4
1347
+ zheng1
1348
+ zheng2
1349
+ zheng3
1350
+ zheng4
1351
+ zhi1
1352
+ zhi2
1353
+ zhi3
1354
+ zhi4
1355
+ zhong1
1356
+ zhong2
1357
+ zhong3
1358
+ zhong4
1359
+ zhou1
1360
+ zhou2
1361
+ zhou3
1362
+ zhou4
1363
+ zhu1
1364
+ zhu2
1365
+ zhu3
1366
+ zhu4
1367
+ zhua1
1368
+ zhua2
1369
+ zhua3
1370
+ zhuai1
1371
+ zhuai3
1372
+ zhuai4
1373
+ zhuan1
1374
+ zhuan2
1375
+ zhuan3
1376
+ zhuan4
1377
+ zhuang1
1378
+ zhuang4
1379
+ zhui1
1380
+ zhui4
1381
+ zhun1
1382
+ zhun2
1383
+ zhun3
1384
+ zhuo1
1385
+ zhuo2
1386
+ zi
1387
+ zi1
1388
+ zi2
1389
+ zi3
1390
+ zi4
1391
+ zong1
1392
+ zong2
1393
+ zong3
1394
+ zong4
1395
+ zou1
1396
+ zou2
1397
+ zou3
1398
+ zou4
1399
+ zu1
1400
+ zu2
1401
+ zu3
1402
+ zuan1
1403
+ zuan3
1404
+ zuan4
1405
+ zui2
1406
+ zui3
1407
+ zui4
1408
+ zun1
1409
+ zuo
1410
+ zuo1
1411
+ zuo2
1412
+ zuo3
1413
+ zuo4
1414
+ {
1415
+ ~
1416
+ ¡
1417
+ ¢
1418
+ £
1419
+ ¥
1420
+ §
1421
+ ¨
1422
+ ©
1423
+ «
1424
+ ®
1425
+ ¯
1426
+ °
1427
+ ±
1428
+ ²
1429
+ ³
1430
+ ´
1431
+ µ
1432
+ ·
1433
+ ¹
1434
+ º
1435
+ »
1436
+ ¼
1437
+ ½
1438
+ ¾
1439
+ ¿
1440
+ À
1441
+ Á
1442
+ Â
1443
+ Ã
1444
+ Ä
1445
+ Å
1446
+ Æ
1447
+ Ç
1448
+ È
1449
+ É
1450
+ Ê
1451
+ Í
1452
+ Î
1453
+ Ñ
1454
+ Ó
1455
+ Ö
1456
+ ×
1457
+ Ø
1458
+ Ú
1459
+ Ü
1460
+ Ý
1461
+ Þ
1462
+ ß
1463
+ à
1464
+ á
1465
+ â
1466
+ ã
1467
+ ä
1468
+ å
1469
+ æ
1470
+ ç
1471
+ è
1472
+ é
1473
+ ê
1474
+ ë
1475
+ ì
1476
+ í
1477
+ î
1478
+ ï
1479
+ ð
1480
+ ñ
1481
+ ò
1482
+ ó
1483
+ ô
1484
+ õ
1485
+ ö
1486
+ ø
1487
+ ù
1488
+ ú
1489
+ û
1490
+ ü
1491
+ ý
1492
+ Ā
1493
+ ā
1494
+ ă
1495
+ ą
1496
+ ć
1497
+ Č
1498
+ č
1499
+ Đ
1500
+ đ
1501
+ ē
1502
+ ė
1503
+ ę
1504
+ ě
1505
+ ĝ
1506
+ ğ
1507
+ ħ
1508
+ ī
1509
+ į
1510
+ İ
1511
+ ı
1512
+ Ł
1513
+ ł
1514
+ ń
1515
+ ņ
1516
+ ň
1517
+ ŋ
1518
+ Ō
1519
+ ō
1520
+ ő
1521
+ œ
1522
+ ř
1523
+ Ś
1524
+ ś
1525
+ Ş
1526
+ ş
1527
+ Š
1528
+ š
1529
+ Ť
1530
+ ť
1531
+ ũ
1532
+ ū
1533
+ ź
1534
+ Ż
1535
+ ż
1536
+ Ž
1537
+ ž
1538
+ ơ
1539
+ ư
1540
+ ǎ
1541
+ ǐ
1542
+ ǒ
1543
+ ǔ
1544
+ ǚ
1545
+ ș
1546
+ ț
1547
+ ɑ
1548
+ ɔ
1549
+ ɕ
1550
+ ə
1551
+ ɛ
1552
+ ɜ
1553
+ ɡ
1554
+ ɣ
1555
+ ɪ
1556
+ ɫ
1557
+ ɴ
1558
+ ɹ
1559
+ ɾ
1560
+ ʃ
1561
+ ʊ
1562
+ ʌ
1563
+ ʒ
1564
+ ʔ
1565
+ ʰ
1566
+ ʷ
1567
+ ʻ
1568
+ ʾ
1569
+ ʿ
1570
+ ˈ
1571
+ ː
1572
+ ˙
1573
+ ˜
1574
+ ˢ
1575
+ ́
1576
+ ̅
1577
+ Α
1578
+ Β
1579
+ Δ
1580
+ Ε
1581
+ Θ
1582
+ Κ
1583
+ Λ
1584
+ Μ
1585
+ Ξ
1586
+ Π
1587
+ Σ
1588
+ Τ
1589
+ Φ
1590
+ Χ
1591
+ Ψ
1592
+ Ω
1593
+ ά
1594
+ έ
1595
+ ή
1596
+ ί
1597
+ α
1598
+ β
1599
+ γ
1600
+ δ
1601
+ ε
1602
+ ζ
1603
+ η
1604
+ θ
1605
+ ι
1606
+ κ
1607
+ λ
1608
+ μ
1609
+ ν
1610
+ ξ
1611
+ ο
1612
+ π
1613
+ ρ
1614
+ ς
1615
+ σ
1616
+ τ
1617
+ υ
1618
+ φ
1619
+ χ
1620
+ ψ
1621
+ ω
1622
+ ϊ
1623
+ ό
1624
+ ύ
1625
+ ώ
1626
+ ϕ
1627
+ ϵ
1628
+ Ё
1629
+ А
1630
+ Б
1631
+ В
1632
+ Г
1633
+ Д
1634
+ Е
1635
+ Ж
1636
+ З
1637
+ И
1638
+ Й
1639
+ К
1640
+ Л
1641
+ М
1642
+ Н
1643
+ О
1644
+ П
1645
+ Р
1646
+ С
1647
+ Т
1648
+ У
1649
+ Ф
1650
+ Х
1651
+ Ц
1652
+ Ч
1653
+ Ш
1654
+ Щ
1655
+ Ы
1656
+ Ь
1657
+ Э
1658
+ Ю
1659
+ Я
1660
+ а
1661
+ б
1662
+ в
1663
+ г
1664
+ д
1665
+ е
1666
+ ж
1667
+ з
1668
+ и
1669
+ й
1670
+ к
1671
+ л
1672
+ м
1673
+ н
1674
+ о
1675
+ п
1676
+ р
1677
+ с
1678
+ т
1679
+ у
1680
+ ф
1681
+ х
1682
+ ц
1683
+ ч
1684
+ ш
1685
+ щ
1686
+ ъ
1687
+ ы
1688
+ ь
1689
+ э
1690
+ ю
1691
+ я
1692
+ ё
1693
+ і
1694
+ ְ
1695
+ ִ
1696
+ ֵ
1697
+ ֶ
1698
+ ַ
1699
+ ָ
1700
+ ֹ
1701
+ ּ
1702
+ ־
1703
+ ׁ
1704
+ א
1705
+ ב
1706
+ ג
1707
+ ד
1708
+ ה
1709
+ ו
1710
+ ז
1711
+ ח
1712
+ ט
1713
+ י
1714
+ כ
1715
+ ל
1716
+ ם
1717
+ מ
1718
+ ן
1719
+ נ
1720
+ ס
1721
+ ע
1722
+ פ
1723
+ ק
1724
+ ר
1725
+ ש
1726
+ ת
1727
+ أ
1728
+ ب
1729
+ ة
1730
+ ت
1731
+ ج
1732
+ ح
1733
+ د
1734
+ ر
1735
+ ز
1736
+ س
1737
+ ص
1738
+ ط
1739
+ ع
1740
+ ق
1741
+ ك
1742
+ ل
1743
+ م
1744
+ ن
1745
+ ه
1746
+ و
1747
+ ي
1748
+ َ
1749
+ ُ
1750
+ ِ
1751
+ ْ
1752
+
1753
+
1754
+
1755
+
1756
+
1757
+
1758
+
1759
+
1760
+
1761
+
1762
+
1763
+
1764
+
1765
+
1766
+
1767
+
1768
+
1769
+
1770
+
1771
+
1772
+
1773
+
1774
+
1775
+
1776
+
1777
+
1778
+
1779
+
1780
+
1781
+
1782
+
1783
+
1784
+
1785
+
1786
+
1787
+
1788
+
1789
+
1790
+
1791
+
1792
+
1793
+
1794
+
1795
+
1796
+
1797
+
1798
+
1799
+
1800
+ ế
1801
+
1802
+
1803
+
1804
+
1805
+
1806
+
1807
+
1808
+
1809
+
1810
+
1811
+
1812
+
1813
+
1814
+
1815
+
1816
+
1817
+
1818
+
1819
+
1820
+
1821
+
1822
+
1823
+
1824
+
1825
+
1826
+
1827
+
1828
+
1829
+
1830
+ ���
1831
+
1832
+
1833
+
1834
+
1835
+
1836
+
1837
+
1838
+
1839
+
1840
+
1841
+
1842
+
1843
+
1844
+
1845
+
1846
+
1847
+
1848
+
1849
+
1850
+
1851
+
1852
+
1853
+
1854
+
1855
+
1856
+
1857
+
1858
+
1859
+
1860
+
1861
+
1862
+
1863
+
1864
+
1865
+
1866
+
1867
+
1868
+
1869
+
1870
+
1871
+
1872
+
1873
+
1874
+
1875
+
1876
+
1877
+
1878
+
1879
+
1880
+
1881
+
1882
+
1883
+
1884
+
1885
+
1886
+
1887
+
1888
+
1889
+
1890
+
1891
+
1892
+
1893
+
1894
+
1895
+
1896
+
1897
+
1898
+
1899
+
1900
+
1901
+
1902
+
1903
+
1904
+
1905
+
1906
+
1907
+
1908
+
1909
+
1910
+
1911
+
1912
+
1913
+
1914
+
1915
+
1916
+
1917
+
1918
+
1919
+
1920
+
1921
+
1922
+
1923
+
1924
+
1925
+
1926
+
1927
+
1928
+
1929
+
1930
+
1931
+
1932
+
1933
+
1934
+
1935
+
1936
+
1937
+
1938
+
1939
+
1940
+
1941
+
1942
+
1943
+
1944
+
1945
+
1946
+
1947
+
1948
+
1949
+
1950
+
1951
+
1952
+
1953
+
1954
+
1955
+
1956
+
1957
+
1958
+
1959
+
1960
+
1961
+
1962
+
1963
+
1964
+
1965
+
1966
+
1967
+
1968
+
1969
+
1970
+
1971
+
1972
+
1973
+
1974
+
1975
+
1976
+
1977
+
1978
+
1979
+
1980
+
1981
+
1982
+
1983
+
1984
+
1985
+
1986
+
1987
+
1988
+
1989
+
1990
+
1991
+
1992
+
1993
+
1994
+
1995
+
1996
+
1997
+
1998
+
1999
+
2000
+
2001
+
2002
+
2003
+
2004
+
2005
+
2006
+
2007
+
2008
+
2009
+
2010
+
2011
+
2012
+
2013
+
2014
+
2015
+
2016
+
2017
+
2018
+
2019
+
2020
+
2021
+
2022
+
2023
+
2024
+
2025
+
2026
+
2027
+
2028
+
2029
+
2030
+
2031
+
2032
+
2033
+
2034
+
2035
+
2036
+
2037
+
2038
+
2039
+
2040
+
2041
+
2042
+
2043
+
2044
+
2045
+
2046
+
2047
+
2048
+
2049
+
2050
+
2051
+
2052
+
2053
+
2054
+
2055
+
2056
+
2057
+
2058
+
2059
+
2060
+
2061
+
2062
+
2063
+
2064
+
2065
+
2066
+
2067
+
2068
+
2069
+
2070
+
2071
+
2072
+
2073
+
2074
+
2075
+
2076
+
2077
+
2078
+
2079
+
2080
+
2081
+
2082
+
2083
+
2084
+
2085
+
2086
+
2087
+
2088
+
2089
+
2090
+
2091
+
2092
+
2093
+
2094
+
2095
+
2096
+
2097
+
2098
+
2099
+
2100
+
2101
+
2102
+
2103
+
2104
+
2105
+
2106
+
2107
+
2108
+
2109
+
2110
+
2111
+
2112
+
2113
+
2114
+
2115
+
2116
+
2117
+
2118
+
2119
+
2120
+
2121
+
2122
+
2123
+
2124
+
2125
+
2126
+
2127
+
2128
+
2129
+
2130
+
2131
+
2132
+
2133
+
2134
+
2135
+
2136
+
2137
+
2138
+
2139
+
2140
+
2141
+
2142
+
2143
+
2144
+
2145
+
2146
+
2147
+
2148
+
2149
+
2150
+
2151
+
2152
+
2153
+
2154
+
2155
+
2156
+
2157
+
2158
+
2159
+
2160
+
2161
+
2162
+
2163
+
2164
+
2165
+
2166
+
2167
+
2168
+
2169
+
2170
+
2171
+
2172
+
2173
+
2174
+
2175
+
2176
+
2177
+
2178
+
2179
+
2180
+
2181
+
2182
+
2183
+
2184
+
2185
+
2186
+
2187
+
2188
+
2189
+
2190
+
2191
+
2192
+
2193
+
2194
+
2195
+
2196
+
2197
+
2198
+
2199
+
2200
+
2201
+
2202
+
2203
+
2204
+
2205
+
2206
+
2207
+
2208
+
2209
+
2210
+
2211
+
2212
+
2213
+
2214
+
2215
+
2216
+
2217
+
2218
+
2219
+
2220
+
2221
+
2222
+
2223
+
2224
+
2225
+
2226
+
2227
+
2228
+
2229
+
2230
+
2231
+
2232
+
2233
+
2234
+
2235
+
2236
+
2237
+
2238
+
2239
+
2240
+
2241
+
2242
+
2243
+
2244
+
2245
+
2246
+
2247
+
2248
+
2249
+
2250
+
2251
+
2252
+
2253
+
2254
+
2255
+
2256
+
2257
+
2258
+
2259
+
2260
+
2261
+
2262
+
2263
+
2264
+
2265
+
2266
+
2267
+
2268
+
2269
+
2270
+
2271
+
2272
+
2273
+
2274
+
2275
+
2276
+
2277
+
2278
+
2279
+
2280
+
2281
+
2282
+
2283
+
2284
+
2285
+
2286
+
2287
+
2288
+
2289
+
2290
+
2291
+
2292
+
2293
+
2294
+
2295
+
2296
+
2297
+
2298
+
2299
+
2300
+
2301
+
2302
+
2303
+
2304
+
2305
+
2306
+
2307
+
2308
+
2309
+
2310
+
2311
+
2312
+
2313
+
2314
+
2315
+
2316
+
2317
+
2318
+
2319
+
2320
+
2321
+
2322
+
2323
+
2324
+
2325
+
2326
+
2327
+
2328
+
2329
+
2330
+
2331
+
2332
+
2333
+
2334
+
2335
+
2336
+
2337
+
2338
+
2339
+
2340
+
2341
+
2342
+
2343
+
2344
+
2345
+
2346
+
2347
+
2348
+
2349
+
2350
+
2351
+
2352
+
2353
+
2354
+
2355
+
2356
+
2357
+
2358
+
2359
+
2360
+
2361
+
2362
+
2363
+
2364
+
2365
+
2366
+
2367
+
2368
+
2369
+
2370
+
2371
+
2372
+
2373
+
2374
+
2375
+
2376
+
2377
+
2378
+
2379
+
2380
+
2381
+
2382
+
2383
+
2384
+
2385
+
2386
+
2387
+
2388
+
2389
+
2390
+
2391
+
2392
+
2393
+
2394
+
2395
+
2396
+
2397
+
2398
+
2399
+
2400
+
2401
+
2402
+
2403
+
2404
+
2405
+
2406
+
2407
+
2408
+
2409
+
2410
+
2411
+
2412
+
2413
+
2414
+
2415
+
2416
+
2417
+
2418
+
2419
+
2420
+
2421
+
2422
+
2423
+
2424
+
2425
+
2426
+
2427
+
2428
+
2429
+
2430
+
2431
+
2432
+
2433
+
2434
+
2435
+
2436
+
2437
+
2438
+
2439
+
2440
+
2441
+
2442
+
2443
+
2444
+
2445
+
2446
+
2447
+
2448
+
2449
+
2450
+
2451
+
2452
+
2453
+
2454
+
2455
+
2456
+
2457
+
2458
+
2459
+
2460
+
2461
+
2462
+
2463
+
2464
+
2465
+
2466
+
2467
+
2468
+
2469
+
2470
+
2471
+
2472
+
2473
+
2474
+
2475
+
2476
+
2477
+
2478
+
2479
+
2480
+
2481
+
2482
+
2483
+
2484
+
2485
+
2486
+
2487
+
2488
+
2489
+
2490
+
2491
+
2492
+
2493
+
2494
+
2495
+
2496
+
2497
+
2498
+
2499
+
2500
+
2501
+
2502
+
2503
+
2504
+
2505
+
2506
+
2507
+
2508
+
2509
+
2510
+
2511
+
2512
+
2513
+
2514
+
2515
+
2516
+
2517
+
2518
+
2519
+
2520
+
2521
+
2522
+
2523
+
2524
+
2525
+
2526
+
2527
+
2528
+
2529
+
2530
+
2531
+
2532
+
2533
+
2534
+
2535
+
2536
+
2537
+
2538
+
2539
+
2540
+
2541
+
2542
+
2543
+
2544
+
2545
+ 𠮶
data/librispeech_pc_test_clean_cross_sentence.lst ADDED
The diff for this file is too large to render. See raw diff
 
model/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from model.cfm import CFM
2
+
3
+ from model.backbones.unett import UNetT
4
+ from model.backbones.dit import DiT
5
+ from model.backbones.mmdit import MMDiT
6
+
7
+ from model.trainer import Trainer
model/backbones/README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Backbones quick introduction
2
+
3
+
4
+ ### unett.py
5
+ - flat unet transformer
6
+ - structure same as in e2-tts & voicebox paper except using rotary pos emb
7
+ - update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat
8
+
9
+ ### dit.py
10
+ - adaln-zero dit
11
+ - embedded timestep as condition
12
+ - concatted noised_input + masked_cond + embedded_text, linear proj in
13
+ - possible abs pos emb & convnextv2 blocks for embedded text before concat
14
+ - possible long skip connection (first layer to last layer)
15
+
16
+ ### mmdit.py
17
+ - sd3 structure
18
+ - timestep as condition
19
+ - left stream: text embedded and applied a abs pos emb
20
+ - right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
model/backbones/dit.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ from torch import nn
14
+ import torch.nn.functional as F
15
+
16
+ from einops import repeat
17
+
18
+ from x_transformers.x_transformers import RotaryEmbedding
19
+
20
+ from model.modules import (
21
+ TimestepEmbedding,
22
+ ConvNeXtV2Block,
23
+ ConvPositionEmbedding,
24
+ DiTBlock,
25
+ AdaLayerNormZero_Final,
26
+ precompute_freqs_cis, get_pos_embed_indices,
27
+ )
28
+
29
+
30
+ # Text embedding
31
+
32
+ class TextEmbedding(nn.Module):
33
+ def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
34
+ super().__init__()
35
+ self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
36
+
37
+ if conv_layers > 0:
38
+ self.extra_modeling = True
39
+ self.precompute_max_pos = 4096 # ~44s of 24khz audio
40
+ self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
41
+ self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
42
+ else:
43
+ self.extra_modeling = False
44
+
45
+ def forward(self, text: int['b nt'], seq_len, drop_text = False):
46
+ batch, text_len = text.shape[0], text.shape[1]
47
+ text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
48
+ text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
49
+ text = F.pad(text, (0, seq_len - text_len), value = 0)
50
+
51
+ if drop_text: # cfg for text
52
+ text = torch.zeros_like(text)
53
+
54
+ text = self.text_embed(text) # b n -> b n d
55
+
56
+ # possible extra modeling
57
+ if self.extra_modeling:
58
+ # sinus pos emb
59
+ batch_start = torch.zeros((batch,), dtype=torch.long)
60
+ pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
61
+ text_pos_embed = self.freqs_cis[pos_idx]
62
+ text = text + text_pos_embed
63
+
64
+ # convnextv2 blocks
65
+ text = self.text_blocks(text)
66
+
67
+ return text
68
+
69
+
70
+ # noised input audio and context mixing embedding
71
+
72
+ class InputEmbedding(nn.Module):
73
+ def __init__(self, mel_dim, text_dim, out_dim):
74
+ super().__init__()
75
+ self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
76
+ self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
77
+
78
+ def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
79
+ if drop_audio_cond: # cfg for cond audio
80
+ cond = torch.zeros_like(cond)
81
+
82
+ x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
83
+ x = self.conv_pos_embed(x) + x
84
+ return x
85
+
86
+
87
+ # Transformer backbone using DiT blocks
88
+
89
+ class DiT(nn.Module):
90
+ def __init__(self, *,
91
+ dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
92
+ mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
93
+ long_skip_connection = False,
94
+ ):
95
+ super().__init__()
96
+
97
+ self.time_embed = TimestepEmbedding(dim)
98
+ if text_dim is None:
99
+ text_dim = mel_dim
100
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
101
+ self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
102
+
103
+ self.rotary_embed = RotaryEmbedding(dim_head)
104
+
105
+ self.dim = dim
106
+ self.depth = depth
107
+
108
+ self.transformer_blocks = nn.ModuleList(
109
+ [
110
+ DiTBlock(
111
+ dim = dim,
112
+ heads = heads,
113
+ dim_head = dim_head,
114
+ ff_mult = ff_mult,
115
+ dropout = dropout
116
+ )
117
+ for _ in range(depth)
118
+ ]
119
+ )
120
+ self.long_skip_connection = nn.Linear(dim * 2, dim, bias = False) if long_skip_connection else None
121
+
122
+ self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
123
+ self.proj_out = nn.Linear(dim, mel_dim)
124
+
125
+ def forward(
126
+ self,
127
+ x: float['b n d'], # nosied input audio
128
+ cond: float['b n d'], # masked cond audio
129
+ text: int['b nt'], # text
130
+ time: float['b'] | float[''], # time step
131
+ drop_audio_cond, # cfg for cond audio
132
+ drop_text, # cfg for text
133
+ mask: bool['b n'] | None = None,
134
+ ):
135
+ batch, seq_len = x.shape[0], x.shape[1]
136
+ if time.ndim == 0:
137
+ time = repeat(time, ' -> b', b = batch)
138
+
139
+ # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
140
+ t = self.time_embed(time)
141
+ text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
142
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
143
+
144
+ rope = self.rotary_embed.forward_from_seq_len(seq_len)
145
+
146
+ if self.long_skip_connection is not None:
147
+ residual = x
148
+
149
+ for block in self.transformer_blocks:
150
+ x = block(x, t, mask = mask, rope = rope)
151
+
152
+ if self.long_skip_connection is not None:
153
+ x = self.long_skip_connection(torch.cat((x, residual), dim = -1))
154
+
155
+ x = self.norm_out(x, t)
156
+ output = self.proj_out(x)
157
+
158
+ return output
model/backbones/mmdit.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+ from einops import repeat
16
+
17
+ from x_transformers.x_transformers import RotaryEmbedding
18
+
19
+ from model.modules import (
20
+ TimestepEmbedding,
21
+ ConvPositionEmbedding,
22
+ MMDiTBlock,
23
+ AdaLayerNormZero_Final,
24
+ precompute_freqs_cis, get_pos_embed_indices,
25
+ )
26
+
27
+
28
+ # text embedding
29
+
30
+ class TextEmbedding(nn.Module):
31
+ def __init__(self, out_dim, text_num_embeds):
32
+ super().__init__()
33
+ self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
34
+
35
+ self.precompute_max_pos = 1024
36
+ self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
37
+
38
+ def forward(self, text: int['b nt'], drop_text = False) -> int['b nt d']:
39
+ text = text + 1
40
+ if drop_text:
41
+ text = torch.zeros_like(text)
42
+ text = self.text_embed(text)
43
+
44
+ # sinus pos emb
45
+ batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
46
+ batch_text_len = text.shape[1]
47
+ pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
48
+ text_pos_embed = self.freqs_cis[pos_idx]
49
+
50
+ text = text + text_pos_embed
51
+
52
+ return text
53
+
54
+
55
+ # noised input & masked cond audio embedding
56
+
57
+ class AudioEmbedding(nn.Module):
58
+ def __init__(self, in_dim, out_dim):
59
+ super().__init__()
60
+ self.linear = nn.Linear(2 * in_dim, out_dim)
61
+ self.conv_pos_embed = ConvPositionEmbedding(out_dim)
62
+
63
+ def forward(self, x: float['b n d'], cond: float['b n d'], drop_audio_cond = False):
64
+ if drop_audio_cond:
65
+ cond = torch.zeros_like(cond)
66
+ x = torch.cat((x, cond), dim = -1)
67
+ x = self.linear(x)
68
+ x = self.conv_pos_embed(x) + x
69
+ return x
70
+
71
+
72
+ # Transformer backbone using MM-DiT blocks
73
+
74
+ class MMDiT(nn.Module):
75
+ def __init__(self, *,
76
+ dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
77
+ text_num_embeds = 256, mel_dim = 100,
78
+ ):
79
+ super().__init__()
80
+
81
+ self.time_embed = TimestepEmbedding(dim)
82
+ self.text_embed = TextEmbedding(dim, text_num_embeds)
83
+ self.audio_embed = AudioEmbedding(mel_dim, dim)
84
+
85
+ self.rotary_embed = RotaryEmbedding(dim_head)
86
+
87
+ self.dim = dim
88
+ self.depth = depth
89
+
90
+ self.transformer_blocks = nn.ModuleList(
91
+ [
92
+ MMDiTBlock(
93
+ dim = dim,
94
+ heads = heads,
95
+ dim_head = dim_head,
96
+ dropout = dropout,
97
+ ff_mult = ff_mult,
98
+ context_pre_only = i == depth - 1,
99
+ )
100
+ for i in range(depth)
101
+ ]
102
+ )
103
+ self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
104
+ self.proj_out = nn.Linear(dim, mel_dim)
105
+
106
+ def forward(
107
+ self,
108
+ x: float['b n d'], # nosied input audio
109
+ cond: float['b n d'], # masked cond audio
110
+ text: int['b nt'], # text
111
+ time: float['b'] | float[''], # time step
112
+ drop_audio_cond, # cfg for cond audio
113
+ drop_text, # cfg for text
114
+ mask: bool['b n'] | None = None,
115
+ ):
116
+ batch = x.shape[0]
117
+ if time.ndim == 0:
118
+ time = repeat(time, ' -> b', b = batch)
119
+
120
+ # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
121
+ t = self.time_embed(time)
122
+ c = self.text_embed(text, drop_text = drop_text)
123
+ x = self.audio_embed(x, cond, drop_audio_cond = drop_audio_cond)
124
+
125
+ seq_len = x.shape[1]
126
+ text_len = text.shape[1]
127
+ rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
128
+ rope_text = self.rotary_embed.forward_from_seq_len(text_len)
129
+
130
+ for block in self.transformer_blocks:
131
+ c, x = block(x, c, t, mask = mask, rope = rope_audio, c_rope = rope_text)
132
+
133
+ x = self.norm_out(x, t)
134
+ output = self.proj_out(x)
135
+
136
+ return output
model/backbones/unett.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+ from typing import Literal
12
+
13
+ import torch
14
+ from torch import nn
15
+ import torch.nn.functional as F
16
+
17
+ from einops import repeat, pack, unpack
18
+
19
+ from x_transformers import RMSNorm
20
+ from x_transformers.x_transformers import RotaryEmbedding
21
+
22
+ from model.modules import (
23
+ TimestepEmbedding,
24
+ ConvNeXtV2Block,
25
+ ConvPositionEmbedding,
26
+ Attention,
27
+ AttnProcessor,
28
+ FeedForward,
29
+ precompute_freqs_cis, get_pos_embed_indices,
30
+ )
31
+
32
+
33
+ # Text embedding
34
+
35
+ class TextEmbedding(nn.Module):
36
+ def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
37
+ super().__init__()
38
+ self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
39
+
40
+ if conv_layers > 0:
41
+ self.extra_modeling = True
42
+ self.precompute_max_pos = 4096 # ~44s of 24khz audio
43
+ self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
44
+ self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
45
+ else:
46
+ self.extra_modeling = False
47
+
48
+ def forward(self, text: int['b nt'], seq_len, drop_text = False):
49
+ batch, text_len = text.shape[0], text.shape[1]
50
+ text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
51
+ text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
52
+ text = F.pad(text, (0, seq_len - text_len), value = 0)
53
+
54
+ if drop_text: # cfg for text
55
+ text = torch.zeros_like(text)
56
+
57
+ text = self.text_embed(text) # b n -> b n d
58
+
59
+ # possible extra modeling
60
+ if self.extra_modeling:
61
+ # sinus pos emb
62
+ batch_start = torch.zeros((batch,), dtype=torch.long)
63
+ pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
64
+ text_pos_embed = self.freqs_cis[pos_idx]
65
+ text = text + text_pos_embed
66
+
67
+ # convnextv2 blocks
68
+ text = self.text_blocks(text)
69
+
70
+ return text
71
+
72
+
73
+ # noised input audio and context mixing embedding
74
+
75
+ class InputEmbedding(nn.Module):
76
+ def __init__(self, mel_dim, text_dim, out_dim):
77
+ super().__init__()
78
+ self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
79
+ self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
80
+
81
+ def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
82
+ if drop_audio_cond: # cfg for cond audio
83
+ cond = torch.zeros_like(cond)
84
+
85
+ x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
86
+ x = self.conv_pos_embed(x) + x
87
+ return x
88
+
89
+
90
+ # Flat UNet Transformer backbone
91
+
92
+ class UNetT(nn.Module):
93
+ def __init__(self, *,
94
+ dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
95
+ mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
96
+ skip_connect_type: Literal['add', 'concat', 'none'] = 'concat',
97
+ ):
98
+ super().__init__()
99
+ assert depth % 2 == 0, "UNet-Transformer's depth should be even."
100
+
101
+ self.time_embed = TimestepEmbedding(dim)
102
+ if text_dim is None:
103
+ text_dim = mel_dim
104
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
105
+ self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
106
+
107
+ self.rotary_embed = RotaryEmbedding(dim_head)
108
+
109
+ # transformer layers & skip connections
110
+
111
+ self.dim = dim
112
+ self.skip_connect_type = skip_connect_type
113
+ needs_skip_proj = skip_connect_type == 'concat'
114
+
115
+ self.depth = depth
116
+ self.layers = nn.ModuleList([])
117
+
118
+ for idx in range(depth):
119
+ is_later_half = idx >= (depth // 2)
120
+
121
+ attn_norm = RMSNorm(dim)
122
+ attn = Attention(
123
+ processor = AttnProcessor(),
124
+ dim = dim,
125
+ heads = heads,
126
+ dim_head = dim_head,
127
+ dropout = dropout,
128
+ )
129
+
130
+ ff_norm = RMSNorm(dim)
131
+ ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
132
+
133
+ skip_proj = nn.Linear(dim * 2, dim, bias = False) if needs_skip_proj and is_later_half else None
134
+
135
+ self.layers.append(nn.ModuleList([
136
+ skip_proj,
137
+ attn_norm,
138
+ attn,
139
+ ff_norm,
140
+ ff,
141
+ ]))
142
+
143
+ self.norm_out = RMSNorm(dim)
144
+ self.proj_out = nn.Linear(dim, mel_dim)
145
+
146
+ def forward(
147
+ self,
148
+ x: float['b n d'], # nosied input audio
149
+ cond: float['b n d'], # masked cond audio
150
+ text: int['b nt'], # text
151
+ time: float['b'] | float[''], # time step
152
+ drop_audio_cond, # cfg for cond audio
153
+ drop_text, # cfg for text
154
+ mask: bool['b n'] | None = None,
155
+ ):
156
+ batch, seq_len = x.shape[0], x.shape[1]
157
+ if time.ndim == 0:
158
+ time = repeat(time, ' -> b', b = batch)
159
+
160
+ # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
161
+ t = self.time_embed(time)
162
+ text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
163
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
164
+
165
+ # postfix time t to input x, [b n d] -> [b n+1 d]
166
+ x, ps = pack((t, x), 'b * d')
167
+ if mask is not None:
168
+ mask = F.pad(mask, (1, 0), value=1)
169
+
170
+ rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
171
+
172
+ # flat unet transformer
173
+ skip_connect_type = self.skip_connect_type
174
+ skips = []
175
+ for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers):
176
+ layer = idx + 1
177
+
178
+ # skip connection logic
179
+ is_first_half = layer <= (self.depth // 2)
180
+ is_later_half = not is_first_half
181
+
182
+ if is_first_half:
183
+ skips.append(x)
184
+
185
+ if is_later_half:
186
+ skip = skips.pop()
187
+ if skip_connect_type == 'concat':
188
+ x = torch.cat((x, skip), dim = -1)
189
+ x = maybe_skip_proj(x)
190
+ elif skip_connect_type == 'add':
191
+ x = x + skip
192
+
193
+ # attention and feedforward blocks
194
+ x = attn(attn_norm(x), rope = rope, mask = mask) + x
195
+ x = ff(ff_norm(x)) + x
196
+
197
+ assert len(skips) == 0
198
+
199
+ _, x = unpack(self.norm_out(x), ps, 'b * d')
200
+
201
+ return self.proj_out(x)
model/cfm.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+ from typing import Callable
12
+ from random import random
13
+
14
+ import torch
15
+ from torch import nn
16
+ import torch.nn.functional as F
17
+ from torch.nn.utils.rnn import pad_sequence
18
+
19
+ from torchdiffeq import odeint
20
+
21
+ from einops import rearrange
22
+
23
+ from model.modules import MelSpec
24
+
25
+ from model.utils import (
26
+ default, exists,
27
+ list_str_to_idx, list_str_to_tensor,
28
+ lens_to_mask, mask_from_frac_lengths,
29
+ )
30
+
31
+
32
+ class CFM(nn.Module):
33
+ def __init__(
34
+ self,
35
+ transformer: nn.Module,
36
+ sigma = 0.,
37
+ odeint_kwargs: dict = dict(
38
+ # atol = 1e-5,
39
+ # rtol = 1e-5,
40
+ method = 'euler' # 'midpoint'
41
+ ),
42
+ audio_drop_prob = 0.3,
43
+ cond_drop_prob = 0.2,
44
+ num_channels = None,
45
+ mel_spec_module: nn.Module | None = None,
46
+ mel_spec_kwargs: dict = dict(),
47
+ frac_lengths_mask: tuple[float, float] = (0.7, 1.),
48
+ vocab_char_map: dict[str: int] | None = None
49
+ ):
50
+ super().__init__()
51
+
52
+ self.frac_lengths_mask = frac_lengths_mask
53
+
54
+ # mel spec
55
+ self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
56
+ num_channels = default(num_channels, self.mel_spec.n_mel_channels)
57
+ self.num_channels = num_channels
58
+
59
+ # classifier-free guidance
60
+ self.audio_drop_prob = audio_drop_prob
61
+ self.cond_drop_prob = cond_drop_prob
62
+
63
+ # transformer
64
+ self.transformer = transformer
65
+ dim = transformer.dim
66
+ self.dim = dim
67
+
68
+ # conditional flow related
69
+ self.sigma = sigma
70
+
71
+ # sampling related
72
+ self.odeint_kwargs = odeint_kwargs
73
+
74
+ # vocab map for tokenization
75
+ self.vocab_char_map = vocab_char_map
76
+
77
+ @property
78
+ def device(self):
79
+ return next(self.parameters()).device
80
+
81
+ @torch.no_grad()
82
+ def sample(
83
+ self,
84
+ cond: float['b n d'] | float['b nw'],
85
+ text: int['b nt'] | list[str],
86
+ duration: int | int['b'],
87
+ *,
88
+ lens: int['b'] | None = None,
89
+ steps = 32,
90
+ cfg_strength = 1.,
91
+ sway_sampling_coef = None,
92
+ seed: int | None = None,
93
+ max_duration = 4096,
94
+ vocoder: Callable[[float['b d n']], float['b nw']] | None = None,
95
+ no_ref_audio = False,
96
+ duplicate_test = False,
97
+ t_inter = 0.1,
98
+ ):
99
+ self.eval()
100
+
101
+ # raw wave
102
+
103
+ if cond.ndim == 2:
104
+ cond = self.mel_spec(cond)
105
+ cond = rearrange(cond, 'b d n -> b n d')
106
+ assert cond.shape[-1] == self.num_channels
107
+
108
+ batch, cond_seq_len, device = *cond.shape[:2], cond.device
109
+ if not exists(lens):
110
+ lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long)
111
+
112
+ # text
113
+
114
+ if isinstance(text, list):
115
+ if exists(self.vocab_char_map):
116
+ text = list_str_to_idx(text, self.vocab_char_map).to(device)
117
+ else:
118
+ text = list_str_to_tensor(text).to(device)
119
+ assert text.shape[0] == batch
120
+
121
+ if exists(text):
122
+ text_lens = (text != -1).sum(dim = -1)
123
+ lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
124
+
125
+ # duration
126
+
127
+ cond_mask = lens_to_mask(lens)
128
+
129
+ if isinstance(duration, int):
130
+ duration = torch.full((batch,), duration, device = device, dtype = torch.long)
131
+
132
+ duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
133
+ duration = duration.clamp(max = max_duration)
134
+ max_duration = duration.amax()
135
+
136
+ # duplicate test corner for inner time step oberservation
137
+ if duplicate_test:
138
+ test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2*cond_seq_len), value = 0.)
139
+
140
+ cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.)
141
+ cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False)
142
+ cond_mask = rearrange(cond_mask, '... -> ... 1')
143
+ step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in
144
+
145
+ mask = lens_to_mask(duration)
146
+
147
+ # test for no ref audio
148
+ if no_ref_audio:
149
+ cond = torch.zeros_like(cond)
150
+
151
+ # neural ode
152
+
153
+ def fn(t, x):
154
+ # at each step, conditioning is fixed
155
+ # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
156
+
157
+ # predict flow
158
+ pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = False, drop_text = False)
159
+ if cfg_strength < 1e-5:
160
+ return pred
161
+
162
+ null_pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = True, drop_text = True)
163
+ return pred + (pred - null_pred) * cfg_strength
164
+
165
+ # noise input
166
+ # to make sure batch inference result is same with different batch size, and for sure single inference
167
+ # still some difference maybe due to convolutional layers
168
+ y0 = []
169
+ for dur in duration:
170
+ if exists(seed):
171
+ torch.manual_seed(seed)
172
+ y0.append(torch.randn(dur, self.num_channels, device = self.device))
173
+ y0 = pad_sequence(y0, padding_value = 0, batch_first = True)
174
+
175
+ t_start = 0
176
+
177
+ # duplicate test corner for inner time step oberservation
178
+ if duplicate_test:
179
+ t_start = t_inter
180
+ y0 = (1 - t_start) * y0 + t_start * test_cond
181
+ steps = int(steps * (1 - t_start))
182
+
183
+ t = torch.linspace(t_start, 1, steps, device = self.device)
184
+ if sway_sampling_coef is not None:
185
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
186
+
187
+ trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
188
+
189
+ sampled = trajectory[-1]
190
+ out = sampled
191
+ out = torch.where(cond_mask, cond, out)
192
+
193
+ if exists(vocoder):
194
+ out = rearrange(out, 'b n d -> b d n')
195
+ out = vocoder(out)
196
+
197
+ return out, trajectory
198
+
199
+ def forward(
200
+ self,
201
+ inp: float['b n d'] | float['b nw'], # mel or raw wave
202
+ text: int['b nt'] | list[str],
203
+ *,
204
+ lens: int['b'] | None = None,
205
+ noise_scheduler: str | None = None,
206
+ ):
207
+ # handle raw wave
208
+ if inp.ndim == 2:
209
+ inp = self.mel_spec(inp)
210
+ inp = rearrange(inp, 'b d n -> b n d')
211
+ assert inp.shape[-1] == self.num_channels
212
+
213
+ batch, seq_len, dtype, device, σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
214
+
215
+ # handle text as string
216
+ if isinstance(text, list):
217
+ if exists(self.vocab_char_map):
218
+ text = list_str_to_idx(text, self.vocab_char_map).to(device)
219
+ else:
220
+ text = list_str_to_tensor(text).to(device)
221
+ assert text.shape[0] == batch
222
+
223
+ # lens and mask
224
+ if not exists(lens):
225
+ lens = torch.full((batch,), seq_len, device = device)
226
+
227
+ mask = lens_to_mask(lens, length = seq_len) # useless here, as collate_fn will pad to max length in batch
228
+
229
+ # get a random span to mask out for training conditionally
230
+ frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
231
+ rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
232
+
233
+ if exists(mask):
234
+ rand_span_mask &= mask
235
+
236
+ # mel is x1
237
+ x1 = inp
238
+
239
+ # x0 is gaussian noise
240
+ x0 = torch.randn_like(x1)
241
+
242
+ # time step
243
+ time = torch.rand((batch,), dtype = dtype, device = self.device)
244
+ # TODO. noise_scheduler
245
+
246
+ # sample xt (φ_t(x) in the paper)
247
+ t = rearrange(time, 'b -> b 1 1')
248
+ φ = (1 - t) * x0 + t * x1
249
+ flow = x1 - x0
250
+
251
+ # only predict what is within the random mask span for infilling
252
+ cond = torch.where(
253
+ rand_span_mask[..., None],
254
+ torch.zeros_like(x1), x1
255
+ )
256
+
257
+ # transformer and cfg training with a drop rate
258
+ drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
259
+ if random() < self.cond_drop_prob: # p_uncond in voicebox paper
260
+ drop_audio_cond = True
261
+ drop_text = True
262
+ else:
263
+ drop_text = False
264
+
265
+ # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
266
+ # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
267
+ pred = self.transformer(x = φ, cond = cond, text = text, time = time, drop_audio_cond = drop_audio_cond, drop_text = drop_text)
268
+
269
+ # flow matching loss
270
+ loss = F.mse_loss(pred, flow, reduction = 'none')
271
+ loss = loss[rand_span_mask]
272
+
273
+ return loss.mean(), cond, pred
model/dataset.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from tqdm import tqdm
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch.utils.data import Dataset, Sampler
8
+ import torchaudio
9
+ from datasets import load_dataset, load_from_disk
10
+ from datasets import Dataset as Dataset_
11
+
12
+ from einops import rearrange
13
+
14
+ from model.modules import MelSpec
15
+
16
+
17
+ class HFDataset(Dataset):
18
+ def __init__(
19
+ self,
20
+ hf_dataset: Dataset,
21
+ target_sample_rate = 24_000,
22
+ n_mel_channels = 100,
23
+ hop_length = 256,
24
+ ):
25
+ self.data = hf_dataset
26
+ self.target_sample_rate = target_sample_rate
27
+ self.hop_length = hop_length
28
+ self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
29
+
30
+ def get_frame_len(self, index):
31
+ row = self.data[index]
32
+ audio = row['audio']['array']
33
+ sample_rate = row['audio']['sampling_rate']
34
+ return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
35
+
36
+ def __len__(self):
37
+ return len(self.data)
38
+
39
+ def __getitem__(self, index):
40
+ row = self.data[index]
41
+ audio = row['audio']['array']
42
+
43
+ # logger.info(f"Audio shape: {audio.shape}")
44
+
45
+ sample_rate = row['audio']['sampling_rate']
46
+ duration = audio.shape[-1] / sample_rate
47
+
48
+ if duration > 30 or duration < 0.3:
49
+ return self.__getitem__((index + 1) % len(self.data))
50
+
51
+ audio_tensor = torch.from_numpy(audio).float()
52
+
53
+ if sample_rate != self.target_sample_rate:
54
+ resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
55
+ audio_tensor = resampler(audio_tensor)
56
+
57
+ audio_tensor = rearrange(audio_tensor, 't -> 1 t')
58
+
59
+ mel_spec = self.mel_spectrogram(audio_tensor)
60
+
61
+ mel_spec = rearrange(mel_spec, '1 d t -> d t')
62
+
63
+ text = row['text']
64
+
65
+ return dict(
66
+ mel_spec = mel_spec,
67
+ text = text,
68
+ )
69
+
70
+
71
+ class CustomDataset(Dataset):
72
+ def __init__(
73
+ self,
74
+ custom_dataset: Dataset,
75
+ durations = None,
76
+ target_sample_rate = 24_000,
77
+ hop_length = 256,
78
+ n_mel_channels = 100,
79
+ preprocessed_mel = False,
80
+ ):
81
+ self.data = custom_dataset
82
+ self.durations = durations
83
+ self.target_sample_rate = target_sample_rate
84
+ self.hop_length = hop_length
85
+ self.preprocessed_mel = preprocessed_mel
86
+ if not preprocessed_mel:
87
+ self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels)
88
+
89
+ def get_frame_len(self, index):
90
+ if self.durations is not None: # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
91
+ return self.durations[index] * self.target_sample_rate / self.hop_length
92
+ return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
93
+
94
+ def __len__(self):
95
+ return len(self.data)
96
+
97
+ def __getitem__(self, index):
98
+ row = self.data[index]
99
+ audio_path = row["audio_path"]
100
+ text = row["text"]
101
+ duration = row["duration"]
102
+
103
+ if self.preprocessed_mel:
104
+ mel_spec = torch.tensor(row["mel_spec"])
105
+
106
+ else:
107
+ audio, source_sample_rate = torchaudio.load(audio_path)
108
+
109
+ if duration > 30 or duration < 0.3:
110
+ return self.__getitem__((index + 1) % len(self.data))
111
+
112
+ if source_sample_rate != self.target_sample_rate:
113
+ resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
114
+ audio = resampler(audio)
115
+
116
+ mel_spec = self.mel_spectrogram(audio)
117
+ mel_spec = rearrange(mel_spec, '1 d t -> d t')
118
+
119
+ return dict(
120
+ mel_spec = mel_spec,
121
+ text = text,
122
+ )
123
+
124
+
125
+ # Dynamic Batch Sampler
126
+
127
+ class DynamicBatchSampler(Sampler[list[int]]):
128
+ """ Extension of Sampler that will do the following:
129
+ 1. Change the batch size (essentially number of sequences)
130
+ in a batch to ensure that the total number of frames are less
131
+ than a certain threshold.
132
+ 2. Make sure the padding efficiency in the batch is high.
133
+ """
134
+
135
+ def __init__(self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False):
136
+ self.sampler = sampler
137
+ self.frames_threshold = frames_threshold
138
+ self.max_samples = max_samples
139
+
140
+ indices, batches = [], []
141
+ data_source = self.sampler.data_source
142
+
143
+ for idx in tqdm(self.sampler, desc=f"Sorting with sampler... if slow, check whether dataset is provided with duration"):
144
+ indices.append((idx, data_source.get_frame_len(idx)))
145
+ indices.sort(key=lambda elem : elem[1])
146
+
147
+ batch = []
148
+ batch_frames = 0
149
+ for idx, frame_len in tqdm(indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"):
150
+ if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
151
+ batch.append(idx)
152
+ batch_frames += frame_len
153
+ else:
154
+ if len(batch) > 0:
155
+ batches.append(batch)
156
+ if frame_len <= self.frames_threshold:
157
+ batch = [idx]
158
+ batch_frames = frame_len
159
+ else:
160
+ batch = []
161
+ batch_frames = 0
162
+
163
+ if not drop_last and len(batch) > 0:
164
+ batches.append(batch)
165
+
166
+ del indices
167
+
168
+ # if want to have different batches between epochs, may just set a seed and log it in ckpt
169
+ # cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different
170
+ # e.g. for epoch n, use (random_seed + n)
171
+ random.seed(random_seed)
172
+ random.shuffle(batches)
173
+
174
+ self.batches = batches
175
+
176
+ def __iter__(self):
177
+ return iter(self.batches)
178
+
179
+ def __len__(self):
180
+ return len(self.batches)
181
+
182
+
183
+ # Load dataset
184
+
185
+ def load_dataset(
186
+ dataset_name: str,
187
+ tokenizer: str,
188
+ dataset_type: str = "CustomDataset",
189
+ audio_type: str = "raw",
190
+ mel_spec_kwargs: dict = dict()
191
+ ) -> CustomDataset | HFDataset:
192
+
193
+ print("Loading dataset ...")
194
+
195
+ if dataset_type == "CustomDataset":
196
+ if audio_type == "raw":
197
+ try:
198
+ train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw")
199
+ except:
200
+ train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow")
201
+ preprocessed_mel = False
202
+ elif audio_type == "mel":
203
+ train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow")
204
+ preprocessed_mel = True
205
+ with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'r', encoding='utf-8') as f:
206
+ data_dict = json.load(f)
207
+ durations = data_dict["duration"]
208
+ train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
209
+
210
+ elif dataset_type == "HFDataset":
211
+ print("Should manually modify the path of huggingface dataset to your need.\n" +
212
+ "May also the corresponding script cuz different dataset may have different format.")
213
+ pre, post = dataset_name.split("_")
214
+ train_dataset = HFDataset(load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),)
215
+
216
+ return train_dataset
217
+
218
+
219
+ # collation
220
+
221
+ def collate_fn(batch):
222
+ mel_specs = [item['mel_spec'].squeeze(0) for item in batch]
223
+ mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
224
+ max_mel_length = mel_lengths.amax()
225
+
226
+ padded_mel_specs = []
227
+ for spec in mel_specs: # TODO. maybe records mask for attention here
228
+ padding = (0, max_mel_length - spec.size(-1))
229
+ padded_spec = F.pad(spec, padding, value = 0)
230
+ padded_mel_specs.append(padded_spec)
231
+
232
+ mel_specs = torch.stack(padded_mel_specs)
233
+
234
+ text = [item['text'] for item in batch]
235
+ text_lengths = torch.LongTensor([len(item) for item in text])
236
+
237
+ return dict(
238
+ mel = mel_specs,
239
+ mel_lengths = mel_lengths,
240
+ text = text,
241
+ text_lengths = text_lengths,
242
+ )
model/ecapa_tdnn.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # just for speaker similarity evaluation, third-party code
2
+
3
+ # From https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/
4
+ # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
5
+
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ ''' Res2Conv1d + BatchNorm1d + ReLU
13
+ '''
14
+
15
+ class Res2Conv1dReluBn(nn.Module):
16
+ '''
17
+ in_channels == out_channels == channels
18
+ '''
19
+
20
+ def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
21
+ super().__init__()
22
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
23
+ self.scale = scale
24
+ self.width = channels // scale
25
+ self.nums = scale if scale == 1 else scale - 1
26
+
27
+ self.convs = []
28
+ self.bns = []
29
+ for i in range(self.nums):
30
+ self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
31
+ self.bns.append(nn.BatchNorm1d(self.width))
32
+ self.convs = nn.ModuleList(self.convs)
33
+ self.bns = nn.ModuleList(self.bns)
34
+
35
+ def forward(self, x):
36
+ out = []
37
+ spx = torch.split(x, self.width, 1)
38
+ for i in range(self.nums):
39
+ if i == 0:
40
+ sp = spx[i]
41
+ else:
42
+ sp = sp + spx[i]
43
+ # Order: conv -> relu -> bn
44
+ sp = self.convs[i](sp)
45
+ sp = self.bns[i](F.relu(sp))
46
+ out.append(sp)
47
+ if self.scale != 1:
48
+ out.append(spx[self.nums])
49
+ out = torch.cat(out, dim=1)
50
+
51
+ return out
52
+
53
+
54
+ ''' Conv1d + BatchNorm1d + ReLU
55
+ '''
56
+
57
+ class Conv1dReluBn(nn.Module):
58
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
59
+ super().__init__()
60
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
61
+ self.bn = nn.BatchNorm1d(out_channels)
62
+
63
+ def forward(self, x):
64
+ return self.bn(F.relu(self.conv(x)))
65
+
66
+
67
+ ''' The SE connection of 1D case.
68
+ '''
69
+
70
+ class SE_Connect(nn.Module):
71
+ def __init__(self, channels, se_bottleneck_dim=128):
72
+ super().__init__()
73
+ self.linear1 = nn.Linear(channels, se_bottleneck_dim)
74
+ self.linear2 = nn.Linear(se_bottleneck_dim, channels)
75
+
76
+ def forward(self, x):
77
+ out = x.mean(dim=2)
78
+ out = F.relu(self.linear1(out))
79
+ out = torch.sigmoid(self.linear2(out))
80
+ out = x * out.unsqueeze(2)
81
+
82
+ return out
83
+
84
+
85
+ ''' SE-Res2Block of the ECAPA-TDNN architecture.
86
+ '''
87
+
88
+ # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
89
+ # return nn.Sequential(
90
+ # Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
91
+ # Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
92
+ # Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
93
+ # SE_Connect(channels)
94
+ # )
95
+
96
+ class SE_Res2Block(nn.Module):
97
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
98
+ super().__init__()
99
+ self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
100
+ self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
101
+ self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
102
+ self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
103
+
104
+ self.shortcut = None
105
+ if in_channels != out_channels:
106
+ self.shortcut = nn.Conv1d(
107
+ in_channels=in_channels,
108
+ out_channels=out_channels,
109
+ kernel_size=1,
110
+ )
111
+
112
+ def forward(self, x):
113
+ residual = x
114
+ if self.shortcut:
115
+ residual = self.shortcut(x)
116
+
117
+ x = self.Conv1dReluBn1(x)
118
+ x = self.Res2Conv1dReluBn(x)
119
+ x = self.Conv1dReluBn2(x)
120
+ x = self.SE_Connect(x)
121
+
122
+ return x + residual
123
+
124
+
125
+ ''' Attentive weighted mean and standard deviation pooling.
126
+ '''
127
+
128
+ class AttentiveStatsPool(nn.Module):
129
+ def __init__(self, in_dim, attention_channels=128, global_context_att=False):
130
+ super().__init__()
131
+ self.global_context_att = global_context_att
132
+
133
+ # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
134
+ if global_context_att:
135
+ self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
136
+ else:
137
+ self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
138
+ self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
139
+
140
+ def forward(self, x):
141
+
142
+ if self.global_context_att:
143
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
144
+ context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
145
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
146
+ else:
147
+ x_in = x
148
+
149
+ # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
150
+ alpha = torch.tanh(self.linear1(x_in))
151
+ # alpha = F.relu(self.linear1(x_in))
152
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
153
+ mean = torch.sum(alpha * x, dim=2)
154
+ residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
155
+ std = torch.sqrt(residuals.clamp(min=1e-9))
156
+ return torch.cat([mean, std], dim=1)
157
+
158
+
159
+ class ECAPA_TDNN(nn.Module):
160
+ def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
161
+ feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
162
+ super().__init__()
163
+
164
+ self.feat_type = feat_type
165
+ self.feature_selection = feature_selection
166
+ self.update_extract = update_extract
167
+ self.sr = sr
168
+
169
+ torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
170
+ try:
171
+ local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
172
+ self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source='local', config_path=config_path)
173
+ except:
174
+ self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
175
+
176
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
177
+ self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
178
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
179
+ self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
180
+
181
+ self.feat_num = self.get_feat_num()
182
+ self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
183
+
184
+ if feat_type != 'fbank' and feat_type != 'mfcc':
185
+ freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
186
+ for name, param in self.feature_extract.named_parameters():
187
+ for freeze_val in freeze_list:
188
+ if freeze_val in name:
189
+ param.requires_grad = False
190
+ break
191
+
192
+ if not self.update_extract:
193
+ for param in self.feature_extract.parameters():
194
+ param.requires_grad = False
195
+
196
+ self.instance_norm = nn.InstanceNorm1d(feat_dim)
197
+ # self.channels = [channels] * 4 + [channels * 3]
198
+ self.channels = [channels] * 4 + [1536]
199
+
200
+ self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
201
+ self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
202
+ self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
203
+ self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)
204
+
205
+ # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
206
+ cat_channels = channels * 3
207
+ self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
208
+ self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
209
+ self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
210
+ self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
211
+
212
+
213
+ def get_feat_num(self):
214
+ self.feature_extract.eval()
215
+ wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
216
+ with torch.no_grad():
217
+ features = self.feature_extract(wav)
218
+ select_feature = features[self.feature_selection]
219
+ if isinstance(select_feature, (list, tuple)):
220
+ return len(select_feature)
221
+ else:
222
+ return 1
223
+
224
+ def get_feat(self, x):
225
+ if self.update_extract:
226
+ x = self.feature_extract([sample for sample in x])
227
+ else:
228
+ with torch.no_grad():
229
+ if self.feat_type == 'fbank' or self.feat_type == 'mfcc':
230
+ x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
231
+ else:
232
+ x = self.feature_extract([sample for sample in x])
233
+
234
+ if self.feat_type == 'fbank':
235
+ x = x.log()
236
+
237
+ if self.feat_type != "fbank" and self.feat_type != "mfcc":
238
+ x = x[self.feature_selection]
239
+ if isinstance(x, (list, tuple)):
240
+ x = torch.stack(x, dim=0)
241
+ else:
242
+ x = x.unsqueeze(0)
243
+ norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
244
+ x = (norm_weights * x).sum(dim=0)
245
+ x = torch.transpose(x, 1, 2) + 1e-6
246
+
247
+ x = self.instance_norm(x)
248
+ return x
249
+
250
+ def forward(self, x):
251
+ x = self.get_feat(x)
252
+
253
+ out1 = self.layer1(x)
254
+ out2 = self.layer2(out1)
255
+ out3 = self.layer3(out2)
256
+ out4 = self.layer4(out3)
257
+
258
+ out = torch.cat([out2, out3, out4], dim=1)
259
+ out = F.relu(self.conv(out))
260
+ out = self.bn(self.pooling(out))
261
+ out = self.linear(out)
262
+
263
+ return out
264
+
265
+
266
+ def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
267
+ return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
268
+ feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path)
model/modules.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+ from typing import Optional
12
+ import math
13
+
14
+ import torch
15
+ from torch import nn
16
+ import torch.nn.functional as F
17
+ import torchaudio
18
+
19
+ from einops import rearrange
20
+ from x_transformers.x_transformers import apply_rotary_pos_emb
21
+
22
+
23
+ # raw wav to mel spec
24
+
25
+ class MelSpec(nn.Module):
26
+ def __init__(
27
+ self,
28
+ filter_length = 1024,
29
+ hop_length = 256,
30
+ win_length = 1024,
31
+ n_mel_channels = 100,
32
+ target_sample_rate = 24_000,
33
+ normalize = False,
34
+ power = 1,
35
+ norm = None,
36
+ center = True,
37
+ ):
38
+ super().__init__()
39
+ self.n_mel_channels = n_mel_channels
40
+
41
+ self.mel_stft = torchaudio.transforms.MelSpectrogram(
42
+ sample_rate = target_sample_rate,
43
+ n_fft = filter_length,
44
+ win_length = win_length,
45
+ hop_length = hop_length,
46
+ n_mels = n_mel_channels,
47
+ power = power,
48
+ center = center,
49
+ normalized = normalize,
50
+ norm = norm,
51
+ )
52
+
53
+ self.register_buffer('dummy', torch.tensor(0), persistent = False)
54
+
55
+ def forward(self, inp):
56
+ if len(inp.shape) == 3:
57
+ inp = rearrange(inp, 'b 1 nw -> b nw')
58
+
59
+ assert len(inp.shape) == 2
60
+
61
+ if self.dummy.device != inp.device:
62
+ self.to(inp.device)
63
+
64
+ mel = self.mel_stft(inp)
65
+ mel = mel.clamp(min = 1e-5).log()
66
+ return mel
67
+
68
+
69
+ # sinusoidal position embedding
70
+
71
+ class SinusPositionEmbedding(nn.Module):
72
+ def __init__(self, dim):
73
+ super().__init__()
74
+ self.dim = dim
75
+
76
+ def forward(self, x, scale=1000):
77
+ device = x.device
78
+ half_dim = self.dim // 2
79
+ emb = math.log(10000) / (half_dim - 1)
80
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
81
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
82
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
83
+ return emb
84
+
85
+
86
+ # convolutional position embedding
87
+
88
+ class ConvPositionEmbedding(nn.Module):
89
+ def __init__(self, dim, kernel_size = 31, groups = 16):
90
+ super().__init__()
91
+ assert kernel_size % 2 != 0
92
+ self.conv1d = nn.Sequential(
93
+ nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
94
+ nn.Mish(),
95
+ nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
96
+ nn.Mish(),
97
+ )
98
+
99
+ def forward(self, x: float['b n d'], mask: bool['b n'] | None = None):
100
+ if mask is not None:
101
+ mask = mask[..., None]
102
+ x = x.masked_fill(~mask, 0.)
103
+
104
+ x = rearrange(x, 'b n d -> b d n')
105
+ x = self.conv1d(x)
106
+ out = rearrange(x, 'b d n -> b n d')
107
+
108
+ if mask is not None:
109
+ out = out.masked_fill(~mask, 0.)
110
+
111
+ return out
112
+
113
+
114
+ # rotary positional embedding related
115
+
116
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.):
117
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
118
+ # has some connection to NTK literature
119
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
120
+ # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
121
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
122
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
123
+ t = torch.arange(end, device=freqs.device) # type: ignore
124
+ freqs = torch.outer(t, freqs).float() # type: ignore
125
+ freqs_cos = torch.cos(freqs) # real part
126
+ freqs_sin = torch.sin(freqs) # imaginary part
127
+ return torch.cat([freqs_cos, freqs_sin], dim=-1)
128
+
129
+ def get_pos_embed_indices(start, length, max_pos, scale=1.):
130
+ # length = length if isinstance(length, int) else length.max()
131
+ scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
132
+ pos = start.unsqueeze(1) + (
133
+ torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) *
134
+ scale.unsqueeze(1)).long()
135
+ # avoid extra long error.
136
+ pos = torch.where(pos < max_pos, pos, max_pos - 1)
137
+ return pos
138
+
139
+
140
+ # Global Response Normalization layer (Instance Normalization ?)
141
+
142
+ class GRN(nn.Module):
143
+ def __init__(self, dim):
144
+ super().__init__()
145
+ self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
146
+ self.beta = nn.Parameter(torch.zeros(1, 1, dim))
147
+
148
+ def forward(self, x):
149
+ Gx = torch.norm(x, p=2, dim=1, keepdim=True)
150
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
151
+ return self.gamma * (x * Nx) + self.beta + x
152
+
153
+
154
+ # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
155
+ # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
156
+
157
+ class ConvNeXtV2Block(nn.Module):
158
+ def __init__(
159
+ self,
160
+ dim: int,
161
+ intermediate_dim: int,
162
+ dilation: int = 1,
163
+ ):
164
+ super().__init__()
165
+ padding = (dilation * (7 - 1)) // 2
166
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation) # depthwise conv
167
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
168
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
169
+ self.act = nn.GELU()
170
+ self.grn = GRN(intermediate_dim)
171
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
172
+
173
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
174
+ residual = x
175
+ x = x.transpose(1, 2) # b n d -> b d n
176
+ x = self.dwconv(x)
177
+ x = x.transpose(1, 2) # b d n -> b n d
178
+ x = self.norm(x)
179
+ x = self.pwconv1(x)
180
+ x = self.act(x)
181
+ x = self.grn(x)
182
+ x = self.pwconv2(x)
183
+ return residual + x
184
+
185
+
186
+ # AdaLayerNormZero
187
+ # return with modulated x for attn input, and params for later mlp modulation
188
+
189
+ class AdaLayerNormZero(nn.Module):
190
+ def __init__(self, dim):
191
+ super().__init__()
192
+
193
+ self.silu = nn.SiLU()
194
+ self.linear = nn.Linear(dim, dim * 6)
195
+
196
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
197
+
198
+ def forward(self, x, emb = None):
199
+ emb = self.linear(self.silu(emb))
200
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
201
+
202
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
203
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
204
+
205
+
206
+ # AdaLayerNormZero for final layer
207
+ # return only with modulated x for attn input, cuz no more mlp modulation
208
+
209
+ class AdaLayerNormZero_Final(nn.Module):
210
+ def __init__(self, dim):
211
+ super().__init__()
212
+
213
+ self.silu = nn.SiLU()
214
+ self.linear = nn.Linear(dim, dim * 2)
215
+
216
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
217
+
218
+ def forward(self, x, emb):
219
+ emb = self.linear(self.silu(emb))
220
+ scale, shift = torch.chunk(emb, 2, dim=1)
221
+
222
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
223
+ return x
224
+
225
+
226
+ # FeedForward
227
+
228
+ class FeedForward(nn.Module):
229
+ def __init__(self, dim, dim_out = None, mult = 4, dropout = 0., approximate: str = 'none'):
230
+ super().__init__()
231
+ inner_dim = int(dim * mult)
232
+ dim_out = dim_out if dim_out is not None else dim
233
+
234
+ activation = nn.GELU(approximate=approximate)
235
+ project_in = nn.Sequential(
236
+ nn.Linear(dim, inner_dim),
237
+ activation
238
+ )
239
+ self.ff = nn.Sequential(
240
+ project_in,
241
+ nn.Dropout(dropout),
242
+ nn.Linear(inner_dim, dim_out)
243
+ )
244
+
245
+ def forward(self, x):
246
+ return self.ff(x)
247
+
248
+
249
+ # Attention with possible joint part
250
+ # modified from diffusers/src/diffusers/models/attention_processor.py
251
+
252
+ class Attention(nn.Module):
253
+ def __init__(
254
+ self,
255
+ processor: JointAttnProcessor | AttnProcessor,
256
+ dim: int,
257
+ heads: int = 8,
258
+ dim_head: int = 64,
259
+ dropout: float = 0.0,
260
+ context_dim: Optional[int] = None, # if not None -> joint attention
261
+ context_pre_only = None,
262
+ ):
263
+ super().__init__()
264
+
265
+ if not hasattr(F, "scaled_dot_product_attention"):
266
+ raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
267
+
268
+ self.processor = processor
269
+
270
+ self.dim = dim
271
+ self.heads = heads
272
+ self.inner_dim = dim_head * heads
273
+ self.dropout = dropout
274
+
275
+ self.context_dim = context_dim
276
+ self.context_pre_only = context_pre_only
277
+
278
+ self.to_q = nn.Linear(dim, self.inner_dim)
279
+ self.to_k = nn.Linear(dim, self.inner_dim)
280
+ self.to_v = nn.Linear(dim, self.inner_dim)
281
+
282
+ if self.context_dim is not None:
283
+ self.to_k_c = nn.Linear(context_dim, self.inner_dim)
284
+ self.to_v_c = nn.Linear(context_dim, self.inner_dim)
285
+ if self.context_pre_only is not None:
286
+ self.to_q_c = nn.Linear(context_dim, self.inner_dim)
287
+
288
+ self.to_out = nn.ModuleList([])
289
+ self.to_out.append(nn.Linear(self.inner_dim, dim))
290
+ self.to_out.append(nn.Dropout(dropout))
291
+
292
+ if self.context_pre_only is not None and not self.context_pre_only:
293
+ self.to_out_c = nn.Linear(self.inner_dim, dim)
294
+
295
+ def forward(
296
+ self,
297
+ x: float['b n d'], # noised input x
298
+ c: float['b n d'] = None, # context c
299
+ mask: bool['b n'] | None = None,
300
+ rope = None, # rotary position embedding for x
301
+ c_rope = None, # rotary position embedding for c
302
+ ) -> torch.Tensor:
303
+ if c is not None:
304
+ return self.processor(self, x, c = c, mask = mask, rope = rope, c_rope = c_rope)
305
+ else:
306
+ return self.processor(self, x, mask = mask, rope = rope)
307
+
308
+
309
+ # Attention processor
310
+
311
+ class AttnProcessor:
312
+ def __init__(self):
313
+ pass
314
+
315
+ def __call__(
316
+ self,
317
+ attn: Attention,
318
+ x: float['b n d'], # noised input x
319
+ mask: bool['b n'] | None = None,
320
+ rope = None, # rotary position embedding
321
+ ) -> torch.FloatTensor:
322
+
323
+ batch_size = x.shape[0]
324
+
325
+ # `sample` projections.
326
+ query = attn.to_q(x)
327
+ key = attn.to_k(x)
328
+ value = attn.to_v(x)
329
+
330
+ # apply rotary position embedding
331
+ if rope is not None:
332
+ freqs, xpos_scale = rope
333
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
334
+
335
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
336
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
337
+
338
+ # attention
339
+ inner_dim = key.shape[-1]
340
+ head_dim = inner_dim // attn.heads
341
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
342
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
343
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
344
+
345
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
346
+ if mask is not None:
347
+ attn_mask = mask
348
+ attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
349
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
350
+ else:
351
+ attn_mask = None
352
+
353
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
354
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
355
+ x = x.to(query.dtype)
356
+
357
+ # linear proj
358
+ x = attn.to_out[0](x)
359
+ # dropout
360
+ x = attn.to_out[1](x)
361
+
362
+ if mask is not None:
363
+ mask = rearrange(mask, 'b n -> b n 1')
364
+ x = x.masked_fill(~mask, 0.)
365
+
366
+ return x
367
+
368
+
369
+ # Joint Attention processor for MM-DiT
370
+ # modified from diffusers/src/diffusers/models/attention_processor.py
371
+
372
+ class JointAttnProcessor:
373
+ def __init__(self):
374
+ pass
375
+
376
+ def __call__(
377
+ self,
378
+ attn: Attention,
379
+ x: float['b n d'], # noised input x
380
+ c: float['b nt d'] = None, # context c, here text
381
+ mask: bool['b n'] | None = None,
382
+ rope = None, # rotary position embedding for x
383
+ c_rope = None, # rotary position embedding for c
384
+ ) -> torch.FloatTensor:
385
+ residual = x
386
+
387
+ batch_size = c.shape[0]
388
+
389
+ # `sample` projections.
390
+ query = attn.to_q(x)
391
+ key = attn.to_k(x)
392
+ value = attn.to_v(x)
393
+
394
+ # `context` projections.
395
+ c_query = attn.to_q_c(c)
396
+ c_key = attn.to_k_c(c)
397
+ c_value = attn.to_v_c(c)
398
+
399
+ # apply rope for context and noised input independently
400
+ if rope is not None:
401
+ freqs, xpos_scale = rope
402
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
403
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
404
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
405
+ if c_rope is not None:
406
+ freqs, xpos_scale = c_rope
407
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
408
+ c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
409
+ c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
410
+
411
+ # attention
412
+ query = torch.cat([query, c_query], dim=1)
413
+ key = torch.cat([key, c_key], dim=1)
414
+ value = torch.cat([value, c_value], dim=1)
415
+
416
+ inner_dim = key.shape[-1]
417
+ head_dim = inner_dim // attn.heads
418
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
419
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
420
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
421
+
422
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
423
+ if mask is not None:
424
+ attn_mask = F.pad(mask, (0, c.shape[1]), value = True) # no mask for c (text)
425
+ attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
426
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
427
+ else:
428
+ attn_mask = None
429
+
430
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
431
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
432
+ x = x.to(query.dtype)
433
+
434
+ # Split the attention outputs.
435
+ x, c = (
436
+ x[:, :residual.shape[1]],
437
+ x[:, residual.shape[1]:],
438
+ )
439
+
440
+ # linear proj
441
+ x = attn.to_out[0](x)
442
+ # dropout
443
+ x = attn.to_out[1](x)
444
+ if not attn.context_pre_only:
445
+ c = attn.to_out_c(c)
446
+
447
+ if mask is not None:
448
+ mask = rearrange(mask, 'b n -> b n 1')
449
+ x = x.masked_fill(~mask, 0.)
450
+ # c = c.masked_fill(~mask, 0.) # no mask for c (text)
451
+
452
+ return x, c
453
+
454
+
455
+ # DiT Block
456
+
457
+ class DiTBlock(nn.Module):
458
+
459
+ def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1):
460
+ super().__init__()
461
+
462
+ self.attn_norm = AdaLayerNormZero(dim)
463
+ self.attn = Attention(
464
+ processor = AttnProcessor(),
465
+ dim = dim,
466
+ heads = heads,
467
+ dim_head = dim_head,
468
+ dropout = dropout,
469
+ )
470
+
471
+ self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
472
+ self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
473
+
474
+ def forward(self, x, t, mask = None, rope = None): # x: noised input, t: time embedding
475
+ # pre-norm & modulation for attention input
476
+ norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
477
+
478
+ # attention
479
+ attn_output = self.attn(x=norm, mask=mask, rope=rope)
480
+
481
+ # process attention output for input x
482
+ x = x + gate_msa.unsqueeze(1) * attn_output
483
+
484
+ norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
485
+ ff_output = self.ff(norm)
486
+ x = x + gate_mlp.unsqueeze(1) * ff_output
487
+
488
+ return x
489
+
490
+
491
+ # MMDiT Block https://arxiv.org/abs/2403.03206
492
+
493
+ class MMDiTBlock(nn.Module):
494
+ r"""
495
+ modified from diffusers/src/diffusers/models/attention.py
496
+
497
+ notes.
498
+ _c: context related. text, cond, etc. (left part in sd3 fig2.b)
499
+ _x: noised input related. (right part)
500
+ context_pre_only: last layer only do prenorm + modulation cuz no more ffn
501
+ """
502
+
503
+ def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1, context_pre_only = False):
504
+ super().__init__()
505
+
506
+ self.context_pre_only = context_pre_only
507
+
508
+ self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
509
+ self.attn_norm_x = AdaLayerNormZero(dim)
510
+ self.attn = Attention(
511
+ processor = JointAttnProcessor(),
512
+ dim = dim,
513
+ heads = heads,
514
+ dim_head = dim_head,
515
+ dropout = dropout,
516
+ context_dim = dim,
517
+ context_pre_only = context_pre_only,
518
+ )
519
+
520
+ if not context_pre_only:
521
+ self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
522
+ self.ff_c = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
523
+ else:
524
+ self.ff_norm_c = None
525
+ self.ff_c = None
526
+ self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
527
+ self.ff_x = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
528
+
529
+ def forward(self, x, c, t, mask = None, rope = None, c_rope = None): # x: noised input, c: context, t: time embedding
530
+ # pre-norm & modulation for attention input
531
+ if self.context_pre_only:
532
+ norm_c = self.attn_norm_c(c, t)
533
+ else:
534
+ norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
535
+ norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
536
+
537
+ # attention
538
+ x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
539
+
540
+ # process attention output for context c
541
+ if self.context_pre_only:
542
+ c = None
543
+ else: # if not last layer
544
+ c = c + c_gate_msa.unsqueeze(1) * c_attn_output
545
+
546
+ norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
547
+ c_ff_output = self.ff_c(norm_c)
548
+ c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
549
+
550
+ # process attention output for input x
551
+ x = x + x_gate_msa.unsqueeze(1) * x_attn_output
552
+
553
+ norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
554
+ x_ff_output = self.ff_x(norm_x)
555
+ x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
556
+
557
+ return c, x
558
+
559
+
560
+ # time step conditioning embedding
561
+
562
+ class TimestepEmbedding(nn.Module):
563
+ def __init__(self, dim, freq_embed_dim=256):
564
+ super().__init__()
565
+ self.time_embed = SinusPositionEmbedding(freq_embed_dim)
566
+ self.time_mlp = nn.Sequential(
567
+ nn.Linear(freq_embed_dim, dim),
568
+ nn.SiLU(),
569
+ nn.Linear(dim, dim)
570
+ )
571
+
572
+ def forward(self, timestep: float['b']):
573
+ time_hidden = self.time_embed(timestep)
574
+ time = self.time_mlp(time_hidden) # b d
575
+ return time
model/trainer.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import gc
5
+ from tqdm import tqdm
6
+ import wandb
7
+
8
+ import torch
9
+ from torch.optim import AdamW
10
+ from torch.utils.data import DataLoader, Dataset, SequentialSampler
11
+ from torch.optim.lr_scheduler import LinearLR, SequentialLR
12
+
13
+ from einops import rearrange
14
+
15
+ from accelerate import Accelerator
16
+ from accelerate.utils import DistributedDataParallelKwargs
17
+
18
+ from ema_pytorch import EMA
19
+
20
+ from model import CFM
21
+ from model.utils import exists, default
22
+ from model.dataset import DynamicBatchSampler, collate_fn
23
+
24
+
25
+ # trainer
26
+
27
+ class Trainer:
28
+ def __init__(
29
+ self,
30
+ model: CFM,
31
+ epochs,
32
+ learning_rate,
33
+ num_warmup_updates = 20000,
34
+ save_per_updates = 1000,
35
+ checkpoint_path = None,
36
+ batch_size = 32,
37
+ batch_size_type: str = "sample",
38
+ max_samples = 32,
39
+ grad_accumulation_steps = 1,
40
+ max_grad_norm = 1.0,
41
+ noise_scheduler: str | None = None,
42
+ duration_predictor: torch.nn.Module | None = None,
43
+ wandb_project = "test_e2-tts",
44
+ wandb_run_name = "test_run",
45
+ wandb_resume_id: str = None,
46
+ last_per_steps = None,
47
+ accelerate_kwargs: dict = dict(),
48
+ ema_kwargs: dict = dict()
49
+ ):
50
+
51
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
52
+
53
+ self.accelerator = Accelerator(
54
+ log_with = "wandb",
55
+ kwargs_handlers = [ddp_kwargs],
56
+ gradient_accumulation_steps = grad_accumulation_steps,
57
+ **accelerate_kwargs
58
+ )
59
+
60
+ if exists(wandb_resume_id):
61
+ init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
62
+ else:
63
+ init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
64
+ self.accelerator.init_trackers(
65
+ project_name = wandb_project,
66
+ init_kwargs=init_kwargs,
67
+ config={"epochs": epochs,
68
+ "learning_rate": learning_rate,
69
+ "num_warmup_updates": num_warmup_updates,
70
+ "batch_size": batch_size,
71
+ "batch_size_type": batch_size_type,
72
+ "max_samples": max_samples,
73
+ "grad_accumulation_steps": grad_accumulation_steps,
74
+ "max_grad_norm": max_grad_norm,
75
+ "gpus": self.accelerator.num_processes,
76
+ "noise_scheduler": noise_scheduler}
77
+ )
78
+
79
+ self.model = model
80
+
81
+ if self.is_main:
82
+ self.ema_model = EMA(
83
+ model,
84
+ include_online_model = False,
85
+ **ema_kwargs
86
+ )
87
+
88
+ self.ema_model.to(self.accelerator.device)
89
+
90
+ self.epochs = epochs
91
+ self.num_warmup_updates = num_warmup_updates
92
+ self.save_per_updates = save_per_updates
93
+ self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
94
+ self.checkpoint_path = default(checkpoint_path, 'ckpts/test_e2-tts')
95
+
96
+ self.batch_size = batch_size
97
+ self.batch_size_type = batch_size_type
98
+ self.max_samples = max_samples
99
+ self.grad_accumulation_steps = grad_accumulation_steps
100
+ self.max_grad_norm = max_grad_norm
101
+
102
+ self.noise_scheduler = noise_scheduler
103
+
104
+ self.duration_predictor = duration_predictor
105
+
106
+ self.optimizer = AdamW(model.parameters(), lr=learning_rate)
107
+ self.model, self.optimizer = self.accelerator.prepare(
108
+ self.model, self.optimizer
109
+ )
110
+
111
+ @property
112
+ def is_main(self):
113
+ return self.accelerator.is_main_process
114
+
115
+ def save_checkpoint(self, step, last=False):
116
+ self.accelerator.wait_for_everyone()
117
+ if self.is_main:
118
+ checkpoint = dict(
119
+ model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(),
120
+ optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(),
121
+ ema_model_state_dict = self.ema_model.state_dict(),
122
+ scheduler_state_dict = self.scheduler.state_dict(),
123
+ step = step
124
+ )
125
+ if not os.path.exists(self.checkpoint_path):
126
+ os.makedirs(self.checkpoint_path)
127
+ if last == True:
128
+ self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
129
+ print(f"Saved last checkpoint at step {step}")
130
+ else:
131
+ self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
132
+
133
+ def load_checkpoint(self):
134
+ if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) or not os.listdir(self.checkpoint_path):
135
+ return 0
136
+
137
+ self.accelerator.wait_for_everyone()
138
+ if "model_last.pt" in os.listdir(self.checkpoint_path):
139
+ latest_checkpoint = "model_last.pt"
140
+ else:
141
+ latest_checkpoint = sorted(os.listdir(self.checkpoint_path), key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
142
+ # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
143
+ checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu")
144
+ self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
145
+ self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict'])
146
+
147
+ if self.is_main:
148
+ self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
149
+
150
+ if self.scheduler:
151
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
152
+
153
+ step = checkpoint['step']
154
+ del checkpoint; gc.collect()
155
+ return step
156
+
157
+ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
158
+
159
+ if exists(resumable_with_seed):
160
+ generator = torch.Generator()
161
+ generator.manual_seed(resumable_with_seed)
162
+ else:
163
+ generator = None
164
+
165
+ if self.batch_size_type == "sample":
166
+ train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True,
167
+ batch_size=self.batch_size, shuffle=True, generator=generator)
168
+ elif self.batch_size_type == "frame":
169
+ self.accelerator.even_batches = False
170
+ sampler = SequentialSampler(train_dataset)
171
+ batch_sampler = DynamicBatchSampler(sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False)
172
+ train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True,
173
+ batch_sampler=batch_sampler)
174
+ else:
175
+ raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but recieved {self.batch_size_type}")
176
+
177
+ # accelerator.prepare() dispatches batches to devices;
178
+ # which means the length of dataloader calculated before, should consider the number of devices
179
+ warmup_steps = self.num_warmup_updates * self.accelerator.num_processes # consider a fixed warmup steps while using accelerate multi-gpu ddp
180
+ # otherwise by default with split_batches=False, warmup steps change with num_processes
181
+ total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
182
+ decay_steps = total_steps - warmup_steps
183
+ warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
184
+ decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
185
+ self.scheduler = SequentialLR(self.optimizer,
186
+ schedulers=[warmup_scheduler, decay_scheduler],
187
+ milestones=[warmup_steps])
188
+ train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler) # actual steps = 1 gpu steps / gpus
189
+ start_step = self.load_checkpoint()
190
+ global_step = start_step
191
+
192
+ if exists(resumable_with_seed):
193
+ orig_epoch_step = len(train_dataloader)
194
+ skipped_epoch = int(start_step // orig_epoch_step)
195
+ skipped_batch = start_step % orig_epoch_step
196
+ skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
197
+ else:
198
+ skipped_epoch = 0
199
+
200
+ for epoch in range(skipped_epoch, self.epochs):
201
+ self.model.train()
202
+ if exists(resumable_with_seed) and epoch == skipped_epoch:
203
+ progress_bar = tqdm(skipped_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process,
204
+ initial=skipped_batch, total=orig_epoch_step)
205
+ else:
206
+ progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process)
207
+
208
+ for batch in progress_bar:
209
+ with self.accelerator.accumulate(self.model):
210
+ text_inputs = batch['text']
211
+ mel_spec = rearrange(batch['mel'], 'b d n -> b n d')
212
+ mel_lengths = batch["mel_lengths"]
213
+
214
+ # TODO. add duration predictor training
215
+ if self.duration_predictor is not None and self.accelerator.is_local_main_process:
216
+ dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
217
+ self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
218
+
219
+ loss, cond, pred = self.model(mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler)
220
+ self.accelerator.backward(loss)
221
+
222
+ if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
223
+ self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
224
+
225
+ self.optimizer.step()
226
+ self.scheduler.step()
227
+ self.optimizer.zero_grad()
228
+
229
+ if self.is_main:
230
+ self.ema_model.update()
231
+
232
+ global_step += 1
233
+
234
+ if self.accelerator.is_local_main_process:
235
+ self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
236
+
237
+ progress_bar.set_postfix(step=str(global_step), loss=loss.item())
238
+
239
+ if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
240
+ self.save_checkpoint(global_step)
241
+
242
+ if global_step % self.last_per_steps == 0:
243
+ self.save_checkpoint(global_step, last=True)
244
+
245
+ self.accelerator.end_training()
model/utils.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import re
5
+ import math
6
+ import random
7
+ import string
8
+ from tqdm import tqdm
9
+ from collections import defaultdict
10
+
11
+ import matplotlib
12
+ matplotlib.use("Agg")
13
+ import matplotlib.pylab as plt
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from torch.nn.utils.rnn import pad_sequence
18
+ import torchaudio
19
+
20
+ import einx
21
+ from einops import rearrange, reduce
22
+
23
+ import jieba
24
+ from pypinyin import lazy_pinyin, Style
25
+ import zhconv
26
+ from zhon.hanzi import punctuation
27
+ from jiwer import compute_measures
28
+
29
+ from funasr import AutoModel
30
+ from faster_whisper import WhisperModel
31
+
32
+ from model.ecapa_tdnn import ECAPA_TDNN_SMALL
33
+ from model.modules import MelSpec
34
+
35
+
36
+ # seed everything
37
+
38
+ def seed_everything(seed = 0):
39
+ random.seed(seed)
40
+ os.environ['PYTHONHASHSEED'] = str(seed)
41
+ torch.manual_seed(seed)
42
+ torch.cuda.manual_seed(seed)
43
+ torch.cuda.manual_seed_all(seed)
44
+ torch.backends.cudnn.deterministic = True
45
+ torch.backends.cudnn.benchmark = False
46
+
47
+ # helpers
48
+
49
+ def exists(v):
50
+ return v is not None
51
+
52
+ def default(v, d):
53
+ return v if exists(v) else d
54
+
55
+ # tensor helpers
56
+
57
+ def lens_to_mask(
58
+ t: int['b'],
59
+ length: int | None = None
60
+ ) -> bool['b n']:
61
+
62
+ if not exists(length):
63
+ length = t.amax()
64
+
65
+ seq = torch.arange(length, device = t.device)
66
+ return einx.less('n, b -> b n', seq, t)
67
+
68
+ def mask_from_start_end_indices(
69
+ seq_len: int['b'],
70
+ start: int['b'],
71
+ end: int['b']
72
+ ):
73
+ max_seq_len = seq_len.max().item()
74
+ seq = torch.arange(max_seq_len, device = start.device).long()
75
+ return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end)
76
+
77
+ def mask_from_frac_lengths(
78
+ seq_len: int['b'],
79
+ frac_lengths: float['b']
80
+ ):
81
+ lengths = (frac_lengths * seq_len).long()
82
+ max_start = seq_len - lengths
83
+
84
+ rand = torch.rand_like(frac_lengths)
85
+ start = (max_start * rand).long().clamp(min = 0)
86
+ end = start + lengths
87
+
88
+ return mask_from_start_end_indices(seq_len, start, end)
89
+
90
+ def maybe_masked_mean(
91
+ t: float['b n d'],
92
+ mask: bool['b n'] = None
93
+ ) -> float['b d']:
94
+
95
+ if not exists(mask):
96
+ return t.mean(dim = 1)
97
+
98
+ t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
99
+ num = reduce(t, 'b n d -> b d', 'sum')
100
+ den = reduce(mask.float(), 'b n -> b', 'sum')
101
+
102
+ return einx.divide('b d, b -> b d', num, den.clamp(min = 1.))
103
+
104
+
105
+ # simple utf-8 tokenizer, since paper went character based
106
+ def list_str_to_tensor(
107
+ text: list[str],
108
+ padding_value = -1
109
+ ) -> int['b nt']:
110
+ list_tensors = [torch.tensor([*bytes(t, 'UTF-8')]) for t in text] # ByT5 style
111
+ text = pad_sequence(list_tensors, padding_value = padding_value, batch_first = True)
112
+ return text
113
+
114
+ # char tokenizer, based on custom dataset's extracted .txt file
115
+ def list_str_to_idx(
116
+ text: list[str] | list[list[str]],
117
+ vocab_char_map: dict[str, int], # {char: idx}
118
+ padding_value = -1
119
+ ) -> int['b nt']:
120
+ list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
121
+ text = pad_sequence(list_idx_tensors, padding_value = padding_value, batch_first = True)
122
+ return text
123
+
124
+
125
+ # Get tokenizer
126
+
127
+ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
128
+ '''
129
+ tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
130
+ - "char" for char-wise tokenizer, need .txt vocab_file
131
+ - "byte" for utf-8 tokenizer
132
+ vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
133
+ - if use "char", derived from unfiltered character & symbol counts of custom dataset
134
+ - if use "byte", set to 256 (unicode byte range)
135
+ '''
136
+ if tokenizer in ["pinyin", "char"]:
137
+ with open (f"data/{dataset_name}_{tokenizer}/vocab.txt", "r") as f:
138
+ vocab_char_map = {}
139
+ for i, char in enumerate(f):
140
+ vocab_char_map[char[:-1]] = i
141
+ vocab_size = len(vocab_char_map)
142
+ assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
143
+
144
+ elif tokenizer == "byte":
145
+ vocab_char_map = None
146
+ vocab_size = 256
147
+
148
+ return vocab_char_map, vocab_size
149
+
150
+
151
+ # convert char to pinyin
152
+
153
+ def convert_char_to_pinyin(text_list, polyphone = True):
154
+ final_text_list = []
155
+ god_knows_why_en_testset_contains_zh_quote = str.maketrans({'“': '"', '”': '"', '‘': "'", '’': "'"}) # in case librispeech (orig no-pc) test-clean
156
+ for text in text_list:
157
+ char_list = []
158
+ text = text.translate(god_knows_why_en_testset_contains_zh_quote)
159
+ for seg in jieba.cut(text):
160
+ seg_byte_len = len(bytes(seg, 'UTF-8'))
161
+ if seg_byte_len == len(seg): # if pure alphabets and symbols
162
+ if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
163
+ char_list.append(" ")
164
+ char_list.extend(seg)
165
+ elif polyphone and seg_byte_len == 3 * len(seg): # if pure chinese characters
166
+ seg = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
167
+ for c in seg:
168
+ if c not in "。,、;:?!《》【】—…":
169
+ char_list.append(" ")
170
+ char_list.append(c)
171
+ else: # if mixed chinese characters, alphabets and symbols
172
+ for c in seg:
173
+ if ord(c) < 256:
174
+ char_list.extend(c)
175
+ else:
176
+ if c not in "。,、;:?!《》【】—…":
177
+ char_list.append(" ")
178
+ char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
179
+ else: # if is zh punc
180
+ char_list.append(c)
181
+ final_text_list.append(char_list)
182
+
183
+ return final_text_list
184
+
185
+
186
+ # save spectrogram
187
+ def save_spectrogram(spectrogram, path):
188
+ plt.figure(figsize=(12, 4))
189
+ plt.imshow(spectrogram, origin='lower', aspect='auto')
190
+ plt.colorbar()
191
+ plt.savefig(path)
192
+ plt.close()
193
+
194
+
195
+ # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
196
+ def get_seedtts_testset_metainfo(metalst):
197
+ f = open(metalst); lines = f.readlines(); f.close()
198
+ metainfo = []
199
+ for line in lines:
200
+ if len(line.strip().split('|')) == 5:
201
+ utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
202
+ elif len(line.strip().split('|')) == 4:
203
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
204
+ gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
205
+ if not os.path.isabs(prompt_wav):
206
+ prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
207
+ metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
208
+ return metainfo
209
+
210
+
211
+ # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
212
+ def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
213
+ f = open(metalst); lines = f.readlines(); f.close()
214
+ metainfo = []
215
+ for line in lines:
216
+ ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
217
+
218
+ # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
219
+ ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
220
+ ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
221
+
222
+ # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
223
+ gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
224
+ gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
225
+
226
+ metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
227
+
228
+ return metainfo
229
+
230
+
231
+ # padded to max length mel batch
232
+ def padded_mel_batch(ref_mels):
233
+ max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
234
+ padded_ref_mels = []
235
+ for mel in ref_mels:
236
+ padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value = 0)
237
+ padded_ref_mels.append(padded_ref_mel)
238
+ padded_ref_mels = torch.stack(padded_ref_mels)
239
+ padded_ref_mels = rearrange(padded_ref_mels, 'b d n -> b n d')
240
+ return padded_ref_mels
241
+
242
+
243
+ # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
244
+
245
+ def get_inference_prompt(
246
+ metainfo,
247
+ speed = 1., tokenizer = "pinyin", polyphone = True,
248
+ target_sample_rate = 24000, n_mel_channels = 100, hop_length = 256, target_rms = 0.1,
249
+ use_truth_duration = False,
250
+ infer_batch_size = 1, num_buckets = 200, min_secs = 3, max_secs = 40,
251
+ ):
252
+ prompts_all = []
253
+
254
+ min_tokens = min_secs * target_sample_rate // hop_length
255
+ max_tokens = max_secs * target_sample_rate // hop_length
256
+
257
+ batch_accum = [0] * num_buckets
258
+ utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = \
259
+ ([[] for _ in range(num_buckets)] for _ in range(6))
260
+
261
+ mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
262
+
263
+ for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
264
+
265
+ # Audio
266
+ ref_audio, ref_sr = torchaudio.load(prompt_wav)
267
+ ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
268
+ if ref_rms < target_rms:
269
+ ref_audio = ref_audio * target_rms / ref_rms
270
+ assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
271
+ if ref_sr != target_sample_rate:
272
+ resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
273
+ ref_audio = resampler(ref_audio)
274
+
275
+ # Text
276
+ text = [prompt_text + gt_text]
277
+ if tokenizer == "pinyin":
278
+ text_list = convert_char_to_pinyin(text, polyphone = polyphone)
279
+ else:
280
+ text_list = text
281
+
282
+ # Duration, mel frame length
283
+ ref_mel_len = ref_audio.shape[-1] // hop_length
284
+ if use_truth_duration:
285
+ gt_audio, gt_sr = torchaudio.load(gt_wav)
286
+ if gt_sr != target_sample_rate:
287
+ resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
288
+ gt_audio = resampler(gt_audio)
289
+ total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
290
+
291
+ # # test vocoder resynthesis
292
+ # ref_audio = gt_audio
293
+ else:
294
+ zh_pause_punc = r"。,、;:?!"
295
+ ref_text_len = len(prompt_text) + len(re.findall(zh_pause_punc, prompt_text))
296
+ gen_text_len = len(gt_text) + len(re.findall(zh_pause_punc, gt_text))
297
+ total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
298
+
299
+ # to mel spectrogram
300
+ ref_mel = mel_spectrogram(ref_audio)
301
+ ref_mel = rearrange(ref_mel, '1 d n -> d n')
302
+
303
+ # deal with batch
304
+ assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
305
+ assert min_tokens <= total_mel_len <= max_tokens, \
306
+ f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
307
+ bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
308
+
309
+ utts[bucket_i].append(utt)
310
+ ref_rms_list[bucket_i].append(ref_rms)
311
+ ref_mels[bucket_i].append(ref_mel)
312
+ ref_mel_lens[bucket_i].append(ref_mel_len)
313
+ total_mel_lens[bucket_i].append(total_mel_len)
314
+ final_text_list[bucket_i].extend(text_list)
315
+
316
+ batch_accum[bucket_i] += total_mel_len
317
+
318
+ if batch_accum[bucket_i] >= infer_batch_size:
319
+ # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
320
+ prompts_all.append((
321
+ utts[bucket_i],
322
+ ref_rms_list[bucket_i],
323
+ padded_mel_batch(ref_mels[bucket_i]),
324
+ ref_mel_lens[bucket_i],
325
+ total_mel_lens[bucket_i],
326
+ final_text_list[bucket_i]
327
+ ))
328
+ batch_accum[bucket_i] = 0
329
+ utts[bucket_i], ref_rms_list[bucket_i], ref_mels[bucket_i], ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i] = [], [], [], [], [], []
330
+
331
+ # add residual
332
+ for bucket_i, bucket_frames in enumerate(batch_accum):
333
+ if bucket_frames > 0:
334
+ prompts_all.append((
335
+ utts[bucket_i],
336
+ ref_rms_list[bucket_i],
337
+ padded_mel_batch(ref_mels[bucket_i]),
338
+ ref_mel_lens[bucket_i],
339
+ total_mel_lens[bucket_i],
340
+ final_text_list[bucket_i]
341
+ ))
342
+ # not only leave easy work for last workers
343
+ random.seed(666)
344
+ random.shuffle(prompts_all)
345
+
346
+ return prompts_all
347
+
348
+
349
+ # get wav_res_ref_text of seed-tts test metalst
350
+ # https://github.com/BytedanceSpeech/seed-tts-eval
351
+
352
+ def get_seed_tts_test(metalst, gen_wav_dir, gpus):
353
+ f = open(metalst)
354
+ lines = f.readlines()
355
+ f.close()
356
+
357
+ test_set_ = []
358
+ for line in tqdm(lines):
359
+ if len(line.strip().split('|')) == 5:
360
+ utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
361
+ elif len(line.strip().split('|')) == 4:
362
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
363
+
364
+ if not os.path.exists(os.path.join(gen_wav_dir, utt + '.wav')):
365
+ continue
366
+ gen_wav = os.path.join(gen_wav_dir, utt + '.wav')
367
+ if not os.path.isabs(prompt_wav):
368
+ prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
369
+
370
+ test_set_.append((gen_wav, prompt_wav, gt_text))
371
+
372
+ num_jobs = len(gpus)
373
+ if num_jobs == 1:
374
+ return [(gpus[0], test_set_)]
375
+
376
+ wav_per_job = len(test_set_) // num_jobs + 1
377
+ test_set = []
378
+ for i in range(num_jobs):
379
+ test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
380
+
381
+ return test_set
382
+
383
+
384
+ # get librispeech test-clean cross sentence test
385
+
386
+ def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = False):
387
+ f = open(metalst)
388
+ lines = f.readlines()
389
+ f.close()
390
+
391
+ test_set_ = []
392
+ for line in tqdm(lines):
393
+ ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
394
+
395
+ if eval_ground_truth:
396
+ gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
397
+ gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
398
+ else:
399
+ if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + '.wav')):
400
+ raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
401
+ gen_wav = os.path.join(gen_wav_dir, gen_utt + '.wav')
402
+
403
+ ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
404
+ ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
405
+
406
+ test_set_.append((gen_wav, ref_wav, gen_txt))
407
+
408
+ num_jobs = len(gpus)
409
+ if num_jobs == 1:
410
+ return [(gpus[0], test_set_)]
411
+
412
+ wav_per_job = len(test_set_) // num_jobs + 1
413
+ test_set = []
414
+ for i in range(num_jobs):
415
+ test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
416
+
417
+ return test_set
418
+
419
+
420
+ # load asr model
421
+
422
+ def load_asr_model(lang, ckpt_dir = ""):
423
+ if lang == "zh":
424
+ model = AutoModel(
425
+ model = os.path.join(ckpt_dir, "paraformer-zh"),
426
+ # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
427
+ # punc_model = os.path.join(ckpt_dir, "ct-punc"),
428
+ # spk_model = os.path.join(ckpt_dir, "cam++"),
429
+ disable_update=True,
430
+ ) # following seed-tts setting
431
+ elif lang == "en":
432
+ model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
433
+ model = WhisperModel(model_size, device="cuda", compute_type="float16")
434
+ return model
435
+
436
+
437
+ # WER Evaluation, the way Seed-TTS does
438
+
439
+ def run_asr_wer(args):
440
+ rank, lang, test_set, ckpt_dir = args
441
+
442
+ if lang == "zh":
443
+ torch.cuda.set_device(rank)
444
+ elif lang == "en":
445
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
446
+ else:
447
+ raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
448
+
449
+ asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
450
+
451
+ punctuation_all = punctuation + string.punctuation
452
+ wers = []
453
+
454
+ for gen_wav, prompt_wav, truth in tqdm(test_set):
455
+ if lang == "zh":
456
+ res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
457
+ hypo = res[0]["text"]
458
+ hypo = zhconv.convert(hypo, 'zh-cn')
459
+ elif lang == "en":
460
+ segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
461
+ hypo = ''
462
+ for segment in segments:
463
+ hypo = hypo + ' ' + segment.text
464
+
465
+ # raw_truth = truth
466
+ # raw_hypo = hypo
467
+
468
+ for x in punctuation_all:
469
+ truth = truth.replace(x, '')
470
+ hypo = hypo.replace(x, '')
471
+
472
+ truth = truth.replace(' ', ' ')
473
+ hypo = hypo.replace(' ', ' ')
474
+
475
+ if lang == "zh":
476
+ truth = " ".join([x for x in truth])
477
+ hypo = " ".join([x for x in hypo])
478
+ elif lang == "en":
479
+ truth = truth.lower()
480
+ hypo = hypo.lower()
481
+
482
+ measures = compute_measures(truth, hypo)
483
+ wer = measures["wer"]
484
+
485
+ # ref_list = truth.split(" ")
486
+ # subs = measures["substitutions"] / len(ref_list)
487
+ # dele = measures["deletions"] / len(ref_list)
488
+ # inse = measures["insertions"] / len(ref_list)
489
+
490
+ wers.append(wer)
491
+
492
+ return wers
493
+
494
+
495
+ # SIM Evaluation
496
+
497
+ def run_sim(args):
498
+ rank, test_set, ckpt_dir = args
499
+ device = f"cuda:{rank}"
500
+
501
+ model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
502
+ state_dict = torch.load(ckpt_dir, map_location=lambda storage, loc: storage)
503
+ model.load_state_dict(state_dict['model'], strict=False)
504
+
505
+ use_gpu=True if torch.cuda.is_available() else False
506
+ if use_gpu:
507
+ model = model.cuda(device)
508
+ model.eval()
509
+
510
+ sim_list = []
511
+ for wav1, wav2, truth in tqdm(test_set):
512
+
513
+ wav1, sr1 = torchaudio.load(wav1)
514
+ wav2, sr2 = torchaudio.load(wav2)
515
+
516
+ resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
517
+ resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
518
+ wav1 = resample1(wav1)
519
+ wav2 = resample2(wav2)
520
+
521
+ if use_gpu:
522
+ wav1 = wav1.cuda(device)
523
+ wav2 = wav2.cuda(device)
524
+ with torch.no_grad():
525
+ emb1 = model(wav1)
526
+ emb2 = model(wav2)
527
+
528
+ sim = F.cosine_similarity(emb1, emb2)[0].item()
529
+ # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
530
+ sim_list.append(sim)
531
+
532
+ return sim_list
533
+
534
+
535
+ # filter func for dirty data with many repetitions
536
+
537
+ def repetition_found(text, length = 2, tolerance = 10):
538
+ pattern_count = defaultdict(int)
539
+ for i in range(len(text) - length + 1):
540
+ pattern = text[i:i + length]
541
+ pattern_count[pattern] += 1
542
+ for pattern, count in pattern_count.items():
543
+ if count > tolerance:
544
+ return True
545
+ return False
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=0.33.0
2
+ datasets
3
+ einops>=0.8.0
4
+ einx>=0.3.0
5
+ ema_pytorch>=0.5.2
6
+ faster_whisper
7
+ funasr
8
+ jieba
9
+ jiwer
10
+ librosa
11
+ matplotlib
12
+ pypinyin
13
+ torch>=2.0
14
+ torchaudio>=2.3.0
15
+ torchdiffeq
16
+ tqdm>=4.65.0
17
+ transformers
18
+ vocos
19
+ wandb
20
+ x_transformers>=1.31.14
21
+ zhconv
22
+ zhon
23
+ cached_path
24
+ pydub
scripts/count_max_epoch.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''ADAPTIVE BATCH SIZE'''
2
+ print('Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in')
3
+ print(' -> least padding, gather wavs with accumulated frames in a batch\n')
4
+
5
+ # data
6
+ total_hours = 95282
7
+ mel_hop_length = 256
8
+ mel_sampling_rate = 24000
9
+
10
+ # target
11
+ wanted_max_updates = 1000000
12
+
13
+ # train params
14
+ gpus = 8
15
+ frames_per_gpu = 38400 # 8 * 38400 = 307200
16
+ grad_accum = 1
17
+
18
+ # intermediate
19
+ mini_batch_frames = frames_per_gpu * grad_accum * gpus
20
+ mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
21
+ updates_per_epoch = total_hours / mini_batch_hours
22
+ steps_per_epoch = updates_per_epoch * grad_accum
23
+
24
+ # result
25
+ epochs = wanted_max_updates / updates_per_epoch
26
+ print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
27
+ print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
28
+ print(f" or approx. 0/{steps_per_epoch:.0f} steps")
29
+
30
+ # others
31
+ print(f"total {total_hours:.0f} hours")
32
+ print(f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch")
scripts/count_params_gflops.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ sys.path.append(os.getcwd())
3
+
4
+ from model import M2_TTS, UNetT, DiT, MMDiT
5
+
6
+ import torch
7
+ import thop
8
+
9
+
10
+ ''' ~155M '''
11
+ # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
12
+ # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
13
+ # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
14
+ # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4)
15
+ # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
16
+ # transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
17
+
18
+ ''' ~335M '''
19
+ # FLOPs: 622.1 G, Params: 333.2 M
20
+ # transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
21
+ # FLOPs: 363.4 G, Params: 335.8 M
22
+ transformer = DiT(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
23
+
24
+
25
+ model = M2_TTS(transformer=transformer)
26
+ target_sample_rate = 24000
27
+ n_mel_channels = 100
28
+ hop_length = 256
29
+ duration = 20
30
+ frame_length = int(duration * target_sample_rate / hop_length)
31
+ text_length = 150
32
+
33
+ flops, params = thop.profile(model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long)))
34
+ print(f"FLOPs: {flops / 1e9} G")
35
+ print(f"Params: {params / 1e6} M")
scripts/eval_librispeech_test_clean.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
2
+
3
+ import sys, os
4
+ sys.path.append(os.getcwd())
5
+
6
+ import multiprocessing as mp
7
+ import numpy as np
8
+
9
+ from model.utils import (
10
+ get_librispeech_test,
11
+ run_asr_wer,
12
+ run_sim,
13
+ )
14
+
15
+
16
+ eval_task = "wer" # sim | wer
17
+ lang = "en"
18
+ metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
19
+ librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
20
+ gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
21
+
22
+ gpus = [0,1,2,3,4,5,6,7]
23
+ test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
24
+
25
+ ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
26
+ ## leading to a low similarity for the ground truth in some cases.
27
+ # test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth
28
+
29
+ local = False
30
+ if local: # use local custom checkpoint dir
31
+ asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
32
+ else:
33
+ asr_ckpt_dir = "" # auto download to cache dir
34
+
35
+ wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
36
+
37
+
38
+ # --------------------------- WER ---------------------------
39
+
40
+ if eval_task == "wer":
41
+ wers = []
42
+
43
+ with mp.Pool(processes=len(gpus)) as pool:
44
+ args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
45
+ results = pool.map(run_asr_wer, args)
46
+ for wers_ in results:
47
+ wers.extend(wers_)
48
+
49
+ wer = round(np.mean(wers)*100, 3)
50
+ print(f"\nTotal {len(wers)} samples")
51
+ print(f"WER : {wer}%")
52
+
53
+
54
+ # --------------------------- SIM ---------------------------
55
+
56
+ if eval_task == "sim":
57
+ sim_list = []
58
+
59
+ with mp.Pool(processes=len(gpus)) as pool:
60
+ args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
61
+ results = pool.map(run_sim, args)
62
+ for sim_ in results:
63
+ sim_list.extend(sim_)
64
+
65
+ sim = round(sum(sim_list)/len(sim_list), 3)
66
+ print(f"\nTotal {len(sim_list)} samples")
67
+ print(f"SIM : {sim}")
scripts/eval_seedtts_testset.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluate with Seed-TTS testset
2
+
3
+ import sys, os
4
+ sys.path.append(os.getcwd())
5
+
6
+ import multiprocessing as mp
7
+ import numpy as np
8
+
9
+ from model.utils import (
10
+ get_seed_tts_test,
11
+ run_asr_wer,
12
+ run_sim,
13
+ )
14
+
15
+
16
+ eval_task = "wer" # sim | wer
17
+ lang = "zh" # zh | en
18
+ metalst = f"data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
19
+ # gen_wav_dir = f"data/seedtts_testset/{lang}/wavs" # ground truth wavs
20
+ gen_wav_dir = f"PATH_TO_GENERATED" # generated wavs
21
+
22
+
23
+ # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
24
+ # zh 1.254 seems a result of 4 workers wer_seed_tts
25
+ gpus = [0,1,2,3,4,5,6,7]
26
+ test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
27
+
28
+ local = False
29
+ if local: # use local custom checkpoint dir
30
+ if lang == "zh":
31
+ asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
32
+ elif lang == "en":
33
+ asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
34
+ else:
35
+ asr_ckpt_dir = "" # auto download to cache dir
36
+
37
+ wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
38
+
39
+
40
+ # --------------------------- WER ---------------------------
41
+
42
+ if eval_task == "wer":
43
+ wers = []
44
+
45
+ with mp.Pool(processes=len(gpus)) as pool:
46
+ args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
47
+ results = pool.map(run_asr_wer, args)
48
+ for wers_ in results:
49
+ wers.extend(wers_)
50
+
51
+ wer = round(np.mean(wers)*100, 3)
52
+ print(f"\nTotal {len(wers)} samples")
53
+ print(f"WER : {wer}%")
54
+
55
+
56
+ # --------------------------- SIM ---------------------------
57
+
58
+ if eval_task == "sim":
59
+ sim_list = []
60
+
61
+ with mp.Pool(processes=len(gpus)) as pool:
62
+ args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
63
+ results = pool.map(run_sim, args)
64
+ for sim_ in results:
65
+ sim_list.extend(sim_)
66
+
67
+ sim = round(sum(sim_list)/len(sim_list), 3)
68
+ print(f"\nTotal {len(sim_list)} samples")
69
+ print(f"SIM : {sim}")
scripts/prepare_emilia.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Emilia Dataset: https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07
2
+ # if use updated new version, i.e. WebDataset, feel free to modify / draft your own script
3
+
4
+ # generate audio text map for Emilia ZH & EN
5
+ # evaluate for vocab size
6
+
7
+ import sys, os
8
+ sys.path.append(os.getcwd())
9
+
10
+ from pathlib import Path
11
+ import json
12
+ from tqdm import tqdm
13
+ from concurrent.futures import ProcessPoolExecutor
14
+
15
+ from datasets import Dataset
16
+ from datasets.arrow_writer import ArrowWriter
17
+
18
+ from model.utils import (
19
+ repetition_found,
20
+ convert_char_to_pinyin,
21
+ )
22
+
23
+
24
+ out_zh = {"ZH_B00041_S06226", "ZH_B00042_S09204", "ZH_B00065_S09430", "ZH_B00065_S09431", "ZH_B00066_S09327", "ZH_B00066_S09328"}
25
+ zh_filters = ["い", "て"]
26
+ # seems synthesized audios, or heavily code-switched
27
+ out_en = {
28
+ "EN_B00013_S00913", "EN_B00042_S00120", "EN_B00055_S04111", "EN_B00061_S00693", "EN_B00061_S01494", "EN_B00061_S03375",
29
+
30
+ "EN_B00059_S00092", "EN_B00111_S04300", "EN_B00100_S03759", "EN_B00087_S03811", "EN_B00059_S00950", "EN_B00089_S00946", "EN_B00078_S05127", "EN_B00070_S04089", "EN_B00074_S09659", "EN_B00061_S06983", "EN_B00061_S07060", "EN_B00059_S08397", "EN_B00082_S06192", "EN_B00091_S01238", "EN_B00089_S07349", "EN_B00070_S04343", "EN_B00061_S02400", "EN_B00076_S01262", "EN_B00068_S06467", "EN_B00076_S02943", "EN_B00064_S05954", "EN_B00061_S05386", "EN_B00066_S06544", "EN_B00076_S06944", "EN_B00072_S08620", "EN_B00076_S07135", "EN_B00076_S09127", "EN_B00065_S00497", "EN_B00059_S06227", "EN_B00063_S02859", "EN_B00075_S01547", "EN_B00061_S08286", "EN_B00079_S02901", "EN_B00092_S03643", "EN_B00096_S08653", "EN_B00063_S04297", "EN_B00063_S04614", "EN_B00079_S04698", "EN_B00104_S01666", "EN_B00061_S09504", "EN_B00061_S09694", "EN_B00065_S05444", "EN_B00063_S06860", "EN_B00065_S05725", "EN_B00069_S07628", "EN_B00083_S03875", "EN_B00071_S07665", "EN_B00071_S07665", "EN_B00062_S04187", "EN_B00065_S09873", "EN_B00065_S09922", "EN_B00084_S02463", "EN_B00067_S05066", "EN_B00106_S08060", "EN_B00073_S06399", "EN_B00073_S09236", "EN_B00087_S00432", "EN_B00085_S05618", "EN_B00064_S01262", "EN_B00072_S01739", "EN_B00059_S03913", "EN_B00069_S04036", "EN_B00067_S05623", "EN_B00060_S05389", "EN_B00060_S07290", "EN_B00062_S08995",
31
+ }
32
+ en_filters = ["ا", "い", "て"]
33
+
34
+
35
+ def deal_with_audio_dir(audio_dir):
36
+ audio_jsonl = audio_dir.with_suffix(".jsonl")
37
+ sub_result, durations = [], []
38
+ vocab_set = set()
39
+ bad_case_zh = 0
40
+ bad_case_en = 0
41
+ with open(audio_jsonl, "r") as f:
42
+ lines = f.readlines()
43
+ for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
44
+ obj = json.loads(line)
45
+ text = obj["text"]
46
+ if obj['language'] == "zh":
47
+ if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
48
+ bad_case_zh += 1
49
+ continue
50
+ else:
51
+ text = text.translate(str.maketrans({',': ',', '!': '!', '?': '?'})) # not "。" cuz much code-switched
52
+ if obj['language'] == "en":
53
+ if obj["wav"].split("/")[1] in out_en or any(f in text for f in en_filters) or repetition_found(text, length=4):
54
+ bad_case_en += 1
55
+ continue
56
+ if tokenizer == "pinyin":
57
+ text = convert_char_to_pinyin([text], polyphone = polyphone)[0]
58
+ duration = obj["duration"]
59
+ sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
60
+ durations.append(duration)
61
+ vocab_set.update(list(text))
62
+ return sub_result, durations, vocab_set, bad_case_zh, bad_case_en
63
+
64
+
65
+ def main():
66
+ assert tokenizer in ["pinyin", "char"]
67
+ result = []
68
+ duration_list = []
69
+ text_vocab_set = set()
70
+ total_bad_case_zh = 0
71
+ total_bad_case_en = 0
72
+
73
+ # process raw data
74
+ executor = ProcessPoolExecutor(max_workers=max_workers)
75
+ futures = []
76
+ for lang in langs:
77
+ dataset_path = Path(os.path.join(dataset_dir, lang))
78
+ [
79
+ futures.append(executor.submit(deal_with_audio_dir, audio_dir))
80
+ for audio_dir in dataset_path.iterdir()
81
+ if audio_dir.is_dir()
82
+ ]
83
+ for futures in tqdm(futures, total=len(futures)):
84
+ sub_result, durations, vocab_set, bad_case_zh, bad_case_en = futures.result()
85
+ result.extend(sub_result)
86
+ duration_list.extend(durations)
87
+ text_vocab_set.update(vocab_set)
88
+ total_bad_case_zh += bad_case_zh
89
+ total_bad_case_en += bad_case_en
90
+ executor.shutdown()
91
+
92
+ # save preprocessed dataset to disk
93
+ if not os.path.exists(f"data/{dataset_name}"):
94
+ os.makedirs(f"data/{dataset_name}")
95
+ print(f"\nSaving to data/{dataset_name} ...")
96
+ # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
97
+ # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
98
+ with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer:
99
+ for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
100
+ writer.write(line)
101
+
102
+ # dup a json separately saving duration in case for DynamicBatchSampler ease
103
+ with open(f"data/{dataset_name}/duration.json", 'w', encoding='utf-8') as f:
104
+ json.dump({"duration": duration_list}, f, ensure_ascii=False)
105
+
106
+ # vocab map, i.e. tokenizer
107
+ # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
108
+ # if tokenizer == "pinyin":
109
+ # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
110
+ with open(f"data/{dataset_name}/vocab.txt", "w") as f:
111
+ for vocab in sorted(text_vocab_set):
112
+ f.write(vocab + "\n")
113
+
114
+ print(f"\nFor {dataset_name}, sample count: {len(result)}")
115
+ print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
116
+ print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
117
+ if "ZH" in langs: print(f"Bad zh transcription case: {total_bad_case_zh}")
118
+ if "EN" in langs: print(f"Bad en transcription case: {total_bad_case_en}\n")
119
+
120
+
121
+ if __name__ == "__main__":
122
+
123
+ max_workers = 32
124
+
125
+ tokenizer = "pinyin" # "pinyin" | "char"
126
+ polyphone = True
127
+
128
+ langs = ["ZH", "EN"]
129
+ dataset_dir = "<SOME_PATH>/Emilia_Dataset/raw"
130
+ dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}"
131
+ print(f"\nPrepare for {dataset_name}\n")
132
+
133
+ main()
134
+
135
+ # Emilia ZH & EN
136
+ # samples count 37837916 (after removal)
137
+ # pinyin vocab size 2543 (polyphone)
138
+ # total duration 95281.87 (hours)
139
+ # bad zh asr cnt 230435 (samples)
140
+ # bad eh asr cnt 37217 (samples)
141
+
142
+ # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
143
+ # please be careful if using pretrained model, make sure the vocab.txt is same
scripts/prepare_wenetspeech4tts.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # generate audio text map for WenetSpeech4TTS
2
+ # evaluate for vocab size
3
+
4
+ import sys, os
5
+ sys.path.append(os.getcwd())
6
+
7
+ import json
8
+ from tqdm import tqdm
9
+ from concurrent.futures import ProcessPoolExecutor
10
+
11
+ import torchaudio
12
+ from datasets import Dataset
13
+
14
+ from model.utils import convert_char_to_pinyin
15
+
16
+
17
+ def deal_with_sub_path_files(dataset_path, sub_path):
18
+ print(f"Dealing with: {sub_path}")
19
+
20
+ text_dir = os.path.join(dataset_path, sub_path, "txts")
21
+ audio_dir = os.path.join(dataset_path, sub_path, "wavs")
22
+ text_files = os.listdir(text_dir)
23
+
24
+ audio_paths, texts, durations = [], [], []
25
+ for text_file in tqdm(text_files):
26
+ with open(os.path.join(text_dir, text_file), 'r', encoding='utf-8') as file:
27
+ first_line = file.readline().split("\t")
28
+ audio_nm = first_line[0]
29
+ audio_path = os.path.join(audio_dir, audio_nm + ".wav")
30
+ text = first_line[1].strip()
31
+
32
+ audio_paths.append(audio_path)
33
+
34
+ if tokenizer == "pinyin":
35
+ texts.extend(convert_char_to_pinyin([text], polyphone = polyphone))
36
+ elif tokenizer == "char":
37
+ texts.append(text)
38
+
39
+ audio, sample_rate = torchaudio.load(audio_path)
40
+ durations.append(audio.shape[-1] / sample_rate)
41
+
42
+ return audio_paths, texts, durations
43
+
44
+
45
+ def main():
46
+ assert tokenizer in ["pinyin", "char"]
47
+
48
+ audio_path_list, text_list, duration_list = [], [], []
49
+
50
+ executor = ProcessPoolExecutor(max_workers=max_workers)
51
+ futures = []
52
+ for dataset_path in dataset_paths:
53
+ sub_items = os.listdir(dataset_path)
54
+ sub_paths = [item for item in sub_items if os.path.isdir(os.path.join(dataset_path, item))]
55
+ for sub_path in sub_paths:
56
+ futures.append(executor.submit(deal_with_sub_path_files, dataset_path, sub_path))
57
+ for future in tqdm(futures, total=len(futures)):
58
+ audio_paths, texts, durations = future.result()
59
+ audio_path_list.extend(audio_paths)
60
+ text_list.extend(texts)
61
+ duration_list.extend(durations)
62
+ executor.shutdown()
63
+
64
+ if not os.path.exists("data"):
65
+ os.makedirs("data")
66
+
67
+ print(f"\nSaving to data/{dataset_name}_{tokenizer} ...")
68
+ dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
69
+ dataset.save_to_disk(f"data/{dataset_name}_{tokenizer}/raw", max_shard_size="2GB") # arrow format
70
+
71
+ with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'w', encoding='utf-8') as f:
72
+ json.dump({"duration": duration_list}, f, ensure_ascii=False) # dup a json separately saving duration in case for DynamicBatchSampler ease
73
+
74
+ print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
75
+ text_vocab_set = set()
76
+ for text in tqdm(text_list):
77
+ text_vocab_set.update(list(text))
78
+
79
+ # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
80
+ if tokenizer == "pinyin":
81
+ text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
82
+
83
+ with open(f"data/{dataset_name}_{tokenizer}/vocab.txt", "w") as f:
84
+ for vocab in sorted(text_vocab_set):
85
+ f.write(vocab + "\n")
86
+ print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
87
+ print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
88
+
89
+
90
+ if __name__ == "__main__":
91
+
92
+ max_workers = 32
93
+
94
+ tokenizer = "pinyin" # "pinyin" | "char"
95
+ polyphone = True
96
+ dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
97
+
98
+ dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice-1]
99
+ dataset_paths = [
100
+ "<SOME_PATH>/WenetSpeech4TTS/Basic",
101
+ "<SOME_PATH>/WenetSpeech4TTS/Standard",
102
+ "<SOME_PATH>/WenetSpeech4TTS/Premium",
103
+ ][-dataset_choice:]
104
+ print(f"\nChoose Dataset: {dataset_name}\n")
105
+
106
+ main()
107
+
108
+ # Results (if adding alphabets with accents and symbols):
109
+ # WenetSpeech4TTS Basic Standard Premium
110
+ # samples count 3932473 1941220 407494
111
+ # pinyin vocab size 1349 1348 1344 (no polyphone)
112
+ # - - 1459 (polyphone)
113
+ # char vocab size 5264 5219 5042
114
+
115
+ # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
116
+ # please be careful if using pretrained model, make sure the vocab.txt is same
test_infer_batch.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import random
4
+ from tqdm import tqdm
5
+ import argparse
6
+
7
+ import torch
8
+ import torchaudio
9
+ from accelerate import Accelerator
10
+ from einops import rearrange
11
+ from ema_pytorch import EMA
12
+ from vocos import Vocos
13
+
14
+ from model import CFM, UNetT, DiT
15
+ from model.utils import (
16
+ get_tokenizer,
17
+ get_seedtts_testset_metainfo,
18
+ get_librispeech_test_clean_metainfo,
19
+ get_inference_prompt,
20
+ )
21
+
22
+ accelerator = Accelerator()
23
+ device = f"cuda:{accelerator.process_index}"
24
+
25
+
26
+ # --------------------- Dataset Settings -------------------- #
27
+
28
+ target_sample_rate = 24000
29
+ n_mel_channels = 100
30
+ hop_length = 256
31
+ target_rms = 0.1
32
+
33
+ tokenizer = "pinyin"
34
+
35
+
36
+ # ---------------------- infer setting ---------------------- #
37
+
38
+ parser = argparse.ArgumentParser(description="batch inference")
39
+
40
+ parser.add_argument('-s', '--seed', default=None, type=int)
41
+ parser.add_argument('-d', '--dataset', default="Emilia_ZH_EN")
42
+ parser.add_argument('-n', '--expname', required=True)
43
+ parser.add_argument('-c', '--ckptstep', default=1200000, type=int)
44
+
45
+ parser.add_argument('-nfe', '--nfestep', default=32, type=int)
46
+ parser.add_argument('-o', '--odemethod', default="euler")
47
+ parser.add_argument('-ss', '--swaysampling', default=-1, type=float)
48
+
49
+ parser.add_argument('-t', '--testset', required=True)
50
+
51
+ args = parser.parse_args()
52
+
53
+
54
+ seed = args.seed
55
+ dataset_name = args.dataset
56
+ exp_name = args.expname
57
+ ckpt_step = args.ckptstep
58
+ checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device)
59
+
60
+ nfe_step = args.nfestep
61
+ ode_method = args.odemethod
62
+ sway_sampling_coef = args.swaysampling
63
+
64
+ testset = args.testset
65
+
66
+
67
+ infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
68
+ cfg_strength = 2.
69
+ speed = 1.
70
+ use_truth_duration = False
71
+ no_ref_audio = False
72
+
73
+
74
+ if exp_name == "F5TTS_Base":
75
+ model_cls = DiT
76
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
77
+
78
+ elif exp_name == "E2TTS_Base":
79
+ model_cls = UNetT
80
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
81
+
82
+
83
+ if testset == "ls_pc_test_clean":
84
+ metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
85
+ librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
86
+ metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
87
+
88
+ elif testset == "seedtts_test_zh":
89
+ metalst = "data/seedtts_testset/zh/meta.lst"
90
+ metainfo = get_seedtts_testset_metainfo(metalst)
91
+
92
+ elif testset == "seedtts_test_en":
93
+ metalst = "data/seedtts_testset/en/meta.lst"
94
+ metainfo = get_seedtts_testset_metainfo(metalst)
95
+
96
+
97
+ # path to save genereted wavs
98
+ if seed is None: seed = random.randint(-10000, 10000)
99
+ output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
100
+ f"seed{seed}_{ode_method}_nfe{nfe_step}" \
101
+ f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" \
102
+ f"_cfg{cfg_strength}_speed{speed}" \
103
+ f"{'_gt-dur' if use_truth_duration else ''}" \
104
+ f"{'_no-ref-audio' if no_ref_audio else ''}"
105
+
106
+
107
+ # -------------------------------------------------#
108
+
109
+ use_ema = True
110
+
111
+ prompts_all = get_inference_prompt(
112
+ metainfo,
113
+ speed = speed,
114
+ tokenizer = tokenizer,
115
+ target_sample_rate = target_sample_rate,
116
+ n_mel_channels = n_mel_channels,
117
+ hop_length = hop_length,
118
+ target_rms = target_rms,
119
+ use_truth_duration = use_truth_duration,
120
+ infer_batch_size = infer_batch_size,
121
+ )
122
+
123
+ # Vocoder model
124
+ local = False
125
+ if local:
126
+ vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
127
+ vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
128
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
129
+ vocos.load_state_dict(state_dict)
130
+ vocos.eval()
131
+ else:
132
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
133
+
134
+ # Tokenizer
135
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
136
+
137
+ # Model
138
+ model = CFM(
139
+ transformer = model_cls(
140
+ **model_cfg,
141
+ text_num_embeds = vocab_size,
142
+ mel_dim = n_mel_channels
143
+ ),
144
+ mel_spec_kwargs = dict(
145
+ target_sample_rate = target_sample_rate,
146
+ n_mel_channels = n_mel_channels,
147
+ hop_length = hop_length,
148
+ ),
149
+ odeint_kwargs = dict(
150
+ method = ode_method,
151
+ ),
152
+ vocab_char_map = vocab_char_map,
153
+ ).to(device)
154
+
155
+ if use_ema == True:
156
+ ema_model = EMA(model, include_online_model = False).to(device)
157
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
158
+ ema_model.copy_params_from_ema_to_model()
159
+ else:
160
+ model.load_state_dict(checkpoint['model_state_dict'])
161
+
162
+ if not os.path.exists(output_dir) and accelerator.is_main_process:
163
+ os.makedirs(output_dir)
164
+
165
+ # start batch inference
166
+ accelerator.wait_for_everyone()
167
+ start = time.time()
168
+
169
+ with accelerator.split_between_processes(prompts_all) as prompts:
170
+
171
+ for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
172
+ utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
173
+ ref_mels = ref_mels.to(device)
174
+ ref_mel_lens = torch.tensor(ref_mel_lens, dtype = torch.long).to(device)
175
+ total_mel_lens = torch.tensor(total_mel_lens, dtype = torch.long).to(device)
176
+
177
+ # Inference
178
+ with torch.inference_mode():
179
+ generated, _ = model.sample(
180
+ cond = ref_mels,
181
+ text = final_text_list,
182
+ duration = total_mel_lens,
183
+ lens = ref_mel_lens,
184
+ steps = nfe_step,
185
+ cfg_strength = cfg_strength,
186
+ sway_sampling_coef = sway_sampling_coef,
187
+ no_ref_audio = no_ref_audio,
188
+ seed = seed,
189
+ )
190
+ # Final result
191
+ for i, gen in enumerate(generated):
192
+ gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0)
193
+ gen_mel_spec = rearrange(gen, '1 n d -> 1 d n')
194
+ generated_wave = vocos.decode(gen_mel_spec.cpu())
195
+ if ref_rms_list[i] < target_rms:
196
+ generated_wave = generated_wave * ref_rms_list[i] / target_rms
197
+ torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
198
+
199
+ accelerator.wait_for_everyone()
200
+ if accelerator.is_main_process:
201
+ timediff = time.time() - start
202
+ print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
test_infer_batch.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # e.g. F5-TTS, 16 NFE
4
+ accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
5
+ accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
6
+ accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
7
+
8
+ # e.g. Vanilla E2 TTS, 32 NFE
9
+ accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
10
+ accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
11
+ accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
12
+
13
+ # etc.
test_infer_single.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import torch
5
+ import torchaudio
6
+ from einops import rearrange
7
+ from ema_pytorch import EMA
8
+ from vocos import Vocos
9
+
10
+ from model import CFM, UNetT, DiT, MMDiT
11
+ from model.utils import (
12
+ get_tokenizer,
13
+ convert_char_to_pinyin,
14
+ save_spectrogram,
15
+ )
16
+
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+
20
+ # --------------------- Dataset Settings -------------------- #
21
+
22
+ target_sample_rate = 24000
23
+ n_mel_channels = 100
24
+ hop_length = 256
25
+ target_rms = 0.1
26
+
27
+ tokenizer = "pinyin"
28
+ dataset_name = "Emilia_ZH_EN"
29
+
30
+
31
+ # ---------------------- infer setting ---------------------- #
32
+
33
+ seed = None # int | None
34
+
35
+ exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
36
+ ckpt_step = 1200000
37
+
38
+ nfe_step = 32 # 16, 32
39
+ cfg_strength = 2.
40
+ ode_method = 'euler' # euler | midpoint
41
+ sway_sampling_coef = -1.
42
+ speed = 1.
43
+ fix_duration = 27 # None (will linear estimate. if code-switched, consider fix) | float (total in seconds, include ref audio)
44
+
45
+ if exp_name == "F5TTS_Base":
46
+ model_cls = DiT
47
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
48
+
49
+ elif exp_name == "E2TTS_Base":
50
+ model_cls = UNetT
51
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
52
+
53
+ checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device)
54
+ output_dir = "tests"
55
+
56
+ ref_audio = "tests/ref_audio/test_en_1_ref_short.wav"
57
+ ref_text = "Some call me nature, others call me mother nature."
58
+ gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
59
+
60
+ # ref_audio = "tests/ref_audio/test_zh_1_ref_short.wav"
61
+ # ref_text = "对,这就是我,万人敬仰的太乙真人。"
62
+ # gen_text = "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:\"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?\""
63
+
64
+
65
+ # -------------------------------------------------#
66
+
67
+ use_ema = True
68
+
69
+ if not os.path.exists(output_dir):
70
+ os.makedirs(output_dir)
71
+
72
+ # Vocoder model
73
+ local = False
74
+ if local:
75
+ vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
76
+ vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
77
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
78
+ vocos.load_state_dict(state_dict)
79
+ vocos.eval()
80
+ else:
81
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
82
+
83
+ # Tokenizer
84
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
85
+
86
+ # Model
87
+ model = CFM(
88
+ transformer = model_cls(
89
+ **model_cfg,
90
+ text_num_embeds = vocab_size,
91
+ mel_dim = n_mel_channels
92
+ ),
93
+ mel_spec_kwargs = dict(
94
+ target_sample_rate = target_sample_rate,
95
+ n_mel_channels = n_mel_channels,
96
+ hop_length = hop_length,
97
+ ),
98
+ odeint_kwargs = dict(
99
+ method = ode_method,
100
+ ),
101
+ vocab_char_map = vocab_char_map,
102
+ ).to(device)
103
+
104
+ if use_ema == True:
105
+ ema_model = EMA(model, include_online_model = False).to(device)
106
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
107
+ ema_model.copy_params_from_ema_to_model()
108
+ else:
109
+ model.load_state_dict(checkpoint['model_state_dict'])
110
+
111
+ # Audio
112
+ audio, sr = torchaudio.load(ref_audio)
113
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
114
+ if rms < target_rms:
115
+ audio = audio * target_rms / rms
116
+ if sr != target_sample_rate:
117
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
118
+ audio = resampler(audio)
119
+ audio = audio.to(device)
120
+
121
+ # Text
122
+ text_list = [ref_text + gen_text]
123
+ if tokenizer == "pinyin":
124
+ final_text_list = convert_char_to_pinyin(text_list)
125
+ else:
126
+ final_text_list = [text_list]
127
+ print(f"text : {text_list}")
128
+ print(f"pinyin: {final_text_list}")
129
+
130
+ # Duration
131
+ ref_audio_len = audio.shape[-1] // hop_length
132
+ if fix_duration is not None:
133
+ duration = int(fix_duration * target_sample_rate / hop_length)
134
+ else: # simple linear scale calcul
135
+ zh_pause_punc = r"。,、;:?!"
136
+ ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
137
+ gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
138
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
139
+
140
+ # Inference
141
+ with torch.inference_mode():
142
+ generated, trajectory = model.sample(
143
+ cond = audio,
144
+ text = final_text_list,
145
+ duration = duration,
146
+ steps = nfe_step,
147
+ cfg_strength = cfg_strength,
148
+ sway_sampling_coef = sway_sampling_coef,
149
+ seed = seed,
150
+ )
151
+ print(f"Generated mel: {generated.shape}")
152
+
153
+ # Final result
154
+ generated = generated[:, ref_audio_len:, :]
155
+ generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
156
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
157
+ if rms < target_rms:
158
+ generated_wave = generated_wave * rms / target_rms
159
+
160
+ save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/test_single.png")
161
+ torchaudio.save(f"{output_dir}/test_single.wav", generated_wave, target_sample_rate)
162
+ print(f"Generated wav: {generated_wave.shape}")
test_train.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import CFM, UNetT, DiT, MMDiT, Trainer
2
+ from model.utils import get_tokenizer
3
+ from model.dataset import load_dataset
4
+
5
+
6
+ # -------------------------- Dataset Settings --------------------------- #
7
+
8
+ target_sample_rate = 24000
9
+ n_mel_channels = 100
10
+ hop_length = 256
11
+
12
+ tokenizer = "pinyin"
13
+ dataset_name = "Emilia_ZH_EN"
14
+
15
+
16
+ # -------------------------- Training Settings -------------------------- #
17
+
18
+ exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
19
+
20
+ learning_rate = 7.5e-5
21
+
22
+ batch_size_per_gpu = 38400 # 8 GPUs, 8 * 38400 = 307200
23
+ batch_size_type = "frame" # "frame" or "sample"
24
+ max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
25
+ grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps
26
+ max_grad_norm = 1.
27
+
28
+ epochs = 11 # use linear decay, thus epochs control the slope
29
+ num_warmup_updates = 20000 # warmup steps
30
+ save_per_updates = 50000 # save checkpoint per steps
31
+ last_per_steps = 5000 # save last checkpoint per steps
32
+
33
+ # model params
34
+ if exp_name == "F5TTS_Base":
35
+ wandb_resume_id = None
36
+ model_cls = DiT
37
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
38
+ elif exp_name == "E2TTS_Base":
39
+ wandb_resume_id = None
40
+ model_cls = UNetT
41
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
42
+
43
+
44
+ # ----------------------------------------------------------------------- #
45
+
46
+ def main():
47
+
48
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
49
+
50
+ mel_spec_kwargs = dict(
51
+ target_sample_rate = target_sample_rate,
52
+ n_mel_channels = n_mel_channels,
53
+ hop_length = hop_length,
54
+ )
55
+
56
+ e2tts = CFM(
57
+ transformer = model_cls(
58
+ **model_cfg,
59
+ text_num_embeds = vocab_size,
60
+ mel_dim = n_mel_channels
61
+ ),
62
+ mel_spec_kwargs = mel_spec_kwargs,
63
+ vocab_char_map = vocab_char_map,
64
+ )
65
+
66
+ trainer = Trainer(
67
+ e2tts,
68
+ epochs,
69
+ learning_rate,
70
+ num_warmup_updates = num_warmup_updates,
71
+ save_per_updates = save_per_updates,
72
+ checkpoint_path = f'ckpts/{exp_name}',
73
+ batch_size = batch_size_per_gpu,
74
+ batch_size_type = batch_size_type,
75
+ max_samples = max_samples,
76
+ grad_accumulation_steps = grad_accumulation_steps,
77
+ max_grad_norm = max_grad_norm,
78
+ wandb_project = "CFM-TTS",
79
+ wandb_run_name = exp_name,
80
+ wandb_resume_id = wandb_resume_id,
81
+ last_per_steps = last_per_steps,
82
+ )
83
+
84
+ train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
85
+ trainer.train(train_dataset,
86
+ resumable_with_seed = 666 # seed for shuffling dataset
87
+ )
88
+
89
+
90
+ if __name__ == '__main__':
91
+ main()
tests/ref_audio/test_en_1_ref_short.wav ADDED
Binary file (256 kB). View file
 
tests/ref_audio/test_zh_1_ref_short.wav ADDED
Binary file (325 kB). View file