Wismut commited on
Commit
2eaa44a
·
1 Parent(s): 788b43b

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_models/Kokoro/*.pth
2
+ pretrained_models/Kokoro/voices/*.pt
3
+ pretrained_models/Kokoro/config.json
4
+ output.wav
5
+ .env
6
+ # Ommit the DS_Store folder automatically created by macOS
7
+ .DS_Store/
8
+
9
+ # python virtual environment folder
10
+ .venv/
11
+ .vscode/
12
+
13
+ # process log file
14
+ process_log.txt*
15
+ process_log.txt
16
+
17
+ # Python cache
18
+ __pycache__/
19
+ /*/__pycache__/
20
+ /*/*/__pycache__/
21
+
22
+
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11
README.md CHANGED
@@ -8,7 +8,164 @@ sdk_version: 5.9.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- short_description: Simple Space for the Kokoro Model
 
 
 
 
 
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ short_description: Simple Space for the comparing the espeak-ng and openphonemizer with the Kokoro Model
12
+ base_model:
13
+ - yl4579/StyleTTS2-LJSpeech
14
+ - hexgrad/Kokoro-82M
15
+ - openphonemizer/ckpt
16
+ pipeline_tag: text-to-speech
17
  ---
18
 
19
+ ❤️ Kokoro Discord Server: <https://discord.gg/QuGxSWBfQy>
20
+
21
+ <audio controls><source src="https://huggingface.co/hexgrad/Kokoro-82M/resolve/main/demo/HEARME.wav" type="audio/wav"></audio>
22
+
23
+ **Kokoro** is a frontier TTS model for its size of **82 million parameters** (text in/audio out).
24
+
25
+ On 25 Dec 2024, Kokoro v0.19 weights were permissively released in full fp32 precision along with 2 voicepacks (Bella and Sarah), all under an Apache 2.0 license.
26
+
27
+ As of 28 Dec 2024, **8 unique Voicepacks have been released**: 2F 2M each for American and British English.
28
+
29
+ At the time of release, Kokoro v0.19 was the #1🥇 ranked model in [TTS Spaces Arena](https://huggingface.co/spaces/Pendrokar/TTS-Spaces-Arena). Kokoro had achieved higher Elo in this single-voice Arena setting over other models, using fewer parameters and less data:
30
+
31
+ 1. **Kokoro v0.19: 82M params, Apache, trained on <100 hours of audio, for <20 epochs**
32
+ 2. XTTS v2: 467M, CPML, >10k hours
33
+ 3. Edge TTS: Microsoft, proprietary
34
+ 4. MetaVoice: 1.2B, Apache, 100k hours
35
+ 5. Parler Mini: 880M, Apache, 45k hours
36
+ 6. Fish Speech: ~500M, CC-BY-NC-SA, 1M hours
37
+
38
+ Kokoro's ability to top this Elo ladder suggests that the scaling law (Elo vs compute/data/params) for traditional TTS models might have a steeper slope than previously expected.
39
+
40
+ You can find a hosted demo at [hf.co/spaces/hexgrad/Kokoro-TTS](https://huggingface.co/spaces/hexgrad/Kokoro-TTS).
41
+
42
+ ### Usage
43
+
44
+ The following can be run in a single cell on [Google Colab](https://colab.research.google.com/).
45
+
46
+ ```py
47
+ # 1️⃣ Install dependencies silently
48
+ !git clone https://huggingface.co/hexgrad/Kokoro-82M
49
+ %cd Kokoro-82M
50
+ !apt-get -qq -y install espeak-ng > /dev/null 2>&1
51
+ !pip install -q phonemizer torch transformers scipy munch
52
+
53
+ # 2️⃣ Build the model and load the default voicepack
54
+ from models import build_model
55
+ import torch
56
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
57
+ MODEL = build_model('kokoro-v0_19.pth', device)
58
+ VOICE_NAME = [
59
+ 'af', # Default voice is a 50-50 mix of af_bella & af_sarah
60
+ 'af_bella', 'af_sarah', 'am_adam', 'am_michael',
61
+ 'bf_emma', 'bf_isabella', 'bm_george', 'bm_lewis',
62
+ ][0]
63
+ VOICEPACK = torch.load(f'voices/{VOICE_NAME}.pt', weights_only=True).to(device)
64
+ print(f'Loaded voice: {VOICE_NAME}')
65
+
66
+ # 3️⃣ Call generate, which returns a 24khz audio waveform and a string of output phonemes
67
+ from kokoro import generate
68
+ text = "How could I know? It's an unanswerable question. Like asking an unborn child if they'll lead a good life. They haven't even been born."
69
+ audio, out_ps = generate(MODEL, text, VOICEPACK, lang=VOICE_NAME[0])
70
+ # Language is determined by the first letter of the VOICE_NAME:
71
+ # 🇺🇸 'a' => American English => en-us
72
+ # 🇬🇧 'b' => British English => en-gb
73
+
74
+ # 4️⃣ Display the 24khz audio and print the output phonemes
75
+ from IPython.display import display, Audio
76
+ display(Audio(data=audio, rate=24000, autoplay=True))
77
+ print(out_ps)
78
+ ```
79
+
80
+ The inference code was quickly hacked together on Christmas Day. It is not clean code and leaves a lot of room for improvement. If you'd like to contribute, feel free to open a PR.
81
+
82
+ ### Model Facts
83
+
84
+ No affiliation can be assumed between parties on different lines.
85
+
86
+ **Architecture:**
87
+
88
+ - StyleTTS 2: <https://arxiv.org/abs/2306.07691>
89
+ - ISTFTNet: <https://arxiv.org/abs/2203.02395>
90
+ - Decoder only: no diffusion, no encoder release
91
+
92
+ **Architected by:** Li et al @ <https://github.com/yl4579/StyleTTS2>
93
+
94
+ **Trained by**: `@rzvzn` on Discord
95
+
96
+ **Supported Languages:** American English, British English
97
+
98
+ **Model SHA256 Hash:** `3b0c392f87508da38fad3a2f9d94c359f1b657ebd2ef79f9d56d69503e470b0a`
99
+
100
+ ### Releases
101
+
102
+ - 25 Dec 2024: Model v0.19, `af_bella`, `af_sarah`
103
+ - 26 Dec 2024: `am_adam`, `am_michael`
104
+ - 28 Dec 2024: `bf_emma`, `bf_isabella`, `bm_george`, `bm_lewis`
105
+
106
+ ### Licenses
107
+
108
+ - Apache 2.0 weights in this repository
109
+ - MIT inference code in [spaces/hexgrad/Kokoro-TTS](https://huggingface.co/spaces/hexgrad/Kokoro-TTS) adapted from [yl4579/StyleTTS2](https://github.com/yl4579/StyleTTS2)
110
+ - GPLv3 dependency in [espeak-ng](https://github.com/espeak-ng/espeak-ng)
111
+
112
+ The inference code was originally MIT licensed by the paper author. Note that this card applies only to this model, Kokoro. Original models published by the paper author can be found at [hf.co/yl4579](https://huggingface.co/yl4579).
113
+
114
+ ### Evaluation
115
+
116
+ **Metric:** Elo rating
117
+
118
+ **Leaderboard:** [hf.co/spaces/Pendrokar/TTS-Spaces-Arena](https://huggingface.co/spaces/Pendrokar/TTS-Spaces-Arena)
119
+
120
+ ![TTS-Spaces-Arena-25-Dec-2024](demo/TTS-Spaces-Arena-25-Dec-2024.png)
121
+
122
+ The voice ranked in the Arena is a 50-50 mix of Bella and Sarah. For your convenience, this mix is included in this repository as `af.pt`, but you can trivially reproduce it like this:
123
+
124
+ ```py
125
+ import torch
126
+ bella = torch.load('voices/af_bella.pt', weights_only=True)
127
+ sarah = torch.load('voices/af_sarah.pt', weights_only=True)
128
+ af = torch.mean(torch.stack([bella, sarah]), dim=0)
129
+ assert torch.equal(af, torch.load('voices/af.pt', weights_only=True))
130
+ ```
131
+
132
+ ### Training Details
133
+
134
+ **Compute:** Kokoro was trained on A100 80GB vRAM instances rented from [Vast.ai](https://cloud.vast.ai/?ref_id=79907) (referral link). Vast was chosen over other compute providers due to its competitive on-demand hourly rates. The average hourly cost for the A100 80GB vRAM instances used for training was below $1/hr per GPU, which was around half the quoted rates from other providers at the time.
135
+
136
+ **Data:** Kokoro was trained exclusively on **permissive/non-copyrighted audio data** and IPA phoneme labels. Examples of permissive/non-copyrighted audio include:
137
+
138
+ - Public domain audio
139
+ - Audio licensed under Apache, MIT, etc
140
+ - Synthetic audio<sup>[1]</sup> generated by closed<sup>[2]</sup> TTS models from large providers<br/>
141
+ [1] <https://copyright.gov/ai/ai_policy_guidance.pdf><br/>
142
+ [2] No synthetic audio from open TTS models or "custom voice clones"
143
+
144
+ **Epochs:** Less than **20 epochs**
145
+
146
+ **Total Dataset Size:** Less than **100 hours** of audio
147
+
148
+ ### Limitations
149
+
150
+ Kokoro v0.19 is limited in some specific ways, due to its training set and/or architecture:
151
+
152
+ - [Data] Lacks voice cloning capability, likely due to small <100h training set
153
+ - [Arch] Relies on external g2p (espeak-ng), which introduces a class of g2p failure modes
154
+ - [Data] Training dataset is mostly long-form reading and narration, not conversation
155
+ - [Arch] At 82M params, Kokoro almost certainly falls to a well-trained 1B+ param diffusion transformer, or a many-billion-param MLLM like GPT-4o / Gemini 2.0 Flash
156
+ - [Data] Multilingual capability is architecturally feasible, but training data is mostly English
157
+
158
+ Refer to the [Philosophy discussion](https://huggingface.co/hexgrad/Kokoro-82M/discussions/5) to better understand these limitations.
159
+
160
+ **Will the other voicepacks be released?** There is currently no release date scheduled for the other voicepacks, but in the meantime you can try them in the hosted demo at [hf.co/spaces/hexgrad/Kokoro-TTS](https://huggingface.co/spaces/hexgrad/Kokoro-TTS).
161
+
162
+ ### Acknowledgements
163
+
164
+ - [@yl4579](https://huggingface.co/yl4579) for architecting StyleTTS 2
165
+ - [@Pendrokar](https://huggingface.co/Pendrokar) for adding Kokoro as a contender in the TTS Spaces Arena
166
+
167
+ ### Model Card Contact
168
+
169
+ `@rzvzn` on Discord. Server invite: <https://discord.gg/QuGxSWBfQy>
170
+
171
+ <img src="https://static0.gamerantimages.com/wordpress/wp-content/uploads/2024/08/terminator-zero-41-1.jpg" width="400" alt="kokoro" />
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+
5
+ # Import eSpeak TTS pipeline
6
+ from tts_cli import (
7
+ build_model as build_model_espeak,
8
+ generate_long_form_tts as generate_long_form_tts_espeak,
9
+ )
10
+
11
+ # Import OpenPhonemizer TTS pipeline
12
+ from tts_cli_op import (
13
+ build_model as build_model_open,
14
+ generate_long_form_tts as generate_long_form_tts_open,
15
+ )
16
+ from pretrained_models import Kokoro
17
+
18
+ # ---------------------------------------------------------------------
19
+ # Path to models and voicepacks
20
+ # ---------------------------------------------------------------------
21
+ MODELS_DIR = "pretrained_models/Kokoro"
22
+ VOICES_DIR = "pretrained_models/Kokoro/voices"
23
+
24
+
25
+ # ---------------------------------------------------------------------
26
+ # List the models (.pth) and voices (.pt)
27
+ # ---------------------------------------------------------------------
28
+ def get_models():
29
+ return sorted([f for f in os.listdir(MODELS_DIR) if f.endswith(".pth")])
30
+
31
+
32
+ def get_voices():
33
+ return sorted([f for f in os.listdir(VOICES_DIR) if f.endswith(".pt")])
34
+
35
+
36
+ # ---------------------------------------------------------------------
37
+ # We'll map engine selection -> (build_model_func, generate_func)
38
+ # ---------------------------------------------------------------------
39
+ ENGINES = {
40
+ "espeak": (build_model_espeak, generate_long_form_tts_espeak),
41
+ "openphonemizer": (build_model_open, generate_long_form_tts_open),
42
+ }
43
+
44
+
45
+ # ---------------------------------------------------------------------
46
+ # The main inference function called by Gradio
47
+ # ---------------------------------------------------------------------
48
+ def tts_inference(text, engine, model_file, voice_file, speed=1.0):
49
+ """
50
+ text: Input string
51
+ engine: "espeak" or "openphonemizer"
52
+ model_file: Selected .pth from the models folder
53
+ voice_file: Selected .pt from the voices folder
54
+ speed: Speech speed
55
+ """
56
+ # 1) Map engine to the correct build_model + generate_long_form_tts
57
+ build_fn, gen_fn = ENGINES[engine]
58
+
59
+ # 2) Prepare paths
60
+ model_path = os.path.join(MODELS_DIR, model_file)
61
+ voice_path = os.path.join(VOICES_DIR, voice_file)
62
+
63
+ # 3) Decide device
64
+ device = "cuda" if torch.cuda.is_available() else "cpu"
65
+
66
+ # 4) Load model
67
+ model = build_fn(model_path, device=device)
68
+ # Set submodules eval
69
+ for k, subm in model.items():
70
+ if hasattr(subm, "eval"):
71
+ subm.eval()
72
+
73
+ # 5) Load voicepack
74
+ voicepack = torch.load(voice_path, map_location=device)
75
+ if hasattr(voicepack, "eval"):
76
+ voicepack.eval()
77
+
78
+ # 6) Generate TTS
79
+ audio, phonemes = gen_fn(model, text, voicepack, speed=speed)
80
+ sr = 22050 # or your actual sample rate
81
+
82
+ return (sr, audio) # Gradio expects (sample_rate, np_array)
83
+
84
+
85
+ # ---------------------------------------------------------------------
86
+ # Build Gradio App
87
+ # ---------------------------------------------------------------------
88
+ def create_gradio_app():
89
+ model_list = get_models()
90
+ voice_list = get_voices()
91
+
92
+ css = """
93
+ h4 {
94
+ text-align: center;
95
+ display:block;
96
+ }
97
+ h2 {
98
+ text-align: center;
99
+ display:block;
100
+ }
101
+ """
102
+ with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo:
103
+ gr.Markdown("## Kokoro TTS Demo: Choose engine, model, and voice")
104
+
105
+ # Row 1: Text input
106
+ text_input = gr.Textbox(
107
+ label="Input Text",
108
+ value="Hello, world! Testing both eSpeak and OpenPhonemizer. Can you believe that we live in 2024 and have access to advanced AI?",
109
+ lines=3,
110
+ )
111
+
112
+ # Row 2: Engine selection
113
+ engine_dropdown = gr.Dropdown(
114
+ choices=["espeak", "openphonemizer"],
115
+ value="openphonemizer",
116
+ label="Phonemizer",
117
+ )
118
+
119
+ # Row 3: Model dropdown
120
+ model_dropdown = gr.Dropdown(
121
+ choices=model_list,
122
+ value=model_list[0] if model_list else None,
123
+ label="Model (.pth)",
124
+ )
125
+
126
+ # Row 4: Voice dropdown
127
+ voice_dropdown = gr.Dropdown(
128
+ choices=voice_list,
129
+ value=voice_list[0] if voice_list else None,
130
+ label="Voice (.pt)",
131
+ )
132
+
133
+ # Row 5: Speed slider
134
+ speed_slider = gr.Slider(
135
+ minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speech Speed"
136
+ )
137
+
138
+ # Generate button + audio output
139
+ generate_btn = gr.Button("Generate")
140
+ tts_output = gr.Audio(label="TTS Output")
141
+
142
+ # Connect the button to our inference function
143
+ generate_btn.click(
144
+ fn=tts_inference,
145
+ inputs=[
146
+ text_input,
147
+ engine_dropdown,
148
+ model_dropdown,
149
+ voice_dropdown,
150
+ speed_slider,
151
+ ],
152
+ outputs=tts_output,
153
+ )
154
+
155
+ gr.Markdown(
156
+ "#### Kokoro TTS Demo based on [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M)"
157
+ )
158
+ return demo
159
+
160
+
161
+ # ---------------------------------------------------------------------
162
+ # Main
163
+ # ---------------------------------------------------------------------
164
+ if __name__ == "__main__":
165
+ app = create_gradio_app()
166
+ app.launch()
espeak_util.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+ import subprocess
3
+ import shutil
4
+ from pathlib import Path
5
+ import os
6
+ from typing import Optional, Tuple
7
+ from phonemizer.backend.espeak.wrapper import EspeakWrapper
8
+
9
+
10
+ class EspeakConfig:
11
+ """Utility class for configuring espeak-ng library and binary."""
12
+
13
+ @staticmethod
14
+ def find_espeak_binary() -> tuple[bool, Optional[str]]:
15
+ """
16
+ Find espeak-ng binary using multiple methods.
17
+
18
+ Returns:
19
+ tuple: (bool indicating if espeak is available, path to espeak binary if found)
20
+ """
21
+ # Common binary names
22
+ binary_names = ["espeak-ng", "espeak"]
23
+ if platform.system() == "Windows":
24
+ binary_names = ["espeak-ng.exe", "espeak.exe"]
25
+
26
+ # Common installation directories for Linux
27
+ linux_paths = [
28
+ "/usr/bin",
29
+ "/usr/local/bin",
30
+ "/usr/lib/espeak-ng",
31
+ "/usr/local/lib/espeak-ng",
32
+ "/opt/espeak-ng/bin",
33
+ ]
34
+
35
+ # First check if it's in PATH
36
+ for name in binary_names:
37
+ espeak_path = shutil.which(name)
38
+ if espeak_path:
39
+ return True, espeak_path
40
+
41
+ # For Linux, check common installation directories
42
+ if platform.system() == "Linux":
43
+ for directory in linux_paths:
44
+ for name in binary_names:
45
+ path = Path(directory) / name
46
+ if path.exists():
47
+ return True, str(path)
48
+
49
+ # Try running the command directly as a last resort
50
+ try:
51
+ subprocess.run(
52
+ ["espeak-ng", "--version"],
53
+ stdout=subprocess.PIPE,
54
+ stderr=subprocess.PIPE,
55
+ check=True,
56
+ )
57
+ return True, "espeak-ng"
58
+ except (subprocess.SubprocessError, FileNotFoundError):
59
+ pass
60
+
61
+ return False, None
62
+
63
+ @staticmethod
64
+ def find_library_path() -> Optional[str]:
65
+ """
66
+ Find the espeak-ng library using multiple search methods.
67
+
68
+ Returns:
69
+ Optional[str]: Path to the library if found, None otherwise
70
+ """
71
+ system = platform.system()
72
+
73
+ if system == "Linux":
74
+ lib_names = ["libespeak-ng.so", "libespeak-ng.so.1"]
75
+ common_paths = [
76
+ # Debian/Ubuntu paths
77
+ "/usr/lib/x86_64-linux-gnu",
78
+ "/usr/lib/aarch64-linux-gnu", # For ARM64
79
+ "/usr/lib/arm-linux-gnueabihf", # For ARM32
80
+ "/usr/lib",
81
+ "/usr/local/lib",
82
+ # Fedora/RHEL paths
83
+ "/usr/lib64",
84
+ "/usr/lib32",
85
+ # Common additional paths
86
+ "/usr/lib/espeak-ng",
87
+ "/usr/local/lib/espeak-ng",
88
+ "/opt/espeak-ng/lib",
89
+ ]
90
+
91
+ # Check common locations first
92
+ for path in common_paths:
93
+ for lib_name in lib_names:
94
+ lib_path = Path(path) / lib_name
95
+ if lib_path.exists():
96
+ return str(lib_path)
97
+
98
+ # Search system library paths
99
+ try:
100
+ # Use ldconfig to find the library
101
+ result = subprocess.run(
102
+ ["ldconfig", "-p"], capture_output=True, text=True, check=True
103
+ )
104
+ for line in result.stdout.splitlines():
105
+ if "libespeak-ng.so" in line:
106
+ # Extract path from ldconfig output
107
+ return line.split("=>")[-1].strip()
108
+ except (subprocess.SubprocessError, FileNotFoundError):
109
+ pass
110
+
111
+ elif system == "Darwin": # macOS
112
+ common_paths = [
113
+ Path("/opt/homebrew/lib/libespeak-ng.dylib"),
114
+ Path("/usr/local/lib/libespeak-ng.dylib"),
115
+ *list(
116
+ Path("/opt/homebrew/Cellar/espeak-ng").glob(
117
+ "*/lib/libespeak-ng.dylib"
118
+ )
119
+ ),
120
+ *list(
121
+ Path("/usr/local/Cellar/espeak-ng").glob("*/lib/libespeak-ng.dylib")
122
+ ),
123
+ ]
124
+
125
+ for path in common_paths:
126
+ if path.exists():
127
+ return str(path)
128
+
129
+ elif system == "Windows":
130
+ common_paths = [
131
+ Path(os.environ.get("PROGRAMFILES", "C:\\Program Files"))
132
+ / "eSpeak NG"
133
+ / "libespeak-ng.dll",
134
+ Path(os.environ.get("PROGRAMFILES(X86)", "C:\\Program Files (x86)"))
135
+ / "eSpeak NG"
136
+ / "libespeak-ng.dll",
137
+ *[
138
+ Path(p) / "libespeak-ng.dll"
139
+ for p in os.environ.get("PATH", "").split(os.pathsep)
140
+ ],
141
+ ]
142
+
143
+ for path in common_paths:
144
+ if path.exists():
145
+ return str(path)
146
+
147
+ return None
148
+
149
+ @classmethod
150
+ def configure_espeak(cls) -> Tuple[bool, str]:
151
+ """
152
+ Configure espeak-ng for use with the phonemizer.
153
+
154
+ Returns:
155
+ Tuple[bool, str]: (Success status, Status message)
156
+ """
157
+ # First check if espeak binary is available
158
+ espeak_available, espeak_path = cls.find_espeak_binary()
159
+ if not espeak_available:
160
+ raise FileNotFoundError(
161
+ "Could not find espeak-ng binary. Please install espeak-ng:\n"
162
+ "Ubuntu/Debian: sudo apt-get install espeak-ng espeak-ng-data\n"
163
+ "Fedora: sudo dnf install espeak-ng\n"
164
+ "Arch: sudo pacman -S espeak-ng\n"
165
+ "MacOS: brew install espeak-ng\n"
166
+ "Windows: Download from https://github.com/espeak-ng/espeak-ng/releases"
167
+ )
168
+
169
+ # Find the library
170
+ library_path = cls.find_library_path()
171
+ if not library_path:
172
+ # On Linux, we might not need to explicitly set the library path
173
+ if platform.system() == "Linux":
174
+ return True, f"Using system espeak-ng installation at: {espeak_path}"
175
+ else:
176
+ raise FileNotFoundError(
177
+ "Could not find espeak-ng library. Please ensure espeak-ng is properly installed."
178
+ )
179
+
180
+ # Try to set the library path
181
+ try:
182
+ EspeakWrapper.set_library(library_path)
183
+ return True, f"Successfully configured espeak-ng library at: {library_path}"
184
+ except Exception as e:
185
+ if platform.system() == "Linux":
186
+ # On Linux, try to continue without explicit library path
187
+ return True, f"Using system espeak-ng installation at: {espeak_path}"
188
+ else:
189
+ raise RuntimeError(f"Failed to configure espeak-ng library: {str(e)}")
190
+
191
+
192
+ def setup_espeak():
193
+ """
194
+ Set up espeak-ng for use with the phonemizer.
195
+ Raises appropriate exceptions if setup fails.
196
+ """
197
+ try:
198
+ success, message = EspeakConfig.configure_espeak()
199
+ print(message)
200
+ except Exception as e:
201
+ print(f"Error configuring espeak-ng: {str(e)}")
202
+ raise
203
+
204
+
205
+ # Replace the original set_espeak_library function with this
206
+ set_espeak_library = setup_espeak
istftnet.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
2
+ from scipy.signal import get_window
3
+ from torch.nn import Conv1d, ConvTranspose1d
4
+ from torch.nn.utils import weight_norm, remove_weight_norm
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ # https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
11
+ def init_weights(m, mean=0.0, std=0.01):
12
+ classname = m.__class__.__name__
13
+ if classname.find("Conv") != -1:
14
+ m.weight.data.normal_(mean, std)
15
+
16
+ def get_padding(kernel_size, dilation=1):
17
+ return int((kernel_size*dilation - dilation)/2)
18
+
19
+ LRELU_SLOPE = 0.1
20
+
21
+ class AdaIN1d(nn.Module):
22
+ def __init__(self, style_dim, num_features):
23
+ super().__init__()
24
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
25
+ self.fc = nn.Linear(style_dim, num_features*2)
26
+
27
+ def forward(self, x, s):
28
+ h = self.fc(s)
29
+ h = h.view(h.size(0), h.size(1), 1)
30
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
31
+ return (1 + gamma) * self.norm(x) + beta
32
+
33
+ class AdaINResBlock1(torch.nn.Module):
34
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
35
+ super(AdaINResBlock1, self).__init__()
36
+ self.convs1 = nn.ModuleList([
37
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
38
+ padding=get_padding(kernel_size, dilation[0]))),
39
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
40
+ padding=get_padding(kernel_size, dilation[1]))),
41
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
42
+ padding=get_padding(kernel_size, dilation[2])))
43
+ ])
44
+ self.convs1.apply(init_weights)
45
+
46
+ self.convs2 = nn.ModuleList([
47
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
48
+ padding=get_padding(kernel_size, 1))),
49
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
50
+ padding=get_padding(kernel_size, 1))),
51
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
52
+ padding=get_padding(kernel_size, 1)))
53
+ ])
54
+ self.convs2.apply(init_weights)
55
+
56
+ self.adain1 = nn.ModuleList([
57
+ AdaIN1d(style_dim, channels),
58
+ AdaIN1d(style_dim, channels),
59
+ AdaIN1d(style_dim, channels),
60
+ ])
61
+
62
+ self.adain2 = nn.ModuleList([
63
+ AdaIN1d(style_dim, channels),
64
+ AdaIN1d(style_dim, channels),
65
+ AdaIN1d(style_dim, channels),
66
+ ])
67
+
68
+ self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
69
+ self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
70
+
71
+
72
+ def forward(self, x, s):
73
+ for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
74
+ xt = n1(x, s)
75
+ xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
76
+ xt = c1(xt)
77
+ xt = n2(xt, s)
78
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
79
+ xt = c2(xt)
80
+ x = xt + x
81
+ return x
82
+
83
+ def remove_weight_norm(self):
84
+ for l in self.convs1:
85
+ remove_weight_norm(l)
86
+ for l in self.convs2:
87
+ remove_weight_norm(l)
88
+
89
+ class TorchSTFT(torch.nn.Module):
90
+ def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
91
+ super().__init__()
92
+ self.filter_length = filter_length
93
+ self.hop_length = hop_length
94
+ self.win_length = win_length
95
+ self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32))
96
+
97
+ def transform(self, input_data):
98
+ forward_transform = torch.stft(
99
+ input_data,
100
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
101
+ return_complex=True)
102
+
103
+ return torch.abs(forward_transform), torch.angle(forward_transform)
104
+
105
+ def inverse(self, magnitude, phase):
106
+ inverse_transform = torch.istft(
107
+ magnitude * torch.exp(phase * 1j),
108
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
109
+
110
+ return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
111
+
112
+ def forward(self, input_data):
113
+ self.magnitude, self.phase = self.transform(input_data)
114
+ reconstruction = self.inverse(self.magnitude, self.phase)
115
+ return reconstruction
116
+
117
+ class SineGen(torch.nn.Module):
118
+ """ Definition of sine generator
119
+ SineGen(samp_rate, harmonic_num = 0,
120
+ sine_amp = 0.1, noise_std = 0.003,
121
+ voiced_threshold = 0,
122
+ flag_for_pulse=False)
123
+ samp_rate: sampling rate in Hz
124
+ harmonic_num: number of harmonic overtones (default 0)
125
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
126
+ noise_std: std of Gaussian noise (default 0.003)
127
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
128
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
129
+ Note: when flag_for_pulse is True, the first time step of a voiced
130
+ segment is always sin(np.pi) or cos(0)
131
+ """
132
+
133
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
134
+ sine_amp=0.1, noise_std=0.003,
135
+ voiced_threshold=0,
136
+ flag_for_pulse=False):
137
+ super(SineGen, self).__init__()
138
+ self.sine_amp = sine_amp
139
+ self.noise_std = noise_std
140
+ self.harmonic_num = harmonic_num
141
+ self.dim = self.harmonic_num + 1
142
+ self.sampling_rate = samp_rate
143
+ self.voiced_threshold = voiced_threshold
144
+ self.flag_for_pulse = flag_for_pulse
145
+ self.upsample_scale = upsample_scale
146
+
147
+ def _f02uv(self, f0):
148
+ # generate uv signal
149
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
150
+ return uv
151
+
152
+ def _f02sine(self, f0_values):
153
+ """ f0_values: (batchsize, length, dim)
154
+ where dim indicates fundamental tone and overtones
155
+ """
156
+ # convert to F0 in rad. The interger part n can be ignored
157
+ # because 2 * np.pi * n doesn't affect phase
158
+ rad_values = (f0_values / self.sampling_rate) % 1
159
+
160
+ # initial phase noise (no noise for fundamental component)
161
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
162
+ device=f0_values.device)
163
+ rand_ini[:, 0] = 0
164
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
165
+
166
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
167
+ if not self.flag_for_pulse:
168
+ # # for normal case
169
+
170
+ # # To prevent torch.cumsum numerical overflow,
171
+ # # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
172
+ # # Buffer tmp_over_one_idx indicates the time step to add -1.
173
+ # # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
174
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
175
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
176
+ # cumsum_shift = torch.zeros_like(rad_values)
177
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
178
+
179
+ # phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
180
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
181
+ scale_factor=1/self.upsample_scale,
182
+ mode="linear").transpose(1, 2)
183
+
184
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
185
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
186
+ # cumsum_shift = torch.zeros_like(rad_values)
187
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
188
+
189
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
190
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
191
+ scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
192
+ sines = torch.sin(phase)
193
+
194
+ else:
195
+ # If necessary, make sure that the first time step of every
196
+ # voiced segments is sin(pi) or cos(0)
197
+ # This is used for pulse-train generation
198
+
199
+ # identify the last time step in unvoiced segments
200
+ uv = self._f02uv(f0_values)
201
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
202
+ uv_1[:, -1, :] = 1
203
+ u_loc = (uv < 1) * (uv_1 > 0)
204
+
205
+ # get the instantanouse phase
206
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
207
+ # different batch needs to be processed differently
208
+ for idx in range(f0_values.shape[0]):
209
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
210
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
211
+ # stores the accumulation of i.phase within
212
+ # each voiced segments
213
+ tmp_cumsum[idx, :, :] = 0
214
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
215
+
216
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
217
+ # within the previous voiced segment.
218
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
219
+
220
+ # get the sines
221
+ sines = torch.cos(i_phase * 2 * np.pi)
222
+ return sines
223
+
224
+ def forward(self, f0):
225
+ """ sine_tensor, uv = forward(f0)
226
+ input F0: tensor(batchsize=1, length, dim=1)
227
+ f0 for unvoiced steps should be 0
228
+ output sine_tensor: tensor(batchsize=1, length, dim)
229
+ output uv: tensor(batchsize=1, length, 1)
230
+ """
231
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
232
+ device=f0.device)
233
+ # fundamental component
234
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
235
+
236
+ # generate sine waveforms
237
+ sine_waves = self._f02sine(fn) * self.sine_amp
238
+
239
+ # generate uv signal
240
+ # uv = torch.ones(f0.shape)
241
+ # uv = uv * (f0 > self.voiced_threshold)
242
+ uv = self._f02uv(f0)
243
+
244
+ # noise: for unvoiced should be similar to sine_amp
245
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
246
+ # . for voiced regions is self.noise_std
247
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
248
+ noise = noise_amp * torch.randn_like(sine_waves)
249
+
250
+ # first: set the unvoiced part to 0 by uv
251
+ # then: additive noise
252
+ sine_waves = sine_waves * uv + noise
253
+ return sine_waves, uv, noise
254
+
255
+
256
+ class SourceModuleHnNSF(torch.nn.Module):
257
+ """ SourceModule for hn-nsf
258
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
259
+ add_noise_std=0.003, voiced_threshod=0)
260
+ sampling_rate: sampling_rate in Hz
261
+ harmonic_num: number of harmonic above F0 (default: 0)
262
+ sine_amp: amplitude of sine source signal (default: 0.1)
263
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
264
+ note that amplitude of noise in unvoiced is decided
265
+ by sine_amp
266
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
267
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
268
+ F0_sampled (batchsize, length, 1)
269
+ Sine_source (batchsize, length, 1)
270
+ noise_source (batchsize, length 1)
271
+ uv (batchsize, length, 1)
272
+ """
273
+
274
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
275
+ add_noise_std=0.003, voiced_threshod=0):
276
+ super(SourceModuleHnNSF, self).__init__()
277
+
278
+ self.sine_amp = sine_amp
279
+ self.noise_std = add_noise_std
280
+
281
+ # to produce sine waveforms
282
+ self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
283
+ sine_amp, add_noise_std, voiced_threshod)
284
+
285
+ # to merge source harmonics into a single excitation
286
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
287
+ self.l_tanh = torch.nn.Tanh()
288
+
289
+ def forward(self, x):
290
+ """
291
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
292
+ F0_sampled (batchsize, length, 1)
293
+ Sine_source (batchsize, length, 1)
294
+ noise_source (batchsize, length 1)
295
+ """
296
+ # source for harmonic branch
297
+ with torch.no_grad():
298
+ sine_wavs, uv, _ = self.l_sin_gen(x)
299
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
300
+
301
+ # source for noise branch, in the same shape as uv
302
+ noise = torch.randn_like(uv) * self.sine_amp / 3
303
+ return sine_merge, noise, uv
304
+ def padDiff(x):
305
+ return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
306
+
307
+
308
+ class Generator(torch.nn.Module):
309
+ def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size):
310
+ super(Generator, self).__init__()
311
+
312
+ self.num_kernels = len(resblock_kernel_sizes)
313
+ self.num_upsamples = len(upsample_rates)
314
+ resblock = AdaINResBlock1
315
+
316
+ self.m_source = SourceModuleHnNSF(
317
+ sampling_rate=24000,
318
+ upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size,
319
+ harmonic_num=8, voiced_threshod=10)
320
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * gen_istft_hop_size)
321
+ self.noise_convs = nn.ModuleList()
322
+ self.noise_res = nn.ModuleList()
323
+
324
+ self.ups = nn.ModuleList()
325
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
326
+ self.ups.append(weight_norm(
327
+ ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
328
+ k, u, padding=(k-u)//2)))
329
+
330
+ self.resblocks = nn.ModuleList()
331
+ for i in range(len(self.ups)):
332
+ ch = upsample_initial_channel//(2**(i+1))
333
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
334
+ self.resblocks.append(resblock(ch, k, d, style_dim))
335
+
336
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
337
+
338
+ if i + 1 < len(upsample_rates): #
339
+ stride_f0 = np.prod(upsample_rates[i + 1:])
340
+ self.noise_convs.append(Conv1d(
341
+ gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
342
+ self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
343
+ else:
344
+ self.noise_convs.append(Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
345
+ self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
346
+
347
+
348
+ self.post_n_fft = gen_istft_n_fft
349
+ self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
350
+ self.ups.apply(init_weights)
351
+ self.conv_post.apply(init_weights)
352
+ self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
353
+ self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
354
+
355
+
356
+ def forward(self, x, s, f0):
357
+ with torch.no_grad():
358
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
359
+
360
+ har_source, noi_source, uv = self.m_source(f0)
361
+ har_source = har_source.transpose(1, 2).squeeze(1)
362
+ har_spec, har_phase = self.stft.transform(har_source)
363
+ har = torch.cat([har_spec, har_phase], dim=1)
364
+
365
+ for i in range(self.num_upsamples):
366
+ x = F.leaky_relu(x, LRELU_SLOPE)
367
+ x_source = self.noise_convs[i](har)
368
+ x_source = self.noise_res[i](x_source, s)
369
+
370
+ x = self.ups[i](x)
371
+ if i == self.num_upsamples - 1:
372
+ x = self.reflection_pad(x)
373
+
374
+ x = x + x_source
375
+ xs = None
376
+ for j in range(self.num_kernels):
377
+ if xs is None:
378
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
379
+ else:
380
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
381
+ x = xs / self.num_kernels
382
+ x = F.leaky_relu(x)
383
+ x = self.conv_post(x)
384
+ spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
385
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
386
+ return self.stft.inverse(spec, phase)
387
+
388
+ def fw_phase(self, x, s):
389
+ for i in range(self.num_upsamples):
390
+ x = F.leaky_relu(x, LRELU_SLOPE)
391
+ x = self.ups[i](x)
392
+ xs = None
393
+ for j in range(self.num_kernels):
394
+ if xs is None:
395
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
396
+ else:
397
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
398
+ x = xs / self.num_kernels
399
+ x = F.leaky_relu(x)
400
+ x = self.reflection_pad(x)
401
+ x = self.conv_post(x)
402
+ spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
403
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
404
+ return spec, phase
405
+
406
+ def remove_weight_norm(self):
407
+ print('Removing weight norm...')
408
+ for l in self.ups:
409
+ remove_weight_norm(l)
410
+ for l in self.resblocks:
411
+ l.remove_weight_norm()
412
+ remove_weight_norm(self.conv_pre)
413
+ remove_weight_norm(self.conv_post)
414
+
415
+
416
+ class AdainResBlk1d(nn.Module):
417
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
418
+ upsample='none', dropout_p=0.0):
419
+ super().__init__()
420
+ self.actv = actv
421
+ self.upsample_type = upsample
422
+ self.upsample = UpSample1d(upsample)
423
+ self.learned_sc = dim_in != dim_out
424
+ self._build_weights(dim_in, dim_out, style_dim)
425
+ self.dropout = nn.Dropout(dropout_p)
426
+
427
+ if upsample == 'none':
428
+ self.pool = nn.Identity()
429
+ else:
430
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
431
+
432
+
433
+ def _build_weights(self, dim_in, dim_out, style_dim):
434
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
435
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
436
+ self.norm1 = AdaIN1d(style_dim, dim_in)
437
+ self.norm2 = AdaIN1d(style_dim, dim_out)
438
+ if self.learned_sc:
439
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
440
+
441
+ def _shortcut(self, x):
442
+ x = self.upsample(x)
443
+ if self.learned_sc:
444
+ x = self.conv1x1(x)
445
+ return x
446
+
447
+ def _residual(self, x, s):
448
+ x = self.norm1(x, s)
449
+ x = self.actv(x)
450
+ x = self.pool(x)
451
+ x = self.conv1(self.dropout(x))
452
+ x = self.norm2(x, s)
453
+ x = self.actv(x)
454
+ x = self.conv2(self.dropout(x))
455
+ return x
456
+
457
+ def forward(self, x, s):
458
+ out = self._residual(x, s)
459
+ out = (out + self._shortcut(x)) / np.sqrt(2)
460
+ return out
461
+
462
+ class UpSample1d(nn.Module):
463
+ def __init__(self, layer_type):
464
+ super().__init__()
465
+ self.layer_type = layer_type
466
+
467
+ def forward(self, x):
468
+ if self.layer_type == 'none':
469
+ return x
470
+ else:
471
+ return F.interpolate(x, scale_factor=2, mode='nearest')
472
+
473
+ class Decoder(nn.Module):
474
+ def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
475
+ resblock_kernel_sizes = [3,7,11],
476
+ upsample_rates = [10, 6],
477
+ upsample_initial_channel=512,
478
+ resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
479
+ upsample_kernel_sizes=[20, 12],
480
+ gen_istft_n_fft=20, gen_istft_hop_size=5):
481
+ super().__init__()
482
+
483
+ self.decode = nn.ModuleList()
484
+
485
+ self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
486
+
487
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
488
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
489
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
490
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
491
+
492
+ self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
493
+
494
+ self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
495
+
496
+ self.asr_res = nn.Sequential(
497
+ weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
498
+ )
499
+
500
+
501
+ self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
502
+ upsample_initial_channel, resblock_dilation_sizes,
503
+ upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size)
504
+
505
+ def forward(self, asr, F0_curve, N, s):
506
+ F0 = self.F0_conv(F0_curve.unsqueeze(1))
507
+ N = self.N_conv(N.unsqueeze(1))
508
+
509
+ x = torch.cat([asr, F0, N], axis=1)
510
+ x = self.encode(x, s)
511
+
512
+ asr_res = self.asr_res(asr)
513
+
514
+ res = True
515
+ for block in self.decode:
516
+ if res:
517
+ x = torch.cat([x, asr_res, F0, N], axis=1)
518
+ x = block(x, s)
519
+ if block.upsample_type != "none":
520
+ res = False
521
+
522
+ x = self.generator(x, s, F0_curve)
523
+ return x
kokoro.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import phonemizer
2
+ import re
3
+ import torch
4
+ from espeak_util import set_espeak_library
5
+
6
+ set_espeak_library()
7
+
8
+
9
+ def split_num(num):
10
+ num = num.group()
11
+ if "." in num:
12
+ return num
13
+ elif ":" in num:
14
+ h, m = [int(n) for n in num.split(":")]
15
+ if m == 0:
16
+ return f"{h} o'clock"
17
+ elif m < 10:
18
+ return f"{h} oh {m}"
19
+ return f"{h} {m}"
20
+ year = int(num[:4])
21
+ if year < 1100 or year % 1000 < 10:
22
+ return num
23
+ left, right = num[:2], int(num[2:4])
24
+ s = "s" if num.endswith("s") else ""
25
+ if 100 <= year % 1000 <= 999:
26
+ if right == 0:
27
+ return f"{left} hundred{s}"
28
+ elif right < 10:
29
+ return f"{left} oh {right}{s}"
30
+ return f"{left} {right}{s}"
31
+
32
+
33
+ def flip_money(m):
34
+ m = m.group()
35
+ bill = "dollar" if m[0] == "$" else "pound"
36
+ if m[-1].isalpha():
37
+ return f"{m[1:]} {bill}s"
38
+ elif "." not in m:
39
+ s = "" if m[1:] == "1" else "s"
40
+ return f"{m[1:]} {bill}{s}"
41
+ b, c = m[1:].split(".")
42
+ s = "" if b == "1" else "s"
43
+ c = int(c.ljust(2, "0"))
44
+ coins = (
45
+ f"cent{'' if c == 1 else 's'}"
46
+ if m[0] == "$"
47
+ else ("penny" if c == 1 else "pence")
48
+ )
49
+ return f"{b} {bill}{s} and {c} {coins}"
50
+
51
+
52
+ def point_num(num):
53
+ a, b = num.group().split(".")
54
+ return " point ".join([a, " ".join(b)])
55
+
56
+
57
+ def normalize_text(text):
58
+ text = text.replace(chr(8216), "'").replace(chr(8217), "'")
59
+ text = text.replace("«", chr(8220)).replace("»", chr(8221))
60
+ text = text.replace(chr(8220), '"').replace(chr(8221), '"')
61
+ text = text.replace("(", "«").replace(")", "»")
62
+ for a, b in zip("、。!,:;?", ",.!,:;?"):
63
+ text = text.replace(a, b + " ")
64
+ text = re.sub(r"[^\S \n]", " ", text)
65
+ text = re.sub(r" +", " ", text)
66
+ text = re.sub(r"(?<=\n) +(?=\n)", "", text)
67
+ text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
68
+ text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)
69
+ text = re.sub(r"\b(?:Ms\.|MS\.(?= [A-Z]))", "Miss", text)
70
+ text = re.sub(r"\b(?:Mrs\.|MRS\.(?= [A-Z]))", "Mrs", text)
71
+ text = re.sub(r"\betc\.(?! [A-Z])", "etc", text)
72
+ text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
73
+ text = re.sub(
74
+ r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text
75
+ )
76
+ text = re.sub(r"(?<=\d),(?=\d)", "", text)
77
+ text = re.sub(
78
+ r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
79
+ flip_money,
80
+ text,
81
+ )
82
+ text = re.sub(r"\d*\.\d+", point_num, text)
83
+ text = re.sub(r"(?<=\d)-(?=\d)", " to ", text)
84
+ text = re.sub(r"(?<=\d)S", " S", text)
85
+ text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
86
+ text = re.sub(r"(?<=X')S\b", "s", text)
87
+ text = re.sub(
88
+ r"(?:[A-Za-z]\.){2,} [a-z]", lambda m: m.group().replace(".", "-"), text
89
+ )
90
+ text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
91
+ return text.strip()
92
+
93
+
94
+ def get_vocab():
95
+ _pad = "$"
96
+ _punctuation = ';:,.!?¡¿—…"«»“” '
97
+ _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
98
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
99
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
100
+ dicts = {}
101
+ for i in range(len((symbols))):
102
+ dicts[symbols[i]] = i
103
+ return dicts
104
+
105
+
106
+ VOCAB = get_vocab()
107
+
108
+
109
+ def tokenize(ps):
110
+ return [i for i in map(VOCAB.get, ps) if i is not None]
111
+
112
+
113
+ phonemizers = dict(
114
+ a=phonemizer.backend.EspeakBackend(
115
+ language="en-us", preserve_punctuation=True, with_stress=True
116
+ ),
117
+ b=phonemizer.backend.EspeakBackend(
118
+ language="en-gb", preserve_punctuation=True, with_stress=True
119
+ ),
120
+ )
121
+
122
+
123
+ def phonemize(text, lang, norm=True):
124
+ if norm:
125
+ text = normalize_text(text)
126
+ ps = phonemizers[lang].phonemize([text])
127
+ ps = ps[0] if ps else ""
128
+ # https://en.wiktionary.org/wiki/kokoro#English
129
+ ps = ps.replace("kəkˈoːɹoʊ", "kˈoʊkəɹoʊ").replace("kəkˈɔːɹəʊ", "kˈəʊkəɹəʊ")
130
+ ps = ps.replace("ʲ", "j").replace("r", "ɹ").replace("x", "k").replace("ɬ", "l")
131
+ ps = re.sub(r"(?<=[a-zɹː])(?=hˈʌndɹɪd)", " ", ps)
132
+ ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', "z", ps)
133
+ if lang == "a":
134
+ ps = re.sub(r"(?<=nˈaɪn)ti(?!ː)", "di", ps)
135
+ ps = "".join(filter(lambda p: p in VOCAB, ps))
136
+ return ps.strip()
137
+
138
+
139
+ def length_to_mask(lengths):
140
+ mask = (
141
+ torch.arange(lengths.max())
142
+ .unsqueeze(0)
143
+ .expand(lengths.shape[0], -1)
144
+ .type_as(lengths)
145
+ )
146
+ mask = torch.gt(mask + 1, lengths.unsqueeze(1))
147
+ return mask
148
+
149
+
150
+ @torch.no_grad()
151
+ def forward(model, tokens, ref_s, speed):
152
+ device = ref_s.device
153
+ tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
154
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
155
+ text_mask = length_to_mask(input_lengths).to(device)
156
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
157
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
158
+ s = ref_s[:, 128:]
159
+ d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
160
+ x, _ = model.predictor.lstm(d)
161
+ duration = model.predictor.duration_proj(x)
162
+ duration = torch.sigmoid(duration).sum(axis=-1) / speed
163
+ pred_dur = torch.round(duration).clamp(min=1).long()
164
+ pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
165
+ c_frame = 0
166
+ for i in range(pred_aln_trg.size(0)):
167
+ pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
168
+ c_frame += pred_dur[0, i].item()
169
+ en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
170
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
171
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
172
+ asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
173
+ return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
174
+
175
+
176
+ def generate(model, text, voicepack, lang="a", speed=1):
177
+ ps = phonemize(text, lang)
178
+ tokens = tokenize(ps)
179
+ if not tokens:
180
+ return None
181
+ elif len(tokens) > 510:
182
+ tokens = tokens[:510]
183
+ print("Truncated to 510 tokens")
184
+ ref_s = voicepack[len(tokens)]
185
+ out = forward(model, tokens, ref_s, speed)
186
+ ps = "".join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
187
+ return out, ps
models.py ADDED
@@ -0,0 +1,738 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/yl4579/StyleTTS2/blob/main/models.py
2
+ from istftnet import Decoder
3
+ from munch import Munch
4
+ from pathlib import Path
5
+ from plbert import load_plbert
6
+ from torch.nn.utils import weight_norm, spectral_norm
7
+ import json
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class LearnedDownSample(nn.Module):
15
+ def __init__(self, layer_type, dim_in):
16
+ super().__init__()
17
+ self.layer_type = layer_type
18
+
19
+ if self.layer_type == "none":
20
+ self.conv = nn.Identity()
21
+ elif self.layer_type == "timepreserve":
22
+ self.conv = spectral_norm(
23
+ nn.Conv2d(
24
+ dim_in,
25
+ dim_in,
26
+ kernel_size=(3, 1),
27
+ stride=(2, 1),
28
+ groups=dim_in,
29
+ padding=(1, 0),
30
+ )
31
+ )
32
+ elif self.layer_type == "half":
33
+ self.conv = spectral_norm(
34
+ nn.Conv2d(
35
+ dim_in,
36
+ dim_in,
37
+ kernel_size=(3, 3),
38
+ stride=(2, 2),
39
+ groups=dim_in,
40
+ padding=1,
41
+ )
42
+ )
43
+ else:
44
+ raise RuntimeError(
45
+ "Got unexpected donwsampletype %s, expected is [none, timepreserve, half]"
46
+ % self.layer_type
47
+ )
48
+
49
+ def forward(self, x):
50
+ return self.conv(x)
51
+
52
+
53
+ class LearnedUpSample(nn.Module):
54
+ def __init__(self, layer_type, dim_in):
55
+ super().__init__()
56
+ self.layer_type = layer_type
57
+
58
+ if self.layer_type == "none":
59
+ self.conv = nn.Identity()
60
+ elif self.layer_type == "timepreserve":
61
+ self.conv = nn.ConvTranspose2d(
62
+ dim_in,
63
+ dim_in,
64
+ kernel_size=(3, 1),
65
+ stride=(2, 1),
66
+ groups=dim_in,
67
+ output_padding=(1, 0),
68
+ padding=(1, 0),
69
+ )
70
+ elif self.layer_type == "half":
71
+ self.conv = nn.ConvTranspose2d(
72
+ dim_in,
73
+ dim_in,
74
+ kernel_size=(3, 3),
75
+ stride=(2, 2),
76
+ groups=dim_in,
77
+ output_padding=1,
78
+ padding=1,
79
+ )
80
+ else:
81
+ raise RuntimeError(
82
+ "Got unexpected upsampletype %s, expected is [none, timepreserve, half]"
83
+ % self.layer_type
84
+ )
85
+
86
+ def forward(self, x):
87
+ return self.conv(x)
88
+
89
+
90
+ class DownSample(nn.Module):
91
+ def __init__(self, layer_type):
92
+ super().__init__()
93
+ self.layer_type = layer_type
94
+
95
+ def forward(self, x):
96
+ if self.layer_type == "none":
97
+ return x
98
+ elif self.layer_type == "timepreserve":
99
+ return F.avg_pool2d(x, (2, 1))
100
+ elif self.layer_type == "half":
101
+ if x.shape[-1] % 2 != 0:
102
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
103
+ return F.avg_pool2d(x, 2)
104
+ else:
105
+ raise RuntimeError(
106
+ "Got unexpected donwsampletype %s, expected is [none, timepreserve, half]"
107
+ % self.layer_type
108
+ )
109
+
110
+
111
+ class UpSample(nn.Module):
112
+ def __init__(self, layer_type):
113
+ super().__init__()
114
+ self.layer_type = layer_type
115
+
116
+ def forward(self, x):
117
+ if self.layer_type == "none":
118
+ return x
119
+ elif self.layer_type == "timepreserve":
120
+ return F.interpolate(x, scale_factor=(2, 1), mode="nearest")
121
+ elif self.layer_type == "half":
122
+ return F.interpolate(x, scale_factor=2, mode="nearest")
123
+ else:
124
+ raise RuntimeError(
125
+ "Got unexpected upsampletype %s, expected is [none, timepreserve, half]"
126
+ % self.layer_type
127
+ )
128
+
129
+
130
+ class ResBlk(nn.Module):
131
+ def __init__(
132
+ self,
133
+ dim_in,
134
+ dim_out,
135
+ actv=nn.LeakyReLU(0.2),
136
+ normalize=False,
137
+ downsample="none",
138
+ ):
139
+ super().__init__()
140
+ self.actv = actv
141
+ self.normalize = normalize
142
+ self.downsample = DownSample(downsample)
143
+ self.downsample_res = LearnedDownSample(downsample, dim_in)
144
+ self.learned_sc = dim_in != dim_out
145
+ self._build_weights(dim_in, dim_out)
146
+
147
+ def _build_weights(self, dim_in, dim_out):
148
+ self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
149
+ self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
150
+ if self.normalize:
151
+ self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
152
+ self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
153
+ if self.learned_sc:
154
+ self.conv1x1 = spectral_norm(
155
+ nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
156
+ )
157
+
158
+ def _shortcut(self, x):
159
+ if self.learned_sc:
160
+ x = self.conv1x1(x)
161
+ if self.downsample:
162
+ x = self.downsample(x)
163
+ return x
164
+
165
+ def _residual(self, x):
166
+ if self.normalize:
167
+ x = self.norm1(x)
168
+ x = self.actv(x)
169
+ x = self.conv1(x)
170
+ x = self.downsample_res(x)
171
+ if self.normalize:
172
+ x = self.norm2(x)
173
+ x = self.actv(x)
174
+ x = self.conv2(x)
175
+ return x
176
+
177
+ def forward(self, x):
178
+ x = self._shortcut(x) + self._residual(x)
179
+ return x / np.sqrt(2) # unit variance
180
+
181
+
182
+ class LinearNorm(torch.nn.Module):
183
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"):
184
+ super(LinearNorm, self).__init__()
185
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
186
+
187
+ torch.nn.init.xavier_uniform_(
188
+ self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
189
+ )
190
+
191
+ def forward(self, x):
192
+ return self.linear_layer(x)
193
+
194
+
195
+ class Discriminator2d(nn.Module):
196
+ def __init__(self, dim_in=48, num_domains=1, max_conv_dim=384, repeat_num=4):
197
+ super().__init__()
198
+ blocks = []
199
+ blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
200
+
201
+ for lid in range(repeat_num):
202
+ dim_out = min(dim_in * 2, max_conv_dim)
203
+ blocks += [ResBlk(dim_in, dim_out, downsample="half")]
204
+ dim_in = dim_out
205
+
206
+ blocks += [nn.LeakyReLU(0.2)]
207
+ blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
208
+ blocks += [nn.LeakyReLU(0.2)]
209
+ blocks += [nn.AdaptiveAvgPool2d(1)]
210
+ blocks += [spectral_norm(nn.Conv2d(dim_out, num_domains, 1, 1, 0))]
211
+ self.main = nn.Sequential(*blocks)
212
+
213
+ def get_feature(self, x):
214
+ features = []
215
+ for l in self.main:
216
+ x = l(x)
217
+ features.append(x)
218
+ out = features[-1]
219
+ out = out.view(out.size(0), -1) # (batch, num_domains)
220
+ return out, features
221
+
222
+ def forward(self, x):
223
+ out, features = self.get_feature(x)
224
+ out = out.squeeze() # (batch)
225
+ return out, features
226
+
227
+
228
+ class ResBlk1d(nn.Module):
229
+ def __init__(
230
+ self,
231
+ dim_in,
232
+ dim_out,
233
+ actv=nn.LeakyReLU(0.2),
234
+ normalize=False,
235
+ downsample="none",
236
+ dropout_p=0.2,
237
+ ):
238
+ super().__init__()
239
+ self.actv = actv
240
+ self.normalize = normalize
241
+ self.downsample_type = downsample
242
+ self.learned_sc = dim_in != dim_out
243
+ self._build_weights(dim_in, dim_out)
244
+ self.dropout_p = dropout_p
245
+
246
+ if self.downsample_type == "none":
247
+ self.pool = nn.Identity()
248
+ else:
249
+ self.pool = weight_norm(
250
+ nn.Conv1d(
251
+ dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1
252
+ )
253
+ )
254
+
255
+ def _build_weights(self, dim_in, dim_out):
256
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1))
257
+ self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
258
+ if self.normalize:
259
+ self.norm1 = nn.InstanceNorm1d(dim_in, affine=True)
260
+ self.norm2 = nn.InstanceNorm1d(dim_in, affine=True)
261
+ if self.learned_sc:
262
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
263
+
264
+ def downsample(self, x):
265
+ if self.downsample_type == "none":
266
+ return x
267
+ else:
268
+ if x.shape[-1] % 2 != 0:
269
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
270
+ return F.avg_pool1d(x, 2)
271
+
272
+ def _shortcut(self, x):
273
+ if self.learned_sc:
274
+ x = self.conv1x1(x)
275
+ x = self.downsample(x)
276
+ return x
277
+
278
+ def _residual(self, x):
279
+ if self.normalize:
280
+ x = self.norm1(x)
281
+ x = self.actv(x)
282
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
283
+
284
+ x = self.conv1(x)
285
+ x = self.pool(x)
286
+ if self.normalize:
287
+ x = self.norm2(x)
288
+
289
+ x = self.actv(x)
290
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
291
+
292
+ x = self.conv2(x)
293
+ return x
294
+
295
+ def forward(self, x):
296
+ x = self._shortcut(x) + self._residual(x)
297
+ return x / np.sqrt(2) # unit variance
298
+
299
+
300
+ class LayerNorm(nn.Module):
301
+ def __init__(self, channels, eps=1e-5):
302
+ super().__init__()
303
+ self.channels = channels
304
+ self.eps = eps
305
+
306
+ self.gamma = nn.Parameter(torch.ones(channels))
307
+ self.beta = nn.Parameter(torch.zeros(channels))
308
+
309
+ def forward(self, x):
310
+ x = x.transpose(1, -1)
311
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
312
+ return x.transpose(1, -1)
313
+
314
+
315
+ class TextEncoder(nn.Module):
316
+ def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
317
+ super().__init__()
318
+ self.embedding = nn.Embedding(n_symbols, channels)
319
+
320
+ padding = (kernel_size - 1) // 2
321
+ self.cnn = nn.ModuleList()
322
+ for _ in range(depth):
323
+ self.cnn.append(
324
+ nn.Sequential(
325
+ weight_norm(
326
+ nn.Conv1d(
327
+ channels, channels, kernel_size=kernel_size, padding=padding
328
+ )
329
+ ),
330
+ LayerNorm(channels),
331
+ actv,
332
+ nn.Dropout(0.2),
333
+ )
334
+ )
335
+ # self.cnn = nn.Sequential(*self.cnn)
336
+
337
+ self.lstm = nn.LSTM(
338
+ channels, channels // 2, 1, batch_first=True, bidirectional=True
339
+ )
340
+
341
+ def forward(self, x, input_lengths, m):
342
+ x = self.embedding(x) # [B, T, emb]
343
+ x = x.transpose(1, 2) # [B, emb, T]
344
+ m = m.to(input_lengths.device).unsqueeze(1)
345
+ x.masked_fill_(m, 0.0)
346
+
347
+ for c in self.cnn:
348
+ x = c(x)
349
+ x.masked_fill_(m, 0.0)
350
+
351
+ x = x.transpose(1, 2) # [B, T, chn]
352
+
353
+ input_lengths = input_lengths.cpu().numpy()
354
+ x = nn.utils.rnn.pack_padded_sequence(
355
+ x, input_lengths, batch_first=True, enforce_sorted=False
356
+ )
357
+
358
+ self.lstm.flatten_parameters()
359
+ x, _ = self.lstm(x)
360
+ x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
361
+
362
+ x = x.transpose(-1, -2)
363
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
364
+
365
+ x_pad[:, :, : x.shape[-1]] = x
366
+ x = x_pad.to(x.device)
367
+
368
+ x.masked_fill_(m, 0.0)
369
+
370
+ return x
371
+
372
+ def inference(self, x):
373
+ x = self.embedding(x)
374
+ x = x.transpose(1, 2)
375
+ x = self.cnn(x)
376
+ x = x.transpose(1, 2)
377
+ self.lstm.flatten_parameters()
378
+ x, _ = self.lstm(x)
379
+ return x
380
+
381
+ def length_to_mask(self, lengths):
382
+ mask = (
383
+ torch.arange(lengths.max())
384
+ .unsqueeze(0)
385
+ .expand(lengths.shape[0], -1)
386
+ .type_as(lengths)
387
+ )
388
+ mask = torch.gt(mask + 1, lengths.unsqueeze(1))
389
+ return mask
390
+
391
+
392
+ class AdaIN1d(nn.Module):
393
+ def __init__(self, style_dim, num_features):
394
+ super().__init__()
395
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
396
+ self.fc = nn.Linear(style_dim, num_features * 2)
397
+
398
+ def forward(self, x, s):
399
+ h = self.fc(s)
400
+ h = h.view(h.size(0), h.size(1), 1)
401
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
402
+ return (1 + gamma) * self.norm(x) + beta
403
+
404
+
405
+ class UpSample1d(nn.Module):
406
+ def __init__(self, layer_type):
407
+ super().__init__()
408
+ self.layer_type = layer_type
409
+
410
+ def forward(self, x):
411
+ if self.layer_type == "none":
412
+ return x
413
+ else:
414
+ return F.interpolate(x, scale_factor=2, mode="nearest")
415
+
416
+
417
+ class AdainResBlk1d(nn.Module):
418
+ def __init__(
419
+ self,
420
+ dim_in,
421
+ dim_out,
422
+ style_dim=64,
423
+ actv=nn.LeakyReLU(0.2),
424
+ upsample="none",
425
+ dropout_p=0.0,
426
+ ):
427
+ super().__init__()
428
+ self.actv = actv
429
+ self.upsample_type = upsample
430
+ self.upsample = UpSample1d(upsample)
431
+ self.learned_sc = dim_in != dim_out
432
+ self._build_weights(dim_in, dim_out, style_dim)
433
+ self.dropout = nn.Dropout(dropout_p)
434
+
435
+ if upsample == "none":
436
+ self.pool = nn.Identity()
437
+ else:
438
+ self.pool = weight_norm(
439
+ nn.ConvTranspose1d(
440
+ dim_in,
441
+ dim_in,
442
+ kernel_size=3,
443
+ stride=2,
444
+ groups=dim_in,
445
+ padding=1,
446
+ output_padding=1,
447
+ )
448
+ )
449
+
450
+ def _build_weights(self, dim_in, dim_out, style_dim):
451
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
452
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
453
+ self.norm1 = AdaIN1d(style_dim, dim_in)
454
+ self.norm2 = AdaIN1d(style_dim, dim_out)
455
+ if self.learned_sc:
456
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
457
+
458
+ def _shortcut(self, x):
459
+ x = self.upsample(x)
460
+ if self.learned_sc:
461
+ x = self.conv1x1(x)
462
+ return x
463
+
464
+ def _residual(self, x, s):
465
+ x = self.norm1(x, s)
466
+ x = self.actv(x)
467
+ x = self.pool(x)
468
+ x = self.conv1(self.dropout(x))
469
+ x = self.norm2(x, s)
470
+ x = self.actv(x)
471
+ x = self.conv2(self.dropout(x))
472
+ return x
473
+
474
+ def forward(self, x, s):
475
+ out = self._residual(x, s)
476
+ out = (out + self._shortcut(x)) / np.sqrt(2)
477
+ return out
478
+
479
+
480
+ class AdaLayerNorm(nn.Module):
481
+ def __init__(self, style_dim, channels, eps=1e-5):
482
+ super().__init__()
483
+ self.channels = channels
484
+ self.eps = eps
485
+
486
+ self.fc = nn.Linear(style_dim, channels * 2)
487
+
488
+ def forward(self, x, s):
489
+ x = x.transpose(-1, -2)
490
+ x = x.transpose(1, -1)
491
+
492
+ h = self.fc(s)
493
+ h = h.view(h.size(0), h.size(1), 1)
494
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
495
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
496
+
497
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
498
+ x = (1 + gamma) * x + beta
499
+ return x.transpose(1, -1).transpose(-1, -2)
500
+
501
+
502
+ class ProsodyPredictor(nn.Module):
503
+ def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
504
+ super().__init__()
505
+
506
+ self.text_encoder = DurationEncoder(
507
+ sty_dim=style_dim, d_model=d_hid, nlayers=nlayers, dropout=dropout
508
+ )
509
+
510
+ self.lstm = nn.LSTM(
511
+ d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True
512
+ )
513
+ self.duration_proj = LinearNorm(d_hid, max_dur)
514
+
515
+ self.shared = nn.LSTM(
516
+ d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True
517
+ )
518
+ self.F0 = nn.ModuleList()
519
+ self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
520
+ self.F0.append(
521
+ AdainResBlk1d(
522
+ d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout
523
+ )
524
+ )
525
+ self.F0.append(
526
+ AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout)
527
+ )
528
+
529
+ self.N = nn.ModuleList()
530
+ self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
531
+ self.N.append(
532
+ AdainResBlk1d(
533
+ d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout
534
+ )
535
+ )
536
+ self.N.append(
537
+ AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout)
538
+ )
539
+
540
+ self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
541
+ self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
542
+
543
+ def forward(self, texts, style, text_lengths, alignment, m):
544
+ d = self.text_encoder(texts, style, text_lengths, m)
545
+
546
+ batch_size = d.shape[0]
547
+ text_size = d.shape[1]
548
+
549
+ # predict duration
550
+ input_lengths = text_lengths.cpu().numpy()
551
+ x = nn.utils.rnn.pack_padded_sequence(
552
+ d, input_lengths, batch_first=True, enforce_sorted=False
553
+ )
554
+
555
+ m = m.to(text_lengths.device).unsqueeze(1)
556
+
557
+ self.lstm.flatten_parameters()
558
+ x, _ = self.lstm(x)
559
+ x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
560
+
561
+ x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
562
+
563
+ x_pad[:, : x.shape[1], :] = x
564
+ x = x_pad.to(x.device)
565
+
566
+ duration = self.duration_proj(
567
+ nn.functional.dropout(x, 0.5, training=self.training)
568
+ )
569
+
570
+ en = d.transpose(-1, -2) @ alignment
571
+
572
+ return duration.squeeze(-1), en
573
+
574
+ def F0Ntrain(self, x, s):
575
+ x, _ = self.shared(x.transpose(-1, -2))
576
+
577
+ F0 = x.transpose(-1, -2)
578
+ for block in self.F0:
579
+ F0 = block(F0, s)
580
+ F0 = self.F0_proj(F0)
581
+
582
+ N = x.transpose(-1, -2)
583
+ for block in self.N:
584
+ N = block(N, s)
585
+ N = self.N_proj(N)
586
+
587
+ return F0.squeeze(1), N.squeeze(1)
588
+
589
+ def length_to_mask(self, lengths):
590
+ mask = (
591
+ torch.arange(lengths.max())
592
+ .unsqueeze(0)
593
+ .expand(lengths.shape[0], -1)
594
+ .type_as(lengths)
595
+ )
596
+ mask = torch.gt(mask + 1, lengths.unsqueeze(1))
597
+ return mask
598
+
599
+
600
+ class DurationEncoder(nn.Module):
601
+ def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
602
+ super().__init__()
603
+ self.lstms = nn.ModuleList()
604
+ for _ in range(nlayers):
605
+ self.lstms.append(
606
+ nn.LSTM(
607
+ d_model + sty_dim,
608
+ d_model // 2,
609
+ num_layers=1,
610
+ batch_first=True,
611
+ bidirectional=True,
612
+ dropout=dropout,
613
+ )
614
+ )
615
+ self.lstms.append(AdaLayerNorm(sty_dim, d_model))
616
+
617
+ self.dropout = dropout
618
+ self.d_model = d_model
619
+ self.sty_dim = sty_dim
620
+
621
+ def forward(self, x, style, text_lengths, m):
622
+ masks = m.to(text_lengths.device)
623
+
624
+ x = x.permute(2, 0, 1)
625
+ s = style.expand(x.shape[0], x.shape[1], -1)
626
+ x = torch.cat([x, s], axis=-1)
627
+ x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
628
+
629
+ x = x.transpose(0, 1)
630
+ input_lengths = text_lengths.cpu().numpy()
631
+ x = x.transpose(-1, -2)
632
+
633
+ for block in self.lstms:
634
+ if isinstance(block, AdaLayerNorm):
635
+ x = block(x.transpose(-1, -2), style).transpose(-1, -2)
636
+ x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
637
+ x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
638
+ else:
639
+ x = x.transpose(-1, -2)
640
+ x = nn.utils.rnn.pack_padded_sequence(
641
+ x, input_lengths, batch_first=True, enforce_sorted=False
642
+ )
643
+ block.flatten_parameters()
644
+ x, _ = block(x)
645
+ x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
646
+ x = F.dropout(x, p=self.dropout, training=self.training)
647
+ x = x.transpose(-1, -2)
648
+
649
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
650
+
651
+ x_pad[:, :, : x.shape[-1]] = x
652
+ x = x_pad.to(x.device)
653
+
654
+ return x.transpose(-1, -2)
655
+
656
+ def inference(self, x, style):
657
+ x = self.embedding(x.transpose(-1, -2)) * np.sqrt(self.d_model)
658
+ style = style.expand(x.shape[0], x.shape[1], -1)
659
+ x = torch.cat([x, style], axis=-1)
660
+ src = self.pos_encoder(x)
661
+ output = self.transformer_encoder(src).transpose(0, 1)
662
+ return output
663
+
664
+ def length_to_mask(self, lengths):
665
+ mask = (
666
+ torch.arange(lengths.max())
667
+ .unsqueeze(0)
668
+ .expand(lengths.shape[0], -1)
669
+ .type_as(lengths)
670
+ )
671
+ mask = torch.gt(mask + 1, lengths.unsqueeze(1))
672
+ return mask
673
+
674
+
675
+ # https://github.com/yl4579/StyleTTS2/blob/main/utils.py
676
+ def recursive_munch(d):
677
+ if isinstance(d, dict):
678
+ return Munch((k, recursive_munch(v)) for k, v in d.items())
679
+ elif isinstance(d, list):
680
+ return [recursive_munch(v) for v in d]
681
+ else:
682
+ return d
683
+
684
+
685
+ def build_model(path, device):
686
+ config = Path(path).parent / "config.json"
687
+ assert config.exists(), f"Config path incorrect: config.json not found at {config}"
688
+ with open(config, "r") as r:
689
+ args = recursive_munch(json.load(r))
690
+ assert args.decoder.type == "istftnet", f"Unknown decoder type: {args.decoder.type}"
691
+ decoder = Decoder(
692
+ dim_in=args.hidden_dim,
693
+ style_dim=args.style_dim,
694
+ dim_out=args.n_mels,
695
+ resblock_kernel_sizes=args.decoder.resblock_kernel_sizes,
696
+ upsample_rates=args.decoder.upsample_rates,
697
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
698
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
699
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
700
+ gen_istft_n_fft=args.decoder.gen_istft_n_fft,
701
+ gen_istft_hop_size=args.decoder.gen_istft_hop_size,
702
+ )
703
+ text_encoder = TextEncoder(
704
+ channels=args.hidden_dim,
705
+ kernel_size=5,
706
+ depth=args.n_layer,
707
+ n_symbols=args.n_token,
708
+ )
709
+ predictor = ProsodyPredictor(
710
+ style_dim=args.style_dim,
711
+ d_hid=args.hidden_dim,
712
+ nlayers=args.n_layer,
713
+ max_dur=args.max_dur,
714
+ dropout=args.dropout,
715
+ )
716
+ bert = load_plbert()
717
+ bert_encoder = nn.Linear(bert.config.hidden_size, args.hidden_dim)
718
+ for parent in [bert, bert_encoder, predictor, decoder, text_encoder]:
719
+ for child in parent.children():
720
+ if isinstance(child, nn.RNNBase):
721
+ child.flatten_parameters()
722
+ model = Munch(
723
+ bert=bert.to(device).eval(),
724
+ bert_encoder=bert_encoder.to(device).eval(),
725
+ predictor=predictor.to(device).eval(),
726
+ decoder=decoder.to(device).eval(),
727
+ text_encoder=text_encoder.to(device).eval(),
728
+ )
729
+ for key, state_dict in torch.load(path, map_location="cpu", weights_only=True)[
730
+ "net"
731
+ ].items():
732
+ assert key in model, key
733
+ try:
734
+ model[key].load_state_dict(state_dict)
735
+ except:
736
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
737
+ model[key].load_state_dict(state_dict, strict=False)
738
+ return model
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ espeak-ng
plbert.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
2
+ from transformers import AlbertConfig, AlbertModel
3
+
4
+ class CustomAlbert(AlbertModel):
5
+ def forward(self, *args, **kwargs):
6
+ # Call the original forward method
7
+ outputs = super().forward(*args, **kwargs)
8
+ # Only return the last_hidden_state
9
+ return outputs.last_hidden_state
10
+
11
+ def load_plbert():
12
+ plbert_config = {'vocab_size': 178, 'hidden_size': 768, 'num_attention_heads': 12, 'intermediate_size': 2048, 'max_position_embeddings': 512, 'num_hidden_layers': 12, 'dropout': 0.1}
13
+ albert_base_configuration = AlbertConfig(**plbert_config)
14
+ bert = CustomAlbert(albert_base_configuration)
15
+ return bert
pretrained_models/Kokoro/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .downloader import initialize_files
2
+
3
+ initialize_files()
pretrained_models/Kokoro/downloader.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from huggingface_hub import snapshot_download
4
+
5
+
6
+ def download_files(repo_id, target_base_dir, patterns):
7
+ """
8
+ Downloads files matching patterns from a Hugging Face repository and organizes them in a directory structure.
9
+
10
+ :param repo_id: Hugging Face repository ID.
11
+ :param target_base_dir: Base directory where files should be stored.
12
+ :param patterns: A dictionary mapping subdirectories to file patterns.
13
+ Example: {"root": ["config.json", "*.pth"], "voices": ["*.pt"]}
14
+ """
15
+ # Ensure target base directory exists
16
+ if not os.path.exists(target_base_dir):
17
+ os.makedirs(target_base_dir)
18
+
19
+ # Download the snapshot containing all matching files
20
+ snapshot_dir = snapshot_download(repo_id=repo_id, allow_patterns="*")
21
+
22
+ # Loop through patterns and subdirectories
23
+ for subdir, file_patterns in patterns.items():
24
+ # Set target directory for root-level files
25
+ target_dir = (
26
+ target_base_dir
27
+ if subdir == "root"
28
+ else os.path.join(target_base_dir, subdir)
29
+ )
30
+ os.makedirs(target_dir, exist_ok=True)
31
+
32
+ for file_pattern in file_patterns:
33
+ # Walk through the snapshot directory to find matching files
34
+ for root, _, files in os.walk(snapshot_dir):
35
+ for file in files:
36
+ if file.endswith(file_pattern.lstrip("*")): # Match pattern
37
+ source_path = os.path.join(root, file)
38
+ target_file_path = os.path.join(target_dir, file)
39
+
40
+ # Check if file already exists
41
+ if not os.path.exists(target_file_path):
42
+ # Copy the file to the target directory
43
+ shutil.copy(source_path, target_file_path)
44
+ print(f"Downloaded and saved: {file} to {target_dir}")
45
+ else:
46
+ print(f"File already exists, skipping: {target_file_path}")
47
+
48
+
49
+ def initialize_files():
50
+ repo_id = "hexgrad/Kokoro-82M"
51
+ target_base_dir = "pretrained_models/Kokoro" # Base directory for files
52
+ file_patterns = {
53
+ "root": ["config.json", "*.pth"], # Files for the root directory
54
+ "voices": ["*.pt"], # Wildcard for voice pack files
55
+ }
56
+
57
+ download_files(repo_id, target_base_dir, file_patterns)
58
+
59
+
60
+ if __name__ == "__main__":
61
+ initialize_files()
pyproject.toml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "kokoro-studio"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "gradio>=5.9.1",
9
+ "munch>=4.0.0",
10
+ "openphonemizer>=0.1.2",
11
+ "phonemizer>=3.3.0",
12
+ "pyyaml>=6.0.2",
13
+ "scipy>=1.14.1",
14
+ "soundfile>=0.12.1",
15
+ "torch>=2.5.1",
16
+ "transformers>=4.47.1",
17
+ ]
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=5.9.1
2
+ munch>=4.0.0
3
+ openphonemizer>=0.1.2
4
+ phonemizer>=3.3.0
5
+ pyyaml>=6.0.2
6
+ scipy>=1.14.1
7
+ soundfile>=0.12.1
8
+ torch>=2.5.1
9
+ transformers>=4.47.1
tts_cli.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # tts_cli.py
3
+ """
4
+ Example CLI for generating audio with Kokoro-StyleTTS2.
5
+
6
+ Usage:
7
+ python tts_cli.py \
8
+ --model /path/to/kokoro-v0_19.pth \
9
+ --config /path/to/config.json \
10
+ --text "Hello, my stinking friends from 1906! You stink." \
11
+ --voicepack /path/to/af.pt \
12
+ --output output.wav
13
+
14
+ Make sure:
15
+ 1. `models.py` is in the same folder (with `build_model`, `Decoder`, etc.).
16
+ 2. You have installed the needed libraries:
17
+ pip install torch phonemizer munch soundfile pyyaml
18
+ 3. The model is a checkpoint that your `build_model` can load.
19
+
20
+ Adapt as needed!
21
+ """
22
+
23
+ import argparse
24
+ import os
25
+ import re
26
+ import torch
27
+ import soundfile as sf
28
+ import numpy as np
29
+ from phonemizer import backend as phonemizer_backend
30
+
31
+ # If you use eSpeak library:
32
+ try:
33
+ from espeak_util import set_espeak_library
34
+
35
+ set_espeak_library()
36
+ except ImportError:
37
+ pass
38
+
39
+ # --------------------------------------------------------------------
40
+ # Import from your local `models.py` (requires that file to be present).
41
+ # This example assumes `build_model` loads the entire TTS submodules
42
+ # (bert, bert_encoder, predictor, decoder, text_encoder).
43
+ # --------------------------------------------------------------------
44
+ from models import build_model
45
+
46
+
47
+ def resplit_strings(arr):
48
+ """
49
+ Given a list of string tokens (e.g. words, phrases), tries to
50
+ split them into two sub-lists whose total lengths are as balanced
51
+ as possible. The goal is to chunk a large string in half without
52
+ splitting in the middle of a word.
53
+ """
54
+ if not arr:
55
+ return "", ""
56
+ if len(arr) == 1:
57
+ return arr[0], ""
58
+
59
+ min_diff = float("inf")
60
+ best_split = 0
61
+ lengths = [len(s) for s in arr]
62
+ spaces = len(arr) - 1
63
+ left_len = 0
64
+ right_len = sum(lengths) + spaces
65
+
66
+ for i in range(1, len(arr)):
67
+ # Add current word + space to left side
68
+ left_len += lengths[i - 1] + (1 if i > 1 else 0)
69
+ # Remove from right side
70
+ right_len -= lengths[i - 1] + 1
71
+ diff = abs(left_len - right_len)
72
+ if diff < min_diff:
73
+ min_diff = diff
74
+ best_split = i
75
+
76
+ return " ".join(arr[:best_split]), " ".join(arr[best_split:])
77
+
78
+
79
+ def recursive_split(text, lang="a"):
80
+ """
81
+ Splits a piece of text into smaller segments so that
82
+ each segment's phoneme length < some ~limit (~500 tokens).
83
+ """
84
+ # We'll reuse your existing `phonemize_text` + `tokenize` from script 1
85
+ # to see if it is < 512 tokens. If it is, return it as a single chunk.
86
+ # Otherwise, split on punctuation or whitespace and recurse.
87
+
88
+ # 1. Phonemize first, check length
89
+ ps = phonemize_text(text, lang=lang, do_normalize=True)
90
+ tokens = tokenize(ps)
91
+ if len(tokens) < 512:
92
+ return [(text, ps)]
93
+
94
+ # If too large, we split on certain punctuation or fallback to whitespace
95
+ # We'll look for punctuation that often indicates sentence boundaries
96
+ # If none found, fallback to space-split
97
+ for punctuation in [r"[.?!…]", r"[:,;—]"]:
98
+ pattern = f"(?:(?<={punctuation})|(?<={punctuation}[\"'»])) "
99
+ # Attempt to split on that punctuation
100
+ splits = re.split(pattern, text)
101
+ if len(splits) > 1:
102
+ break
103
+ else:
104
+ # If we didn't break out, just do whitespace split
105
+ splits = text.split(" ")
106
+
107
+ # Use resplit_strings to chunk it about halfway
108
+ left, right = resplit_strings(splits)
109
+ # Recurse
110
+ return recursive_split(left, lang=lang) + recursive_split(right, lang=lang)
111
+
112
+
113
+ def segment_and_tokenize(long_text, lang="a"):
114
+ """
115
+ Takes a large text, optionally normalizes or cleans it,
116
+ then breaks it into a list of (segment_text, segment_phonemes).
117
+ """
118
+ # Additional cleaning if you want:
119
+ # long_text = normalize_text(long_text) # your existing function
120
+ # We chunk it up using recursive_split
121
+ segments = recursive_split(long_text, lang=lang)
122
+ return segments
123
+
124
+
125
+ # -------------- Normalization & Phonemization Routines -------------- #
126
+ def parens_to_angles(s):
127
+ return s.replace("(", "«").replace(")", "»")
128
+
129
+
130
+ def split_num(num):
131
+ num = num.group()
132
+ if "." in num:
133
+ return num
134
+ elif ":" in num:
135
+ h, m = [int(n) for n in num.split(":")]
136
+ if m == 0:
137
+ return f"{h} o'clock"
138
+ elif m < 10:
139
+ return f"{h} oh {m}"
140
+ return f"{h} {m}"
141
+ year = int(num[:4])
142
+ if year < 1100 or year % 1000 < 10:
143
+ return num
144
+ left, right = num[:2], int(num[2:4])
145
+ s = "s" if num.endswith("s") else ""
146
+ if 100 <= year % 1000 <= 999:
147
+ if right == 0:
148
+ return f"{left} hundred{s}"
149
+ elif right < 10:
150
+ return f"{left} oh {right}{s}"
151
+ return f"{left} {right}{s}"
152
+
153
+
154
+ def flip_money(m):
155
+ m = m.group()
156
+ bill = "dollar" if m[0] == "$" else "pound"
157
+ if m[-1].isalpha():
158
+ return f"{m[1:]} {bill}s"
159
+ elif "." not in m:
160
+ s = "" if m[1:] == "1" else "s"
161
+ return f"{m[1:]} {bill}{s}"
162
+ b, c = m[1:].split(".")
163
+ s = "" if b == "1" else "s"
164
+ c = int(c.ljust(2, "0"))
165
+ coins = (
166
+ f"cent{'' if c == 1 else 's'}"
167
+ if m[0] == "$"
168
+ else ("penny" if c == 1 else "pence")
169
+ )
170
+ return f"{b} {bill}{s} and {c} {coins}"
171
+
172
+
173
+ def point_num(num):
174
+ a, b = num.group().split(".")
175
+ return " point ".join([a, " ".join(b)])
176
+
177
+
178
+ def normalize_text(text):
179
+ text = text.replace(chr(8216), "'").replace(chr(8217), "'")
180
+ text = text.replace("«", chr(8220)).replace("»", chr(8221))
181
+ text = text.replace(chr(8220), '"').replace(chr(8221), '"')
182
+ text = parens_to_angles(text)
183
+
184
+ # Replace some common full-width punctuation in CJK:
185
+ for a, b in zip("、。!,:;?", ",.!,:;?"):
186
+ text = text.replace(a, b + " ")
187
+
188
+ text = re.sub(r"[^\S \n]", " ", text)
189
+ text = re.sub(r" +", " ", text)
190
+ text = re.sub(r"(?<=\n) +(?=\n)", "", text)
191
+ text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
192
+ text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)
193
+ text = re.sub(r"\b(?:Ms\.|MS\.(?= [A-Z]))", "Miss", text)
194
+ text = re.sub(r"\b(?:Mrs\.|MRS\.(?= [A-Z]))", "Mrs", text)
195
+ text = re.sub(r"\betc\.(?! [A-Z])", "etc", text)
196
+ text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
197
+ text = re.sub(
198
+ r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)",
199
+ split_num,
200
+ text,
201
+ )
202
+ text = re.sub(r"(?<=\d),(?=\d)", "", text)
203
+ text = re.sub(
204
+ r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
205
+ flip_money,
206
+ text,
207
+ )
208
+ text = re.sub(r"\d*\.\d+", point_num, text)
209
+ text = re.sub(r"(?<=\d)-(?=\d)", " to ", text) # Could be minus; adjust if needed
210
+ text = re.sub(r"(?<=\d)S", " S", text)
211
+ text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
212
+ text = re.sub(r"(?<=X')S\b", "s", text)
213
+ text = re.sub(
214
+ r"(?:[A-Za-z]\.){2,} [a-z]", lambda m: m.group().replace(".", "-"), text
215
+ )
216
+ text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
217
+ return text.strip()
218
+
219
+
220
+ # -------------------------------------------------------------------
221
+ # Vocab and Symbol Mapping
222
+ # -------------------------------------------------------------------
223
+ def get_vocab():
224
+ _pad = "$"
225
+ _punctuation = ';:,.!?¡¿—…"«»“” '
226
+ _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
227
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
228
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
229
+ dicts = {}
230
+ for i, s in enumerate(symbols):
231
+ dicts[s] = i
232
+ return dicts
233
+
234
+
235
+ VOCAB = get_vocab()
236
+
237
+
238
+ def tokenize(ps: str):
239
+ """Convert the phoneme string into integer tokens based on VOCAB."""
240
+ return [VOCAB.get(p) for p in ps if p in VOCAB]
241
+
242
+
243
+ # -------------------------------------------------------------------
244
+ # Initialize a simple phonemizer
245
+ # For English:
246
+ # 'a' ~ en-us
247
+ # 'b' ~ en-gb
248
+ # -------------------------------------------------------------------
249
+ phonemizers = dict(
250
+ a=phonemizer_backend.EspeakBackend(
251
+ language="en-us", preserve_punctuation=True, with_stress=True
252
+ ),
253
+ b=phonemizer_backend.EspeakBackend(
254
+ language="en-gb", preserve_punctuation=True, with_stress=True
255
+ ),
256
+ # You can add more, e.g. 'j': some Japanese phonemizer, etc.
257
+ )
258
+
259
+
260
+ def phonemize_text(text, lang="a", do_normalize=True):
261
+ if do_normalize:
262
+ text = normalize_text(text)
263
+ ps_list = phonemizers[lang].phonemize([text])
264
+ ps = ps_list[0] if ps_list else ""
265
+
266
+ # Some custom replacements (from your code)
267
+ ps = ps.replace("kəkˈoːɹoʊ", "kˈoʊkəɹoʊ").replace("kəkˈɔːɹəʊ", "kˈəʊkəɹəʊ")
268
+ ps = ps.replace("ʲ", "j").replace("r", "ɹ").replace("x", "k").replace("ɬ", "l")
269
+ # Example: insert space before "hˈʌndɹɪd" if there's a letter, e.g. "nˈaɪn" => "nˈaɪn hˈʌndɹɪd"
270
+ ps = re.sub(r"(?<=[a-zɹː])(?=hˈʌndɹɪd)", " ", ps)
271
+ # "z" at the end of a word -> remove space (just your snippet)
272
+ ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', "z", ps)
273
+ # If lang is 'a', handle "ninety" => "ninedi"? Just from your snippet:
274
+ if lang == "a":
275
+ ps = re.sub(r"(?<=nˈaɪn)ti(?!ː)", "di", ps)
276
+
277
+ # Only keep valid symbols
278
+ ps = "".join(p for p in ps if p in VOCAB)
279
+ return ps.strip()
280
+
281
+
282
+ # -------------------------------------------------------------------
283
+ # Utility for generating text masks
284
+ # -------------------------------------------------------------------
285
+ def length_to_mask(lengths):
286
+ # lengths is a Tensor of shape [B], containing the text length for each batch
287
+ max_len = lengths.max()
288
+ row_ids = torch.arange(max_len, device=lengths.device).unsqueeze(0)
289
+ mask = row_ids.expand(lengths.shape[0], -1)
290
+ return (mask + 1) > lengths.unsqueeze(1)
291
+
292
+
293
+ # -------------------------------------------------------------------
294
+ # The forward pass for inference (from your snippet).
295
+ # This version references `model.predictor`, `model.decoder`, etc.
296
+ # -------------------------------------------------------------------
297
+ @torch.no_grad()
298
+ def forward_tts(model, tokens, ref_s, speed=1.0):
299
+ """
300
+ model: Munch with submodels: bert, bert_encoder, predictor, decoder, text_encoder
301
+ tokens: list[int], the tokenized input (without [0, ... , 0] yet)
302
+ ref_s: reference embedding (torch.Tensor)
303
+ speed: float, speed factor
304
+ """
305
+ device = ref_s.device
306
+ tokens_t = torch.LongTensor([[0, *tokens, 0]]).to(device) # add boundary tokens
307
+ input_lengths = torch.LongTensor([tokens_t.shape[-1]]).to(device)
308
+ text_mask = length_to_mask(input_lengths).to(device)
309
+
310
+ # 1. Encode with BERT
311
+ bert_dur = model.bert(tokens_t, attention_mask=(~text_mask).int())
312
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
313
+
314
+ # 2. Prosody predictor
315
+ s = ref_s[
316
+ :, 128:
317
+ ] # from your snippet: the last 128 is ???, or the first 128 is ???
318
+
319
+ d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
320
+ x, _ = model.predictor.lstm(d)
321
+ duration = model.predictor.duration_proj(x)
322
+ duration = torch.sigmoid(duration).sum(axis=-1) / speed
323
+ pred_dur = torch.round(duration).clamp(min=1).long()
324
+
325
+ # 3. Expand alignment
326
+ total_len = pred_dur.sum().item()
327
+ pred_aln_trg = torch.zeros(input_lengths, total_len, device=device)
328
+ c_frame = 0
329
+ for i in range(pred_aln_trg.size(0)):
330
+ n = pred_dur[0, i].item()
331
+ pred_aln_trg[i, c_frame : c_frame + n] = 1
332
+ c_frame += n
333
+
334
+ # 4. Run F0 + Noise predictor
335
+ en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0)
336
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
337
+
338
+ # 5. Text encoder -> asr
339
+ t_en = model.text_encoder(tokens_t, input_lengths, text_mask)
340
+ asr = t_en @ pred_aln_trg.unsqueeze(0)
341
+
342
+ # 6. Decode audio
343
+ audio = model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]) # B x audio_len
344
+ return audio.squeeze().cpu().numpy()
345
+
346
+
347
+ def generate_tts(model, text, voicepack, lang="a", speed=1.0):
348
+ """
349
+ model: the Munch returned by build_model(...)
350
+ text: the input text (string)
351
+ voicepack: the torch Tensor reference embedding, or a dict of them
352
+ lang: 'a' or 'b' or etc. from your phonemizers
353
+ speed: speech speed factor
354
+ sample_rate: sampling rate for the output
355
+ """
356
+ # 1. Phonemize
357
+ ps = phonemize_text(text, lang=lang, do_normalize=True)
358
+ tokens = tokenize(ps)
359
+ if not tokens:
360
+ return None, ps
361
+
362
+ # 2. Retrieve reference style
363
+ # If your voicepack is a single embedding for all lengths, adapt as needed.
364
+ # If your voicepack is something like `voicepack[len(tokens)]`, do that.
365
+ # If you have multiple voices, you might do something else.
366
+ try:
367
+ ref_s = voicepack[len(tokens)]
368
+ except:
369
+ # fallback if len(tokens) is out of range
370
+ ref_s = voicepack[-1]
371
+ ref_s = ref_s.to("cpu" if not next(model.bert.parameters()).is_cuda else "cuda")
372
+
373
+ # 3. Generate
374
+ audio = forward_tts(model, tokens, ref_s, speed=speed)
375
+ return audio, ps
376
+
377
+
378
+ def generate_long_form_tts(model, full_text, voicepack, lang="a", speed=1.0):
379
+ """
380
+ Generate TTS for a large `full_text`, splitting it into smaller segments
381
+ and concatenating the resulting audio.
382
+
383
+ Returns: (np.float32 array) final_audio, list_of_segment_phonemes
384
+ """
385
+ # 1. Segment the text
386
+ segments = segment_and_tokenize(full_text, lang=lang)
387
+ # segments is a list of (seg_text, seg_phonemes)
388
+
389
+ # 2. For each segment, call `generate_tts(...)`
390
+ audio_chunks = []
391
+ all_phonemes = []
392
+ for i, (seg_text, seg_ps) in enumerate(segments, 1):
393
+ print(f"[LongForm] Generating chunk {i}/{len(segments)}: {seg_text[:40]}...")
394
+ audio, used_phonemes = generate_tts(
395
+ model, seg_text, voicepack, lang=lang, speed=speed
396
+ )
397
+ if audio is not None:
398
+ audio_chunks.append(audio)
399
+ all_phonemes.append(used_phonemes)
400
+ else:
401
+ print(f"[LongForm] Skipped empty segment {i}...")
402
+
403
+ if not audio_chunks:
404
+ return None, []
405
+
406
+ # 3. Concatenate the audio
407
+ final_audio = np.concatenate(audio_chunks, axis=0)
408
+ return final_audio, all_phonemes
409
+
410
+
411
+ # -------------------------------------------------------------------
412
+ # Main CLI
413
+ # -------------------------------------------------------------------
414
+ def main():
415
+ parser = argparse.ArgumentParser(description="Kokoro-StyleTTS2 CLI Example")
416
+ parser.add_argument(
417
+ "--model",
418
+ type=str,
419
+ default="pretrained_models/Kokoro/kokoro-v0_19.pth",
420
+ help="Path to your model checkpoint (e.g. kokoro-v0_19.pth).",
421
+ )
422
+ parser.add_argument(
423
+ "--config",
424
+ type=str,
425
+ default="pretrained_models/Kokoro/config.json",
426
+ help="Path to config.json (used by build_model).",
427
+ )
428
+ parser.add_argument(
429
+ "--text",
430
+ type=str,
431
+ default="Hello world! This is Kokoro, a new text-to-speech model based on StyleTTS2 from 2024!",
432
+ help="Text to be converted into speech.",
433
+ )
434
+ parser.add_argument(
435
+ "--voicepack",
436
+ type=str,
437
+ default="pretrained_models/Kokoro/voices/af.pt",
438
+ help="Path to a .pt file for your reference embedding(s).",
439
+ )
440
+ parser.add_argument(
441
+ "--output", type=str, default="output.wav", help="Output WAV filename."
442
+ )
443
+ parser.add_argument(
444
+ "--speed",
445
+ type=float,
446
+ default=1.0,
447
+ help="Speech speed factor, e.g. 0.8 slower, 1.2 faster, etc.",
448
+ )
449
+ parser.add_argument(
450
+ "--device",
451
+ type=str,
452
+ default="cpu",
453
+ choices=["cpu", "cuda"],
454
+ help="Device to run inference on.",
455
+ )
456
+ args = parser.parse_args()
457
+
458
+ # 1. Build model using your local build_model function
459
+ # (which loads TextEncoder, Decoder, etc. and returns a Munch).
460
+ if not os.path.isfile(args.config):
461
+ raise FileNotFoundError(f"config.json not found: {args.config}")
462
+
463
+ # Optionally load config as Munch (depends on your build_model usage)
464
+ # But your snippet does something like:
465
+ # with open(config, 'r') as r: ...
466
+ # ...
467
+ # model = build_model(path, device)
468
+ # We'll do the same but in a simpler form:
469
+ device = (
470
+ args.device if (args.device == "cuda" and torch.cuda.is_available()) else "cpu"
471
+ )
472
+ print(f"Loading model from: {args.model}")
473
+ model = build_model(
474
+ args.model, device
475
+ ) # This requires that `args.model` is the checkpoint path
476
+
477
+ # Because `build_model` returns a Munch (dict of submodules),
478
+ # we can't just do `model.eval()`, we must set each submodule to eval:
479
+ for k, subm in model.items():
480
+ if isinstance(subm, torch.nn.Module):
481
+ subm.eval()
482
+
483
+ # 2. Load voicepack
484
+ if not os.path.isfile(args.voicepack):
485
+ raise FileNotFoundError(f"Voicepack file not found: {args.voicepack}")
486
+ print(f"Loading voicepack from: {args.voicepack}")
487
+ vp = torch.load(args.voicepack, map_location=device)
488
+ # If your voicepack is an nn.Module, set it to eval as well
489
+ if isinstance(vp, torch.nn.Module):
490
+ vp.eval()
491
+
492
+ # 3. Generate audio
493
+ print(f"Generating speech for text: {args.text}")
494
+ audio, phonemes = generate_long_form_tts(
495
+ model, args.text, vp, lang="a", speed=args.speed
496
+ )
497
+ if audio is None:
498
+ print("No tokens were generated (maybe empty text?). Exiting.")
499
+ return
500
+
501
+ # 4. Write WAV
502
+ print(f"Writing output to: {args.output}")
503
+ sf.write(args.output, audio, 22050)
504
+
505
+ print("Finished!")
506
+ print(f"Phonemes used: {phonemes}")
507
+
508
+
509
+ if __name__ == "__main__":
510
+ main()
tts_cli_op.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # tts_cli_op.py
3
+ """
4
+ Example CLI for generating audio with Kokoro-StyleTTS2.
5
+
6
+ Usage:
7
+ python tts_cli.py \
8
+ --model /path/to/kokoro-v0_19.pth \
9
+ --config /path/to/config.json \
10
+ --text "Hello, my stinking friends from 1906! You stink." \
11
+ --voicepack /path/to/af.pt \
12
+ --output output.wav
13
+
14
+ Make sure:
15
+ 1. `models.py` is in the same folder (with `build_model`, `Decoder`, etc.).
16
+ 2. You have installed the needed libraries:
17
+ pip install torch phonemizer munch soundfile pyyaml
18
+ 3. The model is a checkpoint that your `build_model` can load.
19
+
20
+ Adapt as needed!
21
+ """
22
+
23
+ import argparse
24
+ import os
25
+ import re
26
+ import torch
27
+ import soundfile as sf
28
+ import numpy as np
29
+ from openphonemizer import OpenPhonemizer
30
+ from typing import List
31
+ import joblib
32
+
33
+ # --------------------------------------------------------------------
34
+ # Import from your local `models.py` (requires that file to be present).
35
+ # This example assumes `build_model` loads the entire TTS submodules
36
+ # (bert, bert_encoder, predictor, decoder, text_encoder).
37
+ # --------------------------------------------------------------------
38
+ from models import build_model
39
+
40
+
41
+ def resplit_strings(arr):
42
+ """
43
+ Given a list of string tokens (e.g. words, phrases), tries to
44
+ split them into two sub-lists whose total lengths are as balanced
45
+ as possible. The goal is to chunk a large string in half without
46
+ splitting in the middle of a word.
47
+ """
48
+ if not arr:
49
+ return "", ""
50
+ if len(arr) == 1:
51
+ return arr[0], ""
52
+
53
+ min_diff = float("inf")
54
+ best_split = 0
55
+ lengths = [len(s) for s in arr]
56
+ spaces = len(arr) - 1
57
+ left_len = 0
58
+ right_len = sum(lengths) + spaces
59
+
60
+ for i in range(1, len(arr)):
61
+ # Add current word + space to left side
62
+ left_len += lengths[i - 1] + (1 if i > 1 else 0)
63
+ # Remove from right side
64
+ right_len -= lengths[i - 1] + 1
65
+ diff = abs(left_len - right_len)
66
+ if diff < min_diff:
67
+ min_diff = diff
68
+ best_split = i
69
+
70
+ return " ".join(arr[:best_split]), " ".join(arr[best_split:])
71
+
72
+
73
+ def recursive_split(text, lang="a"):
74
+ """
75
+ Splits a piece of text into smaller segments so that
76
+ each segment's phoneme length < some ~limit (~500 tokens).
77
+ """
78
+ # We'll reuse your existing `phonemize_text` + `tokenize` from script 1
79
+ # to see if it is < 512 tokens. If it is, return it as a single chunk.
80
+ # Otherwise, split on punctuation or whitespace and recurse.
81
+
82
+ # 1. Phonemize first, check length
83
+ ps = phonemize_text(text, do_normalize=True)
84
+ tokens = tokenize(ps)
85
+ if len(tokens) < 512:
86
+ return [(text, ps)]
87
+
88
+ # If too large, we split on certain punctuation or fallback to whitespace
89
+ # We'll look for punctuation that often indicates sentence boundaries
90
+ # If none found, fallback to space-split
91
+ for punctuation in [r"[.?!…]", r"[:,;—]"]:
92
+ pattern = f"(?:(?<={punctuation})|(?<={punctuation}[\"'»])) "
93
+ # Attempt to split on that punctuation
94
+ splits = re.split(pattern, text)
95
+ if len(splits) > 1:
96
+ break
97
+ else:
98
+ # If we didn't break out, just do whitespace split
99
+ splits = text.split(" ")
100
+
101
+ # Use resplit_strings to chunk it about halfway
102
+ left, right = resplit_strings(splits)
103
+ # Recurse
104
+ return recursive_split(left) + recursive_split(right)
105
+
106
+
107
+ def segment_and_tokenize(long_text, lang="a"):
108
+ """
109
+ Takes a large text, optionally normalizes or cleans it,
110
+ then breaks it into a list of (segment_text, segment_phonemes).
111
+ """
112
+ # Additional cleaning if you want:
113
+ # long_text = normalize_text(long_text) # your existing function
114
+ # We chunk it up using recursive_split
115
+ segments = recursive_split(long_text)
116
+ return segments
117
+
118
+
119
+ # -------------- Normalization & Phonemization Routines -------------- #
120
+ def parens_to_angles(s):
121
+ return s.replace("(", "«").replace(")", "»")
122
+
123
+
124
+ def split_num(num):
125
+ num = num.group()
126
+ if "." in num:
127
+ return num
128
+ elif ":" in num:
129
+ h, m = [int(n) for n in num.split(":")]
130
+ if m == 0:
131
+ return f"{h} o'clock"
132
+ elif m < 10:
133
+ return f"{h} oh {m}"
134
+ return f"{h} {m}"
135
+ year = int(num[:4])
136
+ if year < 1100 or year % 1000 < 10:
137
+ return num
138
+ left, right = num[:2], int(num[2:4])
139
+ s = "s" if num.endswith("s") else ""
140
+ if 100 <= year % 1000 <= 999:
141
+ if right == 0:
142
+ return f"{left} hundred{s}"
143
+ elif right < 10:
144
+ return f"{left} oh {right}{s}"
145
+ return f"{left} {right}{s}"
146
+
147
+
148
+ def flip_money(m):
149
+ m = m.group()
150
+ bill = "dollar" if m[0] == "$" else "pound"
151
+ if m[-1].isalpha():
152
+ return f"{m[1:]} {bill}s"
153
+ elif "." not in m:
154
+ s = "" if m[1:] == "1" else "s"
155
+ return f"{m[1:]} {bill}{s}"
156
+ b, c = m[1:].split(".")
157
+ s = "" if b == "1" else "s"
158
+ c = int(c.ljust(2, "0"))
159
+ coins = (
160
+ f"cent{'' if c == 1 else 's'}"
161
+ if m[0] == "$"
162
+ else ("penny" if c == 1 else "pence")
163
+ )
164
+ return f"{b} {bill}{s} and {c} {coins}"
165
+
166
+
167
+ def point_num(num):
168
+ a, b = num.group().split(".")
169
+ return " point ".join([a, " ".join(b)])
170
+
171
+
172
+ def normalize_text(text):
173
+ text = text.replace(chr(8216), "'").replace(chr(8217), "'")
174
+ text = text.replace("«", chr(8220)).replace("»", chr(8221))
175
+ text = text.replace(chr(8220), '"').replace(chr(8221), '"')
176
+ text = parens_to_angles(text)
177
+
178
+ # Replace some common full-width punctuation in CJK:
179
+ for a, b in zip("、。!,:;?", ",.!,:;?"):
180
+ text = text.replace(a, b + " ")
181
+
182
+ text = re.sub(r"[^\S \n]", " ", text)
183
+ text = re.sub(r" +", " ", text)
184
+ text = re.sub(r"(?<=\n) +(?=\n)", "", text)
185
+ text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
186
+ text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)
187
+ text = re.sub(r"\b(?:Ms\.|MS\.(?= [A-Z]))", "Miss", text)
188
+ text = re.sub(r"\b(?:Mrs\.|MRS\.(?= [A-Z]))", "Mrs", text)
189
+ text = re.sub(r"\betc\.(?! [A-Z])", "etc", text)
190
+ text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
191
+ text = re.sub(
192
+ r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)",
193
+ split_num,
194
+ text,
195
+ )
196
+ text = re.sub(r"(?<=\d),(?=\d)", "", text)
197
+ text = re.sub(
198
+ r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
199
+ flip_money,
200
+ text,
201
+ )
202
+ text = re.sub(r"\d*\.\d+", point_num, text)
203
+ text = re.sub(r"(?<=\d)-(?=\d)", " to ", text) # Could be minus; adjust if needed
204
+ text = re.sub(r"(?<=\d)S", " S", text)
205
+ text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
206
+ text = re.sub(r"(?<=X')S\b", "s", text)
207
+ text = re.sub(
208
+ r"(?:[A-Za-z]\.){2,} [a-z]", lambda m: m.group().replace(".", "-"), text
209
+ )
210
+ text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
211
+ return text.strip()
212
+
213
+
214
+ # -------------------------------------------------------------------
215
+ # Vocab and Symbol Mapping
216
+ # -------------------------------------------------------------------
217
+ def get_vocab():
218
+ _pad = "$"
219
+ _punctuation = ';:,.!?¡¿—…"«»“” '
220
+ _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
221
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
222
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
223
+ dicts = {}
224
+ for i, s in enumerate(symbols):
225
+ dicts[s] = i
226
+ return dicts
227
+
228
+
229
+ VOCAB = get_vocab()
230
+
231
+
232
+ def tokenize(ps: str):
233
+ """Convert the phoneme string into integer tokens based on VOCAB."""
234
+ return [VOCAB.get(p) for p in ps if p in VOCAB]
235
+
236
+
237
+ # -------------------------------------------------------------------
238
+ # Initialize a simple phonemizer
239
+ # For English:
240
+ # 'a' ~ en-us
241
+ # 'b' ~ en-gb
242
+ # -------------------------------------------------------------------
243
+
244
+ # Wrapper around OpenPhonemizer to enable batch phonemization
245
+ open_phonemizer = OpenPhonemizer()
246
+
247
+
248
+ class Phonemizer:
249
+ def __init__(self, open_phonemizer):
250
+ """
251
+ open_phonemizer: an instance of OpenPhonemizer()
252
+ """
253
+ self.open_phonemizer = open_phonemizer
254
+
255
+ def phonemize(
256
+ self,
257
+ text: List[str],
258
+ njobs: int = 4,
259
+ strip: bool = False,
260
+ ) -> List[str]:
261
+ """
262
+ Phonemizes a list of input strings using OpenPhonemizer (which
263
+ itself only supports single-string input).
264
+
265
+ Parameters
266
+ ----------
267
+ text : list of str
268
+ Each element is an utterance (or line) to be phonemized.
269
+ njobs : int
270
+ The number of parallel jobs for phonemization.
271
+ strip : bool
272
+ Not used by OpenPhonemizer directly, but you can implement
273
+ an additional final “strip” step if needed.
274
+
275
+ Returns
276
+ -------
277
+ list of str
278
+ The phonemized text strings, in the same order as input.
279
+ """
280
+
281
+ if not isinstance(text, list) or any(not isinstance(x, str) for x in text):
282
+ raise ValueError("`text` must be a list of strings.")
283
+
284
+ # Optionally do any pre-processing you want here ...
285
+ # e.g. text = [self._some_preprocess(line) for line in text]
286
+
287
+ # If we only have one job, do it in a single loop
288
+ if njobs == 1:
289
+ return [self._phonemize_single(utterance, strip) for utterance in text]
290
+
291
+ # Otherwise, we can split `text` into chunks and process in parallel
292
+ # The easiest approach is to chunk the entire list into smaller sublists
293
+ # and phonemize each chunk. Then flatten them.
294
+ # For large corpora, you might want a more sophisticated approach.
295
+ chunked_results = joblib.Parallel(n_jobs=njobs)(
296
+ joblib.delayed(self._phonemize_single)(t, strip) for t in text
297
+ )
298
+ return chunked_results
299
+
300
+ def _phonemize_single(self, line: str, strip: bool) -> str:
301
+ """
302
+ Phonemize a single line using openphonemizer.
303
+ """
304
+ # OpenPhonemizer usage:
305
+ # out_str = self.open_phonemizer(line)
306
+ # (That returns the phonemes as a string.)
307
+ phonemes = self.open_phonemizer(line)
308
+
309
+ # Implement a post-strip if you want to mimic removing trailing delimiters
310
+ # (this is just a placeholder for demonstration).
311
+ if strip:
312
+ phonemes = phonemes.rstrip()
313
+
314
+ return phonemes
315
+
316
+
317
+ # initiate the wrapper:
318
+ phonemizer_wrapper = Phonemizer(open_phonemizer)
319
+
320
+
321
+ def phonemize_text(text, lang="a", do_normalize=True):
322
+ if do_normalize:
323
+ text = normalize_text(text)
324
+ ps_list = phonemizer_wrapper.phonemize([text])
325
+ ps = ps_list[0] if ps_list else ""
326
+
327
+ # Some custom replacements
328
+ ps = ps.replace("kəkˈoːɹoʊ", "kˈoʊkəɹoʊ").replace("kəkˈɔːɹəʊ", "kˈəʊkəɹəʊ")
329
+ ps = ps.replace("ʲ", "j").replace("r", "ɹ").replace("x", "k").replace("ɬ", "l")
330
+ # Example: insert space before "hˈʌndɹɪd" if there's a letter, e.g. "nˈaɪn" => "nˈaɪn hˈʌndɹɪd"
331
+ ps = re.sub(r"(?<=[a-zɹː])(?=hˈʌndɹɪd)", " ", ps)
332
+ # "z" at the end of a word -> remove space (just your snippet)
333
+ ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', "z", ps)
334
+ # Handle "ninety" => "ninedi"? Just from your snippet:
335
+ # If lang is 'a', handle "ninety" => "ninedi"? Just from your snippet:
336
+ if lang == "a":
337
+ ps = re.sub(r"(?<=nˈaɪn)ti(?!ː)", "di", ps)
338
+
339
+ # Only keep valid symbols
340
+ ps = "".join(p for p in ps if p in VOCAB)
341
+ return ps.strip()
342
+
343
+
344
+ # -------------------------------------------------------------------
345
+ # Utility for generating text masks
346
+ # -------------------------------------------------------------------
347
+ def length_to_mask(lengths):
348
+ # lengths is a Tensor of shape [B], containing the text length for each batch
349
+ max_len = lengths.max()
350
+ row_ids = torch.arange(max_len, device=lengths.device).unsqueeze(0)
351
+ mask = row_ids.expand(lengths.shape[0], -1)
352
+ return (mask + 1) > lengths.unsqueeze(1)
353
+
354
+
355
+ # -------------------------------------------------------------------
356
+ # The forward pass for inference (from your snippet).
357
+ # This version references `model.predictor`, `model.decoder`, etc.
358
+ # -------------------------------------------------------------------
359
+ @torch.no_grad()
360
+ def forward_tts(model, tokens, ref_s, speed=1.0):
361
+ """
362
+ model: Munch with submodels: bert, bert_encoder, predictor, decoder, text_encoder
363
+ tokens: list[int], the tokenized input (without [0, ... , 0] yet)
364
+ ref_s: reference embedding (torch.Tensor)
365
+ speed: float, speed factor
366
+ """
367
+ device = ref_s.device
368
+ tokens_t = torch.LongTensor([[0, *tokens, 0]]).to(device) # add boundary tokens
369
+ input_lengths = torch.LongTensor([tokens_t.shape[-1]]).to(device)
370
+ text_mask = length_to_mask(input_lengths).to(device)
371
+
372
+ # 1. Encode with BERT
373
+ bert_dur = model.bert(tokens_t, attention_mask=(~text_mask).int())
374
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
375
+
376
+ # 2. Prosody predictor
377
+ s = ref_s[
378
+ :, 128:
379
+ ] # from your snippet: the last 128 is ???, or the first 128 is ???
380
+
381
+ d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
382
+ x, _ = model.predictor.lstm(d)
383
+ duration = model.predictor.duration_proj(x)
384
+ duration = torch.sigmoid(duration).sum(axis=-1) / speed
385
+ pred_dur = torch.round(duration).clamp(min=1).long()
386
+
387
+ # 3. Expand alignment
388
+ total_len = pred_dur.sum().item()
389
+ pred_aln_trg = torch.zeros(input_lengths, total_len, device=device)
390
+ c_frame = 0
391
+ for i in range(pred_aln_trg.size(0)):
392
+ n = pred_dur[0, i].item()
393
+ pred_aln_trg[i, c_frame : c_frame + n] = 1
394
+ c_frame += n
395
+
396
+ # 4. Run F0 + Noise predictor
397
+ en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0)
398
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
399
+
400
+ # 5. Text encoder -> asr
401
+ t_en = model.text_encoder(tokens_t, input_lengths, text_mask)
402
+ asr = t_en @ pred_aln_trg.unsqueeze(0)
403
+
404
+ # 6. Decode audio
405
+ audio = model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]) # B x audio_len
406
+ return audio.squeeze().cpu().numpy()
407
+
408
+
409
+ def generate_tts(model, text, voicepack, lang="a", speed=1.0):
410
+ """
411
+ model: the Munch returned by build_model(...)
412
+ text: the input text (string)
413
+ voicepack: the torch Tensor reference embedding, or a dict of them
414
+ speed: speech speed factor
415
+ sample_rate: sampling rate for the output
416
+ """
417
+ # 1. Phonemize
418
+ ps = phonemize_text(text, do_normalize=True)
419
+ tokens = tokenize(ps)
420
+ if not tokens:
421
+ return None, ps
422
+
423
+ # 2. Retrieve reference style
424
+ # If your voicepack is a single embedding for all lengths, adapt as needed.
425
+ # If your voicepack is something like `voicepack[len(tokens)]`, do that.
426
+ # If you have multiple voices, you might do something else.
427
+ try:
428
+ ref_s = voicepack[len(tokens)]
429
+ except:
430
+ # fallback if len(tokens) is out of range
431
+ ref_s = voicepack[-1]
432
+ ref_s = ref_s.to("cpu" if not next(model.bert.parameters()).is_cuda else "cuda")
433
+
434
+ # 3. Generate
435
+ audio = forward_tts(model, tokens, ref_s, speed=speed)
436
+ return audio, ps
437
+
438
+
439
+ def generate_long_form_tts(model, full_text, voicepack, lang="a", speed=1.0):
440
+ """
441
+ Generate TTS for a large `full_text`, splitting it into smaller segments
442
+ and concatenating the resulting audio.
443
+
444
+ Returns: (np.float32 array) final_audio, list_of_segment_phonemes
445
+ """
446
+ # 1. Segment the text
447
+ segments = segment_and_tokenize(full_text)
448
+ # segments is a list of (seg_text, seg_phonemes)
449
+
450
+ # 2. For each segment, call `generate_tts(...)`
451
+ audio_chunks = []
452
+ all_phonemes = []
453
+ for i, (seg_text, seg_ps) in enumerate(segments, 1):
454
+ print(f"[LongForm] Generating chunk {i}/{len(segments)}: {seg_text[:40]}...")
455
+ audio, used_phonemes = generate_tts(model, seg_text, voicepack, speed=speed)
456
+ if audio is not None:
457
+ audio_chunks.append(audio)
458
+ all_phonemes.append(used_phonemes)
459
+ else:
460
+ print(f"[LongForm] Skipped empty segment {i}...")
461
+
462
+ if not audio_chunks:
463
+ return None, []
464
+
465
+ # 3. Concatenate the audio
466
+ final_audio = np.concatenate(audio_chunks, axis=0)
467
+ return final_audio, all_phonemes
468
+
469
+
470
+ # -------------------------------------------------------------------
471
+ # Main CLI
472
+ # -------------------------------------------------------------------
473
+ def main():
474
+ parser = argparse.ArgumentParser(description="Kokoro-StyleTTS2 CLI Example")
475
+ parser.add_argument(
476
+ "--model",
477
+ type=str,
478
+ default="pretrained_models/Kokoro/kokoro-v0_19.pth",
479
+ help="Path to your model checkpoint (e.g. kokoro-v0_19.pth).",
480
+ )
481
+ parser.add_argument(
482
+ "--config",
483
+ type=str,
484
+ default="pretrained_models/Kokoro/config.json",
485
+ help="Path to config.json (used by build_model).",
486
+ )
487
+ parser.add_argument(
488
+ "--text",
489
+ type=str,
490
+ default="Hello world! This is Kokoro, a new text-to-speech model based on StyleTTS2 from 2024!",
491
+ help="Text to be converted into speech.",
492
+ )
493
+ parser.add_argument(
494
+ "--voicepack",
495
+ type=str,
496
+ default="pretrained_models/Kokoro/voices/af.pt",
497
+ help="Path to a .pt file for your reference embedding(s).",
498
+ )
499
+ parser.add_argument(
500
+ "--output", type=str, default="output.wav", help="Output WAV filename."
501
+ )
502
+ parser.add_argument(
503
+ "--speed",
504
+ type=float,
505
+ default=1.2,
506
+ help="Speech speed factor, e.g. 0.8 slower, 1.2 faster, etc.",
507
+ )
508
+ parser.add_argument(
509
+ "--device",
510
+ type=str,
511
+ default="cpu",
512
+ choices=["cpu", "cuda"],
513
+ help="Device to run inference on.",
514
+ )
515
+ args = parser.parse_args()
516
+
517
+ # 1. Build model using your local build_model function
518
+ # (which loads TextEncoder, Decoder, etc. and returns a Munch).
519
+ if not os.path.isfile(args.config):
520
+ raise FileNotFoundError(f"config.json not found: {args.config}")
521
+
522
+ # Optionally load config as Munch (depends on your build_model usage)
523
+ # But your snippet does something like:
524
+ # with open(config, 'r') as r: ...
525
+ # ...
526
+ # model = build_model(path, device)
527
+ # We'll do the same but in a simpler form:
528
+ device = (
529
+ args.device if (args.device == "cuda" and torch.cuda.is_available()) else "cpu"
530
+ )
531
+ print(f"Loading model from: {args.model}")
532
+ model = build_model(
533
+ args.model, device
534
+ ) # This requires that `args.model` is the checkpoint path
535
+
536
+ # Because `build_model` returns a Munch (dict of submodules),
537
+ # we can't just do `model.eval()`, we must set each submodule to eval:
538
+ for k, subm in model.items():
539
+ if isinstance(subm, torch.nn.Module):
540
+ subm.eval()
541
+
542
+ # 2. Load voicepack
543
+ if not os.path.isfile(args.voicepack):
544
+ raise FileNotFoundError(f"Voicepack file not found: {args.voicepack}")
545
+ print(f"Loading voicepack from: {args.voicepack}")
546
+ vp = torch.load(args.voicepack, map_location=device)
547
+ # If your voicepack is an nn.Module, set it to eval as well
548
+ if isinstance(vp, torch.nn.Module):
549
+ vp.eval()
550
+
551
+ # 3. Generate audio
552
+ print(f"Generating speech for text: {args.text}")
553
+ audio, phonemes = generate_long_form_tts(
554
+ model, args.text, vp, lang="a", speed=args.speed
555
+ )
556
+ if audio is None:
557
+ print("No tokens were generated (maybe empty text?). Exiting.")
558
+ return
559
+
560
+ # 4. Write WAV
561
+ print(f"Writing output to: {args.output}")
562
+ sf.write(args.output, audio, 22050)
563
+
564
+ print("Finished!")
565
+ print(f"Phonemes used: {phonemes}")
566
+
567
+
568
+ if __name__ == "__main__":
569
+ main()
uv.lock ADDED
The diff for this file is too large to render. See raw diff