Skip to content

Commit ca86a7f

Browse files
committed
Add sentence transformer for embeddings + config for what models to use
1 parent 3559621 commit ca86a7f

File tree

5 files changed

+43
-14
lines changed

5 files changed

+43
-14
lines changed

bot.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import streamlit as st
44
from langchain.vectorstores.neo4j_vector import Neo4jVector
55
from langchain.embeddings.openai import OpenAIEmbeddings
6-
from langchain.embeddings import OllamaEmbeddings
6+
from langchain.embeddings import OllamaEmbeddings, SentenceTransformerEmbeddings
77
from langchain.chat_models import ChatOpenAI, ChatOllama
88
from langchain.chains import RetrievalQAWithSourcesChain
99
from langchain.prompts.chat import (
@@ -19,14 +19,30 @@
1919
username = os.getenv("NEO4J_USERNAME")
2020
password = os.getenv("NEO4J_PASSWORD")
2121
ollama_base_url = os.getenv("OLLAMA_BASE_URL")
22+
embedding_model_name = os.getenv("EMBEDDING_MODEL")
23+
llm_name = os.getenv("LLM")
2224

2325
os.environ["NEO4J_URL"] = url
2426

25-
# embeddings = OllamaEmbeddings(base_url=ollama_base_url)
26-
# llm = ChatOllama(temperature=0, base_url=ollama_base_url)
27-
28-
embeddings = OpenAIEmbeddings()
29-
llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo")
27+
if embedding_model_name == "ollama":
28+
embeddings = OllamaEmbeddings(base_url=ollama_base_url, model="llama2")
29+
print("Embedding: Using Ollama")
30+
elif embedding_model_name == "openai":
31+
embeddings = OpenAIEmbeddings()
32+
print("Embedding: Using OpenAI")
33+
else:
34+
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
35+
print("Embedding: Using SentenceTransformer")
36+
37+
if llm_name == "gpt-4":
38+
llm = ChatOpenAI(temperature=0, model_name="gpt-4")
39+
print("LLM: Using GPT-4")
40+
elif llm_name == "ollama":
41+
llm = ChatOllama(temperature=0, base_url=ollama_base_url, model="llama2")
42+
print("LLM: Using Ollama (llama2)")
43+
else:
44+
llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo")
45+
print("LLM: Using GPT-3.5 Turbo")
3046

3147
# LLM only response
3248
template = "You are a helpful assistant that helps with programming questions."

docker-compose.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ services:
2828
- NEO4J_USERNAME=${NEO4J_USERNAME-neo4j}
2929
- OPENAI_API_KEY=${OPENAI_API_KEY}
3030
- OLLAMA_BASE_URL=${OLLAMA_BASE_URL-http://host.docker.internal:11434}
31+
- EMBEDDING_MODEL=${EMBEDDING_MODEL-sentence_transformer}
3132
networks:
3233
- net
3334
depends_on:
@@ -53,6 +54,8 @@ services:
5354
- NEO4J_USERNAME=${NEO4J_USERNAME-neo4j}
5455
- OPENAI_API_KEY=${OPENAI_API_KEY}
5556
- OLLAMA_BASE_URL=${OLLAMA_BASE_URL-http://host.docker.internal:11434}
57+
- LLM=${LLM-gpt-3.5}
58+
- EMBEDDING_MODEL=${EMBEDDING_MODEL-sentence_transformer}
5659
networks:
5760
- net
5861
depends_on:

example.env

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@ OLLAMA_BASE_URL=http://host.docker.internal:11434
33
#NEO4J_URI=neo4j://localhost:7687
44
#NEO4J_USERNAME=neo4j
55
#NEO4J_PASSWORD=password
6+
LLM=ollama #or gpt-4 or gpt-3.5
7+
EMBEDDING_MODEL=sentence_transformer #or openai or ollama

loader.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import requests
33

44
from dotenv import load_dotenv
5-
from langchain.embeddings import OllamaEmbeddings, OpenAIEmbeddings
5+
from langchain.embeddings import OllamaEmbeddings, OpenAIEmbeddings, SentenceTransformerEmbeddings
66
from langchain.graphs import Neo4jGraph
77

88
import streamlit as st
@@ -13,18 +13,25 @@
1313
username = os.getenv("NEO4J_USERNAME")
1414
password = os.getenv("NEO4J_PASSWORD")
1515
ollama_base_url = os.getenv("OLLAMA_BASE_URL")
16+
embedding_model_name = os.getenv("EMBEDDING_MODEL")
1617

1718
os.environ["NEO4J_URL"] = url
1819

19-
# embeddings = OllamaEmbeddings(base_url=ollama_base_url)
20-
# dimension = 4096 # Ollama
21-
22-
embeddings = OpenAIEmbeddings()
23-
dimension = 1536 # OpenAi
20+
if embedding_model_name == "ollama":
21+
embeddings = OllamaEmbeddings(base_url=ollama_base_url, model="llama2")
22+
dimension = 4096
23+
print("Embedding: Using Ollama")
24+
elif embedding_model_name == "openai":
25+
embeddings = OpenAIEmbeddings()
26+
dimension = 1536
27+
print("Embedding: Using OpenAI")
28+
else:
29+
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
30+
dimension = 384
31+
print("Embedding: Using SentenceTransformer")
2432

2533
neo4j_graph = Neo4jGraph(url=url, username=username, password=password)
2634

27-
2835
def create_constraints():
2936
neo4j_graph.query(
3037
"CREATE CONSTRAINT question_id IF NOT EXISTS FOR (q:Question) REQUIRE (q.id) IS UNIQUE"

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ python-dotenv
44
wikipedia
55
tiktoken
66
neo4j
7-
streamlit
7+
streamlit
8+
sentence_transformers

0 commit comments

Comments
 (0)