1818from metagpt .exp_pool .context_builders import RoleZeroContextBuilder
1919from metagpt .exp_pool .serializers import RoleZeroSerializer
2020from metagpt .logs import logger
21+ from metagpt .memory .role_zero_memory import RoleZeroLongTermMemory
2122from metagpt .prompts .di .role_zero import (
2223 ASK_HUMAN_COMMAND ,
2324 CMD_PROMPT ,
3435 THOUGHT_GUIDANCE ,
3536)
3637from metagpt .roles import Role
37- from metagpt .schema import AIMessage , Message , UserMessage
38+ from metagpt .schema import AIMessage , LongTermMemoryItem , Message , UserMessage
3839from metagpt .strategy .experience_retriever import DummyExpRetriever , ExpRetriever
3940from metagpt .strategy .planner import Planner
4041from metagpt .tools .libs .browser import Browser
4142from metagpt .tools .libs .editor import Editor
4243from metagpt .tools .tool_recommend import BM25ToolRecommender , ToolRecommender
4344from metagpt .tools .tool_registry import register_tool
4445from metagpt .utils .common import CodeParser , any_to_str , extract_and_encode_images
46+ from metagpt .utils .exceptions import handle_exception
4547from metagpt .utils .repair_llm_raw_output import (
4648 RepairType ,
4749 repair_escape_error ,
@@ -86,6 +88,8 @@ class RoleZero(Role):
8688 command_rsp : str = "" # the raw string containing the commands
8789 commands : list [dict ] = [] # commands to be executed
8890 memory_k : int = 20 # number of memories (messages) to use as historical context
91+ enable_longterm_memory : bool = True # whether to use longterm memory
92+ longterm_memory : RoleZeroLongTermMemory = None
8993 use_fixed_sop : bool = False
9094 requirements_constraints : str = "" # the constraints in user requirements
9195 use_summary : bool = True # whether to summarize at the end
@@ -140,6 +144,19 @@ def set_tool_execution(self) -> "RoleZero":
140144 self ._update_tool_execution ()
141145 return self
142146
147+ @model_validator (mode = "after" )
148+ def set_longterm_memory (self ) -> "RoleZero" :
149+ """Set longterm memory.
150+
151+ If enable_longterm_memory is True and longterm_memory is not set, set it.
152+ The role name will be used as the collection name.
153+ """
154+
155+ if self .enable_longterm_memory and not self .longterm_memory :
156+ self .longterm_memory = RoleZeroLongTermMemory (collection_name = self .name .replace (" " , "" ))
157+
158+ return self
159+
143160 def _update_tool_execution (self ):
144161 pass
145162
@@ -154,7 +171,7 @@ async def _think(self) -> bool:
154171 return False
155172
156173 if not self .planner .plan .goal :
157- self .planner .plan .goal = self .get_memories ()[- 1 ].content
174+ self .planner .plan .goal = self ._get_all_memories ()[- 1 ].content
158175 self .requirements_constraints = await AnalyzeRequirementsRestrictions ().run (self .planner .plan .goal )
159176
160177 ### 1. Experience ###
@@ -186,7 +203,7 @@ async def _think(self) -> bool:
186203 )
187204
188205 ### Recent Observation ###
189- memory = self .rc . memory . get ( self . memory_k )
206+ memory = self ._fetch_memories ( )
190207 memory = await self .parse_browser_actions (memory )
191208 memory = self .parse_images (memory )
192209
@@ -202,7 +219,7 @@ async def _think(self) -> bool:
202219
203220 self .command_rsp = await self ._check_duplicates (req , self .command_rsp )
204221
205- self .rc . memory . add (AIMessage (content = self .command_rsp ))
222+ self ._add_memory (AIMessage (content = self .command_rsp ))
206223 return True
207224
208225 @exp_cache (context_builder = RoleZeroContextBuilder (), serializer = RoleZeroSerializer ())
@@ -245,12 +262,12 @@ async def _act(self) -> Message:
245262 commands , ok = await self ._parse_commands (self .command_rsp )
246263 if not ok :
247264 error_msg = commands
248- self .rc . memory . add (UserMessage (content = error_msg ))
265+ self ._add_memory (UserMessage (content = error_msg ))
249266 return error_msg
250267 logger .info (f"Commands: \n { commands } " )
251268 outputs = await self ._run_commands (commands )
252269 logger .info (f"Commands outputs: \n { outputs } " )
253- self .rc . memory . add (UserMessage (content = outputs ))
270+ self ._add_memory (UserMessage (content = outputs ))
254271
255272 return AIMessage (
256273 content = f"I have finished the task, please mark my task as finished. Outputs: { outputs } " ,
@@ -303,7 +320,7 @@ async def _quick_think(self) -> Tuple[Message, str]:
303320 return rsp_msg , ""
304321
305322 # routing
306- memory = self .get_memories ( k = self . memory_k )
323+ memory = self ._fetch_memories ( )
307324 context = self .llm .format_msg (memory + [UserMessage (content = QUICK_THINK_PROMPT )])
308325 async with ThoughtReporter () as reporter :
309326 await reporter .async_report ({"type" : "classify" })
@@ -328,7 +345,7 @@ async def _quick_think(self) -> Tuple[Message, str]:
328345 answer = await SearchEnhancedQA ().run (query )
329346
330347 if answer :
331- self .rc . memory . add (AIMessage (content = answer , cause_by = RunCommand ))
348+ self ._add_memory (AIMessage (content = answer , cause_by = RunCommand ))
332349 await self .reply_to_human (content = answer )
333350 rsp_msg = AIMessage (
334351 content = "Complete run" ,
@@ -339,7 +356,7 @@ async def _quick_think(self) -> Tuple[Message, str]:
339356 return rsp_msg , intent_result
340357
341358 async def _check_duplicates (self , req : list [dict ], command_rsp : str ):
342- past_rsp = [mem .content for mem in self .rc . memory . get ( self . memory_k )]
359+ past_rsp = [mem .content for mem in self ._fetch_memories ( )]
343360 if command_rsp in past_rsp :
344361 # Normal response with thought contents are highly unlikely to reproduce
345362 # If an identical response is detected, it is a bad response, mostly due to LLM repeating generated content
@@ -479,7 +496,7 @@ def _get_plan_status(self) -> Tuple[str, str]:
479496
480497 def _retrieve_experience (self ) -> str :
481498 """Default implementation of experience retrieval. Can be overwritten in subclasses."""
482- context = [str (msg ) for msg in self .rc . memory . get ( self . memory_k )]
499+ context = [str (msg ) for msg in self ._fetch_memories ( )]
483500 context = "\n \n " .join (context )
484501 example = self .experience_retriever .retrieve (context = context )
485502 return example
@@ -504,9 +521,9 @@ async def reply_to_human(self, content: str) -> str:
504521
505522 async def _end (self ):
506523 self ._set_state (- 1 )
507- memory = self .rc . memory . get ( self . memory_k )
524+ memory = self ._fetch_memories ( )
508525 # Ensure reply to the human before the "end" command is executed. Hard code k=5 for checking.
509- if not any (["reply_to_human" in memory .content for memory in self .get_memories (k = 5 )]):
526+ if not any (["reply_to_human" in memory .content for memory in self ._fetch_memories (k = 5 )]):
510527 logger .info ("manually reply to human" )
511528 pattern = r"\[Language Restrictions\](.*?)\n"
512529 match = re .search (pattern , self .requirements_constraints , re .DOTALL )
@@ -515,10 +532,95 @@ async def _end(self):
515532 await reporter .async_report ({"type" : "quick" })
516533 reply_content = await self .llm .aask (self .llm .format_msg (memory + [UserMessage (reply_to_human_prompt )]))
517534 await self .reply_to_human (content = reply_content )
518- self .rc . memory . add (AIMessage (content = reply_content , cause_by = RunCommand ))
535+ self ._add_memory (AIMessage (content = reply_content , cause_by = RunCommand ))
519536 outputs = ""
520537 # Summary of the Completed Task and Deliverables
521538 if self .use_summary :
522539 logger .info ("end current run and summarize" )
523540 outputs = await self .llm .aask (self .llm .format_msg (memory + [UserMessage (SUMMARY_PROMPT )]))
524541 return outputs
542+
543+ def _get_all_memories (self ) -> list [Message ]:
544+ return self ._fetch_memories (k = 0 )
545+
546+ def _fetch_memories (self , k : Optional [int ] = None ) -> list [Message ]:
547+ """Fetches recent memories and optionally combines them with related long-term memories.
548+
549+ If long-term memory is not enabled or the last message is not from the user,
550+ it returns the recent memories without fetching from long-term memory.
551+
552+ Args:
553+ k (Optional[int]): The number of recent memories to fetch. If None, defaults to self.memory_k.
554+
555+ Returns:
556+ List[Message]: A list of messages representing the combined memories.
557+ """
558+
559+ if k is None :
560+ k = self .memory_k
561+
562+ memories = self .rc .memory .get (k )
563+
564+ if not self ._should_use_longterm_memory (k = k , k_memories = memories ):
565+ return memories
566+
567+ related_memories = self .longterm_memory .fetch (memories [- 1 ].content )
568+ logger .info (f"Fetched { len (related_memories )} long-term memories." )
569+
570+ if related_memories and self ._is_first_message_from_ai (memories ):
571+ memories = memories [1 :]
572+
573+ final_memories = related_memories + memories
574+
575+ return final_memories
576+
577+ def _add_memory (self , message : Message ):
578+ self .rc .memory .add (message )
579+
580+ if not self ._should_use_longterm_memory ():
581+ return
582+
583+ self ._transfer_to_longterm_memory ()
584+
585+ def _should_use_longterm_memory (self , k : int = None , k_memories : list [Message ] = None ) -> bool :
586+ """Determines if long-term memory should be used.
587+
588+ Long-term memory is used if:
589+ - k is not 0.
590+ - k_memories is None or k_memories is not empty, and the last message is a user message.
591+ - Long-term memory usage is enabled.
592+ - The count of recent memories is greater than self.memory_k.
593+ """
594+
595+ conds = [
596+ k != 0 ,
597+ k_memories is None or self ._is_last_message_from_user (k_memories ),
598+ self .enable_longterm_memory ,
599+ self .rc .memory .count () > self .memory_k ,
600+ ]
601+
602+ return all (conds )
603+
604+ def _transfer_to_longterm_memory (self ):
605+ item = self ._get_longterm_memory_item ()
606+ self .longterm_memory .add (item )
607+
608+ @handle_exception
609+ def _get_longterm_memory_item (self ) -> Optional [LongTermMemoryItem ]:
610+ """Retrieves the most recent pair of user and AI messages before the last k messages."""
611+
612+ index = - (self .memory_k + 1 )
613+ message = self .rc .memory .get_by_position (index )
614+ if not message .is_ai_message ():
615+ return None
616+
617+ index = - (self .memory_k + 2 )
618+ user_message = self .rc .memory .get_by_position (index )
619+
620+ return LongTermMemoryItem (user_message = user_message , ai_message = message )
621+
622+ def _is_last_message_from_user (self , memories : list [Message ]) -> bool :
623+ return bool (memories and memories [- 1 ].is_user_message ())
624+
625+ def _is_first_message_from_ai (self , memories : list [Message ]) -> bool :
626+ return bool (memories and memories [0 ].is_ai_message ())
0 commit comments