diff --git a/.env.template b/.env.template index 8fc8213..93380fb 100644 --- a/.env.template +++ b/.env.template @@ -30,6 +30,14 @@ TTS_API_KEY="your_tts_key" # OpenAI TTS 模型 TTS_MODEL="tts-1" +# CAMB AI TTS 配置(当 USE_CAMBAI_TTS=true 时需配置) +USE_CAMBAI_TTS=false +CAMBAI_API_KEY="" +CAMBAI_SPEECH_MODEL="mars-flash" +CAMBAI_LANGUAGE="en-us" +# CAMB AI 声音映射(JSON格式,将角色映射到 CAMB AI voice ID) +CAMBAI_VOICE_MAP={} + # 微软 TTS 代理配置 # 当USE_OPENAI_TTS_MODEL为false时,且在国内网络环境使用时,需要配置代理 HTTPS_PROXY="http://localhost:7890" diff --git a/requirements.in b/requirements.in index e935b56..4d6c28d 100644 --- a/requirements.in +++ b/requirements.in @@ -21,4 +21,5 @@ python-multipart bcrypt==4.0.1 alembic tenacity -loguru \ No newline at end of file +loguru +camb-sdk \ No newline at end of file diff --git a/server/core/config.py b/server/core/config.py index 15f189f..3104b7a 100644 --- a/server/core/config.py +++ b/server/core/config.py @@ -41,6 +41,13 @@ class Settings(BaseSettings): TTS_BASE_URL: str TTS_API_KEY: str TTS_MODEL: str + + # CAMB AI TTS配置 + USE_CAMBAI_TTS: bool = False + CAMBAI_API_KEY: str = "" + CAMBAI_SPEECH_MODEL: str = "mars-flash" + CAMBAI_LANGUAGE: str = "en-us" + CAMBAI_VOICE_MAP: Dict[str, int] = {} # 代理配置 HTTPS_PROXY: str | None = None @@ -96,6 +103,11 @@ class ConfigManager: 'TTS_BASE_URL', 'TTS_API_KEY', 'TTS_MODEL', + 'USE_CAMBAI_TTS', + 'CAMBAI_API_KEY', + 'CAMBAI_SPEECH_MODEL', + 'CAMBAI_LANGUAGE', + 'CAMBAI_VOICE_MAP', 'ANCHOR_TYPE_MAP', 'HTTPS_PROXY', 'ALLOW_REGISTRATION', diff --git a/server/services/cambai_tts.py b/server/services/cambai_tts.py new file mode 100644 index 0000000..8496ac8 --- /dev/null +++ b/server/services/cambai_tts.py @@ -0,0 +1,43 @@ +import os +import tempfile + +from camb.client import CambAI, save_stream_to_file +from camb.types import StreamTtsOutputConfiguration +from core.config import settings +from core.logging import log + + +class CambAITTSService: + def __init__(self): + self.client = CambAI(api_key=settings.CAMBAI_API_KEY) + self.voice_mapping = settings.CAMBAI_VOICE_MAP + + def generate_speech(self, text: str, voice: str, response_format: str = "mp3") -> str: + voice_id = self._resolve_voice_id(voice) + speech_model = settings.CAMBAI_SPEECH_MODEL + + temp_output = tempfile.NamedTemporaryFile(delete=False, suffix=f".{response_format}") + try: + stream = self.client.text_to_speech.tts( + text=text, + language=settings.CAMBAI_LANGUAGE if hasattr(settings, 'CAMBAI_LANGUAGE') else "en-us", + voice_id=voice_id, + speech_model=speech_model, + output_configuration=StreamTtsOutputConfiguration(format=response_format), + ) + save_stream_to_file(stream, temp_output.name) + return temp_output.name + except Exception as e: + log.error(f"CAMB AI TTS generation failed: {str(e)}") + if os.path.exists(temp_output.name): + os.remove(temp_output.name) + raise + + def _resolve_voice_id(self, voice: str) -> int: + if voice in self.voice_mapping: + return self.voice_mapping[voice] + try: + return int(voice) + except (ValueError, TypeError): + log.warning(f"Unknown voice '{voice}', using default voice 147320") + return 147320 diff --git a/server/services/task/steps/audio.py b/server/services/task/steps/audio.py index d40d4e3..d46823f 100644 --- a/server/services/task/steps/audio.py +++ b/server/services/task/steps/audio.py @@ -11,6 +11,7 @@ from services.edgetts import EdgeTTSService from openai import OpenAI from core.config import settings +from services.cambai_tts import CambAITTSService from core.logging import log from services.task.utils.progress_tracker import ProgressTracker @@ -36,6 +37,7 @@ def __init__( base_url=settings.TTS_BASE_URL, api_key=settings.TTS_API_KEY ) + self.cambai_tts = CambAITTSService() if settings.USE_CAMBAI_TTS else None def _verify_audio_file(self, file_path: str) -> bool: """验证音频文件是否有效""" @@ -60,18 +62,24 @@ def _generate_audio_with_retry(self, item: dict, file_path: str, anchor_type: st """生成音频文件,支持重试""" for attempt in range(max_retries): try: - if settings.USE_OPENAI_TTS_MODEL: + if settings.USE_CAMBAI_TTS: + temp_audio_file = self.cambai_tts.generate_speech(item['content'], anchor_type) + if not temp_audio_file or not os.path.exists(temp_audio_file): + raise Exception("CAMB AI TTS 生成失败") + + shutil.move(temp_audio_file, file_path) + elif settings.USE_OPENAI_TTS_MODEL: audio_content = self._sync_openai_tts_request(item['content'], anchor_type) if audio_content is None: raise Exception("OpenAI TTS 返回空内容") - + with open(file_path, 'wb') as f: f.write(audio_content) else: temp_audio_file = self.edge_tts.generate_speech(item['content'], anchor_type) if not temp_audio_file or not os.path.exists(temp_audio_file): raise Exception("Edge TTS 生成失败") - + shutil.move(temp_audio_file, file_path) # 验证生成的音频文件 diff --git a/server/tests/test_cambai_tts.py b/server/tests/test_cambai_tts.py new file mode 100644 index 0000000..9c20ad9 --- /dev/null +++ b/server/tests/test_cambai_tts.py @@ -0,0 +1,168 @@ +import os +import sys +from pathlib import Path + +# Add server/ to path so 'services', 'core' etc. are importable +sys.path.insert(0, str(Path(__file__).parent.parent)) + +import pytest +import tempfile +from unittest.mock import Mock, patch, MagicMock + + +@pytest.fixture +def mock_settings(): + with patch("services.cambai_tts.settings") as mock_s: + mock_s.CAMBAI_API_KEY = "test-key" + mock_s.CAMBAI_SPEECH_MODEL = "mars-flash" + mock_s.CAMBAI_LANGUAGE = "en-us" + mock_s.CAMBAI_VOICE_MAP = {"host_en": 147320, "guest_en": 147321} + yield mock_s + + +@pytest.fixture +def cambai_service(mock_settings): + with patch("services.cambai_tts.CambAI") as mock_camb_cls: + mock_client = Mock() + mock_camb_cls.return_value = mock_client + from services.cambai_tts import CambAITTSService + service = CambAITTSService() + service._mock_client = mock_client + yield service + + +class TestVoiceMapping: + def test_resolve_voice_id_from_map(self, cambai_service): + assert cambai_service._resolve_voice_id("host_en") == 147320 + assert cambai_service._resolve_voice_id("guest_en") == 147321 + + def test_resolve_voice_id_fallback_to_int(self, cambai_service): + assert cambai_service._resolve_voice_id("99999") == 99999 + + def test_resolve_voice_id_fallback_to_default(self, cambai_service): + assert cambai_service._resolve_voice_id("unknown_voice") == 147320 + + +class TestGenerateSpeech: + def test_success_returns_temp_file_path(self, cambai_service): + mock_stream = iter([b"fake-audio-bytes"]) + cambai_service._mock_client.text_to_speech.tts.return_value = mock_stream + + with patch("services.cambai_tts.save_stream_to_file") as mock_save: + result = cambai_service.generate_speech("Hello world", "host_en") + + assert isinstance(result, str) + assert result.endswith(".mp3") + mock_save.assert_called_once() + cambai_service._mock_client.text_to_speech.tts.assert_called_once() + + def test_uses_correct_speech_model(self, cambai_service, mock_settings): + mock_settings.CAMBAI_SPEECH_MODEL = "mars-pro" + cambai_service._mock_client.text_to_speech.tts.return_value = iter([b"audio"]) + + with patch("services.cambai_tts.save_stream_to_file"): + cambai_service.generate_speech("test", "host_en") + + call_kwargs = cambai_service._mock_client.text_to_speech.tts.call_args.kwargs + assert call_kwargs["speech_model"] == "mars-pro" + + def test_uses_correct_language(self, cambai_service, mock_settings): + mock_settings.CAMBAI_LANGUAGE = "ja-jp" + cambai_service._mock_client.text_to_speech.tts.return_value = iter([b"audio"]) + + with patch("services.cambai_tts.save_stream_to_file"): + cambai_service.generate_speech("test", "host_en") + + call_kwargs = cambai_service._mock_client.text_to_speech.tts.call_args.kwargs + assert call_kwargs["language"] == "ja-jp" + + def test_failure_cleans_up_temp_file(self, cambai_service): + cambai_service._mock_client.text_to_speech.tts.side_effect = Exception("API error") + + with pytest.raises(Exception, match="API error"): + cambai_service.generate_speech("Hello", "host_en") + + def test_custom_response_format(self, cambai_service): + cambai_service._mock_client.text_to_speech.tts.return_value = iter([b"audio"]) + + with patch("services.cambai_tts.save_stream_to_file"): + result = cambai_service.generate_speech("test", "host_en", response_format="wav") + + assert result.endswith(".wav") + + +class TestAudioStepIntegration: + def test_audio_step_uses_cambai_when_enabled(self): + with patch("services.task.steps.audio.settings") as mock_s, \ + patch("services.task.steps.audio.CambAITTSService") as mock_cls, \ + patch("services.task.steps.audio.EdgeTTSService"), \ + patch("services.task.steps.audio.OpenAI"), \ + patch("services.task.steps.audio.ProgressTracker"): + mock_s.USE_CAMBAI_TTS = True + mock_s.USE_OPENAI_TTS_MODEL = False + mock_s.TTS_BASE_URL = "http://test" + mock_s.TTS_API_KEY = "test" + + mock_cambai = Mock() + mock_cambai.generate_speech.return_value = "/tmp/test.mp3" + mock_cls.return_value = mock_cambai + + from services.task.steps.audio import AudioStep + step = AudioStep( + level="beginner", + lang="en", + progress_tracker=Mock(), + context_manager=Mock() + ) + assert step.cambai_tts is not None + + def test_audio_step_skips_cambai_when_disabled(self): + with patch("services.task.steps.audio.settings") as mock_s, \ + patch("services.task.steps.audio.CambAITTSService") as mock_cls, \ + patch("services.task.steps.audio.EdgeTTSService"), \ + patch("services.task.steps.audio.OpenAI"), \ + patch("services.task.steps.audio.ProgressTracker"): + mock_s.USE_CAMBAI_TTS = False + mock_s.USE_OPENAI_TTS_MODEL = False + mock_s.TTS_BASE_URL = "http://test" + mock_s.TTS_API_KEY = "test" + + from services.task.steps.audio import AudioStep + step = AudioStep( + level="beginner", + lang="en", + progress_tracker=Mock(), + context_manager=Mock() + ) + assert step.cambai_tts is None + + +@pytest.mark.integration +class TestCambAIIntegration: + @pytest.fixture(autouse=True) + def skip_without_key(self): + if not os.environ.get("CAMB_API_KEY"): + pytest.skip("CAMB_API_KEY not set") + + def test_real_api_generates_audio(self): + from camb.client import CambAI, save_stream_to_file + from camb.types import StreamTtsOutputConfiguration + + client = CambAI(api_key=os.environ["CAMB_API_KEY"]) + with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f: + filepath = f.name + + try: + stream = client.text_to_speech.tts( + text="Integration test for lingopod.", + language="en-us", + voice_id=147320, + speech_model="mars-flash", + output_configuration=StreamTtsOutputConfiguration(format="mp3"), + ) + save_stream_to_file(stream, filepath) + assert os.path.exists(filepath) + assert os.path.getsize(filepath) > 0 + finally: + if os.path.exists(filepath): + os.remove(filepath)