@@ -26,6 +26,18 @@ class RedisMemory(MemoryBase):
2626 .. note:: All Redis keys used by this class will be prefixed by `prefix`
2727 (if provided) to support multi-tenant / multi-app isolation.
2828
29+ **Mark Index Storage:**
30+
31+ This class maintains a `marks_index` (Redis Set) to efficiently track all
32+ mark names within a session. When a mark is created via `add_mark()`, the
33+ mark name is added to this set. This allows quick retrieval of all marks
34+ without scanning all Redis keys. The marks_index key pattern is:
35+ ``user_id:{user_id}:session:{session_id}:marks_index``
36+
37+ Each individual mark stores its associated message IDs in a separate Redis
38+ List with the key pattern:
39+ ``user_id:{user_id}:session:{session_id}:mark:{mark}``
40+
2941 """
3042
3143 SESSION_KEY = "user_id:{user_id}:session:{session_id}:messages"
@@ -45,6 +57,11 @@ class RedisMemory(MemoryBase):
4557 MESSAGE_KEY = "user_id:{user_id}:session:{session_id}:msg:{msg_id}"
4658 """Redis key pattern (without prefix) for storing message payload data."""
4759
60+ MARKS_INDEX_KEY = "user_id:{user_id}:session:{session_id}:marks_index"
61+ """Redis key pattern (without prefix) for storing all mark names as a set.
62+ This is used to avoid scanning all keys to find marks.
63+ """
64+
4865 def __init__ (
4966 self ,
5067 session_id : str = "default_session" ,
@@ -208,6 +225,38 @@ def _get_mark_pattern(self) -> str:
208225 mark = "*" ,
209226 )
210227
228+ def _get_marks_index_key (self ) -> str :
229+ """Get the Redis key for the marks index set.
230+
231+ Returns:
232+ `str`:
233+ The Redis key for storing all mark names as a set.
234+ """
235+ return self .key_prefix + self .MARKS_INDEX_KEY .format (
236+ user_id = self .user_id ,
237+ session_id = self .session_id ,
238+ )
239+
240+ def _extract_mark_from_key (self , mark_key : str ) -> str :
241+ """Extract the mark name from a full mark key.
242+
243+ Args:
244+ mark_key (`str`):
245+ The full Redis key for a mark.
246+
247+ Returns:
248+ `str`:
249+ The mark name extracted from the key.
250+ """
251+ # Remove the prefix and the base pattern to get the mark name
252+ # Example: "prefix:user_id:xxx:session:yyy:mark:my_mark" -> "my_mark"
253+ prefix_pattern = self .key_prefix + self .MARK_KEY .format (
254+ user_id = self .user_id ,
255+ session_id = self .session_id ,
256+ mark = "" ,
257+ )
258+ return mark_key .replace (prefix_pattern , "" )
259+
211260 def _get_message_key (self , msg_id : str ) -> str :
212261 """Get the Redis key for a specific message.
213262
@@ -263,6 +312,70 @@ async def _refresh_session_ttl(
263312 if should_execute :
264313 await pipe .execute ()
265314
315+ async def _scan_and_migrate_marks (self ) -> list [str ]:
316+ """Scan all mark keys and migrate them to the marks index.
317+
318+ This method is only called once for old data that doesn't have
319+ a marks index yet. After migration, the marks index will be
320+ maintained automatically.
321+
322+ Returns:
323+ `list[str]`:
324+ The list of all mark keys found.
325+ """
326+ mark_keys = []
327+ cursor = 0
328+ while True :
329+ cursor , keys = await self ._client .scan (
330+ cursor ,
331+ match = self ._get_mark_pattern (),
332+ count = 50 ,
333+ )
334+ keys = self ._decode_list (keys )
335+ mark_keys .extend (keys )
336+ if cursor == 0 :
337+ break
338+
339+ # Build the marks index
340+ if mark_keys :
341+ pipe = self ._client .pipeline ()
342+ for mark_key in mark_keys :
343+ mark = self ._extract_mark_from_key (mark_key )
344+ await pipe .sadd (self ._get_marks_index_key (), mark )
345+ await pipe .execute ()
346+
347+ return mark_keys
348+
349+ async def _get_all_mark_keys (self ) -> list [str ]:
350+ """Get all mark keys, compatible with both old and new data.
351+
352+ For new data (with marks index), this method uses the index directly.
353+ For old data (without marks index), this method scans once and
354+ migrates to the new structure.
355+
356+ Returns:
357+ `list[str]`:
358+ The list of all mark keys.
359+ """
360+ marks_index_key = self ._get_marks_index_key ()
361+
362+ # Try to read from the index first
363+ marks = await self ._client .smembers (marks_index_key )
364+ if marks :
365+ # Index exists, use it
366+ marks = self ._decode_list (list (marks ))
367+ return [self ._get_mark_key (mark ) for mark in marks ]
368+
369+ # Index doesn't exist, check if this is a new session
370+ session_exists = await self ._client .exists (self ._get_session_key ())
371+ if not session_exists :
372+ # New session, no data at all, return empty
373+ return []
374+
375+ # Old session without index, need to scan and migrate (only once)
376+ mark_keys = await self ._scan_and_migrate_marks ()
377+ return mark_keys
378+
266379 async def get_memory (
267380 self ,
268381 mark : str | None = None ,
@@ -432,6 +545,8 @@ async def add(
432545 # Record the marks if provided
433546 for mark in mark_list :
434547 await pipe .rpush (self ._get_mark_key (mark ), m .id )
548+ # Maintain the marks index
549+ await pipe .sadd (self ._get_marks_index_key (), mark )
435550
436551 # Refresh TTLs
437552 await self ._refresh_session_ttl (pipe = pipe )
@@ -456,20 +571,8 @@ async def delete(
456571 if not msg_ids :
457572 return 0
458573
459- # Get all mark keys once before the pipeline
460- mark_keys = []
461- cursor = 0
462- while True :
463- cursor , keys = await self ._client .scan (
464- cursor ,
465- match = self ._get_mark_pattern (),
466- count = 50 ,
467- )
468- # Decode keys if they are bytes
469- keys = self ._decode_list (keys )
470- mark_keys .extend (keys )
471- if cursor == 0 :
472- break
574+ # Get all mark keys using the new method (compatible with old data)
575+ mark_keys = await self ._get_all_mark_keys ()
473576
474577 pipe = self ._client .pipeline ()
475578 for msg_id in msg_ids :
@@ -539,6 +642,9 @@ async def delete_by_mark(
539642 # Delete the mark list
540643 await self ._client .delete (mark_key )
541644
645+ # Remove from the marks index
646+ await self ._client .srem (self ._get_marks_index_key (), m )
647+
542648 # Refresh TTLs
543649 await self ._refresh_session_ttl ()
544650
@@ -549,20 +655,8 @@ async def clear(self) -> None:
549655 msg_ids = await self ._client .lrange (self ._get_session_key (), 0 , - 1 )
550656 msg_ids = self ._decode_list (msg_ids )
551657
552- # Get all mark keys using SCAN
553- mark_keys = []
554- cursor = 0
555- while True :
556- cursor , keys = await self ._client .scan (
557- cursor ,
558- match = self ._get_mark_pattern (),
559- count = 50 ,
560- )
561- # Decode keys if they are bytes
562- keys = self ._decode_list (keys )
563- mark_keys .extend (keys )
564- if cursor == 0 :
565- break
658+ # Get all mark keys using the new method (compatible with old data)
659+ mark_keys = await self ._get_all_mark_keys ()
566660
567661 pipe = self ._client .pipeline ()
568662
@@ -577,6 +671,9 @@ async def clear(self) -> None:
577671 for mark_key in mark_keys :
578672 await pipe .delete (mark_key )
579673
674+ # Delete the marks index
675+ await pipe .delete (self ._get_marks_index_key ())
676+
580677 await pipe .execute ()
581678
582679 async def size (self ) -> int :
@@ -622,22 +719,19 @@ async def update_messages_mark(
622719 The number of messages updated.
623720 """
624721 # Determine which message IDs to update
625- if old_mark is not None :
626- # Get message IDs from the old mark list
627- mark_msg_ids = await self ._client .lrange (
628- self ._get_mark_key (old_mark ),
629- 0 ,
630- - 1 ,
631- )
632- mark_msg_ids = self ._decode_list (mark_msg_ids )
633- else :
634- # Get all message IDs from the session
635- mark_msg_ids = await self ._client .lrange (
636- self ._get_session_key (),
637- 0 ,
638- - 1 ,
639- )
640- mark_msg_ids = self ._decode_list (mark_msg_ids )
722+ # Get source key based on old_mark
723+ source_key = (
724+ self ._get_mark_key (old_mark )
725+ if old_mark is not None
726+ else self ._get_session_key ()
727+ )
728+ mark_msg_ids = await self ._client .lrange (source_key , 0 , - 1 )
729+ mark_msg_ids = self ._decode_list (mark_msg_ids )
730+
731+ # Check if we're removing all messages from old_mark
732+ removing_all_from_old_mark = old_mark is not None and (
733+ msg_ids is None or all (mid in set (msg_ids ) for mid in mark_msg_ids )
734+ )
641735
642736 # Filter by msg_ids if provided
643737 if msg_ids is not None :
@@ -661,31 +755,33 @@ async def update_messages_mark(
661755 updated_count = 0
662756
663757 for msg_id in mark_msg_ids :
664- # If new_mark is None, remove the old_mark
665- if new_mark is None :
666- if old_mark is not None :
667- await pipe .lrem (
668- self ._get_mark_key (old_mark ),
669- 0 ,
670- msg_id ,
671- )
672- updated_count += 1
673- else :
674- # Remove from old_mark list if applicable
675- if old_mark is not None :
676- await pipe .lrem (
677- self ._get_mark_key (old_mark ),
678- 0 ,
679- msg_id ,
680- )
681-
682- # Add to new_mark list only if not already present
683- if msg_id not in existing_ids_set and new_mark_key is not None :
684- await pipe .rpush (new_mark_key , msg_id )
685- existing_ids_set .add (msg_id )
686-
758+ # Remove from old_mark list if applicable
759+ if old_mark is not None :
760+ await pipe .lrem (
761+ self ._get_mark_key (old_mark ),
762+ 0 ,
763+ msg_id ,
764+ )
765+
766+ # Add to new_mark list only if not already present
767+ if new_mark is not None and msg_id not in existing_ids_set :
768+ await pipe .rpush (new_mark_key , msg_id )
769+ existing_ids_set .add (msg_id )
770+ # Maintain the marks index
771+ await pipe .sadd (self ._get_marks_index_key (), new_mark )
772+
773+ # Count update only if we actually did something
774+ if old_mark is not None or new_mark is not None :
687775 updated_count += 1
688776
777+ # Clean up old_mark only if we removed ALL messages from it
778+ if old_mark is not None and removing_all_from_old_mark :
779+ old_mark_key = self ._get_mark_key (old_mark )
780+ # After lrem operations, the old mark list will be empty
781+ # Delete the mark key and remove from index
782+ await pipe .delete (old_mark_key )
783+ await pipe .srem (self ._get_marks_index_key (), old_mark )
784+
689785 await self ._refresh_session_ttl (pipe = pipe )
690786
691787 await pipe .execute ()
0 commit comments