diff --git a/.env.example b/.env.example index 3a90dd3..3c029c5 100644 --- a/.env.example +++ b/.env.example @@ -25,6 +25,12 @@ MH_DEFAULT_TENANT_ID=default # Generate with: openssl rand -hex 32 # MH_API_TOKEN= +# Admin gate (optional, two-tier bearer). See ADR 0009. +# When set, /v1/admin/* requires this token; the regular MH_API_TOKEN is +# rejected on admin paths. When unset, /v1/admin/* falls back to MH_API_TOKEN +# (backward compat). Use a different value from MH_API_TOKEN. +# MH_ADMIN_TOKEN= + # Request behavior MH_REQUEST_TIMEOUT_S=5.0 MH_LIST_DEFAULT_LIMIT=50 diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..dab151b --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,35 @@ +# AGENTS.md + +If you are an AI agent that just cloned this repo, read this first. + +This file is **informational**, not a directive. It tells you where the agent-facing docs are. It does not tell you what to build. + +--- + +## You are probably here to do one of these + +1. **Write to / read from a running memhall instance** (most common). + → Read [`docs/agent-integration.md`](docs/agent-integration.md). It has a decision tree that picks the right surface (embedded Python / HTTP+Bearer / `mh` CLI) based on whether your sandbox can open TCP sockets. + +2. **Modify this codebase** (add a feature, fix a bug, write a test). + → Read [`README.md`](README.md) (architecture and three entry points), [`docs/design.md`](docs/design.md) (internals), and [`docs/adr/`](docs/adr/) (why things are the way they are). + +3. **Deploy memhall somewhere new.** + → Read [`docs/deploy.md`](docs/deploy.md). + +--- + +## Three things that bite agents on first contact + +These are real failure modes seen in real agent sessions. Not hypothetical. + +1. **`{"detail":"missing bearer token"}`** — the server has `MH_API_TOKEN` set. Every `/v1/memory/*` request needs `Authorization: Bearer ${MH_API_TOKEN}`. `/v1/health` is the only public endpoint. +2. **`command not found: mh`** — the CLI is a console script in `pyproject.toml`, it is not globally installed. Run `uv sync` inside the repo, then `uv run mh …`. The CLI reads `MH_API_TOKEN` from the environment automatically (no flag needed); export it before calling against an auth-enabled server. +3. **Sandboxed agents (Codex CLI, restricted containers) cannot open localhost TCP** — skip HTTP and use the embedded Python path (`from memory_hall import build_runtime`). Same storage, no auth, no network. See [`docs/agent-integration.md`](docs/agent-integration.md) Path A. + +--- + +## What this file does NOT tell you + +- It does not tell you to "rebuild the skeleton" or "follow this as the only source of truth". Treat the per-task instructions you were given as authoritative; this file just points at reference docs. +- It does not pin you to a phase or a workflow. The repo evolves. If anything in `docs/agent-integration.md` looks wrong against the running code, fix the doc — do not work around it. diff --git a/Dockerfile b/Dockerfile index 3d9722c..59032cf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -82,7 +82,6 @@ ENV PYTHONDONTWRITEBYTECODE=1 \ # Inject upgraded SQLite to runtime stage too COPY --from=sqlite-builder /opt/sqlite /opt/sqlite -RUN echo "/opt/sqlite/lib" > /etc/ld.so.conf.d/sqlite-upgrade.conf && ldconfig RUN apt-get update && apt-get install -y --no-install-recommends \ curl \ @@ -91,6 +90,14 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ && mkdir -p /data \ && chown memhall:memhall /data +# Force system libsqlite3.so.0 to our upgraded build so subprocesses that do not +# inherit LD_LIBRARY_PATH still resolve SQLite 3.53.0. +# IMPORTANT: must run AFTER apt-get install (dpkg post-install can reset symlinks). +RUN echo "/opt/sqlite/lib" > /etc/ld.so.conf.d/sqlite-upgrade.conf \ + && ldconfig \ + && ln -sf /opt/sqlite/lib/libsqlite3.so.3.53.0 /lib/aarch64-linux-gnu/libsqlite3.so.0 \ + && { ln -sf /opt/sqlite/lib/libsqlite3.so.3.53.0 /usr/lib/aarch64-linux-gnu/libsqlite3.so.0 2>/dev/null || true; } + WORKDIR /app COPY --from=builder --chown=memhall:memhall /app/.venv /app/.venv diff --git a/README.md b/README.md index 6dd7142..5412f7e 100644 --- a/README.md +++ b/README.md @@ -130,6 +130,8 @@ See [`docs/adr/0003-engine-library-vs-deployment-platform.md`](docs/adr/0003-eng No entry is privileged — they all hit the same backend, so no single-point-of-failure path. +> **Agents reading this**: see [`docs/agent-integration.md`](docs/agent-integration.md) for a decision tree that picks the right surface based on your sandbox, plus the auth + install gotchas that have bitten real Codex / Gemini sessions. + ### Embedded (in-process) use Some agents run in sandboxes that block localhost sockets (Codex CLI, some Gemini setups, restricted containers). For those, skip HTTP entirely: diff --git a/docs/adr/0008-personal-pki-lightweight-stance.md b/docs/adr/0008-personal-pki-lightweight-stance.md new file mode 100644 index 0000000..09c914b --- /dev/null +++ b/docs/adr/0008-personal-pki-lightweight-stance.md @@ -0,0 +1,97 @@ +# ADR 0008 — memhall 是 personal PKI,輕量 > 完整 + +- **Status**: Accepted +- **Date**: 2026-04-28 +- **Related**: ADR 0003(engine library vs deployment platform)、ADR 0005(v0.2 minimum viable contract)、ADR 0007(minimal token auth)、`rules/four-layer-north-star.md` L4 + +## Context + +2026-04-28 對 Phase A / A.5 / B 體檢時發現一個漂移傾向:每個 reliability incident 後,patch 容易順手帶入「業界最佳實踐」(k8s liveness/readiness 拆分、weighted linear hybrid 加 tuning knob、HMAC + principal registry + key rotation),把 memhall 的複雜度往 production-grade memory platform 推。 + +但 memhall 的實際定位是: + +- **單一使用者**(Maki) +- **單一部署**(Mac mini Tailscale tailnet `:9100`,mini2 冷備) +- **規模 ~10² 量級** entries +- **caller < 10**(ops-hub / repo CLI / `.claude/skills/*` / mk-brain),全部在 Maki 自己的 tailnet 內 +- **目的**:七位一體共用記憶大廳 + Maki 個人 PKI 的聯想入口 + +ADR 0003 已經把「engine library vs deployment platform」分開——這份 ADR 把它再往前推一步,明確 memhall 的設計目標**不是** production-grade memory platform,是 **personal PKI 的記憶引擎**。 + +## Decision + +memhall 接受以下四個北極星,依優先序: + +1. **聯想品質**(retrieval recall / ranking 正確) +2. **穩定**(不會壞、不會吞錯、不會 silent degrade) +3. **快速**(search p50 < 200ms,write < 50ms) +4. **輕量**(schema、config knob、auth 機制、ops surface 都要可以一個人理解) + +**任何 patch 在 land 前必須通過「personal PKI 體檢」**: + +- 這個改動修的是真 bug 還是引入「業界慣例」? +- 加了幾個 config knob?每個 knob 的 default 你能解釋嗎? +- schema 多了幾個欄位?對 ~10² 規模值得嗎? +- 對單一 caller 場景,是否引入跨組織 / 多 tenant / 多 operator 才需要的機制? +- 如果回答「以後可能用得到」——拒絕,等真的用到再做。 + +明確**不做**的清單(除非觸發 sunset criteria): + +- ❌ k8s 風格的 liveness/readiness/startup probe 三件套(單一 launchd container 不需要) + - **2026-04-28 補執行**:Phase A.5 PR2 Patch F 引入的 `/v1/healthz` + `/v1/ready` 拆分已 revert,回到單一 `/v1/health`(200/503,body 含完整 status)。理由:mini 用 `restart: unless-stopped`,health unhealthy 不會自動 restart,flapping 風險為零;單一 endpoint 對個人 PKI 維運心智成本更低 +- ❌ Hybrid search 的可調 α / mode switch(除非有 retrieval benchmark 證明非 RRF 更好) +- ❌ HMAC + nonce + per-key rotation(ADR 0007 minimal token + Tailscale ACL 已足夠) +- ❌ Principal registry / role mapping / `key_id → role/ns/agent` 表 +- ❌ Per-row 失敗計數 / retry budget machinery(log + 下次 reindex 重試就夠) +- ❌ Dashboard / metrics aggregation / 需要打開看的觀測介面(違反 L4) + +## Consequences + +### Gains + +- **複雜度預算用在聯想品質上**(embedding 模型、ranking、CJK tokenization),不是 ops surface +- **單人可維護**:schema、auth、health 邏輯都能一個下午讀完 +- **可逆**:每個 ADR 都有 sunset criteria,跨過門檻就升級,不跨就保持輕量 +- **OSS friendly**:`git clone && docker compose up` 立刻能跑,不需要設 ACL / 簽 cert / 發 key + +### Costs + +- **不適合多 operator 共用**:第二個 operator 出現時,這份 ADR 的多數決策需要重新評估 +- **Audit trail 較弱**:自我宣告的 `agent_id` 是唯一的 attribution,不是密碼學保證 +- **某些「正確」的工程實踐被刻意延後**:HMAC、principal registry、retry budget——不是因為它們錯,是因為**現在做的 ROI 不夠** + +### Non-goals + +- 不取代 ADR 0003 的 engine vs platform 分工:production-grade ACL / multi-tenant ACL / 跨組織 audit 仍由未來的 `memory-gateway` 承擔 +- 不否定 `rules/agent-security-hygiene.md` S2.1 的 HMAC 規格——那是 destination,這份 ADR 是「現在不要走」的理由 +- 不放棄 reliability:Phase A SQLite chain / silent except / WAL 修復都是必須做的,這份 ADR 不是「拒絕修 bug」 + +## Sunset criteria + +任一條件成立就重新審視這份 ADR: + +1. 第二個 operator(不是 Maki)開始寫入同一個 memhall 部署 +2. caller 數量 > 20,或出現 Maki 不認識的 caller +3. entries 規模超過 10⁵(schema / index 策略可能需要重新設計) +4. 出現需要密碼學 attribution 的 incident(token 洩漏 + 不知道誰寫的 entry) +5. memhall 變成 OSS 多人協作專案,外部 contributor 開始要求「production-grade」feature + +## Alternatives considered + +### A. 不寫這份 ADR,用 PR review 把關 + +拒絕。沒有明文化的設計哲學,每個 PR 都要重新辯論「這個是不是 over-design」。這份 ADR 把判準寫下來,未來的 PR / Codex 提案 / Claude 設計都先過這份體檢,不通過就直接砍。 + +### B. 寫成 rule(`rules/memhall-lightweight.md`)而非 ADR + +拒絕。ADR 是 repo 內 immutable 決策記錄,scope 限定 memhall。Rules 是跨專案行為規範。這份內容的 scope 是 memhall 設計哲學,屬於 ADR。 + +### C. 列「禁止做什麼」清單但不寫優先序 + +拒絕。沒有優先序時,遇到取捨會憑感覺。明確「聯想品質 > 穩定 > 快速 > 輕量」讓未來的取捨有依據——例如 BM25 normalize bug 雖然動了 ranking 邏輯,但是修聯想品質,最高優先;hybrid α 參數化是輕量倒退,最低優先,需要 benchmark 才能 land。 + +## Implementation summary + +- 新增本 ADR +- 更新 `docs/adr/README.md` 索引 +- 後續 PR description 在引入新 config knob / schema 欄位 / auth 機制時,必須引用本 ADR 並回答「personal PKI 體檢」五題 diff --git a/docs/adr/0009-admin-gate.md b/docs/adr/0009-admin-gate.md new file mode 100644 index 0000000..3c6f6f9 --- /dev/null +++ b/docs/adr/0009-admin-gate.md @@ -0,0 +1,106 @@ +# ADR 0009 — Admin gate(two-tier bearer,不做 HMAC) + +- **Status**: Accepted +- **Date**: 2026-04-28 +- **Related**: ADR 0007(minimal token auth,這份是它的最小延伸)、ADR 0008(personal PKI 輕量立場,是這份的判準依據)、Codex Phase B Dissent 2026-04-27(D2 Option E 的最小實作) + +## Context + +現況下 `/v1/admin/reindex` 與 `/v1/admin/audit` 兩個 admin endpoint 由 `MH_API_TOKEN` 統一保護——任何持有 api_token 的 caller 都能呼叫 admin 操作。風險: + +- api_token 被多個 caller 共用(ops-hub / repo CLI / 4 個 Claude skills / mk-brain),任一 caller 機器被攻破或 log 不慎洩漏 token,都直接拿到 admin 權限 +- reindex 是危險動作(會掃描全表、可能踩到 embedder 連環失敗),不該與一般 read/write 共用權限 + +七位一體 Phase B 一開始的提案是「HMAC + nonce + replay window + principal registry + 14 天並存期 + 7 連日零 bearer write 退場」一整套 production-grade machinery(rules/agent-security-hygiene.md S2.1 的方向)。Codex Phase B Dissent 2026-04-27 D2 Option E 把它縮成「先封 admin,再做 attribution」。SuperGrok 2026-04-28 sanity check:2025-2026 全球範圍沒有命中本情境(Tailscale tailnet + single-tenant + two-tier static bearer)的近期 incident,社群也沒把這個簡化設計列為已知 anti-pattern;獨立 admin bearer 反而是 community 推薦的 least-privilege 做法。 + +ADR 0008 已 ratify「memhall 是 personal PKI,輕量 > 完整」,明確排除 HMAC / principal registry / per-key rotation。本 ADR 把 Phase B 縮到 ADR 0008 立場下還能做的最小步驟。 + +## Decision + +**新增 `MH_ADMIN_TOKEN`(optional,獨立於 `MH_API_TOKEN`)。設定後 `/v1/admin/*` 要求 admin token,一般 api_token 在 admin path 上被拒絕。** + +- 新 config field:`Settings.admin_token: str | None = None`(`MH_ADMIN_TOKEN` env) +- Middleware 行為(`src/memory_hall/server/app.py` 的 `require_api_token`): + - `/v1/health*` → 永遠 public(沿用 ADR 0007) + - `/v1/admin/*` 且 `admin_token` 已設 → 要求 `Authorization: Bearer `,傳 `api_token` 也回 `401` + - `/v1/admin/*` 但 `admin_token` 未設 → fallback 到 `api_token` 邏輯(ADR 0007 backward compat) + - 其他 path → 既有 `api_token` 邏輯 +- `admin_token` 不能反過來用在非 admin path(least privilege 雙向) +- 比較全程用 `hmac.compare_digest`(constant-time,沿用 ADR 0007) +- 錯誤訊息分開(`invalid token` vs `invalid admin token`),但**不**用 `403` 區分「你的 token 是 valid api_token 但不是 admin」——避免 token validity oracle + +非程式碼層面的搭配(docs only,不寫進 repo code): + +- 在 mini Tailscale ACL 鎖 `/v1/admin/*` path 到 Maki 自己的 device(defense-in-depth 第二層) +- Token 用 `openssl rand -hex 32` 生成,與 `MH_API_TOKEN` 不同值 +- 不要 log `Authorization` header(已 grep 過 src/memory_hall/,目前無此類 log;本 PR 不引入) + +## Consequences + +### Gains + +- **Admin 操作從共享 token 隔離出來**:一般 caller token 洩漏不再等於 admin 失守 +- **Backward compatible**:`MH_ADMIN_TOKEN` 未設時行為與 ADR 0007 完全相同,現有 deployment 不需改 +- **實作 ~30 行**(config 1 行 + middleware 改 ~20 行 + tests 6 個 case),1.5 小時內完成 +- **Personal PKI 體檢通過**:1 個新 config knob、0 個新 schema 欄位、0 個跨組織機制 + +### Costs + +- **仍是 possession-based**:admin_token 洩漏 = admin 失守,沒有 cryptographic attribution +- **沒有 rotation infra**:rotate admin_token = 改 env + restart container + 通知少數 caller,與 api_token 同等 +- **Config-load 時 fail-fast 兩個 invariant**(Codex review 2026-04-28 PR1 round 1 補強,5 行 pydantic validator): + - `admin_token` 設了但 `api_token` 沒設 → 拒絕啟動(否則非 admin path 會 fail-open) + - `admin_token == api_token` → 拒絕啟動(否則 two-tier 被靜默抵消) + - 這兩條不算違反 ADR 0008 輕量原則:屬於「防止操作者誤配置造成 silent security regression」,5 行 code 防一個 high-severity 漏洞,ROI 明確 + +### Non-goals + +- 不取代 HMAC(rules/agent-security-hygiene.md S2.1 仍是 destination,但 sunset criteria 未觸發) +- 不引入 principal registry / role mapping +- 不做 14 天 sunset window(沒有要 retire 的舊機制) +- 不在 code 層強制 Tailscale ACL(infra config 該由 ops 維護) + +## Alternatives considered + +### A. Codex 完整版 Phase B Option E(registry + HMAC + 14 天並存期 + 7 連日零 bearer write 退場) + +拒絕:sunset criteria 未觸發(單一 operator / caller < 10 / 全部在 Maki tailnet 內)。引入 HMAC 等 ADR 0008 sunset criteria 1 (第二個 operator) 或 5 (token 洩漏 incident) 之一發生才做。 + +### B. 用 `403 Forbidden` 區分「valid api_token 用在 admin path」 + +拒絕:會形成 token validity oracle(攻擊者送 garbage 拿 401,送 valid api_token 拿 403,能反推 token 是否合法)。統一回 401 較安全。內部 caller 的 debug 體驗用「invalid admin token」訊息字串足以區分。 + +### C. 不做 admin gate,靠 Tailscale ACL 鎖 path + +拒絕:ACL 是 device 層級,無法區分「同 device 上 ops-hub 的 read-only flow」和「同 device 上不該呼叫 reindex 的 LINE bot」。code 層 self-defense + ACL defense-in-depth 比單靠 ACL 強。 + +### D. 把 admin_token 設成 default required(不向後相容) + +拒絕:會影響現有 deployment(mini production),需要 migration window。本 ADR 走可逆路徑:opt-in 起手,未來如果要強制可再 supersede。 + +## Sunset criteria + +任一條件成立就重新審視: + +1. ADR 0008 任一 sunset criteria 觸發(自動帶動本 ADR) +2. admin_token 洩漏 incident(這份 ADR 為什麼沒有 rotation infra 就是答案——出事的話 rotation 是第一個要建的東西) +3. caller 數量需要 per-caller admin attribution(例如知道是 ops-hub 還是 mk-brain 觸發的 reindex) +4. 出現第三層權限需求(read-only / write / admin → read-only / write / reindex / audit / superuser) + +## Implementation summary + +- `src/memory_hall/config.py`:加 `admin_token` 欄位 +- `src/memory_hall/server/app.py`:擴充 `require_api_token` middleware,加 admin path 分支 +- `tests/test_auth.py`:8 個新 case(6 個 middleware 行為 + 2 個 config invariant fail-fast) +- `.env.example`:加 `MH_ADMIN_TOKEN=` 範例段落 +- `docs/api.md`:加「Admin gate (two-tier bearer)」段落 + +Total: ~140 行 across 6 files。`pytest`:16 passed (auth),full suite 59 passed 1 skipped。 + +## Round 1 review history + +- 2026-04-28 Codex review REJECT,2 finding: + 1. [HIGH] `admin_token` 設 + `api_token` 沒設 → 非 admin path fail-open(實測 POST /v1/memory/write 回 201) + 2. [MEDIUM] `admin_token == api_token` → 靜默抵消 two-tier +- 修法:在 `Settings` 加 `_validate_auth_tokens` model_validator,config load 時 fail-fast +- 補 2 個 unit test 鎖 invariant diff --git a/docs/adr/README.md b/docs/adr/README.md index d1404e4..ac702dd 100644 --- a/docs/adr/README.md +++ b/docs/adr/README.md @@ -11,6 +11,8 @@ Numbered, immutable records of significant design choices. Append new entries; n | [0005](0005-v0.2-minimum-viable-contract.md) | v0.2 Minimum Viable Contract (production-facing freeze) | Accepted (2026-04-19) | | [0006](0006-http-embedder-embed-queue-isolation.md) | HttpEmbedder: embed path isolation from LLM queue | Accepted (2026-04-20) | | [0007](0007-minimal-token-auth.md) | Minimal Token auth (single-tenant deployment shim) | Accepted (2026-04-23) | +| [0008](0008-personal-pki-lightweight-stance.md) | memhall 是 personal PKI,輕量 > 完整 | Accepted (2026-04-28) | +| [0009](0009-admin-gate.md) | Admin gate(two-tier bearer,不做 HMAC) | Accepted (2026-04-28) | ## Format diff --git a/docs/agent-integration.md b/docs/agent-integration.md new file mode 100644 index 0000000..6702fb3 --- /dev/null +++ b/docs/agent-integration.md @@ -0,0 +1,165 @@ +# Agent Integration Guide + +If you are an AI agent (Claude / Codex / Gemini / a sub-agent / a script in a sandbox) and you need to read or write memhall, this is the doc for you. + +The README's "Three entry points" lists the surfaces. This doc is the **decision tree**: which surface you should actually pick, and the gotchas each one has. + +> **Status legend** (last verified 2026-04-28 against `fix/reliability-phase-a5-2026-04-27`): +> - ✅ **verified** — exercised end-to-end in a real session, including against a server with `MH_API_TOKEN` set. +> - ⚠️ **partial** — works for the no-auth case, but does **not** currently work against a server that requires `MH_API_TOKEN`. +> +> If a path is marked ⚠️ and you need it to work with auth, fall back to a ✅ path until the gap is closed. + +--- + +## Decision tree + +``` +Are you running in the same process / repo as memory-hall, with `import memory_hall` available? +├─ Yes → use the embedded Python runtime (Path A) +└─ No + │ + Can your sandbox open a TCP socket to the memhall host? + ├─ Yes → use HTTP + Bearer (Path B) + └─ No (sandboxed agents: Codex CLI, restricted containers, some Gemini setups) + └─ install the package and use Path A in-process, + or shell out via `mh` CLI which goes through Path A under the hood (Path C) +``` + +If you do not know which one applies to you, default to **Path B (HTTP + Bearer)** — it works from anywhere that has network access and `curl`. + +--- + +## Path A — Embedded Python (in-process) ✅ + +Status: ✅ verified. Bypasses HTTP + auth entirely (in-process call, no middleware). + +Use when: same process, sandboxed environments where TCP is blocked, batch imports, tests. + +```python +import asyncio +from memory_hall import Settings, build_runtime +from memory_hall.models import WriteMemoryRequest, SearchMemoryRequest + +async def main(): + runtime = build_runtime(settings=Settings()) + await runtime.start() + try: + await runtime.write_entry( + tenant_id="default", + principal_id="my-agent", + payload=WriteMemoryRequest( + agent_id="my-agent", + namespace="shared", + type="note", + content="hello from inside the process", + ), + ) + hits = await runtime.search_entries( + tenant_id="default", + payload=SearchMemoryRequest(query="hello", limit=5), + ) + print(hits.total) + finally: + await runtime.stop() + +asyncio.run(main()) +``` + +**No network, no auth, same storage.** This is the path Codex / Gemini sandboxes should prefer when localhost TCP is blocked by the sandbox. + +Gotchas: +- `Settings()` reads from env (`MH_DB_PATH`, `MH_EMBEDDER_KIND`, …). If the agent's working directory has its own `.env`, runtime config will diverge from the running HTTP server. Point both at the same DB if you want them to share state. +- `build_runtime` is async; you need an event loop. In a sync script, wrap with `asyncio.run(...)`. + +--- + +## Path B — HTTP + Bearer ✅ + +Status: ✅ verified against a server with `MH_API_TOKEN` set. This is the most reliable path when the sandbox has TCP access. + +Use when: any language, any tool, sandbox can reach the host over TCP. + +```bash +# Set once per shell. Maki's setup keeps the token at ~/.config/memhall/token (0600). +export MH_API_TOKEN="$(cat ~/.config/memhall/token)" + +curl -sS http://127.0.0.1:9000/v1/memory/write \ + -H "Authorization: Bearer ${MH_API_TOKEN}" \ + -H 'Content-Type: application/json' \ + -d '{ + "agent_id": "my-agent", + "namespace": "shared", + "type": "note", + "content": "hello from curl" + }' +``` + +Gotchas: +- **`Authorization: Bearer …` is required** on every `/v1/memory/*` request when the server has `MH_API_TOKEN` set. `/v1/health` is the only public endpoint. Missing the header returns `{"detail":"missing bearer token"}` — the server is alive, you are just unauthenticated. +- If the server runs without `MH_API_TOKEN` set (dev / standalone), the header is ignored. Sending it anyway is safe and forward-compatible — always send it. +- Default port is `9000`. Maki's home deployment maps it to `9100` (`http://100.122.171.74:9100`). Check the deployment you are talking to. +- `/v1/admin/*` requires `MH_ADMIN_TOKEN` (a different token). Regular `MH_API_TOKEN` is **rejected** on admin paths when admin token is set. See `docs/adr/0007-minimal-token-auth.md`. +- See `examples/shell/write_memory.sh` for a runnable starter. + +--- + +## Path C — `mh` CLI ✅ + +Status: ✅ verified. The CLI reads `MH_API_TOKEN` from the environment (via `Settings()`) and attaches `Authorization: Bearer ` automatically when set. Works against both auth-enabled and no-auth servers. Verified against `src/memory_hall/cli/main.py:31` on `fix/reliability-phase-a5-2026-04-27`; covered by `tests/test_cli_auth.py`. + +Use when: you want a one-liner from a shell, you do not want to hand-roll JSON, and the package is installed. + +```bash +# One-time install in the project venv: +uv sync +# Then `mh` is on PATH inside the venv. + +# If the server has MH_API_TOKEN set, export it (CLI reads it automatically): +export MH_API_TOKEN="$(cat ~/.config/memhall/token)" + +uv run mh write "DEC-018 落地完成" \ + --agent-id codex \ + --namespace project:memory-hall \ + --type decision \ + --tag governance + +uv run mh search "DEC-018" +``` + +Gotchas: +- `mh` is a console script defined in `pyproject.toml`. It is **not** globally available. If `command -v mh` returns nothing, you have not installed the package — run `uv sync` (or `pip install -e .`) inside the repo first. +- `uv run mh …` works without prior install but resolves dependencies on first use. In sandboxes where `~/.cache/uv` is not writable, set `UV_CACHE_DIR=/tmp/uv-cache` before calling. +- The CLI hits HTTP under the hood. `MH_API_TOKEN` is read from the environment on each command; no CLI flag is needed. If unset, no `Authorization` header is sent (works against no-auth servers). + +--- + +## Common failure modes + +| Symptom | Likely cause | Fix | +|---|---|---| +| `{"detail":"missing bearer token"}` | Path B without `Authorization` header | Set `MH_API_TOKEN` and add `-H "Authorization: Bearer ${MH_API_TOKEN}"` | +| `curl: (7) Couldn't connect to server` from a sandboxed agent | Sandbox blocks localhost TCP | Switch to Path A (embedded Python) | +| `command not found: mh` | Package not installed in this shell's PATH | `uv sync` inside the repo, or use `uv run mh …` | +| `uv run mh` errors on `~/.cache/uv` permission | Sandbox cache dir not writable | `export UV_CACHE_DIR=/tmp/uv-cache` | +| Writes succeed but search returns nothing | Path A and Path B pointing at different DB files | Align `MH_DB_PATH` in both, or always go through HTTP | + +--- + +## Picking the right `agent_id` and `namespace` + +- `agent_id` — stable identity for the agent. Examples: `claude`, `codex`, `gemini`, `max`, `grok`, `gemma4`, `maki`. Do not invent a new id per session; one id per agent persona. +- `namespace` — scope of the entry. Examples: `home`, `work`, `project:`, `agent:`, `shared`. +- `type` — one of `episode`, `decision`, `observation`, `experiment`, `fact`, `note`, `question`, `answer`. + +Do not write company-sensitive content into `shared` or `work`. Use `project:` or do not write at all. + +--- + +## See also + +- [`README.md`](../README.md) — full feature list and quickstart +- [`docs/api.md`](api.md) — HTTP endpoint reference +- [`docs/adr/0007-minimal-token-auth.md`](adr/0007-minimal-token-auth.md) — why Bearer auth is the way it is +- [`examples/codex_cli/`](../examples/codex_cli/) — Codex CLI starter +- [`examples/shell/`](../examples/shell/) — curl starter diff --git a/docs/api.md b/docs/api.md index aeb783b..90a7ac5 100644 --- a/docs/api.md +++ b/docs/api.md @@ -23,6 +23,15 @@ Authorization: Bearer ``` Missing or wrong token → `401`. `/v1/health` stays public so external uptime probes and the in-image HEALTHCHECK don't need credentials. Rationale and scope limits in [ADR 0007](adr/0007-minimal-token-auth.md). This is **not** a replacement for the production HMAC mode below — it's a local-network deployment shim. +### Admin gate (two-tier bearer, optional) +When `MH_ADMIN_TOKEN` is set in addition to `MH_API_TOKEN`: +- `/v1/admin/*` paths require `Authorization: Bearer ` — the regular `MH_API_TOKEN` is **rejected** on admin paths +- The regular `MH_API_TOKEN` continues to gate non-admin paths (`/v1/memory/*`) +- `MH_ADMIN_TOKEN` does **not** grant access to non-admin paths (least privilege both ways) +- `/v1/health` remains public + +When `MH_ADMIN_TOKEN` is unset, admin paths fall back to `MH_API_TOKEN` (backward compatible with ADR 0007). Operators are encouraged to also lock `/v1/admin/*` to specific devices via Tailscale ACL as defense-in-depth. Rationale in [ADR 0009](adr/0009-admin-gate.md). + ### Future HMAC mode (planned via `memory-gateway`, not implemented in this repo yet) The long-term production mode is the HMAC scheme below: - `Authorization: HMAC :` header diff --git a/docs/council/2026-04-27-codex-phase-b-dissent.md b/docs/council/2026-04-27-codex-phase-b-dissent.md new file mode 100644 index 0000000..1f8fb5f --- /dev/null +++ b/docs/council/2026-04-27-codex-phase-b-dissent.md @@ -0,0 +1,21 @@ +# Codex Phase B Dissent + +## D1 反對點 +- 不同意叫 Option C。B.3 defer 後,這次其實不是 defense-in-depth;E26 仍是 app-layer single fence。`/v1/admin/*` 的 Tailscale ACL 應升到同優先。 +- HMAC 不算 over-engineering,E11 attribution 需要它;但順序錯。`Principal.role` 現在預設就是 `admin`,role source 不存在。若 B.1 先做、B.2 後補,容易變成所有 valid key 都是 admin。E25 也證明 bearer 與 principal chain 斷開。 +- caller 成本被低估:`ops-hub` 只會 bearer 且 401/403 當 permanent;repo CLI 不帶 auth;`.claude/skills/*` 是 curl bearer;`mk-brain` 的 HMAC 是 gateway 另一套格式。 + +## D2 替代方案 +- **Option E**:先封 admin,再做 attribution。E0 用現有 bearer + static allowlist + Tailscale ACL 鎖 `/v1/admin/*`;E1 補 principal registry(`key_id -> role/ns/agent`)+ shared signer,再漸進切 S2.1;E2 telemetry 穩定後 retire bearer。JWT/PASETO 不建議,沒解 replay/body integrity。 + +## D3 Missing Risks +- R6 role model 缺席;R7 rotation 期間 `ops-hub` queue 遇 401/403 會 drop queued records;R8 220 筆 `dev-local` 不可硬回填,應標 `legacy-unattributed`。 + +## D4 Implementation Order +- 先 admin gate,再 HMAC。拆兩 PR:PR1 admin gate + tests + ACL;PR2 registry + HMAC + caller helper。並存期至少 14 天,且以 7 連日零 bearer write 退場,不建議寫死 7 天。 + +## D5 mk-council Interaction +- bypass 只繞過 council,不等於 auth downgrade。direct memhall 仍應優先 HMAC;bearer 只能在 sunset window 做 non-admin fallback,`/v1/admin/*` 不得 fallback。ADR-0007 可暫留為 deprecated shim,adoption 完成後 retire。 + +## Verdict +**APPROVE WITH MODIFICATIONS** — 方向對,但若不先封 E26、先補 role/registry,再推 migration,只是把洞從「無 gate」換成「所有 valid key 皆 admin」。 diff --git a/docs/council/2026-04-27-phase-a5-pr1-codex-answer.md b/docs/council/2026-04-27-phase-a5-pr1-codex-answer.md new file mode 100644 index 0000000..03b5d75 --- /dev/null +++ b/docs/council/2026-04-27-phase-a5-pr1-codex-answer.md @@ -0,0 +1,91 @@ +# 2026-04-27 Phase A.5 PR1 Codex Summary + +Source briefing: [AGENT: Claude] + +Note: +本檔未追加成第 5 個 commit,刻意保留「4 個 patch commit」的 review 邊界。 + +## Patch A + +- Commit: `c17b8a0` `分頁化 admin reindex 掃描流程` +- Diff stat: + - `src/memory_hall/server/app.py` `97 insertions, 24 deletions` + - `src/memory_hall/storage/interface.py` `8 insertions` + - `src/memory_hall/storage/sqlite_store.py` `34 insertions` + - `tests/test_sync_status.py` `117 insertions, 1 deletion` +- Main changes: + - `src/memory_hall/server/app.py:517` reworked `_handle_reindex()` into a cursor-based loop with fixed `limit=200`, batch progress logs, and `CancelledError` passthrough after already-finished batches. + - `src/memory_hall/storage/sqlite_store.py:210` added `sync_status` filtering to `list_entries()`. + - `src/memory_hall/storage/sqlite_store.py:251` added `count_entries()` so the worker can pre-compute total batch count. + - `tests/test_sync_status.py:106` proves admin full reindex over `205` failed rows never calls `limit=None`. + - `tests/test_sync_status.py:144` proves `pending_only=True` still paginates only `sync_status='pending'`. + +## Patch B + +- Commit: `31ddf83` `回收 SQLite 暫時性故障連線並重試` +- Diff stat: + - `src/memory_hall/storage/sqlite_store.py` `411 insertions, 138 deletions` + - `tests/test_reindex_retry.py` `149 insertions, 1 deletion` +- Main changes: + - `src/memory_hall/storage/sqlite_store.py:18` added transient `sqlite3.OperationalError` markers for `disk i/o error`, `database is locked`, `database table is locked`, and `database is busy`. + - `src/memory_hall/storage/sqlite_store.py:626` added `_run_read_operation()` with one-shot recycle + retry. + - `src/memory_hall/storage/sqlite_store.py:648` added `_run_writer_operation()` with writer recycle + reopen before retry. + - `src/memory_hall/storage/sqlite_store.py:670` logs recycled connection ids via `aiosqlite connection recycled after disk I/O error ...`. + - `tests/test_reindex_retry.py:142` verifies writer connection recycle on first `disk I/O error`. + - `tests/test_reindex_retry.py:214` verifies read connection recycle on first `database is locked`. + +## Patch C + +- Commit: `3e7d2ce` `強制 runtime sqlite3 系統連結指向 3.53.0` +- Diff stat: + - `Dockerfile` `6 insertions, 1 deletion` +- Main changes: + - `Dockerfile:83` keeps `/opt/sqlite/lib` in `ld.so.conf.d`, runs `ldconfig`, and forces `/lib/aarch64-linux-gnu/libsqlite3.so.0` to point at `/opt/sqlite/lib/libsqlite3.so.3.53.0`. + - `Dockerfile:85` adds the rationale comment for child processes that do not inherit `LD_LIBRARY_PATH`. + +## Patch D + +- Commit: `a19886b` `新增背景 WAL checkpoint 與雙資料庫截斷` +- Diff stat: + - `src/memory_hall/config.py` `1 insertion` + - `src/memory_hall/server/app.py` `41 insertions` + - `src/memory_hall/storage/interface.py` `2 insertions` + - `src/memory_hall/storage/sqlite_store.py` `111 insertions, 20 deletions` + - `src/memory_hall/storage/vector_store.py` `182 insertions, 79 deletions` + - `tests/test_smoke.py` `75 insertions` +- Main changes: + - `src/memory_hall/config.py:38` adds `wal_checkpoint_interval_s` mapped from `MH_WAL_CHECKPOINT_INTERVAL_S`, default `300.0`. + - `src/memory_hall/server/app.py:105` adds `_wal_checkpoint_worker`. + - `src/memory_hall/server/app.py:388` adds the periodic background task. + - `src/memory_hall/server/app.py:736` adds `_checkpoint_wal_databases()` for both main DB and vector DB with `WAL checkpoint completed: ...` logging. + - `src/memory_hall/storage/sqlite_store.py:493` adds async `checkpoint_wal()` with reader gating plus `PASSIVE -> RESTART -> TRUNCATE`. + - `src/memory_hall/storage/sqlite_store.py:706` adds `_acquire_reader_slot()` / `_pause_readers()` so checkpoint can drain in-flight readers before truncation. + - `src/memory_hall/storage/vector_store.py:50` adds a connection lock around the shared sqlite3 connection. + - `src/memory_hall/storage/vector_store.py:163` adds vector DB `checkpoint_wal()`. + - `tests/test_smoke.py:151` writes `100` entries, triggers checkpoint directly, and verifies both WAL files shrink to `<= 32KB`. + +## Pytest Results + +- After Patch A: `42 passed, 1 skipped, 4 warnings` +- After Patch B: `44 passed, 1 skipped, 4 warnings` +- After Patch C: `44 passed, 1 skipped, 4 warnings` +- After Patch D: `45 passed, 1 skipped, 4 warnings` +- Extra targeted runs: + - `pytest -q tests/test_reindex_retry.py` + - `pytest -q tests/test_smoke.py tests/test_vec0.py` + - `ruff check src tests` + +## Deviations + +- Patch C docker validation could not be completed in this sandbox: + - `docker build -t memory-hall:0.1.0 .` + - Result: `permission denied while trying to connect to the docker API at unix:///Users/maki/.docker/run/docker.sock` +- I did not make a 5th commit for this summary file, to preserve the requested 4-patch commit structure. +- I could not update `~/infrastructure/handoff/latest.md` from this session because the sandbox only allows writes inside the repo/workspace roots. + +## Latent Bugs Found + +- `src/memory_hall/storage/sqlite_store.py:47` + - `SqliteStore.open()` still fails hard if a transient SQLite error happens during startup schema/open, because the new recycle+retry layer only wraps steady-state read/write operations, not bootstrap. +- `src/memory_hall/storage/vector_store.py:50` + - Shared vector DB connection had no explicit locking before this PR. I fixed that as part of Patch D because the new checkpoint path made the race impossible to ignore. diff --git a/docs/council/2026-04-27-phase-a5-pr2-codex-answer.md b/docs/council/2026-04-27-phase-a5-pr2-codex-answer.md new file mode 100644 index 0000000..99e401f --- /dev/null +++ b/docs/council/2026-04-27-phase-a5-pr2-codex-answer.md @@ -0,0 +1,87 @@ +# 2026-04-27 Phase A.5 PR2 Codex Summary + +Source briefing: [AGENT: Claude] + +Note: +本次 sandbox 可修改工作樹與跑測試,但無法在 `.git/` 建立新檔,`git commit` 會失敗於 `.git/index.lock: Operation not permitted`。 +以下內容因此依「預期 commit 邊界」整理,而不是實際 commit hash。 + +## Intended Commit 1 — Patch E.0 + +- Intended message: + - `fix(search): 修 _normalize_bm25 邏輯反轉 (BM25 愈好 score 反而愈低)` +- Files: + - `src/memory_hall/storage/sqlite_store.py` + - `tests/test_fts_tokenization.py` +- Main changes: + - `src/memory_hall/storage/sqlite_store.py:839` 將 BM25 normalize 從 `1/(1+abs(s))` 改為 `-bm25/(1.0-bm25)`,使 SQLite FTS5 的負值 BM25 與品質同向單調遞增。 + - `tests/test_fts_tokenization.py:100` 新增 5 個負值 BM25 的純單元測試,驗證 normalize 後排序保序。 + - Caller audit 結論:`_normalize_bm25()` 目前只有 `SqliteStore.search_lexical()` 使用;它先由 SQL `ORDER BY bm25_score`(raw negative BM25,越小越好)決定 lexical 排名,再把 normalized score 傳給 `MemoryHallRuntime.search_entries()`,沒有其他 caller 另外拿 normalized 分數做反向排序。 +- Pytest after patch: + - `46 passed, 1 skipped, 4 warnings` + +## Intended Commit 2 — Patch E + +- Intended message: + - `feat(search): hybrid 改 weighted linear combination (α=0.3 default)` +- Files: + - `src/memory_hall/config.py` + - `src/memory_hall/models.py` + - `src/memory_hall/server/app.py` + - `tests/test_search_degraded.py` + - `tests/test_hybrid_search.py` +- Main changes: + - `src/memory_hall/config.py:36` 新增 `MH_HYBRID_MODE`(default `weighted_linear`,可切 `rrf`)與 `MH_HYBRID_ALPHA`(default `0.3`)。 + - `src/memory_hall/server/app.py:263` search path 改為依 mode dispatch。 + - `src/memory_hall/server/app.py:789` 新增 weighted linear combine:`alpha * bm25 + (1 - alpha) * semantic`。 + - `src/memory_hall/server/app.py:796` 邊界條件: + - `semantic_status != "ok"` → `alpha=1.0` pure lexical fallback + - lexical 空 → `alpha=0.0` pure semantic + - no overlap → union 補 `0.0` + - `src/memory_hall/server/app.py:816` 保留 legacy RRF path,`MH_HYBRID_MODE=rrf` 可切回舊行為。 + - `src/memory_hall/models.py:183` `score_breakdown` 新增 `hybrid_mode` 與 `alpha`。`rrf` 欄位名稱保留給 backward compatibility,但值現在承載最終 combined score,需配合 `hybrid_mode` 解讀。 + - `tests/test_hybrid_search.py` 新增 4 個測試: + - rare lexical target Top 1 + - semantic paraphrase 在 lexical 空時由 semantic 拉回 + - conflict resolution 時 both-signals entry 高於單維 entry + - legacy `rrf` mode 仍可用 + - `tests/test_search_degraded.py` 補 semantic fail 時 `alpha=1.0` pure lexical fallback 驗證。 +- Pytest after patch: + - `50 passed, 1 skipped, 4 warnings` +- Extra targeted runs: + - `pytest -q tests/test_hybrid_search.py tests/test_search_degraded.py tests/test_write_search.py tests/test_cjk_search.py` + +## Intended Commit 3 — Patch F + +- Intended message: + - `feat(health): 拆分 liveness /healthz 與 readiness /ready (業界慣例)` +- Files: + - `src/memory_hall/server/app.py` + - `src/memory_hall/server/routes/health.py` + - `tests/test_smoke.py` + - `tests/test_auth.py` + - `Dockerfile` +- Main changes: + - `src/memory_hall/server/app.py:342` 新增 `ready()`,沿用既有 `_health_cache` 與 sub-check。 + - `src/memory_hall/server/app.py:348` 新增 `healthz()`,固定回 `{"status": "alive"}`,不碰 DB / vector / embedder。 + - `src/memory_hall/server/routes/health.py` 新增 `/v1/healthz`、`/v1/ready`,並讓 `/v1/health` alias 到 readiness 回應。 + - `src/memory_hall/server/app.py:953` auth middleware 放行 `/v1/healthz`、`/v1/ready`、`/v1/health`。 + - `Dockerfile:115` `HEALTHCHECK` 改打 `/v1/healthz`,避免 readiness probe 因 embedder/reindex 波動而誤判 container dead。 + - `tests/test_smoke.py` 新增: + - `healthz` 在 embedder fail 時仍回 200 alive + - `/v1/ready` 與 `/v1/health` alias response 一致 + - Dockerfile probe 指向 `healthz` + - `tests/test_auth.py` 更新為 3 個 health endpoints 都保持 public。 +- Pytest after patch: + - `53 passed, 1 skipped, 4 warnings` +- Extra targeted runs: + - `pytest -q tests/test_smoke.py tests/test_auth.py` + +## Deviations + +- 無法在此 session 內完成使用者要求的 3 個實際 git commits: + - `git add` 可行 + - `git commit` 失敗:`.git/index.lock: Operation not permitted` +- 無法更新 `~/infrastructure/handoff/latest.md`: + - 該路徑不在本 session writable roots 內 +- 未引入新 dependency。 diff --git a/examples/codex_cli/README.md b/examples/codex_cli/README.md index 335ea0a..d3abc2a 100644 --- a/examples/codex_cli/README.md +++ b/examples/codex_cli/README.md @@ -2,13 +2,41 @@ `mh` wraps the HTTP API, so Codex or any shell session can write notes without touching MCP. -Start the server: +> **Auth status (verified 2026-04-28 against `fix/reliability-phase-a5-2026-04-27`)**: the CLI reads `MH_API_TOKEN` from the environment via `Settings()` and attaches `Authorization: Bearer ` automatically. Works against both auth-enabled and no-auth servers. See [`docs/agent-integration.md`](../../docs/agent-integration.md) for the full decision tree. + +## Install (one-time) + +`mh` is a console script defined in `pyproject.toml`. It is **not** globally available until the package is installed in this venv. + +```bash +uv sync +``` + +If `command -v mh` still returns nothing, you are not in the project venv. `uv run mh …` invokes it without activation. + +In sandboxes where `~/.cache/uv` is not writable (e.g. Codex CLI restricted environments), set `UV_CACHE_DIR=/tmp/uv-cache` before calling `uv`. + +## Start the server ```bash uv run python -m memory_hall serve ``` -Write a note: +## Auth + +The CLI reads `MH_API_TOKEN` from the environment on each command via `Settings()` and attaches it as `Authorization: Bearer `. No flag, no manual header. + +Maki's setup keeps the token at `~/.config/memhall/token` (0600). Export before calling: + +```bash +export MH_API_TOKEN="$(cat ~/.config/memhall/token)" +``` + +If `MH_API_TOKEN` is unset, no `Authorization` header is sent — `mh` works against no-auth dev servers unchanged. + +Implementation: `src/memory_hall/cli/main.py:31` (`_client()` injects the header from `Settings().api_token`). Test coverage: `tests/test_cli_auth.py`. + +## Write a note ```bash uv run mh write "DEC-018 落地完成" \ @@ -18,7 +46,7 @@ uv run mh write "DEC-018 落地完成" \ --tag governance ``` -Search: +## Search ```bash uv run mh search "DEC-018" diff --git a/examples/shell/write_memory.sh b/examples/shell/write_memory.sh index 3e2ab0b..1fab73b 100644 --- a/examples/shell/write_memory.sh +++ b/examples/shell/write_memory.sh @@ -1,7 +1,16 @@ #!/usr/bin/env bash set -euo pipefail +# When the server has MH_API_TOKEN set, every /v1/memory/* request needs +# `Authorization: Bearer `. /v1/health is the only public endpoint. +# Maki's setup keeps the token at ~/.config/memhall/token (0600). +# +# If the server runs without MH_API_TOKEN (dev / standalone), the header is +# ignored — sending it anyway is safe and forward-compatible. +TOKEN="${MH_API_TOKEN:-$(cat ~/.config/memhall/token 2>/dev/null || true)}" + curl -sS http://127.0.0.1:9000/v1/memory/write \ + ${TOKEN:+-H "Authorization: Bearer ${TOKEN}"} \ -H 'Content-Type: application/json' \ -d '{ "agent_id": "codex", diff --git a/scripts/bench_hybrid.py b/scripts/bench_hybrid.py new file mode 100644 index 0000000..fc1269f --- /dev/null +++ b/scripts/bench_hybrid.py @@ -0,0 +1,478 @@ +"""Retrieval quality benchmark — RRF vs weighted_linear (α sweep). + +Goal: 回答一個問題 — weighted_linear 預設 α=0.3 是否真的比 RRF 在 memhall 的 +實際使用場景上更好。ADR 0008 的立場:沒有 benchmark 證據就回退 RRF。 + +Usage: + + # synthetic mode (default):跑內建 fixture corpus,directional only + python scripts/bench_hybrid.py + + # real-corpus mode:指向 running memhall 實例 + 自己的 query 清單 + python scripts/bench_hybrid.py --corpus my-queries.jsonl --base-url http://... + +Query file format (jsonl): + {"query": "...", "relevant_ids": ["ent_a", "ent_b"], "notes": "..."} + +Metrics: + MRR (mean reciprocal rank), Recall@5, nDCG@10 + +Exits non-zero if no mode wins on majority of metrics(讓 CI 可選用)。 +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import math +import sys +from collections.abc import Iterable +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Literal, cast + +import httpx + +REPO_ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(REPO_ROOT / "src")) +sys.path.insert(0, str(REPO_ROOT)) + +from memory_hall.config import Settings # noqa: E402 +from memory_hall.server.app import create_app # noqa: E402 +from tests.conftest import client_for_app # noqa: E402 + + +# ---------- Synthetic corpus + embedder ---------------------------------- + +# 同義詞群組 — 同 group 的詞共享 embedding 維度,模擬 semantic similarity +_SYNONYM_GROUPS: list[tuple[str, ...]] = [ + ("restore", "recover", "resurface", "rebuild", "還原"), + ("checklist", "list", "playbook", "清單"), + ("rollout", "deploy", "ship", "release", "部署"), + ("incident", "outage", "failure", "broken", "事故"), + ("hybrid", "combined", "fusion", "混合"), + ("ranking", "scoring", "rank", "排序"), + ("embedder", "embed", "vector", "嵌入"), + ("timeout", "stall", "hang", "逾時"), + ("sqlite", "database", "db", "資料庫"), + ("memhall", "memory-hall", "memory", "記憶"), + ("auth", "token", "bearer", "驗證"), + ("tailscale", "tailnet", "vpn"), + ("benchmark", "metric", "evaluation", "評估"), + ("schema", "migration", "table", "結構"), + ("council", "review", "agent", "協作"), +] + +_TOKEN_TO_DIM: dict[str, int] = {} +for idx, group in enumerate(_SYNONYM_GROUPS): + for token in group: + _TOKEN_TO_DIM[token.lower()] = idx + +_VECTOR_DIM = len(_SYNONYM_GROUPS) + 1 # +1 collision bucket + + +def _tokenize(text: str) -> list[str]: + """Lowercase + split on non-alphanumeric, keep CJK runs as single tokens.""" + out: list[str] = [] + buf: list[str] = [] + for ch in text.lower(): + if ch.isalnum(): + buf.append(ch) + elif "一" <= ch <= "鿿": + if buf: + out.append("".join(buf)) + buf = [] + out.append(ch) + else: + if buf: + out.append("".join(buf)) + buf = [] + if buf: + out.append("".join(buf)) + return out + + +class SynonymEmbedder: + """Bag-of-words over synonym groups. Provides realistic semantic similarity + (synonyms cluster) without perfectly mirroring lexical overlap.""" + + def __init__(self) -> None: + self.dim = _VECTOR_DIM + self.timeout_s = 2.0 + + def embed(self, text: str) -> list[float]: + vec = [0.0] * self.dim + for token in _tokenize(text): + dim = _TOKEN_TO_DIM.get(token, self.dim - 1) + vec[dim] += 1.0 + # L2 normalize + norm = math.sqrt(sum(v * v for v in vec)) + if norm > 0: + vec = [v / norm for v in vec] + return vec + + def embed_batch(self, texts: list[str]) -> list[list[float]]: + return [self.embed(t) for t in texts] + + +# Synthetic corpus — 25 entries covering English + CJK + mixed signals +_CORPUS: list[dict[str, str]] = [ + {"id": "e01", "content": "quokkamode rollout mitigation log"}, + {"id": "e02", "content": "rollout playbook for tomorrow morning"}, + {"id": "e03", "content": "restore recovery list after embed failures"}, + {"id": "e04", "content": "release calendar for next month"}, + {"id": "e05", "content": "daily standup reminders"}, + {"id": "e06", "content": "hybrid ranking marker entry"}, + {"id": "e07", "content": "combined retrieval ranking strategy notes"}, + {"id": "e08", "content": "hybrid combined retrieval ranking strategy details"}, + {"id": "e09", "content": "sqlite WAL corruption incident 2026-04-20"}, + {"id": "e10", "content": "database migration playbook for schema changes"}, + {"id": "e11", "content": "embedder timeout stall during reindex"}, + {"id": "e12", "content": "tailscale ACL setup for admin endpoints"}, + {"id": "e13", "content": "bearer token auth shim ADR 0007"}, + {"id": "e14", "content": "council review session for memhall reliability"}, + {"id": "e15", "content": "benchmark metric evaluation MRR nDCG"}, + {"id": "e16", "content": "撞牆 incident 記錄 — embedder 逾時"}, + {"id": "e17", "content": "記憶 大廳 部署 到 mac mini"}, + {"id": "e18", "content": "資料庫 結構 變更 計畫"}, + {"id": "e19", "content": "驗證 token 旋轉 流程"}, + {"id": "e20", "content": "混合 排序 策略 評估"}, + {"id": "e21", "content": "ollama bge-m3 embedder configuration"}, + {"id": "e22", "content": "phase A reliability patches summary"}, + {"id": "e23", "content": "phase B admin gate proposal"}, + {"id": "e24", "content": "RRF reciprocal rank fusion default"}, + {"id": "e25", "content": "weighted linear alpha tuning experiments"}, +] + +# Hand-labeled queries — relevance based on intent, not just lexical overlap +_QUERIES: list[dict[str, Any]] = [ + { + "query": "quokkamode rollout", + "relevant_ids": ["e01"], + "kind": "rare_lexical", + }, + { + "query": "resurface checklist", + "relevant_ids": ["e03"], + "kind": "pure_semantic", + }, + { + "query": "hybrid ranking", + "relevant_ids": ["e08", "e07", "e06"], + "kind": "mixed", + }, + { + "query": "deploy plan", + "relevant_ids": ["e02", "e04"], + "kind": "semantic_paraphrase", + }, + { + "query": "WAL corruption sqlite", + "relevant_ids": ["e09"], + "kind": "rare_lexical", + }, + { + "query": "schema migration", + "relevant_ids": ["e10", "e18"], + "kind": "mixed_cjk", + }, + { + "query": "embedder hang", + "relevant_ids": ["e11", "e16"], + "kind": "semantic_paraphrase", + }, + { + "query": "admin endpoint auth", + "relevant_ids": ["e12", "e13", "e23"], + "kind": "mixed", + }, + { + "query": "撞牆", + "relevant_ids": ["e16"], + "kind": "cjk_short", + }, + { + "query": "資料庫 結構", + "relevant_ids": ["e18", "e10"], + "kind": "cjk_mixed", + }, + { + "query": "混合排序", + "relevant_ids": ["e20", "e08", "e07", "e06"], + "kind": "cjk_semantic", + }, + { + "query": "RRF fusion", + "relevant_ids": ["e24"], + "kind": "rare_lexical", + }, + { + "query": "alpha tuning", + "relevant_ids": ["e25"], + "kind": "rare_lexical", + }, + { + "query": "incident outage", + "relevant_ids": ["e09", "e16"], + "kind": "pure_semantic", + }, + { + "query": "benchmark evaluation", + "relevant_ids": ["e15"], + "kind": "mixed", + }, +] + + +# ---------- Metrics ------------------------------------------------------- + + +def reciprocal_rank(ranked_ids: list[str], relevant: set[str]) -> float: + for i, eid in enumerate(ranked_ids, start=1): + if eid in relevant: + return 1.0 / i + return 0.0 + + +def recall_at_k(ranked_ids: list[str], relevant: set[str], k: int) -> float: + if not relevant: + return 0.0 + hits = sum(1 for eid in ranked_ids[:k] if eid in relevant) + return hits / len(relevant) + + +def ndcg_at_k(ranked_ids: list[str], relevant: set[str], k: int) -> float: + if not relevant: + return 0.0 + dcg = 0.0 + for i, eid in enumerate(ranked_ids[:k], start=1): + if eid in relevant: + dcg += 1.0 / math.log2(i + 1) + ideal_hits = min(len(relevant), k) + idcg = sum(1.0 / math.log2(i + 1) for i in range(1, ideal_hits + 1)) + return dcg / idcg if idcg > 0 else 0.0 + + +# ---------- Bench runner -------------------------------------------------- + + +@dataclass +class ModeResult: + label: str + mrr: float + recall_at_5: float + ndcg_at_10: float + per_query: list[dict[str, Any]] + + +async def _seed_corpus(client: httpx.AsyncClient, corpus: list[dict[str, str]]) -> dict[str, str]: + """Write corpus, return content -> entry_id mapping.""" + mapping: dict[str, str] = {} + for item in corpus: + resp = await client.post( + "/v1/memory/write", + json={ + "agent_id": "bench", + "namespace": "shared", + "type": "note", + "content": item["content"], + }, + ) + resp.raise_for_status() + mapping[item["id"]] = resp.json()["entry_id"] + return mapping + + +async def _run_queries( + client: httpx.AsyncClient, + queries: list[dict[str, Any]], + id_map: dict[str, str], + label: str, +) -> ModeResult: + rrs: list[float] = [] + recalls: list[float] = [] + ndcgs: list[float] = [] + per_query: list[dict[str, Any]] = [] + + for q in queries: + resp = await client.post( + "/v1/memory/search", + json={"query": q["query"], "limit": 10, "mode": "hybrid"}, + ) + resp.raise_for_status() + results = resp.json()["results"] + ranked_real_ids = [r["entry"]["entry_id"] for r in results] + relevant = {id_map[eid] for eid in q["relevant_ids"] if eid in id_map} + + rr = reciprocal_rank(ranked_real_ids, relevant) + r5 = recall_at_k(ranked_real_ids, relevant, 5) + ndcg = ndcg_at_k(ranked_real_ids, relevant, 10) + + rrs.append(rr) + recalls.append(r5) + ndcgs.append(ndcg) + per_query.append( + { + "query": q["query"], + "kind": q.get("kind", ""), + "rr": rr, + "recall_at_5": r5, + "ndcg_at_10": ndcg, + } + ) + + return ModeResult( + label=label, + mrr=sum(rrs) / len(rrs) if rrs else 0.0, + recall_at_5=sum(recalls) / len(recalls) if recalls else 0.0, + ndcg_at_10=sum(ndcgs) / len(ndcgs) if ndcgs else 0.0, + per_query=per_query, + ) + + +def _build_app(hybrid_mode: str, alpha: float, tmp_path: Path) -> Any: + settings = Settings( + database_path=tmp_path / "bench.sqlite3", + vector_database_path=tmp_path / "bench-vectors.sqlite3", + vector_dim=_VECTOR_DIM, + embed_dim=_VECTOR_DIM, + hybrid_mode=cast(Literal["weighted_linear", "rrf"], hybrid_mode), + hybrid_alpha=alpha, + request_timeout_s=2.0, + health_embed_timeout_s=1.0, + api_token=None, + ) + return create_app(settings=settings, embedder=SynonymEmbedder()) + + +async def _run_synthetic_one( + label: str, + hybrid_mode: str, + alpha: float, + queries: list[dict[str, Any]], + tmp_dir: Path, +) -> ModeResult: + sub = tmp_dir / label.replace("=", "_").replace(" ", "_") + sub.mkdir(parents=True, exist_ok=True) + app = _build_app(hybrid_mode, alpha, sub) + async with client_for_app(app) as client: + id_map = await _seed_corpus(client, _CORPUS) + return await _run_queries(client, queries, id_map, label) + + +async def run_synthetic(alphas: Iterable[float]) -> list[ModeResult]: + import tempfile + + results: list[ModeResult] = [] + with tempfile.TemporaryDirectory() as raw_tmp: + tmp_dir = Path(raw_tmp) + results.append( + await _run_synthetic_one("rrf", "rrf", 0.0, _QUERIES, tmp_dir) + ) + for alpha in alphas: + label = f"weighted_linear(α={alpha})" + results.append( + await _run_synthetic_one(label, "weighted_linear", alpha, _QUERIES, tmp_dir) + ) + return results + + +# ---------- Real-corpus mode ---------------------------------------------- + + +async def run_real_corpus( + base_url: str, + queries_file: Path, + api_token: str | None, +) -> list[ModeResult]: + """Real-corpus mode: alpha / mode must be set on the running server before + invocation. This runner only fires queries; switch server config + re-run + to compare modes.""" + queries = [json.loads(line) for line in queries_file.read_text().splitlines() if line.strip()] + headers = {"Authorization": f"Bearer {api_token}"} if api_token else {} + real_id_map = {eid: eid for q in queries for eid in q["relevant_ids"]} + + async with httpx.AsyncClient(base_url=base_url, headers=headers, timeout=10.0) as client: + return [ + await _run_queries( + client, + queries, + real_id_map, + f"server-config@{base_url}", + ) + ] + + +# ---------- Reporting ----------------------------------------------------- + + +def print_summary(results: list[ModeResult]) -> None: + print("\n=== Summary (higher is better) ===") + print(f"{'mode':<28} {'MRR':>8} {'R@5':>8} {'nDCG@10':>10}") + print("-" * 56) + for r in results: + print(f"{r.label:<28} {r.mrr:>8.4f} {r.recall_at_5:>8.4f} {r.ndcg_at_10:>10.4f}") + + best_mrr = max(results, key=lambda r: r.mrr) + best_recall = max(results, key=lambda r: r.recall_at_5) + best_ndcg = max(results, key=lambda r: r.ndcg_at_10) + print() + print(f"best MRR: {best_mrr.label}") + print(f"best R@5: {best_recall.label}") + print(f"best nDCG@10: {best_ndcg.label}") + + +def print_per_query_diffs(results: list[ModeResult]) -> None: + print("\n=== Per-query reciprocal rank ===") + rrf = next((r for r in results if r.label == "rrf"), None) + if rrf is None: + return + queries = [q["query"] for q in rrf.per_query] + header = f"{'query':<30} " + " ".join(f"{r.label[:14]:>14}" for r in results) + print(header) + print("-" * len(header)) + for i, query in enumerate(queries): + row = f"{query[:30]:<30} " + row += " ".join(f"{r.per_query[i]['rr']:>14.3f}" for r in results) + print(row) + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument( + "--corpus", + type=Path, + help="Path to jsonl query file (real-corpus mode). Omit for synthetic.", + ) + parser.add_argument("--base-url", default="http://localhost:9100") + parser.add_argument("--api-token", default=None) + parser.add_argument( + "--alpha", + type=float, + action="append", + help="α value(s) to sweep. Default: 0.1, 0.3, 0.5, 0.7, 0.9.", + ) + args = parser.parse_args() + + alphas = args.alpha or [0.1, 0.3, 0.5, 0.7, 0.9] + + if args.corpus: + if not args.corpus.exists(): + print(f"corpus file not found: {args.corpus}", file=sys.stderr) + return 2 + results = asyncio.run( + run_real_corpus(args.base_url, args.corpus, args.api_token) + ) + else: + print("Running synthetic benchmark (directional only — confirm with real corpus).") + results = asyncio.run(run_synthetic(alphas)) + + print_summary(results) + if not args.corpus: + print_per_query_diffs(results) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/memory_hall/cli/main.py b/src/memory_hall/cli/main.py index a3ce761..0045da3 100644 --- a/src/memory_hall/cli/main.py +++ b/src/memory_hall/cli/main.py @@ -29,7 +29,17 @@ def _settings() -> Settings: def _client(base_url: str, timeout_s: float) -> httpx.Client: - return httpx.Client(base_url=base_url.rstrip("/"), timeout=timeout_s) + settings = _settings() + headers = ( + {"Authorization": f"Bearer {settings.api_token}"} + if settings.api_token + else None + ) + return httpx.Client( + base_url=base_url.rstrip("/"), + timeout=timeout_s, + headers=headers, + ) def _parse_metadata(value: str | None) -> dict[str, Any]: diff --git a/src/memory_hall/config.py b/src/memory_hall/config.py index 9766f36..68b3ae7 100644 --- a/src/memory_hall/config.py +++ b/src/memory_hall/config.py @@ -30,11 +30,15 @@ class Settings(BaseSettings): vector_dim: int = 1024 default_tenant_id: str = "default" api_token: str | None = None + admin_token: str | None = None list_default_limit: int = 50 search_default_limit: int = 20 search_candidate_multiplier: int = 5 + hybrid_mode: Literal["weighted_linear", "rrf"] = "rrf" + hybrid_alpha: float = Field(default=0.3, ge=0.0, le=1.0) request_timeout_s: float = 5.0 reindex_batch_size: int = 500 + wal_checkpoint_interval_s: float = 300.0 max_content_bytes: int = 64 * 1024 @model_validator(mode="before") @@ -58,6 +62,14 @@ def _set_default_embed_dim(self) -> Settings: self.embed_dim = self.vector_dim return self + @model_validator(mode="after") + def _validate_auth_tokens(self) -> Settings: + if self.admin_token and not self.api_token: + raise ValueError("admin_token requires api_token (would fail-open on non-admin paths)") + if self.admin_token and self.api_token and self.admin_token == self.api_token: + raise ValueError("admin_token must differ from api_token") + return self + def prepare_paths(self) -> None: self.database_path.parent.mkdir(parents=True, exist_ok=True) self.vector_database_path.parent.mkdir(parents=True, exist_ok=True) diff --git a/src/memory_hall/models.py b/src/memory_hall/models.py index d7c6783..bb49274 100644 --- a/src/memory_hall/models.py +++ b/src/memory_hall/models.py @@ -16,6 +16,7 @@ SYNC_FAILED = "failed" SyncStatus = Literal["pending", "embedded", "failed"] SemanticStatus = Literal["ok", "timeout", "embedder_error", "not_attempted"] +HybridMode = Literal["weighted_linear", "rrf"] _ULID_ALPHABET = "0123456789ABCDEFGHJKMNPQRSTVWXYZ" @@ -182,6 +183,8 @@ class ScoreBreakdown(BaseModel): bm25: float semantic: float rrf: float + hybrid_mode: HybridMode = "weighted_linear" + alpha: float = 0.3 semantic_status: SemanticStatus = "not_attempted" diff --git a/src/memory_hall/server/app.py b/src/memory_hall/server/app.py index 3063ba3..ee0d17b 100644 --- a/src/memory_hall/server/app.py +++ b/src/memory_hall/server/app.py @@ -4,6 +4,7 @@ import hmac import json import logging +import math import random import re from contextlib import asynccontextmanager, suppress @@ -30,10 +31,10 @@ ListEntriesResponse, ReindexResponse, ScoreBreakdown, - SemanticStatus, SearchMemoryRequest, SearchMemoryResponse, SearchResultItem, + SemanticStatus, WriteMemoryRequest, WriteOutcome, build_content_hash, @@ -53,6 +54,8 @@ _HEALTH_PROBE_INTERVAL_S = 30.0 _HEALTH_CACHE_TTL_S = 60.0 _REINDEX_EMBED_BATCH_SIZE = 16 +_REINDEX_SCAN_PAGE_SIZE = 200 +_WAL_CHECKPOINT_MODE = "TRUNCATE" _EMBED_FAILURE_LIMIT = 5 _MAX_EMBED_ERROR_LENGTH = 500 @@ -99,10 +102,12 @@ def __init__( self._worker: asyncio.Task[None] | None = None self._reindex_worker: asyncio.Task[None] | None = None self._health_probe_worker: asyncio.Task[None] | None = None + self._wal_checkpoint_worker: asyncio.Task[None] | None = None self._background_reindex_interval_s = _BACKGROUND_REINDEX_INTERVAL_S self._background_reindex_jitter_s = min(15.0, _BACKGROUND_REINDEX_INTERVAL_S * 0.1) self._health_probe_interval_s = _HEALTH_PROBE_INTERVAL_S self._health_cache_ttl_s = _HEALTH_CACHE_TTL_S + self._wal_checkpoint_interval_s = self.settings.wal_checkpoint_interval_s self._health_cache_checked_at = None self._health_last_success_at = None self._health_cache = HealthResponse( @@ -123,8 +128,13 @@ async def start(self) -> None: self._worker = asyncio.create_task(self._consume_writes()) self._reindex_worker = asyncio.create_task(self._run_background_reindex()) self._health_probe_worker = asyncio.create_task(self._run_health_probe()) + self._wal_checkpoint_worker = asyncio.create_task(self._run_wal_checkpoint()) async def stop(self) -> None: + if self._wal_checkpoint_worker is not None: + self._wal_checkpoint_worker.cancel() + with suppress(asyncio.CancelledError): + await self._wal_checkpoint_worker if self._health_probe_worker is not None: self._health_probe_worker.cancel() with suppress(asyncio.CancelledError): @@ -250,7 +260,13 @@ async def search_entries( exc, ) - combined = self._combine_hits(payload.query, lexical_hits, semantic_hits, limit) + combined = self._combine_hits( + query=payload.query, + lexical_hits=lexical_hits, + semantic_hits=semantic_hits, + semantic_status=semantic_status, + limit=limit, + ) entry_ids = [item["entry_id"] for item in combined] entries = await self.storage.get_entries_by_ids(tenant_id, entry_ids) entry_map = {entry.entry_id: entry for entry in entries} @@ -262,11 +278,14 @@ async def search_entries( results.append( SearchResultItem( entry_id=entry.entry_id, - score=item["rrf"], + score=item["score"], score_breakdown=ScoreBreakdown( bm25=item["bm25"], semantic=item["semantic"], - rrf=item["rrf"], + # Legacy field name preserved for backward compatibility. + rrf=item["score"], + hybrid_mode=item["hybrid_mode"], + alpha=item["alpha"], semantic_status=semantic_status, ), entry=EntryDocument.from_entry(entry), @@ -375,6 +394,16 @@ async def _run_health_probe(self) -> None: except Exception as exc: logger.warning("health probe failed: %s", exc) + async def _run_wal_checkpoint(self) -> None: + while True: + await asyncio.sleep(self._wal_checkpoint_interval_s) + try: + await self._checkpoint_wal_databases() + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning("background WAL checkpoint failed: %s", exc) + async def audit(self) -> AuditResponse: payload = await self.storage.audit() return AuditResponse.model_validate(payload) @@ -495,35 +524,78 @@ async def _handle_link(self, job: LinkJob) -> EntryDocument | None: return EntryDocument.from_entry(entry) async def _handle_reindex(self, job: ReindexJob) -> ReindexResponse: - if job.pending_only: - all_entries = await self.storage.list_pending_entries( - job.tenant_id, - limit=self.settings.reindex_batch_size, - ) - else: - all_entries = await self.storage.list_entries(job.tenant_id, limit=None) - scanned = len(all_entries) - candidates: list[Entry] = [] + sync_status = SYNC_PENDING if job.pending_only else None + total_entries = await self.storage.count_entries( + job.tenant_id, + sync_status=sync_status, + ) + if total_entries == 0: + return ReindexResponse(scanned=0, embedded=0, pending=0) + + total_batches = max(1, math.ceil(total_entries / _REINDEX_SCAN_PAGE_SIZE)) + scanned = 0 embedded_count = 0 pending_count = 0 - for entry in all_entries: - if not job.pending_only: - needs_reindex = entry.sync_status != SYNC_EMBEDDED - if not needs_reindex: - needs_reindex = not await asyncio.to_thread( - self.vector_store.contains, - entry.tenant_id, - entry.entry_id, + cursor: str | None = None + batch_number = 0 + + try: + while scanned < total_entries: + entries = await self.storage.list_entries( + job.tenant_id, + sync_status=sync_status, + limit=_REINDEX_SCAN_PAGE_SIZE, + cursor=cursor, + ) + if not entries: + break + + batch_number += 1 + scanned += len(entries) + candidates: list[Entry] = [] + for entry in entries: + if not job.pending_only: + needs_reindex = entry.sync_status != SYNC_EMBEDDED + if not needs_reindex: + needs_reindex = not await asyncio.to_thread( + self.vector_store.contains, + entry.tenant_id, + entry.entry_id, + ) + if not needs_reindex: + continue + candidates.append(entry) + + for offset in range(0, len(candidates), _REINDEX_EMBED_BATCH_SIZE): + embedded, pending = await self._embed_reindex_batch( + candidates[offset : offset + _REINDEX_EMBED_BATCH_SIZE] ) - if not needs_reindex: - continue - candidates.append(entry) - for offset in range(0, len(candidates), _REINDEX_EMBED_BATCH_SIZE): - embedded, pending = await self._embed_reindex_batch( - candidates[offset : offset + _REINDEX_EMBED_BATCH_SIZE] + embedded_count += embedded + pending_count += pending + + logger.info( + "reindex batch %s/%s, %s done tenant_id=%s pending_only=%s", + batch_number, + total_batches, + scanned, + job.tenant_id, + job.pending_only, + ) + if scanned >= total_entries: + break + tail = entries[-1] + cursor = encode_cursor(tail.created_at, tail.entry_id) + except asyncio.CancelledError: + logger.info( + "reindex cancelled tenant_id=%s batches=%s scanned=%s embedded=%s pending=%s", + job.tenant_id, + batch_number, + scanned, + embedded_count, + pending_count, ) - embedded_count += embedded - pending_count += pending + raise + return ReindexResponse(scanned=scanned, embedded=embedded_count, pending=pending_count) async def _embed_reindex_batch(self, entries: list[Entry]) -> tuple[int, int]: @@ -670,13 +742,94 @@ def _record_health_error(self, component: str, exc: Exception) -> str: ) return message[:_MAX_EMBED_ERROR_LENGTH] + async def _checkpoint_wal_databases(self) -> None: + busy, log_frames, checkpointed = await self.storage.checkpoint_wal( + mode=_WAL_CHECKPOINT_MODE + ) + logger.info( + "WAL checkpoint completed: busy=%s log=%s ckpt=%s db=%s", + busy, + log_frames, + checkpointed, + self.settings.database_path, + ) + busy, log_frames, checkpointed = await asyncio.to_thread( + self.vector_store.checkpoint_wal, + mode=_WAL_CHECKPOINT_MODE, + ) + logger.info( + "WAL checkpoint completed: busy=%s log=%s ckpt=%s db=%s", + busy, + log_frames, + checkpointed, + self.settings.vector_database_path, + ) + def _require_queue(self) -> asyncio.Queue[WriteJob | LinkJob | ReindexJob | None]: if self._queue is None: raise RuntimeError("runtime is not started") return self._queue - @staticmethod def _combine_hits( + self, + *, + query: str, + lexical_hits: list[tuple[str, float]], + semantic_hits: list[tuple[str, float]], + semantic_status: SemanticStatus, + limit: int, + ) -> list[dict[str, Any]]: + if self.settings.hybrid_mode == "rrf": + return self._combine_hits_rrf( + query=query, + lexical_hits=lexical_hits, + semantic_hits=semantic_hits, + limit=limit, + ) + return self._combine_hits_weighted_linear( + lexical_hits=lexical_hits, + semantic_hits=semantic_hits, + semantic_status=semantic_status, + limit=limit, + ) + + def _combine_hits_weighted_linear( + self, + *, + lexical_hits: list[tuple[str, float]], + semantic_hits: list[tuple[str, float]], + semantic_status: SemanticStatus, + limit: int, + ) -> list[dict[str, Any]]: + if semantic_status != "ok": + alpha = 1.0 + elif not lexical_hits: + alpha = 0.0 + else: + alpha = self.settings.hybrid_alpha + + lexical_map = {entry_id: score for entry_id, score in lexical_hits} + semantic_map = {entry_id: score for entry_id, score in semantic_hits} + combined = [] + for entry_id in set(lexical_map) | set(semantic_map): + lexical_score = lexical_map.get(entry_id, 0.0) + semantic_score = semantic_map.get(entry_id, 0.0) + combined.append( + { + "entry_id": entry_id, + "bm25": lexical_score, + "semantic": semantic_score, + "score": (alpha * lexical_score) + ((1.0 - alpha) * semantic_score), + "hybrid_mode": "weighted_linear", + "alpha": alpha, + } + ) + combined.sort(key=lambda item: item["score"], reverse=True) + return combined[:limit] + + @staticmethod + def _combine_hits_rrf( + *, query: str, lexical_hits: list[tuple[str, float]], semantic_hits: list[tuple[str, float]], @@ -687,18 +840,32 @@ def _combine_hits( for rank, (entry_id, score) in enumerate(lexical_hits, start=1): payload = combined.setdefault( entry_id, - {"entry_id": entry_id, "bm25": 0.0, "semantic": 0.0, "rrf": 0.0}, + { + "entry_id": entry_id, + "bm25": 0.0, + "semantic": 0.0, + "score": 0.0, + "hybrid_mode": "rrf", + "alpha": 0.0, + }, ) payload["bm25"] = score - payload["rrf"] += lexical_weight / (_RRF_K + rank) + payload["score"] += lexical_weight / (_RRF_K + rank) for rank, (entry_id, score) in enumerate(semantic_hits, start=1): payload = combined.setdefault( entry_id, - {"entry_id": entry_id, "bm25": 0.0, "semantic": 0.0, "rrf": 0.0}, + { + "entry_id": entry_id, + "bm25": 0.0, + "semantic": 0.0, + "score": 0.0, + "hybrid_mode": "rrf", + "alpha": 0.0, + }, ) payload["semantic"] = score - payload["rrf"] += 1.0 / (_RRF_K + rank) - ranked = sorted(combined.values(), key=lambda item: item["rrf"], reverse=True) + payload["score"] += 1.0 / (_RRF_K + rank) + ranked = sorted(combined.values(), key=lambda item: item["score"], reverse=True) return ranked[:limit] @@ -777,14 +944,26 @@ async def lifespan(app: FastAPI): @app.middleware("http") async def require_api_token(request: Request, call_next): - # /v1/health is intentionally public — external uptime monitors and the - # in-image HEALTHCHECK probe it without credentials. - if request.url.path.rstrip("/") == "/v1/health": + # Health probe routes stay public for uptime monitors and container + # orchestrators. + path = request.url.path.rstrip("/") + if path == "/v1/health": return await call_next(request) - # Backward compat: when api_token is unset (None) or empty string - # (docker-compose `${MH_API_TOKEN:-}` expands to "" when host env is - # unset — pydantic reads that as "", not None), auth is disabled. - if not active_settings.api_token: + # /v1/admin/* with explicit admin_token configured requires the + # admin_token; the regular api_token is rejected. When admin_token is + # unset, admin paths fall back to api_token (ADR 0007 backward compat). + is_admin_path = path == "/v1/admin" or path.startswith("/v1/admin/") + if is_admin_path and active_settings.admin_token: + expected = active_settings.admin_token + invalid_msg = "invalid admin token" + elif active_settings.api_token: + expected = active_settings.api_token + invalid_msg = "invalid token" + else: + # Backward compat: when api_token is unset (None) or empty string + # (docker-compose `${MH_API_TOKEN:-}` expands to "" when host env + # is unset — pydantic reads that as "", not None), auth is + # disabled. return await call_next(request) header = request.headers.get("authorization", "") prefix = "Bearer " @@ -794,10 +973,10 @@ async def require_api_token(request: Request, call_next): content={"detail": "missing bearer token"}, ) received = header[len(prefix):] - if not hmac.compare_digest(received, active_settings.api_token): + if not hmac.compare_digest(received, expected): return JSONResponse( status_code=401, - content={"detail": "invalid token"}, + content={"detail": invalid_msg}, ) return await call_next(request) diff --git a/src/memory_hall/storage/interface.py b/src/memory_hall/storage/interface.py index d88490e..d64e9eb 100644 --- a/src/memory_hall/storage/interface.py +++ b/src/memory_hall/storage/interface.py @@ -40,12 +40,20 @@ async def list_entries( agent_id: str | None = None, types: list[str] | None = None, tags: list[str] | None = None, + sync_status: str | None = None, since: datetime | None = None, until: datetime | None = None, limit: int | None = None, cursor: str | None = None, ) -> list[Entry]: ... + async def count_entries( + self, + tenant_id: str, + *, + sync_status: str | None = None, + ) -> int: ... + async def search_lexical( self, tenant_id: str, @@ -75,4 +83,6 @@ async def get_references_out(self, tenant_id: str, entry_id: str) -> list[Entry] async def get_references_in(self, tenant_id: str, entry_id: str) -> list[Entry]: ... + async def checkpoint_wal(self, *, mode: str = "TRUNCATE") -> tuple[int, int, int]: ... + async def audit(self) -> dict[str, object]: ... diff --git a/src/memory_hall/storage/sqlite_store.py b/src/memory_hall/storage/sqlite_store.py index b6fec02..ec86e47 100644 --- a/src/memory_hall/storage/sqlite_store.py +++ b/src/memory_hall/storage/sqlite_store.py @@ -3,7 +3,9 @@ from __future__ import annotations import json +import logging import sqlite3 +from collections.abc import Awaitable, Callable from contextlib import asynccontextmanager from datetime import datetime from pathlib import Path @@ -13,13 +15,37 @@ from memory_hall.models import Entry, InsertOutcome, SearchCandidate, decode_cursor, dump_json +_TRANSIENT_OPERATIONAL_ERROR_MARKERS = ( + "disk i/o error", + "database is locked", + "database table is locked", + "database is busy", +) + +logger = logging.getLogger(__name__) + class SqliteStore: def __init__(self, database_path: Path) -> None: self.database_path = database_path self._writer_connection: aiosqlite.Connection | None = None + self._writer_lock = None + self._reader_state = None + self._active_readers = 0 + self._checkpoint_requested = False + + def _ensure_runtime_primitives(self) -> None: + if self._writer_lock is None: + import asyncio + + self._writer_lock = asyncio.Lock() + if self._reader_state is None: + import asyncio + + self._reader_state = asyncio.Condition() async def open(self) -> None: + self._ensure_runtime_primitives() self.database_path.parent.mkdir(parents=True, exist_ok=True) await self._open_writer_connection() async with self._read_connection() as connection: @@ -31,66 +57,70 @@ async def close(self) -> None: self._writer_connection = None async def healthcheck(self) -> None: - async with self._read_connection() as connection: + async def operation(connection: aiosqlite.Connection) -> None: await connection.execute("SELECT 1") + await self._run_read_operation(operation) + async def insert_entry(self, entry: Entry) -> InsertOutcome: - connection = await self._require_writer_connection() - await connection.execute("BEGIN IMMEDIATE") - try: - await connection.execute( - """ - INSERT INTO entries ( - entry_id, tenant_id, agent_id, namespace, type, content, content_hash, - summary, tags_json, references_json, metadata_json, sync_status, - last_embedded_at, last_embed_error, last_embed_attempted_at, - embed_attempt_count, created_at, created_by_principal + async def operation(connection: aiosqlite.Connection) -> InsertOutcome: + await connection.execute("BEGIN IMMEDIATE") + try: + await connection.execute( + """ + INSERT INTO entries ( + entry_id, tenant_id, agent_id, namespace, type, content, content_hash, + summary, tags_json, references_json, metadata_json, sync_status, + last_embedded_at, last_embed_error, last_embed_attempted_at, + embed_attempt_count, created_at, created_by_principal + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + entry.entry_id, + entry.tenant_id, + entry.agent_id, + entry.namespace, + entry.type, + entry.content, + entry.content_hash, + entry.summary, + dump_json(entry.tags), + dump_json(entry.references), + dump_json(entry.metadata), + entry.sync_status, + entry.last_embedded_at.isoformat() if entry.last_embedded_at else None, + entry.last_embed_error, + entry.last_embed_attempted_at.isoformat() + if entry.last_embed_attempted_at + else None, + entry.embed_attempt_count, + entry.created_at.isoformat(), + entry.created_by_principal, + ), ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - entry.entry_id, - entry.tenant_id, - entry.agent_id, - entry.namespace, - entry.type, - entry.content, - entry.content_hash, - entry.summary, - dump_json(entry.tags), - dump_json(entry.references), - dump_json(entry.metadata), - entry.sync_status, - entry.last_embedded_at.isoformat() if entry.last_embedded_at else None, - entry.last_embed_error, - entry.last_embed_attempted_at.isoformat() - if entry.last_embed_attempted_at - else None, - entry.embed_attempt_count, - entry.created_at.isoformat(), - entry.created_by_principal, - ), - ) - await connection.execute( - """ - INSERT INTO entries_fts (entry_id, tenant_id, content, summary, tags) - VALUES (?, ?, ?, ?, ?) - """, - (entry.entry_id, entry.tenant_id, *self._build_fts_document(entry)), - ) - await connection.commit() - return InsertOutcome(entry=entry, created=True) - except sqlite3.IntegrityError as exc: - await connection.rollback() - if ( - "entries.tenant_id, entries.content_hash" not in str(exc) - and "UNIQUE" not in str(exc) - ): - raise - existing = await self.get_entry_by_hash(entry.tenant_id, entry.content_hash) - if existing is None: - raise - return InsertOutcome(entry=existing, created=False) + await connection.execute( + """ + INSERT INTO entries_fts (entry_id, tenant_id, content, summary, tags) + VALUES (?, ?, ?, ?, ?) + """, + (entry.entry_id, entry.tenant_id, *self._build_fts_document(entry)), + ) + await connection.commit() + return InsertOutcome(entry=entry, created=True) + except sqlite3.IntegrityError as exc: + await connection.rollback() + if ( + "entries.tenant_id, entries.content_hash" not in str(exc) + and "UNIQUE" not in str(exc) + ): + raise + existing = await self.get_entry_by_hash(entry.tenant_id, entry.content_hash) + if existing is None: + raise + return InsertOutcome(entry=existing, created=False) + + return await self._run_writer_operation(operation) async def update_sync_status( self, @@ -102,34 +132,42 @@ async def update_sync_status( last_embed_attempted_at: datetime | None, embed_attempt_count: int, ) -> Entry | None: - connection = await self._require_writer_connection() - await connection.execute("BEGIN IMMEDIATE") - await connection.execute( - """ - UPDATE entries - SET - sync_status = ?, - last_embedded_at = ?, - last_embed_error = ?, - last_embed_attempted_at = ?, - embed_attempt_count = ? - WHERE tenant_id = ? AND entry_id = ? - """, - ( - sync_status, - last_embedded_at.isoformat() if last_embedded_at else None, - last_embed_error, - last_embed_attempted_at.isoformat() if last_embed_attempted_at else None, - embed_attempt_count, - tenant_id, - entry_id, - ), - ) - await connection.commit() + async def operation(connection: aiosqlite.Connection) -> None: + await connection.execute("BEGIN IMMEDIATE") + try: + await connection.execute( + """ + UPDATE entries + SET + sync_status = ?, + last_embedded_at = ?, + last_embed_error = ?, + last_embed_attempted_at = ?, + embed_attempt_count = ? + WHERE tenant_id = ? AND entry_id = ? + """, + ( + sync_status, + last_embedded_at.isoformat() if last_embedded_at else None, + last_embed_error, + last_embed_attempted_at.isoformat() + if last_embed_attempted_at + else None, + embed_attempt_count, + tenant_id, + entry_id, + ), + ) + await connection.commit() + except Exception: + await connection.rollback() + raise + + await self._run_writer_operation(operation) return await self.get_entry(tenant_id, entry_id) async def get_entry(self, tenant_id: str, entry_id: str) -> Entry | None: - async with self._read_connection() as connection: + async def operation(connection: aiosqlite.Connection) -> Entry | None: cursor = await connection.execute( "SELECT * FROM entries WHERE tenant_id = ? AND entry_id = ?", (tenant_id, entry_id), @@ -137,8 +175,10 @@ async def get_entry(self, tenant_id: str, entry_id: str) -> Entry | None: row = await cursor.fetchone() return self._row_to_entry(row) if row else None + return await self._run_read_operation(operation) + async def get_entry_by_hash(self, tenant_id: str, content_hash: str) -> Entry | None: - async with self._read_connection() as connection: + async def operation(connection: aiosqlite.Connection) -> Entry | None: cursor = await connection.execute( "SELECT * FROM entries WHERE tenant_id = ? AND content_hash = ?", (tenant_id, content_hash), @@ -146,11 +186,14 @@ async def get_entry_by_hash(self, tenant_id: str, content_hash: str) -> Entry | row = await cursor.fetchone() return self._row_to_entry(row) if row else None + return await self._run_read_operation(operation) + async def get_entries_by_ids(self, tenant_id: str, entry_ids: list[str]) -> list[Entry]: if not entry_ids: return [] placeholders = ",".join("?" for _ in entry_ids) - async with self._read_connection() as connection: + + async def operation(connection: aiosqlite.Connection) -> list[Entry]: cursor = await connection.execute( f""" SELECT * FROM entries @@ -159,8 +202,10 @@ async def get_entries_by_ids(self, tenant_id: str, entry_ids: list[str]) -> list (tenant_id, *entry_ids), ) rows = await cursor.fetchall() - mapping = {row["entry_id"]: self._row_to_entry(row) for row in rows} - return [mapping[entry_id] for entry_id in entry_ids if entry_id in mapping] + mapping = {row["entry_id"]: self._row_to_entry(row) for row in rows} + return [mapping[entry_id] for entry_id in entry_ids if entry_id in mapping] + + return await self._run_read_operation(operation) async def list_entries( self, @@ -170,6 +215,7 @@ async def list_entries( agent_id: str | None = None, types: list[str] | None = None, tags: list[str] | None = None, + sync_status: str | None = None, since: datetime | None = None, until: datetime | None = None, limit: int | None = None, @@ -185,6 +231,7 @@ async def list_entries( agent_id=agent_id, types=types, tags=tags, + sync_status=sync_status, since=since, until=until, cursor=cursor, @@ -194,10 +241,41 @@ async def list_entries( if limit is not None: sql += " LIMIT ?" params.append(limit) - async with self._read_connection() as connection: + async def operation(connection: aiosqlite.Connection) -> list[Entry]: cursor_obj = await connection.execute(sql, params) rows = await cursor_obj.fetchall() - return [self._row_to_entry(row) for row in rows] + return [self._row_to_entry(row) for row in rows] + + return await self._run_read_operation(operation) + + async def count_entries( + self, + tenant_id: str, + *, + sync_status: str | None = None, + ) -> int: + conditions = ["tenant_id = ?"] + params: list[Any] = [tenant_id] + self._apply_common_filters( + conditions=conditions, + params=params, + alias="entries", + namespaces=None, + agent_id=None, + types=None, + tags=None, + sync_status=sync_status, + since=None, + until=None, + cursor=None, + ) + sql = "SELECT COUNT(*) FROM entries WHERE " + " AND ".join(conditions) + async def operation(connection: aiosqlite.Connection) -> int: + cursor_obj = await connection.execute(sql, params) + row = await cursor_obj.fetchone() + return int(row[0] if row else 0) + + return await self._run_read_operation(operation) async def search_lexical( self, @@ -220,6 +298,7 @@ async def search_lexical( agent_id=agent_id, types=types, tags=tags, + sync_status=None, since=None, until=None, cursor=None, @@ -240,13 +319,18 @@ async def search_lexical( """ sql += " AND ".join(conditions) sql += " ORDER BY bm25_score LIMIT ?" - async with self._read_connection() as connection: + async def operation(connection: aiosqlite.Connection) -> list[SearchCandidate]: cursor_obj = await connection.execute(sql, params) rows = await cursor_obj.fetchall() - return [ - SearchCandidate(entry_id=row["entry_id"], score=self._normalize_bm25(row["bm25_score"])) - for row in rows - ] + return [ + SearchCandidate( + entry_id=row["entry_id"], + score=self._normalize_bm25(row["bm25_score"]), + ) + for row in rows + ] + + return await self._run_read_operation(operation) async def add_reference( self, @@ -261,17 +345,24 @@ async def add_reference( references = list(source.references) if target_entry_id not in references: references.append(target_entry_id) - connection = await self._require_writer_connection() - await connection.execute("BEGIN IMMEDIATE") - await connection.execute( - """ - UPDATE entries - SET references_json = ? - WHERE tenant_id = ? AND entry_id = ? - """, - (dump_json(references), tenant_id, source_entry_id), - ) - await connection.commit() + + async def operation(connection: aiosqlite.Connection) -> None: + await connection.execute("BEGIN IMMEDIATE") + try: + await connection.execute( + """ + UPDATE entries + SET references_json = ? + WHERE tenant_id = ? AND entry_id = ? + """, + (dump_json(references), tenant_id, source_entry_id), + ) + await connection.commit() + except Exception: + await connection.rollback() + raise + + await self._run_writer_operation(operation) return await self.get_entry(tenant_id, source_entry_id) async def list_pending_entries(self, tenant_id: str, limit: int | None = None) -> list[Entry]: @@ -281,13 +372,16 @@ async def list_pending_entries(self, tenant_id: str, limit: int | None = None) - if limit is not None: sql += " LIMIT ?" params.append(limit) - async with self._read_connection() as connection: + + async def operation(connection: aiosqlite.Connection) -> list[Entry]: cursor = await connection.execute(sql, params) rows = await cursor.fetchall() - return [self._row_to_entry(row) for row in rows] + return [self._row_to_entry(row) for row in rows] + + return await self._run_read_operation(operation) async def list_tenant_ids(self) -> list[str]: - async with self._read_connection() as connection: + async def operation(connection: aiosqlite.Connection) -> list[str]: cursor = await connection.execute( """ SELECT DISTINCT tenant_id @@ -296,10 +390,12 @@ async def list_tenant_ids(self) -> list[str]: """ ) rows = await cursor.fetchall() - return [row["tenant_id"] for row in rows] + return [row["tenant_id"] for row in rows] + + return await self._run_read_operation(operation) async def get_references_out(self, tenant_id: str, entry_id: str) -> list[Entry]: - async with self._read_connection() as connection: + async def operation(connection: aiosqlite.Connection) -> list[Entry]: cursor = await connection.execute( """ SELECT e.* @@ -321,10 +417,12 @@ async def get_references_out(self, tenant_id: str, entry_id: str) -> list[Entry] (tenant_id, entry_id, tenant_id), ) rows = await cursor.fetchall() - return [self._row_to_entry(row) for row in rows] + return [self._row_to_entry(row) for row in rows] + + return await self._run_read_operation(operation) async def get_references_in(self, tenant_id: str, entry_id: str) -> list[Entry]: - async with self._read_connection() as connection: + async def operation(connection: aiosqlite.Connection) -> list[Entry]: cursor = await connection.execute( """ SELECT DISTINCT e.* @@ -337,10 +435,12 @@ async def get_references_in(self, tenant_id: str, entry_id: str) -> list[Entry]: (entry_id, tenant_id), ) rows = await cursor.fetchall() - return [self._row_to_entry(row) for row in rows] + return [self._row_to_entry(row) for row in rows] + + return await self._run_read_operation(operation) async def audit(self) -> dict[str, object]: - async with self._read_connection() as connection: + async def operation(connection: aiosqlite.Connection) -> dict[str, object]: total_entries = await self._fetch_count(connection, "SELECT COUNT(*) FROM entries") tenant_counts = await self._fetch_key_count( connection, @@ -380,33 +480,48 @@ async def audit(self) -> dict[str, object]: ) """, ) - return { - "total_entries": total_entries, - "tenant_counts": tenant_counts, - "namespace_counts": namespace_counts, - "sync_status_counts": sync_status_counts, - "content_hash_collisions": collisions, - } + return { + "total_entries": total_entries, + "tenant_counts": tenant_counts, + "namespace_counts": namespace_counts, + "sync_status_counts": sync_status_counts, + "content_hash_collisions": collisions, + } + + return await self._run_read_operation(operation) + + async def checkpoint_wal(self, *, mode: str = "TRUNCATE") -> tuple[int, int, int]: + self._ensure_runtime_primitives() + async with self._writer_lock: + async with self._pause_readers(): + connection = await self._open_connection() + try: + busy, log_frames, checkpointed = await self._checkpoint(connection, "PASSIVE") + if busy > 0: + await self._checkpoint(connection, "RESTART") + return await self._checkpoint(connection, mode) + finally: + await connection.close() async def reindex_fts_entries(self, entries: list[Entry]) -> int: if not entries: return 0 - connection = await self._require_writer_connection() - await connection.execute("BEGIN IMMEDIATE") - try: - reindexed = 0 - for entry in entries: - reindexed += await self._refresh_fts_row(connection, entry) - await connection.commit() - return reindexed - except Exception: - await connection.rollback() - raise + async def operation(connection: aiosqlite.Connection) -> int: + await connection.execute("BEGIN IMMEDIATE") + try: + reindexed = 0 + for entry in entries: + reindexed += await self._refresh_fts_row(connection, entry) + await connection.commit() + return reindexed + except Exception: + await connection.rollback() + raise + + return await self._run_writer_operation(operation) async def _open_writer_connection(self) -> None: - self._writer_connection = await aiosqlite.connect(self.database_path) - self._writer_connection.row_factory = aiosqlite.Row - await self._apply_pragmas(self._writer_connection) + self._writer_connection = await self._open_connection() await self._create_schema(self._writer_connection) async def _create_schema(self, connection: aiosqlite.Connection) -> None: @@ -496,13 +611,136 @@ async def _require_writer_connection(self) -> aiosqlite.Connection: @asynccontextmanager async def _read_connection(self): - connection = await aiosqlite.connect(self.database_path) + connection = await self._open_connection() try: - await self._apply_pragmas(connection) yield connection finally: await connection.close() + async def _open_connection(self) -> aiosqlite.Connection: + connection = await aiosqlite.connect(self.database_path) + connection.row_factory = aiosqlite.Row + await self._apply_pragmas(connection) + return connection + + async def _run_read_operation( + self, + operation: Callable[[aiosqlite.Connection], Awaitable[Any]], + ) -> Any: + self._ensure_runtime_primitives() + attempts = 0 + while True: + connection: aiosqlite.Connection | None = None + async with self._acquire_reader_slot(): + try: + connection = await self._open_connection() + return await operation(connection) + except sqlite3.OperationalError as exc: + if not self._should_retry_operational_error(exc, attempts): + raise + attempts += 1 + await self._recycle_broken_connection(connection, exc) + connection = None + finally: + if connection is not None: + await connection.close() + + async def _run_writer_operation( + self, + operation: Callable[[aiosqlite.Connection], Awaitable[Any]], + ) -> Any: + self._ensure_runtime_primitives() + attempts = 0 + while True: + async with self._writer_lock: + connection = await self._require_writer_connection() + try: + return await operation(connection) + except sqlite3.OperationalError as exc: + if not self._should_retry_operational_error(exc, attempts): + raise + attempts += 1 + await self._recycle_writer_connection(connection, exc) + + async def _recycle_writer_connection( + self, + connection: aiosqlite.Connection, + exc: sqlite3.OperationalError, + ) -> None: + await self._recycle_broken_connection(connection, exc) + self._writer_connection = None + await self._open_writer_connection() + + async def _recycle_broken_connection( + self, + connection: aiosqlite.Connection | None, + exc: sqlite3.OperationalError, + ) -> None: + if connection is None: + return + logger.warning( + "aiosqlite connection recycled after disk I/O error connection_id=%s error=%s", + hex(id(connection)), + exc, + ) + try: + await connection.close() + except Exception as close_exc: # pragma: no cover - best-effort cleanup only + logger.warning( + "aiosqlite connection close failed during recycle connection_id=%s error=%s", + hex(id(connection)), + close_exc, + ) + + @staticmethod + def _should_retry_operational_error( + exc: sqlite3.OperationalError, + attempts: int, + ) -> bool: + if attempts >= 1: + return False + message = str(exc).lower() + return any(marker in message for marker in _TRANSIENT_OPERATIONAL_ERROR_MARKERS) + + @asynccontextmanager + async def _acquire_reader_slot(self): + self._ensure_runtime_primitives() + async with self._reader_state: + while self._checkpoint_requested: + await self._reader_state.wait() + self._active_readers += 1 + try: + yield + finally: + async with self._reader_state: + self._active_readers -= 1 + self._reader_state.notify_all() + + @asynccontextmanager + async def _pause_readers(self): + self._ensure_runtime_primitives() + async with self._reader_state: + self._checkpoint_requested = True + while self._active_readers > 0: + await self._reader_state.wait() + try: + yield + finally: + async with self._reader_state: + self._checkpoint_requested = False + self._reader_state.notify_all() + + @staticmethod + async def _checkpoint( + connection: aiosqlite.Connection, + mode: str, + ) -> tuple[int, int, int]: + cursor = await connection.execute(f"PRAGMA wal_checkpoint({mode});") + row = await cursor.fetchone() + if row is None: + raise RuntimeError(f"wal_checkpoint({mode}) returned no result") + return (int(row[0]), int(row[1]), int(row[2])) + @staticmethod async def _apply_pragmas(connection: aiosqlite.Connection) -> None: await connection.execute("PRAGMA journal_mode=WAL;") @@ -548,6 +786,7 @@ def _apply_common_filters( agent_id: str | None, types: list[str] | None, tags: list[str] | None, + sync_status: str | None, since: datetime | None, until: datetime | None, cursor: str | None, @@ -575,6 +814,9 @@ def _apply_common_filters( """ ) params.append(tag) + if sync_status: + conditions.append(f"{alias}.sync_status = ?") + params.append(sync_status) if since: conditions.append(f"{alias}.created_at >= ?") params.append(since.isoformat()) @@ -595,7 +837,7 @@ def _normalize_fts_query(query: str) -> str: @staticmethod def _normalize_bm25(score: float) -> float: - return 1.0 / (1.0 + abs(score)) + return -score / (1.0 - score) @classmethod def _build_fts_document(cls, entry: Entry) -> tuple[str, str, str]: diff --git a/src/memory_hall/storage/vector_store.py b/src/memory_hall/storage/vector_store.py index 2acbae7..30a46d6 100644 --- a/src/memory_hall/storage/vector_store.py +++ b/src/memory_hall/storage/vector_store.py @@ -4,6 +4,7 @@ import logging import math import sqlite3 +import threading from pathlib import Path from typing import Protocol @@ -29,6 +30,8 @@ def contains(self, tenant_id: str, entry_id: str) -> bool: ... def delete(self, tenant_id: str, entry_id: str) -> None: ... + def checkpoint_wal(self, *, mode: str = "TRUNCATE") -> tuple[int, int, int]: ... + class SqliteVecStore: """Vector store backed by sqlite-vec vec0 virtual table when available. @@ -44,104 +47,126 @@ def __init__(self, database_path: Path, dim: int = 1024) -> None: self.dim = dim self._connection: sqlite3.Connection | None = None self._vec0_enabled: bool = False + self._connection_lock = threading.RLock() def open(self) -> None: self.database_path.parent.mkdir(parents=True, exist_ok=True) - connection = sqlite3.connect(self.database_path, check_same_thread=False) - connection.row_factory = sqlite3.Row - self._apply_pragmas(connection) - self._vec0_enabled = self._try_load_vec0(connection) - if self._vec0_enabled: - self._init_vec0_table(connection) - else: - self._init_fallback_table(connection) - self._connection = connection + with self._connection_lock: + connection = sqlite3.connect(self.database_path, check_same_thread=False) + connection.row_factory = sqlite3.Row + self._apply_pragmas(connection) + self._vec0_enabled = self._try_load_vec0(connection) + if self._vec0_enabled: + self._init_vec0_table(connection) + else: + self._init_fallback_table(connection) + self._connection = connection def close(self) -> None: - if self._connection is not None: - self._connection.close() - self._connection = None + with self._connection_lock: + if self._connection is not None: + self._connection.close() + self._connection = None def healthcheck(self) -> None: - connection = self._require_connection() - connection.execute("SELECT 1").fetchone() + with self._connection_lock: + connection = self._require_connection() + connection.execute("SELECT 1").fetchone() def upsert(self, tenant_id: str, entry_id: str, vec: list[float]) -> None: self._validate_vector(vec) - connection = self._require_connection() - if self._vec0_enabled: - import sqlite_vec # type: ignore[import-not-found] - - blob = sqlite_vec.serialize_float32(vec) - connection.execute( - "DELETE FROM vectors WHERE tenant_id = ? AND entry_id = ?", - (tenant_id, entry_id), - ) - connection.execute( - "INSERT INTO vectors(tenant_id, entry_id, embedding) VALUES (?, ?, ?)", - (tenant_id, entry_id, blob), - ) - else: - connection.execute( - """ - INSERT INTO vectors (tenant_id, entry_id, vector_json) - VALUES (?, ?, ?) - ON CONFLICT(tenant_id, entry_id) - DO UPDATE SET vector_json = excluded.vector_json - """, - (tenant_id, entry_id, json.dumps(vec)), - ) - connection.commit() + with self._connection_lock: + connection = self._require_connection() + if self._vec0_enabled: + import sqlite_vec # type: ignore[import-not-found] + + blob = sqlite_vec.serialize_float32(vec) + connection.execute( + "DELETE FROM vectors WHERE tenant_id = ? AND entry_id = ?", + (tenant_id, entry_id), + ) + connection.execute( + "INSERT INTO vectors(tenant_id, entry_id, embedding) VALUES (?, ?, ?)", + (tenant_id, entry_id, blob), + ) + else: + connection.execute( + """ + INSERT INTO vectors (tenant_id, entry_id, vector_json) + VALUES (?, ?, ?) + ON CONFLICT(tenant_id, entry_id) + DO UPDATE SET vector_json = excluded.vector_json + """, + (tenant_id, entry_id, json.dumps(vec)), + ) + connection.commit() def search(self, tenant_id: str, query_vec: list[float], k: int) -> list[SearchCandidate]: self._validate_vector(query_vec) - connection = self._require_connection() - if self._vec0_enabled: - import sqlite_vec # type: ignore[import-not-found] + with self._connection_lock: + connection = self._require_connection() + if self._vec0_enabled: + import sqlite_vec # type: ignore[import-not-found] + + rows = connection.execute( + """ + SELECT entry_id, distance + FROM vectors + WHERE embedding MATCH ? AND tenant_id = ? AND k = ? + ORDER BY distance + """, + (sqlite_vec.serialize_float32(query_vec), tenant_id, k), + ).fetchall() + return [ + SearchCandidate( + entry_id=row["entry_id"], + score=self._cosine_distance_to_similarity(float(row["distance"])), + ) + for row in rows + ] rows = connection.execute( - """ - SELECT entry_id, distance - FROM vectors - WHERE embedding MATCH ? AND tenant_id = ? AND k = ? - ORDER BY distance - """, - (sqlite_vec.serialize_float32(query_vec), tenant_id, k), + "SELECT entry_id, vector_json FROM vectors WHERE tenant_id = ?", + (tenant_id,), ).fetchall() - return [ - SearchCandidate( - entry_id=row["entry_id"], - score=self._cosine_distance_to_similarity(float(row["distance"])), + scored: list[tuple[str, float]] = [ + ( + row["entry_id"], + self._cosine_similarity(query_vec, json.loads(row["vector_json"])), ) for row in rows ] - - rows = connection.execute( - "SELECT entry_id, vector_json FROM vectors WHERE tenant_id = ?", - (tenant_id,), - ).fetchall() - scored: list[tuple[str, float]] = [ - (row["entry_id"], self._cosine_similarity(query_vec, json.loads(row["vector_json"]))) - for row in rows - ] - scored.sort(key=lambda item: item[1], reverse=True) - return [SearchCandidate(entry_id=entry_id, score=score) for entry_id, score in scored[:k]] + scored.sort(key=lambda item: item[1], reverse=True) + return [ + SearchCandidate(entry_id=entry_id, score=score) + for entry_id, score in scored[:k] + ] def contains(self, tenant_id: str, entry_id: str) -> bool: - connection = self._require_connection() - row = connection.execute( - "SELECT 1 FROM vectors WHERE tenant_id = ? AND entry_id = ?", - (tenant_id, entry_id), - ).fetchone() - return row is not None + with self._connection_lock: + connection = self._require_connection() + row = connection.execute( + "SELECT 1 FROM vectors WHERE tenant_id = ? AND entry_id = ?", + (tenant_id, entry_id), + ).fetchone() + return row is not None def delete(self, tenant_id: str, entry_id: str) -> None: - connection = self._require_connection() - connection.execute( - "DELETE FROM vectors WHERE tenant_id = ? AND entry_id = ?", - (tenant_id, entry_id), - ) - connection.commit() + with self._connection_lock: + connection = self._require_connection() + connection.execute( + "DELETE FROM vectors WHERE tenant_id = ? AND entry_id = ?", + (tenant_id, entry_id), + ) + connection.commit() + + def checkpoint_wal(self, *, mode: str = "TRUNCATE") -> tuple[int, int, int]: + with self._connection_lock: + connection = self._require_connection() + busy, log_frames, checkpointed = self._checkpoint(connection, "PASSIVE") + if busy > 0: + self._checkpoint(connection, "RESTART") + return self._checkpoint(connection, mode) def _try_load_vec0(self, connection: sqlite3.Connection) -> bool: if not hasattr(connection, "enable_load_extension"): @@ -204,6 +229,13 @@ def _validate_vector(self, vec: list[float]) -> None: if len(vec) != self.dim: raise ValueError(f"expected vector length {self.dim}, got {len(vec)}") + @staticmethod + def _checkpoint(connection: sqlite3.Connection, mode: str) -> tuple[int, int, int]: + row = connection.execute(f"PRAGMA wal_checkpoint({mode});").fetchone() + if row is None: + raise RuntimeError(f"wal_checkpoint({mode}) returned no result") + return (int(row[0]), int(row[1]), int(row[2])) + @staticmethod def _apply_pragmas(connection: sqlite3.Connection) -> None: connection.execute("PRAGMA journal_mode=WAL;") diff --git a/tests/conftest.py b/tests/conftest.py index 17769b0..8e11f02 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -81,13 +81,22 @@ def deterministic_embedder() -> DeterministicEmbedder: @pytest.fixture(autouse=True) def isolate_api_token_env(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delenv("MH_API_TOKEN", raising=False) + monkeypatch.delenv("MH_ADMIN_TOKEN", raising=False) @pytest.fixture() def app_factory(tmp_path: Path): - def factory(*, tenant_id: str = "default", embedder=None, base_dir: Path | None = None): + def factory( + *, + tenant_id: str = "default", + embedder=None, + base_dir: Path | None = None, + hybrid_mode: str | None = None, + ): root = base_dir or tmp_path settings = build_settings(root, tenant_id=tenant_id) + if hybrid_mode is not None: + settings.hybrid_mode = hybrid_mode # type: ignore[assignment] active_embedder = embedder or DeterministicEmbedder(dim=settings.vector_dim) return create_app(settings=settings, embedder=active_embedder) diff --git a/tests/test_auth.py b/tests/test_auth.py index a05f56c..3c779d2 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -95,14 +95,13 @@ async def test_auth_enabled_wrong_scheme_returns_401(tmp_path: Path) -> None: @pytest.mark.asyncio -async def test_auth_enabled_health_endpoint_stays_public(tmp_path: Path) -> None: +async def test_auth_enabled_health_endpoints_stay_public(tmp_path: Path) -> None: settings = build_settings(tmp_path) settings.api_token = "secret-token-abc" app = create_app(settings=settings, embedder=DeterministicEmbedder(dim=settings.vector_dim)) async with client_for_app(app) as client: response = await client.get("/v1/health") - # Health returns 200 (or 503 degraded). Point is: not 401. - assert response.status_code != 401 + assert response.status_code != 401 @pytest.mark.asyncio @@ -116,3 +115,126 @@ async def test_auth_enabled_search_requires_token(tmp_path: Path) -> None: json={"query": "anything", "mode": "hybrid", "limit": 5}, ) assert response.status_code == 401 + + +# ---------- ADR 0009: admin gate (two-tier bearer) ------------------------ + + +@pytest.mark.asyncio +async def test_admin_token_unset_admin_falls_back_to_api_token(tmp_path: Path) -> None: + """Backward compat: when MH_ADMIN_TOKEN is unset, /v1/admin/* uses api_token.""" + settings = build_settings(tmp_path) + settings.api_token = "shared-token" + settings.admin_token = None + app = create_app(settings=settings, embedder=DeterministicEmbedder(dim=settings.vector_dim)) + async with client_for_app(app) as client: + response = await client.post( + "/v1/admin/audit", + headers={"Authorization": "Bearer shared-token"}, + ) + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_admin_token_set_correct_token_allows_admin(tmp_path: Path) -> None: + settings = build_settings(tmp_path) + settings.api_token = "shared-token" + settings.admin_token = "admin-only-token" + app = create_app(settings=settings, embedder=DeterministicEmbedder(dim=settings.vector_dim)) + async with client_for_app(app) as client: + response = await client.post( + "/v1/admin/audit", + headers={"Authorization": "Bearer admin-only-token"}, + ) + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_admin_token_set_api_token_rejected_on_admin(tmp_path: Path) -> None: + """When admin_token is set, the regular api_token must NOT grant admin access.""" + settings = build_settings(tmp_path) + settings.api_token = "shared-token" + settings.admin_token = "admin-only-token" + app = create_app(settings=settings, embedder=DeterministicEmbedder(dim=settings.vector_dim)) + async with client_for_app(app) as client: + response = await client.post( + "/v1/admin/audit", + headers={"Authorization": "Bearer shared-token"}, + ) + assert response.status_code == 401 + assert response.json()["detail"] == "invalid admin token" + + +@pytest.mark.asyncio +async def test_admin_token_set_missing_header_returns_401(tmp_path: Path) -> None: + settings = build_settings(tmp_path) + settings.api_token = "shared-token" + settings.admin_token = "admin-only-token" + app = create_app(settings=settings, embedder=DeterministicEmbedder(dim=settings.vector_dim)) + async with client_for_app(app) as client: + response = await client.post("/v1/admin/audit") + assert response.status_code == 401 + assert response.json()["detail"] == "missing bearer token" + + +@pytest.mark.asyncio +async def test_admin_token_does_not_grant_general_endpoints(tmp_path: Path) -> None: + """admin_token is admin-only; it must not work as a general api_token on + non-admin paths (least privilege both directions).""" + settings = build_settings(tmp_path) + settings.api_token = "shared-token" + settings.admin_token = "admin-only-token" + app = create_app(settings=settings, embedder=DeterministicEmbedder(dim=settings.vector_dim)) + async with client_for_app(app) as client: + response = await client.post( + "/v1/memory/write", + json=_write_payload(), + headers={"Authorization": "Bearer admin-only-token"}, + ) + assert response.status_code == 401 + + +@pytest.mark.asyncio +async def test_admin_token_set_health_endpoints_stay_public(tmp_path: Path) -> None: + settings = build_settings(tmp_path) + settings.api_token = "shared-token" + settings.admin_token = "admin-only-token" + app = create_app(settings=settings, embedder=DeterministicEmbedder(dim=settings.vector_dim)) + async with client_for_app(app) as client: + response = await client.get("/v1/health") + assert response.status_code != 401 + + +# ---------- ADR 0009: config-level invariants (fail-fast) ---------------- + + +def test_settings_admin_token_without_api_token_fails(tmp_path: Path) -> None: + """Codex review finding #1 [HIGH]: admin_token set + api_token unset would + fail-open on non-admin paths. Settings load must reject this combo.""" + from pydantic import ValidationError + + from memory_hall.config import Settings + + with pytest.raises(ValidationError, match="admin_token requires api_token"): + Settings( + database_path=tmp_path / "db.sqlite3", + vector_database_path=tmp_path / "vec.sqlite3", + admin_token="admin-only-token", + api_token=None, + ) + + +def test_settings_admin_token_equal_to_api_token_fails(tmp_path: Path) -> None: + """Codex review finding #2 [MEDIUM]: equal tokens silently nullify the + two-tier separation. Settings load must reject this combo.""" + from pydantic import ValidationError + + from memory_hall.config import Settings + + with pytest.raises(ValidationError, match="admin_token must differ from api_token"): + Settings( + database_path=tmp_path / "db.sqlite3", + vector_database_path=tmp_path / "vec.sqlite3", + api_token="same-token", + admin_token="same-token", + ) diff --git a/tests/test_cli_auth.py b/tests/test_cli_auth.py new file mode 100644 index 0000000..3fb5fcc --- /dev/null +++ b/tests/test_cli_auth.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from memory_hall.cli.main import _client + + +def test_client_attaches_bearer_header_when_api_token_set(monkeypatch) -> None: + monkeypatch.setenv("MH_API_TOKEN", "secret-token-abc") + + with _client("http://127.0.0.1:9000", 5.0) as client: + assert client.headers["Authorization"] == "Bearer secret-token-abc" + + +def test_client_omits_bearer_header_when_api_token_unset(monkeypatch) -> None: + monkeypatch.delenv("MH_API_TOKEN", raising=False) + + with _client("http://127.0.0.1:9000", 5.0) as client: + assert "Authorization" not in client.headers + + +def test_client_omits_bearer_header_when_api_token_empty(monkeypatch) -> None: + monkeypatch.setenv("MH_API_TOKEN", "") + + with _client("http://127.0.0.1:9000", 5.0) as client: + assert "Authorization" not in client.headers diff --git a/tests/test_fts_tokenization.py b/tests/test_fts_tokenization.py index d867707..3aed2a0 100644 --- a/tests/test_fts_tokenization.py +++ b/tests/test_fts_tokenization.py @@ -97,6 +97,18 @@ def test_normalize_fts_query_edge_cases() -> None: assert '"系統"' in normalized +def test_normalize_bm25_preserves_rank_order_for_negative_scores() -> None: + raw_scores = [-15.0, -10.0, -5.0, -1.0, -0.1] + normalized = [SqliteStore._normalize_bm25(score) for score in raw_scores] + + assert normalized == sorted(normalized, reverse=True) + assert normalized[0] == pytest.approx(0.9375) + assert normalized[1] == pytest.approx(10.0 / 11.0) + assert normalized[2] == pytest.approx(5.0 / 6.0) + assert normalized[3] == pytest.approx(0.5) + assert normalized[4] == pytest.approx(1.0 / 11.0) + + @pytest.mark.asyncio async def test_reindex_fts_rewrites_legacy_rows(tmp_path) -> None: settings = build_settings(tmp_path) diff --git a/tests/test_hybrid_search.py b/tests/test_hybrid_search.py new file mode 100644 index 0000000..8398748 --- /dev/null +++ b/tests/test_hybrid_search.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from memory_hall.server.app import create_app +from tests.conftest import DeterministicEmbedder, build_settings, client_for_app + + +class WeightedHybridEmbedder(DeterministicEmbedder): + def embed(self, text: str) -> list[float]: + base = [0.0] * self.dim + lower = text.lower() + base[0] = 1.0 if "quokkamode" in lower else 0.0 + base[1] = 1.0 if "rollout" in lower else 0.0 + base[2] = 1.0 if any( + token in lower for token in ("resurface", "restore", "recovery") + ) else 0.0 + base[3] = 1.0 if any(token in lower for token in ("checklist", "list")) else 0.0 + base[4] = 1.0 if ( + lower.strip() in {"hybrid search", "hybrid ranking"} + or "combined retrieval" in lower + or "ranking strategy" in lower + ) else 0.0 + base[5] = 1.0 if ( + lower.strip() in {"hybrid search", "hybrid ranking"} + or ("hybrid" in lower and "combined retrieval" in lower) + ) else 0.0 + base[-1] = 0.01 + return base + + +@pytest.mark.asyncio +async def test_weighted_linear_prefers_rare_lexical_target(app_factory) -> None: + app = app_factory(embedder=WeightedHybridEmbedder(), hybrid_mode="weighted_linear") + async with client_for_app(app) as client: + target = await client.post( + "/v1/memory/write", + json={ + "agent_id": "codex", + "namespace": "shared", + "type": "note", + "content": "quokkamode rollout mitigation log", + }, + ) + distractor = await client.post( + "/v1/memory/write", + json={ + "agent_id": "codex", + "namespace": "shared", + "type": "note", + "content": "rollout playbook for tomorrow morning", + }, + ) + response = await client.post( + "/v1/memory/search", + json={"query": "quokkamode rollout", "limit": 5, "mode": "hybrid"}, + ) + + assert target.status_code == 201 + assert distractor.status_code == 201 + payload = response.json() + assert payload["results"][0]["entry"]["entry_id"] == target.json()["entry_id"] + assert payload["results"][0]["score_breakdown"]["hybrid_mode"] == "weighted_linear" + assert payload["results"][0]["score_breakdown"]["alpha"] == pytest.approx(0.3) + + +@pytest.mark.asyncio +async def test_weighted_linear_recovers_semantic_paraphrase_without_lexical_overlap( + app_factory, +) -> None: + app = app_factory(embedder=WeightedHybridEmbedder(), hybrid_mode="weighted_linear") + async with client_for_app(app) as client: + relevant = await client.post( + "/v1/memory/write", + json={ + "agent_id": "codex", + "namespace": "shared", + "type": "note", + "content": "restore recovery list after embed failures", + }, + ) + await client.post( + "/v1/memory/write", + json={ + "agent_id": "codex", + "namespace": "shared", + "type": "note", + "content": "release calendar for next month", + }, + ) + await client.post( + "/v1/memory/write", + json={ + "agent_id": "codex", + "namespace": "shared", + "type": "note", + "content": "daily standup reminders", + }, + ) + response = await client.post( + "/v1/memory/search", + json={"query": "resurface checklist", "limit": 5, "mode": "hybrid"}, + ) + + assert relevant.status_code == 201 + payload = response.json() + top_three_ids = [item["entry"]["entry_id"] for item in payload["results"][:3]] + assert relevant.json()["entry_id"] in top_three_ids + assert payload["results"][0]["entry"]["entry_id"] == relevant.json()["entry_id"] + assert payload["results"][0]["score_breakdown"]["alpha"] == pytest.approx(0.0) + assert payload["results"][0]["score_breakdown"]["bm25"] == pytest.approx(0.0) + + +@pytest.mark.asyncio +async def test_weighted_linear_rewards_entries_that_hit_both_signals(app_factory) -> None: + app = app_factory(embedder=WeightedHybridEmbedder(), hybrid_mode="weighted_linear") + async with client_for_app(app) as client: + lexical_only = await client.post( + "/v1/memory/write", + json={ + "agent_id": "codex", + "namespace": "shared", + "type": "note", + "content": "hybrid ranking marker", + }, + ) + semantic_only = await client.post( + "/v1/memory/write", + json={ + "agent_id": "codex", + "namespace": "shared", + "type": "note", + "content": "combined retrieval ranking strategy", + }, + ) + both = await client.post( + "/v1/memory/write", + json={ + "agent_id": "codex", + "namespace": "shared", + "type": "note", + "content": "hybrid combined retrieval ranking strategy", + }, + ) + response = await client.post( + "/v1/memory/search", + json={"query": "hybrid ranking", "limit": 5, "mode": "hybrid"}, + ) + + assert lexical_only.status_code == 201 + assert semantic_only.status_code == 201 + assert both.status_code == 201 + payload = response.json() + score_by_id = {item["entry"]["entry_id"]: item["score"] for item in payload["results"]} + both_id = both.json()["entry_id"] + lexical_only_id = lexical_only.json()["entry_id"] + semantic_only_id = semantic_only.json()["entry_id"] + + assert payload["results"][0]["entry"]["entry_id"] == both_id + assert score_by_id[both_id] > score_by_id[lexical_only_id] + assert score_by_id[both_id] > score_by_id[semantic_only_id] + assert score_by_id[both_id] - max( + score_by_id[lexical_only_id], + score_by_id[semantic_only_id], + ) > 0.05 + + +@pytest.mark.asyncio +async def test_hybrid_search_supports_legacy_rrf_mode(tmp_path: Path) -> None: + settings = build_settings(tmp_path) + settings.hybrid_mode = "rrf" + app = create_app( + settings=settings, + embedder=WeightedHybridEmbedder(dim=settings.vector_dim), + ) + + async with client_for_app(app) as client: + await client.post( + "/v1/memory/write", + json={ + "agent_id": "codex", + "namespace": "shared", + "type": "note", + "content": "hybrid combined retrieval ranking strategy", + }, + ) + response = await client.post( + "/v1/memory/search", + json={"query": "hybrid search", "limit": 5, "mode": "hybrid"}, + ) + + payload = response.json() + assert payload["results"][0]["score_breakdown"]["hybrid_mode"] == "rrf" diff --git a/tests/test_reindex_retry.py b/tests/test_reindex_retry.py index 6b7b065..0ebad48 100644 --- a/tests/test_reindex_retry.py +++ b/tests/test_reindex_retry.py @@ -1,10 +1,18 @@ from __future__ import annotations import sqlite3 +from datetime import UTC, datetime import pytest -from memory_hall.models import SYNC_FAILED, SYNC_PENDING, WriteMemoryRequest +from memory_hall.models import ( + SYNC_EMBEDDED, + SYNC_FAILED, + SYNC_PENDING, + Entry, + WriteMemoryRequest, + build_content_hash, +) from memory_hall.server.app import build_runtime from memory_hall.storage.sqlite_store import SqliteStore from tests.conftest import TimeoutEmbedder, build_settings @@ -129,3 +137,142 @@ async def test_store_migrates_legacy_entries_without_data_loss(tmp_path) -> None assert entry.embed_attempt_count == 0 finally: await store.close() + + +@pytest.mark.asyncio +async def test_store_recycles_writer_connection_after_transient_operational_error( + tmp_path, + monkeypatch, +) -> None: + settings = build_settings(tmp_path) + store = SqliteStore(settings.database_path) + await store.open() + try: + outcome = await store.insert_entry( + Entry( + entry_id="01KPREINDEXRETRYTEST0000002", + tenant_id=settings.default_tenant_id, + agent_id="pytest", + namespace="shared", + type="note", + content="writer recycle target", + content_hash=build_content_hash("writer recycle target"), + summary=None, + tags=[], + references=[], + metadata={}, + sync_status=SYNC_PENDING, + last_embedded_at=None, + last_embed_error=None, + last_embed_attempted_at=None, + embed_attempt_count=0, + created_at=datetime(2026, 4, 27, tzinfo=UTC), + created_by_principal="pytest", + ) + ) + writer = await store._require_writer_connection() + closed = False + execute_calls = 0 + original_close = writer.close + original_execute = writer.execute + + async def tracked_close() -> None: + nonlocal closed + closed = True + await original_close() + + async def flaky_execute(sql: str, parameters=()): + nonlocal execute_calls + execute_calls += 1 + if execute_calls == 1: + raise sqlite3.OperationalError("disk I/O error") + return await original_execute(sql, parameters) + + monkeypatch.setattr(writer, "close", tracked_close) + monkeypatch.setattr(writer, "execute", flaky_execute) + + updated = await store.update_sync_status( + settings.default_tenant_id, + outcome.entry.entry_id, + SYNC_EMBEDDED, + None, + None, + None, + 0, + ) + new_writer = await store._require_writer_connection() + finally: + await store.close() + + assert updated is not None + assert updated.sync_status == SYNC_EMBEDDED + assert execute_calls == 1 + assert closed is True + assert new_writer is not writer + + +@pytest.mark.asyncio +async def test_store_recycles_read_connection_after_transient_operational_error( + tmp_path, + monkeypatch, +) -> None: + settings = build_settings(tmp_path) + store = SqliteStore(settings.database_path) + entry = Entry( + entry_id="01KPREINDEXRETRYTEST0000003", + tenant_id=settings.default_tenant_id, + agent_id="pytest", + namespace="shared", + type="note", + content="read recycle target", + content_hash=build_content_hash("read recycle target"), + summary=None, + tags=[], + references=[], + metadata={}, + sync_status=SYNC_PENDING, + last_embedded_at=None, + last_embed_error=None, + last_embed_attempted_at=None, + embed_attempt_count=0, + created_at=datetime(2026, 4, 27, tzinfo=UTC), + created_by_principal="pytest", + ) + await store.open() + try: + await store.insert_entry(entry) + + original_open_connection = store._open_connection + open_calls = 0 + closed = False + + async def flaky_open_connection(): + nonlocal open_calls, closed + connection = await original_open_connection() + open_calls += 1 + if open_calls == 1: + original_close = connection.close + + async def tracked_close() -> None: + nonlocal closed + closed = True + await original_close() + + async def flaky_execute(sql: str, parameters=()): + del sql, parameters + raise sqlite3.OperationalError("database is locked") + + monkeypatch.setattr(connection, "close", tracked_close) + monkeypatch.setattr(connection, "execute", flaky_execute) + return connection + + monkeypatch.setattr(store, "_open_connection", flaky_open_connection) + + fetched = await store.get_entry(settings.default_tenant_id, entry.entry_id) + finally: + await store.close() + + assert fetched is not None + assert fetched.entry_id == entry.entry_id + assert open_calls == 2 + assert closed is True diff --git a/tests/test_search_degraded.py b/tests/test_search_degraded.py index 15fb698..f342a2b 100644 --- a/tests/test_search_degraded.py +++ b/tests/test_search_degraded.py @@ -7,7 +7,7 @@ @pytest.mark.asyncio async def test_hybrid_search_marks_timeout_degradation(app_factory) -> None: - app = app_factory() + app = app_factory(hybrid_mode="weighted_linear") async with client_for_app(app) as client: write_response = await client.post( "/v1/memory/write", @@ -40,3 +40,8 @@ async def test_hybrid_search_marks_timeout_degradation(app_factory) -> None: assert payload["degraded"] is True assert payload["results"][0]["entry"]["content"] == "hybrid timeout fallback note" assert payload["results"][0]["score_breakdown"]["semantic_status"] == "timeout" + assert payload["results"][0]["score_breakdown"]["hybrid_mode"] == "weighted_linear" + assert payload["results"][0]["score_breakdown"]["alpha"] == pytest.approx(1.0) + assert payload["results"][0]["score"] == pytest.approx( + payload["results"][0]["score_breakdown"]["bm25"] + ) diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 2039a73..638c72f 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -1,12 +1,15 @@ from __future__ import annotations +import asyncio import logging import time from pathlib import Path import pytest +from memory_hall.models import Entry, build_content_hash, utc_now from memory_hall.server.app import create_app +from memory_hall.storage.vector_store import SqliteVecStore from tests.conftest import DeterministicEmbedder, TimeoutEmbedder, build_settings, client_for_app @@ -30,6 +33,38 @@ def embed(self, text: str) -> list[float]: return super().embed(text) +def _entry(index: int, tenant_id: str) -> Entry: + content = f"wal checkpoint entry {index}" + return Entry( + entry_id=f"01KPWALCHECKPOINT{index:08d}", + tenant_id=tenant_id, + agent_id="pytest", + namespace="shared", + type="note", + content=content, + content_hash=build_content_hash(content), + summary=None, + tags=[], + references=[], + metadata={}, + sync_status="embedded", + last_embedded_at=None, + last_embed_error=None, + last_embed_attempted_at=None, + embed_attempt_count=0, + created_at=utc_now(), + created_by_principal="pytest", + ) + + +def _wal_path(path: Path) -> Path: + return path.with_name(f"{path.name}-wal") + + +def _wal_size(path: Path) -> int: + return path.stat().st_size if path.exists() else 0 + + @pytest.mark.asyncio async def test_health_returns_ok(app_factory) -> None: app = app_factory() @@ -113,6 +148,46 @@ async def test_health_logs_subcheck_error_and_exposes_last_error(app_factory, ca ) +@pytest.mark.asyncio +async def test_wal_checkpoint_truncates_main_and_vector_wal(tmp_path: Path) -> None: + settings = build_settings(tmp_path) + settings.wal_checkpoint_interval_s = 300.0 + app = create_app( + settings=settings, + embedder=DeterministicEmbedder(dim=settings.vector_dim), + ) + + async with app.router.lifespan_context(app): + runtime = app.state.runtime + vector_store = runtime.vector_store + assert isinstance(vector_store, SqliteVecStore) + + for index in range(100): + entry = _entry(index, settings.default_tenant_id) + await runtime.storage.insert_entry(entry) + await asyncio.to_thread( + vector_store.upsert, + entry.tenant_id, + entry.entry_id, + [float(index + 1)] * settings.vector_dim, + ) + + main_wal_path = _wal_path(settings.database_path) + vector_wal_path = _wal_path(settings.vector_database_path) + + assert main_wal_path.exists() + assert vector_wal_path.exists() + main_wal_before = _wal_size(main_wal_path) + vector_wal_before = _wal_size(vector_wal_path) + assert main_wal_before > 32 * 1024 + assert vector_wal_before > 32 * 1024 + + await runtime._checkpoint_wal_databases() + + assert _wal_size(main_wal_path) <= 32 * 1024 + assert _wal_size(vector_wal_path) <= 32 * 1024 + + @pytest.mark.asyncio async def test_list_endpoint_accepts_limit_1000_and_rejects_1001(app_factory) -> None: app = app_factory() diff --git a/tests/test_sync_status.py b/tests/test_sync_status.py index a9bd874..a0dfbac 100644 --- a/tests/test_sync_status.py +++ b/tests/test_sync_status.py @@ -1,8 +1,14 @@ from __future__ import annotations +import asyncio +import logging +from datetime import timedelta + import pytest -from tests.conftest import DeterministicEmbedder, TimeoutEmbedder, client_for_app +from memory_hall.models import SYNC_FAILED, SYNC_PENDING, Entry, build_content_hash, utc_now +from memory_hall.server.app import ReindexJob, build_runtime +from tests.conftest import DeterministicEmbedder, TimeoutEmbedder, build_settings, client_for_app class BatchTrackingEmbedder(DeterministicEmbedder): @@ -20,6 +26,30 @@ def embed_batch(self, texts: list[str]) -> list[list[float]]: return [DeterministicEmbedder.embed(self, text) for text in texts] +def _entry(index: int, *, tenant_id: str, sync_status: str) -> Entry: + content = f"{sync_status} backlog entry {index}" + return Entry( + entry_id=f"01KPHA5REINDEX{index:012d}", + tenant_id=tenant_id, + agent_id="pytest", + namespace="shared", + type="note", + content=content, + content_hash=build_content_hash(f"{content}-{index}"), + summary=None, + tags=[], + references=[], + metadata={}, + sync_status=sync_status, + last_embedded_at=None, + last_embed_error=None, + last_embed_attempted_at=None, + embed_attempt_count=0, + created_at=utc_now() - timedelta(seconds=index), + created_by_principal="pytest", + ) + + @pytest.mark.asyncio async def test_pending_write_reindexes_to_embedded(app_factory) -> None: app = app_factory(embedder=TimeoutEmbedder()) @@ -71,3 +101,88 @@ async def test_reindex_uses_embed_batch_for_pending_backlog(app_factory) -> None assert reindex_response.json()["embedded"] == 3 assert tracking.embed_batch_calls == [3] assert tracking.embed_calls == 0 + + +@pytest.mark.asyncio +async def test_admin_reindex_paginates_failed_backlog(app_factory, monkeypatch, caplog) -> None: + app = app_factory() + async with client_for_app(app) as client: + runtime = app.state.runtime + tenant_id = app.state.settings.default_tenant_id + for index in range(205): + await runtime.storage.insert_entry( + _entry(index, tenant_id=tenant_id, sync_status=SYNC_FAILED) + ) + + original_list_entries = runtime.storage.list_entries + list_calls: list[dict[str, object]] = [] + + async def tracked_list_entries(tenant: str, **kwargs): + assert tenant == tenant_id + assert kwargs.get("limit") == 200 + list_calls.append(dict(kwargs)) + return await original_list_entries(tenant, **kwargs) + + monkeypatch.setattr(runtime.storage, "list_entries", tracked_list_entries) + + with caplog.at_level(logging.INFO): + response = await client.post("/v1/admin/reindex") + + assert response.status_code == 200 + payload = response.json() + assert payload["scanned"] == 205 + assert payload["embedded"] == 205 + assert payload["pending"] == 0 + assert len(list_calls) == 2 + assert [call["cursor"] for call in list_calls] == [None, list_calls[1]["cursor"]] + assert list_calls[1]["cursor"] is not None + assert all(call["sync_status"] is None for call in list_calls) + assert any("reindex batch 1/2, 200 done" in record.message for record in caplog.records) + assert any("reindex batch 2/2, 205 done" in record.message for record in caplog.records) + + +@pytest.mark.asyncio +async def test_pending_only_reindex_paginates_pending_entries_only(tmp_path, monkeypatch) -> None: + settings = build_settings(tmp_path) + runtime = build_runtime( + settings=settings, + embedder=DeterministicEmbedder(dim=settings.vector_dim), + ) + await runtime.start() + try: + for index in range(205): + await runtime.storage.insert_entry( + _entry(index, tenant_id=settings.default_tenant_id, sync_status=SYNC_PENDING) + ) + failed_entry = _entry(9999, tenant_id=settings.default_tenant_id, sync_status=SYNC_FAILED) + await runtime.storage.insert_entry(failed_entry) + + original_list_entries = runtime.storage.list_entries + list_calls: list[dict[str, object]] = [] + + async def tracked_list_entries(tenant: str, **kwargs): + assert tenant == settings.default_tenant_id + assert kwargs.get("limit") == 200 + assert kwargs.get("sync_status") == SYNC_PENDING + list_calls.append(dict(kwargs)) + return await original_list_entries(tenant, **kwargs) + + monkeypatch.setattr(runtime.storage, "list_entries", tracked_list_entries) + + future: asyncio.Future = asyncio.get_running_loop().create_future() + outcome = await runtime._handle_reindex( + ReindexJob( + tenant_id=settings.default_tenant_id, + future=future, + pending_only=True, + ) + ) + finally: + await runtime.stop() + + assert outcome.scanned == 205 + assert outcome.embedded == 205 + assert outcome.pending == 0 + assert len(list_calls) == 2 + assert [call["cursor"] for call in list_calls] == [None, list_calls[1]["cursor"]] + assert list_calls[1]["cursor"] is not None