Skip to content

Commit cc91df5

Browse files
committed
replace langchain with llama-index
1 parent 01d40e0 commit cc91df5

File tree

13 files changed

+175
-71
lines changed

13 files changed

+175
-71
lines changed

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,11 @@ key.yaml
154154
data
155155
data.ms
156156
examples/nb/
157+
examples/default__vector_store.json
158+
examples/docstore.json
159+
examples/graph_store.json
160+
examples/image__vector_store.json
161+
examples/index_store.json
157162
.chroma
158163
*~$*
159164
workspace/*
@@ -168,6 +173,7 @@ output
168173
tmp.png
169174
.dependencies.json
170175
tests/metagpt/utils/file_repo_git
176+
tests/data/rsp_cache.json
171177
*.tmp
172178
*.png
173179
htmlcov

examples/search_kb.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77
import asyncio
88

9-
from langchain.embeddings import OpenAIEmbeddings
9+
from llama_index.embeddings import OpenAIEmbedding
1010

1111
from metagpt.config2 import config
1212
from metagpt.const import DATA_PATH, EXAMPLE_PATH
@@ -17,7 +17,7 @@
1717

1818
def get_store():
1919
llm = config.get_openai_llm()
20-
embedding = OpenAIEmbeddings(openai_api_key=llm.api_key, openai_api_base=llm.base_url)
20+
embedding = OpenAIEmbedding(api_key=llm.api_key, api_base=llm.base_url)
2121
return FaissStore(DATA_PATH / "example.json", embedding=embedding)
2222

2323

metagpt/document.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,8 @@
1111
from typing import Optional, Union
1212

1313
import pandas as pd
14-
from langchain.text_splitter import CharacterTextSplitter
15-
from langchain_community.document_loaders import (
16-
TextLoader,
17-
UnstructuredPDFLoader,
18-
UnstructuredWordDocumentLoader,
19-
)
14+
from llama_index.node_parser import SimpleNodeParser
15+
from llama_index.readers import Document, PDFReader, SimpleDirectoryReader
2016
from pydantic import BaseModel, ConfigDict, Field
2117
from tqdm import tqdm
2218

@@ -29,7 +25,7 @@ def validate_cols(content_col: str, df: pd.DataFrame):
2925
raise ValueError("Content column not found in DataFrame.")
3026

3127

32-
def read_data(data_path: Path):
28+
def read_data(data_path: Path) -> Union[pd.DataFrame, list[Document]]:
3329
suffix = data_path.suffix
3430
if ".xlsx" == suffix:
3531
data = pd.read_excel(data_path)
@@ -38,14 +34,13 @@ def read_data(data_path: Path):
3834
elif ".json" == suffix:
3935
data = pd.read_json(data_path)
4036
elif suffix in (".docx", ".doc"):
41-
data = UnstructuredWordDocumentLoader(str(data_path), mode="elements").load()
37+
data = SimpleDirectoryReader(input_files=[str(data_path)]).load_data()
4238
elif ".txt" == suffix:
43-
data = TextLoader(str(data_path)).load()
44-
text_splitter = CharacterTextSplitter(separator="\n", chunk_size=256, chunk_overlap=0)
45-
texts = text_splitter.split_documents(data)
46-
data = texts
39+
data = SimpleDirectoryReader(input_files=[str(data_path)]).load_data()
40+
node_parser = SimpleNodeParser.from_defaults(separator="\n", chunk_size=256, chunk_overlap=0)
41+
data = node_parser.get_nodes_from_documents(data)
4742
elif ".pdf" == suffix:
48-
data = UnstructuredPDFLoader(str(data_path), mode="elements").load()
43+
data = PDFReader.load_data(str(data_path))
4944
else:
5045
raise NotImplementedError("File format not supported.")
5146
return data
@@ -150,17 +145,17 @@ def _get_docs_and_metadatas_by_df(self) -> (list, list):
150145
metadatas.append({})
151146
return docs, metadatas
152147

153-
def _get_docs_and_metadatas_by_langchain(self) -> (list, list):
148+
def _get_docs_and_metadatas_by_llamaindex(self) -> (list, list):
154149
data = self.data
155-
docs = [i.page_content for i in data]
150+
docs = [i.text for i in data]
156151
metadatas = [i.metadata for i in data]
157152
return docs, metadatas
158153

159154
def get_docs_and_metadatas(self) -> (list, list):
160155
if isinstance(self.data, pd.DataFrame):
161156
return self._get_docs_and_metadatas_by_df()
162157
elif isinstance(self.data, list):
163-
return self._get_docs_and_metadatas_by_langchain()
158+
return self._get_docs_and_metadatas_by_llamaindex()
164159
else:
165160
raise NotImplementedError("Data type not supported for metadata extraction.")
166161

metagpt/document_store/base_store.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ def __init__(self, raw_data_path: Path, cache_dir: Path = None):
3939
self.store = self.write()
4040

4141
def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"):
42-
index_file = self.cache_dir / f"{self.fname}{index_ext}"
43-
store_file = self.cache_dir / f"{self.fname}{pkl_ext}"
42+
index_file = self.cache_dir / "default__vector_store.json"
43+
store_file = self.cache_dir / "docstore.json"
4444
return index_file, store_file
4545

4646
@abstractmethod

metagpt/document_store/faiss_store.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@
77
"""
88
import asyncio
99
from pathlib import Path
10-
from typing import Optional
10+
from typing import Any, Optional
1111

12-
from langchain.vectorstores import FAISS
13-
from langchain_core.embeddings import Embeddings
12+
import faiss
13+
from llama_index import VectorStoreIndex, load_index_from_storage
14+
from llama_index.embeddings import BaseEmbedding
15+
from llama_index.schema import Document, QueryBundle, TextNode
16+
from llama_index.storage import StorageContext
17+
from llama_index.vector_stores import FaissVectorStore
1418

1519
from metagpt.document import IndexableDocument
1620
from metagpt.document_store.base_store import LocalStore
@@ -20,36 +24,52 @@
2024

2125
class FaissStore(LocalStore):
2226
def __init__(
23-
self, raw_data: Path, cache_dir=None, meta_col="source", content_col="output", embedding: Embeddings = None
27+
self, raw_data: Path, cache_dir=None, meta_col="source", content_col="output", embedding: BaseEmbedding = None
2428
):
2529
self.meta_col = meta_col
2630
self.content_col = content_col
2731
self.embedding = embedding or get_embedding()
32+
self.store: VectorStoreIndex
2833
super().__init__(raw_data, cache_dir)
2934

30-
def _load(self) -> Optional["FaissStore"]:
31-
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss
35+
def _load(self) -> Optional["VectorStoreIndex"]:
36+
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # FAISS using .faiss
3237

3338
if not (index_file.exists() and store_file.exists()):
3439
logger.info("Missing at least one of index_file/store_file, load failed and return None")
3540
return None
41+
vector_store = FaissVectorStore.from_persist_dir(persist_dir=self.cache_dir)
42+
storage_context = StorageContext.from_defaults(persist_dir=self.cache_dir, vector_store=vector_store)
43+
index = load_index_from_storage(storage_context)
3644

37-
return FAISS.load_local(self.raw_data_path.parent, self.embedding, self.fname)
45+
return index
3846

39-
def _write(self, docs, metadatas):
40-
store = FAISS.from_texts(docs, self.embedding, metadatas=metadatas)
41-
return store
47+
def _write(self, docs: list[str], metadatas: list[dict[str, Any]]) -> VectorStoreIndex:
48+
assert len(docs) == len(metadatas)
49+
texts_embeds = self.embedding.get_text_embedding_batch(docs)
50+
documents = [Document(text=doc, metadata=metadatas[idx]) for idx, doc in enumerate(docs)]
51+
52+
[TextNode(embedding=embed, metadata=metadatas[idx]) for idx, embed in enumerate(texts_embeds)]
53+
# doc_store = SimpleDocumentStore()
54+
# doc_store.add_documents(nodes)
55+
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(1536))
56+
storage_context = StorageContext.from_defaults(vector_store=vector_store)
57+
index = VectorStoreIndex.from_documents(documents=documents, storage_context=storage_context)
58+
59+
return index
4260

4361
def persist(self):
44-
self.store.save_local(self.raw_data_path.parent, self.fname)
62+
self.store.storage_context.persist(self.cache_dir)
63+
64+
def search(self, query: str, expand_cols=False, sep="\n", *args, k=5, **kwargs):
65+
retriever = self.store.as_retriever(similarity_top_k=k)
66+
rsp = retriever.retrieve(QueryBundle(query_str=query, embedding=self.embedding.get_text_embedding(query)))
4567

46-
def search(self, query, expand_cols=False, sep="\n", *args, k=5, **kwargs):
47-
rsp = self.store.similarity_search(query, k=k, **kwargs)
4868
logger.debug(rsp)
4969
if expand_cols:
50-
return str(sep.join([f"{x.page_content}: {x.metadata}" for x in rsp]))
70+
return str(sep.join([f"{x.node.text}: {x.node.metadata}" for x in rsp]))
5171
else:
52-
return str(sep.join([f"{x.page_content}" for x in rsp]))
72+
return str(sep.join([f"{x.node.text}" for x in rsp]))
5373

5474
async def asearch(self, *args, **kwargs):
5575
return await asyncio.to_thread(self.search, *args, **kwargs)
@@ -67,8 +87,12 @@ def write(self):
6787

6888
def add(self, texts: list[str], *args, **kwargs) -> list[str]:
6989
"""FIXME: Currently, the store is not updated after adding."""
70-
return self.store.add_texts(texts)
90+
texts_embeds = self.embedding.get_text_embedding_batch(texts)
91+
nodes = [TextNode(embedding=embed) for embed in texts_embeds]
92+
self.store.insert_nodes(nodes)
93+
94+
return []
7195

7296
def delete(self, *args, **kwargs):
73-
"""Currently, langchain does not provide a delete interface."""
97+
"""Currently, faiss does not provide a delete interface."""
7498
raise NotImplementedError

metagpt/memory/memory2.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# @Desc : memory mechanism including store/retrieval/rank
4+
5+
from typing import Union, Optional
6+
from pydantic import Field, BaseModel
7+
8+
from metagpt.memory.memory_network import MemoryNetwork
9+
from metagpt.memory.schema import MemoryNode
10+
from metagpt.schema import Message
11+
12+
13+
class Memory(BaseModel):
14+
mem_network: Optional[MemoryNetwork] = Field(default_factory=MemoryNetwork, description="the network to store memory")
15+
16+
def add_msg(self, message: Message):
17+
mem_node = MemoryNode.create_mem_node_from_message(message)
18+
self.mem_network.add_mem(mem_node)
19+
20+
def add_msgs(self, messages: list[Message]):
21+
for msg in messages:
22+
self.add_msg(msg)

metagpt/memory/memory_network.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# @Desc : the memory network to store memory segment
4+
5+
from pydantic import Field, BaseModel
6+
7+
from metagpt.memory.schema import MemorySegment, MemoryNode
8+
9+
10+
class MemoryNetwork(BaseModel):
11+
mem_seg: MemorySegment = Field(default_factory=MemorySegment, description="the memory segment to store memory nodes")
12+
13+
def add_mem(self, mem_node: MemoryNode):
14+
self.mem_seg.add_mem_node(mem_node)
15+
16+
def add_mems(self, mem_nodes: list[MemoryNode]):
17+
for mem_node in mem_nodes:
18+
self.add_mem(mem_node)

metagpt/memory/memory_storage.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55
"""
66

77
from pathlib import Path
8-
from typing import Optional
98

10-
from langchain.vectorstores.faiss import FAISS
11-
from langchain_core.embeddings import Embeddings
9+
from llama_index.embeddings import BaseEmbedding
1210

1311
from metagpt.const import DATA_PATH, MEM_TTL
1412
from metagpt.document_store.faiss_store import FaissStore
@@ -23,29 +21,17 @@ class MemoryStorage(FaissStore):
2321
The memory storage with Faiss as ANN search engine
2422
"""
2523

26-
def __init__(self, mem_ttl: int = MEM_TTL, embedding: Embeddings = None):
24+
def __init__(self, mem_ttl: int = MEM_TTL, embedding: BaseEmbedding = None):
2725
self.role_id: str = None
2826
self.role_mem_path: str = None
2927
self.mem_ttl: int = mem_ttl # later use
3028
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories
3129
self._initialized: bool = False
3230

33-
self.embedding = embedding or get_embedding()
34-
self.store: FAISS = None # Faiss engine
35-
3631
@property
3732
def is_initialized(self) -> bool:
3833
return self._initialized
3934

40-
def _load(self) -> Optional["FaissStore"]:
41-
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss
42-
43-
if not (index_file.exists() and store_file.exists()):
44-
logger.info("Missing at least one of index_file/store_file, load failed and return None")
45-
return None
46-
47-
return FAISS.load_local(self.role_mem_path, self.embedding, self.role_id)
48-
4935
def recover_memory(self, role_id: str) -> list[Message]:
5036
self.role_id = role_id
5137
self.role_mem_path = Path(DATA_PATH / f"role_mem/{self.role_id}/")
@@ -69,6 +55,7 @@ def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"):
6955
return None, None
7056
index_fpath = Path(self.role_mem_path / f"{self.role_id}{index_ext}")
7157
storage_fpath = Path(self.role_mem_path / f"{self.role_id}{pkl_ext}")
58+
self.cache_dir = Path(self.role_mem_path).joinpath(self.role_id)
7259
return index_fpath, storage_fpath
7360

7461
def persist(self):

metagpt/memory/schema.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# @Desc : the memory schema definition
4+
5+
from datetime import datetime
6+
from enum import Enum
7+
from typing import Optional, Union
8+
from uuid import UUID, uuid4
9+
10+
from pydantic import BaseModel, Field
11+
12+
13+
class MemNodeType(Enum):
14+
OBSERVE = "observe" # memory from observation
15+
THINK = "think" # memory from self-think/reflect
16+
17+
18+
class MemoryNode(BaseModel):
19+
"""base unit of memory abstraction"""
20+
21+
mem_node_id: UUID = Field(default_factory=uuid4(), description="unique node id")
22+
parent_node_id: Optional[str] = Field(default=None, description="memory's parent memory node id")
23+
node_type: MemNodeType = Field(default=MemNodeType.OBSERVE, description="memory node type")
24+
25+
content: str = Field(default="", description="the memory content")
26+
summary: Optional[str] = Field(default=None, description="the summary of the content by providers")
27+
keywords: list[str] = Field(default=[], description="the extracted keywords of the content")
28+
embedding: list[float] = Field(default=[], description="the embeeding of the content")
29+
30+
raw_path: Optional[str] = Field(default=None, description="the relative path of the media like image")
31+
raw_corpus: list[Union[str, dict, tuple]] = Field(default=[], description="the raw corpus of the memory")
32+
33+
create_at: datetime = Field(default_factory=datetime, description="the memory create time")
34+
access_at: datetime = Field(default_factory=datetime, description="the memory last access time")
35+
expire_at: datetime = Field(default_factory=datetime, description="the memory expire time due to a TTL")
36+
37+
importance: int = Field(default=0, ge=0, le=10, description="the memory importance")
38+
access_cnt: int = Field(default=0, description="the memory acess count time")
39+
40+
@classmethod
41+
def create_mem_node(
42+
cls,
43+
content: str,
44+
summary: Optional[str] = None,
45+
keywords: list[str] = [],
46+
node_type: MemNodeType = MemNodeType.OBSERVE,
47+
):
48+
pass
49+
50+
@classmethod
51+
def create_mem_node_from_message(cls, message: "Message"):
52+
pass
53+
54+
55+
class MemorySegment(BaseModel):
56+
"""segment abstraction to store memory_node"""
57+
58+
mem_nodes: list[MemoryNode] = Field(default=[], description="memory list to store MemoryNode")
59+
60+
def add_mem_node(self, mem_node: MemoryNode):
61+
self.mem_nodes.append(mem_node)

metagpt/roles/role.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,6 @@ class RoleContext(BaseModel):
108108
) # see `Role._set_react_mode` for definitions of the following two attributes
109109
max_react_loop: int = 1
110110

111-
def check(self, role_id: str):
112-
# if hasattr(CONFIG, "enable_longterm_memory") and CONFIG.enable_longterm_memory:
113-
# self.long_term_memory.recover_memory(role_id, self)
114-
# self.memory = self.long_term_memory # use memory to act as long_term_memory for unify operation
115-
pass
116-
117111
@property
118112
def important_memory(self) -> list[Message]:
119113
"""Retrieve information corresponding to the attention action."""
@@ -313,8 +307,6 @@ def _watch(self, actions: Iterable[Type[Action]] | Iterable[Action]):
313307
buffer during _observe.
314308
"""
315309
self.rc.watch = {any_to_str(t) for t in actions}
316-
# check RoleContext after adding watch actions
317-
self.rc.check(self.role_id)
318310

319311
def is_watch(self, caused_by: str):
320312
return caused_by in self.rc.watch

0 commit comments

Comments
 (0)