diff --git a/LLM/language_model.py b/LLM/language_model.py index 3fb3329..7f8dc47 100644 --- a/LLM/language_model.py +++ b/LLM/language_model.py @@ -145,3 +145,4 @@ def process(self, prompt): # don't forget last sentence yield (printable_text, language_code) + yield b"DONE" diff --git a/TTS/melo_handler.py b/TTS/melo_handler.py index 6dd50f1..2b7be4c 100644 --- a/TTS/melo_handler.py +++ b/TTS/melo_handler.py @@ -56,6 +56,11 @@ def warmup(self): _ = self.model.tts_to_file("text", self.speaker_id, quiet=True) def process(self, llm_sentence): + if llm_sentence == b"DONE": + self.should_listen.set() + yield b"DONE" + return + language_code = None if isinstance(llm_sentence, tuple): @@ -94,10 +99,11 @@ def process(self, llm_sentence): ) except (AssertionError, RuntimeError) as e: logger.error(f"Error in MeloTTSHandler: {e}") - audio_chunk = np.array([]) - if len(audio_chunk) == 0: - self.should_listen.set() - return + audio_chunk = np.zeros([self.blocksize]) + except Exception as e: + logger.error(f"Unknown error in MeloTTSHandler: {e}") + audio_chunk = np.zeros([self.blocksize]) + audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000) audio_chunk = (audio_chunk * 32768).astype(np.int16) for i in range(0, len(audio_chunk), self.blocksize): diff --git a/TTS/parler_handler.py b/TTS/parler_handler.py index ac539c7..7b1351c 100644 --- a/TTS/parler_handler.py +++ b/TTS/parler_handler.py @@ -189,3 +189,4 @@ def process(self, llm_sentence): ) self.should_listen.set() + yield b"END" \ No newline at end of file diff --git a/arguments_classes/module_arguments.py b/arguments_classes/module_arguments.py index d2f00e1..e0375fd 100644 --- a/arguments_classes/module_arguments.py +++ b/arguments_classes/module_arguments.py @@ -11,7 +11,7 @@ class ModuleArguments: mode: Optional[str] = field( default="socket", metadata={ - "help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'socket'." + "help": "The mode to run the pipeline in. Either 'local', 'socket', or 'none'. Default is 'socket'." }, ) local_mac_optimal_settings: bool = field( diff --git a/audio_streaming_client.py b/audio_streaming_client.py new file mode 100644 index 0000000..f9e1604 --- /dev/null +++ b/audio_streaming_client.py @@ -0,0 +1,185 @@ +import threading +from queue import Queue +import sounddevice as sd +import numpy as np +import time +from dataclasses import dataclass, field +import websocket +import ssl + + +@dataclass +class AudioStreamingClientArguments: + sample_rate: int = field( + default=16000, metadata={"help": "Audio sample rate in Hz. Default is 16000."} + ) + chunk_size: int = field( + default=512, + metadata={"help": "The size of audio chunks in samples. Default is 512."}, + ) + api_url: str = field( + default="https://yxfmjcvuzgi123sw.us-east-1.aws.endpoints.huggingface.cloud", + metadata={"help": "The URL of the API endpoint."}, + ) + auth_token: str = field( + default="your_auth_token", + metadata={"help": "Authentication token for the API."}, + ) + + +class AudioStreamingClient: + def __init__(self, args: AudioStreamingClientArguments): + self.args = args + self.stop_event = threading.Event() + self.send_queue = Queue() + self.recv_queue = Queue() + self.session_id = None + self.headers = { + "Accept": "application/json", + "Authorization": f"Bearer {self.args.auth_token}", + "Content-Type": "application/json", + } + self.session_state = ( + "idle" # Possible states: idle, sending, processing, waiting + ) + self.ws_ready = threading.Event() + + def start(self): + print("Starting audio streaming...") + + ws_url = self.args.api_url.replace("http", "ws") + "/ws" + + self.ws = websocket.WebSocketApp( + ws_url, + header=[f"{key}: {value}" for key, value in self.headers.items()], + on_open=self.on_open, + on_message=self.on_message, + on_error=self.on_error, + on_close=self.on_close, + ) + + self.ws_thread = threading.Thread( + target=self.ws.run_forever, kwargs={"sslopt": {"cert_reqs": ssl.CERT_NONE}} + ) + self.ws_thread.start() + + # Wait for the WebSocket to be ready + self.ws_ready.wait() + self.start_audio_streaming() + + def start_audio_streaming(self): + self.send_thread = threading.Thread(target=self.send_audio) + self.play_thread = threading.Thread(target=self.play_audio) + + with sd.InputStream( + samplerate=self.args.sample_rate, + channels=1, + dtype="int16", + callback=self.audio_input_callback, + blocksize=self.args.chunk_size, + ): + self.send_thread.start() + self.play_thread.start() + input("Press Enter to stop streaming... \n") + self.on_shutdown() + + def on_open(self, ws): + print("WebSocket connection opened.") + self.ws_ready.set() # Signal that the WebSocket is ready + + def on_message(self, ws, message): + # message is bytes + if message == b"DONE": + print("listen") + self.session_state = "listen" + else: + if self.session_state != "processing": + print("processing") + self.session_state = "processing" + audio_np = np.frombuffer(message, dtype=np.int16) + self.recv_queue.put(audio_np) + + def on_error(self, ws, error): + print(f"WebSocket error: {error}") + + def on_close(self, ws, close_status_code, close_msg): + print("WebSocket connection closed.") + + def on_shutdown(self): + self.stop_event.set() + self.send_thread.join() + self.play_thread.join() + self.ws.close() + self.ws_thread.join() + print("Service shutdown.") + + def send_audio(self): + while not self.stop_event.is_set(): + if not self.send_queue.empty(): + chunk = self.send_queue.get() + if self.session_state != "processing": + self.ws.send(chunk.tobytes(), opcode=websocket.ABNF.OPCODE_BINARY) + else: + self.ws.send([], opcode=websocket.ABNF.OPCODE_BINARY) # handshake + time.sleep(0.01) + + def audio_input_callback(self, indata, frames, time, status): + self.send_queue.put(indata.copy()) + + def audio_out_callback(self, outdata, frames, time, status): + if not self.recv_queue.empty(): + chunk = self.recv_queue.get() + + # Ensure chunk is int16 and clip to valid range + chunk_int16 = np.clip(chunk, -32768, 32767).astype(np.int16) + + if len(chunk_int16) < len(outdata): + outdata[: len(chunk_int16), 0] = chunk_int16 + outdata[len(chunk_int16) :] = 0 + else: + outdata[:, 0] = chunk_int16[: len(outdata)] + else: + outdata[:] = 0 + + def play_audio(self): + with sd.OutputStream( + samplerate=self.args.sample_rate, + channels=1, + dtype="int16", + callback=self.audio_out_callback, + blocksize=self.args.chunk_size, + ): + while not self.stop_event.is_set(): + time.sleep(0.1) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Audio Streaming Client") + parser.add_argument( + "--sample_rate", + type=int, + default=16000, + help="Audio sample rate in Hz. Default is 16000.", + ) + parser.add_argument( + "--chunk_size", + type=int, + default=1024, + help="The size of audio chunks in samples. Default is 1024.", + ) + parser.add_argument( + "--api_url", type=str, required=True, help="The URL of the API endpoint." + ) + parser.add_argument( + "--auth_token", + type=str, + required=True, + help="Authentication token for the API.", + ) + + args = parser.parse_args() + client_args = AudioStreamingClientArguments(**vars(args)) + client = AudioStreamingClient(client_args) + client.start() diff --git a/listen_and_play.py b/listen_and_play.py index 701fe04..dea7c6d 100644 --- a/listen_and_play.py +++ b/listen_and_play.py @@ -125,4 +125,4 @@ def receive_full_chunk(conn, chunk_size): if __name__ == "__main__": parser = HfArgumentParser((ListenAndPlayArguments,)) (listen_and_play_kwargs,) = parser.parse_args_into_dataclasses() - listen_and_play(**vars(listen_and_play_kwargs)) + listen_and_play(**vars(listen_and_play_kwargs)) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 19bb4d7..084bf46 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,5 @@ funasr>=1.1.6 faster-whisper>=1.0.3 modelscope>=1.17.1 deepfilternet>=0.5.6 -openai>=1.40.1 \ No newline at end of file +openai>=1.40.1 +websocket-client>=1.8.0 diff --git a/requirements_mac.txt b/requirements_mac.txt index f7c0a9c..7887147 100644 --- a/requirements_mac.txt +++ b/requirements_mac.txt @@ -10,4 +10,5 @@ funasr>=1.1.6 faster-whisper>=1.0.3 modelscope>=1.17.1 deepfilternet>=0.5.6 -openai>=1.40.1 \ No newline at end of file +openai>=1.40.1 +websocket-client>=1.8.0 diff --git a/s2s_handler.py b/s2s_handler.py new file mode 100644 index 0000000..2adb21d --- /dev/null +++ b/s2s_handler.py @@ -0,0 +1,105 @@ +from typing import Dict, Any, List, Generator +import torch +import os +import logging +from s2s_pipeline import main, prepare_all_args, get_default_arguments, setup_logger, initialize_queues_and_events, build_pipeline +import numpy as np +from queue import Queue, Empty +import threading +import base64 +import uuid +import torch + +class EndpointHandler: + def __init__(self, path=""): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + lm_model_name = os.getenv('LM_MODEL_NAME', 'meta-llama/Meta-Llama-3.1-8B-Instruct') + chat_size = int(os.getenv('CHAT_SIZE', 10)) + + ( + self.module_kwargs, + self.socket_receiver_kwargs, + self.socket_sender_kwargs, + self.vad_handler_kwargs, + self.whisper_stt_handler_kwargs, + self.paraformer_stt_handler_kwargs, + self.faster_whisper_stt_handler_kwargs, + self.language_model_handler_kwargs, + self.open_api_language_model_handler_kwargs, + self.mlx_language_model_handler_kwargs, + self.parler_tts_handler_kwargs, + self.melo_tts_handler_kwargs, + self.chat_tts_handler_kwargs, + self.facebook_mm_stts_handler_kwargs, + ) = get_default_arguments(mode='none', log_level='DEBUG', lm_model_name=lm_model_name, + tts="melo", device=device, chat_size=chat_size) + setup_logger(self.module_kwargs.log_level) + + prepare_all_args( + self.module_kwargs, + self.whisper_stt_handler_kwargs, + self.paraformer_stt_handler_kwargs, + self.faster_whisper_stt_handler_kwargs, + self.language_model_handler_kwargs, + self.open_api_language_model_handler_kwargs, + self.mlx_language_model_handler_kwargs, + self.parler_tts_handler_kwargs, + self.melo_tts_handler_kwargs, + self.chat_tts_handler_kwargs, + self.facebook_mm_stts_handler_kwargs, + ) + + self.queues_and_events = initialize_queues_and_events() + + self.pipeline_manager = build_pipeline( + self.module_kwargs, + self.socket_receiver_kwargs, + self.socket_sender_kwargs, + self.vad_handler_kwargs, + self.whisper_stt_handler_kwargs, + self.paraformer_stt_handler_kwargs, + self.faster_whisper_stt_handler_kwargs, + self.language_model_handler_kwargs, + self.open_api_language_model_handler_kwargs, + self.mlx_language_model_handler_kwargs, + self.parler_tts_handler_kwargs, + self.melo_tts_handler_kwargs, + self.chat_tts_handler_kwargs, + self.facebook_mm_stts_handler_kwargs, + self.queues_and_events, + ) + + self.vad_chunk_size = 512 # Set the chunk size required by the VAD model + self.sample_rate = 16000 # Set the expected sample rate + + def process_streaming_data(self, data: bytes) -> bytes: + audio_array = np.frombuffer(data, dtype=np.int16) + + # Process the audio data in chunks + chunks = [audio_array[i:i+self.vad_chunk_size] for i in range(0, len(audio_array), self.vad_chunk_size)] + + for chunk in chunks: + if len(chunk) == self.vad_chunk_size: + self.queues_and_events['recv_audio_chunks_queue'].put(chunk.tobytes()) + elif len(chunk) < self.vad_chunk_size: + # Pad the last chunk if it's smaller than the required size + padded_chunk = np.pad(chunk, (0, self.vad_chunk_size - len(chunk)), 'constant') + self.queues_and_events['recv_audio_chunks_queue'].put(padded_chunk.tobytes()) + + # Collect the output, if any + try: + output = self.queues_and_events['send_audio_chunks_queue'].get_nowait() # improvement idea, group all available output chunks + if isinstance(output, np.ndarray): + return output.tobytes() + else: + return output + except Empty: + return None + + def cleanup(self): + # Stop the pipeline + self.pipeline_manager.stop() + + # Stop the output collector thread + self.queues_and_events['send_audio_chunks_queue'].put(b"END") + self.output_collector_thread.join() \ No newline at end of file diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 3aafc8a..35c391e 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -70,6 +70,30 @@ def rename_args(args, prefix): args.__dict__["gen_kwargs"] = gen_kwargs +def get_default_arguments(**kwargs): + default_args = [ + ModuleArguments(), + SocketReceiverArguments(), + SocketSenderArguments(), + VADHandlerArguments(), + WhisperSTTHandlerArguments(), + ParaformerSTTHandlerArguments(), + FasterWhisperSTTHandlerArguments(), + LanguageModelHandlerArguments(), + OpenApiLanguageModelHandlerArguments(), + MLXLanguageModelHandlerArguments(), + ParlerTTSHandlerArguments(), + MeloTTSHandlerArguments(), + ChatTTSHandlerArguments(), + FacebookMMSTTSHandlerArguments(), + ] + # Update arguments with provided kwargs + for arg_obj in default_args: + for key, value in kwargs.items(): + if hasattr(arg_obj, key): + setattr(arg_obj, key, value) + + return tuple(default_args) def parse_arguments(): parser = HfArgumentParser( @@ -248,7 +272,7 @@ def build_pipeline( ) comms_handlers = [local_audio_streamer] should_listen.set() - else: + elif module_kwargs.mode == "socket": from connections.socket_receiver import SocketReceiver from connections.socket_sender import SocketSender @@ -268,6 +292,9 @@ def build_pipeline( port=socket_sender_kwargs.send_port, ), ] + else: + comms_handlers = [] + should_listen.set() vad = VADHandler( stop_event, @@ -319,7 +346,7 @@ def get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_ setup_kwargs=vars(faster_whisper_stt_handler_kwargs), ) else: - raise ValueError("The STT should be either whisper, whisper-mlx, or paraformer.") + raise ValueError("The STT should be either whisper, whisper-mlx, faster-whisper, or paraformer.") def get_llm_handler(