Skip to content

Commit 7979fcf

Browse files
fix: improve session store concurrency safety
- Fix counter delta loss: Use delta tracking instead of max() to preserve concurrent increments - Fix file stat race conditions: Handle FileNotFoundError when files are deleted between operations - Add baseline stat tracking for proper merge semantics - Maintain backward compatibility for existing sessions Addresses CodeRabbit feedback on concurrent session safety issues. Co-authored-by: Mervin Praison <MervinPraison@users.noreply.github.com>
1 parent cf8e373 commit 7979fcf

1 file changed

Lines changed: 63 additions & 15 deletions

File tree

src/praisonai/praisonai/cli/session/unified.py

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ class UnifiedSession:
5252
total_cost: float = 0.0
5353
request_count: int = 0
5454

55+
# Track baseline values for proper delta merging (not persisted)
56+
_baseline_input_tokens: int = field(default=0, init=False, repr=False)
57+
_baseline_output_tokens: int = field(default=0, init=False, repr=False)
58+
_baseline_cost: float = field(default=0.0, init=False, repr=False)
59+
_baseline_request_count: int = field(default=0, init=False, repr=False)
60+
5561
# Model info
5662
current_model: str = "gpt-4o-mini"
5763

@@ -89,19 +95,45 @@ def update_stats(self, input_tokens: int, output_tokens: int, cost: float = 0.0)
8995
self.request_count += 1
9096
self.updated_at = datetime.now().isoformat()
9197

98+
def set_baseline_stats(self) -> None:
99+
"""Set baseline stats for delta tracking during merge operations."""
100+
self._baseline_input_tokens = self.total_input_tokens
101+
self._baseline_output_tokens = self.total_output_tokens
102+
self._baseline_cost = self.total_cost
103+
self._baseline_request_count = self.request_count
104+
105+
def get_stat_deltas(self) -> Dict[str, int | float]:
106+
"""Get deltas from baseline for proper merge."""
107+
return {
108+
"input_tokens": self.total_input_tokens - self._baseline_input_tokens,
109+
"output_tokens": self.total_output_tokens - self._baseline_output_tokens,
110+
"cost": self.total_cost - self._baseline_cost,
111+
"request_count": self.request_count - self._baseline_request_count,
112+
}
113+
92114
def clear_messages(self) -> None:
93115
"""Clear all messages from the session."""
94116
self.messages.clear()
95117
self.updated_at = datetime.now().isoformat()
96118

97119
def to_dict(self) -> Dict[str, Any]:
98-
"""Convert session to dictionary."""
99-
return asdict(self)
120+
"""Convert session to dictionary, excluding internal baseline fields."""
121+
data = asdict(self)
122+
# Remove internal baseline fields from serialization
123+
for key in list(data.keys()):
124+
if key.startswith('_baseline_'):
125+
del data[key]
126+
return data
100127

101128
@classmethod
102129
def from_dict(cls, data: Dict[str, Any]) -> "UnifiedSession":
103130
"""Create session from dictionary."""
104-
return cls(**data)
131+
# Remove any internal baseline fields that might have leaked into saved data
132+
clean_data = {k: v for k, v in data.items() if not k.startswith('_baseline_')}
133+
instance = cls(**clean_data)
134+
# Initialize baseline values to current values
135+
instance.set_baseline_stats()
136+
return instance
105137

106138
@property
107139
def message_count(self) -> int:
@@ -148,7 +180,7 @@ def _messages_common_prefix(
148180
) -> int:
149181
"""Return shared message prefix length for safe concurrent merge."""
150182
prefix = 0
151-
for left_msg, right_msg in zip(left, right):
183+
for left_msg, right_msg in zip(left, right, strict=False):
152184
if left_msg.get("role") != right_msg.get("role"):
153185
break
154186
if left_msg.get("content") != right_msg.get("content"):
@@ -207,17 +239,19 @@ def _merge_sessions(
207239
return incoming
208240

209241
merged = UnifiedSession.from_dict(disk_session.to_dict())
242+
243+
# Use prefix-based merge for append-only scenarios (original design)
210244
prefix = self._messages_common_prefix(disk_session.messages, incoming.messages)
211245
merged.messages = disk_session.messages + incoming.messages[prefix:]
212246

213-
if incoming.total_input_tokens > merged.total_input_tokens:
214-
merged.total_input_tokens = incoming.total_input_tokens
215-
if incoming.total_output_tokens > merged.total_output_tokens:
216-
merged.total_output_tokens = incoming.total_output_tokens
217-
if incoming.total_cost > merged.total_cost:
218-
merged.total_cost = incoming.total_cost
219-
if incoming.request_count > merged.request_count:
220-
merged.request_count = incoming.request_count
247+
# Merge stats using deltas instead of max()
248+
incoming_deltas = incoming.get_stat_deltas()
249+
merged.total_input_tokens += max(0, incoming_deltas["input_tokens"])
250+
merged.total_output_tokens += max(0, incoming_deltas["output_tokens"])
251+
merged.total_cost += max(0.0, incoming_deltas["cost"])
252+
merged.request_count += max(0, incoming_deltas["request_count"])
253+
254+
# Update other fields with incoming values if present
221255
if incoming.current_model:
222256
merged.current_model = incoming.current_model
223257
if incoming.metadata:
@@ -299,8 +333,13 @@ def save(self, session: UnifiedSession) -> None:
299333

300334
# Update cache
301335
self._cache[session.session_id] = session
302-
if path.exists():
303-
self._cache_mtime[session.session_id] = path.stat().st_mtime
336+
# Safely update mtime cache with error handling
337+
try:
338+
if path.exists():
339+
self._cache_mtime[session.session_id] = path.stat().st_mtime
340+
except (FileNotFoundError, OSError):
341+
# File was deleted/moved between write and stat, skip mtime update
342+
pass
304343

305344
# Update last session marker
306345
self._update_last_session(session.session_id)
@@ -328,8 +367,15 @@ def load(self, session_id: str) -> Optional[UnifiedSession]:
328367

329368
session = self._read_session_from_file(path)
330369
if session is not None:
370+
# Set baseline stats for proper delta tracking
371+
session.set_baseline_stats()
331372
self._cache[session_id] = session
332-
self._cache_mtime[session_id] = path.stat().st_mtime
373+
# Safely update mtime cache
374+
try:
375+
self._cache_mtime[session_id] = path.stat().st_mtime
376+
except (FileNotFoundError, OSError):
377+
# File was deleted/moved after read, skip mtime update
378+
pass
333379
logger.debug(f"Loaded session: {session_id}")
334380
return session
335381

@@ -351,6 +397,8 @@ def get_or_create(self, session_id: Optional[str] = None) -> UnifiedSession:
351397
# Create new session
352398
new_id = session_id or str(uuid.uuid4())[:8]
353399
session = UnifiedSession(session_id=new_id)
400+
# Set baseline stats for new session
401+
session.set_baseline_stats()
354402
self.save(session)
355403
return session
356404

0 commit comments

Comments
 (0)