Skip to content

Commit f0af40a

Browse files
fix: update tests to match thread-safe integration APIs
- Fix LlamaIndex event_id parameter conflicts in callback methods - Update all integration tests to use new thread-safe attributes - Fix parent-child relationship tests to pass proper parent IDs - Add proper asyncio event loop mocking for LlamaIndex tests - All 62 integration tests now pass 🤖 Generated with [Claude Code](https://claude.ai/code) Co-authored-by: Mervin Praison <MervinPraison@users.noreply.github.com>
1 parent c783f0c commit f0af40a

4 files changed

Lines changed: 59 additions & 50 deletions

File tree

src/praisonaiui/integrations/llama_index.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,9 @@ def on_query_end(self, response: Any, **kwargs: Any) -> None:
148148
"""Handle query end."""
149149
event_id = kwargs.get("event_id")
150150
payload = {"response": str(response) if response else None}
151-
self.on_event_end("query", payload, event_id, **kwargs)
151+
# Remove event_id from kwargs to avoid duplication
152+
kwargs_without_event_id = {k: v for k, v in kwargs.items() if k != "event_id"}
153+
self.on_event_end("query", payload, event_id, **kwargs_without_event_id)
152154

153155
def on_retrieve_start(self, query: str, **kwargs: Any) -> str:
154156
"""Handle retrieval start."""
@@ -161,7 +163,9 @@ def on_retrieve_end(self, nodes: List[Any], **kwargs: Any) -> None:
161163
"num_nodes": len(nodes) if nodes else 0,
162164
"nodes": [str(node) for node in (nodes or [])[:3]] # First 3 for brevity
163165
}
164-
self.on_event_end("retrieve", payload, event_id, **kwargs)
166+
# Remove event_id from kwargs to avoid duplication
167+
kwargs_without_event_id = {k: v for k, v in kwargs.items() if k != "event_id"}
168+
self.on_event_end("retrieve", payload, event_id, **kwargs_without_event_id)
165169

166170
def on_llm_start(self, messages: List[Any], **kwargs: Any) -> str:
167171
"""Handle LLM start."""
@@ -185,7 +189,9 @@ def on_llm_end(self, response: Any, **kwargs: Any) -> None:
185189
"""Handle LLM end."""
186190
event_id = kwargs.get("event_id")
187191
payload = {"response": str(response) if response else None}
188-
self.on_event_end("llm", payload, event_id, **kwargs)
192+
# Remove event_id from kwargs to avoid duplication
193+
kwargs_without_event_id = {k: v for k, v in kwargs.items() if k != "event_id"}
194+
self.on_event_end("llm", payload, event_id, **kwargs_without_event_id)
189195

190196
async def _start_step(self, step: Step, payload: Optional[Dict[str, Any]]) -> None:
191197
"""Start a step and optionally stream initial content."""

tests/unit/integrations/test_langchain.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def mock_step(self):
2828

2929
def test_init(self, handler):
3030
"""Test handler initialization."""
31-
assert handler._step_stack == []
3231
assert handler._run_id_to_step == {}
32+
assert hasattr(handler, '_lock')
3333

3434
@patch('praisonaiui.integrations.langchain.Step')
3535
@patch('asyncio.create_task')
@@ -52,7 +52,6 @@ def test_on_chain_start(self, mock_create_task, mock_step_class, handler, mock_s
5252
)
5353

5454
# Verify step is tracked
55-
assert mock_step in handler._step_stack
5655
assert handler._run_id_to_step[run_id] == mock_step
5756

5857
@patch('praisonaiui.integrations.langchain.Step')
@@ -83,7 +82,7 @@ def test_on_chain_end(self, mock_create_task, mock_get_loop, handler, mock_step)
8382

8483
run_id = "run_123"
8584
handler._run_id_to_step[run_id] = mock_step
86-
handler._step_stack.append(mock_step)
85+
handler._run_id_to_step[run_id] = mock_step
8786

8887
outputs = {"output": "test response"}
8988
handler.on_chain_end(outputs, run_id=run_id)
@@ -92,7 +91,7 @@ def test_on_chain_end(self, mock_create_task, mock_get_loop, handler, mock_step)
9291
mock_create_task.assert_called_once()
9392

9493
# Verify cleanup
95-
assert mock_step not in handler._step_stack
94+
assert run_id not in handler._run_id_to_step
9695
assert run_id not in handler._run_id_to_step
9796

9897
@patch('asyncio.get_running_loop')
@@ -104,7 +103,7 @@ def test_on_chain_error(self, mock_create_task, mock_get_loop, handler, mock_ste
104103

105104
run_id = "run_123"
106105
handler._run_id_to_step[run_id] = mock_step
107-
handler._step_stack.append(mock_step)
106+
handler._run_id_to_step[run_id] = mock_step
108107

109108
error = Exception("Test error")
110109
handler.on_chain_error(error, run_id=run_id)
@@ -113,7 +112,7 @@ def test_on_chain_error(self, mock_create_task, mock_get_loop, handler, mock_ste
113112
mock_create_task.assert_called_once()
114113

115114
# Verify cleanup
116-
assert mock_step not in handler._step_stack
115+
assert run_id not in handler._run_id_to_step
117116
assert run_id not in handler._run_id_to_step
118117

119118
@patch('praisonaiui.integrations.langchain.Step')
@@ -135,7 +134,7 @@ def test_on_llm_start(self, mock_create_task, mock_step_class, handler, mock_ste
135134
metadata={"prompts": prompts, "serialized": serialized}
136135
)
137136

138-
assert mock_step in handler._step_stack
137+
# Verify step is tracked
139138
assert handler._run_id_to_step[run_id] == mock_step
140139

141140
@patch('asyncio.get_running_loop')
@@ -178,13 +177,13 @@ def test_on_tool_end(self, mock_create_task, handler, mock_step):
178177
"""Test tool end event."""
179178
run_id = "run_123"
180179
handler._run_id_to_step[run_id] = mock_step
181-
handler._step_stack.append(mock_step)
180+
handler._run_id_to_step[run_id] = mock_step
182181

183182
output = "Found 10 results"
184183
handler.on_tool_end(output, run_id=run_id)
185184

186185
# Verify cleanup happened
187-
assert mock_step not in handler._step_stack
186+
assert run_id not in handler._run_id_to_step
188187
assert run_id not in handler._run_id_to_step
189188

190189
@patch('praisonaiui.integrations.langchain.Step')
@@ -215,13 +214,13 @@ def test_no_event_loop_handling(self, mock_create_task, handler, mock_step):
215214

216215
run_id = "run_123"
217216
handler._run_id_to_step[run_id] = mock_step
218-
handler._step_stack.append(mock_step)
217+
handler._run_id_to_step[run_id] = mock_step
219218

220219
# Should not raise exception
221220
handler.on_chain_end({}, run_id=run_id)
222221

223222
# Cleanup should still happen
224-
assert mock_step not in handler._step_stack
223+
assert run_id not in handler._run_id_to_step
225224
assert run_id not in handler._run_id_to_step
226225

227226
def test_missing_run_id(self, handler):
@@ -232,7 +231,7 @@ def test_missing_run_id(self, handler):
232231
handler.on_llm_start({}, [])
233232
handler.on_tool_start({}, "")
234233

235-
assert len(handler._step_stack) == 0
234+
assert len(handler._run_id_to_step) == 0
236235
assert len(handler._run_id_to_step) == 0
237236

238237
def test_nested_steps(self, handler):
@@ -249,8 +248,8 @@ def test_nested_steps(self, handler):
249248
# Start parent chain
250249
handler.on_chain_start({"name": "parent"}, {}, run_id="parent_run")
251250

252-
# Start nested LLM call
253-
handler.on_llm_start({"name": "OpenAI"}, ["test"], run_id="child_run")
251+
# Start nested LLM call with parent_run_id
252+
handler.on_llm_start({"name": "OpenAI"}, ["test"], run_id="child_run", parent_run_id="parent_run")
254253

255254
# Verify child step was created with parent
256255
assert mock_step_class.call_count == 2
@@ -278,8 +277,8 @@ def mock_step(self):
278277

279278
def test_init(self, handler):
280279
"""Test async handler initialization."""
281-
assert handler._step_stack == []
282280
assert handler._run_id_to_step == {}
281+
assert hasattr(handler, '_lock')
283282

284283
@pytest.mark.asyncio
285284
@patch('praisonaiui.integrations.langchain.Step')
@@ -305,15 +304,15 @@ async def test_on_chain_start(self, mock_step_class, handler, mock_step):
305304
mock_step.__aenter__.assert_called_once()
306305

307306
# Verify tracking
308-
assert mock_step in handler._step_stack
307+
# Verify step is tracked
309308
assert handler._run_id_to_step[run_id] == mock_step
310309

311310
@pytest.mark.asyncio
312311
async def test_on_chain_end(self, handler, mock_step):
313312
"""Test async chain end event."""
314313
run_id = "async_run_123"
315314
handler._run_id_to_step[run_id] = mock_step
316-
handler._step_stack.append(mock_step)
315+
handler._run_id_to_step[run_id] = mock_step
317316

318317
outputs = {"output": "async response"}
319318
await handler.on_chain_end(outputs, run_id=run_id)
@@ -322,15 +321,15 @@ async def test_on_chain_end(self, handler, mock_step):
322321
mock_step.__aexit__.assert_called_once_with(None, None, None)
323322

324323
# Verify cleanup
325-
assert mock_step not in handler._step_stack
324+
assert run_id not in handler._run_id_to_step
326325
assert run_id not in handler._run_id_to_step
327326

328327
@pytest.mark.asyncio
329328
async def test_on_chain_error(self, handler, mock_step):
330329
"""Test async chain error event."""
331330
run_id = "async_run_123"
332331
handler._run_id_to_step[run_id] = mock_step
333-
handler._step_stack.append(mock_step)
332+
handler._run_id_to_step[run_id] = mock_step
334333

335334
error = ValueError("Async test error")
336335
await handler.on_chain_error(error, run_id=run_id)
@@ -339,7 +338,7 @@ async def test_on_chain_error(self, handler, mock_step):
339338
mock_step.__aexit__.assert_called_once_with(ValueError, error, None)
340339

341340
# Verify cleanup
342-
assert mock_step not in handler._step_stack
341+
assert run_id not in handler._run_id_to_step
343342
assert run_id not in handler._run_id_to_step
344343

345344
@pytest.mark.asyncio
@@ -435,5 +434,5 @@ async def test_unknown_run_id_ignored(self, handler):
435434
await handler.on_llm_new_token("token", run_id="unknown")
436435
await handler.on_tool_end("output", run_id="unknown")
437436

438-
assert len(handler._step_stack) == 0
437+
assert len(handler._run_id_to_step) == 0
439438
assert len(handler._run_id_to_step) == 0

tests/unit/integrations/test_llama_index.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def mock_step(self):
2828

2929
def test_init(self, handler):
3030
"""Test handler initialization."""
31-
assert handler._step_stack == []
3231
assert handler._event_id_to_step == {}
32+
assert handler._parent_map == {}
3333

3434
@patch('praisonaiui.integrations.llama_index.Step')
3535
@patch('asyncio.create_task')
@@ -51,7 +51,6 @@ def test_on_event_start_query(self, mock_create_task, mock_step_class, handler,
5151
)
5252

5353
# Verify tracking
54-
assert mock_step in handler._step_stack
5554
assert handler._event_id_to_step[event_id] == mock_step
5655

5756
@patch('praisonaiui.integrations.llama_index.Step')
@@ -162,12 +161,15 @@ def test_on_event_start_unknown(self, mock_create_task, mock_step_class, handler
162161
metadata={"event_type": event_type, "payload": payload}
163162
)
164163

164+
@patch('asyncio.get_running_loop')
165165
@patch('asyncio.create_task')
166-
def test_on_event_end(self, mock_create_task, handler, mock_step):
166+
def test_on_event_end(self, mock_create_task, mock_get_loop, handler, mock_step):
167167
"""Test event end handling."""
168+
# Mock that there's a running event loop
169+
mock_get_loop.return_value = MagicMock()
170+
168171
event_id = "test_event_123"
169172
handler._event_id_to_step[event_id] = mock_step
170-
handler._step_stack.append(mock_step)
171173

172174
payload = {"response": "Test response"}
173175
handler.on_event_end("query", payload, event_id)
@@ -176,15 +178,18 @@ def test_on_event_end(self, mock_create_task, handler, mock_step):
176178
mock_create_task.assert_called_once()
177179

178180
# Verify cleanup
179-
assert mock_step not in handler._step_stack
181+
# Step should be removed from tracking maps
180182
assert event_id not in handler._event_id_to_step
181183

184+
@patch('asyncio.get_running_loop')
182185
@patch('asyncio.create_task')
183-
def test_on_event_error(self, mock_create_task, handler, mock_step):
186+
def test_on_event_error(self, mock_create_task, mock_get_loop, handler, mock_step):
184187
"""Test event error handling."""
188+
# Mock that there's a running event loop
189+
mock_get_loop.return_value = MagicMock()
190+
185191
event_id = "test_event_123"
186192
handler._event_id_to_step[event_id] = mock_step
187-
handler._step_stack.append(mock_step)
188193

189194
exception = ValueError("Test error")
190195
handler.on_event_error("query", exception, event_id)
@@ -193,7 +198,7 @@ def test_on_event_error(self, mock_create_task, handler, mock_step):
193198
mock_create_task.assert_called_once()
194199

195200
# Verify cleanup
196-
assert mock_step not in handler._step_stack
201+
# Step should be removed from tracking maps
197202
assert event_id not in handler._event_id_to_step
198203

199204
def test_start_trace_end_trace(self, handler):
@@ -226,15 +231,14 @@ def test_on_query_end(self, mock_create_task, handler, mock_step):
226231
"""Test query end convenience method."""
227232
event_id = "query_123"
228233
handler._event_id_to_step[event_id] = mock_step
229-
handler._step_stack.append(mock_step)
230234

231235
response = MagicMock()
232236
response.__str__ = MagicMock(return_value="AI is a field of computer science")
233237

234238
handler.on_query_end(response, event_id=event_id)
235239

236240
# Verify cleanup
237-
assert mock_step not in handler._step_stack
241+
# Step should be removed from tracking maps
238242
assert event_id not in handler._event_id_to_step
239243

240244
@patch('praisonaiui.integrations.llama_index.Step')
@@ -258,7 +262,6 @@ def test_on_retrieve_end(self, mock_create_task, handler, mock_step):
258262
"""Test retrieval end convenience method."""
259263
event_id = "retrieve_123"
260264
handler._event_id_to_step[event_id] = mock_step
261-
handler._step_stack.append(mock_step)
262265

263266
# Mock nodes
264267
node1 = MagicMock()
@@ -270,7 +273,7 @@ def test_on_retrieve_end(self, mock_create_task, handler, mock_step):
270273
handler.on_retrieve_end(nodes, event_id=event_id)
271274

272275
# Verify cleanup
273-
assert mock_step not in handler._step_stack
276+
# Step should be removed from tracking maps
274277
assert event_id not in handler._event_id_to_step
275278

276279
@patch('praisonaiui.integrations.llama_index.Step')
@@ -289,9 +292,13 @@ def test_on_llm_start(self, mock_create_task, mock_step_class, handler, mock_ste
289292
metadata={"event_type": "llm", "payload": {"messages": messages}}
290293
)
291294

295+
@patch('asyncio.get_running_loop')
292296
@patch('asyncio.create_task')
293-
def test_on_llm_new_token(self, mock_create_task, handler, mock_step):
297+
def test_on_llm_new_token(self, mock_create_task, mock_get_loop, handler, mock_step):
294298
"""Test LLM token streaming."""
299+
# Mock that there's a running event loop
300+
mock_get_loop.return_value = MagicMock()
301+
295302
event_id = "llm_123"
296303
handler._event_id_to_step[event_id] = mock_step
297304

@@ -306,13 +313,12 @@ def test_on_llm_end(self, mock_create_task, handler, mock_step):
306313
"""Test LLM end convenience method."""
307314
event_id = "llm_123"
308315
handler._event_id_to_step[event_id] = mock_step
309-
handler._step_stack.append(mock_step)
310316

311317
response = "Python is a programming language"
312318
handler.on_llm_end(response, event_id=event_id)
313319

314320
# Verify cleanup
315-
assert mock_step not in handler._step_stack
321+
# Step should be removed from tracking maps
316322
assert event_id not in handler._event_id_to_step
317323

318324
def test_no_event_loop_handling(self, handler, mock_step):
@@ -323,13 +329,11 @@ def test_no_event_loop_handling(self, handler, mock_step):
323329

324330
event_id = "test_event_123"
325331
handler._event_id_to_step[event_id] = mock_step
326-
handler._step_stack.append(mock_step)
327332

328333
# Should not raise exception
329334
handler.on_event_end("query", {}, event_id)
330335

331336
# Cleanup should still happen
332-
assert mock_step not in handler._step_stack
333337
assert event_id not in handler._event_id_to_step
334338

335339
def test_missing_event_id(self, handler):
@@ -339,7 +343,7 @@ def test_missing_event_id(self, handler):
339343
handler.on_event_error("query", Exception("test"), event_id=None)
340344
handler.on_llm_new_token("token", event_id=None)
341345

342-
assert len(handler._step_stack) == 0
346+
assert len(handler._event_id_to_step) == 0
343347
assert len(handler._event_id_to_step) == 0
344348

345349
def test_unknown_event_id(self, handler):
@@ -349,7 +353,7 @@ def test_unknown_event_id(self, handler):
349353
handler.on_event_error("query", Exception("test"), event_id="unknown_123")
350354
handler.on_llm_new_token("token", event_id="unknown_123")
351355

352-
assert len(handler._step_stack) == 0
356+
assert len(handler._event_id_to_step) == 0
353357
assert len(handler._event_id_to_step) == 0
354358

355359
@patch('praisonaiui.integrations.llama_index.Step')
@@ -365,8 +369,8 @@ def test_nested_events(self, mock_create_task, mock_step_class, handler):
365369
# Start parent query
366370
parent_id = handler.on_event_start("query", {"query": "test"})
367371

368-
# Start nested retrieval
369-
child_id = handler.on_event_start("retrieve", {"query": "test"})
372+
# Start nested retrieval with parent_id
373+
child_id = handler.on_event_start("retrieve", {"query": "test"}, parent_id=parent_id)
370374

371375
# Verify child step was created with parent
372376
assert mock_step_class.call_count == 2

0 commit comments

Comments
 (0)