Skip to content

Commit 1242d28

Browse files
authored
Merge pull request #14 from conversence/feature/avoid_ensure_metadata
Optimizations and abstractions
2 parents f696edf + 793a1ab commit 1242d28

File tree

5 files changed

+167
-141
lines changed

5 files changed

+167
-141
lines changed

agentmemory/check_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def _download(url: str, fname: Path, chunk_size: int = 1024) -> None:
1818
size = file.write(data)
1919
bar.update(size)
2020

21-
default_model_path = str(Path.home() / ".cache" / "onnx_models")
21+
default_model_path = Path.home() / ".cache" / "onnx_models"
2222

2323
def check_model(model_name = "all-MiniLM-L6-v2", model_path = default_model_path) -> str:
2424
DOWNLOAD_PATH = Path(model_path) / model_name

agentmemory/chroma_client.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,78 @@
22

33
import chromadb
44

5+
from .client import CollectionMemory, AgentMemory
6+
7+
class ChromaCollectionMemory(CollectionMemory):
8+
def __init__(self, collection, metadata=None) -> None:
9+
self.collection = collection
10+
11+
def count(self):
12+
return self.collection.count()
13+
14+
def add(self, ids, documents=None, metadatas=None, embeddings=None):
15+
return self.collection.add(ids, documents, metadatas, embeddings)
16+
17+
def get(
18+
self,
19+
ids=None,
20+
where=None,
21+
limit=None,
22+
offset=None,
23+
where_document=None,
24+
include=["metadatas", "documents"],
25+
):
26+
return self.collection.get(ids, where, limit, offset, where_document, include)
27+
28+
def peek(self, limit=10):
29+
return self.collection.peek(limit)
30+
31+
def query(
32+
self,
33+
query_embeddings=None,
34+
query_texts=None,
35+
n_results=10,
36+
where=None,
37+
where_document=None,
38+
include=["metadatas", "documents", "distances"],
39+
):
40+
return self.collection.query(query_embeddings, query_texts, n_results, where, where_document, include)
41+
42+
def update(self, ids, documents=None, metadatas=None, embeddings=None):
43+
return self.collection.update(ids, embeddings, metadatas, documents)
44+
45+
def upsert(self, ids, documents=None, metadatas=None, embeddings=None):
46+
# if no id is provided, generate one based on count of documents in collection
47+
if any(id is None for id in ids):
48+
origin = self.count()
49+
# pad the id with zeros to make it 16 digits long
50+
ids = [str(id_).zfill(16) for id_ in range(origin, origin+len(documents))]
51+
52+
return self.collection.upsert(ids, embeddings, metadatas, documents)
53+
54+
def delete(self, ids=None, where=None, where_document=None):
55+
return self.collection.delete(ids, where, where_document)
56+
57+
58+
class ChromaMemory(AgentMemory):
59+
def __init__(self, path) -> None:
60+
self.chroma = chromadb.PersistentClient(path=path)
61+
62+
def get_or_create_collection(self, category, metadata=None) -> CollectionMemory:
63+
memory = self.chroma.get_or_create_collection(category)
64+
return ChromaCollectionMemory(memory, metadata)
65+
66+
def get_collection(self, category) -> CollectionMemory:
67+
memory = self.chroma.get_collection(category)
68+
return ChromaCollectionMemory(memory)
69+
70+
def delete_collection(self, category):
71+
self.chroma.delete_collection(category)
72+
73+
def list_collections(self):
74+
return self.chroma.list_collections()
75+
576

677
def create_client():
778
STORAGE_PATH = os.environ.get("STORAGE_PATH", "./memory")
8-
return chromadb.PersistentClient(path=STORAGE_PATH)
79+
return ChromaMemory(path=STORAGE_PATH)

agentmemory/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class AgentCollection():
105105

106106
class AgentMemory(ABC):
107107
@abstractmethod
108-
def get_or_create_collection(self, category) -> CollectionMemory:
108+
def get_or_create_collection(self, category, metadata=None) -> CollectionMemory:
109109
raise NotImplementedError()
110110

111111
@abstractmethod

agentmemory/main.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,6 @@ def create_memory(category, text, metadata={}, embedding=None, id=None):
3737
metadata["created_at"] = datetime.datetime.now().timestamp()
3838
metadata["updated_at"] = datetime.datetime.now().timestamp()
3939

40-
# if no id is provided, generate one based on count of documents in collection
41-
if id is None:
42-
id = str(memories.count())
43-
# pad the id with zeros to make it 16 digits long
44-
id = id.zfill(16)
45-
4640
# for each field in metadata...
4741
# if the field is a boolean, convert it to a string
4842
for key, value in metadata.items():
@@ -52,7 +46,7 @@ def create_memory(category, text, metadata={}, embedding=None, id=None):
5246

5347
# insert the document into the collection
5448
memories.upsert(
55-
ids=[str(id)],
49+
ids=[id],
5650
documents=[text],
5751
metadatas=[metadata],
5852
embeddings=[embedding] if embedding is not None else None,

0 commit comments

Comments
 (0)