diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 2f19434c9d..0eedd76d4c 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -39,6 +39,21 @@ from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64 +def _write_base64_payload_to_temp_file( + payload: str, + *, + prefix: str, + suffix: str, +) -> str: + file_bytes = base64.b64decode(payload) + file_path = os.path.join( + get_astrbot_temp_path(), f"{prefix}_{uuid.uuid4().hex}{suffix}" + ) + with open(file_path, "wb") as f: + f.write(file_bytes) + return os.path.abspath(file_path) + + class ComponentType(str, Enum): # Basic Segment Types Plain = "Plain" # plain text message @@ -139,6 +154,12 @@ def fromURL(url: str, **_): def fromBase64(bs64_data: str, **_): return Record(file=f"base64://{bs64_data}", **_) + def _get_source(self) -> str: + for candidate in (self.url, self.file, self.path): + if isinstance(candidate, str) and candidate: + return candidate + return "" + async def convert_to_file_path(self) -> str: """将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。 @@ -146,25 +167,24 @@ async def convert_to_file_path(self) -> str: str: 语音的本地路径,以绝对路径表示。 """ - if not self.file: + source = self._get_source() + if not source: raise Exception(f"not a valid file: {self.file}") - if self.file.startswith("file:///"): - return self.file[8:] - if self.file.startswith("http"): - file_path = await download_image_by_url(self.file) + if source.startswith("file:///"): + return source[8:] + if source.startswith("http"): + file_path = await download_image_by_url(source) return os.path.abspath(file_path) - if self.file.startswith("base64://"): - bs64_data = self.file.removeprefix("base64://") - image_bytes = base64.b64decode(bs64_data) - file_path = os.path.join( - get_astrbot_temp_path(), f"recordseg_{uuid.uuid4()}.jpg" + if source.startswith("base64://"): + bs64_data = source.removeprefix("base64://") + return _write_base64_payload_to_temp_file( + bs64_data, + prefix="recordseg", + suffix=".amr", ) - with open(file_path, "wb") as f: - f.write(image_bytes) - return os.path.abspath(file_path) - if os.path.exists(self.file): - return os.path.abspath(self.file) - raise Exception(f"not a valid file: {self.file}") + if os.path.exists(source): + return os.path.abspath(source) + raise Exception(f"not a valid file: {source}") async def convert_to_base64(self) -> str: """将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。 @@ -174,19 +194,20 @@ async def convert_to_base64(self) -> str: """ # convert to base64 - if not self.file: + source = self._get_source() + if not source: raise Exception(f"not a valid file: {self.file}") - if self.file.startswith("file:///"): - bs64_data = file_to_base64(self.file[8:]) - elif self.file.startswith("http"): - file_path = await download_image_by_url(self.file) + if source.startswith("file:///"): + bs64_data = file_to_base64(source[8:]) + elif source.startswith("http"): + file_path = await download_image_by_url(source) bs64_data = file_to_base64(file_path) - elif self.file.startswith("base64://"): - bs64_data = self.file - elif os.path.exists(self.file): - bs64_data = file_to_base64(self.file) + elif source.startswith("base64://"): + bs64_data = source + elif os.path.exists(source): + bs64_data = file_to_base64(source) else: - raise Exception(f"not a valid file: {self.file}") + raise Exception(f"not a valid file: {source}") bs64_data = bs64_data.removeprefix("base64://") return bs64_data @@ -217,6 +238,7 @@ async def register_to_file_service(self) -> str: class Video(BaseMessageComponent): type: ComponentType = ComponentType.Video file: str + url: str | None = "" cover: str | None = "" # 额外 path: str | None = "" @@ -234,6 +256,12 @@ def fromURL(url: str, **_): return Video(file=url, **_) raise Exception("not a valid url") + def _get_source(self) -> str: + for candidate in (self.url, self.file, self.path): + if isinstance(candidate, str) and candidate: + return candidate + return "" + async def convert_to_file_path(self) -> str: """将这个视频统一转换为本地文件路径。这个方法避免了手动判断视频数据类型,直接返回视频数据的本地路径(如果是网络 URL,则会自动进行下载)。 @@ -241,7 +269,7 @@ async def convert_to_file_path(self) -> str: str: 视频的本地路径,以绝对路径表示。 """ - url = self.file + url = self._get_source() if url and url.startswith("file:///"): return url[8:] if url and url.startswith("http"): @@ -281,7 +309,7 @@ async def register_to_file_service(self) -> str: async def to_dict(self): """需要和 toDict 区分开,toDict 是同步方法""" - url_or_path = self.file + url_or_path = self._get_source() if url_or_path.startswith("http"): payload_file = url_or_path elif callback_host := astrbot_config.get("callback_api_base"): @@ -440,13 +468,11 @@ async def convert_to_file_path(self) -> str: return os.path.abspath(image_file_path) if url.startswith("base64://"): bs64_data = url.removeprefix("base64://") - image_bytes = base64.b64decode(bs64_data) - image_file_path = os.path.join( - get_astrbot_temp_path(), f"imgseg_{uuid.uuid4()}.jpg" + return _write_base64_payload_to_temp_file( + bs64_data, + prefix="imgseg", + suffix=".jpg", ) - with open(image_file_path, "wb") as f: - f.write(image_bytes) - return os.path.abspath(image_file_path) if os.path.exists(url): return os.path.abspath(url) raise Exception(f"not a valid file: {url}") diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index 7110199afb..3eeb5fb766 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -2,6 +2,7 @@ import inspect import itertools import logging +import os import time import uuid from collections.abc import Awaitable @@ -27,6 +28,141 @@ from .aiocqhttp_message_event import AiocqhttpMessageEvent +def _looks_like_resolved_media_ref(value: str) -> bool: + return value.startswith(("http://", "https://", "file://")) or os.path.isabs(value) + + +def _pick_usable_media_source(seg_data: dict[str, Any]) -> str: + for key in ("url", "file", "path", "file_path"): + value = seg_data.get(key) + if isinstance(value, str) and value.strip(): + candidate = value.strip() + if _looks_like_resolved_media_ref(candidate) or os.path.exists(candidate): + return candidate + return "" + + +def _unwrap_onebot_action_data(payload: Any) -> dict[str, Any]: + if not isinstance(payload, dict): + return {} + data = payload.get("data") + if isinstance(data, dict): + return data + return payload + + +async def _resolve_onebot_file_reference( + bot: CQHttp, + *, + message_type: str, + group_id: str | int | None, + file_ref: str, + seg_type: str | None = None, +) -> str | None: + normalized = str(file_ref or "").strip() + if not normalized: + return None + if _looks_like_resolved_media_ref(normalized) or os.path.exists(normalized): + return normalized + + candidates = [normalized] + base_name = os.path.basename(normalized) + if base_name and base_name not in candidates: + candidates.append(base_name) + stem, ext = os.path.splitext(base_name) + if stem and ext and stem not in candidates: + candidates.append(stem) + + actions: list[tuple[str, dict[str, Any]]] = [] + if seg_type == "record": + for candidate in candidates: + actions.append(("get_record", {"file": candidate})) + for candidate in candidates: + actions.extend( + [ + ("get_file", {"file_id": candidate}), + ("get_file", {"file": candidate}), + ("get_image", {"file": candidate}), + ("get_image", {"file_id": candidate}), + ("get_private_file_url", {"file_id": candidate}), + ] + ) + if str(message_type).lower() == "group" and group_id not in (None, ""): + actions.append( + ( + "get_group_file_url", + {"group_id": group_id, "file_id": candidate}, + ) + ) + + for action, params in actions: + try: + ret = await bot.call_action(action=action, **params) + except Exception: + continue + data = _unwrap_onebot_action_data(ret) + for key in ("url", "file"): + value = data.get(key) + if isinstance(value, str) and value.strip(): + resolved = value.strip() + if _looks_like_resolved_media_ref(resolved) or os.path.exists(resolved): + return resolved + return None + + +async def _normalize_onebot_media_data( + bot: CQHttp, + seg_type: str, + seg_data: dict[str, Any], + *, + message_type: str, + group_id: str | int | None, +) -> dict[str, Any]: + normalized = dict(seg_data) + + if seg_type not in {"video", "record", "file"}: + return normalized + + direct_url = normalized.get("url") + if isinstance(direct_url, str) and direct_url.strip(): + usable_source = _pick_usable_media_source(normalized) or direct_url.strip() + normalized["file"] = usable_source + if seg_type == "file": + normalized.setdefault("url", usable_source) + return normalized + + file_ref = normalized.get("file") + file_id = normalized.get("file_id") + candidate = file_ref or file_id + if not isinstance(candidate, str) or not candidate.strip(): + usable_source = _pick_usable_media_source(normalized) + if usable_source: + normalized["file"] = usable_source + if seg_type == "file": + normalized.setdefault("url", usable_source) + return normalized + return normalized + candidate = candidate.strip() + if _looks_like_resolved_media_ref(candidate) or os.path.exists(candidate): + normalized["file"] = candidate + return normalized + + resolved = await _resolve_onebot_file_reference( + bot, + message_type=message_type, + group_id=group_id, + file_ref=candidate, + seg_type=seg_type, + ) + if not resolved: + return normalized + + normalized["file"] = resolved + if seg_type == "file": + normalized.setdefault("url", resolved) + return normalized + + @register_platform_adapter( "aiocqhttp", "适用于 OneBot V11 标准的消息平台适配器,支持反向 WebSockets。", @@ -265,6 +401,27 @@ async def _convert_handle_message_event( abm.message.append(File(name=file_name, url=m["data"]["url"])) else: try: + normalized_data = await _normalize_onebot_media_data( + self.bot, + "file", + m["data"], + message_type=event["message_type"], + group_id=event.get("group_id"), + ) + if normalized_data.get("url"): + file_name = ( + normalized_data.get("file_name", "") + or normalized_data.get("name", "") + or normalized_data.get("file", "") + or "file" + ) + abm.message.append( + File( + name=file_name, + url=cast(str, normalized_data["url"]), + ) + ) + continue # Napcat ret = None if abm.type == MessageType.GROUP_MESSAGE: @@ -402,7 +559,14 @@ async def _convert_handle_message_event( f"不支持的消息段类型,已忽略: {t}, data={m['data']}" ) continue - a = ComponentTypes[t](**m["data"]) + normalized_data = await _normalize_onebot_media_data( + self.bot, + t, + m["data"], + message_type=event["message_type"], + group_id=event.get("group_id"), + ) + a = ComponentTypes[t](**normalized_data) abm.message.append(a) except Exception as e: logger.exception( diff --git a/tests/unit/test_record_component_sources.py b/tests/unit/test_record_component_sources.py new file mode 100644 index 0000000000..72185165ab --- /dev/null +++ b/tests/unit/test_record_component_sources.py @@ -0,0 +1,59 @@ +import os +from pathlib import Path +from unittest.mock import AsyncMock, patch + +import pytest + +from astrbot.core.message.components import Record + + +@pytest.mark.asyncio +async def test_record_convert_to_file_path_prefers_url_over_invalid_file_name(): + record = Record( + file="0f47835d687410ab50cfed981e80c15c.amr", + url="https://example.com/audio.amr", + ) + + with patch( + "astrbot.core.message.components.download_image_by_url", + AsyncMock(return_value="/tmp/audio.amr"), + ): + path = await record.convert_to_file_path() + + assert os.path.isabs(path) + assert os.path.basename(path) == "audio.amr" + + +@pytest.mark.asyncio +async def test_record_convert_to_base64_prefers_url_over_invalid_file_name(): + record = Record( + file="0f47835d687410ab50cfed981e80c15c.amr", + url="https://example.com/audio.amr", + ) + + with ( + patch( + "astrbot.core.message.components.download_image_by_url", + AsyncMock(return_value="/tmp/audio.amr"), + ), + patch( + "astrbot.core.message.components.file_to_base64", + return_value="base64://ZmFrZQ==", + ), + ): + encoded = await record.convert_to_base64() + + assert encoded == "ZmFrZQ==" + + +@pytest.mark.asyncio +async def test_record_convert_to_file_path_writes_base64_payload_with_audio_extension(): + record = Record(file="base64://ZmFrZQ==") + + path = await record.convert_to_file_path() + + assert os.path.isabs(path) + assert Path(path).suffix == ".amr" + assert Path(path).read_bytes() == b"fake" + + Path(path).unlink(missing_ok=True)