diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..ecb2dcaf --- /dev/null +++ b/.gitignore @@ -0,0 +1,80 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual Environment +venv/ +ENV/ +env/ +.env +.venv +env.bak/ +venv.bak/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo +.DS_Store + +# Testing +.coverage +htmlcov/ +.tox/ +.nox/ +.pytest_cache/ +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ + +# Jupyter Notebook +.ipynb_checkpoints + +# Logs +*.log +logs/ +log/ + +# Local development settings +.env.local +.env.development.local +.env.test.local +.env.production.local +uv.lock + +# Google Cloud specific +.gcloudignore +.gcloudignore.local + +# Documentation +docs/_build/ +site/ + +# Misc +Thumbs.db +*.bak +*.tmp +*.temp diff --git a/pyproject.toml b/pyproject.toml index e5be0da6..b71b1197 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,9 +25,11 @@ classifiers = [ # List of https://pypi.org/classifiers/ ] dependencies = [ # go/keep-sorted start - "google-genai>=1.21.1, <2.0.0", # Google GenAI SDK - "google-adk", # Google ADK + "google-genai>=1.21.1, <2.0.0", # Google GenAI SDK + "google-adk", # Google ADK + "redis>=5.0.0, <6.0.0", # Redis for session storage # go/keep-sorted end + "orjson>=3.11.3", ] dynamic = ["version"] @@ -61,6 +63,12 @@ pyink-annotation-pragmas = [ requires = ["flit_core >=3.8,<4"] build-backend = "flit_core.buildapi" +[dependency-groups] +dev = [ + "pytest>=8.4.2", + "pytest-asyncio>=1.2.0", +] + [tool.flit.sdist] include = ['src/**/*', 'README.md', 'pyproject.toml', 'LICENSE'] diff --git a/src/google/adk_community/sessions/__init__.py b/src/google/adk_community/sessions/__init__.py new file mode 100644 index 00000000..3748e39f --- /dev/null +++ b/src/google/adk_community/sessions/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Community session services for ADK.""" + +from .redis_session_service import RedisSessionService + +__all__ = ["RedisMemorySessionService"] diff --git a/src/google/adk_community/sessions/redis_session_service.py b/src/google/adk_community/sessions/redis_session_service.py new file mode 100644 index 00000000..bd8e2891 --- /dev/null +++ b/src/google/adk_community/sessions/redis_session_service.py @@ -0,0 +1,297 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import asyncio +import bisect +import logging +import time +import uuid +from typing import Any, Optional + +import orjson +import redis.asyncio as redis +from redis.crc import key_slot +from typing_extensions import override + +from google.adk.events.event import Event +from google.adk.sessions.base_session_service import ( + BaseSessionService, + GetSessionConfig, + ListSessionsResponse, +) +from google.adk.sessions.session import Session +from google.adk.sessions.state import State + +from .utils import _json_serializer + +logger = logging.getLogger("google_adk." + __name__) + +DEFAULT_EXPIRATION = 60 * 60 # 1 hour + + +def _session_serializer(obj: Session) -> bytes: + """Serialize ADK Session to JSON bytes.""" + return orjson.dumps(obj.model_dump(), default=_json_serializer) + + +class RedisKeys: + """Helper to generate Redis keys consistently.""" + + @staticmethod + def session(session_id: str) -> str: + return f"session:{session_id}" + + @staticmethod + def user_sessions(app_name: str, user_id: str) -> str: + return f"{State.APP_PREFIX}:{app_name}:{user_id}" + + @staticmethod + def app_state(app_name: str) -> str: + return f"{State.APP_PREFIX}{app_name}" + + @staticmethod + def user_state(app_name: str, user_id: str) -> str: + return f"{State.USER_PREFIX}{app_name}:{user_id}" + + +class RedisSessionService(BaseSessionService): + """A Redis-backed implementation of the session service.""" + + def __init__( + self, + host="localhost", + port=6379, + db=0, + uri=None, + cluster_uri=None, + expire=DEFAULT_EXPIRATION, + **kwargs, + ): + self.expire = expire + + if cluster_uri: + self.cache = redis.RedisCluster.from_url(cluster_uri, **kwargs) + elif uri: + self.cache = redis.Redis.from_url(uri, **kwargs) + else: + self.cache = redis.Redis(host=host, port=port, db=db, **kwargs) + + async def health_check(self) -> bool: + try: + await self.cache.ping() + return True + except redis.RedisError: + return False + + @override + async def create_session( + self, + *, + app_name: str, + user_id: str, + state: Optional[dict[str, Any]] = None, + session_id: Optional[str] = None, + ) -> Session: + session_id = ( + session_id.strip() + if session_id and session_id.strip() + else str(uuid.uuid4()) + ) + session = Session( + app_name=app_name, + user_id=user_id, + id=session_id, + state=state or {}, + last_update_time=time.time(), + ) + + user_sessions_key = RedisKeys.user_sessions(app_name, user_id) + session_key = RedisKeys.session(session_id) + + async with self.cache.pipeline(transaction=False) as pipe: + pipe.sadd(user_sessions_key, session_id) + pipe.expire(user_sessions_key, self.expire) + pipe.set( + session_key, + _session_serializer(session), + ex=self.expire, + ) + await pipe.execute() + + return await self._merge_state(app_name, user_id, session) + + @override + async def get_session( + self, + *, + app_name: str, + user_id: str, + session_id: str, + config: Optional[GetSessionConfig] = None, + ) -> Optional[Session]: + session_key = RedisKeys.session(session_id) + raw_session = await self.cache.get(session_key) + if not raw_session: + user_sessions_key = RedisKeys.user_sessions(app_name, user_id) + await self.cache.srem(user_sessions_key, session_id) + return None + + try: + session_dict = orjson.loads(raw_session) + session = Session.model_validate(session_dict) + except (orjson.JSONDecodeError, Exception) as e: + logger.error(f"Error decoding session {session_id}: {e}") + return None + + if config: + if config.num_recent_events: + session.events = session.events[-config.num_recent_events :] + if config.after_timestamp: + timestamps = [e.timestamp for e in session.events] + start_index = bisect.bisect_left(timestamps, config.after_timestamp) + session.events = session.events[start_index:] + + return await self._merge_state(app_name, user_id, session) + + @override + async def list_sessions( + self, *, app_name: str, user_id: str + ) -> ListSessionsResponse: + sessions = await self._load_sessions(app_name, user_id) + sessions_without_events = [] + + for session_data in sessions.values(): + session = Session.model_validate(session_data) + session.events = [] + session.state = {} + sessions_without_events.append(session) + + return ListSessionsResponse(sessions=sessions_without_events) + + @override + async def delete_session( + self, *, app_name: str, user_id: str, session_id: str + ) -> None: + user_sessions_key = RedisKeys.user_sessions(app_name, user_id) + session_key = RedisKeys.session(session_id) + + async with self.cache.pipeline(transaction=False) as pipe: + pipe.srem(user_sessions_key, session_id) + pipe.delete(session_key) + await pipe.execute() + + @override + async def append_event(self, session: Session, event: Event) -> Event: + await super().append_event(session=session, event=event) + session.last_update_time = event.timestamp + + async with self.cache.pipeline(transaction=False) as pipe: + user_sessions_key = RedisKeys.user_sessions( + session.app_name, session.user_id + ) + pipe.expire(user_sessions_key, self.expire) + + if event.actions and event.actions.state_delta: + for key, value in event.actions.state_delta.items(): + if key.startswith(State.APP_PREFIX): + pipe.hset( + RedisKeys.app_state(session.app_name), + key.removeprefix(State.APP_PREFIX), + orjson.dumps(value), + ) + if key.startswith(State.USER_PREFIX): + pipe.hset( + RedisKeys.user_state(session.app_name, session.user_id), + key.removeprefix(State.USER_PREFIX), + orjson.dumps(value), + ) + + pipe.set( + RedisKeys.session(session.id), + _session_serializer(session), + ex=self.expire, + ) + await pipe.execute() + + return event + + async def _merge_state( + self, app_name: str, user_id: str, session: Session + ) -> Session: + app_state = await self.cache.hgetall(RedisKeys.app_state(app_name)) + for k, v in app_state.items(): + session.state[State.APP_PREFIX + k.decode()] = orjson.loads(v) + + user_state = await self.cache.hgetall(RedisKeys.user_state(app_name, user_id)) + for k, v in user_state.items(): + session.state[State.USER_PREFIX + k.decode()] = orjson.loads(v) + + return session + + async def _load_sessions(self, app_name: str, user_id: str) -> dict[str, dict]: + key = RedisKeys.user_sessions(app_name, user_id) + try: + session_ids_bytes = await self.cache.smembers(key) + if not session_ids_bytes: + return {} + + session_ids = [s.decode() for s in session_ids_bytes] + session_keys = [RedisKeys.session(sid) for sid in session_ids] + + # Group by slot for Redis Cluster + slot_groups: dict[int, list[str]] = {} + for k in session_keys: + slot = key_slot(k.encode()) + slot_groups.setdefault(slot, []).append(k) + + async def fetch_group(keys: list[str]): + async with self.cache.pipeline(transaction=False) as pipe: + for k in keys: + pipe.get(k) + return await pipe.execute() + + results_per_group = await asyncio.gather( + *(fetch_group(keys) for keys in slot_groups.values()) + ) + + raw_sessions = [] + for group_keys, group_results in zip( + slot_groups.values(), results_per_group + ): + raw_sessions.extend(zip(group_keys, group_results)) + + sessions = {} + sessions_to_cleanup = [] + for key_name, raw_session in raw_sessions: + session_id = key_name.split(":", 1)[1] + if raw_session: + try: + sessions[session_id] = orjson.loads(raw_session) + except orjson.JSONDecodeError as e: + logger.error(f"Error decoding session {session_id}: {e}") + else: + logger.warning( + "Session ID %s found in user set but session data is missing. Cleaning up.", + session_id, + ) + sessions_to_cleanup.append(session_id) + + if sessions_to_cleanup: + await self.cache.srem(key, *sessions_to_cleanup) + + return sessions + except redis.RedisError as e: + logger.error(f"Error loading sessions for {user_id}: {e}") + return {} diff --git a/src/google/adk_community/sessions/utils.py b/src/google/adk_community/sessions/utils.py new file mode 100644 index 00000000..bc53d2b2 --- /dev/null +++ b/src/google/adk_community/sessions/utils.py @@ -0,0 +1,37 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import base64 +import datetime +from decimal import Decimal +import uuid + + +def _json_serializer(obj): + """Fallback serializer to handle non-JSON-compatible types.""" + if isinstance(obj, set): + return list(obj) + if isinstance(obj, bytes): + try: + return base64.b64encode(obj).decode("ascii") + except Exception: + return repr(obj) + if isinstance(obj, (datetime.datetime, datetime.date)): + return obj.isoformat() + if isinstance(obj, uuid.UUID): + return str(obj) + if isinstance(obj, Decimal): + return float(obj) + return str(obj) diff --git a/tests/unittests/__init__.py b/tests/unittests/__init__.py new file mode 100644 index 00000000..0a2669d7 --- /dev/null +++ b/tests/unittests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/sessions/__init__.py b/tests/unittests/sessions/__init__.py new file mode 100644 index 00000000..0a2669d7 --- /dev/null +++ b/tests/unittests/sessions/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/sessions/test_redis_session_service.py b/tests/unittests/sessions/test_redis_session_service.py new file mode 100644 index 00000000..fe3abc85 --- /dev/null +++ b/tests/unittests/sessions/test_redis_session_service.py @@ -0,0 +1,539 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import orjson +from datetime import datetime, timezone +import pytest +import pytest_asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.sessions.base_session_service import GetSessionConfig +from google.adk_community.sessions.redis_session_service import RedisSessionService +from google.genai import types + + +class TestRedisSessionService: + """Test cases for RedisSessionService.""" + + @pytest_asyncio.fixture + async def redis_service(self): + """Create a Redis session service for testing.""" + with patch("redis.asyncio.Redis") as mock_redis: + mock_client = AsyncMock() + mock_redis.return_value = mock_client + service = RedisSessionService() + service.cache = mock_client + yield service + + @pytest_asyncio.fixture + async def redis_cluster_service(self): + """Create a Redis cluster session service for testing.""" + with patch("redis.asyncio.RedisCluster.from_url") as mock_redis_cluster: + mock_client = AsyncMock() + mock_redis_cluster.return_value = mock_client + cluster_uri = "redis://redis-node1:6379" + service = RedisSessionService(cluster_uri=cluster_uri) + service.cache = mock_client + yield service + + @pytest_asyncio.fixture + async def redis_cluster_uri_service(self): + """Create a Redis cluster session service using URI for testing.""" + with patch("redis.asyncio.RedisCluster.from_url") as mock_redis_cluster: + mock_client = AsyncMock() + mock_redis_cluster.return_value = mock_client + cluster_uri = "redis://node1:6379,node2:6379" + service = RedisSessionService(cluster_uri=cluster_uri) + service.cache = mock_client + yield service + + def _setup_redis_mocks(self, redis_service, sessions_data=None): + """Helper to set up Redis mocks for the new storage strategy.""" + if sessions_data is None: + sessions_data = {} + + session_ids = list(sessions_data.keys()) + redis_service.cache.smembers = AsyncMock( + return_value={sid.encode() for sid in session_ids} + ) + + # Mock the new cluster-aware pipeline approach + session_values = [ + orjson.dumps(sessions_data[sid]) if sid in sessions_data else None + for sid in session_ids + ] + + # For backward compatibility with mget approach (still used in some tests) + redis_service.cache.mget = AsyncMock(return_value=session_values) + + # Mock pipeline for the new cluster approach + if session_ids: + # Group sessions as the actual implementation does + results_per_group = [] + for i in range(len(session_ids)): + results_per_group.append([session_values[i]]) + + mock_context_manager = MagicMock() + mock_pipe = MagicMock() + mock_pipe.get = MagicMock(return_value=mock_pipe) + mock_pipe.execute = AsyncMock(side_effect=results_per_group) + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + redis_service.cache.pipeline = MagicMock(return_value=mock_context_manager) + else: + mock_context_manager = MagicMock() + mock_pipe = MagicMock() + mock_pipe.get = MagicMock(return_value=mock_pipe) + mock_pipe.execute = AsyncMock(return_value=[]) + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + redis_service.cache.pipeline = MagicMock(return_value=mock_context_manager) + redis_service.cache.srem = AsyncMock() + redis_service.cache.get = AsyncMock(return_value=None) # Default to no session + + # Additional pipeline operations for create/update operations + if not session_ids: + mock_context_manager = MagicMock() + mock_pipe = MagicMock() + mock_pipe.set = MagicMock(return_value=mock_pipe) # Allow chaining + mock_pipe.sadd = MagicMock(return_value=mock_pipe) + mock_pipe.expire = MagicMock(return_value=mock_pipe) + mock_pipe.delete = MagicMock(return_value=mock_pipe) + mock_pipe.srem = MagicMock(return_value=mock_pipe) + mock_pipe.hset = MagicMock(return_value=mock_pipe) + mock_pipe.get = MagicMock(return_value=mock_pipe) + mock_pipe.execute = AsyncMock(return_value=[]) + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + redis_service.cache.pipeline = MagicMock(return_value=mock_context_manager) + + redis_service.cache.hgetall = AsyncMock(return_value={}) + redis_service.cache.hset = AsyncMock() + + @pytest.mark.asyncio + async def test_get_empty_session(self, redis_service): + """Test getting a non-existent session.""" + self._setup_redis_mocks(redis_service) + + session = await redis_service.get_session( + app_name="test_app", user_id="test_user", session_id="nonexistent" + ) + + assert session is None + + @pytest.mark.asyncio + async def test_create_get_session(self, redis_service): + """Test session creation and retrieval.""" + app_name = "test_app" + user_id = "test_user" + state = {"key": "value"} + + self._setup_redis_mocks(redis_service) + + session = await redis_service.create_session( + app_name=app_name, user_id=user_id, state=state + ) + + assert session.app_name == app_name + assert session.user_id == user_id + assert session.id is not None + assert session.state == state + assert ( + session.last_update_time + <= datetime.now().astimezone(timezone.utc).timestamp() + ) + + # Mock individual session retrieval + redis_service.cache.get = AsyncMock( + return_value=session.model_dump_json().encode() + ) + + got_session = await redis_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + + assert got_session.app_name == session.app_name + assert got_session.user_id == session.user_id + assert got_session.id == session.id + assert got_session.state == session.state + + @pytest.mark.asyncio + async def test_create_and_list_sessions(self, redis_service): + """Test creating multiple sessions and listing them. + + list_sessions() is expected to return lightweight session summaries, + i.e., with events and state stripped for performance. + """ + app_name = "test_app" + user_id = "test_user" + + self._setup_redis_mocks(redis_service) + + session_ids = ["session" + str(i) for i in range(3)] + sessions_data = {} + + for i, session_id in enumerate(session_ids): + session = await redis_service.create_session( + app_name=app_name, + user_id=user_id, + session_id=session_id, + state={"key": "value" + session_id}, + ) + # Add at least one event to ensure list_sessions actually strips them. + session.events.append(Event(author="user", timestamp=float(i + 1))) + sessions_data[session_id] = session.model_dump() + + # Now mock Redis to return those sessions (with events present in storage) + self._setup_redis_mocks(redis_service, sessions_data) + + list_sessions_response = await redis_service.list_sessions( + app_name=app_name, user_id=user_id + ) + sessions = list_sessions_response.sessions + + assert len(sessions) == len(session_ids) + returned_session_ids = {s.id for s in sessions} + assert returned_session_ids == set(session_ids) + + for s in sessions: + # list_sessions returns summaries: events and state removed for perf. + assert len(s.events) == 0 + assert s.state == {} + + @pytest.mark.asyncio + async def test_session_state_management(self, redis_service): + """Test session state management with app, user, and temp state.""" + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + self._setup_redis_mocks(redis_service) + + session = await redis_service.create_session( + app_name=app_name, + user_id=user_id, + session_id=session_id, + state={"initial_key": "initial_value"}, + ) + + event = Event( + invocation_id="invocation", + author="user", + content=types.Content(role="user", parts=[types.Part(text="text")]), + actions=EventActions( + state_delta={ + "app:key": "app_value", + "user:key1": "user_value", + "temp:key": "temp_value", + "initial_key": "updated_value", + } + ), + ) + + redis_service.cache.get = AsyncMock( + return_value=session.model_dump_json().encode() + ) + + await redis_service.append_event(session=session, event=event) + + assert session.state.get("app:key") == "app_value" + assert session.state.get("user:key1") == "user_value" + assert session.state.get("initial_key") == "updated_value" + assert session.state.get("temp:key") is None # Temp state filtered + + pipeline_mock = redis_service.cache.pipeline.return_value + pipe_mock = await pipeline_mock.__aenter__() + pipe_mock.hset.assert_any_call("app:test_app", "key", orjson.dumps("app_value")) + pipe_mock.hset.assert_any_call( + "user:test_app:test_user", "key1", orjson.dumps("user_value") + ) + + @pytest.mark.asyncio + async def test_append_event_with_bytes(self, redis_service): + """Test appending events with binary content and serialization roundtrip.""" + app_name = "test_app" + user_id = "test_user" + + self._setup_redis_mocks(redis_service) + + session = await redis_service.create_session(app_name=app_name, user_id=user_id) + + test_content = types.Content( + role="user", + parts=[ + types.Part.from_bytes(data=b"test_image_data", mime_type="image/png"), + ], + ) + test_grounding_metadata = types.GroundingMetadata( + search_entry_point=types.SearchEntryPoint(sdk_blob=b"test_sdk_blob") + ) + event = Event( + invocation_id="invocation", + author="user", + content=test_content, + grounding_metadata=test_grounding_metadata, + ) + + redis_service.cache.get = AsyncMock( + return_value=session.model_dump_json().encode() + ) + + await redis_service.append_event(session=session, event=event) + + # Verify the event was appended to in-memory session + assert len(session.events) == 1 + assert session.events[0].content == test_content + assert session.events[0].grounding_metadata == test_grounding_metadata + + # Test serialization/deserialization roundtrip to ensure binary data is preserved + # Simulate what happens when session is stored and retrieved from Redis + serialized_session = session.model_dump_json() + + redis_service.cache.get = AsyncMock(return_value=serialized_session.encode()) + + retrieved_session = await redis_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + + assert retrieved_session is not None + assert len(retrieved_session.events) == 1 + + # Verify the binary content was preserved through serialization + retrieved_event = retrieved_session.events[0] + assert retrieved_event.content.parts[0].inline_data.data == b"test_image_data" + assert retrieved_event.content.parts[0].inline_data.mime_type == "image/png" + assert ( + retrieved_event.grounding_metadata.search_entry_point.sdk_blob + == b"test_sdk_blob" + ) + + @pytest.mark.asyncio + async def test_get_session_with_config(self, redis_service): + """Test getting session with configuration filters.""" + app_name = "test_app" + user_id = "test_user" + + self._setup_redis_mocks(redis_service) + + session = await redis_service.create_session(app_name=app_name, user_id=user_id) + + # Add multiple events with different timestamps + num_test_events = 5 + for i in range(1, num_test_events + 1): + event = Event(author="user", timestamp=float(i)) + session.events.append(event) + + redis_service.cache.get = AsyncMock( + return_value=session.model_dump_json().encode() + ) + + # Test num_recent_events filter + config = GetSessionConfig(num_recent_events=3) + filtered_session = await redis_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id, config=config + ) + + assert len(filtered_session.events) == 3 + assert filtered_session.events[0].timestamp == 3.0 # Last 3 events + + # Test after_timestamp filter + config = GetSessionConfig(after_timestamp=3.0) + filtered_session = await redis_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id, config=config + ) + + assert len(filtered_session.events) == 3 # Events 3, 4, 5 + assert filtered_session.events[0].timestamp == 3.0 + + @pytest.mark.asyncio + async def test_delete_session(self, redis_service): + """Test session deletion.""" + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + self._setup_redis_mocks(redis_service) # Empty sessions + await redis_service.delete_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + pipeline_mock = redis_service.cache.pipeline.return_value + pipe_mock = await pipeline_mock.__aenter__() + pipe_mock.execute.assert_called() + + redis_service.cache.pipeline.reset_mock() + self._setup_redis_mocks(redis_service) + + await redis_service.delete_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + pipeline_mock = redis_service.cache.pipeline.return_value + pipe_mock = await pipeline_mock.__aenter__() + pipe_mock.execute.assert_called() + + @pytest.mark.asyncio + async def test_cluster_health_check(self, redis_cluster_service): + """Test health check for Redis cluster.""" + redis_cluster_service.cache.ping = AsyncMock(return_value=True) + + result = await redis_cluster_service.health_check() + assert result is True + redis_cluster_service.cache.ping.assert_called_once() + + @pytest.mark.asyncio + async def test_cluster_health_check_failure(self, redis_cluster_service): + """Test health check failure for Redis cluster.""" + from redis import RedisError + + redis_cluster_service.cache.ping = AsyncMock( + side_effect=RedisError("Connection failed") + ) + + result = await redis_cluster_service.health_check() + assert result is False + + @pytest.mark.asyncio + async def test_cluster_create_and_get_session(self, redis_cluster_service): + """Test session creation and retrieval in cluster mode.""" + app_name = "cluster_test_app" + user_id = "cluster_test_user" + state = {"cluster_key": "cluster_value"} + + self._setup_redis_mocks(redis_cluster_service) + + session = await redis_cluster_service.create_session( + app_name=app_name, user_id=user_id, state=state + ) + + assert session.app_name == app_name + assert session.user_id == user_id + assert session.id is not None + assert session.state == state + + # Mock individual session retrieval + redis_cluster_service.cache.get = AsyncMock( + return_value=session.model_dump_json().encode() + ) + + got_session = await redis_cluster_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + + assert got_session.app_name == session.app_name + assert got_session.user_id == session.user_id + assert got_session.id == session.id + assert got_session.state == session.state + + @pytest.mark.asyncio + async def test_cluster_uri_initialization(self, redis_cluster_uri_service): + """Test Redis cluster initialization with URI.""" + assert redis_cluster_uri_service.cache is not None + + @pytest.mark.asyncio + async def test_cluster_error_handling(self, redis_cluster_service): + """Test error handling in cluster operations.""" + from redis import RedisError + + app_name = "test_app" + user_id = "test_user" + + # Mock Redis error during session loading + redis_cluster_service.cache.smembers = AsyncMock( + side_effect=RedisError("Cluster error") + ) + + sessions_response = await redis_cluster_service.list_sessions( + app_name=app_name, user_id=user_id + ) + + assert len(sessions_response.sessions) == 0 + + @pytest.mark.asyncio + async def test_cluster_connection_validation(self): + """Test cluster connection validation during initialization.""" + cluster_uri = "redis://redis-node1:6379" + + with patch("redis.asyncio.RedisCluster.from_url") as mock_redis_cluster: + mock_client = AsyncMock() + mock_redis_cluster.return_value = mock_client + + service = RedisSessionService(cluster_uri=cluster_uri) + assert service.cache is not None + mock_redis_cluster.assert_called_once() + + @pytest.mark.asyncio + async def test_cluster_session_cleanup_on_error(self, redis_cluster_service): + """Test session cleanup when corrupted data is found in cluster.""" + app_name = "test_app" + user_id = "test_user" + + # Setup mock with corrupted session data + valid_session_data = { + "app_name": "test_app", + "user_id": "test_user", + "id": "session1", + "state": {}, + "events": [], + "last_update_time": 1234567890, + } + redis_cluster_service.cache.smembers = AsyncMock( + return_value={b"session1", b"session2"} + ) + + # Mock the pipeline for cluster approach + mock_context_manager = MagicMock() + mock_pipe = MagicMock() + mock_pipe.get = MagicMock(return_value=mock_pipe) + mock_pipe.execute = AsyncMock( + side_effect=[ + [orjson.dumps(valid_session_data)], # session1 result + [None], # session2 result (missing) + ] + ) + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + redis_cluster_service.cache.pipeline = MagicMock( + return_value=mock_context_manager + ) + redis_cluster_service.cache.srem = AsyncMock() + redis_cluster_service.cache.hgetall = AsyncMock(return_value={}) + + sessions_response = await redis_cluster_service.list_sessions( + app_name=app_name, user_id=user_id + ) + + redis_cluster_service.cache.srem.assert_called() + assert len(sessions_response.sessions) == 1 + + @pytest.mark.asyncio + async def test_decode_responses_handling(self, redis_service): + """Test proper handling of decode_responses setting.""" + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + # Test with bytes response (decode_responses=False) + session_data = '{"app_name": "test_app", "user_id": "test_user", "id": "test_session", "state": {}, "events": [], "last_update_time": 1234567890}' + redis_service.cache.get = AsyncMock(return_value=session_data.encode()) + redis_service.cache.hgetall = AsyncMock(return_value={}) + + session = await redis_service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + assert session is not None + assert session.app_name == app_name + assert session.user_id == user_id