diff --git a/examples/TCL_rag/config.yaml b/examples/TCL_rag/config.yaml index 231b3a8..9551aff 100644 --- a/examples/TCL_rag/config.yaml +++ b/examples/TCL_rag/config.yaml @@ -1,12 +1,12 @@ llm: name: openai - base_url: "https://api.gptsapi.net/v1" - api_key: "sk-2T06b7c7f9c3870049fbf8fada596b0f8ef908d1e233KLY2" + base_url: "xxx" + api_key: "xxx" model: "gpt-4.1-mini" embedding: name: huggingface - model_name: "/finance_ML/dataarc_syn_database/model/Qwen/qwen_embedding_0.6B" + model_name: "xxx" model_kwargs: device: "cuda:0" @@ -14,20 +14,20 @@ embedding: store: name: faiss - folder_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/test_faiss_store + folder_path: xxx bm25: name: bm25 k: 10 - data_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/tog3_backend/TCL/syn_table_data/data_all_clearn_short_chunk_with_caption_desc.json + data_path: xxx retriever: name: vectorstore reranker: name: qwen3 - model_name_or_path: "/finance_ML/dataarc_syn_database/model/Qwen/qwen_reranker_0.6B" + model_name_or_path: "xxx" device_id: "cuda:0" dataset: diff --git a/examples/TCL_rag/test.py b/examples/TCL_rag/test.py index 1062cae..34efb74 100644 --- a/examples/TCL_rag/test.py +++ b/examples/TCL_rag/test.py @@ -24,11 +24,8 @@ vector_store_config=vector_store_config, bm25_retriever_config=bm25_retriever_config) - result = rag.invoke("毛细管设计规范按照什么标准",k=50) + result = rag.invoke("模块机传感器端子不防呆的改善方案是什么?由哪个部门负责?",k=20) - result = rag.rerank("毛细管设计规范按照什么标准",result,k=10) - - answer = rag.answer("毛细管设计规范按照什么标准",result) - - - print(answer) \ No newline at end of file + for i in result: + print(i) + print("-"*100) \ No newline at end of file diff --git a/rag_factory/Retrieval/Retriever/Retriever_BM25.py b/rag_factory/Retrieval/Retriever/Retriever_BM25.py index 98eef3e..9749166 100644 --- a/rag_factory/Retrieval/Retriever/Retriever_BM25.py +++ b/rag_factory/Retrieval/Retriever/Retriever_BM25.py @@ -5,7 +5,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence from dataclasses import dataclass, field - +import uuid from pydantic import ConfigDict, Field, model_validator logger = logging.getLogger(__name__) @@ -207,7 +207,7 @@ def from_texts( f"与 texts 长度 ({len(texts_list)}) 不匹配" ) else: - ids_list = [None for _ in texts_list] + ids_list = [str(uuid.uuid4()) for _ in texts_list] # 预处理文本 logger.info(f"正在预处理 {len(texts_list)} 个文本...")