diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 32dd1b4543..61f9115117 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -800,6 +800,37 @@ "edge-tts-voice": "zh-CN-XiaoxiaoNeural", "timeout": 20, }, + "GSV TTS(本地加载)": { + "id": "gsv_tts", + "enable": False, + "type": "gsv_tts_selfhost", + "provider_type": "text_to_speech", + "api_base": "http://127.0.0.1:9880", + "gpt_weights_path": "", + "sovits_weights_path": "", + "timeout": 60, + "gsv_default_parms": { + "gsv_ref_audio_path": "", + "gsv_prompt_text": "", + "gsv_prompt_lang": "zh", + "gsv_aux_ref_audio_paths": "", + "gsv_text_lang": "zh", + "gsv_top_k": 5, + "gsv_top_p": 1.0, + "gsv_temperature": 1.0, + "gsv_text_split_method": "cut3", + "gsv_batch_size": 1, + "gsv_batch_threshold": 0.75, + "gsv_split_bucket": True, + "gsv_speed_factor": 1, + "gsv_fragment_interval": 0.3, + "gsv_streaming_mode": False, + "gsv_seed": -1, + "gsv_parallel_infer": True, + "gsv_repetition_penalty": 1.35, + "gsv_media_type": "wav", + }, + }, "GSVI TTS(API)": { "id": "gsvi_tts", "type": "gsvi_tts_api", @@ -901,6 +932,130 @@ }, }, "items": { + "gpt_weights_path": { + "description": "GPT模型文件路径", + "type": "string", + "hint": "即“.ckpt”后缀的文件,请使用绝对路径,路径两端不要带双引号,不填则默认用GPT_SoVITS内置的SoVITS模型(建议直接在GPT_SoVITS中改默认模型)", + "obvious_hint": True, + }, + "sovits_weights_path": { + "description": "SoVITS模型文件路径", + "type": "string", + "hint": "即“.pth”后缀的文件,请使用绝对路径,路径两端不要带双引号,不填则默认用GPT_SoVITS内置的SoVITS模型(建议直接在GPT_SoVITS中改默认模型)", + "obvious_hint": True, + }, + "gsv_default_parms": { + "description": "GPT_SoVITS默认参数", + "hint": "参考音频文件路径、参考音频文本必填,其他参数根据个人爱好自行填写", + "type": "object", + "items": { + "gsv_ref_audio_path": { + "description": "参考音频文件路径", + "type": "string", + "hint": "必填!请使用绝对路径!路径两端不要带双引号!", + "obvious_hint": True, + }, + "gsv_prompt_text": { + "description": "参考音频文本", + "type": "string", + "hint": "必填!请填写参考音频讲述的文本", + "obvious_hint": True, + }, + "gsv_prompt_lang": { + "description": "参考音频文本语言", + "type": "string", + "hint": "请填写参考音频讲述的文本的语言,默认为中文", + }, + "gsv_aux_ref_audio_paths": { + "description": "辅助参考音频文件路径", + "type": "string", + "hint": "辅助参考音频文件,可不填", + }, + "gsv_text_lang": { + "description": "文本语言", + "type": "string", + "hint": "默认为中文", + }, + "gsv_top_k": { + "description": "生成语音的多样性", + "type": "int", + "hint": "", + }, + "gsv_top_p": { + "description": "核采样的阈值", + "type": "float", + "hint": "", + }, + "gsv_temperature": { + "description": "生成语音的随机性", + "type": "float", + "hint": "", + }, + "gsv_text_split_method": { + "description": "切分文本的方法", + "type": "string", + "hint": "可选值: `cut0`:不切分 `cut1`:四句一切 `cut2`:50字一切 `cut3`:按中文句号切 `cut4`:按英文句号切 `cut5`:按标点符号切", + "options": [ + "cut0", + "cut1", + "cut2", + "cut3", + "cut4", + "cut5", + ], + }, + "gsv_batch_size": { + "description": "批处理大小", + "type": "int", + "hint": "", + }, + "gsv_batch_threshold": { + "description": "批处理阈值", + "type": "float", + "hint": "", + }, + "gsv_split_bucket": { + "description": "将文本分割成桶以便并行处理", + "type": "bool", + "hint": "", + }, + "gsv_speed_factor": { + "description": "语音播放速度", + "type": "float", + "hint": "1为原始语速", + }, + "gsv_fragment_interval": { + "description": "语音片段之间的间隔时间", + "type": "float", + "hint": "", + }, + "gsv_streaming_mode": { + "description": "启用流模式", + "type": "bool", + "hint": "", + }, + "gsv_seed": { + "description": "随机种子", + "type": "int", + "hint": "用于结果的可重复性", + }, + "gsv_parallel_infer": { + "description": "并行执行推理", + "type": "bool", + "hint": "", + }, + "gsv_repetition_penalty": { + "description": "重复惩罚因子", + "type": "float", + "hint": "", + }, + "gsv_media_type": { + "description": "输出媒体的类型", + "type": "string", + "hint": "建议用wav", + }, + }, + }, "embedding_dimensions": { "description": "嵌入维度", "type": "int", diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index b11a3361a9..382f469fef 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -225,6 +225,10 @@ async def load_provider(self, provider_config: dict): from .sources.edge_tts_source import ( ProviderEdgeTTS as ProviderEdgeTTS, ) + case "gsv_tts_selfhost": + from .sources.gsv_selfhosted_source import ( + ProviderGSVTTS as ProviderGSVTTS, + ) case "gsvi_tts_api": from .sources.gsvi_tts_source import ( ProviderGSVITTS as ProviderGSVITTS, diff --git a/astrbot/core/provider/sources/gsv_selfhosted_source.py b/astrbot/core/provider/sources/gsv_selfhosted_source.py new file mode 100644 index 0000000000..6c4d872a9b --- /dev/null +++ b/astrbot/core/provider/sources/gsv_selfhosted_source.py @@ -0,0 +1,148 @@ +import asyncio +import os +import uuid + +import aiohttp +from ..provider import TTSProvider +from ..entities import ProviderType +from ..register import register_provider_adapter +from astrbot import logger +from astrbot.core.utils.astrbot_path import get_astrbot_data_path + + +@register_provider_adapter( + provider_type_name="gsv_tts_selfhost", + desc="GPT-SoVITS TTS(本地加载)", + provider_type=ProviderType.TEXT_TO_SPEECH, +) +class ProviderGSVTTS(TTSProvider): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + ) -> None: + super().__init__(provider_config, provider_settings) + + self.api_base = provider_config.get("api_base", "http://127.0.0.1:9880").rstrip( + "/" + ) + self.gpt_weights_path: str = provider_config.get("gpt_weights_path", "") + self.sovits_weights_path: str = provider_config.get("sovits_weights_path", "") + + # TTS 请求的默认参数,移除前缀gsv_ + self.default_params: dict = { + key.removeprefix("gsv_"): str(value).lower() + for key, value in provider_config.get("gsv_default_parms", {}).items() + } + self.timeout = provider_config.get("timeout", 60) + self._session: aiohttp.ClientSession | None = None + + async def initialize(self): + """异步初始化:在 ProviderManager 中被调用""" + self._session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=self.timeout) + ) + try: + await self._set_model_weights() + logger.info("[GSV TTS] 初始化完成") + except Exception as e: + logger.error(f"[GSV TTS] 初始化失败:{e}") + raise + + def get_session(self) -> aiohttp.ClientSession: + if not self._session or self._session.closed: + raise RuntimeError( + "[GSV TTS] Provider HTTP session is not ready or closed." + ) + return self._session + + async def _make_request( + self, endpoint: str, params=None, retries: int = 3 + ) -> bytes | None: + """发起请求""" + for attempt in range(retries): + logger.debug(f"[GSV TTS] 请求地址:{endpoint},参数:{params}") + try: + async with self.get_session().get(endpoint, params=params) as response: + if response.status != 200: + error_text = await response.text() + raise Exception( + f"[GSV TTS] Request to {endpoint} failed with status {response.status}: {error_text}" + ) + return await response.read() + except Exception as e: + if attempt < retries - 1: + logger.warning( + f"[GSV TTS] 请求 {endpoint} 第 {attempt + 1} 次失败:{e},重试中..." + ) + await asyncio.sleep(1) + else: + logger.error(f"[GSV TTS] 请求 {endpoint} 最终失败:{e}") + raise + + async def _set_model_weights(self): + """设置模型路径""" + try: + if self.gpt_weights_path: + await self._make_request( + f"{self.api_base}/set_gpt_weights", + {"weights_path": self.gpt_weights_path}, + ) + logger.info(f"[GSV TTS] 成功设置 GPT 模型路径:{self.gpt_weights_path}") + else: + logger.info("[GSV TTS] GPT 模型路径未配置,将使用内置 GPT 模型") + + if self.sovits_weights_path: + await self._make_request( + f"{self.api_base}/set_sovits_weights", + {"weights_path": self.sovits_weights_path}, + ) + logger.info( + f"[GSV TTS] 成功设置 SoVITS 模型路径:{self.sovits_weights_path}" + ) + else: + logger.info("[GSV TTS] SoVITS 模型路径未配置,将使用内置 SoVITS 模型") + except aiohttp.ClientError as e: + logger.error(f"[GSV TTS] 设置模型路径时发生网络错误:{e}") + except Exception as e: + logger.error(f"[GSV TTS] 设置模型路径时发生未知错误:{e}") + + async def get_audio(self, text: str) -> str: + """实现 TTS 核心方法,根据文本内容自动切换情绪""" + if not text.strip(): + raise ValueError("[GSV TTS] TTS 文本不能为空") + + endpoint = f"{self.api_base}/tts" + + params = self.build_synthesis_params(text) + + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + os.makedirs(temp_dir, exist_ok=True) + path = os.path.join(temp_dir, f"gsv_tts_{uuid.uuid4().hex}.wav") + + logger.debug(f"[GSV TTS] 正在调用语音合成接口,参数:{params}") + + result = await self._make_request(endpoint, params) + if isinstance(result, bytes): + with open(path, "wb") as f: + f.write(result) + return path + else: + raise Exception(f"[GSV TTS] 合成失败,输入文本:{text},错误信息:{result}") + + def build_synthesis_params(self, text: str) -> dict: + """ + 构建语音合成所需的参数字典。 + + 当前仅包含默认参数 + 文本,未来可在此基础上动态添加如情绪、角色等语义控制字段。 + """ + params = self.default_params.copy() + params["text"] = text + # TODO: 在此处添加情绪分析,例如 params["emotion"] = detect_emotion(text) + return params + + async def terminate(self): + """终止释放资源:在 ProviderManager 中被调用""" + if self._session and not self._session.closed: + await self._session.close() + logger.info("[GSV TTS] Session 已关闭")