Skip to content

Commit 606811b

Browse files
authored
perf(redis_memory): avoid scanning operation in the RedisMemory class to improve the efficiency (#1211)
1 parent 734b6d3 commit 606811b

File tree

1 file changed

+163
-67
lines changed

1 file changed

+163
-67
lines changed

src/agentscope/memory/_working_memory/_redis_memory.py

Lines changed: 163 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,18 @@ class RedisMemory(MemoryBase):
2626
.. note:: All Redis keys used by this class will be prefixed by `prefix`
2727
(if provided) to support multi-tenant / multi-app isolation.
2828
29+
**Mark Index Storage:**
30+
31+
This class maintains a `marks_index` (Redis Set) to efficiently track all
32+
mark names within a session. When a mark is created via `add_mark()`, the
33+
mark name is added to this set. This allows quick retrieval of all marks
34+
without scanning all Redis keys. The marks_index key pattern is:
35+
``user_id:{user_id}:session:{session_id}:marks_index``
36+
37+
Each individual mark stores its associated message IDs in a separate Redis
38+
List with the key pattern:
39+
``user_id:{user_id}:session:{session_id}:mark:{mark}``
40+
2941
"""
3042

3143
SESSION_KEY = "user_id:{user_id}:session:{session_id}:messages"
@@ -45,6 +57,11 @@ class RedisMemory(MemoryBase):
4557
MESSAGE_KEY = "user_id:{user_id}:session:{session_id}:msg:{msg_id}"
4658
"""Redis key pattern (without prefix) for storing message payload data."""
4759

60+
MARKS_INDEX_KEY = "user_id:{user_id}:session:{session_id}:marks_index"
61+
"""Redis key pattern (without prefix) for storing all mark names as a set.
62+
This is used to avoid scanning all keys to find marks.
63+
"""
64+
4865
def __init__(
4966
self,
5067
session_id: str = "default_session",
@@ -208,6 +225,38 @@ def _get_mark_pattern(self) -> str:
208225
mark="*",
209226
)
210227

228+
def _get_marks_index_key(self) -> str:
229+
"""Get the Redis key for the marks index set.
230+
231+
Returns:
232+
`str`:
233+
The Redis key for storing all mark names as a set.
234+
"""
235+
return self.key_prefix + self.MARKS_INDEX_KEY.format(
236+
user_id=self.user_id,
237+
session_id=self.session_id,
238+
)
239+
240+
def _extract_mark_from_key(self, mark_key: str) -> str:
241+
"""Extract the mark name from a full mark key.
242+
243+
Args:
244+
mark_key (`str`):
245+
The full Redis key for a mark.
246+
247+
Returns:
248+
`str`:
249+
The mark name extracted from the key.
250+
"""
251+
# Remove the prefix and the base pattern to get the mark name
252+
# Example: "prefix:user_id:xxx:session:yyy:mark:my_mark" -> "my_mark"
253+
prefix_pattern = self.key_prefix + self.MARK_KEY.format(
254+
user_id=self.user_id,
255+
session_id=self.session_id,
256+
mark="",
257+
)
258+
return mark_key.replace(prefix_pattern, "")
259+
211260
def _get_message_key(self, msg_id: str) -> str:
212261
"""Get the Redis key for a specific message.
213262
@@ -263,6 +312,70 @@ async def _refresh_session_ttl(
263312
if should_execute:
264313
await pipe.execute()
265314

315+
async def _scan_and_migrate_marks(self) -> list[str]:
316+
"""Scan all mark keys and migrate them to the marks index.
317+
318+
This method is only called once for old data that doesn't have
319+
a marks index yet. After migration, the marks index will be
320+
maintained automatically.
321+
322+
Returns:
323+
`list[str]`:
324+
The list of all mark keys found.
325+
"""
326+
mark_keys = []
327+
cursor = 0
328+
while True:
329+
cursor, keys = await self._client.scan(
330+
cursor,
331+
match=self._get_mark_pattern(),
332+
count=50,
333+
)
334+
keys = self._decode_list(keys)
335+
mark_keys.extend(keys)
336+
if cursor == 0:
337+
break
338+
339+
# Build the marks index
340+
if mark_keys:
341+
pipe = self._client.pipeline()
342+
for mark_key in mark_keys:
343+
mark = self._extract_mark_from_key(mark_key)
344+
await pipe.sadd(self._get_marks_index_key(), mark)
345+
await pipe.execute()
346+
347+
return mark_keys
348+
349+
async def _get_all_mark_keys(self) -> list[str]:
350+
"""Get all mark keys, compatible with both old and new data.
351+
352+
For new data (with marks index), this method uses the index directly.
353+
For old data (without marks index), this method scans once and
354+
migrates to the new structure.
355+
356+
Returns:
357+
`list[str]`:
358+
The list of all mark keys.
359+
"""
360+
marks_index_key = self._get_marks_index_key()
361+
362+
# Try to read from the index first
363+
marks = await self._client.smembers(marks_index_key)
364+
if marks:
365+
# Index exists, use it
366+
marks = self._decode_list(list(marks))
367+
return [self._get_mark_key(mark) for mark in marks]
368+
369+
# Index doesn't exist, check if this is a new session
370+
session_exists = await self._client.exists(self._get_session_key())
371+
if not session_exists:
372+
# New session, no data at all, return empty
373+
return []
374+
375+
# Old session without index, need to scan and migrate (only once)
376+
mark_keys = await self._scan_and_migrate_marks()
377+
return mark_keys
378+
266379
async def get_memory(
267380
self,
268381
mark: str | None = None,
@@ -432,6 +545,8 @@ async def add(
432545
# Record the marks if provided
433546
for mark in mark_list:
434547
await pipe.rpush(self._get_mark_key(mark), m.id)
548+
# Maintain the marks index
549+
await pipe.sadd(self._get_marks_index_key(), mark)
435550

436551
# Refresh TTLs
437552
await self._refresh_session_ttl(pipe=pipe)
@@ -456,20 +571,8 @@ async def delete(
456571
if not msg_ids:
457572
return 0
458573

459-
# Get all mark keys once before the pipeline
460-
mark_keys = []
461-
cursor = 0
462-
while True:
463-
cursor, keys = await self._client.scan(
464-
cursor,
465-
match=self._get_mark_pattern(),
466-
count=50,
467-
)
468-
# Decode keys if they are bytes
469-
keys = self._decode_list(keys)
470-
mark_keys.extend(keys)
471-
if cursor == 0:
472-
break
574+
# Get all mark keys using the new method (compatible with old data)
575+
mark_keys = await self._get_all_mark_keys()
473576

474577
pipe = self._client.pipeline()
475578
for msg_id in msg_ids:
@@ -539,6 +642,9 @@ async def delete_by_mark(
539642
# Delete the mark list
540643
await self._client.delete(mark_key)
541644

645+
# Remove from the marks index
646+
await self._client.srem(self._get_marks_index_key(), m)
647+
542648
# Refresh TTLs
543649
await self._refresh_session_ttl()
544650

@@ -549,20 +655,8 @@ async def clear(self) -> None:
549655
msg_ids = await self._client.lrange(self._get_session_key(), 0, -1)
550656
msg_ids = self._decode_list(msg_ids)
551657

552-
# Get all mark keys using SCAN
553-
mark_keys = []
554-
cursor = 0
555-
while True:
556-
cursor, keys = await self._client.scan(
557-
cursor,
558-
match=self._get_mark_pattern(),
559-
count=50,
560-
)
561-
# Decode keys if they are bytes
562-
keys = self._decode_list(keys)
563-
mark_keys.extend(keys)
564-
if cursor == 0:
565-
break
658+
# Get all mark keys using the new method (compatible with old data)
659+
mark_keys = await self._get_all_mark_keys()
566660

567661
pipe = self._client.pipeline()
568662

@@ -577,6 +671,9 @@ async def clear(self) -> None:
577671
for mark_key in mark_keys:
578672
await pipe.delete(mark_key)
579673

674+
# Delete the marks index
675+
await pipe.delete(self._get_marks_index_key())
676+
580677
await pipe.execute()
581678

582679
async def size(self) -> int:
@@ -622,22 +719,19 @@ async def update_messages_mark(
622719
The number of messages updated.
623720
"""
624721
# Determine which message IDs to update
625-
if old_mark is not None:
626-
# Get message IDs from the old mark list
627-
mark_msg_ids = await self._client.lrange(
628-
self._get_mark_key(old_mark),
629-
0,
630-
-1,
631-
)
632-
mark_msg_ids = self._decode_list(mark_msg_ids)
633-
else:
634-
# Get all message IDs from the session
635-
mark_msg_ids = await self._client.lrange(
636-
self._get_session_key(),
637-
0,
638-
-1,
639-
)
640-
mark_msg_ids = self._decode_list(mark_msg_ids)
722+
# Get source key based on old_mark
723+
source_key = (
724+
self._get_mark_key(old_mark)
725+
if old_mark is not None
726+
else self._get_session_key()
727+
)
728+
mark_msg_ids = await self._client.lrange(source_key, 0, -1)
729+
mark_msg_ids = self._decode_list(mark_msg_ids)
730+
731+
# Check if we're removing all messages from old_mark
732+
removing_all_from_old_mark = old_mark is not None and (
733+
msg_ids is None or all(mid in set(msg_ids) for mid in mark_msg_ids)
734+
)
641735

642736
# Filter by msg_ids if provided
643737
if msg_ids is not None:
@@ -661,31 +755,33 @@ async def update_messages_mark(
661755
updated_count = 0
662756

663757
for msg_id in mark_msg_ids:
664-
# If new_mark is None, remove the old_mark
665-
if new_mark is None:
666-
if old_mark is not None:
667-
await pipe.lrem(
668-
self._get_mark_key(old_mark),
669-
0,
670-
msg_id,
671-
)
672-
updated_count += 1
673-
else:
674-
# Remove from old_mark list if applicable
675-
if old_mark is not None:
676-
await pipe.lrem(
677-
self._get_mark_key(old_mark),
678-
0,
679-
msg_id,
680-
)
681-
682-
# Add to new_mark list only if not already present
683-
if msg_id not in existing_ids_set and new_mark_key is not None:
684-
await pipe.rpush(new_mark_key, msg_id)
685-
existing_ids_set.add(msg_id)
686-
758+
# Remove from old_mark list if applicable
759+
if old_mark is not None:
760+
await pipe.lrem(
761+
self._get_mark_key(old_mark),
762+
0,
763+
msg_id,
764+
)
765+
766+
# Add to new_mark list only if not already present
767+
if new_mark is not None and msg_id not in existing_ids_set:
768+
await pipe.rpush(new_mark_key, msg_id)
769+
existing_ids_set.add(msg_id)
770+
# Maintain the marks index
771+
await pipe.sadd(self._get_marks_index_key(), new_mark)
772+
773+
# Count update only if we actually did something
774+
if old_mark is not None or new_mark is not None:
687775
updated_count += 1
688776

777+
# Clean up old_mark only if we removed ALL messages from it
778+
if old_mark is not None and removing_all_from_old_mark:
779+
old_mark_key = self._get_mark_key(old_mark)
780+
# After lrem operations, the old mark list will be empty
781+
# Delete the mark key and remove from index
782+
await pipe.delete(old_mark_key)
783+
await pipe.srem(self._get_marks_index_key(), old_mark)
784+
689785
await self._refresh_session_ttl(pipe=pipe)
690786

691787
await pipe.execute()

0 commit comments

Comments
 (0)