-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerative_bot.py
More file actions
87 lines (72 loc) · 2.8 KB
/
generative_bot.py
File metadata and controls
87 lines (72 loc) · 2.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from collections import deque
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer
from utilities import generate_response
import pandas as pd
import numpy as np
import pickle
from utilities import encode_rag, cosine_sim_rag, top_candidates
class ChatBot:
def __init__(self):
self.conversation_history = deque([], maxlen=10)
self.generative_model = None
self.generative_tokenizer = None
self.vect_data = []
self.scripts = []
self.ranking_model = None
def load(self):
""" "This method is called first to load all datasets and
model used by the chat bot; all the data to be saved in
tha data folder, models to be loaded from hugging face"""
with open("data/spock_lines_vectorized.pkl", "rb") as fp:
self.vect_data = pickle.load(fp)
self.scripts = pd.read_pickle("data/spock_lines.pkl")
self.ranking_model = SentenceTransformer(
"greatakela/gnlp_hw1_encoder"
)
self.generative_model = AutoModelForSeq2SeqLM.from_pretrained(
"greatakela/flan-t5-small-gen-chat_v3"
)
self.generative_tokenizer = AutoTokenizer.from_pretrained(
"greatakela/flan-t5-small-gen-chat_v3"
)
def generate_answer(self, utterance):
query_encoding = encode_rag(
texts=utterance,
model=self.ranking_model,
contexts=self.conversation_history,
)
print("Query Encoding Shape:", query_encoding.shape)
print("Stored Embeddings Shape:", np.array(self.vect_data).shape)
bot_cosine_scores = cosine_sim_rag(
self.vect_data,
query_encoding,
)
top_scores, top_indexes = top_candidates(
bot_cosine_scores, initial_data=self.scripts
)
print(top_scores, top_indexes) # for debugging
if top_scores[0] >= 0.9:
for index in top_indexes:
rag_answer = self.scripts.iloc[index]["ANSWER"]
answer = generate_response(
model=self.generative_model,
tokenizer=self.generative_tokenizer,
question=utterance,
context=self.conversation_history,
top_p=0.9,
temperature=0.95,
rag_answer=rag_answer,
)
else:
answer = generate_response(
model=self.generative_model,
tokenizer=self.generative_tokenizer,
question=utterance,
context=self.conversation_history,
top_p=0.9,
temperature=0.95,
)
self.conversation_history.append(utterance)
self.conversation_history.append(answer)
return answer