|
2 | 2 |
|
3 | 3 | import chromadb |
4 | 4 |
|
| 5 | +from .client import CollectionMemory, AgentMemory |
| 6 | + |
| 7 | +class ChromaCollectionMemory(CollectionMemory): |
| 8 | + def __init__(self, collection, metadata=None) -> None: |
| 9 | + self.collection = collection |
| 10 | + |
| 11 | + def count(self): |
| 12 | + return self.collection.count() |
| 13 | + |
| 14 | + def add(self, ids, documents=None, metadatas=None, embeddings=None): |
| 15 | + return self.collection.add(ids, documents, metadatas, embeddings) |
| 16 | + |
| 17 | + def get( |
| 18 | + self, |
| 19 | + ids=None, |
| 20 | + where=None, |
| 21 | + limit=None, |
| 22 | + offset=None, |
| 23 | + where_document=None, |
| 24 | + include=["metadatas", "documents"], |
| 25 | + ): |
| 26 | + return self.collection.get(ids, where, limit, offset, where_document, include) |
| 27 | + |
| 28 | + def peek(self, limit=10): |
| 29 | + return self.collection.peek(limit) |
| 30 | + |
| 31 | + def query( |
| 32 | + self, |
| 33 | + query_embeddings=None, |
| 34 | + query_texts=None, |
| 35 | + n_results=10, |
| 36 | + where=None, |
| 37 | + where_document=None, |
| 38 | + include=["metadatas", "documents", "distances"], |
| 39 | + ): |
| 40 | + return self.collection.query(query_embeddings, query_texts, n_results, where, where_document, include) |
| 41 | + |
| 42 | + def update(self, ids, documents=None, metadatas=None, embeddings=None): |
| 43 | + return self.collection.update(ids, embeddings, metadatas, documents) |
| 44 | + |
| 45 | + def upsert(self, ids, documents=None, metadatas=None, embeddings=None): |
| 46 | + # if no id is provided, generate one based on count of documents in collection |
| 47 | + if any(id is None for id in ids): |
| 48 | + origin = self.count() |
| 49 | + # pad the id with zeros to make it 16 digits long |
| 50 | + ids = [str(id_).zfill(16) for id_ in range(origin, origin+len(documents))] |
| 51 | + |
| 52 | + return self.collection.upsert(ids, embeddings, metadatas, documents) |
| 53 | + |
| 54 | + def delete(self, ids=None, where=None, where_document=None): |
| 55 | + return self.collection.delete(ids, where, where_document) |
| 56 | + |
| 57 | + |
| 58 | +class ChromaMemory(AgentMemory): |
| 59 | + def __init__(self, path) -> None: |
| 60 | + self.chroma = chromadb.PersistentClient(path=path) |
| 61 | + |
| 62 | + def get_or_create_collection(self, category, metadata=None) -> CollectionMemory: |
| 63 | + memory = self.chroma.get_or_create_collection(category) |
| 64 | + return ChromaCollectionMemory(memory, metadata) |
| 65 | + |
| 66 | + def get_collection(self, category) -> CollectionMemory: |
| 67 | + memory = self.chroma.get_collection(category) |
| 68 | + return ChromaCollectionMemory(memory) |
| 69 | + |
| 70 | + def delete_collection(self, category): |
| 71 | + self.chroma.delete_collection(category) |
| 72 | + |
| 73 | + def list_collections(self): |
| 74 | + return self.chroma.list_collections() |
| 75 | + |
5 | 76 |
|
6 | 77 | def create_client(): |
7 | 78 | STORAGE_PATH = os.environ.get("STORAGE_PATH", "./memory") |
8 | | - return chromadb.PersistentClient(path=STORAGE_PATH) |
| 79 | + return ChromaMemory(path=STORAGE_PATH) |
0 commit comments