Skip to content

Commit 0f22162

Browse files
committed
enable multiple chat sessions
1 parent 3be1eca commit 0f22162

File tree

8 files changed

+274
-131
lines changed

8 files changed

+274
-131
lines changed

migrations/3.surrealql

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
REMOVE FUNCTION fn::vector_search;
1+
2+
DEFINE TABLE IF NOT EXISTS chat_session SCHEMALESS;
3+
4+
DEFINE TABLE IF NOT EXISTS refers_to
5+
TYPE RELATION
6+
FROM chat_session TO notebook;
7+
8+
REMOVE FUNCTION IF EXISTS fn::vector_search;
29

310
DEFINE FUNCTION IF NOT EXISTS fn::vector_search($query: array<float>, $match_count: int, $sources: bool, $show_notes: bool, $min_similarity: float) {
411
let $source_embedding_search =
@@ -16,7 +23,6 @@ DEFINE FUNCTION IF NOT EXISTS fn::vector_search($query: array<float>, $match_cou
1623
)}
1724
ELSE { [] };
1825

19-
-- Busca em source_insight com threshold
2026
let $source_insight_search =
2127
IF $sources {(
2228
SELECT
@@ -67,10 +73,10 @@ DEFINE FUNCTION IF NOT EXISTS fn::vector_search($query: array<float>, $match_cou
6773
};
6874

6975

70-
REMOVE FUNCTION fn::text_search;
76+
REMOVE FUNCTION IF EXISTS fn::text_search;
7177

7278

73-
DEFINE FUNCTION IF NOT EXISTS fn::text_search($query_text: string, $match_count: int, $sources:bool, $show_notes:bool) {
79+
DEFINE FUNCTION IF NOT EXISTS fn::text_search($query_text: string, $match_count: int, $sources:bool, $show_notes:bool) {
7480

7581
let $source_title_search =
7682
IF $sources {(

migrations/3_down.surrealql

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
REMOVE TABLE IF EXISTS chat_session;
2+
3+
REMOVE TABLE IF EXISTS refers_to;
4+
5+
16
REMOVE FUNCTION fn::vector_search;
27

38

open_notebook/domain/notebook.py

Lines changed: 67 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
import os
21
from typing import Any, ClassVar, Dict, List, Literal, Optional
32

43
from loguru import logger
54
from pydantic import BaseModel, Field, field_validator
65

76
from open_notebook.database.repository import (
8-
repo_create,
97
repo_query,
108
)
119
from open_notebook.domain.base import ObjectModel
@@ -68,6 +66,27 @@ def notes(self) -> List["Note"]:
6866
logger.exception(e)
6967
raise DatabaseOperationError(e)
7068

69+
@property
70+
def chat_sessions(self) -> List["ChatSession"]:
71+
try:
72+
srcs = repo_query(f"""
73+
select * from (
74+
select
75+
<- chat_session as chat_session
76+
from refers_to
77+
where out={self.id}
78+
fetch chat_session
79+
)
80+
order by chat_session.updated desc
81+
""")
82+
return (
83+
[ChatSession(**src["chat_session"][0]) for src in srcs] if srcs else []
84+
)
85+
except Exception as e:
86+
logger.error(f"Error fetching notes for notebook {self.id}: {str(e)}")
87+
logger.exception(e)
88+
raise DatabaseOperationError(e)
89+
7190

7291
class Asset(BaseModel):
7392
file_path: Optional[str] = None
@@ -99,6 +118,22 @@ def get_context(
99118
else:
100119
return dict(id=self.id, title=self.title, insights=self.insights)
101120

121+
@property
122+
def embedded_chunks(self) -> int:
123+
try:
124+
result = repo_query(
125+
f"""
126+
select count() as chunks from source_embedding where source={self.id} GROUP ALL
127+
"""
128+
)
129+
if len(result) == 0:
130+
return 0
131+
return result[0]["chunks"]
132+
except Exception as e:
133+
logger.error(f"Error fetching insights for source {self.id}: {str(e)}")
134+
logger.exception(e)
135+
raise DatabaseOperationError(f"Failed to count chunks for source: {str(e)}")
136+
102137
@property
103138
def insights(self) -> List[SourceInsight]:
104139
try:
@@ -118,24 +153,6 @@ def add_to_notebook(self, notebook_id: str) -> Any:
118153
raise InvalidInputError("Notebook ID must be provided")
119154
return self.relate("reference", notebook_id)
120155

121-
def save_chunks(self, text: str) -> None:
122-
if not text:
123-
raise InvalidInputError("Text cannot be empty")
124-
try:
125-
chunks = split_text(text, chunk=500000, overlap=1000)
126-
logger.debug(f"Split into {len(chunks)} chunks")
127-
for i, chunk in enumerate(chunks):
128-
logger.debug(f"Saving chunk {i}")
129-
data = {"source": self.id, "order": i, "content": surreal_clean(chunk)}
130-
repo_create(
131-
"source_chunk",
132-
data,
133-
)
134-
except Exception as e:
135-
logger.exception(e)
136-
logger.error(f"Error saving chunks for source {self.id}: {str(e)}")
137-
raise DatabaseOperationError(e)
138-
139156
def vectorize(self) -> None:
140157
EMBEDDING_MODEL = model_manager.embedding_model
141158

@@ -144,8 +161,6 @@ def vectorize(self) -> None:
144161
return
145162
chunks = split_text(
146163
self.full_text,
147-
chunk=int(os.environ.get("EMBEDDING_CHUNK_SIZE", 1000)),
148-
overlap=int(os.environ.get("EMBEDDING_CHUNK_OVERLAP", 1000)),
149164
)
150165
logger.debug(f"Split into {len(chunks)} chunks")
151166

@@ -166,26 +181,26 @@ def vectorize(self) -> None:
166181
logger.exception(e)
167182
raise DatabaseOperationError(e)
168183

169-
@classmethod
170-
def search(cls, query: str) -> List[Dict[str, Any]]:
171-
if not query:
172-
raise InvalidInputError("Search query cannot be empty")
173-
try:
174-
result = repo_query(
175-
"""
176-
SELECT * omit full_text
177-
FROM source
178-
WHERE string::lowercase(title) CONTAINS $query or title @@ $query
179-
OR string::lowercase(summary) CONTAINS $query or summary @@ $query
180-
OR string::lowercase(full_text) CONTAINS $query or full_text @@ $query
181-
""",
182-
{"query": query},
183-
)
184-
return result
185-
except Exception as e:
186-
logger.error(f"Error searching sources: {str(e)}")
187-
logger.exception(e)
188-
raise DatabaseOperationError("Failed to search sources")
184+
# @classmethod
185+
# def search(cls, query: str) -> List[Dict[str, Any]]:
186+
# if not query:
187+
# raise InvalidInputError("Search query cannot be empty")
188+
# try:
189+
# result = repo_query(
190+
# """
191+
# SELECT * omit full_text
192+
# FROM source
193+
# WHERE string::lowercase(title) CONTAINS $query or title @@ $query
194+
# OR string::lowercase(summary) CONTAINS $query or summary @@ $query
195+
# OR string::lowercase(full_text) CONTAINS $query or full_text @@ $query
196+
# """,
197+
# {"query": query},
198+
# )
199+
# return result
200+
# except Exception as e:
201+
# logger.error(f"Error searching sources: {str(e)}")
202+
# logger.exception(e)
203+
# raise DatabaseOperationError("Failed to search sources")
189204

190205
def add_insight(self, insight_type: str, content: str) -> Any:
191206
EMBEDDING_MODEL = model_manager.embedding_model
@@ -246,6 +261,16 @@ def get_embedding_content(self) -> Optional[str]:
246261
return self.content
247262

248263

264+
class ChatSession(ObjectModel):
265+
table_name: ClassVar[str] = "chat_session"
266+
title: Optional[str] = None
267+
268+
def relate_to_notebook(self, notebook_id: str) -> Any:
269+
if not notebook_id:
270+
raise InvalidInputError("Notebook ID must be provided")
271+
return self.relate("refers_to", notebook_id)
272+
273+
249274
def text_search(keyword: str, results: int, source: bool = True, note: bool = True):
250275
if not keyword:
251276
raise InvalidInputError("Search keyword cannot be empty")
@@ -263,18 +288,6 @@ def text_search(keyword: str, results: int, source: bool = True, note: bool = Tr
263288
raise DatabaseOperationError(e)
264289

265290

266-
# def hybrid_search(
267-
# keyword_search: List[str],
268-
# embed_search: List[str],
269-
# results: int = 50,
270-
# source: bool = True,
271-
# note: bool = True,
272-
# ):
273-
# EMBEDDING_MODEL = model_manager.embedding_model
274-
# embed1_vector = EMBEDDING_MODEL.embed(embed1) if embed1 else None
275-
276-
277-
# todo: mover o embedding pra ca
278291
def vector_search(keyword: str, results: int, source: bool = True, note: bool = True):
279292
if not keyword:
280293
raise InvalidInputError("Search keyword cannot be empty")

pages/2_📒_Notebooks.py

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
setup_page("📒 Open Notebook")
1111

1212

13-
def notebook_header(current_notebook):
13+
def notebook_header(current_notebook: Notebook):
14+
"""
15+
Defines the header of the notebook page, including the ability to edit the notebook's name and description.
16+
"""
1417
c1, c2, c3 = st.columns([8, 2, 2])
1518
c1.header(current_notebook.name)
1619
if c2.button("Back to the list", icon="🔙"):
17-
st.session_state["current_notebook"] = None
20+
st.session_state["current_notebook_id"] = None
1821
st.rerun()
1922

2023
if c3.button("Refresh", icon="🔄"):
@@ -49,20 +52,20 @@ def notebook_header(current_notebook):
4952
st.toast("Notebook unarchived", icon="🗃️")
5053
if c3.button("Delete forever", type="primary", icon="☠️"):
5154
current_notebook.delete()
52-
st.session_state["current_notebook"] = None
55+
st.session_state["current_notebook_id"] = None
5356
st.rerun()
5457

5558

56-
def notebook_page(current_notebook_id):
57-
current_notebook: Notebook = Notebook.get(current_notebook_id)
58-
if not current_notebook:
59-
st.error("Notebook not found")
60-
return
61-
if current_notebook_id not in st.session_state.keys():
62-
st.session_state[current_notebook_id] = current_notebook
59+
def notebook_page(current_notebook: Notebook):
60+
# Guarantees that we have an entry for this notebook in the session state
61+
if current_notebook.id not in st.session_state:
62+
st.session_state[current_notebook.id] = {"notebook": current_notebook}
63+
64+
# sets up the active session
65+
current_session = setup_stream_state(
66+
current_notebook=current_notebook,
67+
)
6368

64-
session_id = st.session_state["active_session"]
65-
st.session_state[session_id]["notebook"] = current_notebook
6669
sources = current_notebook.sources
6770
notes = current_notebook.notes
6871

@@ -74,18 +77,18 @@ def notebook_page(current_notebook_id):
7477
with sources_tab:
7578
with st.container(border=True):
7679
if st.button("Add Source", icon="➕"):
77-
add_source(session_id)
80+
add_source(current_notebook.id)
7881
for source in sources:
79-
source_card(session_id=session_id, source=source)
82+
source_card(source=source, notebook_id=current_notebook.id)
8083

8184
with notes_tab:
8285
with st.container(border=True):
8386
if st.button("Write a Note", icon="📝"):
84-
add_note(session_id)
87+
add_note(current_notebook.id)
8588
for note in notes:
86-
note_card(session_id=session_id, note=note)
89+
note_card(note=note, notebook_id=current_notebook.id)
8790
with chat_tab:
88-
chat_sidebar(session_id=session_id)
91+
chat_sidebar(current_notebook=current_notebook, current_session=current_session)
8992

9093

9194
def notebook_list_item(notebook):
@@ -96,40 +99,50 @@ def notebook_list_item(notebook):
9699
)
97100
st.write(notebook.description)
98101
if st.button("Open", key=f"open_notebook_{notebook.id}"):
99-
setup_stream_state(notebook.id)
100-
st.session_state["current_notebook"] = notebook.id
102+
st.session_state["current_notebook_id"] = notebook.id
101103
st.rerun()
102104

103105

104-
if "current_notebook" not in st.session_state:
105-
st.session_state["current_notebook"] = None
106+
if "current_notebook_id" not in st.session_state:
107+
st.session_state["current_notebook_id"] = None
106108

107-
if st.session_state["current_notebook"]:
108-
notebook_page(st.session_state["current_notebook"])
109+
# todo: get the notebook, check if it exists and if it's archived
110+
if st.session_state["current_notebook_id"]:
111+
current_notebook: Notebook = Notebook.get(st.session_state["current_notebook_id"])
112+
if not current_notebook:
113+
st.error("Notebook not found")
114+
st.stop()
115+
notebook_page(current_notebook)
109116
st.stop()
110117

111118
st.title("📒 My Notebooks")
112-
st.caption("Here are all your notebooks")
113-
114-
notebooks = Notebook.get_all(order_by="updated desc")
115-
116-
for notebook in notebooks:
117-
if notebook.archived:
118-
continue
119-
notebook_list_item(notebook)
119+
st.caption(
120+
"Notebooks are a great way to organize your thoughts, ideas, and sources. You can create notebooks for different research topics and projects, to create new articles, etc. "
121+
)
120122

121123
with st.expander("➕ **New Notebook**"):
122124
new_notebook_title = st.text_input("New Notebook Name")
123-
new_notebook_description = st.text_area("Description")
125+
new_notebook_description = st.text_area(
126+
"Description",
127+
placeholder="Explain the purpose of this notebook. The more details the better.",
128+
)
124129
if st.button("Create a new Notebook", icon="➕"):
125130
notebook = Notebook(
126131
name=new_notebook_title, description=new_notebook_description
127132
)
128133
notebook.save()
129-
st.rerun()
134+
st.toast("Notebook created successfully", icon="📒")
130135

136+
notebooks = Notebook.get_all(order_by="updated desc")
131137
archived_notebooks = [nb for nb in notebooks if nb.archived]
138+
139+
for notebook in notebooks:
140+
if notebook.archived:
141+
continue
142+
notebook_list_item(notebook)
143+
132144
if len(archived_notebooks) > 0:
133145
with st.expander(f"**🗃️ {len(archived_notebooks)} archived Notebooks**"):
146+
st.write("ℹ Archived Notebooks can still be accessed and used in search.")
134147
for notebook in archived_notebooks:
135148
notebook_list_item(notebook)

0 commit comments

Comments
 (0)