Spaces:
Running
Running
initial commit
Browse files- .gitignore +22 -0
- .python-version +1 -0
- README.md +159 -2
- app.py +166 -0
- espeak_util.py +206 -0
- istftnet.py +523 -0
- kokoro.py +187 -0
- models.py +738 -0
- packages.txt +1 -0
- plbert.py +15 -0
- pretrained_models/Kokoro/__init__.py +3 -0
- pretrained_models/Kokoro/downloader.py +61 -0
- pyproject.toml +17 -0
- requirements.txt +9 -0
- tts_cli.py +510 -0
- tts_cli_op.py +569 -0
- uv.lock +0 -0
.gitignore
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pretrained_models/Kokoro/*.pth
|
2 |
+
pretrained_models/Kokoro/voices/*.pt
|
3 |
+
pretrained_models/Kokoro/config.json
|
4 |
+
output.wav
|
5 |
+
.env
|
6 |
+
# Ommit the DS_Store folder automatically created by macOS
|
7 |
+
.DS_Store/
|
8 |
+
|
9 |
+
# python virtual environment folder
|
10 |
+
.venv/
|
11 |
+
.vscode/
|
12 |
+
|
13 |
+
# process log file
|
14 |
+
process_log.txt*
|
15 |
+
process_log.txt
|
16 |
+
|
17 |
+
# Python cache
|
18 |
+
__pycache__/
|
19 |
+
/*/__pycache__/
|
20 |
+
/*/*/__pycache__/
|
21 |
+
|
22 |
+
|
.python-version
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
3.11
|
README.md
CHANGED
@@ -8,7 +8,164 @@ sdk_version: 5.9.1
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
-
short_description: Simple Space for the Kokoro Model
|
|
|
|
|
|
|
|
|
|
|
12 |
---
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
+
short_description: Simple Space for the comparing the espeak-ng and openphonemizer with the Kokoro Model
|
12 |
+
base_model:
|
13 |
+
- yl4579/StyleTTS2-LJSpeech
|
14 |
+
- hexgrad/Kokoro-82M
|
15 |
+
- openphonemizer/ckpt
|
16 |
+
pipeline_tag: text-to-speech
|
17 |
---
|
18 |
|
19 |
+
❤️ Kokoro Discord Server: <https://discord.gg/QuGxSWBfQy>
|
20 |
+
|
21 |
+
<audio controls><source src="https://huggingface.co/hexgrad/Kokoro-82M/resolve/main/demo/HEARME.wav" type="audio/wav"></audio>
|
22 |
+
|
23 |
+
**Kokoro** is a frontier TTS model for its size of **82 million parameters** (text in/audio out).
|
24 |
+
|
25 |
+
On 25 Dec 2024, Kokoro v0.19 weights were permissively released in full fp32 precision along with 2 voicepacks (Bella and Sarah), all under an Apache 2.0 license.
|
26 |
+
|
27 |
+
As of 28 Dec 2024, **8 unique Voicepacks have been released**: 2F 2M each for American and British English.
|
28 |
+
|
29 |
+
At the time of release, Kokoro v0.19 was the #1🥇 ranked model in [TTS Spaces Arena](https://huggingface.co/spaces/Pendrokar/TTS-Spaces-Arena). Kokoro had achieved higher Elo in this single-voice Arena setting over other models, using fewer parameters and less data:
|
30 |
+
|
31 |
+
1. **Kokoro v0.19: 82M params, Apache, trained on <100 hours of audio, for <20 epochs**
|
32 |
+
2. XTTS v2: 467M, CPML, >10k hours
|
33 |
+
3. Edge TTS: Microsoft, proprietary
|
34 |
+
4. MetaVoice: 1.2B, Apache, 100k hours
|
35 |
+
5. Parler Mini: 880M, Apache, 45k hours
|
36 |
+
6. Fish Speech: ~500M, CC-BY-NC-SA, 1M hours
|
37 |
+
|
38 |
+
Kokoro's ability to top this Elo ladder suggests that the scaling law (Elo vs compute/data/params) for traditional TTS models might have a steeper slope than previously expected.
|
39 |
+
|
40 |
+
You can find a hosted demo at [hf.co/spaces/hexgrad/Kokoro-TTS](https://huggingface.co/spaces/hexgrad/Kokoro-TTS).
|
41 |
+
|
42 |
+
### Usage
|
43 |
+
|
44 |
+
The following can be run in a single cell on [Google Colab](https://colab.research.google.com/).
|
45 |
+
|
46 |
+
```py
|
47 |
+
# 1️⃣ Install dependencies silently
|
48 |
+
!git clone https://huggingface.co/hexgrad/Kokoro-82M
|
49 |
+
%cd Kokoro-82M
|
50 |
+
!apt-get -qq -y install espeak-ng > /dev/null 2>&1
|
51 |
+
!pip install -q phonemizer torch transformers scipy munch
|
52 |
+
|
53 |
+
# 2️⃣ Build the model and load the default voicepack
|
54 |
+
from models import build_model
|
55 |
+
import torch
|
56 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
57 |
+
MODEL = build_model('kokoro-v0_19.pth', device)
|
58 |
+
VOICE_NAME = [
|
59 |
+
'af', # Default voice is a 50-50 mix of af_bella & af_sarah
|
60 |
+
'af_bella', 'af_sarah', 'am_adam', 'am_michael',
|
61 |
+
'bf_emma', 'bf_isabella', 'bm_george', 'bm_lewis',
|
62 |
+
][0]
|
63 |
+
VOICEPACK = torch.load(f'voices/{VOICE_NAME}.pt', weights_only=True).to(device)
|
64 |
+
print(f'Loaded voice: {VOICE_NAME}')
|
65 |
+
|
66 |
+
# 3️⃣ Call generate, which returns a 24khz audio waveform and a string of output phonemes
|
67 |
+
from kokoro import generate
|
68 |
+
text = "How could I know? It's an unanswerable question. Like asking an unborn child if they'll lead a good life. They haven't even been born."
|
69 |
+
audio, out_ps = generate(MODEL, text, VOICEPACK, lang=VOICE_NAME[0])
|
70 |
+
# Language is determined by the first letter of the VOICE_NAME:
|
71 |
+
# 🇺🇸 'a' => American English => en-us
|
72 |
+
# 🇬🇧 'b' => British English => en-gb
|
73 |
+
|
74 |
+
# 4️⃣ Display the 24khz audio and print the output phonemes
|
75 |
+
from IPython.display import display, Audio
|
76 |
+
display(Audio(data=audio, rate=24000, autoplay=True))
|
77 |
+
print(out_ps)
|
78 |
+
```
|
79 |
+
|
80 |
+
The inference code was quickly hacked together on Christmas Day. It is not clean code and leaves a lot of room for improvement. If you'd like to contribute, feel free to open a PR.
|
81 |
+
|
82 |
+
### Model Facts
|
83 |
+
|
84 |
+
No affiliation can be assumed between parties on different lines.
|
85 |
+
|
86 |
+
**Architecture:**
|
87 |
+
|
88 |
+
- StyleTTS 2: <https://arxiv.org/abs/2306.07691>
|
89 |
+
- ISTFTNet: <https://arxiv.org/abs/2203.02395>
|
90 |
+
- Decoder only: no diffusion, no encoder release
|
91 |
+
|
92 |
+
**Architected by:** Li et al @ <https://github.com/yl4579/StyleTTS2>
|
93 |
+
|
94 |
+
**Trained by**: `@rzvzn` on Discord
|
95 |
+
|
96 |
+
**Supported Languages:** American English, British English
|
97 |
+
|
98 |
+
**Model SHA256 Hash:** `3b0c392f87508da38fad3a2f9d94c359f1b657ebd2ef79f9d56d69503e470b0a`
|
99 |
+
|
100 |
+
### Releases
|
101 |
+
|
102 |
+
- 25 Dec 2024: Model v0.19, `af_bella`, `af_sarah`
|
103 |
+
- 26 Dec 2024: `am_adam`, `am_michael`
|
104 |
+
- 28 Dec 2024: `bf_emma`, `bf_isabella`, `bm_george`, `bm_lewis`
|
105 |
+
|
106 |
+
### Licenses
|
107 |
+
|
108 |
+
- Apache 2.0 weights in this repository
|
109 |
+
- MIT inference code in [spaces/hexgrad/Kokoro-TTS](https://huggingface.co/spaces/hexgrad/Kokoro-TTS) adapted from [yl4579/StyleTTS2](https://github.com/yl4579/StyleTTS2)
|
110 |
+
- GPLv3 dependency in [espeak-ng](https://github.com/espeak-ng/espeak-ng)
|
111 |
+
|
112 |
+
The inference code was originally MIT licensed by the paper author. Note that this card applies only to this model, Kokoro. Original models published by the paper author can be found at [hf.co/yl4579](https://huggingface.co/yl4579).
|
113 |
+
|
114 |
+
### Evaluation
|
115 |
+
|
116 |
+
**Metric:** Elo rating
|
117 |
+
|
118 |
+
**Leaderboard:** [hf.co/spaces/Pendrokar/TTS-Spaces-Arena](https://huggingface.co/spaces/Pendrokar/TTS-Spaces-Arena)
|
119 |
+
|
120 |
+
![TTS-Spaces-Arena-25-Dec-2024](demo/TTS-Spaces-Arena-25-Dec-2024.png)
|
121 |
+
|
122 |
+
The voice ranked in the Arena is a 50-50 mix of Bella and Sarah. For your convenience, this mix is included in this repository as `af.pt`, but you can trivially reproduce it like this:
|
123 |
+
|
124 |
+
```py
|
125 |
+
import torch
|
126 |
+
bella = torch.load('voices/af_bella.pt', weights_only=True)
|
127 |
+
sarah = torch.load('voices/af_sarah.pt', weights_only=True)
|
128 |
+
af = torch.mean(torch.stack([bella, sarah]), dim=0)
|
129 |
+
assert torch.equal(af, torch.load('voices/af.pt', weights_only=True))
|
130 |
+
```
|
131 |
+
|
132 |
+
### Training Details
|
133 |
+
|
134 |
+
**Compute:** Kokoro was trained on A100 80GB vRAM instances rented from [Vast.ai](https://cloud.vast.ai/?ref_id=79907) (referral link). Vast was chosen over other compute providers due to its competitive on-demand hourly rates. The average hourly cost for the A100 80GB vRAM instances used for training was below $1/hr per GPU, which was around half the quoted rates from other providers at the time.
|
135 |
+
|
136 |
+
**Data:** Kokoro was trained exclusively on **permissive/non-copyrighted audio data** and IPA phoneme labels. Examples of permissive/non-copyrighted audio include:
|
137 |
+
|
138 |
+
- Public domain audio
|
139 |
+
- Audio licensed under Apache, MIT, etc
|
140 |
+
- Synthetic audio<sup>[1]</sup> generated by closed<sup>[2]</sup> TTS models from large providers<br/>
|
141 |
+
[1] <https://copyright.gov/ai/ai_policy_guidance.pdf><br/>
|
142 |
+
[2] No synthetic audio from open TTS models or "custom voice clones"
|
143 |
+
|
144 |
+
**Epochs:** Less than **20 epochs**
|
145 |
+
|
146 |
+
**Total Dataset Size:** Less than **100 hours** of audio
|
147 |
+
|
148 |
+
### Limitations
|
149 |
+
|
150 |
+
Kokoro v0.19 is limited in some specific ways, due to its training set and/or architecture:
|
151 |
+
|
152 |
+
- [Data] Lacks voice cloning capability, likely due to small <100h training set
|
153 |
+
- [Arch] Relies on external g2p (espeak-ng), which introduces a class of g2p failure modes
|
154 |
+
- [Data] Training dataset is mostly long-form reading and narration, not conversation
|
155 |
+
- [Arch] At 82M params, Kokoro almost certainly falls to a well-trained 1B+ param diffusion transformer, or a many-billion-param MLLM like GPT-4o / Gemini 2.0 Flash
|
156 |
+
- [Data] Multilingual capability is architecturally feasible, but training data is mostly English
|
157 |
+
|
158 |
+
Refer to the [Philosophy discussion](https://huggingface.co/hexgrad/Kokoro-82M/discussions/5) to better understand these limitations.
|
159 |
+
|
160 |
+
**Will the other voicepacks be released?** There is currently no release date scheduled for the other voicepacks, but in the meantime you can try them in the hosted demo at [hf.co/spaces/hexgrad/Kokoro-TTS](https://huggingface.co/spaces/hexgrad/Kokoro-TTS).
|
161 |
+
|
162 |
+
### Acknowledgements
|
163 |
+
|
164 |
+
- [@yl4579](https://huggingface.co/yl4579) for architecting StyleTTS 2
|
165 |
+
- [@Pendrokar](https://huggingface.co/Pendrokar) for adding Kokoro as a contender in the TTS Spaces Arena
|
166 |
+
|
167 |
+
### Model Card Contact
|
168 |
+
|
169 |
+
`@rzvzn` on Discord. Server invite: <https://discord.gg/QuGxSWBfQy>
|
170 |
+
|
171 |
+
<img src="https://static0.gamerantimages.com/wordpress/wp-content/uploads/2024/08/terminator-zero-41-1.jpg" width="400" alt="kokoro" />
|
app.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
|
5 |
+
# Import eSpeak TTS pipeline
|
6 |
+
from tts_cli import (
|
7 |
+
build_model as build_model_espeak,
|
8 |
+
generate_long_form_tts as generate_long_form_tts_espeak,
|
9 |
+
)
|
10 |
+
|
11 |
+
# Import OpenPhonemizer TTS pipeline
|
12 |
+
from tts_cli_op import (
|
13 |
+
build_model as build_model_open,
|
14 |
+
generate_long_form_tts as generate_long_form_tts_open,
|
15 |
+
)
|
16 |
+
from pretrained_models import Kokoro
|
17 |
+
|
18 |
+
# ---------------------------------------------------------------------
|
19 |
+
# Path to models and voicepacks
|
20 |
+
# ---------------------------------------------------------------------
|
21 |
+
MODELS_DIR = "pretrained_models/Kokoro"
|
22 |
+
VOICES_DIR = "pretrained_models/Kokoro/voices"
|
23 |
+
|
24 |
+
|
25 |
+
# ---------------------------------------------------------------------
|
26 |
+
# List the models (.pth) and voices (.pt)
|
27 |
+
# ---------------------------------------------------------------------
|
28 |
+
def get_models():
|
29 |
+
return sorted([f for f in os.listdir(MODELS_DIR) if f.endswith(".pth")])
|
30 |
+
|
31 |
+
|
32 |
+
def get_voices():
|
33 |
+
return sorted([f for f in os.listdir(VOICES_DIR) if f.endswith(".pt")])
|
34 |
+
|
35 |
+
|
36 |
+
# ---------------------------------------------------------------------
|
37 |
+
# We'll map engine selection -> (build_model_func, generate_func)
|
38 |
+
# ---------------------------------------------------------------------
|
39 |
+
ENGINES = {
|
40 |
+
"espeak": (build_model_espeak, generate_long_form_tts_espeak),
|
41 |
+
"openphonemizer": (build_model_open, generate_long_form_tts_open),
|
42 |
+
}
|
43 |
+
|
44 |
+
|
45 |
+
# ---------------------------------------------------------------------
|
46 |
+
# The main inference function called by Gradio
|
47 |
+
# ---------------------------------------------------------------------
|
48 |
+
def tts_inference(text, engine, model_file, voice_file, speed=1.0):
|
49 |
+
"""
|
50 |
+
text: Input string
|
51 |
+
engine: "espeak" or "openphonemizer"
|
52 |
+
model_file: Selected .pth from the models folder
|
53 |
+
voice_file: Selected .pt from the voices folder
|
54 |
+
speed: Speech speed
|
55 |
+
"""
|
56 |
+
# 1) Map engine to the correct build_model + generate_long_form_tts
|
57 |
+
build_fn, gen_fn = ENGINES[engine]
|
58 |
+
|
59 |
+
# 2) Prepare paths
|
60 |
+
model_path = os.path.join(MODELS_DIR, model_file)
|
61 |
+
voice_path = os.path.join(VOICES_DIR, voice_file)
|
62 |
+
|
63 |
+
# 3) Decide device
|
64 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
65 |
+
|
66 |
+
# 4) Load model
|
67 |
+
model = build_fn(model_path, device=device)
|
68 |
+
# Set submodules eval
|
69 |
+
for k, subm in model.items():
|
70 |
+
if hasattr(subm, "eval"):
|
71 |
+
subm.eval()
|
72 |
+
|
73 |
+
# 5) Load voicepack
|
74 |
+
voicepack = torch.load(voice_path, map_location=device)
|
75 |
+
if hasattr(voicepack, "eval"):
|
76 |
+
voicepack.eval()
|
77 |
+
|
78 |
+
# 6) Generate TTS
|
79 |
+
audio, phonemes = gen_fn(model, text, voicepack, speed=speed)
|
80 |
+
sr = 22050 # or your actual sample rate
|
81 |
+
|
82 |
+
return (sr, audio) # Gradio expects (sample_rate, np_array)
|
83 |
+
|
84 |
+
|
85 |
+
# ---------------------------------------------------------------------
|
86 |
+
# Build Gradio App
|
87 |
+
# ---------------------------------------------------------------------
|
88 |
+
def create_gradio_app():
|
89 |
+
model_list = get_models()
|
90 |
+
voice_list = get_voices()
|
91 |
+
|
92 |
+
css = """
|
93 |
+
h4 {
|
94 |
+
text-align: center;
|
95 |
+
display:block;
|
96 |
+
}
|
97 |
+
h2 {
|
98 |
+
text-align: center;
|
99 |
+
display:block;
|
100 |
+
}
|
101 |
+
"""
|
102 |
+
with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo:
|
103 |
+
gr.Markdown("## Kokoro TTS Demo: Choose engine, model, and voice")
|
104 |
+
|
105 |
+
# Row 1: Text input
|
106 |
+
text_input = gr.Textbox(
|
107 |
+
label="Input Text",
|
108 |
+
value="Hello, world! Testing both eSpeak and OpenPhonemizer. Can you believe that we live in 2024 and have access to advanced AI?",
|
109 |
+
lines=3,
|
110 |
+
)
|
111 |
+
|
112 |
+
# Row 2: Engine selection
|
113 |
+
engine_dropdown = gr.Dropdown(
|
114 |
+
choices=["espeak", "openphonemizer"],
|
115 |
+
value="openphonemizer",
|
116 |
+
label="Phonemizer",
|
117 |
+
)
|
118 |
+
|
119 |
+
# Row 3: Model dropdown
|
120 |
+
model_dropdown = gr.Dropdown(
|
121 |
+
choices=model_list,
|
122 |
+
value=model_list[0] if model_list else None,
|
123 |
+
label="Model (.pth)",
|
124 |
+
)
|
125 |
+
|
126 |
+
# Row 4: Voice dropdown
|
127 |
+
voice_dropdown = gr.Dropdown(
|
128 |
+
choices=voice_list,
|
129 |
+
value=voice_list[0] if voice_list else None,
|
130 |
+
label="Voice (.pt)",
|
131 |
+
)
|
132 |
+
|
133 |
+
# Row 5: Speed slider
|
134 |
+
speed_slider = gr.Slider(
|
135 |
+
minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speech Speed"
|
136 |
+
)
|
137 |
+
|
138 |
+
# Generate button + audio output
|
139 |
+
generate_btn = gr.Button("Generate")
|
140 |
+
tts_output = gr.Audio(label="TTS Output")
|
141 |
+
|
142 |
+
# Connect the button to our inference function
|
143 |
+
generate_btn.click(
|
144 |
+
fn=tts_inference,
|
145 |
+
inputs=[
|
146 |
+
text_input,
|
147 |
+
engine_dropdown,
|
148 |
+
model_dropdown,
|
149 |
+
voice_dropdown,
|
150 |
+
speed_slider,
|
151 |
+
],
|
152 |
+
outputs=tts_output,
|
153 |
+
)
|
154 |
+
|
155 |
+
gr.Markdown(
|
156 |
+
"#### Kokoro TTS Demo based on [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M)"
|
157 |
+
)
|
158 |
+
return demo
|
159 |
+
|
160 |
+
|
161 |
+
# ---------------------------------------------------------------------
|
162 |
+
# Main
|
163 |
+
# ---------------------------------------------------------------------
|
164 |
+
if __name__ == "__main__":
|
165 |
+
app = create_gradio_app()
|
166 |
+
app.launch()
|
espeak_util.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import platform
|
2 |
+
import subprocess
|
3 |
+
import shutil
|
4 |
+
from pathlib import Path
|
5 |
+
import os
|
6 |
+
from typing import Optional, Tuple
|
7 |
+
from phonemizer.backend.espeak.wrapper import EspeakWrapper
|
8 |
+
|
9 |
+
|
10 |
+
class EspeakConfig:
|
11 |
+
"""Utility class for configuring espeak-ng library and binary."""
|
12 |
+
|
13 |
+
@staticmethod
|
14 |
+
def find_espeak_binary() -> tuple[bool, Optional[str]]:
|
15 |
+
"""
|
16 |
+
Find espeak-ng binary using multiple methods.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
tuple: (bool indicating if espeak is available, path to espeak binary if found)
|
20 |
+
"""
|
21 |
+
# Common binary names
|
22 |
+
binary_names = ["espeak-ng", "espeak"]
|
23 |
+
if platform.system() == "Windows":
|
24 |
+
binary_names = ["espeak-ng.exe", "espeak.exe"]
|
25 |
+
|
26 |
+
# Common installation directories for Linux
|
27 |
+
linux_paths = [
|
28 |
+
"/usr/bin",
|
29 |
+
"/usr/local/bin",
|
30 |
+
"/usr/lib/espeak-ng",
|
31 |
+
"/usr/local/lib/espeak-ng",
|
32 |
+
"/opt/espeak-ng/bin",
|
33 |
+
]
|
34 |
+
|
35 |
+
# First check if it's in PATH
|
36 |
+
for name in binary_names:
|
37 |
+
espeak_path = shutil.which(name)
|
38 |
+
if espeak_path:
|
39 |
+
return True, espeak_path
|
40 |
+
|
41 |
+
# For Linux, check common installation directories
|
42 |
+
if platform.system() == "Linux":
|
43 |
+
for directory in linux_paths:
|
44 |
+
for name in binary_names:
|
45 |
+
path = Path(directory) / name
|
46 |
+
if path.exists():
|
47 |
+
return True, str(path)
|
48 |
+
|
49 |
+
# Try running the command directly as a last resort
|
50 |
+
try:
|
51 |
+
subprocess.run(
|
52 |
+
["espeak-ng", "--version"],
|
53 |
+
stdout=subprocess.PIPE,
|
54 |
+
stderr=subprocess.PIPE,
|
55 |
+
check=True,
|
56 |
+
)
|
57 |
+
return True, "espeak-ng"
|
58 |
+
except (subprocess.SubprocessError, FileNotFoundError):
|
59 |
+
pass
|
60 |
+
|
61 |
+
return False, None
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def find_library_path() -> Optional[str]:
|
65 |
+
"""
|
66 |
+
Find the espeak-ng library using multiple search methods.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
Optional[str]: Path to the library if found, None otherwise
|
70 |
+
"""
|
71 |
+
system = platform.system()
|
72 |
+
|
73 |
+
if system == "Linux":
|
74 |
+
lib_names = ["libespeak-ng.so", "libespeak-ng.so.1"]
|
75 |
+
common_paths = [
|
76 |
+
# Debian/Ubuntu paths
|
77 |
+
"/usr/lib/x86_64-linux-gnu",
|
78 |
+
"/usr/lib/aarch64-linux-gnu", # For ARM64
|
79 |
+
"/usr/lib/arm-linux-gnueabihf", # For ARM32
|
80 |
+
"/usr/lib",
|
81 |
+
"/usr/local/lib",
|
82 |
+
# Fedora/RHEL paths
|
83 |
+
"/usr/lib64",
|
84 |
+
"/usr/lib32",
|
85 |
+
# Common additional paths
|
86 |
+
"/usr/lib/espeak-ng",
|
87 |
+
"/usr/local/lib/espeak-ng",
|
88 |
+
"/opt/espeak-ng/lib",
|
89 |
+
]
|
90 |
+
|
91 |
+
# Check common locations first
|
92 |
+
for path in common_paths:
|
93 |
+
for lib_name in lib_names:
|
94 |
+
lib_path = Path(path) / lib_name
|
95 |
+
if lib_path.exists():
|
96 |
+
return str(lib_path)
|
97 |
+
|
98 |
+
# Search system library paths
|
99 |
+
try:
|
100 |
+
# Use ldconfig to find the library
|
101 |
+
result = subprocess.run(
|
102 |
+
["ldconfig", "-p"], capture_output=True, text=True, check=True
|
103 |
+
)
|
104 |
+
for line in result.stdout.splitlines():
|
105 |
+
if "libespeak-ng.so" in line:
|
106 |
+
# Extract path from ldconfig output
|
107 |
+
return line.split("=>")[-1].strip()
|
108 |
+
except (subprocess.SubprocessError, FileNotFoundError):
|
109 |
+
pass
|
110 |
+
|
111 |
+
elif system == "Darwin": # macOS
|
112 |
+
common_paths = [
|
113 |
+
Path("/opt/homebrew/lib/libespeak-ng.dylib"),
|
114 |
+
Path("/usr/local/lib/libespeak-ng.dylib"),
|
115 |
+
*list(
|
116 |
+
Path("/opt/homebrew/Cellar/espeak-ng").glob(
|
117 |
+
"*/lib/libespeak-ng.dylib"
|
118 |
+
)
|
119 |
+
),
|
120 |
+
*list(
|
121 |
+
Path("/usr/local/Cellar/espeak-ng").glob("*/lib/libespeak-ng.dylib")
|
122 |
+
),
|
123 |
+
]
|
124 |
+
|
125 |
+
for path in common_paths:
|
126 |
+
if path.exists():
|
127 |
+
return str(path)
|
128 |
+
|
129 |
+
elif system == "Windows":
|
130 |
+
common_paths = [
|
131 |
+
Path(os.environ.get("PROGRAMFILES", "C:\\Program Files"))
|
132 |
+
/ "eSpeak NG"
|
133 |
+
/ "libespeak-ng.dll",
|
134 |
+
Path(os.environ.get("PROGRAMFILES(X86)", "C:\\Program Files (x86)"))
|
135 |
+
/ "eSpeak NG"
|
136 |
+
/ "libespeak-ng.dll",
|
137 |
+
*[
|
138 |
+
Path(p) / "libespeak-ng.dll"
|
139 |
+
for p in os.environ.get("PATH", "").split(os.pathsep)
|
140 |
+
],
|
141 |
+
]
|
142 |
+
|
143 |
+
for path in common_paths:
|
144 |
+
if path.exists():
|
145 |
+
return str(path)
|
146 |
+
|
147 |
+
return None
|
148 |
+
|
149 |
+
@classmethod
|
150 |
+
def configure_espeak(cls) -> Tuple[bool, str]:
|
151 |
+
"""
|
152 |
+
Configure espeak-ng for use with the phonemizer.
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
Tuple[bool, str]: (Success status, Status message)
|
156 |
+
"""
|
157 |
+
# First check if espeak binary is available
|
158 |
+
espeak_available, espeak_path = cls.find_espeak_binary()
|
159 |
+
if not espeak_available:
|
160 |
+
raise FileNotFoundError(
|
161 |
+
"Could not find espeak-ng binary. Please install espeak-ng:\n"
|
162 |
+
"Ubuntu/Debian: sudo apt-get install espeak-ng espeak-ng-data\n"
|
163 |
+
"Fedora: sudo dnf install espeak-ng\n"
|
164 |
+
"Arch: sudo pacman -S espeak-ng\n"
|
165 |
+
"MacOS: brew install espeak-ng\n"
|
166 |
+
"Windows: Download from https://github.com/espeak-ng/espeak-ng/releases"
|
167 |
+
)
|
168 |
+
|
169 |
+
# Find the library
|
170 |
+
library_path = cls.find_library_path()
|
171 |
+
if not library_path:
|
172 |
+
# On Linux, we might not need to explicitly set the library path
|
173 |
+
if platform.system() == "Linux":
|
174 |
+
return True, f"Using system espeak-ng installation at: {espeak_path}"
|
175 |
+
else:
|
176 |
+
raise FileNotFoundError(
|
177 |
+
"Could not find espeak-ng library. Please ensure espeak-ng is properly installed."
|
178 |
+
)
|
179 |
+
|
180 |
+
# Try to set the library path
|
181 |
+
try:
|
182 |
+
EspeakWrapper.set_library(library_path)
|
183 |
+
return True, f"Successfully configured espeak-ng library at: {library_path}"
|
184 |
+
except Exception as e:
|
185 |
+
if platform.system() == "Linux":
|
186 |
+
# On Linux, try to continue without explicit library path
|
187 |
+
return True, f"Using system espeak-ng installation at: {espeak_path}"
|
188 |
+
else:
|
189 |
+
raise RuntimeError(f"Failed to configure espeak-ng library: {str(e)}")
|
190 |
+
|
191 |
+
|
192 |
+
def setup_espeak():
|
193 |
+
"""
|
194 |
+
Set up espeak-ng for use with the phonemizer.
|
195 |
+
Raises appropriate exceptions if setup fails.
|
196 |
+
"""
|
197 |
+
try:
|
198 |
+
success, message = EspeakConfig.configure_espeak()
|
199 |
+
print(message)
|
200 |
+
except Exception as e:
|
201 |
+
print(f"Error configuring espeak-ng: {str(e)}")
|
202 |
+
raise
|
203 |
+
|
204 |
+
|
205 |
+
# Replace the original set_espeak_library function with this
|
206 |
+
set_espeak_library = setup_espeak
|
istftnet.py
ADDED
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
|
2 |
+
from scipy.signal import get_window
|
3 |
+
from torch.nn import Conv1d, ConvTranspose1d
|
4 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
|
11 |
+
def init_weights(m, mean=0.0, std=0.01):
|
12 |
+
classname = m.__class__.__name__
|
13 |
+
if classname.find("Conv") != -1:
|
14 |
+
m.weight.data.normal_(mean, std)
|
15 |
+
|
16 |
+
def get_padding(kernel_size, dilation=1):
|
17 |
+
return int((kernel_size*dilation - dilation)/2)
|
18 |
+
|
19 |
+
LRELU_SLOPE = 0.1
|
20 |
+
|
21 |
+
class AdaIN1d(nn.Module):
|
22 |
+
def __init__(self, style_dim, num_features):
|
23 |
+
super().__init__()
|
24 |
+
self.norm = nn.InstanceNorm1d(num_features, affine=False)
|
25 |
+
self.fc = nn.Linear(style_dim, num_features*2)
|
26 |
+
|
27 |
+
def forward(self, x, s):
|
28 |
+
h = self.fc(s)
|
29 |
+
h = h.view(h.size(0), h.size(1), 1)
|
30 |
+
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
31 |
+
return (1 + gamma) * self.norm(x) + beta
|
32 |
+
|
33 |
+
class AdaINResBlock1(torch.nn.Module):
|
34 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
|
35 |
+
super(AdaINResBlock1, self).__init__()
|
36 |
+
self.convs1 = nn.ModuleList([
|
37 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
38 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
39 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
40 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
41 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
42 |
+
padding=get_padding(kernel_size, dilation[2])))
|
43 |
+
])
|
44 |
+
self.convs1.apply(init_weights)
|
45 |
+
|
46 |
+
self.convs2 = nn.ModuleList([
|
47 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
48 |
+
padding=get_padding(kernel_size, 1))),
|
49 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
50 |
+
padding=get_padding(kernel_size, 1))),
|
51 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
52 |
+
padding=get_padding(kernel_size, 1)))
|
53 |
+
])
|
54 |
+
self.convs2.apply(init_weights)
|
55 |
+
|
56 |
+
self.adain1 = nn.ModuleList([
|
57 |
+
AdaIN1d(style_dim, channels),
|
58 |
+
AdaIN1d(style_dim, channels),
|
59 |
+
AdaIN1d(style_dim, channels),
|
60 |
+
])
|
61 |
+
|
62 |
+
self.adain2 = nn.ModuleList([
|
63 |
+
AdaIN1d(style_dim, channels),
|
64 |
+
AdaIN1d(style_dim, channels),
|
65 |
+
AdaIN1d(style_dim, channels),
|
66 |
+
])
|
67 |
+
|
68 |
+
self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
|
69 |
+
self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
|
70 |
+
|
71 |
+
|
72 |
+
def forward(self, x, s):
|
73 |
+
for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
|
74 |
+
xt = n1(x, s)
|
75 |
+
xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
|
76 |
+
xt = c1(xt)
|
77 |
+
xt = n2(xt, s)
|
78 |
+
xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
|
79 |
+
xt = c2(xt)
|
80 |
+
x = xt + x
|
81 |
+
return x
|
82 |
+
|
83 |
+
def remove_weight_norm(self):
|
84 |
+
for l in self.convs1:
|
85 |
+
remove_weight_norm(l)
|
86 |
+
for l in self.convs2:
|
87 |
+
remove_weight_norm(l)
|
88 |
+
|
89 |
+
class TorchSTFT(torch.nn.Module):
|
90 |
+
def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
|
91 |
+
super().__init__()
|
92 |
+
self.filter_length = filter_length
|
93 |
+
self.hop_length = hop_length
|
94 |
+
self.win_length = win_length
|
95 |
+
self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32))
|
96 |
+
|
97 |
+
def transform(self, input_data):
|
98 |
+
forward_transform = torch.stft(
|
99 |
+
input_data,
|
100 |
+
self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
|
101 |
+
return_complex=True)
|
102 |
+
|
103 |
+
return torch.abs(forward_transform), torch.angle(forward_transform)
|
104 |
+
|
105 |
+
def inverse(self, magnitude, phase):
|
106 |
+
inverse_transform = torch.istft(
|
107 |
+
magnitude * torch.exp(phase * 1j),
|
108 |
+
self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
|
109 |
+
|
110 |
+
return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
|
111 |
+
|
112 |
+
def forward(self, input_data):
|
113 |
+
self.magnitude, self.phase = self.transform(input_data)
|
114 |
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
115 |
+
return reconstruction
|
116 |
+
|
117 |
+
class SineGen(torch.nn.Module):
|
118 |
+
""" Definition of sine generator
|
119 |
+
SineGen(samp_rate, harmonic_num = 0,
|
120 |
+
sine_amp = 0.1, noise_std = 0.003,
|
121 |
+
voiced_threshold = 0,
|
122 |
+
flag_for_pulse=False)
|
123 |
+
samp_rate: sampling rate in Hz
|
124 |
+
harmonic_num: number of harmonic overtones (default 0)
|
125 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
126 |
+
noise_std: std of Gaussian noise (default 0.003)
|
127 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
128 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
129 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
130 |
+
segment is always sin(np.pi) or cos(0)
|
131 |
+
"""
|
132 |
+
|
133 |
+
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
|
134 |
+
sine_amp=0.1, noise_std=0.003,
|
135 |
+
voiced_threshold=0,
|
136 |
+
flag_for_pulse=False):
|
137 |
+
super(SineGen, self).__init__()
|
138 |
+
self.sine_amp = sine_amp
|
139 |
+
self.noise_std = noise_std
|
140 |
+
self.harmonic_num = harmonic_num
|
141 |
+
self.dim = self.harmonic_num + 1
|
142 |
+
self.sampling_rate = samp_rate
|
143 |
+
self.voiced_threshold = voiced_threshold
|
144 |
+
self.flag_for_pulse = flag_for_pulse
|
145 |
+
self.upsample_scale = upsample_scale
|
146 |
+
|
147 |
+
def _f02uv(self, f0):
|
148 |
+
# generate uv signal
|
149 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
150 |
+
return uv
|
151 |
+
|
152 |
+
def _f02sine(self, f0_values):
|
153 |
+
""" f0_values: (batchsize, length, dim)
|
154 |
+
where dim indicates fundamental tone and overtones
|
155 |
+
"""
|
156 |
+
# convert to F0 in rad. The interger part n can be ignored
|
157 |
+
# because 2 * np.pi * n doesn't affect phase
|
158 |
+
rad_values = (f0_values / self.sampling_rate) % 1
|
159 |
+
|
160 |
+
# initial phase noise (no noise for fundamental component)
|
161 |
+
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
|
162 |
+
device=f0_values.device)
|
163 |
+
rand_ini[:, 0] = 0
|
164 |
+
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
165 |
+
|
166 |
+
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
167 |
+
if not self.flag_for_pulse:
|
168 |
+
# # for normal case
|
169 |
+
|
170 |
+
# # To prevent torch.cumsum numerical overflow,
|
171 |
+
# # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
|
172 |
+
# # Buffer tmp_over_one_idx indicates the time step to add -1.
|
173 |
+
# # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
|
174 |
+
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
175 |
+
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
176 |
+
# cumsum_shift = torch.zeros_like(rad_values)
|
177 |
+
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
178 |
+
|
179 |
+
# phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
180 |
+
rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
|
181 |
+
scale_factor=1/self.upsample_scale,
|
182 |
+
mode="linear").transpose(1, 2)
|
183 |
+
|
184 |
+
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
185 |
+
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
186 |
+
# cumsum_shift = torch.zeros_like(rad_values)
|
187 |
+
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
188 |
+
|
189 |
+
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
190 |
+
phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
|
191 |
+
scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
|
192 |
+
sines = torch.sin(phase)
|
193 |
+
|
194 |
+
else:
|
195 |
+
# If necessary, make sure that the first time step of every
|
196 |
+
# voiced segments is sin(pi) or cos(0)
|
197 |
+
# This is used for pulse-train generation
|
198 |
+
|
199 |
+
# identify the last time step in unvoiced segments
|
200 |
+
uv = self._f02uv(f0_values)
|
201 |
+
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
202 |
+
uv_1[:, -1, :] = 1
|
203 |
+
u_loc = (uv < 1) * (uv_1 > 0)
|
204 |
+
|
205 |
+
# get the instantanouse phase
|
206 |
+
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
207 |
+
# different batch needs to be processed differently
|
208 |
+
for idx in range(f0_values.shape[0]):
|
209 |
+
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
210 |
+
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
211 |
+
# stores the accumulation of i.phase within
|
212 |
+
# each voiced segments
|
213 |
+
tmp_cumsum[idx, :, :] = 0
|
214 |
+
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
215 |
+
|
216 |
+
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
217 |
+
# within the previous voiced segment.
|
218 |
+
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
219 |
+
|
220 |
+
# get the sines
|
221 |
+
sines = torch.cos(i_phase * 2 * np.pi)
|
222 |
+
return sines
|
223 |
+
|
224 |
+
def forward(self, f0):
|
225 |
+
""" sine_tensor, uv = forward(f0)
|
226 |
+
input F0: tensor(batchsize=1, length, dim=1)
|
227 |
+
f0 for unvoiced steps should be 0
|
228 |
+
output sine_tensor: tensor(batchsize=1, length, dim)
|
229 |
+
output uv: tensor(batchsize=1, length, 1)
|
230 |
+
"""
|
231 |
+
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
|
232 |
+
device=f0.device)
|
233 |
+
# fundamental component
|
234 |
+
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
|
235 |
+
|
236 |
+
# generate sine waveforms
|
237 |
+
sine_waves = self._f02sine(fn) * self.sine_amp
|
238 |
+
|
239 |
+
# generate uv signal
|
240 |
+
# uv = torch.ones(f0.shape)
|
241 |
+
# uv = uv * (f0 > self.voiced_threshold)
|
242 |
+
uv = self._f02uv(f0)
|
243 |
+
|
244 |
+
# noise: for unvoiced should be similar to sine_amp
|
245 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
246 |
+
# . for voiced regions is self.noise_std
|
247 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
248 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
249 |
+
|
250 |
+
# first: set the unvoiced part to 0 by uv
|
251 |
+
# then: additive noise
|
252 |
+
sine_waves = sine_waves * uv + noise
|
253 |
+
return sine_waves, uv, noise
|
254 |
+
|
255 |
+
|
256 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
257 |
+
""" SourceModule for hn-nsf
|
258 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
259 |
+
add_noise_std=0.003, voiced_threshod=0)
|
260 |
+
sampling_rate: sampling_rate in Hz
|
261 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
262 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
263 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
264 |
+
note that amplitude of noise in unvoiced is decided
|
265 |
+
by sine_amp
|
266 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
267 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
268 |
+
F0_sampled (batchsize, length, 1)
|
269 |
+
Sine_source (batchsize, length, 1)
|
270 |
+
noise_source (batchsize, length 1)
|
271 |
+
uv (batchsize, length, 1)
|
272 |
+
"""
|
273 |
+
|
274 |
+
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
275 |
+
add_noise_std=0.003, voiced_threshod=0):
|
276 |
+
super(SourceModuleHnNSF, self).__init__()
|
277 |
+
|
278 |
+
self.sine_amp = sine_amp
|
279 |
+
self.noise_std = add_noise_std
|
280 |
+
|
281 |
+
# to produce sine waveforms
|
282 |
+
self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
|
283 |
+
sine_amp, add_noise_std, voiced_threshod)
|
284 |
+
|
285 |
+
# to merge source harmonics into a single excitation
|
286 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
287 |
+
self.l_tanh = torch.nn.Tanh()
|
288 |
+
|
289 |
+
def forward(self, x):
|
290 |
+
"""
|
291 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
292 |
+
F0_sampled (batchsize, length, 1)
|
293 |
+
Sine_source (batchsize, length, 1)
|
294 |
+
noise_source (batchsize, length 1)
|
295 |
+
"""
|
296 |
+
# source for harmonic branch
|
297 |
+
with torch.no_grad():
|
298 |
+
sine_wavs, uv, _ = self.l_sin_gen(x)
|
299 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
300 |
+
|
301 |
+
# source for noise branch, in the same shape as uv
|
302 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
303 |
+
return sine_merge, noise, uv
|
304 |
+
def padDiff(x):
|
305 |
+
return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
|
306 |
+
|
307 |
+
|
308 |
+
class Generator(torch.nn.Module):
|
309 |
+
def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size):
|
310 |
+
super(Generator, self).__init__()
|
311 |
+
|
312 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
313 |
+
self.num_upsamples = len(upsample_rates)
|
314 |
+
resblock = AdaINResBlock1
|
315 |
+
|
316 |
+
self.m_source = SourceModuleHnNSF(
|
317 |
+
sampling_rate=24000,
|
318 |
+
upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size,
|
319 |
+
harmonic_num=8, voiced_threshod=10)
|
320 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * gen_istft_hop_size)
|
321 |
+
self.noise_convs = nn.ModuleList()
|
322 |
+
self.noise_res = nn.ModuleList()
|
323 |
+
|
324 |
+
self.ups = nn.ModuleList()
|
325 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
326 |
+
self.ups.append(weight_norm(
|
327 |
+
ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
|
328 |
+
k, u, padding=(k-u)//2)))
|
329 |
+
|
330 |
+
self.resblocks = nn.ModuleList()
|
331 |
+
for i in range(len(self.ups)):
|
332 |
+
ch = upsample_initial_channel//(2**(i+1))
|
333 |
+
for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
|
334 |
+
self.resblocks.append(resblock(ch, k, d, style_dim))
|
335 |
+
|
336 |
+
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
337 |
+
|
338 |
+
if i + 1 < len(upsample_rates): #
|
339 |
+
stride_f0 = np.prod(upsample_rates[i + 1:])
|
340 |
+
self.noise_convs.append(Conv1d(
|
341 |
+
gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
|
342 |
+
self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
|
343 |
+
else:
|
344 |
+
self.noise_convs.append(Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
|
345 |
+
self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
|
346 |
+
|
347 |
+
|
348 |
+
self.post_n_fft = gen_istft_n_fft
|
349 |
+
self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
|
350 |
+
self.ups.apply(init_weights)
|
351 |
+
self.conv_post.apply(init_weights)
|
352 |
+
self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
|
353 |
+
self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
|
354 |
+
|
355 |
+
|
356 |
+
def forward(self, x, s, f0):
|
357 |
+
with torch.no_grad():
|
358 |
+
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
359 |
+
|
360 |
+
har_source, noi_source, uv = self.m_source(f0)
|
361 |
+
har_source = har_source.transpose(1, 2).squeeze(1)
|
362 |
+
har_spec, har_phase = self.stft.transform(har_source)
|
363 |
+
har = torch.cat([har_spec, har_phase], dim=1)
|
364 |
+
|
365 |
+
for i in range(self.num_upsamples):
|
366 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
367 |
+
x_source = self.noise_convs[i](har)
|
368 |
+
x_source = self.noise_res[i](x_source, s)
|
369 |
+
|
370 |
+
x = self.ups[i](x)
|
371 |
+
if i == self.num_upsamples - 1:
|
372 |
+
x = self.reflection_pad(x)
|
373 |
+
|
374 |
+
x = x + x_source
|
375 |
+
xs = None
|
376 |
+
for j in range(self.num_kernels):
|
377 |
+
if xs is None:
|
378 |
+
xs = self.resblocks[i*self.num_kernels+j](x, s)
|
379 |
+
else:
|
380 |
+
xs += self.resblocks[i*self.num_kernels+j](x, s)
|
381 |
+
x = xs / self.num_kernels
|
382 |
+
x = F.leaky_relu(x)
|
383 |
+
x = self.conv_post(x)
|
384 |
+
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
|
385 |
+
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
|
386 |
+
return self.stft.inverse(spec, phase)
|
387 |
+
|
388 |
+
def fw_phase(self, x, s):
|
389 |
+
for i in range(self.num_upsamples):
|
390 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
391 |
+
x = self.ups[i](x)
|
392 |
+
xs = None
|
393 |
+
for j in range(self.num_kernels):
|
394 |
+
if xs is None:
|
395 |
+
xs = self.resblocks[i*self.num_kernels+j](x, s)
|
396 |
+
else:
|
397 |
+
xs += self.resblocks[i*self.num_kernels+j](x, s)
|
398 |
+
x = xs / self.num_kernels
|
399 |
+
x = F.leaky_relu(x)
|
400 |
+
x = self.reflection_pad(x)
|
401 |
+
x = self.conv_post(x)
|
402 |
+
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
|
403 |
+
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
|
404 |
+
return spec, phase
|
405 |
+
|
406 |
+
def remove_weight_norm(self):
|
407 |
+
print('Removing weight norm...')
|
408 |
+
for l in self.ups:
|
409 |
+
remove_weight_norm(l)
|
410 |
+
for l in self.resblocks:
|
411 |
+
l.remove_weight_norm()
|
412 |
+
remove_weight_norm(self.conv_pre)
|
413 |
+
remove_weight_norm(self.conv_post)
|
414 |
+
|
415 |
+
|
416 |
+
class AdainResBlk1d(nn.Module):
|
417 |
+
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
|
418 |
+
upsample='none', dropout_p=0.0):
|
419 |
+
super().__init__()
|
420 |
+
self.actv = actv
|
421 |
+
self.upsample_type = upsample
|
422 |
+
self.upsample = UpSample1d(upsample)
|
423 |
+
self.learned_sc = dim_in != dim_out
|
424 |
+
self._build_weights(dim_in, dim_out, style_dim)
|
425 |
+
self.dropout = nn.Dropout(dropout_p)
|
426 |
+
|
427 |
+
if upsample == 'none':
|
428 |
+
self.pool = nn.Identity()
|
429 |
+
else:
|
430 |
+
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
|
431 |
+
|
432 |
+
|
433 |
+
def _build_weights(self, dim_in, dim_out, style_dim):
|
434 |
+
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
435 |
+
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
436 |
+
self.norm1 = AdaIN1d(style_dim, dim_in)
|
437 |
+
self.norm2 = AdaIN1d(style_dim, dim_out)
|
438 |
+
if self.learned_sc:
|
439 |
+
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
440 |
+
|
441 |
+
def _shortcut(self, x):
|
442 |
+
x = self.upsample(x)
|
443 |
+
if self.learned_sc:
|
444 |
+
x = self.conv1x1(x)
|
445 |
+
return x
|
446 |
+
|
447 |
+
def _residual(self, x, s):
|
448 |
+
x = self.norm1(x, s)
|
449 |
+
x = self.actv(x)
|
450 |
+
x = self.pool(x)
|
451 |
+
x = self.conv1(self.dropout(x))
|
452 |
+
x = self.norm2(x, s)
|
453 |
+
x = self.actv(x)
|
454 |
+
x = self.conv2(self.dropout(x))
|
455 |
+
return x
|
456 |
+
|
457 |
+
def forward(self, x, s):
|
458 |
+
out = self._residual(x, s)
|
459 |
+
out = (out + self._shortcut(x)) / np.sqrt(2)
|
460 |
+
return out
|
461 |
+
|
462 |
+
class UpSample1d(nn.Module):
|
463 |
+
def __init__(self, layer_type):
|
464 |
+
super().__init__()
|
465 |
+
self.layer_type = layer_type
|
466 |
+
|
467 |
+
def forward(self, x):
|
468 |
+
if self.layer_type == 'none':
|
469 |
+
return x
|
470 |
+
else:
|
471 |
+
return F.interpolate(x, scale_factor=2, mode='nearest')
|
472 |
+
|
473 |
+
class Decoder(nn.Module):
|
474 |
+
def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
|
475 |
+
resblock_kernel_sizes = [3,7,11],
|
476 |
+
upsample_rates = [10, 6],
|
477 |
+
upsample_initial_channel=512,
|
478 |
+
resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
|
479 |
+
upsample_kernel_sizes=[20, 12],
|
480 |
+
gen_istft_n_fft=20, gen_istft_hop_size=5):
|
481 |
+
super().__init__()
|
482 |
+
|
483 |
+
self.decode = nn.ModuleList()
|
484 |
+
|
485 |
+
self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
|
486 |
+
|
487 |
+
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
488 |
+
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
489 |
+
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
490 |
+
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
|
491 |
+
|
492 |
+
self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
|
493 |
+
|
494 |
+
self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
|
495 |
+
|
496 |
+
self.asr_res = nn.Sequential(
|
497 |
+
weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
|
498 |
+
)
|
499 |
+
|
500 |
+
|
501 |
+
self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
|
502 |
+
upsample_initial_channel, resblock_dilation_sizes,
|
503 |
+
upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size)
|
504 |
+
|
505 |
+
def forward(self, asr, F0_curve, N, s):
|
506 |
+
F0 = self.F0_conv(F0_curve.unsqueeze(1))
|
507 |
+
N = self.N_conv(N.unsqueeze(1))
|
508 |
+
|
509 |
+
x = torch.cat([asr, F0, N], axis=1)
|
510 |
+
x = self.encode(x, s)
|
511 |
+
|
512 |
+
asr_res = self.asr_res(asr)
|
513 |
+
|
514 |
+
res = True
|
515 |
+
for block in self.decode:
|
516 |
+
if res:
|
517 |
+
x = torch.cat([x, asr_res, F0, N], axis=1)
|
518 |
+
x = block(x, s)
|
519 |
+
if block.upsample_type != "none":
|
520 |
+
res = False
|
521 |
+
|
522 |
+
x = self.generator(x, s, F0_curve)
|
523 |
+
return x
|
kokoro.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import phonemizer
|
2 |
+
import re
|
3 |
+
import torch
|
4 |
+
from espeak_util import set_espeak_library
|
5 |
+
|
6 |
+
set_espeak_library()
|
7 |
+
|
8 |
+
|
9 |
+
def split_num(num):
|
10 |
+
num = num.group()
|
11 |
+
if "." in num:
|
12 |
+
return num
|
13 |
+
elif ":" in num:
|
14 |
+
h, m = [int(n) for n in num.split(":")]
|
15 |
+
if m == 0:
|
16 |
+
return f"{h} o'clock"
|
17 |
+
elif m < 10:
|
18 |
+
return f"{h} oh {m}"
|
19 |
+
return f"{h} {m}"
|
20 |
+
year = int(num[:4])
|
21 |
+
if year < 1100 or year % 1000 < 10:
|
22 |
+
return num
|
23 |
+
left, right = num[:2], int(num[2:4])
|
24 |
+
s = "s" if num.endswith("s") else ""
|
25 |
+
if 100 <= year % 1000 <= 999:
|
26 |
+
if right == 0:
|
27 |
+
return f"{left} hundred{s}"
|
28 |
+
elif right < 10:
|
29 |
+
return f"{left} oh {right}{s}"
|
30 |
+
return f"{left} {right}{s}"
|
31 |
+
|
32 |
+
|
33 |
+
def flip_money(m):
|
34 |
+
m = m.group()
|
35 |
+
bill = "dollar" if m[0] == "$" else "pound"
|
36 |
+
if m[-1].isalpha():
|
37 |
+
return f"{m[1:]} {bill}s"
|
38 |
+
elif "." not in m:
|
39 |
+
s = "" if m[1:] == "1" else "s"
|
40 |
+
return f"{m[1:]} {bill}{s}"
|
41 |
+
b, c = m[1:].split(".")
|
42 |
+
s = "" if b == "1" else "s"
|
43 |
+
c = int(c.ljust(2, "0"))
|
44 |
+
coins = (
|
45 |
+
f"cent{'' if c == 1 else 's'}"
|
46 |
+
if m[0] == "$"
|
47 |
+
else ("penny" if c == 1 else "pence")
|
48 |
+
)
|
49 |
+
return f"{b} {bill}{s} and {c} {coins}"
|
50 |
+
|
51 |
+
|
52 |
+
def point_num(num):
|
53 |
+
a, b = num.group().split(".")
|
54 |
+
return " point ".join([a, " ".join(b)])
|
55 |
+
|
56 |
+
|
57 |
+
def normalize_text(text):
|
58 |
+
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
59 |
+
text = text.replace("«", chr(8220)).replace("»", chr(8221))
|
60 |
+
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
|
61 |
+
text = text.replace("(", "«").replace(")", "»")
|
62 |
+
for a, b in zip("、。!,:;?", ",.!,:;?"):
|
63 |
+
text = text.replace(a, b + " ")
|
64 |
+
text = re.sub(r"[^\S \n]", " ", text)
|
65 |
+
text = re.sub(r" +", " ", text)
|
66 |
+
text = re.sub(r"(?<=\n) +(?=\n)", "", text)
|
67 |
+
text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
|
68 |
+
text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)
|
69 |
+
text = re.sub(r"\b(?:Ms\.|MS\.(?= [A-Z]))", "Miss", text)
|
70 |
+
text = re.sub(r"\b(?:Mrs\.|MRS\.(?= [A-Z]))", "Mrs", text)
|
71 |
+
text = re.sub(r"\betc\.(?! [A-Z])", "etc", text)
|
72 |
+
text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
|
73 |
+
text = re.sub(
|
74 |
+
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text
|
75 |
+
)
|
76 |
+
text = re.sub(r"(?<=\d),(?=\d)", "", text)
|
77 |
+
text = re.sub(
|
78 |
+
r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
|
79 |
+
flip_money,
|
80 |
+
text,
|
81 |
+
)
|
82 |
+
text = re.sub(r"\d*\.\d+", point_num, text)
|
83 |
+
text = re.sub(r"(?<=\d)-(?=\d)", " to ", text)
|
84 |
+
text = re.sub(r"(?<=\d)S", " S", text)
|
85 |
+
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
|
86 |
+
text = re.sub(r"(?<=X')S\b", "s", text)
|
87 |
+
text = re.sub(
|
88 |
+
r"(?:[A-Za-z]\.){2,} [a-z]", lambda m: m.group().replace(".", "-"), text
|
89 |
+
)
|
90 |
+
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
|
91 |
+
return text.strip()
|
92 |
+
|
93 |
+
|
94 |
+
def get_vocab():
|
95 |
+
_pad = "$"
|
96 |
+
_punctuation = ';:,.!?¡¿—…"«»“” '
|
97 |
+
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
98 |
+
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
99 |
+
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
|
100 |
+
dicts = {}
|
101 |
+
for i in range(len((symbols))):
|
102 |
+
dicts[symbols[i]] = i
|
103 |
+
return dicts
|
104 |
+
|
105 |
+
|
106 |
+
VOCAB = get_vocab()
|
107 |
+
|
108 |
+
|
109 |
+
def tokenize(ps):
|
110 |
+
return [i for i in map(VOCAB.get, ps) if i is not None]
|
111 |
+
|
112 |
+
|
113 |
+
phonemizers = dict(
|
114 |
+
a=phonemizer.backend.EspeakBackend(
|
115 |
+
language="en-us", preserve_punctuation=True, with_stress=True
|
116 |
+
),
|
117 |
+
b=phonemizer.backend.EspeakBackend(
|
118 |
+
language="en-gb", preserve_punctuation=True, with_stress=True
|
119 |
+
),
|
120 |
+
)
|
121 |
+
|
122 |
+
|
123 |
+
def phonemize(text, lang, norm=True):
|
124 |
+
if norm:
|
125 |
+
text = normalize_text(text)
|
126 |
+
ps = phonemizers[lang].phonemize([text])
|
127 |
+
ps = ps[0] if ps else ""
|
128 |
+
# https://en.wiktionary.org/wiki/kokoro#English
|
129 |
+
ps = ps.replace("kəkˈoːɹoʊ", "kˈoʊkəɹoʊ").replace("kəkˈɔːɹəʊ", "kˈəʊkəɹəʊ")
|
130 |
+
ps = ps.replace("ʲ", "j").replace("r", "ɹ").replace("x", "k").replace("ɬ", "l")
|
131 |
+
ps = re.sub(r"(?<=[a-zɹː])(?=hˈʌndɹɪd)", " ", ps)
|
132 |
+
ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', "z", ps)
|
133 |
+
if lang == "a":
|
134 |
+
ps = re.sub(r"(?<=nˈaɪn)ti(?!ː)", "di", ps)
|
135 |
+
ps = "".join(filter(lambda p: p in VOCAB, ps))
|
136 |
+
return ps.strip()
|
137 |
+
|
138 |
+
|
139 |
+
def length_to_mask(lengths):
|
140 |
+
mask = (
|
141 |
+
torch.arange(lengths.max())
|
142 |
+
.unsqueeze(0)
|
143 |
+
.expand(lengths.shape[0], -1)
|
144 |
+
.type_as(lengths)
|
145 |
+
)
|
146 |
+
mask = torch.gt(mask + 1, lengths.unsqueeze(1))
|
147 |
+
return mask
|
148 |
+
|
149 |
+
|
150 |
+
@torch.no_grad()
|
151 |
+
def forward(model, tokens, ref_s, speed):
|
152 |
+
device = ref_s.device
|
153 |
+
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
154 |
+
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
155 |
+
text_mask = length_to_mask(input_lengths).to(device)
|
156 |
+
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
157 |
+
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
158 |
+
s = ref_s[:, 128:]
|
159 |
+
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
160 |
+
x, _ = model.predictor.lstm(d)
|
161 |
+
duration = model.predictor.duration_proj(x)
|
162 |
+
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
163 |
+
pred_dur = torch.round(duration).clamp(min=1).long()
|
164 |
+
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
|
165 |
+
c_frame = 0
|
166 |
+
for i in range(pred_aln_trg.size(0)):
|
167 |
+
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
|
168 |
+
c_frame += pred_dur[0, i].item()
|
169 |
+
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
|
170 |
+
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
171 |
+
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
172 |
+
asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
|
173 |
+
return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
|
174 |
+
|
175 |
+
|
176 |
+
def generate(model, text, voicepack, lang="a", speed=1):
|
177 |
+
ps = phonemize(text, lang)
|
178 |
+
tokens = tokenize(ps)
|
179 |
+
if not tokens:
|
180 |
+
return None
|
181 |
+
elif len(tokens) > 510:
|
182 |
+
tokens = tokens[:510]
|
183 |
+
print("Truncated to 510 tokens")
|
184 |
+
ref_s = voicepack[len(tokens)]
|
185 |
+
out = forward(model, tokens, ref_s, speed)
|
186 |
+
ps = "".join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
|
187 |
+
return out, ps
|
models.py
ADDED
@@ -0,0 +1,738 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/yl4579/StyleTTS2/blob/main/models.py
|
2 |
+
from istftnet import Decoder
|
3 |
+
from munch import Munch
|
4 |
+
from pathlib import Path
|
5 |
+
from plbert import load_plbert
|
6 |
+
from torch.nn.utils import weight_norm, spectral_norm
|
7 |
+
import json
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
|
14 |
+
class LearnedDownSample(nn.Module):
|
15 |
+
def __init__(self, layer_type, dim_in):
|
16 |
+
super().__init__()
|
17 |
+
self.layer_type = layer_type
|
18 |
+
|
19 |
+
if self.layer_type == "none":
|
20 |
+
self.conv = nn.Identity()
|
21 |
+
elif self.layer_type == "timepreserve":
|
22 |
+
self.conv = spectral_norm(
|
23 |
+
nn.Conv2d(
|
24 |
+
dim_in,
|
25 |
+
dim_in,
|
26 |
+
kernel_size=(3, 1),
|
27 |
+
stride=(2, 1),
|
28 |
+
groups=dim_in,
|
29 |
+
padding=(1, 0),
|
30 |
+
)
|
31 |
+
)
|
32 |
+
elif self.layer_type == "half":
|
33 |
+
self.conv = spectral_norm(
|
34 |
+
nn.Conv2d(
|
35 |
+
dim_in,
|
36 |
+
dim_in,
|
37 |
+
kernel_size=(3, 3),
|
38 |
+
stride=(2, 2),
|
39 |
+
groups=dim_in,
|
40 |
+
padding=1,
|
41 |
+
)
|
42 |
+
)
|
43 |
+
else:
|
44 |
+
raise RuntimeError(
|
45 |
+
"Got unexpected donwsampletype %s, expected is [none, timepreserve, half]"
|
46 |
+
% self.layer_type
|
47 |
+
)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
return self.conv(x)
|
51 |
+
|
52 |
+
|
53 |
+
class LearnedUpSample(nn.Module):
|
54 |
+
def __init__(self, layer_type, dim_in):
|
55 |
+
super().__init__()
|
56 |
+
self.layer_type = layer_type
|
57 |
+
|
58 |
+
if self.layer_type == "none":
|
59 |
+
self.conv = nn.Identity()
|
60 |
+
elif self.layer_type == "timepreserve":
|
61 |
+
self.conv = nn.ConvTranspose2d(
|
62 |
+
dim_in,
|
63 |
+
dim_in,
|
64 |
+
kernel_size=(3, 1),
|
65 |
+
stride=(2, 1),
|
66 |
+
groups=dim_in,
|
67 |
+
output_padding=(1, 0),
|
68 |
+
padding=(1, 0),
|
69 |
+
)
|
70 |
+
elif self.layer_type == "half":
|
71 |
+
self.conv = nn.ConvTranspose2d(
|
72 |
+
dim_in,
|
73 |
+
dim_in,
|
74 |
+
kernel_size=(3, 3),
|
75 |
+
stride=(2, 2),
|
76 |
+
groups=dim_in,
|
77 |
+
output_padding=1,
|
78 |
+
padding=1,
|
79 |
+
)
|
80 |
+
else:
|
81 |
+
raise RuntimeError(
|
82 |
+
"Got unexpected upsampletype %s, expected is [none, timepreserve, half]"
|
83 |
+
% self.layer_type
|
84 |
+
)
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
return self.conv(x)
|
88 |
+
|
89 |
+
|
90 |
+
class DownSample(nn.Module):
|
91 |
+
def __init__(self, layer_type):
|
92 |
+
super().__init__()
|
93 |
+
self.layer_type = layer_type
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
if self.layer_type == "none":
|
97 |
+
return x
|
98 |
+
elif self.layer_type == "timepreserve":
|
99 |
+
return F.avg_pool2d(x, (2, 1))
|
100 |
+
elif self.layer_type == "half":
|
101 |
+
if x.shape[-1] % 2 != 0:
|
102 |
+
x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
|
103 |
+
return F.avg_pool2d(x, 2)
|
104 |
+
else:
|
105 |
+
raise RuntimeError(
|
106 |
+
"Got unexpected donwsampletype %s, expected is [none, timepreserve, half]"
|
107 |
+
% self.layer_type
|
108 |
+
)
|
109 |
+
|
110 |
+
|
111 |
+
class UpSample(nn.Module):
|
112 |
+
def __init__(self, layer_type):
|
113 |
+
super().__init__()
|
114 |
+
self.layer_type = layer_type
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
if self.layer_type == "none":
|
118 |
+
return x
|
119 |
+
elif self.layer_type == "timepreserve":
|
120 |
+
return F.interpolate(x, scale_factor=(2, 1), mode="nearest")
|
121 |
+
elif self.layer_type == "half":
|
122 |
+
return F.interpolate(x, scale_factor=2, mode="nearest")
|
123 |
+
else:
|
124 |
+
raise RuntimeError(
|
125 |
+
"Got unexpected upsampletype %s, expected is [none, timepreserve, half]"
|
126 |
+
% self.layer_type
|
127 |
+
)
|
128 |
+
|
129 |
+
|
130 |
+
class ResBlk(nn.Module):
|
131 |
+
def __init__(
|
132 |
+
self,
|
133 |
+
dim_in,
|
134 |
+
dim_out,
|
135 |
+
actv=nn.LeakyReLU(0.2),
|
136 |
+
normalize=False,
|
137 |
+
downsample="none",
|
138 |
+
):
|
139 |
+
super().__init__()
|
140 |
+
self.actv = actv
|
141 |
+
self.normalize = normalize
|
142 |
+
self.downsample = DownSample(downsample)
|
143 |
+
self.downsample_res = LearnedDownSample(downsample, dim_in)
|
144 |
+
self.learned_sc = dim_in != dim_out
|
145 |
+
self._build_weights(dim_in, dim_out)
|
146 |
+
|
147 |
+
def _build_weights(self, dim_in, dim_out):
|
148 |
+
self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
|
149 |
+
self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
|
150 |
+
if self.normalize:
|
151 |
+
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
152 |
+
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
|
153 |
+
if self.learned_sc:
|
154 |
+
self.conv1x1 = spectral_norm(
|
155 |
+
nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
156 |
+
)
|
157 |
+
|
158 |
+
def _shortcut(self, x):
|
159 |
+
if self.learned_sc:
|
160 |
+
x = self.conv1x1(x)
|
161 |
+
if self.downsample:
|
162 |
+
x = self.downsample(x)
|
163 |
+
return x
|
164 |
+
|
165 |
+
def _residual(self, x):
|
166 |
+
if self.normalize:
|
167 |
+
x = self.norm1(x)
|
168 |
+
x = self.actv(x)
|
169 |
+
x = self.conv1(x)
|
170 |
+
x = self.downsample_res(x)
|
171 |
+
if self.normalize:
|
172 |
+
x = self.norm2(x)
|
173 |
+
x = self.actv(x)
|
174 |
+
x = self.conv2(x)
|
175 |
+
return x
|
176 |
+
|
177 |
+
def forward(self, x):
|
178 |
+
x = self._shortcut(x) + self._residual(x)
|
179 |
+
return x / np.sqrt(2) # unit variance
|
180 |
+
|
181 |
+
|
182 |
+
class LinearNorm(torch.nn.Module):
|
183 |
+
def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"):
|
184 |
+
super(LinearNorm, self).__init__()
|
185 |
+
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
186 |
+
|
187 |
+
torch.nn.init.xavier_uniform_(
|
188 |
+
self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
|
189 |
+
)
|
190 |
+
|
191 |
+
def forward(self, x):
|
192 |
+
return self.linear_layer(x)
|
193 |
+
|
194 |
+
|
195 |
+
class Discriminator2d(nn.Module):
|
196 |
+
def __init__(self, dim_in=48, num_domains=1, max_conv_dim=384, repeat_num=4):
|
197 |
+
super().__init__()
|
198 |
+
blocks = []
|
199 |
+
blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
|
200 |
+
|
201 |
+
for lid in range(repeat_num):
|
202 |
+
dim_out = min(dim_in * 2, max_conv_dim)
|
203 |
+
blocks += [ResBlk(dim_in, dim_out, downsample="half")]
|
204 |
+
dim_in = dim_out
|
205 |
+
|
206 |
+
blocks += [nn.LeakyReLU(0.2)]
|
207 |
+
blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
|
208 |
+
blocks += [nn.LeakyReLU(0.2)]
|
209 |
+
blocks += [nn.AdaptiveAvgPool2d(1)]
|
210 |
+
blocks += [spectral_norm(nn.Conv2d(dim_out, num_domains, 1, 1, 0))]
|
211 |
+
self.main = nn.Sequential(*blocks)
|
212 |
+
|
213 |
+
def get_feature(self, x):
|
214 |
+
features = []
|
215 |
+
for l in self.main:
|
216 |
+
x = l(x)
|
217 |
+
features.append(x)
|
218 |
+
out = features[-1]
|
219 |
+
out = out.view(out.size(0), -1) # (batch, num_domains)
|
220 |
+
return out, features
|
221 |
+
|
222 |
+
def forward(self, x):
|
223 |
+
out, features = self.get_feature(x)
|
224 |
+
out = out.squeeze() # (batch)
|
225 |
+
return out, features
|
226 |
+
|
227 |
+
|
228 |
+
class ResBlk1d(nn.Module):
|
229 |
+
def __init__(
|
230 |
+
self,
|
231 |
+
dim_in,
|
232 |
+
dim_out,
|
233 |
+
actv=nn.LeakyReLU(0.2),
|
234 |
+
normalize=False,
|
235 |
+
downsample="none",
|
236 |
+
dropout_p=0.2,
|
237 |
+
):
|
238 |
+
super().__init__()
|
239 |
+
self.actv = actv
|
240 |
+
self.normalize = normalize
|
241 |
+
self.downsample_type = downsample
|
242 |
+
self.learned_sc = dim_in != dim_out
|
243 |
+
self._build_weights(dim_in, dim_out)
|
244 |
+
self.dropout_p = dropout_p
|
245 |
+
|
246 |
+
if self.downsample_type == "none":
|
247 |
+
self.pool = nn.Identity()
|
248 |
+
else:
|
249 |
+
self.pool = weight_norm(
|
250 |
+
nn.Conv1d(
|
251 |
+
dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1
|
252 |
+
)
|
253 |
+
)
|
254 |
+
|
255 |
+
def _build_weights(self, dim_in, dim_out):
|
256 |
+
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1))
|
257 |
+
self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
258 |
+
if self.normalize:
|
259 |
+
self.norm1 = nn.InstanceNorm1d(dim_in, affine=True)
|
260 |
+
self.norm2 = nn.InstanceNorm1d(dim_in, affine=True)
|
261 |
+
if self.learned_sc:
|
262 |
+
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
263 |
+
|
264 |
+
def downsample(self, x):
|
265 |
+
if self.downsample_type == "none":
|
266 |
+
return x
|
267 |
+
else:
|
268 |
+
if x.shape[-1] % 2 != 0:
|
269 |
+
x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
|
270 |
+
return F.avg_pool1d(x, 2)
|
271 |
+
|
272 |
+
def _shortcut(self, x):
|
273 |
+
if self.learned_sc:
|
274 |
+
x = self.conv1x1(x)
|
275 |
+
x = self.downsample(x)
|
276 |
+
return x
|
277 |
+
|
278 |
+
def _residual(self, x):
|
279 |
+
if self.normalize:
|
280 |
+
x = self.norm1(x)
|
281 |
+
x = self.actv(x)
|
282 |
+
x = F.dropout(x, p=self.dropout_p, training=self.training)
|
283 |
+
|
284 |
+
x = self.conv1(x)
|
285 |
+
x = self.pool(x)
|
286 |
+
if self.normalize:
|
287 |
+
x = self.norm2(x)
|
288 |
+
|
289 |
+
x = self.actv(x)
|
290 |
+
x = F.dropout(x, p=self.dropout_p, training=self.training)
|
291 |
+
|
292 |
+
x = self.conv2(x)
|
293 |
+
return x
|
294 |
+
|
295 |
+
def forward(self, x):
|
296 |
+
x = self._shortcut(x) + self._residual(x)
|
297 |
+
return x / np.sqrt(2) # unit variance
|
298 |
+
|
299 |
+
|
300 |
+
class LayerNorm(nn.Module):
|
301 |
+
def __init__(self, channels, eps=1e-5):
|
302 |
+
super().__init__()
|
303 |
+
self.channels = channels
|
304 |
+
self.eps = eps
|
305 |
+
|
306 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
307 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
308 |
+
|
309 |
+
def forward(self, x):
|
310 |
+
x = x.transpose(1, -1)
|
311 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
312 |
+
return x.transpose(1, -1)
|
313 |
+
|
314 |
+
|
315 |
+
class TextEncoder(nn.Module):
|
316 |
+
def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
|
317 |
+
super().__init__()
|
318 |
+
self.embedding = nn.Embedding(n_symbols, channels)
|
319 |
+
|
320 |
+
padding = (kernel_size - 1) // 2
|
321 |
+
self.cnn = nn.ModuleList()
|
322 |
+
for _ in range(depth):
|
323 |
+
self.cnn.append(
|
324 |
+
nn.Sequential(
|
325 |
+
weight_norm(
|
326 |
+
nn.Conv1d(
|
327 |
+
channels, channels, kernel_size=kernel_size, padding=padding
|
328 |
+
)
|
329 |
+
),
|
330 |
+
LayerNorm(channels),
|
331 |
+
actv,
|
332 |
+
nn.Dropout(0.2),
|
333 |
+
)
|
334 |
+
)
|
335 |
+
# self.cnn = nn.Sequential(*self.cnn)
|
336 |
+
|
337 |
+
self.lstm = nn.LSTM(
|
338 |
+
channels, channels // 2, 1, batch_first=True, bidirectional=True
|
339 |
+
)
|
340 |
+
|
341 |
+
def forward(self, x, input_lengths, m):
|
342 |
+
x = self.embedding(x) # [B, T, emb]
|
343 |
+
x = x.transpose(1, 2) # [B, emb, T]
|
344 |
+
m = m.to(input_lengths.device).unsqueeze(1)
|
345 |
+
x.masked_fill_(m, 0.0)
|
346 |
+
|
347 |
+
for c in self.cnn:
|
348 |
+
x = c(x)
|
349 |
+
x.masked_fill_(m, 0.0)
|
350 |
+
|
351 |
+
x = x.transpose(1, 2) # [B, T, chn]
|
352 |
+
|
353 |
+
input_lengths = input_lengths.cpu().numpy()
|
354 |
+
x = nn.utils.rnn.pack_padded_sequence(
|
355 |
+
x, input_lengths, batch_first=True, enforce_sorted=False
|
356 |
+
)
|
357 |
+
|
358 |
+
self.lstm.flatten_parameters()
|
359 |
+
x, _ = self.lstm(x)
|
360 |
+
x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
|
361 |
+
|
362 |
+
x = x.transpose(-1, -2)
|
363 |
+
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
|
364 |
+
|
365 |
+
x_pad[:, :, : x.shape[-1]] = x
|
366 |
+
x = x_pad.to(x.device)
|
367 |
+
|
368 |
+
x.masked_fill_(m, 0.0)
|
369 |
+
|
370 |
+
return x
|
371 |
+
|
372 |
+
def inference(self, x):
|
373 |
+
x = self.embedding(x)
|
374 |
+
x = x.transpose(1, 2)
|
375 |
+
x = self.cnn(x)
|
376 |
+
x = x.transpose(1, 2)
|
377 |
+
self.lstm.flatten_parameters()
|
378 |
+
x, _ = self.lstm(x)
|
379 |
+
return x
|
380 |
+
|
381 |
+
def length_to_mask(self, lengths):
|
382 |
+
mask = (
|
383 |
+
torch.arange(lengths.max())
|
384 |
+
.unsqueeze(0)
|
385 |
+
.expand(lengths.shape[0], -1)
|
386 |
+
.type_as(lengths)
|
387 |
+
)
|
388 |
+
mask = torch.gt(mask + 1, lengths.unsqueeze(1))
|
389 |
+
return mask
|
390 |
+
|
391 |
+
|
392 |
+
class AdaIN1d(nn.Module):
|
393 |
+
def __init__(self, style_dim, num_features):
|
394 |
+
super().__init__()
|
395 |
+
self.norm = nn.InstanceNorm1d(num_features, affine=False)
|
396 |
+
self.fc = nn.Linear(style_dim, num_features * 2)
|
397 |
+
|
398 |
+
def forward(self, x, s):
|
399 |
+
h = self.fc(s)
|
400 |
+
h = h.view(h.size(0), h.size(1), 1)
|
401 |
+
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
402 |
+
return (1 + gamma) * self.norm(x) + beta
|
403 |
+
|
404 |
+
|
405 |
+
class UpSample1d(nn.Module):
|
406 |
+
def __init__(self, layer_type):
|
407 |
+
super().__init__()
|
408 |
+
self.layer_type = layer_type
|
409 |
+
|
410 |
+
def forward(self, x):
|
411 |
+
if self.layer_type == "none":
|
412 |
+
return x
|
413 |
+
else:
|
414 |
+
return F.interpolate(x, scale_factor=2, mode="nearest")
|
415 |
+
|
416 |
+
|
417 |
+
class AdainResBlk1d(nn.Module):
|
418 |
+
def __init__(
|
419 |
+
self,
|
420 |
+
dim_in,
|
421 |
+
dim_out,
|
422 |
+
style_dim=64,
|
423 |
+
actv=nn.LeakyReLU(0.2),
|
424 |
+
upsample="none",
|
425 |
+
dropout_p=0.0,
|
426 |
+
):
|
427 |
+
super().__init__()
|
428 |
+
self.actv = actv
|
429 |
+
self.upsample_type = upsample
|
430 |
+
self.upsample = UpSample1d(upsample)
|
431 |
+
self.learned_sc = dim_in != dim_out
|
432 |
+
self._build_weights(dim_in, dim_out, style_dim)
|
433 |
+
self.dropout = nn.Dropout(dropout_p)
|
434 |
+
|
435 |
+
if upsample == "none":
|
436 |
+
self.pool = nn.Identity()
|
437 |
+
else:
|
438 |
+
self.pool = weight_norm(
|
439 |
+
nn.ConvTranspose1d(
|
440 |
+
dim_in,
|
441 |
+
dim_in,
|
442 |
+
kernel_size=3,
|
443 |
+
stride=2,
|
444 |
+
groups=dim_in,
|
445 |
+
padding=1,
|
446 |
+
output_padding=1,
|
447 |
+
)
|
448 |
+
)
|
449 |
+
|
450 |
+
def _build_weights(self, dim_in, dim_out, style_dim):
|
451 |
+
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
452 |
+
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
453 |
+
self.norm1 = AdaIN1d(style_dim, dim_in)
|
454 |
+
self.norm2 = AdaIN1d(style_dim, dim_out)
|
455 |
+
if self.learned_sc:
|
456 |
+
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
457 |
+
|
458 |
+
def _shortcut(self, x):
|
459 |
+
x = self.upsample(x)
|
460 |
+
if self.learned_sc:
|
461 |
+
x = self.conv1x1(x)
|
462 |
+
return x
|
463 |
+
|
464 |
+
def _residual(self, x, s):
|
465 |
+
x = self.norm1(x, s)
|
466 |
+
x = self.actv(x)
|
467 |
+
x = self.pool(x)
|
468 |
+
x = self.conv1(self.dropout(x))
|
469 |
+
x = self.norm2(x, s)
|
470 |
+
x = self.actv(x)
|
471 |
+
x = self.conv2(self.dropout(x))
|
472 |
+
return x
|
473 |
+
|
474 |
+
def forward(self, x, s):
|
475 |
+
out = self._residual(x, s)
|
476 |
+
out = (out + self._shortcut(x)) / np.sqrt(2)
|
477 |
+
return out
|
478 |
+
|
479 |
+
|
480 |
+
class AdaLayerNorm(nn.Module):
|
481 |
+
def __init__(self, style_dim, channels, eps=1e-5):
|
482 |
+
super().__init__()
|
483 |
+
self.channels = channels
|
484 |
+
self.eps = eps
|
485 |
+
|
486 |
+
self.fc = nn.Linear(style_dim, channels * 2)
|
487 |
+
|
488 |
+
def forward(self, x, s):
|
489 |
+
x = x.transpose(-1, -2)
|
490 |
+
x = x.transpose(1, -1)
|
491 |
+
|
492 |
+
h = self.fc(s)
|
493 |
+
h = h.view(h.size(0), h.size(1), 1)
|
494 |
+
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
495 |
+
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
|
496 |
+
|
497 |
+
x = F.layer_norm(x, (self.channels,), eps=self.eps)
|
498 |
+
x = (1 + gamma) * x + beta
|
499 |
+
return x.transpose(1, -1).transpose(-1, -2)
|
500 |
+
|
501 |
+
|
502 |
+
class ProsodyPredictor(nn.Module):
|
503 |
+
def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
|
504 |
+
super().__init__()
|
505 |
+
|
506 |
+
self.text_encoder = DurationEncoder(
|
507 |
+
sty_dim=style_dim, d_model=d_hid, nlayers=nlayers, dropout=dropout
|
508 |
+
)
|
509 |
+
|
510 |
+
self.lstm = nn.LSTM(
|
511 |
+
d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True
|
512 |
+
)
|
513 |
+
self.duration_proj = LinearNorm(d_hid, max_dur)
|
514 |
+
|
515 |
+
self.shared = nn.LSTM(
|
516 |
+
d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True
|
517 |
+
)
|
518 |
+
self.F0 = nn.ModuleList()
|
519 |
+
self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
|
520 |
+
self.F0.append(
|
521 |
+
AdainResBlk1d(
|
522 |
+
d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout
|
523 |
+
)
|
524 |
+
)
|
525 |
+
self.F0.append(
|
526 |
+
AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout)
|
527 |
+
)
|
528 |
+
|
529 |
+
self.N = nn.ModuleList()
|
530 |
+
self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
|
531 |
+
self.N.append(
|
532 |
+
AdainResBlk1d(
|
533 |
+
d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout
|
534 |
+
)
|
535 |
+
)
|
536 |
+
self.N.append(
|
537 |
+
AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout)
|
538 |
+
)
|
539 |
+
|
540 |
+
self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
541 |
+
self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
542 |
+
|
543 |
+
def forward(self, texts, style, text_lengths, alignment, m):
|
544 |
+
d = self.text_encoder(texts, style, text_lengths, m)
|
545 |
+
|
546 |
+
batch_size = d.shape[0]
|
547 |
+
text_size = d.shape[1]
|
548 |
+
|
549 |
+
# predict duration
|
550 |
+
input_lengths = text_lengths.cpu().numpy()
|
551 |
+
x = nn.utils.rnn.pack_padded_sequence(
|
552 |
+
d, input_lengths, batch_first=True, enforce_sorted=False
|
553 |
+
)
|
554 |
+
|
555 |
+
m = m.to(text_lengths.device).unsqueeze(1)
|
556 |
+
|
557 |
+
self.lstm.flatten_parameters()
|
558 |
+
x, _ = self.lstm(x)
|
559 |
+
x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
|
560 |
+
|
561 |
+
x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
|
562 |
+
|
563 |
+
x_pad[:, : x.shape[1], :] = x
|
564 |
+
x = x_pad.to(x.device)
|
565 |
+
|
566 |
+
duration = self.duration_proj(
|
567 |
+
nn.functional.dropout(x, 0.5, training=self.training)
|
568 |
+
)
|
569 |
+
|
570 |
+
en = d.transpose(-1, -2) @ alignment
|
571 |
+
|
572 |
+
return duration.squeeze(-1), en
|
573 |
+
|
574 |
+
def F0Ntrain(self, x, s):
|
575 |
+
x, _ = self.shared(x.transpose(-1, -2))
|
576 |
+
|
577 |
+
F0 = x.transpose(-1, -2)
|
578 |
+
for block in self.F0:
|
579 |
+
F0 = block(F0, s)
|
580 |
+
F0 = self.F0_proj(F0)
|
581 |
+
|
582 |
+
N = x.transpose(-1, -2)
|
583 |
+
for block in self.N:
|
584 |
+
N = block(N, s)
|
585 |
+
N = self.N_proj(N)
|
586 |
+
|
587 |
+
return F0.squeeze(1), N.squeeze(1)
|
588 |
+
|
589 |
+
def length_to_mask(self, lengths):
|
590 |
+
mask = (
|
591 |
+
torch.arange(lengths.max())
|
592 |
+
.unsqueeze(0)
|
593 |
+
.expand(lengths.shape[0], -1)
|
594 |
+
.type_as(lengths)
|
595 |
+
)
|
596 |
+
mask = torch.gt(mask + 1, lengths.unsqueeze(1))
|
597 |
+
return mask
|
598 |
+
|
599 |
+
|
600 |
+
class DurationEncoder(nn.Module):
|
601 |
+
def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
|
602 |
+
super().__init__()
|
603 |
+
self.lstms = nn.ModuleList()
|
604 |
+
for _ in range(nlayers):
|
605 |
+
self.lstms.append(
|
606 |
+
nn.LSTM(
|
607 |
+
d_model + sty_dim,
|
608 |
+
d_model // 2,
|
609 |
+
num_layers=1,
|
610 |
+
batch_first=True,
|
611 |
+
bidirectional=True,
|
612 |
+
dropout=dropout,
|
613 |
+
)
|
614 |
+
)
|
615 |
+
self.lstms.append(AdaLayerNorm(sty_dim, d_model))
|
616 |
+
|
617 |
+
self.dropout = dropout
|
618 |
+
self.d_model = d_model
|
619 |
+
self.sty_dim = sty_dim
|
620 |
+
|
621 |
+
def forward(self, x, style, text_lengths, m):
|
622 |
+
masks = m.to(text_lengths.device)
|
623 |
+
|
624 |
+
x = x.permute(2, 0, 1)
|
625 |
+
s = style.expand(x.shape[0], x.shape[1], -1)
|
626 |
+
x = torch.cat([x, s], axis=-1)
|
627 |
+
x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
|
628 |
+
|
629 |
+
x = x.transpose(0, 1)
|
630 |
+
input_lengths = text_lengths.cpu().numpy()
|
631 |
+
x = x.transpose(-1, -2)
|
632 |
+
|
633 |
+
for block in self.lstms:
|
634 |
+
if isinstance(block, AdaLayerNorm):
|
635 |
+
x = block(x.transpose(-1, -2), style).transpose(-1, -2)
|
636 |
+
x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
|
637 |
+
x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
|
638 |
+
else:
|
639 |
+
x = x.transpose(-1, -2)
|
640 |
+
x = nn.utils.rnn.pack_padded_sequence(
|
641 |
+
x, input_lengths, batch_first=True, enforce_sorted=False
|
642 |
+
)
|
643 |
+
block.flatten_parameters()
|
644 |
+
x, _ = block(x)
|
645 |
+
x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
|
646 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
647 |
+
x = x.transpose(-1, -2)
|
648 |
+
|
649 |
+
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
|
650 |
+
|
651 |
+
x_pad[:, :, : x.shape[-1]] = x
|
652 |
+
x = x_pad.to(x.device)
|
653 |
+
|
654 |
+
return x.transpose(-1, -2)
|
655 |
+
|
656 |
+
def inference(self, x, style):
|
657 |
+
x = self.embedding(x.transpose(-1, -2)) * np.sqrt(self.d_model)
|
658 |
+
style = style.expand(x.shape[0], x.shape[1], -1)
|
659 |
+
x = torch.cat([x, style], axis=-1)
|
660 |
+
src = self.pos_encoder(x)
|
661 |
+
output = self.transformer_encoder(src).transpose(0, 1)
|
662 |
+
return output
|
663 |
+
|
664 |
+
def length_to_mask(self, lengths):
|
665 |
+
mask = (
|
666 |
+
torch.arange(lengths.max())
|
667 |
+
.unsqueeze(0)
|
668 |
+
.expand(lengths.shape[0], -1)
|
669 |
+
.type_as(lengths)
|
670 |
+
)
|
671 |
+
mask = torch.gt(mask + 1, lengths.unsqueeze(1))
|
672 |
+
return mask
|
673 |
+
|
674 |
+
|
675 |
+
# https://github.com/yl4579/StyleTTS2/blob/main/utils.py
|
676 |
+
def recursive_munch(d):
|
677 |
+
if isinstance(d, dict):
|
678 |
+
return Munch((k, recursive_munch(v)) for k, v in d.items())
|
679 |
+
elif isinstance(d, list):
|
680 |
+
return [recursive_munch(v) for v in d]
|
681 |
+
else:
|
682 |
+
return d
|
683 |
+
|
684 |
+
|
685 |
+
def build_model(path, device):
|
686 |
+
config = Path(path).parent / "config.json"
|
687 |
+
assert config.exists(), f"Config path incorrect: config.json not found at {config}"
|
688 |
+
with open(config, "r") as r:
|
689 |
+
args = recursive_munch(json.load(r))
|
690 |
+
assert args.decoder.type == "istftnet", f"Unknown decoder type: {args.decoder.type}"
|
691 |
+
decoder = Decoder(
|
692 |
+
dim_in=args.hidden_dim,
|
693 |
+
style_dim=args.style_dim,
|
694 |
+
dim_out=args.n_mels,
|
695 |
+
resblock_kernel_sizes=args.decoder.resblock_kernel_sizes,
|
696 |
+
upsample_rates=args.decoder.upsample_rates,
|
697 |
+
upsample_initial_channel=args.decoder.upsample_initial_channel,
|
698 |
+
resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
|
699 |
+
upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
|
700 |
+
gen_istft_n_fft=args.decoder.gen_istft_n_fft,
|
701 |
+
gen_istft_hop_size=args.decoder.gen_istft_hop_size,
|
702 |
+
)
|
703 |
+
text_encoder = TextEncoder(
|
704 |
+
channels=args.hidden_dim,
|
705 |
+
kernel_size=5,
|
706 |
+
depth=args.n_layer,
|
707 |
+
n_symbols=args.n_token,
|
708 |
+
)
|
709 |
+
predictor = ProsodyPredictor(
|
710 |
+
style_dim=args.style_dim,
|
711 |
+
d_hid=args.hidden_dim,
|
712 |
+
nlayers=args.n_layer,
|
713 |
+
max_dur=args.max_dur,
|
714 |
+
dropout=args.dropout,
|
715 |
+
)
|
716 |
+
bert = load_plbert()
|
717 |
+
bert_encoder = nn.Linear(bert.config.hidden_size, args.hidden_dim)
|
718 |
+
for parent in [bert, bert_encoder, predictor, decoder, text_encoder]:
|
719 |
+
for child in parent.children():
|
720 |
+
if isinstance(child, nn.RNNBase):
|
721 |
+
child.flatten_parameters()
|
722 |
+
model = Munch(
|
723 |
+
bert=bert.to(device).eval(),
|
724 |
+
bert_encoder=bert_encoder.to(device).eval(),
|
725 |
+
predictor=predictor.to(device).eval(),
|
726 |
+
decoder=decoder.to(device).eval(),
|
727 |
+
text_encoder=text_encoder.to(device).eval(),
|
728 |
+
)
|
729 |
+
for key, state_dict in torch.load(path, map_location="cpu", weights_only=True)[
|
730 |
+
"net"
|
731 |
+
].items():
|
732 |
+
assert key in model, key
|
733 |
+
try:
|
734 |
+
model[key].load_state_dict(state_dict)
|
735 |
+
except:
|
736 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
737 |
+
model[key].load_state_dict(state_dict, strict=False)
|
738 |
+
return model
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
espeak-ng
|
plbert.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
|
2 |
+
from transformers import AlbertConfig, AlbertModel
|
3 |
+
|
4 |
+
class CustomAlbert(AlbertModel):
|
5 |
+
def forward(self, *args, **kwargs):
|
6 |
+
# Call the original forward method
|
7 |
+
outputs = super().forward(*args, **kwargs)
|
8 |
+
# Only return the last_hidden_state
|
9 |
+
return outputs.last_hidden_state
|
10 |
+
|
11 |
+
def load_plbert():
|
12 |
+
plbert_config = {'vocab_size': 178, 'hidden_size': 768, 'num_attention_heads': 12, 'intermediate_size': 2048, 'max_position_embeddings': 512, 'num_hidden_layers': 12, 'dropout': 0.1}
|
13 |
+
albert_base_configuration = AlbertConfig(**plbert_config)
|
14 |
+
bert = CustomAlbert(albert_base_configuration)
|
15 |
+
return bert
|
pretrained_models/Kokoro/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .downloader import initialize_files
|
2 |
+
|
3 |
+
initialize_files()
|
pretrained_models/Kokoro/downloader.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
from huggingface_hub import snapshot_download
|
4 |
+
|
5 |
+
|
6 |
+
def download_files(repo_id, target_base_dir, patterns):
|
7 |
+
"""
|
8 |
+
Downloads files matching patterns from a Hugging Face repository and organizes them in a directory structure.
|
9 |
+
|
10 |
+
:param repo_id: Hugging Face repository ID.
|
11 |
+
:param target_base_dir: Base directory where files should be stored.
|
12 |
+
:param patterns: A dictionary mapping subdirectories to file patterns.
|
13 |
+
Example: {"root": ["config.json", "*.pth"], "voices": ["*.pt"]}
|
14 |
+
"""
|
15 |
+
# Ensure target base directory exists
|
16 |
+
if not os.path.exists(target_base_dir):
|
17 |
+
os.makedirs(target_base_dir)
|
18 |
+
|
19 |
+
# Download the snapshot containing all matching files
|
20 |
+
snapshot_dir = snapshot_download(repo_id=repo_id, allow_patterns="*")
|
21 |
+
|
22 |
+
# Loop through patterns and subdirectories
|
23 |
+
for subdir, file_patterns in patterns.items():
|
24 |
+
# Set target directory for root-level files
|
25 |
+
target_dir = (
|
26 |
+
target_base_dir
|
27 |
+
if subdir == "root"
|
28 |
+
else os.path.join(target_base_dir, subdir)
|
29 |
+
)
|
30 |
+
os.makedirs(target_dir, exist_ok=True)
|
31 |
+
|
32 |
+
for file_pattern in file_patterns:
|
33 |
+
# Walk through the snapshot directory to find matching files
|
34 |
+
for root, _, files in os.walk(snapshot_dir):
|
35 |
+
for file in files:
|
36 |
+
if file.endswith(file_pattern.lstrip("*")): # Match pattern
|
37 |
+
source_path = os.path.join(root, file)
|
38 |
+
target_file_path = os.path.join(target_dir, file)
|
39 |
+
|
40 |
+
# Check if file already exists
|
41 |
+
if not os.path.exists(target_file_path):
|
42 |
+
# Copy the file to the target directory
|
43 |
+
shutil.copy(source_path, target_file_path)
|
44 |
+
print(f"Downloaded and saved: {file} to {target_dir}")
|
45 |
+
else:
|
46 |
+
print(f"File already exists, skipping: {target_file_path}")
|
47 |
+
|
48 |
+
|
49 |
+
def initialize_files():
|
50 |
+
repo_id = "hexgrad/Kokoro-82M"
|
51 |
+
target_base_dir = "pretrained_models/Kokoro" # Base directory for files
|
52 |
+
file_patterns = {
|
53 |
+
"root": ["config.json", "*.pth"], # Files for the root directory
|
54 |
+
"voices": ["*.pt"], # Wildcard for voice pack files
|
55 |
+
}
|
56 |
+
|
57 |
+
download_files(repo_id, target_base_dir, file_patterns)
|
58 |
+
|
59 |
+
|
60 |
+
if __name__ == "__main__":
|
61 |
+
initialize_files()
|
pyproject.toml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "kokoro-studio"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "Add your description here"
|
5 |
+
readme = "README.md"
|
6 |
+
requires-python = ">=3.11"
|
7 |
+
dependencies = [
|
8 |
+
"gradio>=5.9.1",
|
9 |
+
"munch>=4.0.0",
|
10 |
+
"openphonemizer>=0.1.2",
|
11 |
+
"phonemizer>=3.3.0",
|
12 |
+
"pyyaml>=6.0.2",
|
13 |
+
"scipy>=1.14.1",
|
14 |
+
"soundfile>=0.12.1",
|
15 |
+
"torch>=2.5.1",
|
16 |
+
"transformers>=4.47.1",
|
17 |
+
]
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio>=5.9.1
|
2 |
+
munch>=4.0.0
|
3 |
+
openphonemizer>=0.1.2
|
4 |
+
phonemizer>=3.3.0
|
5 |
+
pyyaml>=6.0.2
|
6 |
+
scipy>=1.14.1
|
7 |
+
soundfile>=0.12.1
|
8 |
+
torch>=2.5.1
|
9 |
+
transformers>=4.47.1
|
tts_cli.py
ADDED
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# tts_cli.py
|
3 |
+
"""
|
4 |
+
Example CLI for generating audio with Kokoro-StyleTTS2.
|
5 |
+
|
6 |
+
Usage:
|
7 |
+
python tts_cli.py \
|
8 |
+
--model /path/to/kokoro-v0_19.pth \
|
9 |
+
--config /path/to/config.json \
|
10 |
+
--text "Hello, my stinking friends from 1906! You stink." \
|
11 |
+
--voicepack /path/to/af.pt \
|
12 |
+
--output output.wav
|
13 |
+
|
14 |
+
Make sure:
|
15 |
+
1. `models.py` is in the same folder (with `build_model`, `Decoder`, etc.).
|
16 |
+
2. You have installed the needed libraries:
|
17 |
+
pip install torch phonemizer munch soundfile pyyaml
|
18 |
+
3. The model is a checkpoint that your `build_model` can load.
|
19 |
+
|
20 |
+
Adapt as needed!
|
21 |
+
"""
|
22 |
+
|
23 |
+
import argparse
|
24 |
+
import os
|
25 |
+
import re
|
26 |
+
import torch
|
27 |
+
import soundfile as sf
|
28 |
+
import numpy as np
|
29 |
+
from phonemizer import backend as phonemizer_backend
|
30 |
+
|
31 |
+
# If you use eSpeak library:
|
32 |
+
try:
|
33 |
+
from espeak_util import set_espeak_library
|
34 |
+
|
35 |
+
set_espeak_library()
|
36 |
+
except ImportError:
|
37 |
+
pass
|
38 |
+
|
39 |
+
# --------------------------------------------------------------------
|
40 |
+
# Import from your local `models.py` (requires that file to be present).
|
41 |
+
# This example assumes `build_model` loads the entire TTS submodules
|
42 |
+
# (bert, bert_encoder, predictor, decoder, text_encoder).
|
43 |
+
# --------------------------------------------------------------------
|
44 |
+
from models import build_model
|
45 |
+
|
46 |
+
|
47 |
+
def resplit_strings(arr):
|
48 |
+
"""
|
49 |
+
Given a list of string tokens (e.g. words, phrases), tries to
|
50 |
+
split them into two sub-lists whose total lengths are as balanced
|
51 |
+
as possible. The goal is to chunk a large string in half without
|
52 |
+
splitting in the middle of a word.
|
53 |
+
"""
|
54 |
+
if not arr:
|
55 |
+
return "", ""
|
56 |
+
if len(arr) == 1:
|
57 |
+
return arr[0], ""
|
58 |
+
|
59 |
+
min_diff = float("inf")
|
60 |
+
best_split = 0
|
61 |
+
lengths = [len(s) for s in arr]
|
62 |
+
spaces = len(arr) - 1
|
63 |
+
left_len = 0
|
64 |
+
right_len = sum(lengths) + spaces
|
65 |
+
|
66 |
+
for i in range(1, len(arr)):
|
67 |
+
# Add current word + space to left side
|
68 |
+
left_len += lengths[i - 1] + (1 if i > 1 else 0)
|
69 |
+
# Remove from right side
|
70 |
+
right_len -= lengths[i - 1] + 1
|
71 |
+
diff = abs(left_len - right_len)
|
72 |
+
if diff < min_diff:
|
73 |
+
min_diff = diff
|
74 |
+
best_split = i
|
75 |
+
|
76 |
+
return " ".join(arr[:best_split]), " ".join(arr[best_split:])
|
77 |
+
|
78 |
+
|
79 |
+
def recursive_split(text, lang="a"):
|
80 |
+
"""
|
81 |
+
Splits a piece of text into smaller segments so that
|
82 |
+
each segment's phoneme length < some ~limit (~500 tokens).
|
83 |
+
"""
|
84 |
+
# We'll reuse your existing `phonemize_text` + `tokenize` from script 1
|
85 |
+
# to see if it is < 512 tokens. If it is, return it as a single chunk.
|
86 |
+
# Otherwise, split on punctuation or whitespace and recurse.
|
87 |
+
|
88 |
+
# 1. Phonemize first, check length
|
89 |
+
ps = phonemize_text(text, lang=lang, do_normalize=True)
|
90 |
+
tokens = tokenize(ps)
|
91 |
+
if len(tokens) < 512:
|
92 |
+
return [(text, ps)]
|
93 |
+
|
94 |
+
# If too large, we split on certain punctuation or fallback to whitespace
|
95 |
+
# We'll look for punctuation that often indicates sentence boundaries
|
96 |
+
# If none found, fallback to space-split
|
97 |
+
for punctuation in [r"[.?!…]", r"[:,;—]"]:
|
98 |
+
pattern = f"(?:(?<={punctuation})|(?<={punctuation}[\"'»])) "
|
99 |
+
# Attempt to split on that punctuation
|
100 |
+
splits = re.split(pattern, text)
|
101 |
+
if len(splits) > 1:
|
102 |
+
break
|
103 |
+
else:
|
104 |
+
# If we didn't break out, just do whitespace split
|
105 |
+
splits = text.split(" ")
|
106 |
+
|
107 |
+
# Use resplit_strings to chunk it about halfway
|
108 |
+
left, right = resplit_strings(splits)
|
109 |
+
# Recurse
|
110 |
+
return recursive_split(left, lang=lang) + recursive_split(right, lang=lang)
|
111 |
+
|
112 |
+
|
113 |
+
def segment_and_tokenize(long_text, lang="a"):
|
114 |
+
"""
|
115 |
+
Takes a large text, optionally normalizes or cleans it,
|
116 |
+
then breaks it into a list of (segment_text, segment_phonemes).
|
117 |
+
"""
|
118 |
+
# Additional cleaning if you want:
|
119 |
+
# long_text = normalize_text(long_text) # your existing function
|
120 |
+
# We chunk it up using recursive_split
|
121 |
+
segments = recursive_split(long_text, lang=lang)
|
122 |
+
return segments
|
123 |
+
|
124 |
+
|
125 |
+
# -------------- Normalization & Phonemization Routines -------------- #
|
126 |
+
def parens_to_angles(s):
|
127 |
+
return s.replace("(", "«").replace(")", "»")
|
128 |
+
|
129 |
+
|
130 |
+
def split_num(num):
|
131 |
+
num = num.group()
|
132 |
+
if "." in num:
|
133 |
+
return num
|
134 |
+
elif ":" in num:
|
135 |
+
h, m = [int(n) for n in num.split(":")]
|
136 |
+
if m == 0:
|
137 |
+
return f"{h} o'clock"
|
138 |
+
elif m < 10:
|
139 |
+
return f"{h} oh {m}"
|
140 |
+
return f"{h} {m}"
|
141 |
+
year = int(num[:4])
|
142 |
+
if year < 1100 or year % 1000 < 10:
|
143 |
+
return num
|
144 |
+
left, right = num[:2], int(num[2:4])
|
145 |
+
s = "s" if num.endswith("s") else ""
|
146 |
+
if 100 <= year % 1000 <= 999:
|
147 |
+
if right == 0:
|
148 |
+
return f"{left} hundred{s}"
|
149 |
+
elif right < 10:
|
150 |
+
return f"{left} oh {right}{s}"
|
151 |
+
return f"{left} {right}{s}"
|
152 |
+
|
153 |
+
|
154 |
+
def flip_money(m):
|
155 |
+
m = m.group()
|
156 |
+
bill = "dollar" if m[0] == "$" else "pound"
|
157 |
+
if m[-1].isalpha():
|
158 |
+
return f"{m[1:]} {bill}s"
|
159 |
+
elif "." not in m:
|
160 |
+
s = "" if m[1:] == "1" else "s"
|
161 |
+
return f"{m[1:]} {bill}{s}"
|
162 |
+
b, c = m[1:].split(".")
|
163 |
+
s = "" if b == "1" else "s"
|
164 |
+
c = int(c.ljust(2, "0"))
|
165 |
+
coins = (
|
166 |
+
f"cent{'' if c == 1 else 's'}"
|
167 |
+
if m[0] == "$"
|
168 |
+
else ("penny" if c == 1 else "pence")
|
169 |
+
)
|
170 |
+
return f"{b} {bill}{s} and {c} {coins}"
|
171 |
+
|
172 |
+
|
173 |
+
def point_num(num):
|
174 |
+
a, b = num.group().split(".")
|
175 |
+
return " point ".join([a, " ".join(b)])
|
176 |
+
|
177 |
+
|
178 |
+
def normalize_text(text):
|
179 |
+
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
180 |
+
text = text.replace("«", chr(8220)).replace("»", chr(8221))
|
181 |
+
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
|
182 |
+
text = parens_to_angles(text)
|
183 |
+
|
184 |
+
# Replace some common full-width punctuation in CJK:
|
185 |
+
for a, b in zip("、。!,:;?", ",.!,:;?"):
|
186 |
+
text = text.replace(a, b + " ")
|
187 |
+
|
188 |
+
text = re.sub(r"[^\S \n]", " ", text)
|
189 |
+
text = re.sub(r" +", " ", text)
|
190 |
+
text = re.sub(r"(?<=\n) +(?=\n)", "", text)
|
191 |
+
text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
|
192 |
+
text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)
|
193 |
+
text = re.sub(r"\b(?:Ms\.|MS\.(?= [A-Z]))", "Miss", text)
|
194 |
+
text = re.sub(r"\b(?:Mrs\.|MRS\.(?= [A-Z]))", "Mrs", text)
|
195 |
+
text = re.sub(r"\betc\.(?! [A-Z])", "etc", text)
|
196 |
+
text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
|
197 |
+
text = re.sub(
|
198 |
+
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)",
|
199 |
+
split_num,
|
200 |
+
text,
|
201 |
+
)
|
202 |
+
text = re.sub(r"(?<=\d),(?=\d)", "", text)
|
203 |
+
text = re.sub(
|
204 |
+
r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
|
205 |
+
flip_money,
|
206 |
+
text,
|
207 |
+
)
|
208 |
+
text = re.sub(r"\d*\.\d+", point_num, text)
|
209 |
+
text = re.sub(r"(?<=\d)-(?=\d)", " to ", text) # Could be minus; adjust if needed
|
210 |
+
text = re.sub(r"(?<=\d)S", " S", text)
|
211 |
+
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
|
212 |
+
text = re.sub(r"(?<=X')S\b", "s", text)
|
213 |
+
text = re.sub(
|
214 |
+
r"(?:[A-Za-z]\.){2,} [a-z]", lambda m: m.group().replace(".", "-"), text
|
215 |
+
)
|
216 |
+
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
|
217 |
+
return text.strip()
|
218 |
+
|
219 |
+
|
220 |
+
# -------------------------------------------------------------------
|
221 |
+
# Vocab and Symbol Mapping
|
222 |
+
# -------------------------------------------------------------------
|
223 |
+
def get_vocab():
|
224 |
+
_pad = "$"
|
225 |
+
_punctuation = ';:,.!?¡¿—…"«»“” '
|
226 |
+
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
227 |
+
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
228 |
+
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
|
229 |
+
dicts = {}
|
230 |
+
for i, s in enumerate(symbols):
|
231 |
+
dicts[s] = i
|
232 |
+
return dicts
|
233 |
+
|
234 |
+
|
235 |
+
VOCAB = get_vocab()
|
236 |
+
|
237 |
+
|
238 |
+
def tokenize(ps: str):
|
239 |
+
"""Convert the phoneme string into integer tokens based on VOCAB."""
|
240 |
+
return [VOCAB.get(p) for p in ps if p in VOCAB]
|
241 |
+
|
242 |
+
|
243 |
+
# -------------------------------------------------------------------
|
244 |
+
# Initialize a simple phonemizer
|
245 |
+
# For English:
|
246 |
+
# 'a' ~ en-us
|
247 |
+
# 'b' ~ en-gb
|
248 |
+
# -------------------------------------------------------------------
|
249 |
+
phonemizers = dict(
|
250 |
+
a=phonemizer_backend.EspeakBackend(
|
251 |
+
language="en-us", preserve_punctuation=True, with_stress=True
|
252 |
+
),
|
253 |
+
b=phonemizer_backend.EspeakBackend(
|
254 |
+
language="en-gb", preserve_punctuation=True, with_stress=True
|
255 |
+
),
|
256 |
+
# You can add more, e.g. 'j': some Japanese phonemizer, etc.
|
257 |
+
)
|
258 |
+
|
259 |
+
|
260 |
+
def phonemize_text(text, lang="a", do_normalize=True):
|
261 |
+
if do_normalize:
|
262 |
+
text = normalize_text(text)
|
263 |
+
ps_list = phonemizers[lang].phonemize([text])
|
264 |
+
ps = ps_list[0] if ps_list else ""
|
265 |
+
|
266 |
+
# Some custom replacements (from your code)
|
267 |
+
ps = ps.replace("kəkˈoːɹoʊ", "kˈoʊkəɹoʊ").replace("kəkˈɔːɹəʊ", "kˈəʊkəɹəʊ")
|
268 |
+
ps = ps.replace("ʲ", "j").replace("r", "ɹ").replace("x", "k").replace("ɬ", "l")
|
269 |
+
# Example: insert space before "hˈʌndɹɪd" if there's a letter, e.g. "nˈaɪn" => "nˈaɪn hˈʌndɹɪd"
|
270 |
+
ps = re.sub(r"(?<=[a-zɹː])(?=hˈʌndɹɪd)", " ", ps)
|
271 |
+
# "z" at the end of a word -> remove space (just your snippet)
|
272 |
+
ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', "z", ps)
|
273 |
+
# If lang is 'a', handle "ninety" => "ninedi"? Just from your snippet:
|
274 |
+
if lang == "a":
|
275 |
+
ps = re.sub(r"(?<=nˈaɪn)ti(?!ː)", "di", ps)
|
276 |
+
|
277 |
+
# Only keep valid symbols
|
278 |
+
ps = "".join(p for p in ps if p in VOCAB)
|
279 |
+
return ps.strip()
|
280 |
+
|
281 |
+
|
282 |
+
# -------------------------------------------------------------------
|
283 |
+
# Utility for generating text masks
|
284 |
+
# -------------------------------------------------------------------
|
285 |
+
def length_to_mask(lengths):
|
286 |
+
# lengths is a Tensor of shape [B], containing the text length for each batch
|
287 |
+
max_len = lengths.max()
|
288 |
+
row_ids = torch.arange(max_len, device=lengths.device).unsqueeze(0)
|
289 |
+
mask = row_ids.expand(lengths.shape[0], -1)
|
290 |
+
return (mask + 1) > lengths.unsqueeze(1)
|
291 |
+
|
292 |
+
|
293 |
+
# -------------------------------------------------------------------
|
294 |
+
# The forward pass for inference (from your snippet).
|
295 |
+
# This version references `model.predictor`, `model.decoder`, etc.
|
296 |
+
# -------------------------------------------------------------------
|
297 |
+
@torch.no_grad()
|
298 |
+
def forward_tts(model, tokens, ref_s, speed=1.0):
|
299 |
+
"""
|
300 |
+
model: Munch with submodels: bert, bert_encoder, predictor, decoder, text_encoder
|
301 |
+
tokens: list[int], the tokenized input (without [0, ... , 0] yet)
|
302 |
+
ref_s: reference embedding (torch.Tensor)
|
303 |
+
speed: float, speed factor
|
304 |
+
"""
|
305 |
+
device = ref_s.device
|
306 |
+
tokens_t = torch.LongTensor([[0, *tokens, 0]]).to(device) # add boundary tokens
|
307 |
+
input_lengths = torch.LongTensor([tokens_t.shape[-1]]).to(device)
|
308 |
+
text_mask = length_to_mask(input_lengths).to(device)
|
309 |
+
|
310 |
+
# 1. Encode with BERT
|
311 |
+
bert_dur = model.bert(tokens_t, attention_mask=(~text_mask).int())
|
312 |
+
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
313 |
+
|
314 |
+
# 2. Prosody predictor
|
315 |
+
s = ref_s[
|
316 |
+
:, 128:
|
317 |
+
] # from your snippet: the last 128 is ???, or the first 128 is ???
|
318 |
+
|
319 |
+
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
320 |
+
x, _ = model.predictor.lstm(d)
|
321 |
+
duration = model.predictor.duration_proj(x)
|
322 |
+
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
323 |
+
pred_dur = torch.round(duration).clamp(min=1).long()
|
324 |
+
|
325 |
+
# 3. Expand alignment
|
326 |
+
total_len = pred_dur.sum().item()
|
327 |
+
pred_aln_trg = torch.zeros(input_lengths, total_len, device=device)
|
328 |
+
c_frame = 0
|
329 |
+
for i in range(pred_aln_trg.size(0)):
|
330 |
+
n = pred_dur[0, i].item()
|
331 |
+
pred_aln_trg[i, c_frame : c_frame + n] = 1
|
332 |
+
c_frame += n
|
333 |
+
|
334 |
+
# 4. Run F0 + Noise predictor
|
335 |
+
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0)
|
336 |
+
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
337 |
+
|
338 |
+
# 5. Text encoder -> asr
|
339 |
+
t_en = model.text_encoder(tokens_t, input_lengths, text_mask)
|
340 |
+
asr = t_en @ pred_aln_trg.unsqueeze(0)
|
341 |
+
|
342 |
+
# 6. Decode audio
|
343 |
+
audio = model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]) # B x audio_len
|
344 |
+
return audio.squeeze().cpu().numpy()
|
345 |
+
|
346 |
+
|
347 |
+
def generate_tts(model, text, voicepack, lang="a", speed=1.0):
|
348 |
+
"""
|
349 |
+
model: the Munch returned by build_model(...)
|
350 |
+
text: the input text (string)
|
351 |
+
voicepack: the torch Tensor reference embedding, or a dict of them
|
352 |
+
lang: 'a' or 'b' or etc. from your phonemizers
|
353 |
+
speed: speech speed factor
|
354 |
+
sample_rate: sampling rate for the output
|
355 |
+
"""
|
356 |
+
# 1. Phonemize
|
357 |
+
ps = phonemize_text(text, lang=lang, do_normalize=True)
|
358 |
+
tokens = tokenize(ps)
|
359 |
+
if not tokens:
|
360 |
+
return None, ps
|
361 |
+
|
362 |
+
# 2. Retrieve reference style
|
363 |
+
# If your voicepack is a single embedding for all lengths, adapt as needed.
|
364 |
+
# If your voicepack is something like `voicepack[len(tokens)]`, do that.
|
365 |
+
# If you have multiple voices, you might do something else.
|
366 |
+
try:
|
367 |
+
ref_s = voicepack[len(tokens)]
|
368 |
+
except:
|
369 |
+
# fallback if len(tokens) is out of range
|
370 |
+
ref_s = voicepack[-1]
|
371 |
+
ref_s = ref_s.to("cpu" if not next(model.bert.parameters()).is_cuda else "cuda")
|
372 |
+
|
373 |
+
# 3. Generate
|
374 |
+
audio = forward_tts(model, tokens, ref_s, speed=speed)
|
375 |
+
return audio, ps
|
376 |
+
|
377 |
+
|
378 |
+
def generate_long_form_tts(model, full_text, voicepack, lang="a", speed=1.0):
|
379 |
+
"""
|
380 |
+
Generate TTS for a large `full_text`, splitting it into smaller segments
|
381 |
+
and concatenating the resulting audio.
|
382 |
+
|
383 |
+
Returns: (np.float32 array) final_audio, list_of_segment_phonemes
|
384 |
+
"""
|
385 |
+
# 1. Segment the text
|
386 |
+
segments = segment_and_tokenize(full_text, lang=lang)
|
387 |
+
# segments is a list of (seg_text, seg_phonemes)
|
388 |
+
|
389 |
+
# 2. For each segment, call `generate_tts(...)`
|
390 |
+
audio_chunks = []
|
391 |
+
all_phonemes = []
|
392 |
+
for i, (seg_text, seg_ps) in enumerate(segments, 1):
|
393 |
+
print(f"[LongForm] Generating chunk {i}/{len(segments)}: {seg_text[:40]}...")
|
394 |
+
audio, used_phonemes = generate_tts(
|
395 |
+
model, seg_text, voicepack, lang=lang, speed=speed
|
396 |
+
)
|
397 |
+
if audio is not None:
|
398 |
+
audio_chunks.append(audio)
|
399 |
+
all_phonemes.append(used_phonemes)
|
400 |
+
else:
|
401 |
+
print(f"[LongForm] Skipped empty segment {i}...")
|
402 |
+
|
403 |
+
if not audio_chunks:
|
404 |
+
return None, []
|
405 |
+
|
406 |
+
# 3. Concatenate the audio
|
407 |
+
final_audio = np.concatenate(audio_chunks, axis=0)
|
408 |
+
return final_audio, all_phonemes
|
409 |
+
|
410 |
+
|
411 |
+
# -------------------------------------------------------------------
|
412 |
+
# Main CLI
|
413 |
+
# -------------------------------------------------------------------
|
414 |
+
def main():
|
415 |
+
parser = argparse.ArgumentParser(description="Kokoro-StyleTTS2 CLI Example")
|
416 |
+
parser.add_argument(
|
417 |
+
"--model",
|
418 |
+
type=str,
|
419 |
+
default="pretrained_models/Kokoro/kokoro-v0_19.pth",
|
420 |
+
help="Path to your model checkpoint (e.g. kokoro-v0_19.pth).",
|
421 |
+
)
|
422 |
+
parser.add_argument(
|
423 |
+
"--config",
|
424 |
+
type=str,
|
425 |
+
default="pretrained_models/Kokoro/config.json",
|
426 |
+
help="Path to config.json (used by build_model).",
|
427 |
+
)
|
428 |
+
parser.add_argument(
|
429 |
+
"--text",
|
430 |
+
type=str,
|
431 |
+
default="Hello world! This is Kokoro, a new text-to-speech model based on StyleTTS2 from 2024!",
|
432 |
+
help="Text to be converted into speech.",
|
433 |
+
)
|
434 |
+
parser.add_argument(
|
435 |
+
"--voicepack",
|
436 |
+
type=str,
|
437 |
+
default="pretrained_models/Kokoro/voices/af.pt",
|
438 |
+
help="Path to a .pt file for your reference embedding(s).",
|
439 |
+
)
|
440 |
+
parser.add_argument(
|
441 |
+
"--output", type=str, default="output.wav", help="Output WAV filename."
|
442 |
+
)
|
443 |
+
parser.add_argument(
|
444 |
+
"--speed",
|
445 |
+
type=float,
|
446 |
+
default=1.0,
|
447 |
+
help="Speech speed factor, e.g. 0.8 slower, 1.2 faster, etc.",
|
448 |
+
)
|
449 |
+
parser.add_argument(
|
450 |
+
"--device",
|
451 |
+
type=str,
|
452 |
+
default="cpu",
|
453 |
+
choices=["cpu", "cuda"],
|
454 |
+
help="Device to run inference on.",
|
455 |
+
)
|
456 |
+
args = parser.parse_args()
|
457 |
+
|
458 |
+
# 1. Build model using your local build_model function
|
459 |
+
# (which loads TextEncoder, Decoder, etc. and returns a Munch).
|
460 |
+
if not os.path.isfile(args.config):
|
461 |
+
raise FileNotFoundError(f"config.json not found: {args.config}")
|
462 |
+
|
463 |
+
# Optionally load config as Munch (depends on your build_model usage)
|
464 |
+
# But your snippet does something like:
|
465 |
+
# with open(config, 'r') as r: ...
|
466 |
+
# ...
|
467 |
+
# model = build_model(path, device)
|
468 |
+
# We'll do the same but in a simpler form:
|
469 |
+
device = (
|
470 |
+
args.device if (args.device == "cuda" and torch.cuda.is_available()) else "cpu"
|
471 |
+
)
|
472 |
+
print(f"Loading model from: {args.model}")
|
473 |
+
model = build_model(
|
474 |
+
args.model, device
|
475 |
+
) # This requires that `args.model` is the checkpoint path
|
476 |
+
|
477 |
+
# Because `build_model` returns a Munch (dict of submodules),
|
478 |
+
# we can't just do `model.eval()`, we must set each submodule to eval:
|
479 |
+
for k, subm in model.items():
|
480 |
+
if isinstance(subm, torch.nn.Module):
|
481 |
+
subm.eval()
|
482 |
+
|
483 |
+
# 2. Load voicepack
|
484 |
+
if not os.path.isfile(args.voicepack):
|
485 |
+
raise FileNotFoundError(f"Voicepack file not found: {args.voicepack}")
|
486 |
+
print(f"Loading voicepack from: {args.voicepack}")
|
487 |
+
vp = torch.load(args.voicepack, map_location=device)
|
488 |
+
# If your voicepack is an nn.Module, set it to eval as well
|
489 |
+
if isinstance(vp, torch.nn.Module):
|
490 |
+
vp.eval()
|
491 |
+
|
492 |
+
# 3. Generate audio
|
493 |
+
print(f"Generating speech for text: {args.text}")
|
494 |
+
audio, phonemes = generate_long_form_tts(
|
495 |
+
model, args.text, vp, lang="a", speed=args.speed
|
496 |
+
)
|
497 |
+
if audio is None:
|
498 |
+
print("No tokens were generated (maybe empty text?). Exiting.")
|
499 |
+
return
|
500 |
+
|
501 |
+
# 4. Write WAV
|
502 |
+
print(f"Writing output to: {args.output}")
|
503 |
+
sf.write(args.output, audio, 22050)
|
504 |
+
|
505 |
+
print("Finished!")
|
506 |
+
print(f"Phonemes used: {phonemes}")
|
507 |
+
|
508 |
+
|
509 |
+
if __name__ == "__main__":
|
510 |
+
main()
|
tts_cli_op.py
ADDED
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# tts_cli_op.py
|
3 |
+
"""
|
4 |
+
Example CLI for generating audio with Kokoro-StyleTTS2.
|
5 |
+
|
6 |
+
Usage:
|
7 |
+
python tts_cli.py \
|
8 |
+
--model /path/to/kokoro-v0_19.pth \
|
9 |
+
--config /path/to/config.json \
|
10 |
+
--text "Hello, my stinking friends from 1906! You stink." \
|
11 |
+
--voicepack /path/to/af.pt \
|
12 |
+
--output output.wav
|
13 |
+
|
14 |
+
Make sure:
|
15 |
+
1. `models.py` is in the same folder (with `build_model`, `Decoder`, etc.).
|
16 |
+
2. You have installed the needed libraries:
|
17 |
+
pip install torch phonemizer munch soundfile pyyaml
|
18 |
+
3. The model is a checkpoint that your `build_model` can load.
|
19 |
+
|
20 |
+
Adapt as needed!
|
21 |
+
"""
|
22 |
+
|
23 |
+
import argparse
|
24 |
+
import os
|
25 |
+
import re
|
26 |
+
import torch
|
27 |
+
import soundfile as sf
|
28 |
+
import numpy as np
|
29 |
+
from openphonemizer import OpenPhonemizer
|
30 |
+
from typing import List
|
31 |
+
import joblib
|
32 |
+
|
33 |
+
# --------------------------------------------------------------------
|
34 |
+
# Import from your local `models.py` (requires that file to be present).
|
35 |
+
# This example assumes `build_model` loads the entire TTS submodules
|
36 |
+
# (bert, bert_encoder, predictor, decoder, text_encoder).
|
37 |
+
# --------------------------------------------------------------------
|
38 |
+
from models import build_model
|
39 |
+
|
40 |
+
|
41 |
+
def resplit_strings(arr):
|
42 |
+
"""
|
43 |
+
Given a list of string tokens (e.g. words, phrases), tries to
|
44 |
+
split them into two sub-lists whose total lengths are as balanced
|
45 |
+
as possible. The goal is to chunk a large string in half without
|
46 |
+
splitting in the middle of a word.
|
47 |
+
"""
|
48 |
+
if not arr:
|
49 |
+
return "", ""
|
50 |
+
if len(arr) == 1:
|
51 |
+
return arr[0], ""
|
52 |
+
|
53 |
+
min_diff = float("inf")
|
54 |
+
best_split = 0
|
55 |
+
lengths = [len(s) for s in arr]
|
56 |
+
spaces = len(arr) - 1
|
57 |
+
left_len = 0
|
58 |
+
right_len = sum(lengths) + spaces
|
59 |
+
|
60 |
+
for i in range(1, len(arr)):
|
61 |
+
# Add current word + space to left side
|
62 |
+
left_len += lengths[i - 1] + (1 if i > 1 else 0)
|
63 |
+
# Remove from right side
|
64 |
+
right_len -= lengths[i - 1] + 1
|
65 |
+
diff = abs(left_len - right_len)
|
66 |
+
if diff < min_diff:
|
67 |
+
min_diff = diff
|
68 |
+
best_split = i
|
69 |
+
|
70 |
+
return " ".join(arr[:best_split]), " ".join(arr[best_split:])
|
71 |
+
|
72 |
+
|
73 |
+
def recursive_split(text, lang="a"):
|
74 |
+
"""
|
75 |
+
Splits a piece of text into smaller segments so that
|
76 |
+
each segment's phoneme length < some ~limit (~500 tokens).
|
77 |
+
"""
|
78 |
+
# We'll reuse your existing `phonemize_text` + `tokenize` from script 1
|
79 |
+
# to see if it is < 512 tokens. If it is, return it as a single chunk.
|
80 |
+
# Otherwise, split on punctuation or whitespace and recurse.
|
81 |
+
|
82 |
+
# 1. Phonemize first, check length
|
83 |
+
ps = phonemize_text(text, do_normalize=True)
|
84 |
+
tokens = tokenize(ps)
|
85 |
+
if len(tokens) < 512:
|
86 |
+
return [(text, ps)]
|
87 |
+
|
88 |
+
# If too large, we split on certain punctuation or fallback to whitespace
|
89 |
+
# We'll look for punctuation that often indicates sentence boundaries
|
90 |
+
# If none found, fallback to space-split
|
91 |
+
for punctuation in [r"[.?!…]", r"[:,;—]"]:
|
92 |
+
pattern = f"(?:(?<={punctuation})|(?<={punctuation}[\"'»])) "
|
93 |
+
# Attempt to split on that punctuation
|
94 |
+
splits = re.split(pattern, text)
|
95 |
+
if len(splits) > 1:
|
96 |
+
break
|
97 |
+
else:
|
98 |
+
# If we didn't break out, just do whitespace split
|
99 |
+
splits = text.split(" ")
|
100 |
+
|
101 |
+
# Use resplit_strings to chunk it about halfway
|
102 |
+
left, right = resplit_strings(splits)
|
103 |
+
# Recurse
|
104 |
+
return recursive_split(left) + recursive_split(right)
|
105 |
+
|
106 |
+
|
107 |
+
def segment_and_tokenize(long_text, lang="a"):
|
108 |
+
"""
|
109 |
+
Takes a large text, optionally normalizes or cleans it,
|
110 |
+
then breaks it into a list of (segment_text, segment_phonemes).
|
111 |
+
"""
|
112 |
+
# Additional cleaning if you want:
|
113 |
+
# long_text = normalize_text(long_text) # your existing function
|
114 |
+
# We chunk it up using recursive_split
|
115 |
+
segments = recursive_split(long_text)
|
116 |
+
return segments
|
117 |
+
|
118 |
+
|
119 |
+
# -------------- Normalization & Phonemization Routines -------------- #
|
120 |
+
def parens_to_angles(s):
|
121 |
+
return s.replace("(", "«").replace(")", "»")
|
122 |
+
|
123 |
+
|
124 |
+
def split_num(num):
|
125 |
+
num = num.group()
|
126 |
+
if "." in num:
|
127 |
+
return num
|
128 |
+
elif ":" in num:
|
129 |
+
h, m = [int(n) for n in num.split(":")]
|
130 |
+
if m == 0:
|
131 |
+
return f"{h} o'clock"
|
132 |
+
elif m < 10:
|
133 |
+
return f"{h} oh {m}"
|
134 |
+
return f"{h} {m}"
|
135 |
+
year = int(num[:4])
|
136 |
+
if year < 1100 or year % 1000 < 10:
|
137 |
+
return num
|
138 |
+
left, right = num[:2], int(num[2:4])
|
139 |
+
s = "s" if num.endswith("s") else ""
|
140 |
+
if 100 <= year % 1000 <= 999:
|
141 |
+
if right == 0:
|
142 |
+
return f"{left} hundred{s}"
|
143 |
+
elif right < 10:
|
144 |
+
return f"{left} oh {right}{s}"
|
145 |
+
return f"{left} {right}{s}"
|
146 |
+
|
147 |
+
|
148 |
+
def flip_money(m):
|
149 |
+
m = m.group()
|
150 |
+
bill = "dollar" if m[0] == "$" else "pound"
|
151 |
+
if m[-1].isalpha():
|
152 |
+
return f"{m[1:]} {bill}s"
|
153 |
+
elif "." not in m:
|
154 |
+
s = "" if m[1:] == "1" else "s"
|
155 |
+
return f"{m[1:]} {bill}{s}"
|
156 |
+
b, c = m[1:].split(".")
|
157 |
+
s = "" if b == "1" else "s"
|
158 |
+
c = int(c.ljust(2, "0"))
|
159 |
+
coins = (
|
160 |
+
f"cent{'' if c == 1 else 's'}"
|
161 |
+
if m[0] == "$"
|
162 |
+
else ("penny" if c == 1 else "pence")
|
163 |
+
)
|
164 |
+
return f"{b} {bill}{s} and {c} {coins}"
|
165 |
+
|
166 |
+
|
167 |
+
def point_num(num):
|
168 |
+
a, b = num.group().split(".")
|
169 |
+
return " point ".join([a, " ".join(b)])
|
170 |
+
|
171 |
+
|
172 |
+
def normalize_text(text):
|
173 |
+
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
174 |
+
text = text.replace("«", chr(8220)).replace("»", chr(8221))
|
175 |
+
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
|
176 |
+
text = parens_to_angles(text)
|
177 |
+
|
178 |
+
# Replace some common full-width punctuation in CJK:
|
179 |
+
for a, b in zip("、。!,:;?", ",.!,:;?"):
|
180 |
+
text = text.replace(a, b + " ")
|
181 |
+
|
182 |
+
text = re.sub(r"[^\S \n]", " ", text)
|
183 |
+
text = re.sub(r" +", " ", text)
|
184 |
+
text = re.sub(r"(?<=\n) +(?=\n)", "", text)
|
185 |
+
text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
|
186 |
+
text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)
|
187 |
+
text = re.sub(r"\b(?:Ms\.|MS\.(?= [A-Z]))", "Miss", text)
|
188 |
+
text = re.sub(r"\b(?:Mrs\.|MRS\.(?= [A-Z]))", "Mrs", text)
|
189 |
+
text = re.sub(r"\betc\.(?! [A-Z])", "etc", text)
|
190 |
+
text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
|
191 |
+
text = re.sub(
|
192 |
+
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)",
|
193 |
+
split_num,
|
194 |
+
text,
|
195 |
+
)
|
196 |
+
text = re.sub(r"(?<=\d),(?=\d)", "", text)
|
197 |
+
text = re.sub(
|
198 |
+
r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
|
199 |
+
flip_money,
|
200 |
+
text,
|
201 |
+
)
|
202 |
+
text = re.sub(r"\d*\.\d+", point_num, text)
|
203 |
+
text = re.sub(r"(?<=\d)-(?=\d)", " to ", text) # Could be minus; adjust if needed
|
204 |
+
text = re.sub(r"(?<=\d)S", " S", text)
|
205 |
+
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
|
206 |
+
text = re.sub(r"(?<=X')S\b", "s", text)
|
207 |
+
text = re.sub(
|
208 |
+
r"(?:[A-Za-z]\.){2,} [a-z]", lambda m: m.group().replace(".", "-"), text
|
209 |
+
)
|
210 |
+
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
|
211 |
+
return text.strip()
|
212 |
+
|
213 |
+
|
214 |
+
# -------------------------------------------------------------------
|
215 |
+
# Vocab and Symbol Mapping
|
216 |
+
# -------------------------------------------------------------------
|
217 |
+
def get_vocab():
|
218 |
+
_pad = "$"
|
219 |
+
_punctuation = ';:,.!?¡¿—…"«»“” '
|
220 |
+
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
221 |
+
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
222 |
+
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
|
223 |
+
dicts = {}
|
224 |
+
for i, s in enumerate(symbols):
|
225 |
+
dicts[s] = i
|
226 |
+
return dicts
|
227 |
+
|
228 |
+
|
229 |
+
VOCAB = get_vocab()
|
230 |
+
|
231 |
+
|
232 |
+
def tokenize(ps: str):
|
233 |
+
"""Convert the phoneme string into integer tokens based on VOCAB."""
|
234 |
+
return [VOCAB.get(p) for p in ps if p in VOCAB]
|
235 |
+
|
236 |
+
|
237 |
+
# -------------------------------------------------------------------
|
238 |
+
# Initialize a simple phonemizer
|
239 |
+
# For English:
|
240 |
+
# 'a' ~ en-us
|
241 |
+
# 'b' ~ en-gb
|
242 |
+
# -------------------------------------------------------------------
|
243 |
+
|
244 |
+
# Wrapper around OpenPhonemizer to enable batch phonemization
|
245 |
+
open_phonemizer = OpenPhonemizer()
|
246 |
+
|
247 |
+
|
248 |
+
class Phonemizer:
|
249 |
+
def __init__(self, open_phonemizer):
|
250 |
+
"""
|
251 |
+
open_phonemizer: an instance of OpenPhonemizer()
|
252 |
+
"""
|
253 |
+
self.open_phonemizer = open_phonemizer
|
254 |
+
|
255 |
+
def phonemize(
|
256 |
+
self,
|
257 |
+
text: List[str],
|
258 |
+
njobs: int = 4,
|
259 |
+
strip: bool = False,
|
260 |
+
) -> List[str]:
|
261 |
+
"""
|
262 |
+
Phonemizes a list of input strings using OpenPhonemizer (which
|
263 |
+
itself only supports single-string input).
|
264 |
+
|
265 |
+
Parameters
|
266 |
+
----------
|
267 |
+
text : list of str
|
268 |
+
Each element is an utterance (or line) to be phonemized.
|
269 |
+
njobs : int
|
270 |
+
The number of parallel jobs for phonemization.
|
271 |
+
strip : bool
|
272 |
+
Not used by OpenPhonemizer directly, but you can implement
|
273 |
+
an additional final “strip” step if needed.
|
274 |
+
|
275 |
+
Returns
|
276 |
+
-------
|
277 |
+
list of str
|
278 |
+
The phonemized text strings, in the same order as input.
|
279 |
+
"""
|
280 |
+
|
281 |
+
if not isinstance(text, list) or any(not isinstance(x, str) for x in text):
|
282 |
+
raise ValueError("`text` must be a list of strings.")
|
283 |
+
|
284 |
+
# Optionally do any pre-processing you want here ...
|
285 |
+
# e.g. text = [self._some_preprocess(line) for line in text]
|
286 |
+
|
287 |
+
# If we only have one job, do it in a single loop
|
288 |
+
if njobs == 1:
|
289 |
+
return [self._phonemize_single(utterance, strip) for utterance in text]
|
290 |
+
|
291 |
+
# Otherwise, we can split `text` into chunks and process in parallel
|
292 |
+
# The easiest approach is to chunk the entire list into smaller sublists
|
293 |
+
# and phonemize each chunk. Then flatten them.
|
294 |
+
# For large corpora, you might want a more sophisticated approach.
|
295 |
+
chunked_results = joblib.Parallel(n_jobs=njobs)(
|
296 |
+
joblib.delayed(self._phonemize_single)(t, strip) for t in text
|
297 |
+
)
|
298 |
+
return chunked_results
|
299 |
+
|
300 |
+
def _phonemize_single(self, line: str, strip: bool) -> str:
|
301 |
+
"""
|
302 |
+
Phonemize a single line using openphonemizer.
|
303 |
+
"""
|
304 |
+
# OpenPhonemizer usage:
|
305 |
+
# out_str = self.open_phonemizer(line)
|
306 |
+
# (That returns the phonemes as a string.)
|
307 |
+
phonemes = self.open_phonemizer(line)
|
308 |
+
|
309 |
+
# Implement a post-strip if you want to mimic removing trailing delimiters
|
310 |
+
# (this is just a placeholder for demonstration).
|
311 |
+
if strip:
|
312 |
+
phonemes = phonemes.rstrip()
|
313 |
+
|
314 |
+
return phonemes
|
315 |
+
|
316 |
+
|
317 |
+
# initiate the wrapper:
|
318 |
+
phonemizer_wrapper = Phonemizer(open_phonemizer)
|
319 |
+
|
320 |
+
|
321 |
+
def phonemize_text(text, lang="a", do_normalize=True):
|
322 |
+
if do_normalize:
|
323 |
+
text = normalize_text(text)
|
324 |
+
ps_list = phonemizer_wrapper.phonemize([text])
|
325 |
+
ps = ps_list[0] if ps_list else ""
|
326 |
+
|
327 |
+
# Some custom replacements
|
328 |
+
ps = ps.replace("kəkˈoːɹoʊ", "kˈoʊkəɹoʊ").replace("kəkˈɔːɹəʊ", "kˈəʊkəɹəʊ")
|
329 |
+
ps = ps.replace("ʲ", "j").replace("r", "ɹ").replace("x", "k").replace("ɬ", "l")
|
330 |
+
# Example: insert space before "hˈʌndɹɪd" if there's a letter, e.g. "nˈaɪn" => "nˈaɪn hˈʌndɹɪd"
|
331 |
+
ps = re.sub(r"(?<=[a-zɹː])(?=hˈʌndɹɪd)", " ", ps)
|
332 |
+
# "z" at the end of a word -> remove space (just your snippet)
|
333 |
+
ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', "z", ps)
|
334 |
+
# Handle "ninety" => "ninedi"? Just from your snippet:
|
335 |
+
# If lang is 'a', handle "ninety" => "ninedi"? Just from your snippet:
|
336 |
+
if lang == "a":
|
337 |
+
ps = re.sub(r"(?<=nˈaɪn)ti(?!ː)", "di", ps)
|
338 |
+
|
339 |
+
# Only keep valid symbols
|
340 |
+
ps = "".join(p for p in ps if p in VOCAB)
|
341 |
+
return ps.strip()
|
342 |
+
|
343 |
+
|
344 |
+
# -------------------------------------------------------------------
|
345 |
+
# Utility for generating text masks
|
346 |
+
# -------------------------------------------------------------------
|
347 |
+
def length_to_mask(lengths):
|
348 |
+
# lengths is a Tensor of shape [B], containing the text length for each batch
|
349 |
+
max_len = lengths.max()
|
350 |
+
row_ids = torch.arange(max_len, device=lengths.device).unsqueeze(0)
|
351 |
+
mask = row_ids.expand(lengths.shape[0], -1)
|
352 |
+
return (mask + 1) > lengths.unsqueeze(1)
|
353 |
+
|
354 |
+
|
355 |
+
# -------------------------------------------------------------------
|
356 |
+
# The forward pass for inference (from your snippet).
|
357 |
+
# This version references `model.predictor`, `model.decoder`, etc.
|
358 |
+
# -------------------------------------------------------------------
|
359 |
+
@torch.no_grad()
|
360 |
+
def forward_tts(model, tokens, ref_s, speed=1.0):
|
361 |
+
"""
|
362 |
+
model: Munch with submodels: bert, bert_encoder, predictor, decoder, text_encoder
|
363 |
+
tokens: list[int], the tokenized input (without [0, ... , 0] yet)
|
364 |
+
ref_s: reference embedding (torch.Tensor)
|
365 |
+
speed: float, speed factor
|
366 |
+
"""
|
367 |
+
device = ref_s.device
|
368 |
+
tokens_t = torch.LongTensor([[0, *tokens, 0]]).to(device) # add boundary tokens
|
369 |
+
input_lengths = torch.LongTensor([tokens_t.shape[-1]]).to(device)
|
370 |
+
text_mask = length_to_mask(input_lengths).to(device)
|
371 |
+
|
372 |
+
# 1. Encode with BERT
|
373 |
+
bert_dur = model.bert(tokens_t, attention_mask=(~text_mask).int())
|
374 |
+
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
375 |
+
|
376 |
+
# 2. Prosody predictor
|
377 |
+
s = ref_s[
|
378 |
+
:, 128:
|
379 |
+
] # from your snippet: the last 128 is ???, or the first 128 is ???
|
380 |
+
|
381 |
+
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
382 |
+
x, _ = model.predictor.lstm(d)
|
383 |
+
duration = model.predictor.duration_proj(x)
|
384 |
+
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
385 |
+
pred_dur = torch.round(duration).clamp(min=1).long()
|
386 |
+
|
387 |
+
# 3. Expand alignment
|
388 |
+
total_len = pred_dur.sum().item()
|
389 |
+
pred_aln_trg = torch.zeros(input_lengths, total_len, device=device)
|
390 |
+
c_frame = 0
|
391 |
+
for i in range(pred_aln_trg.size(0)):
|
392 |
+
n = pred_dur[0, i].item()
|
393 |
+
pred_aln_trg[i, c_frame : c_frame + n] = 1
|
394 |
+
c_frame += n
|
395 |
+
|
396 |
+
# 4. Run F0 + Noise predictor
|
397 |
+
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0)
|
398 |
+
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
399 |
+
|
400 |
+
# 5. Text encoder -> asr
|
401 |
+
t_en = model.text_encoder(tokens_t, input_lengths, text_mask)
|
402 |
+
asr = t_en @ pred_aln_trg.unsqueeze(0)
|
403 |
+
|
404 |
+
# 6. Decode audio
|
405 |
+
audio = model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]) # B x audio_len
|
406 |
+
return audio.squeeze().cpu().numpy()
|
407 |
+
|
408 |
+
|
409 |
+
def generate_tts(model, text, voicepack, lang="a", speed=1.0):
|
410 |
+
"""
|
411 |
+
model: the Munch returned by build_model(...)
|
412 |
+
text: the input text (string)
|
413 |
+
voicepack: the torch Tensor reference embedding, or a dict of them
|
414 |
+
speed: speech speed factor
|
415 |
+
sample_rate: sampling rate for the output
|
416 |
+
"""
|
417 |
+
# 1. Phonemize
|
418 |
+
ps = phonemize_text(text, do_normalize=True)
|
419 |
+
tokens = tokenize(ps)
|
420 |
+
if not tokens:
|
421 |
+
return None, ps
|
422 |
+
|
423 |
+
# 2. Retrieve reference style
|
424 |
+
# If your voicepack is a single embedding for all lengths, adapt as needed.
|
425 |
+
# If your voicepack is something like `voicepack[len(tokens)]`, do that.
|
426 |
+
# If you have multiple voices, you might do something else.
|
427 |
+
try:
|
428 |
+
ref_s = voicepack[len(tokens)]
|
429 |
+
except:
|
430 |
+
# fallback if len(tokens) is out of range
|
431 |
+
ref_s = voicepack[-1]
|
432 |
+
ref_s = ref_s.to("cpu" if not next(model.bert.parameters()).is_cuda else "cuda")
|
433 |
+
|
434 |
+
# 3. Generate
|
435 |
+
audio = forward_tts(model, tokens, ref_s, speed=speed)
|
436 |
+
return audio, ps
|
437 |
+
|
438 |
+
|
439 |
+
def generate_long_form_tts(model, full_text, voicepack, lang="a", speed=1.0):
|
440 |
+
"""
|
441 |
+
Generate TTS for a large `full_text`, splitting it into smaller segments
|
442 |
+
and concatenating the resulting audio.
|
443 |
+
|
444 |
+
Returns: (np.float32 array) final_audio, list_of_segment_phonemes
|
445 |
+
"""
|
446 |
+
# 1. Segment the text
|
447 |
+
segments = segment_and_tokenize(full_text)
|
448 |
+
# segments is a list of (seg_text, seg_phonemes)
|
449 |
+
|
450 |
+
# 2. For each segment, call `generate_tts(...)`
|
451 |
+
audio_chunks = []
|
452 |
+
all_phonemes = []
|
453 |
+
for i, (seg_text, seg_ps) in enumerate(segments, 1):
|
454 |
+
print(f"[LongForm] Generating chunk {i}/{len(segments)}: {seg_text[:40]}...")
|
455 |
+
audio, used_phonemes = generate_tts(model, seg_text, voicepack, speed=speed)
|
456 |
+
if audio is not None:
|
457 |
+
audio_chunks.append(audio)
|
458 |
+
all_phonemes.append(used_phonemes)
|
459 |
+
else:
|
460 |
+
print(f"[LongForm] Skipped empty segment {i}...")
|
461 |
+
|
462 |
+
if not audio_chunks:
|
463 |
+
return None, []
|
464 |
+
|
465 |
+
# 3. Concatenate the audio
|
466 |
+
final_audio = np.concatenate(audio_chunks, axis=0)
|
467 |
+
return final_audio, all_phonemes
|
468 |
+
|
469 |
+
|
470 |
+
# -------------------------------------------------------------------
|
471 |
+
# Main CLI
|
472 |
+
# -------------------------------------------------------------------
|
473 |
+
def main():
|
474 |
+
parser = argparse.ArgumentParser(description="Kokoro-StyleTTS2 CLI Example")
|
475 |
+
parser.add_argument(
|
476 |
+
"--model",
|
477 |
+
type=str,
|
478 |
+
default="pretrained_models/Kokoro/kokoro-v0_19.pth",
|
479 |
+
help="Path to your model checkpoint (e.g. kokoro-v0_19.pth).",
|
480 |
+
)
|
481 |
+
parser.add_argument(
|
482 |
+
"--config",
|
483 |
+
type=str,
|
484 |
+
default="pretrained_models/Kokoro/config.json",
|
485 |
+
help="Path to config.json (used by build_model).",
|
486 |
+
)
|
487 |
+
parser.add_argument(
|
488 |
+
"--text",
|
489 |
+
type=str,
|
490 |
+
default="Hello world! This is Kokoro, a new text-to-speech model based on StyleTTS2 from 2024!",
|
491 |
+
help="Text to be converted into speech.",
|
492 |
+
)
|
493 |
+
parser.add_argument(
|
494 |
+
"--voicepack",
|
495 |
+
type=str,
|
496 |
+
default="pretrained_models/Kokoro/voices/af.pt",
|
497 |
+
help="Path to a .pt file for your reference embedding(s).",
|
498 |
+
)
|
499 |
+
parser.add_argument(
|
500 |
+
"--output", type=str, default="output.wav", help="Output WAV filename."
|
501 |
+
)
|
502 |
+
parser.add_argument(
|
503 |
+
"--speed",
|
504 |
+
type=float,
|
505 |
+
default=1.2,
|
506 |
+
help="Speech speed factor, e.g. 0.8 slower, 1.2 faster, etc.",
|
507 |
+
)
|
508 |
+
parser.add_argument(
|
509 |
+
"--device",
|
510 |
+
type=str,
|
511 |
+
default="cpu",
|
512 |
+
choices=["cpu", "cuda"],
|
513 |
+
help="Device to run inference on.",
|
514 |
+
)
|
515 |
+
args = parser.parse_args()
|
516 |
+
|
517 |
+
# 1. Build model using your local build_model function
|
518 |
+
# (which loads TextEncoder, Decoder, etc. and returns a Munch).
|
519 |
+
if not os.path.isfile(args.config):
|
520 |
+
raise FileNotFoundError(f"config.json not found: {args.config}")
|
521 |
+
|
522 |
+
# Optionally load config as Munch (depends on your build_model usage)
|
523 |
+
# But your snippet does something like:
|
524 |
+
# with open(config, 'r') as r: ...
|
525 |
+
# ...
|
526 |
+
# model = build_model(path, device)
|
527 |
+
# We'll do the same but in a simpler form:
|
528 |
+
device = (
|
529 |
+
args.device if (args.device == "cuda" and torch.cuda.is_available()) else "cpu"
|
530 |
+
)
|
531 |
+
print(f"Loading model from: {args.model}")
|
532 |
+
model = build_model(
|
533 |
+
args.model, device
|
534 |
+
) # This requires that `args.model` is the checkpoint path
|
535 |
+
|
536 |
+
# Because `build_model` returns a Munch (dict of submodules),
|
537 |
+
# we can't just do `model.eval()`, we must set each submodule to eval:
|
538 |
+
for k, subm in model.items():
|
539 |
+
if isinstance(subm, torch.nn.Module):
|
540 |
+
subm.eval()
|
541 |
+
|
542 |
+
# 2. Load voicepack
|
543 |
+
if not os.path.isfile(args.voicepack):
|
544 |
+
raise FileNotFoundError(f"Voicepack file not found: {args.voicepack}")
|
545 |
+
print(f"Loading voicepack from: {args.voicepack}")
|
546 |
+
vp = torch.load(args.voicepack, map_location=device)
|
547 |
+
# If your voicepack is an nn.Module, set it to eval as well
|
548 |
+
if isinstance(vp, torch.nn.Module):
|
549 |
+
vp.eval()
|
550 |
+
|
551 |
+
# 3. Generate audio
|
552 |
+
print(f"Generating speech for text: {args.text}")
|
553 |
+
audio, phonemes = generate_long_form_tts(
|
554 |
+
model, args.text, vp, lang="a", speed=args.speed
|
555 |
+
)
|
556 |
+
if audio is None:
|
557 |
+
print("No tokens were generated (maybe empty text?). Exiting.")
|
558 |
+
return
|
559 |
+
|
560 |
+
# 4. Write WAV
|
561 |
+
print(f"Writing output to: {args.output}")
|
562 |
+
sf.write(args.output, audio, 22050)
|
563 |
+
|
564 |
+
print("Finished!")
|
565 |
+
print(f"Phonemes used: {phonemes}")
|
566 |
+
|
567 |
+
|
568 |
+
if __name__ == "__main__":
|
569 |
+
main()
|
uv.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|