Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions examples/TCL_rag/config.yaml
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
llm:
name: openai
base_url: "https://api.gptsapi.net/v1"
api_key: "sk-2T06b7c7f9c3870049fbf8fada596b0f8ef908d1e233KLY2"
base_url: "xxx"
api_key: "xxx"
model: "gpt-4.1-mini"

embedding:
name: huggingface
model_name: "/finance_ML/dataarc_syn_database/model/Qwen/qwen_embedding_0.6B"
model_name: "xxx"
model_kwargs:
device: "cuda:0"



store:
name: faiss
folder_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/test_faiss_store
folder_path: xxx


bm25:
name: bm25
k: 10
data_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/tog3_backend/TCL/syn_table_data/data_all_clearn_short_chunk_with_caption_desc.json
data_path: xxx

retriever:
name: vectorstore

reranker:
name: qwen3
model_name_or_path: "/finance_ML/dataarc_syn_database/model/Qwen/qwen_reranker_0.6B"
model_name_or_path: "xxx"
device_id: "cuda:0"

dataset:
Expand Down
11 changes: 4 additions & 7 deletions examples/TCL_rag/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,8 @@
vector_store_config=vector_store_config,
bm25_retriever_config=bm25_retriever_config)

result = rag.invoke("毛细管设计规范按照什么标准",k=50)
result = rag.invoke("模块机传感器端子不防呆的改善方案是什么?由哪个部门负责?",k=20)

result = rag.rerank("毛细管设计规范按照什么标准",result,k=10)

answer = rag.answer("毛细管设计规范按照什么标准",result)


print(answer)
for i in result:
print(i)
print("-"*100)
4 changes: 2 additions & 2 deletions rag_factory/Retrieval/Retriever/Retriever_BM25.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence
from dataclasses import dataclass, field

import uuid
from pydantic import ConfigDict, Field, model_validator

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -207,7 +207,7 @@ def from_texts(
f"与 texts 长度 ({len(texts_list)}) 不匹配"
)
else:
ids_list = [None for _ in texts_list]
ids_list = [str(uuid.uuid4()) for _ in texts_list]

# 预处理文本
logger.info(f"正在预处理 {len(texts_list)} 个文本...")
Expand Down