Skip to content

Commit 17cefcc

Browse files
committed
Update RAG notebook
1 parent e2961b9 commit 17cefcc

File tree

5 files changed

+364
-3
lines changed

5 files changed

+364
-3
lines changed

docs/module_04_llm_ops/01_retrieval_augmented_llm_app.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@
105105
"name": "stdout",
106106
"output_type": "stream",
107107
"text": [
108-
"Overwriting app.py\n"
108+
"Writing app.py\n"
109109
]
110110
}
111111
],
@@ -198,14 +198,14 @@
198198
},
199199
{
200200
"cell_type": "code",
201-
"execution_count": 3,
201+
"execution_count": null,
202202
"metadata": {
203203
"id": "mlGSHYN0bQSm"
204204
},
205205
"outputs": [],
206206
"source": [
207207
"# check the log file for localhost port\n",
208-
"# !streamlit run app.py &>logs.txt & "
208+
"!streamlit run app.py &>logs.txt & "
209209
]
210210
},
211211
{

docs/module_04_llm_ops/app.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
## import required components
2+
3+
import pandas as pd
4+
from utils import (
5+
get_lines,
6+
load_data,
7+
get_relevant_documents,
8+
get_answer,
9+
create_db,
10+
sidebar,
11+
)
12+
import streamlit as st
13+
chroma_client, db = create_db()
14+
15+
## Setup Page Header and Sidebar
16+
st.set_page_config(page_title="PersonalGPT", page_icon="📖", layout="wide")
17+
lm_model = sidebar()
18+
st.header(f"📖PersonalGPT")
19+
st.markdown(f">:zap: Responses Powered by **{lm_model}**")
20+
21+
if 'is_doc_uploaded' not in st.session_state:
22+
st.session_state['is_doc_uploaded'] = False
23+
24+
25+
## Add Uploader Component
26+
uploaded_file = st.file_uploader(
27+
"Upload a txt file",
28+
type=["txt"],
29+
help="Text files with each sentence acting as a document",
30+
)
31+
32+
if not st.session_state['is_doc_uploaded']:
33+
## Check if upload is complete
34+
if not uploaded_file:
35+
st.stop()
36+
37+
## Read uploaded file
38+
try:
39+
file_data = get_lines(uploaded_file)
40+
## Verbose Status update
41+
st.markdown(f"> Uploaded file has {len(file_data)} lines of text")
42+
st.session_state['is_doc_uploaded'] = True
43+
except Exception as e:
44+
st.markdown(f"Could not upload/read file={e}")
45+
st.session_state['is_doc_uploaded'] = False
46+
47+
## Index Uploaded text file
48+
with st.spinner("Indexing document... This may take a while⏳"):
49+
db_status_msg = load_data(db, documents=file_data)
50+
51+
## status update
52+
st.markdown(f"> Database indexed {db.count()} documents")
53+
if db.count() == 0:
54+
st.markdown(db_status_msg)
55+
st.session_state['is_doc_uploaded'] = False
56+
57+
## Get User Input
58+
with st.form(key="qa_form"):
59+
query = st.text_area("Enter Your Query:",
60+
placeholder="Examples: \nwhat is tf-idf?\nwhich module covers RLHF\nhow many moons does Jupiter have?")
61+
submit = st.form_submit_button("Submit")
62+
63+
## Provide additional Options for citing source
64+
with st.expander("Advanced Options"):
65+
show_source = st.checkbox("Show Source")
66+
67+
## Generate Output upon button click
68+
if submit:
69+
# Get relevant documents from DB
70+
context = get_relevant_documents(query, db)
71+
72+
# get answer from LLM
73+
answer,score,error = get_answer(query,context,lm_model)
74+
75+
# Showcase response on screen
76+
st.markdown(f"**Answer:** _{answer}_")
77+
st.markdown(f"> **Relevance Score**:{score}")
78+
st.markdown("---")
79+
80+
# Add more details if advanced option is chosen
81+
if show_source:
82+
st.markdown("**Source(s):**")
83+
st.markdown(f"- <i>{context[:100]}...</i>", unsafe_allow_html=True)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#####################
2+
## Set Constants
3+
#####################
4+
HF_TOKEN = '<YOUR KEY>'
5+
OPENAI_TOKEN = '<YOUR KEY>'
6+
HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"}
7+
8+
# Constants for embedding model
9+
EMB_MODEL_ID = 'pinecone/mpnet-retriever-discourse'
10+
EMB_API_URL = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{EMB_MODEL_ID}"
11+
12+
# Constants for QA model
13+
QA_MODEL_ID = 'deepset/roberta-base-squad2'
14+
15+
# List of Different Endpoints
16+
HF_QA_ENDPOINT = 'HF-QA'
17+
HF_LM_ENDPOINT = 'HF-LM'
18+
OPENAI_ENDPOINT = 'OPENAI-LM'
19+
LOCAL_OLLAMA_ENDPOINT = 'OLLAMA'
20+
AVAILABLE_LMs = {
21+
'models':
22+
[
23+
'deepset/roberta-base-squad2',
24+
'Intel/dynamic_tinybert',
25+
#'google/gemma-2-2b-it', # this is timing out mostly
26+
'Local-LLAMA-3.1:8b',
27+
'OpenAI-GPT4o-mini'
28+
],
29+
'endpoints':
30+
[
31+
HF_QA_ENDPOINT,
32+
HF_QA_ENDPOINT,
33+
#HF_LM_ENDPOINT, #this is timing out mostly
34+
LOCAL_OLLAMA_ENDPOINT,
35+
OPENAI_ENDPOINT,
36+
]
37+
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Adapted From: https://gist.github.com/psychemedia/925e190e2afd15b050f32334ceff9ef6
2+
import os
3+
import nbformat
4+
5+
class NB_Markdown_Scraper:
6+
7+
def __init__(self,input_paths=None):
8+
self.notebook_md_dict = dict()
9+
self.input_paths = input_paths
10+
11+
def nbpathwalk(self,path):
12+
''' Walk down a directory path looking for ipynb notebook files... '''
13+
valid_notebook_files = []
14+
for path, _, files in os.walk(path):
15+
if '.ipynb_checkpoints' in path or 'solutions' in path : continue
16+
for f in [i for i in files if i.endswith('.ipynb') and not i.startswith('dontcommit')]:
17+
valid_notebook_files.append(os.path.join(path, f))
18+
return valid_notebook_files
19+
20+
21+
def get_cell_contents(self,nb_fn, c_md=None, cell_typ=None):
22+
''' Extract the content of Jupyter notebook cells. '''
23+
if cell_typ is None: cell_typ=['markdown']
24+
if c_md is None: c_md = []
25+
nb=nbformat.read(nb_fn,nbformat.NO_CONVERT)
26+
_c_md=[i for i in nb.cells if i['cell_type'] in cell_typ]
27+
ix=len(c_md)
28+
for c in _c_md:
29+
c.update( {"ix":str(ix)})
30+
c.update( {"title":nb_fn})
31+
ix = ix+1
32+
c_md = c_md + _c_md
33+
return c_md
34+
35+
36+
# scraper
37+
def scrape_markdowns(self):
38+
for directory in self.input_paths:
39+
directory_notebooks = self.nbpathwalk(directory)
40+
for notebook in directory_notebooks:
41+
notebook_cells = self.get_cell_contents(notebook, cell_typ=['markdown'])
42+
notebook_name = '_'.join(notebook.split('/')[1:]).split('.')[0]
43+
self.notebook_md_dict[notebook_name] = ' '.join([cell['source'] for cell in sorted(notebook_cells,
44+
key=lambda d: d['ix'])])
45+

docs/module_04_llm_ops/utils.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
#####################
2+
## imports
3+
#####################
4+
import pandas as pd
5+
import json
6+
import requests
7+
from retry import retry
8+
import streamlit as st
9+
import chromadb.utils.embedding_functions as embedding_functions
10+
from huggingface_hub import InferenceClient
11+
from openai import OpenAI
12+
import ollama
13+
from constants import (
14+
HF_TOKEN,
15+
OPENAI_TOKEN,
16+
HEADERS,
17+
EMB_MODEL_ID,
18+
EMB_API_URL,
19+
QA_MODEL_ID,
20+
HF_QA_ENDPOINT,
21+
HF_LM_ENDPOINT,
22+
OPENAI_ENDPOINT,
23+
LOCAL_OLLAMA_ENDPOINT,
24+
AVAILABLE_LMs)
25+
26+
27+
import chromadb
28+
29+
30+
lm_df = pd.DataFrame.from_dict(AVAILABLE_LMs)
31+
32+
#####################
33+
## Utility Functions
34+
#####################
35+
36+
def get_lines(uploaded_file):
37+
"""
38+
Utility to read raw text file in binary
39+
"""
40+
raw_data = []
41+
for line in uploaded_file:
42+
raw_data.append(line.decode("utf-8") )
43+
return raw_data
44+
45+
def create_db():
46+
"""
47+
Utility to instantiate vector db client and collection
48+
"""
49+
chroma_client = chromadb.Client()
50+
# huggingface_ef = embedding_functions.HuggingFaceEmbeddingFunction(
51+
# api_key=HF_TOKEN,
52+
# model_name=EMB_MODEL_ID
53+
# )
54+
db = chroma_client.get_or_create_collection(name="nlp_llm_workshop",)
55+
#embedding_function=huggingface_ef)
56+
return chroma_client,db
57+
58+
def load_data(db, documents):
59+
"""
60+
Utility to add/index data into vector db
61+
"""
62+
try:
63+
db.add(
64+
documents=documents,
65+
ids=[str(i) for i in range(len(documents))]
66+
)
67+
except Exception as ex:
68+
return "Apologies but I could not ingest document", 0.0, ex
69+
70+
def get_relevant_documents(query, db):
71+
"""
72+
Utility to retrieve relevant documents from vector DB
73+
"""
74+
try:
75+
relevant_doc = db.query(query_texts=[query], n_results=1)['documents'][0][0]
76+
return relevant_doc
77+
except Exception as ex:
78+
return "Apologies but I could not process your query", 0.0, ex
79+
80+
def get_hf_qa_answer(payload,lm_model):
81+
data = json.dumps(payload)
82+
try:
83+
QA_API_URL = f"https://api-inference.huggingface.co/models/{lm_model}"
84+
response = requests.request("POST", QA_API_URL, headers=HEADERS, data=data)
85+
decoded_response = json.loads(response.content.decode("utf-8"))
86+
return decoded_response['answer'], decoded_response['score'], ""
87+
except Exception as ex:
88+
return "Apologies but I could not find any relevant answer", 0.0, ex
89+
90+
# this is mostly timing out
91+
def get_hf_llm_answer(payload,lm_model):
92+
try:
93+
client = InferenceClient(
94+
"google/gemma-2-2b-it",
95+
token=HF_TOKEN,)
96+
97+
content = f"Given the context, answer the question. \ncontext:{payload['context']}\nquestion:{payload['question']}"
98+
response= client.chat_completion(
99+
messages=[{"role": "user", "content": content}],
100+
max_tokens=500,
101+
stream=False,
102+
)
103+
104+
return json.loads(message.choices[0].delta.content), 0.0
105+
except Exception as ex:
106+
return "Apologies but I could not find any relevant answer", 0.0, ex
107+
108+
def get_local_llama_answer(payload,lm_model):
109+
try:
110+
content = f"Given the context, perform the following tasks:1.Respond with a summarized answer to the question factually in few words only if the provided context contains the answer\n2.Check if your answer is really in the provided context, otherwise respond with 'Sorry I could not find the answer'.\n 3.Generate a relevance score between 0 and 1.\n4. Format the output as a json with answer and score as keys.\n5.Do not add makrdown syntax only respond with json.\nBe careful and Think step by step.\ncontext:{payload['context']}\nquestion:{payload['question']}"
111+
response = ollama.chat(model='llama3.1:8b', messages=[
112+
{
113+
'role': 'user',
114+
'content': content,
115+
},
116+
]
117+
)
118+
json_output = json.loads(response['message']['content'])
119+
return json_output['answer'], json_output['score'], ""
120+
except Exception as ex:
121+
st.markdown(ex)
122+
return "Apologies but I could not find any relevant answer", 0.0, ex
123+
124+
def get_opeai_answer(payload,lm_model):
125+
try:
126+
client = OpenAI(
127+
api_key=OPENAI_TOKEN,
128+
)
129+
content = f"Given the context, perform the following tasks:1.Respond with a summarized answer to the question factually in few words only if the provided context contains the answer\n 2.Generate a relevance score.\n3. Format the output as a json with answer and score as keys. Do not add makrdown syntax.\nThink step by step.\ncontext:{payload['context']}\nquestion:{payload['question']}"
130+
chat_completion = client.chat.completions.create(
131+
messages=[
132+
{
133+
"role": "user",
134+
"content": content,
135+
}
136+
],
137+
model="gpt-4o-mini",
138+
)
139+
json_output = json.loads(chat_completion.choices[0].message.content)
140+
return json_output['answer'], json_output['score'], ""
141+
except Exception as ex:
142+
return "Apologies but I could not find any relevant answer", 0.0, ex
143+
144+
145+
def get_answer(question,context,lm_model):
146+
"""
147+
Utility to leverage QA model for answering question using given context
148+
and the mentioned model
149+
"""
150+
payload = {
151+
"question": question,
152+
"context":context
153+
}
154+
try:
155+
endpoint_type = lm_df[lm_df['models']==lm_model]['endpoints'].values[0]
156+
if endpoint_type == HF_QA_ENDPOINT:
157+
return get_hf_qa_answer(payload,lm_model)
158+
elif endpoint_type == HF_LM_ENDPOINT:
159+
return get_hf_llm_answer(payload,lm_model)
160+
elif endpoint_type == OPENAI_ENDPOINT:
161+
return get_opeai_answer(payload,lm_model)
162+
elif endpoint_type == LOCAL_OLLAMA_ENDPOINT:
163+
return get_local_llama_answer(payload,lm_model)
164+
else:
165+
"This is not implemented yet", 0.0, ex
166+
except Exception as ex:
167+
return "Apologies but I could not find any relevant answer", 0.0, ex
168+
169+
170+
def sidebar():
171+
"""
172+
Utility to add content to sidebar
173+
"""
174+
with st.sidebar:
175+
st.markdown(
176+
"## How to use\n"
177+
"1. Upload a txt file📄\n"
178+
"3. Ask a question about the document💬\n"
179+
)
180+
st.markdown("---")
181+
st.markdown("## Which LM would you like to use?")
182+
option = st.selectbox(
183+
"Select a Model",
184+
lm_df['models'],
185+
label_visibility='hidden'
186+
)
187+
188+
st.markdown("---")
189+
st.markdown("# About")
190+
st.markdown(
191+
"📖PersonalGPT is a demo to showcase retrieval augmented question answering system"
192+
)
193+
st.markdown(":heart: Made by [raghav bali](https://raghavbali.github.io)")
194+
st.markdown("---")
195+
196+
return option

0 commit comments

Comments
 (0)