Skip to content

Commit 9475b96

Browse files
authored
Merge pull request #5 from neo4j-contrib/retrieval
update loader & bot
2 parents 4b73f6f + b2a8e2a commit 9475b96

File tree

2 files changed

+49
-33
lines changed

2 files changed

+49
-33
lines changed

bot.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,10 @@
55
from langchain.embeddings.openai import OpenAIEmbeddings
66
from langchain.embeddings import OllamaEmbeddings
77
from langchain.chat_models import ChatOpenAI, ChatOllama
8-
from langchain.chains import ConversationalRetrievalChain
9-
from langchain.memory import ConversationBufferMemory
8+
from langchain.chains import RetrievalQAWithSourcesChain
109
from langchain.prompts.chat import (
1110
ChatPromptTemplate,
1211
SystemMessagePromptTemplate,
13-
AIMessagePromptTemplate,
1412
HumanMessagePromptTemplate,
1513
)
1614
from dotenv import load_dotenv
@@ -28,7 +26,7 @@
2826
# llm = ChatOllama(temperature=0, base_url=ollama_base_url)
2927

3028
embeddings = OpenAIEmbeddings()
31-
llm = ChatOpenAI(temperature=0)
29+
llm = ChatOpenAI(temperature=0, model_name="gpt-4")
3230

3331
# LLM only response
3432
template = "You are a helpful assistant that helps with programming questions."
@@ -57,10 +55,21 @@ def generate_llm_output(user_input: str) -> str:
5755
database="neo4j", # neo4j by default
5856
index_name="stackoverflow", # vector by default
5957
text_node_property="body", # text by default
58+
retrieval_query="""
59+
CALL { with node
60+
MATCH (node)<-[:ANSWERS]-(a)
61+
WITH a
62+
ORDER BY a.is_accepted DESC, a.score DESC
63+
WITH collect(a.body)[..1] as answers
64+
RETURN reduce(str='', text IN answers | str + text + '\n') as answerTexts
65+
}
66+
RETURN node.body + '\n' + answerTexts AS text, score, {source:node.link} AS metadata
67+
""",
6068
)
6169

62-
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
63-
qa = ConversationalRetrievalChain.from_llm(llm, neo4j_db.as_retriever(), memory=memory)
70+
qa = RetrievalQAWithSourcesChain.from_chain_type(
71+
llm, chain_type="stuff", retriever=neo4j_db.as_retriever(search_kwargs={"k": 2})
72+
)
6473

6574
# Rag + KG
6675
kg = Neo4jVector.from_existing_index(
@@ -74,8 +83,9 @@ def generate_llm_output(user_input: str) -> str:
7483
retrieval_query="RETURN 'fancy' AS text, 1 AS score, {} AS metadata", # Fix this
7584
)
7685

77-
kg_memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
78-
kg_qa = ConversationalRetrievalChain.from_llm(llm, kg.as_retriever(), memory=kg_memory)
86+
kg_qa = RetrievalQAWithSourcesChain.from_chain_type(
87+
llm, chain_type="stuff", retriever=kg.as_retriever(search_kwargs={"k": 2})
88+
)
7989

8090
# Streamlit stuff
8191
styl = f"""
@@ -91,6 +101,7 @@ def generate_llm_output(user_input: str) -> str:
91101
"""
92102
st.markdown(styl, unsafe_allow_html=True)
93103

104+
94105
def chat_input():
95106
# Session state
96107
if "generated" not in st.session_state:
@@ -105,11 +116,16 @@ def chat_input():
105116
user_input = st.chat_input("What coding issue can I help you resolve today?")
106117

107118
if user_input:
108-
output = output_function(user_input)
119+
try:
120+
data = output_function(user_input)
121+
output = data["answer"] + "\n" + data["sources"]
122+
except KeyError:
123+
output = output_function(user_input)
109124
st.session_state[f"user_input"].append(user_input)
110125
st.session_state[f"generated"].append(output)
111126
st.session_state[f"rag_mode"].append(name)
112127

128+
113129
def display_chat():
114130
if st.session_state[f"generated"]:
115131
size = len(st.session_state[f"generated"])
@@ -122,19 +138,19 @@ def display_chat():
122138
st.caption(f"Mode: {st.session_state[f'rag_mode'][i]}")
123139
st.write(st.session_state[f"generated"][i])
124140

141+
125142
def mode_select() -> str:
126143
options = ["LLM only", "Vector", "Vector + Graph"]
127144
return st.radio("Select sophistication mode", options, horizontal=True)
128145

146+
129147
name = mode_select()
130-
if(name == "LLM only"):
148+
if name == "LLM only":
131149
output_function = generate_llm_output
132-
elif(name == "Vector"):
133-
output_function = qa.run
134-
elif(name == "Vector + Graph"):
135-
output_function = kg_qa.run
150+
elif name == "Vector":
151+
output_function = qa
152+
elif name == "Vector + Graph":
153+
output_function = kg_qa
136154

137155
chat_input()
138156
display_chat()
139-
140-

loader.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import requests
33

44
from dotenv import load_dotenv
5-
from bs4 import BeautifulSoup
65
from langchain.embeddings import OllamaEmbeddings, OpenAIEmbeddings
76
from langchain.graphs import Neo4jGraph
87

@@ -43,6 +42,7 @@ def create_constraints():
4342

4443
create_constraints()
4544

45+
4646
def create_vector_index(dimension):
4747
# TODO use Neo4jVector Code from LangChain on the existing graph
4848
index_query = "CALL db.index.vector.createNodeIndex('stackoverflow', 'Question', 'embedding', $dimension, 'cosine')"
@@ -56,27 +56,22 @@ def create_vector_index(dimension):
5656

5757

5858
def load_so_data(tag: str = "neo4j", page: int = 1) -> None:
59-
base_url = "https://api.stackexchange.com/2.2/questions"
59+
base_url = "https://api.stackexchange.com/2.3/search/advanced"
6060
parameters = (
61-
f"?pagesize=100&page={page}&order=desc&sort=creation&tagged={tag}"
62-
"&site=stackoverflow&filter=!6WPIomnMNcVD9"
61+
f"?pagesize=100&page={page}&order=desc&sort=creation&answers=1&tagged={tag}"
62+
"&site=stackoverflow&filter=!51dU0b1n(WTdqj5MH1iGsNShY6BhXXwJ)xwV5b"
6363
)
6464
data = requests.get(base_url + parameters).json()
6565
# Convert html to text and calculate embedding values
6666
for q in data["items"]:
67-
question_text = BeautifulSoup(q["body"], features="html.parser").text
68-
q["body"] = question_text
69-
q["embedding"] = embeddings.embed_query(q["title"] + " " + question_text)
70-
if q.get("answers"):
71-
for a in q.get("answers"):
72-
a["body"] = BeautifulSoup(a["body"], features="html.parser").text
67+
q["embedding"] = embeddings.embed_query(q["title"] + " " + q["body_markdown"])
7368

7469
import_query = """
7570
UNWIND $data AS q
7671
MERGE (question:Question {id:q.question_id})
7772
ON CREATE SET question.title = q.title, question.link = q.link,
78-
question.favorite_count = q.favorite_count, question.creation_date = q.creation_date,
79-
question.body = q.body, question.embedding = q.embedding
73+
question.favorite_count = q.favorite_count, question.creation_date = datetime({epochSeconds: q.creation_date}),
74+
question.body = q.body_markdown, question.embedding = q.embedding
8075
FOREACH (tagName IN q.tags |
8176
MERGE (tag:Tag {name:tagName})
8277
MERGE (question)-[:TAGGED]->(tag)
@@ -85,8 +80,8 @@ def load_so_data(tag: str = "neo4j", page: int = 1) -> None:
8580
MERGE (question)<-[:ANSWERS]-(answer:Answer {id:a.answer_id})
8681
SET answer.is_accepted = a.is_accepted,
8782
answer.score = a.score,
88-
answer.creation_date = a.creation_date,
89-
answer.body = a.body
83+
answer.creation_date = datetime({epochSeconds:a.creation_date}),
84+
answer.body = a.body_markdown
9085
MERGE (answerer:User {id:coalesce(a.owner.user_id, "deleted")})
9186
ON CREATE SET answerer.display_name = a.owner.display_name,
9287
answerer.reputation= a.owner.reputation
@@ -103,19 +98,24 @@ def load_so_data(tag: str = "neo4j", page: int = 1) -> None:
10398

10499
# Streamlit
105100
def get_tag() -> str:
106-
input_text = st.text_input("Which tag questions do you want to import?", value="neo4j")
101+
input_text = st.text_input(
102+
"Which tag questions do you want to import?", value="neo4j"
103+
)
107104
return input_text
108105

109106

110107
def get_pages():
111108
col1, col2 = st.columns(2)
112109
with col1:
113-
num_pages = st.number_input("Number of pages (100 questions per page)", step=1, min_value=1)
110+
num_pages = st.number_input(
111+
"Number of pages (100 questions per page)", step=1, min_value=1
112+
)
114113
with col2:
115114
start_page = st.number_input("Start page", step=1, min_value=1)
116115
st.caption("Only questions with answers will be imported.")
117116
return (int(num_pages), int(start_page))
118117

118+
119119
st.header("StackOverflow Loader")
120120
st.subheader("Choose StackOverflow tags to load into Neo4j")
121121
st.caption("Go to http://localhost:7474/browser/ to explore the graph.")
@@ -127,7 +127,7 @@ def get_pages():
127127
with st.spinner("Loading... This might take a minute or two."):
128128
try:
129129
for page in range(1, num_pages + 1):
130-
load_so_data(user_input, start_page + (page-1))
130+
load_so_data(user_input, start_page + (page - 1))
131131
st.success("Import successful", icon="✅")
132132
except Exception as e:
133133
st.error(f"Error: {e}", icon="🚨")

0 commit comments

Comments
 (0)