Skip to content

Commit 6bc24d4

Browse files
author
keqian
committed
feat: 模型改用openAI
1 parent 4012e10 commit 6bc24d4

File tree

7 files changed

+41
-13
lines changed

7 files changed

+41
-13
lines changed

.env.example

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ CHAT_MODEL_NAME=qwen3:8b
99
EMBEDDING_MODEL_NAME=bge-m3:latest
1010
MODEL_TEMPERATURE=0
1111

12+
# ===========================================
13+
# OPENAI模型
14+
# ===========================================
15+
OPENAI_API_KEY=
16+
OPENAI_API_BASE=
17+
1218
# ===========================================
1319
# 数据库配置 (Database Configuration)
1420
# ===========================================

.env.template

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ CHAT_MODEL_NAME=qwen3:8b
99
EMBEDDING_MODEL_NAME=bge-m3:latest
1010
MODEL_TEMPERATURE=0
1111

12+
# ===========================================
13+
# OPENAI模型
14+
# ===========================================
15+
OPENAI_API_KEY=
16+
OPENAI_API_BASE=
17+
1218
# ===========================================
1319
# 数据库配置 (Database Configuration)
1420
# ===========================================

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,4 +204,5 @@ secrets.json
204204
# Local development
205205
local_*
206206
.python-version
207-
qdrant_data
207+
qdrant_data
208+
CLAUDE.md

agent.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from langchain.agents import AgentExecutor, create_openai_tools_agent
99
from langchain_community.chat_message_histories import RedisChatMessageHistory
10+
from langchain_openai import ChatOpenAI
1011
from langchain_core.output_parsers import StrOutputParser
1112
from langchain_core.prompts import MessagesPlaceholder, ChatPromptTemplate
1213
from langchain_core.runnables import RunnableLambda, RunnableWithMessageHistory, RunnableConfig
@@ -41,11 +42,16 @@ def __init__(self, session_id: Optional[str] = None):
4142

4243
# 初始化 Agent 执行器
4344
self.agent_executor = self._init_agent_executor()
44-
45-
def _init_chat_model(self) -> ChatOllama:
45+
46+
def _init_chat_model(self) -> ChatOpenAI:
47+
# def _init_chat_model(self) -> ChatOllama:
4648
"""初始化聊天模型"""
47-
model_config = config.get_model_config()
48-
return ChatOllama(**model_config)
49+
# model_config = config.get_model_config()
50+
# return ChatOllama(**model_config)
51+
return ChatOpenAI(
52+
model="gpt-4o-mini",
53+
temperature=0.2,
54+
)
4955

5056
def _init_agent_executor(self) -> RunnableWithMessageHistory:
5157
"""初始化 Agent 执行器"""

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ langchain-core==0.3.69
1212
langchain-ollama==0.3.5
1313
langchain-qdrant==0.2.0
1414
langchain-text-splitters==0.3.8
15+
langchain-openai==0.3.29
1516

1617
# Database and storage
1718
redis==6.2.0

server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ async def websocket_endpoint(websocket: WebSocket):
216216
import uvicorn
217217

218218
server_logger.info("🔮 算命师机器人服务启动中...")
219-
server_logger.info(f"📍 服务地址: http://localhost:8000")
220-
server_logger.info(f"🌐 API 文档: http://localhost:8000/docs")
219+
server_logger.info(f"📍 服务地址: http://localhost:8001")
220+
server_logger.info(f"🌐 API 文档: http://localhost:8001/docs")
221221

222-
uvicorn.run(app, host="0.0.0.0", port=8000)
222+
uvicorn.run(app, host="0.0.0.0", port=8001)

services/tools.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
使用配置管理和更好的错误处理
44
"""
55
import requests
6-
from typing import Optional
76

87
from langchain.agents import tool
98
from langchain_community.utilities import SerpAPIWrapper
109
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
1110
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
1211
from langchain_core.runnables import RunnableLambda
1312
from langchain_ollama import OllamaEmbeddings, ChatOllama, OllamaLLM
13+
from langchain_openai import OpenAI, ChatOpenAI
1414
from langchain_qdrant import QdrantVectorStore
1515
from qdrant_client import QdrantClient
1616

@@ -96,8 +96,12 @@ def bazi_cesuan(query: str) -> str:
9696

9797
# 创建模型
9898
model_config = config.get_model_config()
99-
model = ChatOllama(**model_config, format="json")
100-
99+
# model = ChatOllama(**model_config, format="json")
100+
model = ChatOpenAI(
101+
model="gpt-4o-mini",
102+
temperature=0.2,
103+
)
104+
101105
# 构建处理链
102106
chain = prompt | model | parser
103107
data = chain.invoke({"query": query})
@@ -151,8 +155,12 @@ def jiemeng(query: str) -> str:
151155

152156
# 创建关键词提取模型
153157
model_config = config.get_model_config()
154-
llm = OllamaLLM(**model_config)
155-
158+
# llm = OllamaLLM(**model_config)
159+
llm = OpenAI(
160+
model="gpt-4o-mini",
161+
temperature=0.2
162+
)
163+
156164
# 直接使用统一管理的模板
157165
dream_prompt_template = SystemPrompts.DREAM_KEYWORD_EXTRACTION_PROMPT
158166

0 commit comments

Comments
 (0)