Skip to content

Commit c150053

Browse files
committed
3 modes
1 parent 6d6da90 commit c150053

File tree

1 file changed

+87
-34
lines changed

1 file changed

+87
-34
lines changed

bot.py

Lines changed: 87 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77
from langchain.chat_models import ChatOpenAI
88
from langchain.chains import ConversationalRetrievalChain
99
from langchain.memory import ConversationBufferMemory
10+
from langchain.prompts.chat import (
11+
ChatPromptTemplate,
12+
SystemMessagePromptTemplate,
13+
AIMessagePromptTemplate,
14+
HumanMessagePromptTemplate,
15+
)
1016
from dotenv import load_dotenv
1117

1218
load_dotenv(".env")
@@ -22,63 +28,110 @@
2228
embeddings = OpenAIEmbeddings()
2329
# embeddings = OllamaEmbeddings()
2430

31+
llm = ChatOpenAI(temperature=0)
32+
33+
# LLM only response
34+
template = "You are a helpful assistant that helps with programming questions."
35+
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
36+
human_template = "{text}"
37+
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
38+
chat_prompt = ChatPromptTemplate.from_messages(
39+
[system_message_prompt, human_message_prompt]
40+
)
41+
42+
43+
def generate_llm_output(user_input: str) -> str:
44+
return llm(
45+
chat_prompt.format_prompt(
46+
text=user_input,
47+
).to_messages()
48+
).content
49+
50+
51+
# Rag response
2552
neo4j_db = Neo4jVector.from_existing_index(
2653
embedding=embeddings,
2754
url=url,
2855
username=username,
2956
password=password,
3057
database="neo4j", # neo4j by default
3158
index_name="stackoverflow", # vector by default
32-
node_label="Question", # Chunk by default
3359
text_node_property="body", # text by default
34-
embedding_node_property="embedding", # embedding by default
35-
# todo retrieval query for KG
3660
)
3761

38-
# result = neo4j_db.similarity_search(prompt, k=1)
39-
#
40-
# print(result)
41-
#
42-
# res = embeddings.embed_query(prompt)
43-
# print(len(res))
44-
45-
4662
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
63+
qa = ConversationalRetrievalChain.from_llm(llm, neo4j_db.as_retriever(), memory=memory)
4764

48-
qa = ConversationalRetrievalChain.from_llm(
49-
ChatOpenAI(temperature=0), neo4j_db.as_retriever(), memory=memory
65+
# Rag + KG
66+
kg = Neo4jVector.from_existing_index(
67+
embedding=embeddings,
68+
url=url,
69+
username=username,
70+
password=password,
71+
database="neo4j", # neo4j by default
72+
index_name="stackoverflow", # vector by default
73+
text_node_property="body", # text by default
74+
retrieval_query="RETURN 'fancy' AS text, 1 AS score, {} AS metadata", # Fix this
5075
)
5176

77+
kg_memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
78+
kg_qa = ConversationalRetrievalChain.from_llm(llm, kg.as_retriever(), memory=kg_memory)
5279

5380
# Streamlit stuff
5481

55-
# Session state
56-
if "generated" not in st.session_state:
57-
st.session_state["generated"] = []
82+
# Make sure the text input is at the bottom
83+
# We can't use chat_input as it can't be put into a tab
84+
styl = f"""
85+
<style>
86+
.stTextInput {{
87+
position: fixed;
88+
bottom: 6rem;
89+
}}
90+
</style>
91+
"""
92+
st.markdown(styl, unsafe_allow_html=True)
93+
94+
95+
def tab_view(name, output_function):
96+
# Session state
97+
if "generated" not in st.session_state:
98+
st.session_state[f"generated_{name}"] = []
99+
100+
if "user_input" not in st.session_state:
101+
st.session_state[f"user_input_{name}"] = []
102+
103+
user_input = st.text_input(
104+
f"{name} mode",
105+
placeholder="Ask your question",
106+
key=f"route_{name}",
107+
label_visibility="hidden",
108+
)
58109

59-
if "user_input" not in st.session_state:
60-
st.session_state["user_input"] = []
110+
if user_input:
111+
with st.spinner():
112+
output = output_function(user_input)
61113

114+
st.session_state[f"user_input_{name}"].append(user_input)
115+
st.session_state[f"generated_{name}"].append(output)
62116

63-
def get_text() -> str:
64-
input_text = st.chat_input("Ask away?")
65-
return input_text
117+
if st.session_state[f"generated_{name}"]:
118+
size = len(st.session_state[f"generated_{name}"])
119+
# Display only the last three exchanges
120+
for i in range(max(size - 3, 0), size):
121+
with st.chat_message("user"):
122+
st.write(st.session_state[f"user_input_{name}"][i])
66123

124+
with st.chat_message("assistant"):
125+
st.write(st.session_state[f"generated_{name}"][i])
67126

68-
user_input = get_text()
69127

70-
if user_input:
71-
output = qa.run(user_input)
128+
llm_view, rag_view, kgrag_view = st.tabs(["LLM only", "Vector", "Vector + Graph"])
72129

73-
st.session_state.user_input.append(user_input)
74-
st.session_state.generated.append(output)
130+
with llm_view:
131+
tab_view("llm", generate_llm_output)
75132

76-
if st.session_state["generated"]:
77-
size = len(st.session_state["generated"])
78-
# Display only the last three exchanges
79-
for i in range(max(size - 3, 0), size):
80-
with st.chat_message("user"):
81-
st.write(st.session_state["user_input"][i])
133+
with rag_view:
134+
tab_view("rag", qa.run)
82135

83-
with st.chat_message("assistant"):
84-
st.write(st.session_state["generated"][i])
136+
with kgrag_view:
137+
tab_view("kg", kg_qa.run)

0 commit comments

Comments
 (0)