Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 119 additions & 45 deletions astrbot/core/db/vec_db/faiss_impl/embedding_storage.py
Original file line number Diff line number Diff line change
@@ -1,95 +1,169 @@
try:
import faiss
except ModuleNotFoundError:
except ImportError as e:
raise ImportError(
"faiss 未安装。请使用 'pip install faiss-cpu' 或 'pip install faiss-gpu' 安装。",
)
import os
import shutil
import tempfile

import numpy as np


# ── Faiss C++ fopen() 在 Windows 上使用 ANSI codepage ──
# Python 传给 Faiss 的路径是 UTF-8 字节,但 Windows fopen 期望 ANSI 编码,
# 导致含非 ASCII 字符的路径(如 C:\Users\中文用户名\...)被解读为乱码而失败。
# 本模块通过"纯 ASCII 临时文件桥接"规避此问题。


def _needs_bridge(path: str) -> bool:
"""判断是否需要 ASCII 临时文件桥接。"""
return os.name == "nt" and not path.isascii()


def _safe_temp_dir() -> str:
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
"""返回保证纯 ASCII 且可写的临时目录,用于 Faiss I/O 桥接。

优先级:
1. %SystemRoot%\\Temp(Windows 系统临时目录)
2. tempfile.gettempdir()(当其为纯 ASCII 时)
3. 非 Windows 平台使用 tempfile.gettempdir()
"""
if os.name == "nt":
root = os.environ.get("SystemRoot", r"C:\Windows")
temp_dir = os.path.join(root, "Temp")
if temp_dir.isascii() and os.path.isdir(temp_dir) and os.access(temp_dir, os.W_OK):
return temp_dir

tmp = tempfile.gettempdir()
if tmp.isascii():
return tmp

raise OSError(
"_safe_temp_dir: 无法找到可写的纯 ASCII 临时目录。"
f" 检查过 SystemRoot\\Temp={temp_dir}, gettempdir={tmp}"
)

return tempfile.gettempdir()


def _make_temp_file(prefix: str) -> str:
"""创建用于 Faiss 桥接的临时文件,返回路径。"""
safe_dir = _safe_temp_dir()
fd, path = tempfile.mkstemp(prefix=f"{prefix}_", suffix=".faiss", dir=safe_dir)
os.close(fd)
return path


class EmbeddingStorage:
def __init__(self, dimension: int, path: str | None = None) -> None:
self.dimension = dimension
self.path = path
self.index = None
if path and os.path.exists(path):
self.index = faiss.read_index(path)
self.index = self._read_index(path)
else:
base_index = faiss.IndexFlatL2(dimension)
self.index = faiss.IndexIDMap(base_index)

async def insert(self, vector: np.ndarray, id: int) -> None:
"""插入向量
@staticmethod
def _read_index(path: str) -> "faiss.Index":
"""读取 Faiss 索引,兼容含非 ASCII 字符的 Windows 路径。"""
try:
return faiss.read_index(path)
except RuntimeError:
if not _needs_bridge(path):
raise

tmp = _make_temp_file("_faiss_read")
try:
shutil.copy2(path, tmp)
return faiss.read_index(tmp)
finally:
if os.path.exists(tmp):
try:
os.remove(tmp)
except OSError:
pass

@staticmethod
def _write_index(index: "faiss.Index", path: str) -> None:
"""保存 Faiss 索引,兼容含非 ASCII 字符的 Windows 路径。"""
dirname = os.path.dirname(path)
if dirname:
os.makedirs(dirname, exist_ok=True)

if not _needs_bridge(path):
faiss.write_index(index, path)
return

Args:
vector (np.ndarray): 要插入的向量
id (int): 向量的ID
Raises:
ValueError: 如果向量的维度与存储的维度不匹配
tmp = _make_temp_file("_faiss_write")
try:
faiss.write_index(index, tmp)
shutil.move(tmp, path)
finally:
if os.path.exists(tmp):
try:
os.remove(tmp)
except OSError:
pass

"""
async def insert(self, vector: np.ndarray, id: int) -> None:
"""插入向量"""
assert self.index is not None, "FAISS index is not initialized."
if vector.shape[0] != self.dimension:
raise ValueError(
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}",
)
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
self.index.add_with_ids(vector.reshape(1, -1), np.array([id], dtype=np.int64))
await self.save_index()

async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None:
"""批量插入向量

Args:
vectors (np.ndarray): 要插入的向量数组
ids (list[int]): 向量的ID列表
Raises:
ValueError: 如果向量的维度与存储的维度不匹配

"""
"""批量插入向量"""
assert self.index is not None, "FAISS index is not initialized."
if len(vectors.shape) != 2:
raise ValueError(
f"向量必须是二维数组, 当前维度: {len(vectors.shape)}",
)
if vectors.shape[1] != self.dimension:
raise ValueError(
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vectors.shape[1]}",
)
self.index.add_with_ids(vectors, np.array(ids))
self.index.add_with_ids(vectors, np.array(ids, dtype=np.int64))
await self.save_index()

async def search(self, vector: np.ndarray, k: int) -> tuple:
"""搜索最相似的向量

Args:
vector (np.ndarray): 查询向量
k (int): 返回的最相似向量的数量
Returns:
tuple: (距离, 索引)

"""
"""搜索向量"""
assert self.index is not None, "FAISS index is not initialized."
faiss.normalize_L2(vector)
distances, indices = self.index.search(vector, k)
if vector.ndim != 1:
raise ValueError(
f"查询向量必须是 1 维, 实际维度: {vector.ndim}。"
" 如需批量搜索请使用 Faiss 原生 API。"
)
if vector.shape[0] != self.dimension:
raise ValueError(
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}",
)
distances, indices = self.index.search(vector.reshape(1, -1), k)
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
return distances, indices

async def delete(self, ids: list[int]) -> None:
"""删除向量

Args:
ids (list[int]): 要删除的向量ID列表

删除不存在的 ID 时 Faiss 会抛 RuntimeError。
由于 remove_ids 为幂等操作,此处忽略该错误。
"""
assert self.index is not None, "FAISS index is not initialized."
id_array = np.array(ids, dtype=np.int64)
self.index.remove_ids(id_array)
try:
self.index.remove_ids(np.array(ids, dtype=np.int64))
except RuntimeError:
# 幂等:删除已不存在的 ID,安全忽略
pass
Comment on lines +158 to +162
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Swallowing RuntimeError during remove_ids without any logging can make it difficult to diagnose issues if the operation fails for unexpected reasons (e.g., index corruption or internal Faiss errors). While IndexIDMap generally supports removal, it is better practice to log a warning when an exception is caught.

Suggested change
try:
self.index.remove_ids(np.array(ids, dtype=np.int64))
except RuntimeError:
pass
try:
self.index.remove_ids(np.array(ids, dtype=np.int64))
except RuntimeError:
# Consider logging the error here if a logger is available
pass

await self.save_index()

async def save_index(self) -> None:
"""保存索引

Args:
path (str): 保存索引的路径

"""
if self.index is None:
"""保存索引(兼容含非 ASCII 字符的 Windows 路径)"""
if self.index is None or not self.path:
return
faiss.write_index(self.index, self.path)
self._write_index(self.index, self.path)