Skip to content

Commit f20d3ca

Browse files
committed
use long-term memory in rolezero
1 parent 4e82d86 commit f20d3ca

File tree

8 files changed

+440
-15
lines changed

8 files changed

+440
-15
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ examples/image__vector_store.json
163163
examples/index_store.json
164164
.chroma
165165
.chroma_exp_data
166+
.role_memory_data
166167
*~$*
167168
workspace/*
168169
tmp

metagpt/base/base_role.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
from __future__ import annotations
2+
13
from abc import abstractmethod
2-
from typing import Optional, Union
4+
from typing import TYPE_CHECKING, Optional, Union
35

46
from metagpt.base.base_serialization import BaseSerialization
5-
from metagpt.schema import Message
7+
8+
if TYPE_CHECKING:
9+
from metagpt.schema import Message
610

711

812
class BaseRole(BaseSerialization):

metagpt/memory/memory.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,7 @@ def get_by_actions(self, actions: Set) -> list[Message]:
104104
continue
105105
rsp += self.index[action]
106106
return rsp
107+
108+
def get_by_position(self, position: int) -> Message:
109+
"""Return the message by its position"""
110+
return self.storage[position]

metagpt/memory/role_zero_memory.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from typing import TYPE_CHECKING, Any
2+
3+
from pydantic import BaseModel, Field
4+
5+
from metagpt.schema import LongTermMemoryItem, Message
6+
7+
if TYPE_CHECKING:
8+
from llama_index.core.schema import NodeWithScore
9+
10+
from metagpt.rag.engines import SimpleEngine
11+
12+
13+
class RoleZeroLongTermMemory(BaseModel):
14+
persist_path: str = Field(default=".role_memory_data", description="The directory to save data.")
15+
collection_name: str = Field(default="role_zero", description="The name of the collection, such as the role name.")
16+
17+
_rag_engine: Any = None
18+
19+
@property
20+
def rag_engine(self) -> "SimpleEngine":
21+
if self._rag_engine is None:
22+
self._rag_engine = self._resolve_rag_engine()
23+
24+
return self._rag_engine
25+
26+
def _resolve_rag_engine(self) -> "SimpleEngine":
27+
try:
28+
from metagpt.rag.engines import SimpleEngine
29+
from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig
30+
except ImportError:
31+
raise ImportError("To use the RoleZeroMemory, you need to install the rag module.")
32+
33+
retriever_configs = [
34+
ChromaRetrieverConfig(persist_path=self.persist_path, collection_name=self.collection_name)
35+
]
36+
ranker_configs = [LLMRankerConfig()]
37+
38+
rag_engine = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs)
39+
40+
return rag_engine
41+
42+
def fetch(self, query: str) -> list[Message]:
43+
if not query:
44+
return []
45+
46+
nodes: list[NodeWithScore] = self.rag_engine.retrieve(query)
47+
48+
memories = []
49+
for node in nodes:
50+
item: LongTermMemoryItem = node.metadata["obj"]
51+
memories.append(item.user_message)
52+
memories.append(item.ai_message)
53+
54+
return memories
55+
56+
def add(self, item: LongTermMemoryItem):
57+
if not item:
58+
return
59+
60+
self.rag_engine.add_objs([item])

metagpt/roles/di/role_zero.py

Lines changed: 115 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from metagpt.exp_pool.context_builders import RoleZeroContextBuilder
1919
from metagpt.exp_pool.serializers import RoleZeroSerializer
2020
from metagpt.logs import logger
21+
from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory
2122
from metagpt.prompts.di.role_zero import (
2223
ASK_HUMAN_COMMAND,
2324
CMD_PROMPT,
@@ -34,14 +35,15 @@
3435
THOUGHT_GUIDANCE,
3536
)
3637
from metagpt.roles import Role
37-
from metagpt.schema import AIMessage, Message, UserMessage
38+
from metagpt.schema import AIMessage, LongTermMemoryItem, Message, UserMessage
3839
from metagpt.strategy.experience_retriever import DummyExpRetriever, ExpRetriever
3940
from metagpt.strategy.planner import Planner
4041
from metagpt.tools.libs.browser import Browser
4142
from metagpt.tools.libs.editor import Editor
4243
from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender
4344
from metagpt.tools.tool_registry import register_tool
4445
from metagpt.utils.common import CodeParser, any_to_str, extract_and_encode_images
46+
from metagpt.utils.exceptions import handle_exception
4547
from metagpt.utils.repair_llm_raw_output import (
4648
RepairType,
4749
repair_escape_error,
@@ -86,6 +88,8 @@ class RoleZero(Role):
8688
command_rsp: str = "" # the raw string containing the commands
8789
commands: list[dict] = [] # commands to be executed
8890
memory_k: int = 20 # number of memories (messages) to use as historical context
91+
enable_longterm_memory: bool = True # whether to use longterm memory
92+
longterm_memory: RoleZeroLongTermMemory = None
8993
use_fixed_sop: bool = False
9094
requirements_constraints: str = "" # the constraints in user requirements
9195
use_summary: bool = True # whether to summarize at the end
@@ -140,6 +144,19 @@ def set_tool_execution(self) -> "RoleZero":
140144
self._update_tool_execution()
141145
return self
142146

147+
@model_validator(mode="after")
148+
def set_longterm_memory(self) -> "RoleZero":
149+
"""Set longterm memory.
150+
151+
If enable_longterm_memory is True and longterm_memory is not set, set it.
152+
The role name will be used as the collection name.
153+
"""
154+
155+
if self.enable_longterm_memory and not self.longterm_memory:
156+
self.longterm_memory = RoleZeroLongTermMemory(collection_name=self.name.replace(" ", ""))
157+
158+
return self
159+
143160
def _update_tool_execution(self):
144161
pass
145162

@@ -154,7 +171,7 @@ async def _think(self) -> bool:
154171
return False
155172

156173
if not self.planner.plan.goal:
157-
self.planner.plan.goal = self.get_memories()[-1].content
174+
self.planner.plan.goal = self._get_all_memories()[-1].content
158175
self.requirements_constraints = await AnalyzeRequirementsRestrictions().run(self.planner.plan.goal)
159176

160177
### 1. Experience ###
@@ -186,7 +203,7 @@ async def _think(self) -> bool:
186203
)
187204

188205
### Recent Observation ###
189-
memory = self.rc.memory.get(self.memory_k)
206+
memory = self._fetch_memories()
190207
memory = await self.parse_browser_actions(memory)
191208
memory = self.parse_images(memory)
192209

@@ -202,7 +219,7 @@ async def _think(self) -> bool:
202219

203220
self.command_rsp = await self._check_duplicates(req, self.command_rsp)
204221

205-
self.rc.memory.add(AIMessage(content=self.command_rsp))
222+
self._add_memory(AIMessage(content=self.command_rsp))
206223
return True
207224

208225
@exp_cache(context_builder=RoleZeroContextBuilder(), serializer=RoleZeroSerializer())
@@ -245,12 +262,12 @@ async def _act(self) -> Message:
245262
commands, ok = await self._parse_commands(self.command_rsp)
246263
if not ok:
247264
error_msg = commands
248-
self.rc.memory.add(UserMessage(content=error_msg))
265+
self._add_memory(UserMessage(content=error_msg))
249266
return error_msg
250267
logger.info(f"Commands: \n{commands}")
251268
outputs = await self._run_commands(commands)
252269
logger.info(f"Commands outputs: \n{outputs}")
253-
self.rc.memory.add(UserMessage(content=outputs))
270+
self._add_memory(UserMessage(content=outputs))
254271

255272
return AIMessage(
256273
content=f"I have finished the task, please mark my task as finished. Outputs: {outputs}",
@@ -303,7 +320,7 @@ async def _quick_think(self) -> Tuple[Message, str]:
303320
return rsp_msg, ""
304321

305322
# routing
306-
memory = self.get_memories(k=self.memory_k)
323+
memory = self._fetch_memories()
307324
context = self.llm.format_msg(memory + [UserMessage(content=QUICK_THINK_PROMPT)])
308325
async with ThoughtReporter() as reporter:
309326
await reporter.async_report({"type": "classify"})
@@ -328,7 +345,7 @@ async def _quick_think(self) -> Tuple[Message, str]:
328345
answer = await SearchEnhancedQA().run(query)
329346

330347
if answer:
331-
self.rc.memory.add(AIMessage(content=answer, cause_by=RunCommand))
348+
self._add_memory(AIMessage(content=answer, cause_by=RunCommand))
332349
await self.reply_to_human(content=answer)
333350
rsp_msg = AIMessage(
334351
content="Complete run",
@@ -339,7 +356,7 @@ async def _quick_think(self) -> Tuple[Message, str]:
339356
return rsp_msg, intent_result
340357

341358
async def _check_duplicates(self, req: list[dict], command_rsp: str):
342-
past_rsp = [mem.content for mem in self.rc.memory.get(self.memory_k)]
359+
past_rsp = [mem.content for mem in self._fetch_memories()]
343360
if command_rsp in past_rsp:
344361
# Normal response with thought contents are highly unlikely to reproduce
345362
# If an identical response is detected, it is a bad response, mostly due to LLM repeating generated content
@@ -479,7 +496,7 @@ def _get_plan_status(self) -> Tuple[str, str]:
479496

480497
def _retrieve_experience(self) -> str:
481498
"""Default implementation of experience retrieval. Can be overwritten in subclasses."""
482-
context = [str(msg) for msg in self.rc.memory.get(self.memory_k)]
499+
context = [str(msg) for msg in self._fetch_memories()]
483500
context = "\n\n".join(context)
484501
example = self.experience_retriever.retrieve(context=context)
485502
return example
@@ -504,9 +521,9 @@ async def reply_to_human(self, content: str) -> str:
504521

505522
async def _end(self):
506523
self._set_state(-1)
507-
memory = self.rc.memory.get(self.memory_k)
524+
memory = self._fetch_memories()
508525
# Ensure reply to the human before the "end" command is executed. Hard code k=5 for checking.
509-
if not any(["reply_to_human" in memory.content for memory in self.get_memories(k=5)]):
526+
if not any(["reply_to_human" in memory.content for memory in self._fetch_memories(k=5)]):
510527
logger.info("manually reply to human")
511528
pattern = r"\[Language Restrictions\](.*?)\n"
512529
match = re.search(pattern, self.requirements_constraints, re.DOTALL)
@@ -515,10 +532,95 @@ async def _end(self):
515532
await reporter.async_report({"type": "quick"})
516533
reply_content = await self.llm.aask(self.llm.format_msg(memory + [UserMessage(reply_to_human_prompt)]))
517534
await self.reply_to_human(content=reply_content)
518-
self.rc.memory.add(AIMessage(content=reply_content, cause_by=RunCommand))
535+
self._add_memory(AIMessage(content=reply_content, cause_by=RunCommand))
519536
outputs = ""
520537
# Summary of the Completed Task and Deliverables
521538
if self.use_summary:
522539
logger.info("end current run and summarize")
523540
outputs = await self.llm.aask(self.llm.format_msg(memory + [UserMessage(SUMMARY_PROMPT)]))
524541
return outputs
542+
543+
def _get_all_memories(self) -> list[Message]:
544+
return self._fetch_memories(k=0)
545+
546+
def _fetch_memories(self, k: Optional[int] = None) -> list[Message]:
547+
"""Fetches recent memories and optionally combines them with related long-term memories.
548+
549+
If long-term memory is not enabled or the last message is not from the user,
550+
it returns the recent memories without fetching from long-term memory.
551+
552+
Args:
553+
k (Optional[int]): The number of recent memories to fetch. If None, defaults to self.memory_k.
554+
555+
Returns:
556+
List[Message]: A list of messages representing the combined memories.
557+
"""
558+
559+
if k is None:
560+
k = self.memory_k
561+
562+
memories = self.rc.memory.get(k)
563+
564+
if not self._should_use_longterm_memory(k=k, k_memories=memories):
565+
return memories
566+
567+
related_memories = self.longterm_memory.fetch(memories[-1].content)
568+
logger.info(f"Fetched {len(related_memories)} long-term memories.")
569+
570+
if related_memories and self._is_first_message_from_ai(memories):
571+
memories = memories[1:]
572+
573+
final_memories = related_memories + memories
574+
575+
return final_memories
576+
577+
def _add_memory(self, message: Message):
578+
self.rc.memory.add(message)
579+
580+
if not self._should_use_longterm_memory():
581+
return
582+
583+
self._transfer_to_longterm_memory()
584+
585+
def _should_use_longterm_memory(self, k: int = None, k_memories: list[Message] = None) -> bool:
586+
"""Determines if long-term memory should be used.
587+
588+
Long-term memory is used if:
589+
- k is not 0.
590+
- k_memories is None or k_memories is not empty, and the last message is a user message.
591+
- Long-term memory usage is enabled.
592+
- The count of recent memories is greater than self.memory_k.
593+
"""
594+
595+
conds = [
596+
k != 0,
597+
k_memories is None or self._is_last_message_from_user(k_memories),
598+
self.enable_longterm_memory,
599+
self.rc.memory.count() > self.memory_k,
600+
]
601+
602+
return all(conds)
603+
604+
def _transfer_to_longterm_memory(self):
605+
item = self._get_longterm_memory_item()
606+
self.longterm_memory.add(item)
607+
608+
@handle_exception
609+
def _get_longterm_memory_item(self) -> Optional[LongTermMemoryItem]:
610+
"""Retrieves the most recent pair of user and AI messages before the last k messages."""
611+
612+
index = -(self.memory_k + 1)
613+
message = self.rc.memory.get_by_position(index)
614+
if not message.is_ai_message():
615+
return None
616+
617+
index = -(self.memory_k + 2)
618+
user_message = self.rc.memory.get_by_position(index)
619+
620+
return LongTermMemoryItem(user_message=user_message, ai_message=message)
621+
622+
def _is_last_message_from_user(self, memories: list[Message]) -> bool:
623+
return bool(memories and memories[-1].is_user_message())
624+
625+
def _is_first_message_from_ai(self, memories: list[Message]) -> bool:
626+
return bool(memories and memories[0].is_ai_message())

metagpt/schema.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,12 @@ def create_instruct_value(kvs: Dict[str, Any], class_name: str = "") -> BaseMode
408408
dynamic_class = create_model(class_name, **{key: (value.__class__, ...) for key, value in kvs.items()})
409409
return dynamic_class.model_validate(kvs)
410410

411+
def is_user_message(self):
412+
return self.role == "user"
413+
414+
def is_ai_message(self):
415+
return self.role == "assistant"
416+
411417

412418
class UserMessage(Message):
413419
"""便于支持OpenAI的消息
@@ -955,3 +961,11 @@ def __new__(cls, value, desc=None):
955961
obj._value_ = value
956962
obj.desc = desc
957963
return obj
964+
965+
966+
class LongTermMemoryItem(BaseModel):
967+
user_message: Message
968+
ai_message: Message
969+
970+
def rag_key(self) -> str:
971+
return self.user_message.content

0 commit comments

Comments
 (0)