File size: 12,446 Bytes
0a639cb
 
 
 
165abce
 
 
 
 
 
 
 
 
 
 
 
 
 
3c8cbc9
0a639cb
 
165abce
0a639cb
 
 
 
 
165abce
3c8cbc9
0a639cb
165abce
 
 
 
 
 
 
 
 
 
 
 
 
0a639cb
 
 
 
165abce
 
 
 
0a639cb
3c8cbc9
0a639cb
3c8cbc9
 
 
 
 
 
0a639cb
 
 
 
 
 
 
165abce
 
0a639cb
3c8cbc9
165abce
 
 
 
0a639cb
165abce
 
 
 
 
3c8cbc9
 
 
 
 
 
 
 
0a639cb
3c8cbc9
 
 
 
 
 
 
 
 
 
 
 
 
0a639cb
 
 
 
 
 
3c8cbc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a639cb
 
 
 
165abce
0a639cb
 
 
 
 
 
 
 
 
 
 
 
3c8cbc9
 
 
165abce
 
 
 
 
 
3c8cbc9
 
 
165abce
0a639cb
 
 
 
3c8cbc9
 
 
0a639cb
3c8cbc9
165abce
3c8cbc9
 
 
 
 
0a639cb
165abce
3c8cbc9
165abce
 
3c8cbc9
165abce
 
 
 
3c8cbc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a639cb
165abce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c8cbc9
165abce
3c8cbc9
 
 
165abce
 
 
 
 
 
 
 
 
 
 
3c8cbc9
165abce
 
 
 
 
 
 
3c8cbc9
 
 
 
 
 
 
 
 
 
0a639cb
 
165abce
3c8cbc9
 
 
 
 
 
 
 
 
 
 
 
 
 
0a639cb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
import os
import torch
import numpy as np
import time
import matplotlib.pyplot as plt
from typing import Tuple, List
from statistics import mean, median, stdev
from lib import (
    normalize_text,
    chunk_text,
    count_tokens,
    load_module_from_file,
    download_model_files,
    list_voice_files,
    download_voice_files,
    ensure_dir,
    concatenate_audio_chunks
)
import spaces

class TTSModel:
    """GPU-accelerated TTS model manager"""
    
    def __init__(self):
        self.model = None
        self.voices_dir = "voices"
        self.model_repo = "hexgrad/Kokoro-82M"
        ensure_dir(self.voices_dir)
        self.model_path = None
        
        # Load required modules
        py_modules = ["istftnet", "plbert", "models", "kokoro"]
        module_files = download_model_files(self.model_repo, [f"{m}.py" for m in py_modules])
        
        for module_name, file_path in zip(py_modules, module_files):
            load_module_from_file(module_name, file_path)
        
        # Import required functions from kokoro module
        kokoro = __import__("kokoro")
        self.generate = kokoro.generate
        self.build_model = __import__("models").build_model
        
    def initialize(self) -> bool:
        """Initialize model and download voices"""
        try:
            print("Initializing model...")
            
            # Download model files
            model_files = download_model_files(
                self.model_repo,
                ["kokoro-v0_19.pth", "config.json"]
            )
            self.model_path = model_files[0]  # kokoro-v0_19.pth
            
            # Download voice files
            download_voice_files(self.model_repo, "voices", self.voices_dir)

            # Get list of available voices
            available_voices = self.list_voices()

            print("Model initialization complete")
            return True
            
        except Exception as e:
            print(f"Error initializing model: {str(e)}")
            return False
    
    def ensure_voice_downloaded(self, voice_name: str) -> bool:
        """Ensure specific voice is downloaded"""
        try:
            voice_path = os.path.join(self.voices_dir, "voices", f"{voice_name}.pt")
            if not os.path.exists(voice_path):
                print(f"Downloading voice {voice_name}.pt...")
                download_voice_files(self.model_repo, [f"{voice_name}.pt"], self.voices_dir)
            return True
        except Exception as e:
            print(f"Error downloading voice {voice_name}: {str(e)}")
            return False

    def list_voices(self) -> List[str]:
        """List available voices"""
        voices = []
        voices_subdir = os.path.join(self.voices_dir, "voices")
        if os.path.exists(voices_subdir):
            for file in os.listdir(voices_subdir):
                if file.endswith(".pt"):
                    voice_name = file[:-3]
                    voices.append(voice_name)
        return voices
    
    # def _ensure_model_on_gpu(self) -> None:
    #     """Ensure model is on GPU and stays there"""
    #     if not hasattr(self, '_model_on_gpu') or not self._model_on_gpu:
    #         print("Moving model to GPU...")
    #         with torch.cuda.device(0):
    #             torch.cuda.set_device(0)
    #             if hasattr(self.model, 'to'):
    #                 self.model.to('cuda')
    #             else:
    #                 for name in self.model:
    #                     if isinstance(self.model[name], torch.Tensor):
    #                         self.model[name] = self.model[name].cuda()
    #             self._model_on_gpu = True
    
    def _generate_audio(self, text: str, voicepack: torch.Tensor, lang: str, speed: float) -> np.ndarray:
        """GPU-accelerated audio generation"""
        try:
            with torch.cuda.device(0):
                torch.cuda.set_device(0)
                try:
                    # Build model if needed
                    if self.model is None:
                        print("Building model...")
                        device = torch.device('cuda')
                        self.model = self.build_model(self.model_path, device=device)
                        if self.model is None:
                            raise ValueError("Failed to build model")
                        print("Model built successfully")
                    
                    # Move model to GPU if needed
                    if not hasattr(self.model, '_on_gpu'):
                        print("Moving model to GPU...")
                        if hasattr(self.model, 'to'):
                            self.model = self.model.to('cuda')
                        else:
                            for name in self.model:
                                if isinstance(self.model[name], torch.Tensor):
                                    self.model[name] = self.model[name].cuda()
                        self.model._on_gpu = True
                except Exception as e:
                    print(f"Error building model: {str(e)}")
                    print("Attempting to continue")
                    raise e
                # Move voicepack to GPU
                voicepack = voicepack.cuda()
                
                # Run generation with everything on GPU
                audio, _ = self.generate(
                    self.model,
                    text,
                    voicepack,
                    lang=lang,
                    speed=speed
                )
                
                return audio
            
        except Exception as e:
            print(f"Error in audio generation: {str(e)}")
            raise e
        
    @spaces.GPU(duration=None)  # Duration will be set by the UI
    def generate_speech(self, text: str, voice_names: list[str], speed: float = 1.0, gpu_timeout: int = 60, progress_callback=None, progress_state=None, progress=None) -> Tuple[np.ndarray, float]:
        """Generate speech from text. Returns (audio_array, duration)
        
        Args:
            text: Input text to convert to speech
            voice_name: Name of voice to use
            speed: Speech speed multiplier
            progress_callback: Optional callback function(chunk_num, total_chunks, tokens_per_sec, rtf, progress_state, start_time, gpu_timeout, progress)
            progress_state: Dictionary tracking generation progress metrics
            progress: Progress callback from Gradio
        """
        try:
            start_time = time.time()
            with torch.cuda.device(0):
                torch.cuda.set_device(0)
                if not text or not voice_names:
                    raise ValueError("Text and voice name are required")
                            # Build model directly on GPU
                
                # Build model if needed
                if self.model is None:
                    print("Building model...")
                    self.model = self.build_model(self.model_path, device='cuda')
                    if self.model is None:
                        raise ValueError("Failed to build model")
                    print("Model built successfully")
                
                # Move model to GPU if needed
                if not hasattr(self.model, '_on_gpu'):
                    print("Moving model to GPU...")
                    if hasattr(self.model, 'to'):
                        self.model = self.model.to('cuda')
                    else:
                        for name in self.model:
                            if isinstance(self.model[name], torch.Tensor):
                                self.model[name] = self.model[name].cuda()
                    self.model._on_gpu = True
                
                t_voices = []
                if isinstance(voice_names, list) and len(voice_names) > 1:
                    for voice in voice_names:
                        try:
                            voice_path = os.path.join(self.voices_dir, "voices", f"{voice}.pt")
                            voicepack = torch.load(voice_path, weights_only=True)
                            t_voices.append(voicepack)
                        except Exception as e:
                            print(f"Warning: Failed to load voice {voice}: {str(e)}")
                
                    # Combine voices by taking mean
                    voicepack = torch.mean(torch.stack(t_voices), dim=0)
                    voice_name = "_".join(voice_names)
                else:
                    voice_name = voice_names[0]
                    voice_path = os.path.join(self.voices_dir, "voices", f"{voice_name}.pt")
                    voicepack = torch.load(voice_path, weights_only=True)

                # Count tokens and normalize text
                total_tokens = count_tokens(text)
                text = normalize_text(text)
                if not text:
                    raise ValueError("Text is empty after normalization")
                
                # Break text into chunks for better memory management
                chunks = chunk_text(text)
                print(f"Processing {len(chunks)} chunks...")
                
                # Process all chunks within same GPU context
                audio_chunks = []
                chunk_times = []
                chunk_sizes = []  # Store chunk lengths
                total_processed_tokens = 0
                total_processed_time = 0
                
                for i, chunk in enumerate(chunks):
                    chunk_start = time.time()
                    chunk_audio = self._generate_audio(
                        text=chunk,
                        voicepack=voicepack,
                        lang=voice_name[0],
                        speed=speed
                    )
                    chunk_time = time.time() - chunk_start
                    
                    # Calculate per-chunk metrics
                    chunk_tokens = count_tokens(chunk)
                    chunk_tokens_per_sec = chunk_tokens / chunk_time
                    
                    # Update totals for overall stats
                    total_processed_tokens += chunk_tokens
                    total_processed_time += chunk_time
                    
                    # Calculate processing speed metrics
                    chunk_duration = len(chunk_audio) / 24000  # audio duration in seconds
                    rtf = chunk_time / chunk_duration
                    times_faster = 1 / rtf
                    
                    chunk_times.append(chunk_time)
                    chunk_sizes.append(len(chunk))
                    print(f"Chunk {i+1}/{len(chunks)} processed in {chunk_time:.2f}s")
                    print(f"Current tokens/sec: {chunk_tokens_per_sec:.2f}")
                    print(f"Real-time factor: {rtf:.2f}x")
                    print(f"{times_faster:.1f}x faster than real-time")
                    
                    audio_chunks.append(chunk_audio)
                    
                    # Call progress callback if provided
                    if progress_callback:
                        progress_callback(
                            i + 1,  # chunk_num
                            len(chunks),  # total_chunks
                            chunk_tokens_per_sec,  # Pass per-chunk rate instead of cumulative
                            rtf,
                            progress_state,  # Added
                            start_time,  # Added
                            gpu_timeout,  # Use the timeout value from UI
                            progress  # Added
                        )
            
            # Concatenate audio chunks
            audio = concatenate_audio_chunks(audio_chunks)
    
            # Return audio and metrics
            return (
                audio,  # Audio array
                len(audio) / 24000,  # Duration
                {
                    "chunk_times": chunk_times,
                    "chunk_sizes": chunk_sizes,
                    "tokens_per_sec": [float(x) for x in progress_state["tokens_per_sec"]],
                    "rtf": [float(x) for x in progress_state["rtf"]],
                    "total_tokens": total_tokens,
                    "total_time": time.time() - start_time
                }
            )
        except Exception as e:
            print(f"Error generating speech: {str(e)}")
            raise