Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 55 additions & 13 deletions astrbot/core/provider/sources/openai_source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import base64
import binascii
import copy
import inspect
import json
Expand Down Expand Up @@ -55,6 +56,17 @@
)
class ProviderOpenAIOfficial(Provider):
_ERROR_TEXT_CANDIDATE_MAX_CHARS = 4096
# 部分 OpenAI 兼容中转站会校验 data URL 的 MIME 类型是否和图片字节一致。
# 这里统一维护格式映射,确保本地文件和 `base64://` 图片引用使用相同声明。
_IMAGE_FORMAT_MIME_TYPES = {
"JPEG": "image/jpeg",
"PNG": "image/png",
"GIF": "image/gif",
"WEBP": "image/webp",
"BMP": "image/bmp",
"TIFF": "image/tiff",
"AVIF": "image/avif",
}

@classmethod
def _truncate_error_text_candidate(cls, text: str) -> str:
Expand Down Expand Up @@ -195,25 +207,55 @@ def _encode_image_file_to_data_url(
raise
return None

try:
with PILImage.open(BytesIO(image_bytes)) as image:
image.verify()
image_format = str(image.format or "").upper()
except (OSError, UnidentifiedImageError):
image_format = cls._detect_image_format(image_bytes)
if image_format is None:
if mode == "strict":
raise ValueError(f"Invalid image file: {image_path}")
return None

mime_type = {
"JPEG": "image/jpeg",
"PNG": "image/png",
"GIF": "image/gif",
"WEBP": "image/webp",
"BMP": "image/bmp",
}.get(image_format, "image/jpeg")
mime_type = cls._image_format_to_mime_type(image_format)
image_bs64 = base64.b64encode(image_bytes).decode("utf-8")
return f"data:{mime_type};base64,{image_bs64}"

@classmethod
def _detect_image_format(cls, image_bytes: bytes) -> str | None:
"""返回 Pillow 校验后的图片格式,非法图片返回 None。"""
try:
# verify() 只校验图片容器,不完整解码像素。
# 这里仅需要可信的格式标签,因此这种方式足够且开销较小。
with PILImage.open(BytesIO(image_bytes)) as image:
image.verify()
return str(image.format or "").upper()
except (OSError, UnidentifiedImageError):
return None

@classmethod
def _image_format_to_mime_type(cls, image_format: str | None) -> str:
"""将 Pillow 图片格式映射为 data URL 使用的 MIME 类型。"""
# 未识别格式保持历史 JPEG 兜底,兼容传入任意 `base64://` 内容的旧调用方。
return cls._IMAGE_FORMAT_MIME_TYPES.get(
str(image_format or "").upper(), "image/jpeg"
)

@classmethod
def _base64_image_ref_to_data_url(cls, image_ref: str) -> str:
"""将 `base64://` 图片引用转换为带真实 MIME 的 data URL。"""
raw_base64 = image_ref.removeprefix("base64://")
mime_type = "image/jpeg"
try:
# 平台适配器可能通过 `base64://` 传入 PNG/GIF/WebP 等图片字节,
# 但不会额外携带 MIME 元数据。发送 OpenAI 请求前先识别真实格式,
# 避免把 PNG 等图片错误声明为 JPEG。
image_bytes = base64.b64decode(raw_base64)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Decoding the entire base64 string into memory just to detect the image format can be inefficient for very large images. Since most image headers are within the first few dozen bytes, you could potentially optimize this by decoding only the beginning of the string. However, given that the full base64 is required for the final data URL and Pillow's verify() is relatively lightweight, this is a minor performance consideration. Additionally, as this introduces new functionality for handling attachments, please ensure it is accompanied by corresponding unit tests.

References
  1. New functionality, such as handling attachments, should be accompanied by corresponding unit tests.

except (binascii.Error, ValueError):
# 对错误或非图片 base64 保持旧行为:继续返回 JPEG data URL,
# 避免让历史调用方因为格式识别失败而直接抛异常。
pass
else:
image_format = cls._detect_image_format(image_bytes)
mime_type = cls._image_format_to_mime_type(image_format)
return f"data:{mime_type};base64,{raw_base64}"

@staticmethod
def _file_uri_to_path(file_uri: str) -> str:
"""Normalize file URIs to paths.
Expand Down Expand Up @@ -242,7 +284,7 @@ async def _image_ref_to_data_url(
mode: Literal["safe", "strict"] = "safe",
) -> str | None:
if image_ref.startswith("base64://"):
return image_ref.replace("base64://", "data:image/jpeg;base64,")
return self._base64_image_ref_to_data_url(image_ref)

if image_ref.startswith("http"):
image_path = await download_image_by_url(image_ref)
Expand Down
23 changes: 23 additions & 0 deletions tests/test_openai_source.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import base64
import builtins
from io import BytesIO
from types import SimpleNamespace

import pytest
Expand Down Expand Up @@ -1031,6 +1033,27 @@ async def test_resolve_image_part_supports_base64_scheme():
await provider.terminate()


@pytest.mark.asyncio
async def test_resolve_image_part_preserves_base64_png_mime_type():
provider = _make_provider()
try:
image_buffer = BytesIO()
PILImage.new("RGBA", (1, 1), (255, 0, 0, 255)).save(
image_buffer,
format="PNG",
)
image_base64 = base64.b64encode(image_buffer.getvalue()).decode("ascii")

image_part = await provider._resolve_image_part(f"base64://{image_base64}")
Comment on lines +1045 to +1047
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Add a test for invalid or malformed base64 inputs to ensure legacy JPEG behavior is preserved

Since _base64_image_ref_to_data_url is intended to preserve the existing behavior for invalid/malformed base64 (or non-image) input by still returning a JPEG data URL, please add a test where image_ref is base64:// plus an invalid base64 string (e.g. "not-base64") and assert that _resolve_image_part still returns a data:image/jpeg;base64,... URL rather than raising. This will explicitly cover the backward-compatibility case described in the implementation comments.


assert image_part == {
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
}
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_prepare_chat_payload_materializes_context_localhost_file_uri_image_urls(
tmp_path,
Expand Down
Loading