Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/google/adk/sessions/redis_memory_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from .session import Session
from .state import State

import base64

logger = logging.getLogger('google_adk.' + __name__)

DEFAULT_EXPIRATION = 60 * 60 # 1 hour
Expand All @@ -46,7 +48,7 @@ def _json_serializer(obj):
return list(obj)
if isinstance(obj, bytes):
try:
return obj.decode("utf-8")
return base64.b64encode(obj).decode("ascii")
except Exception:
return repr(obj)
if isinstance(obj, (datetime.datetime, datetime.date)):
Expand All @@ -62,6 +64,19 @@ def _json_serializer(obj):
return "Infinity" if obj > 0 else "-Infinity"
return str(obj)

def _restore_bytes(obj):
if isinstance(obj, dict):
return {k: _restore_bytes(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [_restore_bytes(v) for v in obj]
elif isinstance(obj, str):
try:
# intenta decodificar base64
data = base64.b64decode(obj, validate=True)
return data
except Exception:
return obj
return obj

class RedisMemorySessionService(BaseSessionService):
"""A Redis-backed implementation of the session service."""
Expand Down Expand Up @@ -315,7 +330,8 @@ async def _load_sessions(self, app_name: str, user_id: str) -> dict[str, dict]:
raw = await self.cache.get(key)
if not raw:
return {}
return json.loads(raw.decode())
raw_data = json.loads(raw.decode())
return raw_data

async def _save_sessions(self, app_name: str, user_id: str, sessions: dict[str, Any]):
key = f"{State.APP_PREFIX}{app_name}:{user_id}"
Expand Down
Loading