From f6441e20cc966436fbd1b8dfb00188f33d3954ee Mon Sep 17 00:00:00 2001 From: xiami762 <> Date: Tue, 2 Jun 2026 13:55:02 +0800 Subject: [PATCH] fix(session): normalize legacy stored messages and parts on cache load Backfill missing assistant/user fields and tool state timestamps so old sessions deserialize without dropping the whole cache; skip invalid entries. Co-authored-by: Cursor --- flocks/session/message.py | 395 +++++++++++++++++- .../session/test_message_parts_persistence.py | 171 ++++++++ 2 files changed, 552 insertions(+), 14 deletions(-) diff --git a/flocks/session/message.py b/flocks/session/message.py index 7eadd39be..9bf05da41 100644 --- a/flocks/session/message.py +++ b/flocks/session/message.py @@ -596,14 +596,35 @@ async def _ensure_cache(cls, session_id: str) -> None: storage_key = f"{cls._MESSAGE_PREFIX}:{session_id}" stored_data = await Storage.get(storage_key) - if stored_data: + message_times: Dict[str, Dict[str, int]] = {} + if isinstance(stored_data, list): messages = [] for msg_data in stored_data: - role = msg_data.get('role', 'assistant') - if role == 'user': - messages.append(UserMessageInfo.model_validate(msg_data)) - else: - messages.append(AssistantMessageInfo.model_validate(msg_data)) + if not isinstance(msg_data, dict): + log.warn("message.cache.skipped_non_dict", { + "session_id": session_id, + "raw_type": type(msg_data).__name__, + }) + continue + + normalized = cls._normalize_stored_message(msg_data, session_id) + role = normalized.get("role", "assistant") + try: + if role == "user": + message = UserMessageInfo.model_validate(normalized) + else: + message = AssistantMessageInfo.model_validate(normalized) + except Exception as exc: + log.warn("message.cache.skipped_invalid", { + "session_id": session_id, + "message_id": normalized.get("id"), + "role": role, + "error": str(exc), + }) + continue + + messages.append(message) + message_times[message.id] = message.time cls._messages_cache[session_id] = messages else: cls._messages_cache[session_id] = [] @@ -638,9 +659,12 @@ async def _ensure_cache(cls, session_id: str) -> None: for msg_id, parts_data in stored_parts.items(): if not isinstance(parts_data, list): continue - cls._parts_cache[session_id][msg_id] = [ - cls.deserialize_part(p) for p in parts_data - ] + cls._parts_cache[session_id][msg_id] = cls._deserialize_parts_list( + session_id, + msg_id, + parts_data, + message_time=message_times.get(msg_id), + ) cls._parts_revision_cache[session_id][msg_id] = 0 cls._parts_serialized_cache[session_id] = { msg_id: list(parts_data) @@ -663,8 +687,347 @@ def _rebuild_id_index(cls, session_id: str) -> None: cls._msg_id_index[session_id] = {m.id: i for i, m in enumerate(messages)} @classmethod - def deserialize_part(cls, part_data: Dict[str, Any]) -> PartType: + def _first_non_none(cls, *values: Any) -> Any: + """Return the first value that is not None, preserving 0-like sentinels.""" + for value in values: + if value is not None: + return value + return None + + @classmethod + def _default_message_time(cls, raw_time: Any) -> Dict[str, int]: + """Normalize stored message timestamps to the current schema.""" + now_ms = int(datetime.now().timestamp() * 1000) + if isinstance(raw_time, dict): + normalized = dict(raw_time) + created = cls._first_non_none( + normalized.get("created"), + normalized.get("start"), + ) + normalized["created"] = int(created) if created is not None else now_ms + if "completed" not in normalized and normalized.get("end") is not None: + normalized["completed"] = int(normalized["end"]) + return normalized + return {"created": now_ms} + + @classmethod + def _default_message_path(cls, raw_path: Any) -> Dict[str, str]: + """Return a safe message path payload for assistant messages.""" + if isinstance(raw_path, MessagePath): + return raw_path.model_dump() + if isinstance(raw_path, dict): + return { + "cwd": str(raw_path.get("cwd") or "./"), + "root": str(raw_path.get("root") or ""), + } + return MessagePath(cwd="./").model_dump() + + @classmethod + def _default_token_usage(cls, raw_tokens: Any = None) -> Dict[str, Any]: + """Return a safe token payload for assistant messages.""" + defaults = TokenUsage().model_dump() + if isinstance(raw_tokens, TokenUsage): + return raw_tokens.model_dump() + if isinstance(raw_tokens, dict): + normalized = dict(defaults) + for key in ("input", "output", "reasoning"): + value = raw_tokens.get(key) + if value is not None: + normalized[key] = value + cache_raw = raw_tokens.get("cache") + if isinstance(cache_raw, dict): + normalized["cache"] = { + "read": cache_raw.get("read", 0), + "write": cache_raw.get("write", 0), + } + return normalized + return defaults + + @classmethod + def _normalize_stored_message( + cls, + msg_data: Dict[str, Any], + session_id: str, + ) -> Dict[str, Any]: + """Backfill missing fields for legacy or partially-written messages.""" + normalized = dict(msg_data) + role = normalized.get("role", "assistant") + normalized["sessionID"] = normalized.get("sessionID") or session_id + normalized["time"] = cls._default_message_time(normalized.get("time")) + + if role == "user": + model_raw = normalized.get("model") + if not isinstance(model_raw, dict): + model_raw = {} + normalized["agent"] = normalized.get("agent") or "rex" + normalized["model"] = { + "providerID": model_raw.get("providerID") + or normalized.get("providerID") + or "", + "modelID": model_raw.get("modelID") + or normalized.get("modelID") + or "", + } + return normalized + + model_raw = normalized.get("model") + if not isinstance(model_raw, dict): + model_raw = {} + normalized["parentID"] = ( + normalized.get("parentID") + or normalized.get("parent_id") + or "" + ) + normalized["modelID"] = ( + normalized.get("modelID") + or normalized.get("model_id") + or model_raw.get("modelID") + or "" + ) + normalized["providerID"] = ( + normalized.get("providerID") + or normalized.get("provider_id") + or model_raw.get("providerID") + or "" + ) + normalized["agent"] = normalized.get("agent") or "rex" + normalized["mode"] = normalized.get("mode") or normalized["agent"] or "standard" + normalized["path"] = cls._default_message_path(normalized.get("path")) + normalized["tokens"] = cls._default_token_usage(normalized.get("tokens")) + return normalized + + @classmethod + def _default_part_time( + cls, + raw_time: Any = None, + *, + message_time: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Optional[int]]: + """Provide best-effort timestamps for legacy stored parts.""" + time_info = raw_time if isinstance(raw_time, dict) else {} + fallback_time = message_time if isinstance(message_time, dict) else {} + start = cls._first_non_none( + time_info.get("start"), + time_info.get("created"), + fallback_time.get("start"), + fallback_time.get("created"), + 0, + ) + end = cls._first_non_none( + time_info.get("end"), + time_info.get("updated"), + time_info.get("completed"), + fallback_time.get("end"), + fallback_time.get("updated"), + fallback_time.get("completed"), + ) + if end is None and start is not None: + end = start + return {"start": int(start), "end": int(end) if end is not None else None} + + @classmethod + def _normalize_tool_state( + cls, + raw_state: Any, + *, + message_time: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """Convert legacy tool state payloads to the current shape.""" + fallback_time = cls._default_part_time(message_time=message_time) + if isinstance(raw_state, dict): + normalized = dict(raw_state) + else: + normalized = {} + if isinstance(raw_state, str): + normalized["status"] = raw_state + + status = normalized.get("status") + if status not in {"pending", "running", "completed", "error"}: + if "output" in normalized: + status = "completed" + elif "error" in normalized: + status = "error" + elif "time" in normalized: + status = "running" + else: + status = "pending" + normalized["status"] = status + + if status == "pending": + normalized.setdefault("input", {}) + normalized.setdefault("raw", "") + elif status == "running": + normalized.setdefault("input", {}) + raw_time = normalized.get("time") + normalized["time"] = ( + cls._default_part_time(raw_time, message_time=message_time) + if isinstance(raw_time, dict) + else fallback_time + ) + elif status == "completed": + normalized.setdefault("input", {}) + normalized.setdefault("output", "") + normalized.setdefault("title", "") + normalized.setdefault("metadata", {}) + raw_time = normalized.get("time") + normalized["time"] = ( + cls._default_part_time(raw_time, message_time=message_time) + if isinstance(raw_time, dict) + else fallback_time + ) + elif status == "error": + normalized.setdefault("input", {}) + normalized.setdefault("error", "") + raw_time = normalized.get("time") + normalized["time"] = ( + cls._default_part_time(raw_time, message_time=message_time) + if isinstance(raw_time, dict) + else fallback_time + ) + + return normalized + + @classmethod + def _normalize_part_data( + cls, + part_data: Dict[str, Any], + *, + session_id: Optional[str] = None, + message_id: Optional[str] = None, + message_time: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """Normalize legacy/exported part payloads before validation.""" + normalized = dict(part_data) + if session_id and not normalized.get("sessionID"): + normalized["sessionID"] = session_id + if message_id and not normalized.get("messageID"): + normalized["messageID"] = message_id + + part_type = normalized.get("type", "text") + metadata = normalized.get("metadata") + metadata_dict = metadata if isinstance(metadata, dict) else {} + + if "content" in normalized and "text" not in normalized: + normalized["text"] = normalized.get("content", "") + + raw_time = normalized.get("time") + if part_type == "text": + normalized.setdefault("text", "") + if isinstance(raw_time, dict): + if "start" not in raw_time and any( + key in raw_time for key in ("created", "updated", "completed", "end") + ): + normalized["time"] = cls._default_part_time( + raw_time, + message_time=message_time, + ) + elif raw_time is not None: + normalized.pop("time", None) + if metadata is not None and not isinstance(metadata, dict): + normalized["metadata"] = None + elif part_type == "reasoning": + normalized.setdefault("text", normalized.get("content", "")) + normalized.setdefault("metadata", metadata_dict or None) + normalized["time"] = cls._default_part_time( + raw_time, + message_time=message_time, + ) + elif part_type == "tool": + normalized.setdefault("callID", metadata_dict.get("callID") or normalized.get("id", "")) + normalized.setdefault("tool", metadata_dict.get("tool") or normalized.get("tool", "unknown")) + raw_state = normalized.get("state") + if raw_state is None and metadata_dict: + raw_state = metadata_dict.get("state") + normalized["state"] = cls._normalize_tool_state( + raw_state, + message_time=message_time, + ) + normalized.setdefault("metadata", metadata_dict or None) + elif part_type == "file": + normalized.setdefault("mime", metadata_dict.get("mime") or "application/octet-stream") + normalized.setdefault("filename", metadata_dict.get("filename")) + normalized.setdefault("url", metadata_dict.get("url") or normalized.get("content", "")) + elif part_type == "snapshot": + normalized.setdefault("snapshot", metadata_dict.get("snapshot") or normalized.get("content", "")) + elif part_type == "patch": + normalized.setdefault("hash", metadata_dict.get("hash") or "") + normalized.setdefault("files", metadata_dict.get("files") or []) + elif part_type == "step-finish": + normalized.setdefault("reason", metadata_dict.get("reason") or "completed") + normalized.setdefault("snapshot", metadata_dict.get("snapshot")) + normalized.setdefault("cost", metadata_dict.get("cost") or 0.0) + normalized.setdefault("tokens", metadata_dict.get("tokens") or cls._default_token_usage()) + elif part_type == "agent": + normalized.setdefault("name", metadata_dict.get("name") or normalized.get("content") or "agent") + elif part_type == "subtask": + normalized.setdefault("prompt", metadata_dict.get("prompt") or normalized.get("content", "")) + normalized.setdefault("description", metadata_dict.get("description") or "") + normalized.setdefault("agent", metadata_dict.get("agent") or "agent") + elif part_type == "retry": + normalized.setdefault("attempt", metadata_dict.get("attempt") or 1) + normalized.setdefault("error", metadata_dict.get("error") or {}) + normalized["time"] = cls._default_part_time( + metadata_dict.get("time") or raw_time, + message_time=message_time, + ) + elif part_type == "compaction": + normalized.setdefault("auto", bool(metadata_dict.get("auto", False))) + + return normalized + + @classmethod + def _deserialize_parts_list( + cls, + session_id: str, + message_id: str, + parts_data: List[Dict[str, Any]], + *, + message_time: Optional[Dict[str, Any]] = None, + ) -> List[PartType]: + """Best-effort deserialize a stored parts list without dropping the session.""" + parts: List[PartType] = [] + for raw_part in parts_data: + if not isinstance(raw_part, dict): + log.warn("message.part.skipped_non_dict", { + "session_id": session_id, + "message_id": message_id, + "raw_type": type(raw_part).__name__, + }) + continue + try: + parts.append( + cls.deserialize_part( + raw_part, + session_id=session_id, + message_id=message_id, + message_time=message_time, + ) + ) + except Exception as exc: + log.warn("message.part.skipped_invalid", { + "session_id": session_id, + "message_id": message_id, + "part_id": raw_part.get("id"), + "error": str(exc), + }) + return parts + + @classmethod + def deserialize_part( + cls, + part_data: Dict[str, Any], + *, + session_id: Optional[str] = None, + message_id: Optional[str] = None, + message_time: Optional[Dict[str, Any]] = None, + ) -> PartType: """Deserialize a part from storage format""" + part_data = cls._normalize_part_data( + part_data, + session_id=session_id, + message_id=message_id, + message_time=message_time, + ) part_type = part_data.get('type', 'text') type_map = { @@ -692,10 +1055,14 @@ def _normalize_assistant_message(message: MessageInfo) -> MessageInfo: return message updates: Dict[str, Any] = {} - if isinstance(message.tokens, dict): - updates["tokens"] = TokenUsage.model_validate(message.tokens) - if isinstance(message.path, dict): - updates["path"] = MessagePath.model_validate(message.path) + if not isinstance(message.tokens, TokenUsage): + updates["tokens"] = TokenUsage.model_validate( + Message._default_token_usage(message.tokens) + ) + if not isinstance(message.path, MessagePath): + updates["path"] = MessagePath.model_validate( + Message._default_message_path(message.path) + ) if not updates: return message diff --git a/tests/session/test_message_parts_persistence.py b/tests/session/test_message_parts_persistence.py index a36c9b508..a7229b285 100644 --- a/tests/session/test_message_parts_persistence.py +++ b/tests/session/test_message_parts_persistence.py @@ -45,6 +45,16 @@ async def _write_legacy_session(session_id: str, messages: dict[str, str]) -> No Message.invalidate_cache(session_id) +async def _write_raw_legacy_payload( + session_id: str, + messages: list[dict], + parts: dict[str, list], +) -> None: + await Storage.set(f"message:{session_id}", messages, "message") + await Storage.set(f"message_parts:{session_id}", parts, "message_parts") + Message.invalidate_cache(session_id) + + @pytest.mark.asyncio async def test_new_sessions_write_per_message_parts_keys() -> None: session_id = "ses_parts_per_message_new" @@ -153,3 +163,164 @@ async def test_clear_removes_legacy_blob_and_per_message_keys() -> None: assert await Storage.get(f"message_parts:{per_message_session_id}") is None assert await Storage.list_keys(prefix=f"message_parts:{per_message_session_id}:") == [] + + +def test_deserialize_legacy_text_part_normalizes_content_and_time() -> None: + part = Message.deserialize_part( + { + "id": "part_legacy_text", + "sessionID": "ses_legacy_text", + "messageID": "msg_legacy_text", + "type": "text", + "content": "hello legacy", + "time": {"created": 7}, + } + ) + + assert part.text == "hello legacy" + assert part.time is not None + assert part.time.start == 7 + assert part.time.end == 7 + + +@pytest.mark.asyncio +async def test_ensure_cache_loads_legacy_assistant_missing_fields() -> None: + session_id = "ses_legacy_assistant_missing_fields" + await _write_raw_legacy_payload( + session_id, + messages=[ + { + "id": "msg_assistant_legacy", + "role": "assistant", + "time": {"created": 2}, + "path": [], + "content": "", + } + ], + parts={ + "msg_assistant_legacy": [ + { + "id": "part_assistant_legacy", + "type": "text", + "content": "restored assistant text", + "time": {"created": 2}, + } + ] + }, + ) + + messages = await Message.list_with_parts(session_id) + + assert len(messages) == 1 + info = messages[0].info + assert info.sessionID == session_id + assert info.agent == "rex" + assert info.parentID == "" + assert info.modelID == "" + assert info.providerID == "" + assert info.path.cwd == "./" + assert info.tokens.input == 0 + assert messages[0].parts[0].text == "restored assistant text" + assert messages[0].parts[0].time is not None + assert messages[0].parts[0].time.start == 2 + + +@pytest.mark.asyncio +async def test_ensure_cache_preserves_zero_created_timestamp() -> None: + session_id = "ses_legacy_zero_created" + await _write_raw_legacy_payload( + session_id, + messages=[ + { + "id": "msg_assistant_zero", + "role": "assistant", + "time": {"created": 0}, + "path": [], + "content": "", + } + ], + parts={ + "msg_assistant_zero": [ + { + "id": "part_assistant_zero", + "type": "text", + "content": "zero timestamp text", + "time": {"created": 0}, + } + ] + }, + ) + + messages = await Message.list_with_parts(session_id) + + assert len(messages) == 1 + assert messages[0].info.time["created"] == 0 + assert messages[0].parts[0].time is not None + assert messages[0].parts[0].time.start == 0 + assert messages[0].parts[0].time.end == 0 + + +@pytest.mark.asyncio +async def test_ensure_cache_loads_legacy_tool_part_without_time() -> None: + session_id = "ses_legacy_tool_missing_time" + await _write_raw_legacy_payload( + session_id, + messages=[ + { + "id": "msg_assistant_tool", + "role": "assistant", + "time": {"created": 0}, + "path": [], + "content": "", + } + ], + parts={ + "msg_assistant_tool": [ + { + "id": "part_tool_legacy", + "type": "tool", + "tool": "bash", + "callID": "call_legacy", + "state": { + "status": "completed", + "output": "legacy output", + }, + } + ] + }, + ) + + messages = await Message.list_with_parts(session_id) + + assert len(messages) == 1 + assert len(messages[0].parts) == 1 + tool_part = messages[0].parts[0] + assert tool_part.type == "tool" + assert tool_part.state.status == "completed" + assert tool_part.state.time == {"start": 0, "end": 0} + + +@pytest.mark.asyncio +async def test_ensure_cache_skips_invalid_part_keeps_siblings() -> None: + session_id = "ses_legacy_bad_part_skip" + await _write_raw_legacy_payload( + session_id, + messages=[_user_message(session_id, "msg_a").model_dump()], + parts={ + "msg_a": [ + "not-a-dict-part", + { + "id": "part_good", + "sessionID": session_id, + "messageID": "msg_a", + "type": "text", + "text": "still here", + }, + ] + }, + ) + + messages = await Message.list_with_parts(session_id) + + assert len(messages) == 1 + assert [part.text for part in messages[0].parts] == ["still here"]