Skip to content

Commit 9e0d126

Browse files
authored
Merge pull request #8 from neo4j-contrib/sourceprompt
prompt engineering sources
2 parents 57a22b8 + cdaa78c commit 9e0d126

File tree

1 file changed

+56
-25
lines changed

1 file changed

+56
-25
lines changed

bot.py

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from langchain.embeddings.openai import OpenAIEmbeddings
99
from langchain.embeddings import OllamaEmbeddings, SentenceTransformerEmbeddings
1010
from langchain.chat_models import ChatOpenAI, ChatOllama
11-
from langchain.chains import ConversationalRetrievalChain
11+
from langchain.chains import RetrievalQAWithSourcesChain
12+
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
1213
from langchain.prompts.chat import (
1314
ChatPromptTemplate,
1415
SystemMessagePromptTemplate,
@@ -34,7 +35,8 @@
3435

3536
neo4j_graph = Neo4jGraph(url=url, username=username, password=password)
3637

37-
def create_vector_index(dimension):
38+
39+
def create_vector_index(dimension: int) -> None:
3840
index_query = "CALL db.index.vector.createNodeIndex('stackoverflow', 'Question', 'embedding', $dimension, 'cosine')"
3941
try:
4042
neo4j_graph.query(index_query, {"dimension": dimension})
@@ -46,6 +48,7 @@ def create_vector_index(dimension):
4648
except: # Already exists
4749
pass
4850

51+
4952
class StreamHandler(BaseCallbackHandler):
5053
def __init__(self, container, initial_text=""):
5154
self.container = container
@@ -55,6 +58,7 @@ def on_llm_new_token(self, token: str, **kwargs) -> None:
5558
self.text += token
5659
self.container.markdown(self.text)
5760

61+
5862
if embedding_model_name == "ollama":
5963
embeddings = OllamaEmbeddings(base_url=ollama_base_url, model="llama2")
6064
dimension = 4096
@@ -64,7 +68,9 @@ def on_llm_new_token(self, token: str, **kwargs) -> None:
6468
dimension = 1536
6569
logger.info("Embedding: Using OpenAI")
6670
else:
67-
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2", cache_folder="/embedding_model")
71+
embeddings = SentenceTransformerEmbeddings(
72+
model_name="all-MiniLM-L6-v2", cache_folder="/embedding_model"
73+
)
6874
dimension = 384
6975
logger.info("Embedding: Using SentenceTransformer")
7076

@@ -75,7 +81,13 @@ def on_llm_new_token(self, token: str, **kwargs) -> None:
7581
logger.info("LLM: Using GPT-4")
7682
elif llm_name == "ollama":
7783
llm = ChatOllama(
78-
temperature=0, base_url=ollama_base_url, model="llama2", streaming=True
84+
temperature=0,
85+
base_url=ollama_base_url,
86+
model="llama2",
87+
streaming=True,
88+
top_k=10, # A higher value (100) will give more diverse answers, while a lower value (10) will be more conservative.
89+
top_p=0.3, # Higher value (0.95) will lead to more diverse text, while a lower value (0.5) will generate more focused text.
90+
num_ctx=3072, # Sets the size of the context window used to generate the next token.
7991
)
8092
logger.info("LLM: Using Ollama (llama2)")
8193
else:
@@ -102,10 +114,10 @@ def generate_llm_output(user_input: str, callbacks: List[Any]) -> str:
102114
).to_messages(),
103115
callbacks=callbacks,
104116
).content
105-
return answer
117+
return {'answer':answer}
106118

107119

108-
# Rag response
120+
# Vector response
109121
neo4j_db = Neo4jVector.from_existing_index(
110122
embedding=embeddings,
111123
url=url,
@@ -116,18 +128,28 @@ def generate_llm_output(user_input: str, callbacks: List[Any]) -> str:
116128
text_node_property="body", # text by default
117129
retrieval_query="""
118130
OPTIONAL MATCH (node)-[:ANSWERS]->(question)
119-
RETURN question.title + '\n' + question.body + '\n' + coalesce(node.body,"") AS text, score, {source:question.link} AS metadata
131+
RETURN 'Question: ' + question.title + '\n' + question.body + '\nAnswer: ' +
132+
coalesce(node.body,"") AS text, score, {source:question.link} AS metadata
133+
ORDER BY score ASC // so that best answer are the last
120134
""",
121135
)
122136

123137
general_system_template = """
124138
Use the following pieces of context to answer the question at the end.
139+
The context contains question-answer pairs and their links from Stackoverflow.
140+
You should prefer information from accepted or more upvoted answers.
141+
Make sure to rely on information from the answers and not on questions to provide accuate responses.
142+
When you find particular answer in the context useful, make sure to cite it in the answer using the link.
125143
If you don't know the answer, just say that you don't know, don't try to make up an answer.
126-
Each document in the context contains the source information.
127-
For every document that you use information from, return their source link at the end of the answer.
128144
----
129-
{context}
145+
{summaries}
130146
----
147+
Each answer you generate should contain a section at the end of links to
148+
Stackoverflow questions and answers you found useful, which are described under Source value.
149+
You can only use links to StackOverflow questions that are present in the context and always
150+
add links to the end of the answer in the style of citations.
151+
Generate concise answers with references sources section of links to
152+
relevant StackOverflow questions only at the end of the answer.
131153
"""
132154
general_user_template = "Question:```{question}```"
133155
messages = [
@@ -136,13 +158,19 @@ def generate_llm_output(user_input: str, callbacks: List[Any]) -> str:
136158
]
137159
qa_prompt = ChatPromptTemplate.from_messages(messages)
138160

139-
qa = ConversationalRetrievalChain.from_llm(
161+
qa_chain = load_qa_with_sources_chain(
140162
llm,
163+
chain_type="stuff",
164+
prompt=qa_prompt,
165+
)
166+
qa = RetrievalQAWithSourcesChain(
167+
combine_documents_chain=qa_chain,
141168
retriever=neo4j_db.as_retriever(search_kwargs={"k": 2}),
142-
combine_docs_chain_kwargs={"prompt": qa_prompt},
169+
reduce_k_below_max_tokens=True,
170+
max_tokens_limit=3375,
143171
)
144172

145-
# Rag + Knowledge Graph response
173+
# Vector + Knowledge Graph response
146174
kg = Neo4jVector.from_existing_index(
147175
embedding=embeddings,
148176
url=url,
@@ -156,20 +184,24 @@ def generate_llm_output(user_input: str, callbacks: List[Any]) -> str:
156184
MATCH (node)<-[:ANSWERS]-(a)
157185
WITH a
158186
ORDER BY a.is_accepted DESC, a.score DESC
159-
WITH collect(a.body)[..2] as answers
160-
RETURN reduce(str='', text IN answers | str + text + '\n') as answerTexts
187+
WITH collect(a)[..2] as answers
188+
RETURN reduce(str='', a IN answers | str +
189+
'\n### Answer (Accepted: '+ a.is_accepted +' Score: ' + a.score+ '): '+ a.body + '\n') as answerTexts
161190
}
162-
RETURN node.title + '\n' + node.body + '\n' + answerTexts AS text, score, {source:node.link} AS metadata
191+
RETURN '##Question: ' + node.title + '\n' + node.body + '\n'
192+
+ answerTexts AS text, score, {source: node.link} AS metadata
193+
ORDER BY score ASC // so that best answers are the last
163194
""",
164195
)
165196

166-
kg_qa = ConversationalRetrievalChain.from_llm(
167-
llm,
197+
kg_qa = RetrievalQAWithSourcesChain(
198+
combine_documents_chain=qa_chain,
168199
retriever=kg.as_retriever(search_kwargs={"k": 2}),
169-
combine_docs_chain_kwargs={"prompt": qa_prompt},
200+
reduce_k_below_max_tokens=True,
201+
max_tokens_limit=3375,
170202
)
171203

172-
# Streamlit stuff
204+
# Streamlit UI
173205
styl = f"""
174206
<style>
175207
/* not great support for :has yet (hello FireFox), but using it for now */
@@ -195,8 +227,8 @@ def chat_input():
195227
stream_handler = StreamHandler(st.empty())
196228
result = output_function(
197229
{"question": user_input, "chat_history": []}, callbacks=[stream_handler]
198-
)
199-
output = result # ["answer"] + "\n" + result["sources"]
230+
)['answer']
231+
output = result
200232
st.session_state[f"user_input"].append(user_input)
201233
st.session_state[f"generated"].append(output)
202234
st.session_state[f"rag_mode"].append(name)
@@ -216,7 +248,6 @@ def display_chat():
216248
if st.session_state[f"generated"]:
217249
size = len(st.session_state[f"generated"])
218250
# Display only the last three exchanges
219-
# Excluding the latest since it's streamed
220251
for i in range(max(size - 3, 0), size):
221252
with st.chat_message("user"):
222253
st.write(st.session_state[f"user_input"][i])
@@ -235,9 +266,9 @@ def mode_select() -> str:
235266
if name == "LLM only":
236267
output_function = generate_llm_output
237268
elif name == "Vector":
238-
output_function = qa.run
269+
output_function = qa
239270
elif name == "Vector + Graph":
240-
output_function = kg_qa.run
271+
output_function = kg_qa
241272

242273
display_chat()
243274
chat_input()

0 commit comments

Comments
 (0)