77"""
88import asyncio
99from pathlib import Path
10- from typing import Optional
10+ from typing import Any , Optional
1111
12- from langchain .vectorstores import FAISS
13- from langchain_core .embeddings import Embeddings
12+ import faiss
13+ from llama_index import VectorStoreIndex , load_index_from_storage
14+ from llama_index .embeddings import BaseEmbedding
15+ from llama_index .schema import Document , QueryBundle , TextNode
16+ from llama_index .storage import StorageContext
17+ from llama_index .vector_stores import FaissVectorStore
1418
1519from metagpt .document import IndexableDocument
1620from metagpt .document_store .base_store import LocalStore
2024
2125class FaissStore (LocalStore ):
2226 def __init__ (
23- self , raw_data : Path , cache_dir = None , meta_col = "source" , content_col = "output" , embedding : Embeddings = None
27+ self , raw_data : Path , cache_dir = None , meta_col = "source" , content_col = "output" , embedding : BaseEmbedding = None
2428 ):
2529 self .meta_col = meta_col
2630 self .content_col = content_col
2731 self .embedding = embedding or get_embedding ()
32+ self .store : VectorStoreIndex
2833 super ().__init__ (raw_data , cache_dir )
2934
30- def _load (self ) -> Optional ["FaissStore " ]:
31- index_file , store_file = self ._get_index_and_store_fname (index_ext = ".faiss" ) # langchain FAISS using .faiss
35+ def _load (self ) -> Optional ["VectorStoreIndex " ]:
36+ index_file , store_file = self ._get_index_and_store_fname (index_ext = ".faiss" ) # FAISS using .faiss
3237
3338 if not (index_file .exists () and store_file .exists ()):
3439 logger .info ("Missing at least one of index_file/store_file, load failed and return None" )
3540 return None
41+ vector_store = FaissVectorStore .from_persist_dir (persist_dir = self .cache_dir )
42+ storage_context = StorageContext .from_defaults (persist_dir = self .cache_dir , vector_store = vector_store )
43+ index = load_index_from_storage (storage_context )
3644
37- return FAISS . load_local ( self . raw_data_path . parent , self . embedding , self . fname )
45+ return index
3846
39- def _write (self , docs , metadatas ):
40- store = FAISS .from_texts (docs , self .embedding , metadatas = metadatas )
41- return store
47+ def _write (self , docs : list [str ], metadatas : list [dict [str , Any ]]) -> VectorStoreIndex :
48+ assert len (docs ) == len (metadatas )
49+ texts_embeds = self .embedding .get_text_embedding_batch (docs )
50+ documents = [Document (text = doc , metadata = metadatas [idx ]) for idx , doc in enumerate (docs )]
51+
52+ [TextNode (embedding = embed , metadata = metadatas [idx ]) for idx , embed in enumerate (texts_embeds )]
53+ # doc_store = SimpleDocumentStore()
54+ # doc_store.add_documents(nodes)
55+ vector_store = FaissVectorStore (faiss_index = faiss .IndexFlatL2 (1536 ))
56+ storage_context = StorageContext .from_defaults (vector_store = vector_store )
57+ index = VectorStoreIndex .from_documents (documents = documents , storage_context = storage_context )
58+
59+ return index
4260
4361 def persist (self ):
44- self .store .save_local (self .raw_data_path .parent , self .fname )
62+ self .store .storage_context .persist (self .cache_dir )
63+
64+ def search (self , query : str , expand_cols = False , sep = "\n " , * args , k = 5 , ** kwargs ):
65+ retriever = self .store .as_retriever (similarity_top_k = k )
66+ rsp = retriever .retrieve (QueryBundle (query_str = query , embedding = self .embedding .get_text_embedding (query )))
4567
46- def search (self , query , expand_cols = False , sep = "\n " , * args , k = 5 , ** kwargs ):
47- rsp = self .store .similarity_search (query , k = k , ** kwargs )
4868 logger .debug (rsp )
4969 if expand_cols :
50- return str (sep .join ([f"{ x .page_content } : { x .metadata } " for x in rsp ]))
70+ return str (sep .join ([f"{ x .node . text } : { x . node .metadata } " for x in rsp ]))
5171 else :
52- return str (sep .join ([f"{ x .page_content } " for x in rsp ]))
72+ return str (sep .join ([f"{ x .node . text } " for x in rsp ]))
5373
5474 async def asearch (self , * args , ** kwargs ):
5575 return await asyncio .to_thread (self .search , * args , ** kwargs )
@@ -67,8 +87,12 @@ def write(self):
6787
6888 def add (self , texts : list [str ], * args , ** kwargs ) -> list [str ]:
6989 """FIXME: Currently, the store is not updated after adding."""
70- return self .store .add_texts (texts )
90+ texts_embeds = self .embedding .get_text_embedding_batch (texts )
91+ nodes = [TextNode (embedding = embed ) for embed in texts_embeds ]
92+ self .store .insert_nodes (nodes )
93+
94+ return []
7195
7296 def delete (self , * args , ** kwargs ):
73- """Currently, langchain does not provide a delete interface."""
97+ """Currently, faiss does not provide a delete interface."""
7498 raise NotImplementedError
0 commit comments