Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 60 additions & 34 deletions astrbot/core/message/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64


def _write_base64_payload_to_temp_file(
payload: str,
*,
prefix: str,
suffix: str,
) -> str:
file_bytes = base64.b64decode(payload)
file_path = os.path.join(
get_astrbot_temp_path(), f"{prefix}_{uuid.uuid4().hex}{suffix}"
)
with open(file_path, "wb") as f:
f.write(file_bytes)
return os.path.abspath(file_path)


class ComponentType(str, Enum):
# Basic Segment Types
Plain = "Plain" # plain text message
Expand Down Expand Up @@ -139,32 +154,37 @@ def fromURL(url: str, **_):
def fromBase64(bs64_data: str, **_):
return Record(file=f"base64://{bs64_data}", **_)

def _get_source(self) -> str:
for candidate in (self.url, self.file, self.path):
if isinstance(candidate, str) and candidate:
return candidate
return ""

async def convert_to_file_path(self) -> str:
"""将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。

Returns:
str: 语音的本地路径,以绝对路径表示。

"""
if not self.file:
source = self._get_source()
if not source:
raise Exception(f"not a valid file: {self.file}")
if self.file.startswith("file:///"):
return self.file[8:]
if self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
if source.startswith("file:///"):
return source[8:]
if source.startswith("http"):
file_path = await download_image_by_url(source)
return os.path.abspath(file_path)
if self.file.startswith("base64://"):
bs64_data = self.file.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
file_path = os.path.join(
get_astrbot_temp_path(), f"recordseg_{uuid.uuid4()}.jpg"
if source.startswith("base64://"):
bs64_data = source.removeprefix("base64://")
return _write_base64_payload_to_temp_file(
bs64_data,
prefix="recordseg",
suffix=".amr",
)
with open(file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(file_path)
if os.path.exists(self.file):
return os.path.abspath(self.file)
raise Exception(f"not a valid file: {self.file}")
if os.path.exists(source):
return os.path.abspath(source)
raise Exception(f"not a valid file: {source}")

async def convert_to_base64(self) -> str:
"""将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。
Expand All @@ -174,19 +194,20 @@ async def convert_to_base64(self) -> str:

"""
# convert to base64
if not self.file:
source = self._get_source()
if not source:
raise Exception(f"not a valid file: {self.file}")
if self.file.startswith("file:///"):
bs64_data = file_to_base64(self.file[8:])
elif self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
if source.startswith("file:///"):
bs64_data = file_to_base64(source[8:])
elif source.startswith("http"):
file_path = await download_image_by_url(source)
bs64_data = file_to_base64(file_path)
elif self.file.startswith("base64://"):
bs64_data = self.file
elif os.path.exists(self.file):
bs64_data = file_to_base64(self.file)
elif source.startswith("base64://"):
bs64_data = source
elif os.path.exists(source):
bs64_data = file_to_base64(source)
else:
raise Exception(f"not a valid file: {self.file}")
raise Exception(f"not a valid file: {source}")
bs64_data = bs64_data.removeprefix("base64://")
return bs64_data

Expand Down Expand Up @@ -217,6 +238,7 @@ async def register_to_file_service(self) -> str:
class Video(BaseMessageComponent):
type: ComponentType = ComponentType.Video
file: str
url: str | None = ""
cover: str | None = ""
# 额外
path: str | None = ""
Expand All @@ -234,14 +256,20 @@ def fromURL(url: str, **_):
return Video(file=url, **_)
raise Exception("not a valid url")

def _get_source(self) -> str:
for candidate in (self.url, self.file, self.path):
if isinstance(candidate, str) and candidate:
return candidate
return ""

async def convert_to_file_path(self) -> str:
"""将这个视频统一转换为本地文件路径。这个方法避免了手动判断视频数据类型,直接返回视频数据的本地路径(如果是网络 URL,则会自动进行下载)。

Returns:
str: 视频的本地路径,以绝对路径表示。

"""
url = self.file
url = self._get_source()
if url and url.startswith("file:///"):
return url[8:]
if url and url.startswith("http"):
Expand Down Expand Up @@ -281,7 +309,7 @@ async def register_to_file_service(self) -> str:

async def to_dict(self):
"""需要和 toDict 区分开,toDict 是同步方法"""
url_or_path = self.file
url_or_path = self._get_source()
if url_or_path.startswith("http"):
payload_file = url_or_path
elif callback_host := astrbot_config.get("callback_api_base"):
Expand Down Expand Up @@ -440,13 +468,11 @@ async def convert_to_file_path(self) -> str:
return os.path.abspath(image_file_path)
if url.startswith("base64://"):
bs64_data = url.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
image_file_path = os.path.join(
get_astrbot_temp_path(), f"imgseg_{uuid.uuid4()}.jpg"
return _write_base64_payload_to_temp_file(
bs64_data,
prefix="imgseg",
suffix=".jpg",
)
with open(image_file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(image_file_path)
if os.path.exists(url):
return os.path.abspath(url)
raise Exception(f"not a valid file: {url}")
Expand Down
166 changes: 165 additions & 1 deletion astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import itertools
import logging
import os
import time
import uuid
from collections.abc import Awaitable
Expand All @@ -27,6 +28,141 @@
from .aiocqhttp_message_event import AiocqhttpMessageEvent


def _looks_like_resolved_media_ref(value: str) -> bool:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider refactoring the new media resolution helpers into a small MediaResolver class with shared helpers and action templates to reduce duplication and make the control flow easier to follow.

You can reduce the new complexity without changing behavior by:

  1. Encapsulating shared context and flow in a small resolver object
  2. Centralizing the “is resolved” check
  3. De‑duplicating action construction with static templates

1. Encapsulate context into a MediaResolver

Instead of threading bot, message_type, group_id, seg_type through multiple helpers and call sites, wrap them in a small resolver. This keeps the public surface small while keeping functionality identical:

class MediaResolver:
    def __init__(self, bot: CQHttp, message_type: str, group_id: str | int | None):
        self.bot = bot
        self.message_type = message_type
        self.group_id = group_id

    async def normalize(self, seg_type: str, seg_data: dict[str, Any]) -> dict[str, Any]:
        # move current _normalize_onebot_media_data body here
        # but replace direct calls with self._resolve_file_reference(...)
        ...

    async def _resolve_file_reference(self, file_ref: str, seg_type: str | None = None) -> str | None:
        # move current _resolve_onebot_file_reference body here
        # use self.bot, self.message_type, self.group_id
        ...

Call site becomes simpler and avoids repeated parameters:

resolver = MediaResolver(self.bot, event["message_type"], event.get("group_id"))

# file segment special case
normalized_data = await resolver.normalize("file", m["data"])
...

# generic handler
normalized_data = await resolver.normalize(t, m["data"])
a = ComponentTypes[t](**normalized_data)

This keeps all behavior but removes a lot of cross‑cutting arguments and makes the flow easier to follow.

2. Centralize the “resolved path or URL” check

You currently repeat:

  • _looks_like_resolved_media_ref(value) or os.path.exists(value)

across _pick_usable_media_source, _normalize_onebot_media_data, _resolve_onebot_file_reference.

Extract this into a single helper that returns either a normalized string or None:

def _resolve_local_or_url_candidate(value: str | None) -> str | None:
    if not isinstance(value, str):
        return None
    candidate = value.strip()
    if not candidate:
        return None
    if _looks_like_resolved_media_ref(candidate) or os.path.exists(candidate):
        return candidate
    return None

Then simplify callers, e.g.:

def _pick_usable_media_source(seg_data: dict[str, Any]) -> str:
    for key in ("url", "file", "path", "file_path"):
        resolved = _resolve_local_or_url_candidate(seg_data.get(key))
        if resolved:
            return resolved
    return ""

and in the resolver:

candidate = file_ref or file_id
resolved = _resolve_local_or_url_candidate(candidate)
if resolved:
    normalized["file"] = resolved
    return normalized

This reduces branching and duplication without changing semantics.

3. De‑duplicate action list construction with templates

The actions.extend([...]) pattern is verbose and repeated per candidate. You can describe the patterns once and instantiate them:

# module-level, static
_BASE_ACTION_TEMPLATES: list[tuple[str, dict[str, str]]] = [
    ("get_file", {"file_id": "candidate"}),
    ("get_file", {"file": "candidate"}),
    ("get_image", {"file": "candidate"}),
    ("get_image", {"file_id": "candidate"}),
    ("get_private_file_url", {"file_id": "candidate"}),
]

def _build_actions_for_candidate(
    candidate: str,
    message_type: str,
    group_id: str | int | None,
    seg_type: str | None,
) -> list[tuple[str, dict[str, Any]]]:
    actions: list[tuple[str, dict[str, Any]]] = []
    if seg_type == "record":
        actions.append(("get_record", {"file": candidate}))
    for action, params in _BASE_ACTION_TEMPLATES:
        concrete = {k: (candidate if v == "candidate" else v) for k, v in params.items()}
        actions.append((action, concrete))
    if str(message_type).lower() == "group" and group_id not in (None, ""):
        actions.append(("get_group_file_url", {"group_id": group_id, "file_id": candidate}))
    return actions

Then _resolve_onebot_file_reference (or MediaResolver._resolve_file_reference) becomes clearer:

actions: list[tuple[str, dict[str, Any]]] = []
for candidate in candidates:
    actions.extend(_build_actions_for_candidate(candidate, message_type, group_id, seg_type))

This removes duplicated action construction logic and makes it obvious what variations exist.


These changes keep your new media resolution behavior intact, but:

  • Narrow the public API to MediaResolver.normalize(...)
  • Centralize repeated checks
  • Make candidate/action handling more declarative and easier to scan.

return value.startswith(("http://", "https://", "file://")) or os.path.isabs(value)


def _pick_usable_media_source(seg_data: dict[str, Any]) -> str:
for key in ("url", "file", "path", "file_path"):
value = seg_data.get(key)
if isinstance(value, str) and value.strip():
candidate = value.strip()
if _looks_like_resolved_media_ref(candidate) or os.path.exists(candidate):
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
return candidate
return ""


def _unwrap_onebot_action_data(payload: Any) -> dict[str, Any]:
if not isinstance(payload, dict):
return {}
data = payload.get("data")
if isinstance(data, dict):
return data
return payload


async def _resolve_onebot_file_reference(
bot: CQHttp,
*,
message_type: str,
group_id: str | int | None,
file_ref: str,
seg_type: str | None = None,
) -> str | None:
normalized = str(file_ref or "").strip()
if not normalized:
return None
if _looks_like_resolved_media_ref(normalized) or os.path.exists(normalized):
return normalized

candidates = [normalized]
base_name = os.path.basename(normalized)
if base_name and base_name not in candidates:
candidates.append(base_name)
stem, ext = os.path.splitext(base_name)
if stem and ext and stem not in candidates:
candidates.append(stem)

actions: list[tuple[str, dict[str, Any]]] = []
if seg_type == "record":
for candidate in candidates:
actions.append(("get_record", {"file": candidate}))
for candidate in candidates:
actions.extend(
[
("get_file", {"file_id": candidate}),
("get_file", {"file": candidate}),
("get_image", {"file": candidate}),
("get_image", {"file_id": candidate}),
("get_private_file_url", {"file_id": candidate}),
]
)
if str(message_type).lower() == "group" and group_id not in (None, ""):
actions.append(
(
"get_group_file_url",
{"group_id": group_id, "file_id": candidate},
)
)

for action, params in actions:
try:
ret = await bot.call_action(action=action, **params)
except Exception:
continue
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
data = _unwrap_onebot_action_data(ret)
for key in ("url", "file"):
value = data.get(key)
if isinstance(value, str) and value.strip():
resolved = value.strip()
if _looks_like_resolved_media_ref(resolved) or os.path.exists(resolved):
return resolved
return None


async def _normalize_onebot_media_data(
bot: CQHttp,
seg_type: str,
seg_data: dict[str, Any],
*,
message_type: str,
group_id: str | int | None,
) -> dict[str, Any]:
normalized = dict(seg_data)

if seg_type not in {"video", "record", "file"}:
return normalized

direct_url = normalized.get("url")
if isinstance(direct_url, str) and direct_url.strip():
usable_source = _pick_usable_media_source(normalized) or direct_url.strip()
normalized["file"] = usable_source
if seg_type == "file":
normalized.setdefault("url", usable_source)
return normalized

file_ref = normalized.get("file")
file_id = normalized.get("file_id")
candidate = file_ref or file_id
if not isinstance(candidate, str) or not candidate.strip():
usable_source = _pick_usable_media_source(normalized)
if usable_source:
normalized["file"] = usable_source
if seg_type == "file":
normalized.setdefault("url", usable_source)
return normalized
return normalized
candidate = candidate.strip()
if _looks_like_resolved_media_ref(candidate) or os.path.exists(candidate):
normalized["file"] = candidate
return normalized

resolved = await _resolve_onebot_file_reference(
bot,
message_type=message_type,
group_id=group_id,
file_ref=candidate,
seg_type=seg_type,
)
if not resolved:
return normalized

normalized["file"] = resolved
if seg_type == "file":
normalized.setdefault("url", resolved)
return normalized


@register_platform_adapter(
"aiocqhttp",
"适用于 OneBot V11 标准的消息平台适配器,支持反向 WebSockets。",
Expand Down Expand Up @@ -265,6 +401,27 @@ async def _convert_handle_message_event(
abm.message.append(File(name=file_name, url=m["data"]["url"]))
else:
try:
normalized_data = await _normalize_onebot_media_data(
self.bot,
"file",
m["data"],
message_type=event["message_type"],
group_id=event.get("group_id"),
)
if normalized_data.get("url"):
file_name = (
normalized_data.get("file_name", "")
or normalized_data.get("name", "")
or normalized_data.get("file", "")
or "file"
)
abm.message.append(
File(
name=file_name,
url=cast(str, normalized_data["url"]),
)
)
continue
# Napcat
ret = None
if abm.type == MessageType.GROUP_MESSAGE:
Expand Down Expand Up @@ -402,7 +559,14 @@ async def _convert_handle_message_event(
f"不支持的消息段类型,已忽略: {t}, data={m['data']}"
)
continue
a = ComponentTypes[t](**m["data"])
normalized_data = await _normalize_onebot_media_data(
self.bot,
t,
m["data"],
message_type=event["message_type"],
group_id=event.get("group_id"),
)
a = ComponentTypes[t](**normalized_data)
abm.message.append(a)
except Exception as e:
logger.exception(
Expand Down
Loading
Loading