forked from volcengine/OpenViking
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_session_task_tracking.py
More file actions
353 lines (248 loc) · 11.6 KB
/
test_session_task_tracking.py
File metadata and controls
353 lines (248 loc) · 11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd.
# SPDX-License-Identifier: Apache-2.0
"""Integration tests for session commit task tracking via HTTP API."""
import asyncio
from typing import AsyncGenerator, Tuple
import httpx
import pytest_asyncio
from openviking import AsyncOpenViking
from openviking.server.app import create_app
from openviking.server.config import ServerConfig
from openviking.server.dependencies import set_service
from openviking.service.core import OpenVikingService
from openviking.service.task_tracker import reset_task_tracker
@pytest_asyncio.fixture
async def api_client(temp_dir) -> AsyncGenerator[Tuple[httpx.AsyncClient, OpenVikingService], None]:
"""Create in-process HTTP client for API endpoint tests."""
reset_task_tracker()
service = OpenVikingService(path=str(temp_dir / "api_data"))
await service.initialize()
app = create_app(config=ServerConfig(), service=service)
set_service(service)
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
yield client, service
await service.close()
await AsyncOpenViking.reset()
reset_task_tracker()
async def _new_session_with_message(client: httpx.AsyncClient) -> str:
resp = await client.post("/api/v1/sessions", json={})
assert resp.status_code == 200
session_id = resp.json()["result"]["session_id"]
await client.post(
f"/api/v1/sessions/{session_id}/messages",
json={"role": "user", "content": "hello world"},
)
return session_id
# ── wait=false returns task_id ──
async def test_commit_wait_false_returns_task_id(api_client):
"""wait=false should return a task_id for polling."""
client, service = api_client
session_id = await _new_session_with_message(client)
done = asyncio.Event()
async def fake_commit(_sid, _ctx):
await asyncio.sleep(0.1)
done.set()
return {"session_id": _sid, "status": "committed", "memories_extracted": 0}
service.sessions.commit_async = fake_commit
resp = await client.post(f"/api/v1/sessions/{session_id}/commit", params={"wait": False})
assert resp.status_code == 200
body = resp.json()
assert body["result"]["status"] == "accepted"
assert "task_id" in body["result"]
await asyncio.wait_for(done.wait(), timeout=2.0)
async def test_commit_wait_false_rejects_full_telemetry(api_client):
"""wait=false should reject telemetry payload requests."""
client, _ = api_client
session_id = await _new_session_with_message(client)
resp = await client.post(
f"/api/v1/sessions/{session_id}/commit",
params={"wait": False},
json={"telemetry": True},
)
assert resp.status_code == 400
body = resp.json()
assert body["status"] == "error"
assert body["error"]["code"] == "INVALID_ARGUMENT"
assert "wait=false" in body["error"]["message"]
async def test_commit_wait_false_rejects_summary_only_telemetry(api_client):
"""wait=false should also reject summary-only telemetry requests."""
client, _ = api_client
session_id = await _new_session_with_message(client)
resp = await client.post(
f"/api/v1/sessions/{session_id}/commit",
params={"wait": False},
json={"telemetry": {"summary": True}},
)
assert resp.status_code == 400
body = resp.json()
assert body["status"] == "error"
assert body["error"]["code"] == "INVALID_ARGUMENT"
assert "wait=false" in body["error"]["message"]
# ── Task lifecycle: pending → running → completed ──
async def test_task_lifecycle_success(api_client):
"""Task should transition pending→running→completed on success."""
client, service = api_client
session_id = await _new_session_with_message(client)
commit_started = asyncio.Event()
commit_gate = asyncio.Event()
async def gated_commit(_sid, _ctx):
commit_started.set()
await commit_gate.wait()
return {"session_id": _sid, "status": "committed", "memories_extracted": 5}
service.sessions.commit_async = gated_commit
# Fire background commit
resp = await client.post(f"/api/v1/sessions/{session_id}/commit", params={"wait": False})
task_id = resp.json()["result"]["task_id"]
# Wait for commit to start
await asyncio.wait_for(commit_started.wait(), timeout=2.0)
# Task should be running
task_resp = await client.get(f"/api/v1/tasks/{task_id}")
assert task_resp.status_code == 200
assert task_resp.json()["result"]["status"] == "running"
# Release the commit
commit_gate.set()
await asyncio.sleep(0.1)
# Task should be completed
task_resp = await client.get(f"/api/v1/tasks/{task_id}")
assert task_resp.status_code == 200
result = task_resp.json()["result"]
assert result["status"] == "completed"
assert result["result"]["memories_extracted"] == 5
# ── Task lifecycle: pending → running → failed ──
async def test_task_lifecycle_failure(api_client):
"""Task should transition to failed on commit error."""
client, service = api_client
session_id = await _new_session_with_message(client)
async def failing_commit(_sid, _ctx):
raise RuntimeError("LLM provider timeout")
service.sessions.commit_async = failing_commit
resp = await client.post(f"/api/v1/sessions/{session_id}/commit", params={"wait": False})
task_id = resp.json()["result"]["task_id"]
await asyncio.sleep(0.2)
task_resp = await client.get(f"/api/v1/tasks/{task_id}")
assert task_resp.status_code == 200
result = task_resp.json()["result"]
assert result["status"] == "failed"
assert "LLM provider timeout" in result["error"]
async def test_task_failed_when_memory_extraction_raises(api_client):
"""Extractor failures should propagate to task error instead of silent completed+0."""
client, service = api_client
session_id = await _new_session_with_message(client)
async def failing_extract(_context, _user, _session_id):
raise RuntimeError("memory_extraction_failed: synthetic extractor error")
service.sessions._session_compressor.extractor.extract = failing_extract
resp = await client.post(f"/api/v1/sessions/{session_id}/commit", params={"wait": False})
task_id = resp.json()["result"]["task_id"]
result = None
for _ in range(120):
await asyncio.sleep(0.1)
task_resp = await client.get(f"/api/v1/tasks/{task_id}")
assert task_resp.status_code == 200
result = task_resp.json()["result"]
if result["status"] in {"completed", "failed"}:
break
assert result is not None
assert result["status"] in {"completed", "failed"}
assert result["status"] == "failed"
assert "memory_extraction_failed" in result["error"]
# ── Duplicate commit rejection ──
async def test_duplicate_commit_rejected(api_client):
"""Second commit on same session should be rejected while first is running."""
client, service = api_client
session_id = await _new_session_with_message(client)
gate = asyncio.Event()
async def slow_commit(_sid, _ctx):
await gate.wait()
return {"session_id": _sid, "status": "committed", "memories_extracted": 0}
service.sessions.commit_async = slow_commit
# First commit
resp1 = await client.post(f"/api/v1/sessions/{session_id}/commit", params={"wait": False})
assert resp1.json()["result"]["status"] == "accepted"
# Second commit should be rejected
resp2 = await client.post(f"/api/v1/sessions/{session_id}/commit", params={"wait": False})
assert resp2.json()["status"] == "error"
assert "already has a commit in progress" in resp2.json()["error"]["message"]
gate.set()
await asyncio.sleep(0.1)
async def test_wait_true_rejected_while_background_commit_running(api_client):
"""wait=true must also reject duplicate commits for the same session."""
client, service = api_client
session_id = await _new_session_with_message(client)
gate = asyncio.Event()
async def slow_commit(_sid, _ctx):
await gate.wait()
return {"session_id": _sid, "status": "committed", "memories_extracted": 0}
service.sessions.commit_async = slow_commit
resp1 = await client.post(f"/api/v1/sessions/{session_id}/commit", params={"wait": False})
assert resp1.json()["result"]["status"] == "accepted"
resp2 = await client.post(
f"/api/v1/sessions/{session_id}/commit",
params={"wait": True},
json={"telemetry": True},
)
assert resp2.status_code == 200
assert resp2.json()["status"] == "error"
assert "already has a commit in progress" in resp2.json()["error"]["message"]
gate.set()
await asyncio.sleep(0.1)
# ── GET /tasks/{id} 404 ──
async def test_get_nonexistent_task_returns_404(api_client):
client, _ = api_client
resp = await client.get("/api/v1/tasks/nonexistent-id")
assert resp.status_code == 404
# ── GET /tasks list ──
async def test_list_tasks(api_client):
client, service = api_client
session_id = await _new_session_with_message(client)
async def instant_commit(_sid, _ctx):
return {"session_id": _sid, "status": "committed", "memories_extracted": 0}
service.sessions.commit_async = instant_commit
await client.post(f"/api/v1/sessions/{session_id}/commit", params={"wait": False})
await asyncio.sleep(0.1)
resp = await client.get("/api/v1/tasks", params={"task_type": "session_commit"})
assert resp.status_code == 200
tasks = resp.json()["result"]
assert len(tasks) >= 1
assert tasks[0]["task_type"] == "session_commit"
async def test_list_tasks_filter_status(api_client):
client, service = api_client
async def instant_commit(_sid, _ctx):
return {"session_id": _sid, "status": "committed", "memories_extracted": 0}
service.sessions.commit_async = instant_commit
session_id = await _new_session_with_message(client)
await client.post(f"/api/v1/sessions/{session_id}/commit", params={"wait": False})
await asyncio.sleep(0.1)
# completed tasks
resp = await client.get("/api/v1/tasks", params={"status": "completed"})
assert resp.status_code == 200
for t in resp.json()["result"]:
assert t["status"] == "completed"
# ── wait=true still works (backward compat) ──
async def test_wait_true_still_works(api_client):
"""wait=true should return inline result, no task_id."""
client, service = api_client
session_id = await _new_session_with_message(client)
async def instant_commit(_sid, _ctx):
return {"session_id": _sid, "status": "committed", "memories_extracted": 2}
service.sessions.commit_async = instant_commit
resp = await client.post(f"/api/v1/sessions/{session_id}/commit", params={"wait": True})
assert resp.status_code == 200
body = resp.json()
assert body["result"]["status"] == "committed"
assert "task_id" not in body["result"]
# ── Error sanitization in task ──
async def test_error_sanitized_in_task(api_client):
"""Errors stored in tasks should have secrets redacted."""
client, service = api_client
session_id = await _new_session_with_message(client)
async def leaky_commit(_sid, _ctx):
raise RuntimeError("Auth failed with key sk-ant-api03-DAqSsuperSecretKey123")
service.sessions.commit_async = leaky_commit
resp = await client.post(f"/api/v1/sessions/{session_id}/commit", params={"wait": False})
task_id = resp.json()["result"]["task_id"]
await asyncio.sleep(0.2)
task_resp = await client.get(f"/api/v1/tasks/{task_id}")
error = task_resp.json()["result"]["error"]
assert "superSecretKey" not in error
assert "[REDACTED]" in error