generated from amazon-archives/__template_MIT-0
-
Notifications
You must be signed in to change notification settings - Fork 105
Expand file tree
/
Copy pathkendra_chat_falcon_40b.py
More file actions
135 lines (115 loc) · 4.58 KB
/
kendra_chat_falcon_40b.py
File metadata and controls
135 lines (115 loc) · 4.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from langchain.retrievers import AmazonKendraRetriever
from langchain.chains import ConversationalRetrievalChain
from langchain import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.prompts import PromptTemplate
import sys
import json
import os
class bcolors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKCYAN = '\033[96m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
MAX_HISTORY_LENGTH = 5
def build_chain():
region = os.environ["AWS_REGION"]
kendra_index_id = os.environ["KENDRA_INDEX_ID"]
endpoint_name = os.environ["FALCON_40B_ENDPOINT"]
if "INFERENCE_COMPONENT_NAME" in os.environ:
inference_component_name = os.environ["INFERENCE_COMPONENT_NAME"]
class ContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
return input_str.encode('utf-8')
def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
return response_json[0]["generated_text"]
content_handler = ContentHandler()
if inference_component_name:
llm=SagemakerEndpoint(
endpoint_name=endpoint_name,
region_name=region,
model_kwargs={
"temperature": 0.8,
"max_new_tokens": 512,
"do_sample": True,
"top_p": 0.9,
"repetition_penalty": 1.03,
"stop": ["\nUser:","<|endoftext|>","</s>"],
},
endpoint_kwargs={"CustomAttributes":"accept_eula=true",
"InferenceComponentName":inference_component_name},
content_handler=content_handler
)
else :
llm=SagemakerEndpoint(
endpoint_name=endpoint_name,
region_name=region,
model_kwargs={
"temperature": 0.8,
"max_new_tokens": 512,
"do_sample": True,
"top_p": 0.9,
"repetition_penalty": 1.03,
"stop": ["\nUser:","<|endoftext|>","</s>"],
},
content_handler=content_handler
)
retriever = AmazonKendraRetriever(index_id=kendra_index_id,region_name=region, top_k=2)
prompt_template = """
{context}
Instruction: Based on the above documents, provide a detailed answer for, {question} Answer "don't know"
if not present in the document.
Solution:"""
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
condense_qa_template = """
Given the following conversation and a follow up question, rephrase the follow up question
to be a standalone question.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
standalone_question_prompt = PromptTemplate.from_template(condense_qa_template)
qa = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
condense_question_prompt=standalone_question_prompt,
return_source_documents=True,
verbose =True,
combine_docs_chain_kwargs={"prompt":PROMPT})
return qa
def run_chain(chain, prompt: str, history=[]):
return chain({"question": prompt, "chat_history": history})
if __name__ == "__main__":
chat_history = []
qa = build_chain()
print(bcolors.OKBLUE + "Hello! How can I help you?" + bcolors.ENDC)
print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC)
print(">", end=" ", flush=True)
for query in sys.stdin:
if (query.strip().lower().startswith("new search:")):
query = query.strip().lower().replace("new search:","")
chat_history = []
elif (len(chat_history) == MAX_HISTORY_LENGTH):
chat_history.pop(0)
result = run_chain(qa, query, chat_history)
chat_history.append((query, result["answer"]))
print(bcolors.OKGREEN + result['answer'] + bcolors.ENDC)
if 'source_documents' in result:
print(bcolors.OKGREEN + 'Sources:')
for d in result['source_documents']:
print(d.metadata['source'])
print(bcolors.ENDC)
print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC)
print(">", end=" ", flush=True)
print(bcolors.OKBLUE + "Bye" + bcolors.ENDC)