88from langchain .embeddings .openai import OpenAIEmbeddings
99from langchain .embeddings import OllamaEmbeddings , SentenceTransformerEmbeddings
1010from langchain .chat_models import ChatOpenAI , ChatOllama
11- from langchain .chains import ConversationalRetrievalChain
11+ from langchain .chains import RetrievalQAWithSourcesChain
12+ from langchain .chains .qa_with_sources import load_qa_with_sources_chain
1213from langchain .prompts .chat import (
1314 ChatPromptTemplate ,
1415 SystemMessagePromptTemplate ,
3435
3536neo4j_graph = Neo4jGraph (url = url , username = username , password = password )
3637
37- def create_vector_index (dimension ):
38+
39+ def create_vector_index (dimension : int ) -> None :
3840 index_query = "CALL db.index.vector.createNodeIndex('stackoverflow', 'Question', 'embedding', $dimension, 'cosine')"
3941 try :
4042 neo4j_graph .query (index_query , {"dimension" : dimension })
@@ -46,6 +48,7 @@ def create_vector_index(dimension):
4648 except : # Already exists
4749 pass
4850
51+
4952class StreamHandler (BaseCallbackHandler ):
5053 def __init__ (self , container , initial_text = "" ):
5154 self .container = container
@@ -55,6 +58,7 @@ def on_llm_new_token(self, token: str, **kwargs) -> None:
5558 self .text += token
5659 self .container .markdown (self .text )
5760
61+
5862if embedding_model_name == "ollama" :
5963 embeddings = OllamaEmbeddings (base_url = ollama_base_url , model = "llama2" )
6064 dimension = 4096
@@ -64,7 +68,9 @@ def on_llm_new_token(self, token: str, **kwargs) -> None:
6468 dimension = 1536
6569 logger .info ("Embedding: Using OpenAI" )
6670else :
67- embeddings = SentenceTransformerEmbeddings (model_name = "all-MiniLM-L6-v2" , cache_folder = "/embedding_model" )
71+ embeddings = SentenceTransformerEmbeddings (
72+ model_name = "all-MiniLM-L6-v2" , cache_folder = "/embedding_model"
73+ )
6874 dimension = 384
6975 logger .info ("Embedding: Using SentenceTransformer" )
7076
@@ -75,7 +81,13 @@ def on_llm_new_token(self, token: str, **kwargs) -> None:
7581 logger .info ("LLM: Using GPT-4" )
7682elif llm_name == "ollama" :
7783 llm = ChatOllama (
78- temperature = 0 , base_url = ollama_base_url , model = "llama2" , streaming = True
84+ temperature = 0 ,
85+ base_url = ollama_base_url ,
86+ model = "llama2" ,
87+ streaming = True ,
88+ top_k = 10 , # A higher value (100) will give more diverse answers, while a lower value (10) will be more conservative.
89+ top_p = 0.3 , # Higher value (0.95) will lead to more diverse text, while a lower value (0.5) will generate more focused text.
90+ num_ctx = 3072 , # Sets the size of the context window used to generate the next token.
7991 )
8092 logger .info ("LLM: Using Ollama (llama2)" )
8193else :
@@ -102,10 +114,10 @@ def generate_llm_output(user_input: str, callbacks: List[Any]) -> str:
102114 ).to_messages (),
103115 callbacks = callbacks ,
104116 ).content
105- return answer
117+ return { ' answer' : answer }
106118
107119
108- # Rag response
120+ # Vector response
109121neo4j_db = Neo4jVector .from_existing_index (
110122 embedding = embeddings ,
111123 url = url ,
@@ -116,18 +128,28 @@ def generate_llm_output(user_input: str, callbacks: List[Any]) -> str:
116128 text_node_property = "body" , # text by default
117129 retrieval_query = """
118130 OPTIONAL MATCH (node)-[:ANSWERS]->(question)
119- RETURN question.title + '\n ' + question.body + '\n ' + coalesce(node.body,"") AS text, score, {source:question.link} AS metadata
131+ RETURN 'Question: ' + question.title + '\n ' + question.body + '\n Answer: ' +
132+ coalesce(node.body,"") AS text, score, {source:question.link} AS metadata
133+ ORDER BY score ASC // so that best answer are the last
120134""" ,
121135)
122136
123137general_system_template = """
124138Use the following pieces of context to answer the question at the end.
139+ The context contains question-answer pairs and their links from Stackoverflow.
140+ You should prefer information from accepted or more upvoted answers.
141+ Make sure to rely on information from the answers and not on questions to provide accuate responses.
142+ When you find particular answer in the context useful, make sure to cite it in the answer using the link.
125143If you don't know the answer, just say that you don't know, don't try to make up an answer.
126- Each document in the context contains the source information.
127- For every document that you use information from, return their source link at the end of the answer.
128144----
129- {context }
145+ {summaries }
130146----
147+ Each answer you generate should contain a section at the end of links to
148+ Stackoverflow questions and answers you found useful, which are described under Source value.
149+ You can only use links to StackOverflow questions that are present in the context and always
150+ add links to the end of the answer in the style of citations.
151+ Generate concise answers with references sources section of links to
152+ relevant StackOverflow questions only at the end of the answer.
131153"""
132154general_user_template = "Question:```{question}```"
133155messages = [
@@ -136,13 +158,19 @@ def generate_llm_output(user_input: str, callbacks: List[Any]) -> str:
136158]
137159qa_prompt = ChatPromptTemplate .from_messages (messages )
138160
139- qa = ConversationalRetrievalChain . from_llm (
161+ qa_chain = load_qa_with_sources_chain (
140162 llm ,
163+ chain_type = "stuff" ,
164+ prompt = qa_prompt ,
165+ )
166+ qa = RetrievalQAWithSourcesChain (
167+ combine_documents_chain = qa_chain ,
141168 retriever = neo4j_db .as_retriever (search_kwargs = {"k" : 2 }),
142- combine_docs_chain_kwargs = {"prompt" : qa_prompt },
169+ reduce_k_below_max_tokens = True ,
170+ max_tokens_limit = 3375 ,
143171)
144172
145- # Rag + Knowledge Graph response
173+ # Vector + Knowledge Graph response
146174kg = Neo4jVector .from_existing_index (
147175 embedding = embeddings ,
148176 url = url ,
@@ -156,20 +184,24 @@ def generate_llm_output(user_input: str, callbacks: List[Any]) -> str:
156184 MATCH (node)<-[:ANSWERS]-(a)
157185 WITH a
158186 ORDER BY a.is_accepted DESC, a.score DESC
159- WITH collect(a.body)[..2] as answers
160- RETURN reduce(str='', text IN answers | str + text + '\n ') as answerTexts
187+ WITH collect(a)[..2] as answers
188+ RETURN reduce(str='', a IN answers | str +
189+ '\n ### Answer (Accepted: '+ a.is_accepted +' Score: ' + a.score+ '): '+ a.body + '\n ') as answerTexts
161190}
162- RETURN node.title + '\n ' + node.body + '\n ' + answerTexts AS text, score, {source:node.link} AS metadata
191+ RETURN '##Question: ' + node.title + '\n ' + node.body + '\n '
192+ + answerTexts AS text, score, {source: node.link} AS metadata
193+ ORDER BY score ASC // so that best answers are the last
163194""" ,
164195)
165196
166- kg_qa = ConversationalRetrievalChain . from_llm (
167- llm ,
197+ kg_qa = RetrievalQAWithSourcesChain (
198+ combine_documents_chain = qa_chain ,
168199 retriever = kg .as_retriever (search_kwargs = {"k" : 2 }),
169- combine_docs_chain_kwargs = {"prompt" : qa_prompt },
200+ reduce_k_below_max_tokens = True ,
201+ max_tokens_limit = 3375 ,
170202)
171203
172- # Streamlit stuff
204+ # Streamlit UI
173205styl = f"""
174206<style>
175207 /* not great support for :has yet (hello FireFox), but using it for now */
@@ -195,8 +227,8 @@ def chat_input():
195227 stream_handler = StreamHandler (st .empty ())
196228 result = output_function (
197229 {"question" : user_input , "chat_history" : []}, callbacks = [stream_handler ]
198- )
199- output = result # ["answer"] + "\n" + result["sources"]
230+ )[ 'answer' ]
231+ output = result
200232 st .session_state [f"user_input" ].append (user_input )
201233 st .session_state [f"generated" ].append (output )
202234 st .session_state [f"rag_mode" ].append (name )
@@ -216,7 +248,6 @@ def display_chat():
216248 if st .session_state [f"generated" ]:
217249 size = len (st .session_state [f"generated" ])
218250 # Display only the last three exchanges
219- # Excluding the latest since it's streamed
220251 for i in range (max (size - 3 , 0 ), size ):
221252 with st .chat_message ("user" ):
222253 st .write (st .session_state [f"user_input" ][i ])
@@ -235,9 +266,9 @@ def mode_select() -> str:
235266if name == "LLM only" :
236267 output_function = generate_llm_output
237268elif name == "Vector" :
238- output_function = qa . run
269+ output_function = qa
239270elif name == "Vector + Graph" :
240- output_function = kg_qa . run
271+ output_function = kg_qa
241272
242273display_chat ()
243274chat_input ()
0 commit comments