diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 64e3a6645a..9b73819a38 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -1,5 +1,6 @@ import asyncio import base64 +import binascii import copy import inspect import json @@ -55,6 +56,17 @@ ) class ProviderOpenAIOfficial(Provider): _ERROR_TEXT_CANDIDATE_MAX_CHARS = 4096 + # 部分 OpenAI 兼容中转站会校验 data URL 的 MIME 类型是否和图片字节一致。 + # 这里统一维护格式映射,确保本地文件和 `base64://` 图片引用使用相同声明。 + _IMAGE_FORMAT_MIME_TYPES = { + "JPEG": "image/jpeg", + "PNG": "image/png", + "GIF": "image/gif", + "WEBP": "image/webp", + "BMP": "image/bmp", + "TIFF": "image/tiff", + "AVIF": "image/avif", + } @classmethod def _truncate_error_text_candidate(cls, text: str) -> str: @@ -195,25 +207,55 @@ def _encode_image_file_to_data_url( raise return None - try: - with PILImage.open(BytesIO(image_bytes)) as image: - image.verify() - image_format = str(image.format or "").upper() - except (OSError, UnidentifiedImageError): + image_format = cls._detect_image_format(image_bytes) + if image_format is None: if mode == "strict": raise ValueError(f"Invalid image file: {image_path}") return None - mime_type = { - "JPEG": "image/jpeg", - "PNG": "image/png", - "GIF": "image/gif", - "WEBP": "image/webp", - "BMP": "image/bmp", - }.get(image_format, "image/jpeg") + mime_type = cls._image_format_to_mime_type(image_format) image_bs64 = base64.b64encode(image_bytes).decode("utf-8") return f"data:{mime_type};base64,{image_bs64}" + @classmethod + def _detect_image_format(cls, image_bytes: bytes) -> str | None: + """返回 Pillow 校验后的图片格式,非法图片返回 None。""" + try: + # verify() 只校验图片容器,不完整解码像素。 + # 这里仅需要可信的格式标签,因此这种方式足够且开销较小。 + with PILImage.open(BytesIO(image_bytes)) as image: + image.verify() + return str(image.format or "").upper() + except (OSError, UnidentifiedImageError): + return None + + @classmethod + def _image_format_to_mime_type(cls, image_format: str | None) -> str: + """将 Pillow 图片格式映射为 data URL 使用的 MIME 类型。""" + # 未识别格式保持历史 JPEG 兜底,兼容传入任意 `base64://` 内容的旧调用方。 + return cls._IMAGE_FORMAT_MIME_TYPES.get( + str(image_format or "").upper(), "image/jpeg" + ) + + @classmethod + def _base64_image_ref_to_data_url(cls, image_ref: str) -> str: + """将 `base64://` 图片引用转换为带真实 MIME 的 data URL。""" + raw_base64 = image_ref.removeprefix("base64://") + mime_type = "image/jpeg" + try: + # 平台适配器可能通过 `base64://` 传入 PNG/GIF/WebP 等图片字节, + # 但不会额外携带 MIME 元数据。发送 OpenAI 请求前先识别真实格式, + # 避免把 PNG 等图片错误声明为 JPEG。 + image_bytes = base64.b64decode(raw_base64) + except (binascii.Error, ValueError): + # 对错误或非图片 base64 保持旧行为:继续返回 JPEG data URL, + # 避免让历史调用方因为格式识别失败而直接抛异常。 + pass + else: + image_format = cls._detect_image_format(image_bytes) + mime_type = cls._image_format_to_mime_type(image_format) + return f"data:{mime_type};base64,{raw_base64}" + @staticmethod def _file_uri_to_path(file_uri: str) -> str: """Normalize file URIs to paths. @@ -242,7 +284,7 @@ async def _image_ref_to_data_url( mode: Literal["safe", "strict"] = "safe", ) -> str | None: if image_ref.startswith("base64://"): - return image_ref.replace("base64://", "data:image/jpeg;base64,") + return self._base64_image_ref_to_data_url(image_ref) if image_ref.startswith("http"): image_path = await download_image_by_url(image_ref) diff --git a/tests/test_openai_source.py b/tests/test_openai_source.py index 950e2ea162..b5587ffb14 100644 --- a/tests/test_openai_source.py +++ b/tests/test_openai_source.py @@ -1,4 +1,6 @@ +import base64 import builtins +from io import BytesIO from types import SimpleNamespace import pytest @@ -1031,6 +1033,27 @@ async def test_resolve_image_part_supports_base64_scheme(): await provider.terminate() +@pytest.mark.asyncio +async def test_resolve_image_part_preserves_base64_png_mime_type(): + provider = _make_provider() + try: + image_buffer = BytesIO() + PILImage.new("RGBA", (1, 1), (255, 0, 0, 255)).save( + image_buffer, + format="PNG", + ) + image_base64 = base64.b64encode(image_buffer.getvalue()).decode("ascii") + + image_part = await provider._resolve_image_part(f"base64://{image_base64}") + + assert image_part == { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + } + finally: + await provider.terminate() + + @pytest.mark.asyncio async def test_prepare_chat_payload_materializes_context_localhost_file_uri_image_urls( tmp_path,