diff --git a/dev/plans/2026-03-19-phase10-msrc-csaf-fix-plan.md b/dev/plans/2026-03-19-phase10-msrc-csaf-fix-plan.md new file mode 100644 index 00000000..5dab1034 --- /dev/null +++ b/dev/plans/2026-03-19-phase10-msrc-csaf-fix-plan.md @@ -0,0 +1,827 @@ +# MSRC Adapter CSAF Fix + Phase 10 Completion Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Fix the broken MSRC adapter (switch from non-existent `/csaf/{id}` API endpoint to Microsoft's real CSAF 2.0 static file distribution), capture proper CSAF golden fixtures, write remaining golden tests (EPSS + MSRC), add MSRC to SeedCorpus, and complete Phase 10 test fixture corpus verification. + +**Architecture:** The MSRC adapter's `Fetch()` method is rewritten to use Microsoft's CSAF static file distribution (`msrc.microsoft.com/csaf/advisories/`). Discovery uses `changes.csv` (a standard CSAF directory mechanism) for incremental sync. Individual per-CVE CSAF 2.0 JSON files are downloaded and parsed by the existing `csaf.Parse()` function. The `csafToPatches()` and `buildVendorEnrichment()` functions are unchanged — they already handle CSAF 2.0 correctly. + +**Tech Stack:** Go 1.26, `net/http`, `encoding/csv`, `encoding/json`, `internal/feed/csaf` (existing parser), `httptest` (golden file tests), testcontainers (EPSS test) + +**Worktree:** All work happens in `.claude/worktrees/phase10-fixture-corpus` on branch `phase10/test-fixture-corpus`. + +--- + +## Context + +**The Bug:** The MSRC adapter constructs URLs as `api.msrc.microsoft.com/cvrf/v3.0/csaf/{releaseID}`, but this endpoint does not exist. It returns HTTP 400 "Invalid ID format" for every release ID. The adapter was tested against hand-crafted CSAF fixtures and passed all tests, but never worked against real Microsoft data. + +**The Fix:** Microsoft publishes real CSAF 2.0 files at `msrc.microsoft.com/csaf/advisories/`. Discovery mechanism: +- `changes.csv` — CSV of `"path","timestamp"` pairs, sorted newest-first, no header row +- `index.txt` — plain text list of all file paths +- Individual files at `https://msrc.microsoft.com/csaf/advisories/{year}/msrc_cve-{id}.json` + +Each file is ~6KB of proper CSAF 2.0 JSON (keys: `document`, `product_tree`, `vulnerabilities`) that the existing `csaf.Parse()` handles correctly. Files are per-CVE (one vulnerability per document), not per-release-month. + +**What stays unchanged:** +- `csafToPatches()` — converts `csaf.Document` to `[]feed.CanonicalPatch` (all field mappings preserved) +- `buildVendorEnrichment()` — extracts vendor-specific metadata (severity, fix state, KB articles, exploitability) +- `parseCSAFDocument()` — delegates to `csaf.Parse()` +- All `TestCSAFToPatches_*` unit tests (6 tests) — they test CSAF parsing, not Fetch +- The `csaf` parser package — untouched + +**What changes:** +- `Fetch()` method — new two-phase: download `changes.csv` → download per-CVE CSAF files +- `Cursor` struct — `last_release_date` → `last_updated` (cursor semantics changed) +- Constants — `baseURL` → new CSAF base URL; size limits adjusted +- Removed: `updateEntry`, `updatesResponse`, `parseUpdates()`, `dateTimeRe`, OData filter logic +- Fetch tests — rewritten for new flow (5 tests) + +**Testing-pitfalls warnings applicable to this plan:** +- §9.4 (Falsy-value preservation): CVSS 0.0 must be preserved. Existing `TestCSAFToPatches_CVSSZeroIsValid` covers this and stays unchanged. +- §9.1 (Wire format assumptions): We've verified the real CSAF file format with `curl` — proper CSAF 2.0 JSON. +- §7 (Test data must flow through production code paths): SeedCorpus feeds through `merge.Ingest`, not raw SQL. +- §16 (Test setup must not discard errors): All `os.ReadFile`, `json.Unmarshal` etc. must check errors with `require.NoError` or `t.Fatalf`. + +**Implementation-pitfalls warnings:** +- FEED-1 (Wire format): Verified — per-CVE CSAF files use `{"document":..., "product_tree":..., "vulnerabilities":[...]}`. Not streaming (files are ~6KB). +- FEED-5 (defer in loops): When downloading multiple CSAF files in a loop, use explicit `resp.Body.Close()` per iteration, NOT `defer`. +- FEED-10 (String cloning): `csafToPatches()` already clones strings. No changes needed there. +- FEED-16 (Body drain): Drain response body before close on non-200 responses. + +--- + +## Subagent Execution Protocol + +All tasks that write tests MUST follow this protocol. + +### Before starting any task: +``` +1. Read dev/testing-pitfalls.md +2. Read the TDD skill at .claude/skills/test-driven-development/ (or invoke /test-driven-development) +For pure test additions: write the test, verify it fails for the right reason +(or passes if it's testing already-correct behavior), then move on. +For code bugs: write failing test → fix code → verify green. +``` + +### Before marking any task complete: +``` +1. Review your tests against dev/testing-pitfalls.md +2. Verify test coverage of the fix (are error paths tested? edge cases?) +3. Run tests for the relevant packages only (e.g., `go test ./internal/feed/msrc/... -count=1`). + Do NOT run `go test ./...` — Docker container overload is a known issue. +``` + +### After completing each phase: +``` +You MUST carefully review the batch of work from multiple perspectives +and revise/refine as appropriate. Repeat this review loop (you must do +a minimum of three review rounds; if you still find substantive issues +in the third review, keep going with additional rounds until there are +no findings) until you're confident there aren't any more issues. Then +update your private journal and continue onto the next phase. +``` + +--- + +## Phase A: Capture Real CSAF Fixtures + +### Task 1: Download real CSAF files for golden test data + +**Files:** +- Replace: `internal/feed/msrc/testdata/golden/csaf/*.json` (currently CVRF-format files, must be replaced with real CSAF 2.0) +- Create: `internal/feed/msrc/testdata/golden/changes.csv` +- Remove: `internal/feed/msrc/testdata/golden/updates.json` (OData response, no longer needed) + +**Step 1: Download CSAF files for manifest CVEs** + +The test fixture manifest (`dev/plans/test-fixture-manifest.json`) includes 3 MSRC CVEs: CVE-2026-3909, CVE-2026-21510, CVE-2025-14174. Download their CSAF files. Note that the CSAF filename format is `msrc_cve-{id}.json` with lowercase `cve`. + +```bash +# From worktree root +mkdir -p internal/feed/msrc/testdata/golden/csaf + +# Try each CVE — some may not have CSAF files yet +curl -sL "https://msrc.microsoft.com/csaf/advisories/2026/msrc_cve-2026-3909.json" -o /tmp/msrc_cve-2026-3909.json +curl -sL "https://msrc.microsoft.com/csaf/advisories/2026/msrc_cve-2026-21510.json" -o /tmp/msrc_cve-2026-21510.json +curl -sL "https://msrc.microsoft.com/csaf/advisories/2025/msrc_cve-2025-14174.json" -o /tmp/msrc_cve-2025-14174.json +``` + +Check each file: a valid CSAF file starts with `{"document":`. If any returns a 404 or HTML, it doesn't exist. In that case, pick replacement CVEs from the CSAF index: + +```bash +curl -sL "https://msrc.microsoft.com/csaf/advisories/index.txt" | grep "2026/" | head -10 +``` + +Download 3-5 valid CSAF files. Place them in `internal/feed/msrc/testdata/golden/csaf/` with their original filenames (e.g., `msrc_cve-2026-3909.json`). + +Verify each file is valid CSAF 2.0: +- Contains `"document"` key with `"tracking"` sub-key +- Contains `"vulnerabilities"` array with at least one entry having a `"cve"` field +- Contains `"product_tree"` key + +**Step 2: Create a changes.csv fixture** + +Build a `changes.csv` listing only the downloaded CSAF files. Format: `"path","timestamp"` with no header row. Use the `document.tracking.current_release_date` from each file as the timestamp. Example: + +```csv +"2026/msrc_cve-2026-3909.json","2026-03-18T01:00:00Z" +"2026/msrc_cve-2026-21510.json","2026-03-17T07:00:00Z" +"2025/msrc_cve-2025-14174.json","2026-03-12T07:00:00Z" +``` + +Save to `internal/feed/msrc/testdata/golden/changes.csv`. + +**Step 3: Remove old CVRF fixtures** + +Delete: +- `internal/feed/msrc/testdata/golden/updates.json` (OData response, no longer used) +- All files in `internal/feed/msrc/testdata/golden/csaf/` that are CVRF format (check: CVRF files have `"DocumentTitle"` key; CSAF files have `"document"` key) + +**Step 4: Commit** + +```bash +git add internal/feed/msrc/testdata/golden/ +git commit -m "test: replace CVRF fixtures with real CSAF 2.0 files from msrc.microsoft.com" +``` + +--- + +## Phase B: Fix MSRC Adapter + +### Task 2: Rewrite MSRC adapter Fetch method for CSAF static files + +**Files:** +- Modify: `internal/feed/msrc/adapter.go` + +**Current behavior:** Fetches from `api.msrc.microsoft.com/cvrf/v3.0/updates` (OData) then `api.msrc.microsoft.com/cvrf/v3.0/csaf/{releaseID}` (broken endpoint). Returns patches grouped by monthly release. + +**Desired behavior:** Fetches `changes.csv` from `msrc.microsoft.com/csaf/advisories/changes.csv`, filters by cursor timestamp, downloads individual per-CVE CSAF files, parses them with the existing CSAF parser. Returns patches — one per CVE file. + +**Step 1: Update constants and imports** + +Replace the constants section: + +```go +const ( + // SourceName is the canonical feed name stored in cve_sources. + SourceName = "msrc" + + // baseURL is the MSRC CSAF advisory distribution base. + baseURL = "https://msrc.microsoft.com/csaf/advisories/" + + // maxChangesSize caps the changes.csv response to prevent OOM. + maxChangesSize = 10 << 20 // 10 MB + + // maxCSAFDocSize caps individual CSAF file response to prevent OOM. + // Per-CVE files are typically ~6KB; 1MB is generous. + maxCSAFDocSize = 1 << 20 // 1 MB +) +``` + +Remove these imports (no longer needed): `"net/url"`, `"regexp"`. +Add this import: `"encoding/csv"`. + +**Step 2: Update Cursor struct** + +```go +// Cursor is the JSON-serializable sync state for the MSRC adapter. +type Cursor struct { + LastUpdated string `json:"last_updated"` +} +``` + +**Step 3: Remove dead code** + +Delete these types and functions entirely: +- `dateTimeRe` (regexp for OData injection prevention — no longer needed) +- `updateEntry` struct +- `updatesResponse` struct +- `parseUpdates()` function + +Keep these functions unchanged: +- `parseCSAFDocument()` — still used +- `csafToPatches()` — still used +- `buildVendorEnrichment()` — still used + +**Step 4: Add changes.csv parser** + +```go +// changeEntry represents a single row from the CSAF changes.csv file. +type changeEntry struct { + Path string // e.g., "2026/msrc_cve-2026-3909.json" + Timestamp string // e.g., "2026-03-18T01:00:00Z" +} + +// parseChangesCSV parses the CSAF changes.csv file. The CSV has no header row; +// each row is "path","timestamp". Returns entries sorted by the CSV's natural +// order (newest first). +func parseChangesCSV(r io.Reader) ([]changeEntry, error) { + cr := csv.NewReader(r) + cr.FieldsPerRecord = 2 + cr.ReuseRecord = true + + var entries []changeEntry + for { + record, err := cr.Read() + if err == io.EOF { + break + } + if err != nil { + return nil, fmt.Errorf("msrc: parse changes.csv: %w", err) + } + entries = append(entries, changeEntry{ + Path: strings.Clone(record[0]), + Timestamp: strings.Clone(record[1]), + }) + } + return entries, nil +} +``` + +**Step 5: Rewrite Fetch method** + +Replace the entire `Fetch` method with: + +```go +// Fetch implements feed.Adapter. Two-phase: +// 1. Download changes.csv to discover updated CSAF advisory files +// 2. Download and parse each changed per-CVE CSAF file +func (a *Adapter) Fetch(ctx context.Context, cursorJSON json.RawMessage) (*feed.FetchResult, error) { + var cur Cursor + if len(cursorJSON) > 0 { + if err := json.Unmarshal(cursorJSON, &cur); err != nil { + return nil, fmt.Errorf("msrc: parse cursor: %w", err) + } + } + + // Phase 1: download changes.csv + if err := a.rateLimiter.Wait(ctx); err != nil { + return nil, fmt.Errorf("msrc: rate limit: %w", err) + } + + changesURL := baseURL + "changes.csv" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, changesURL, nil) + if err != nil { + return nil, fmt.Errorf("msrc: build changes request: %w", err) + } + + resp, err := a.client.Do(req) //nolint:gosec // URL constructed from constant base + if err != nil { + return nil, fmt.Errorf("msrc: fetch changes.csv: %w", err) + } + defer resp.Body.Close() //nolint:errcheck + + if resp.StatusCode != http.StatusOK { + io.Copy(io.Discard, resp.Body) //nolint:errcheck,gosec // drain for connection reuse + return nil, fmt.Errorf("msrc: changes.csv HTTP %d", resp.StatusCode) + } + + entries, err := parseChangesCSV(io.LimitReader(resp.Body, maxChangesSize)) + if err != nil { + return nil, err + } + + // Filter to entries newer than cursor + var pending []changeEntry + var latestTimestamp string + for _, e := range entries { + if e.Timestamp > latestTimestamp { + latestTimestamp = e.Timestamp + } + if cur.LastUpdated != "" && e.Timestamp <= cur.LastUpdated { + continue + } + pending = append(pending, e) + } + + // Short-circuit: no new changes + if len(pending) == 0 { + effectiveTS := cur.LastUpdated + if latestTimestamp > effectiveTS { + effectiveTS = latestTimestamp + } + nextCursor := Cursor{LastUpdated: effectiveTS} + nextCursorJSON, marshalErr := json.Marshal(nextCursor) + if marshalErr != nil { + return nil, fmt.Errorf("msrc: marshal cursor: %w", marshalErr) + } + return &feed.FetchResult{ + SourceMeta: feed.SourceMeta{ + SourceName: SourceName, + FetchedAt: time.Now().UTC(), + }, + NextCursor: nextCursorJSON, + LastPage: true, + }, nil + } + + // Phase 2: download and parse each changed CSAF file + fetchedAt := time.Now().UTC() + var allPatches []feed.CanonicalPatch + + for _, entry := range pending { + if err := a.rateLimiter.Wait(ctx); err != nil { + return nil, fmt.Errorf("msrc: rate limit: %w", err) + } + + fileURL := baseURL + entry.Path + fileReq, reqErr := http.NewRequestWithContext(ctx, http.MethodGet, fileURL, nil) + if reqErr != nil { + return nil, fmt.Errorf("msrc: build request for %s: %w", entry.Path, reqErr) + } + fileReq.Header.Set("Accept", "application/json") + + fileResp, doErr := a.client.Do(fileReq) //nolint:gosec // URL constructed from constant base + CSV path + if doErr != nil { + return nil, fmt.Errorf("msrc: fetch %s: %w", entry.Path, doErr) + } + + if fileResp.StatusCode != http.StatusOK { + io.Copy(io.Discard, fileResp.Body) //nolint:errcheck,gosec // drain for connection reuse + fileResp.Body.Close() //nolint:errcheck,gosec + return nil, fmt.Errorf("msrc: %s HTTP %d", entry.Path, fileResp.StatusCode) + } + + body, readErr := io.ReadAll(io.LimitReader(fileResp.Body, maxCSAFDocSize)) + fileResp.Body.Close() //nolint:errcheck,gosec + if readErr != nil { + return nil, fmt.Errorf("msrc: read %s: %w", entry.Path, readErr) + } + + doc, parseErr := parseCSAFDocument(body) + if parseErr != nil { + return nil, fmt.Errorf("msrc: parse %s: %w", entry.Path, parseErr) + } + + patches := csafToPatches(doc) + allPatches = append(allPatches, patches...) + } + + // Update cursor to latest timestamp seen + effectiveTS := cur.LastUpdated + if latestTimestamp > effectiveTS { + effectiveTS = latestTimestamp + } + nextCursor := Cursor{LastUpdated: effectiveTS} + nextCursorJSON, err := json.Marshal(nextCursor) + if err != nil { + return nil, fmt.Errorf("msrc: marshal cursor: %w", err) + } + + return &feed.FetchResult{ + Patches: allPatches, + SourceMeta: feed.SourceMeta{ + SourceName: SourceName, + FetchedAt: fetchedAt, + }, + NextCursor: nextCursorJSON, + LastPage: true, + }, nil +} +``` + +**Step 6: Run existing CSAF parsing tests** + +```bash +cd && go test ./internal/feed/msrc/... -run TestCSAFToPatches -v -count=1 +``` + +Expected: All 6 `TestCSAFToPatches_*` tests PASS (these test `csafToPatches()` which is unchanged). + +The `TestFetch_*` and `TestParseUpdates` tests will fail — that's expected and fixed in Task 3. + +**Step 7: Commit** + +```bash +git add internal/feed/msrc/adapter.go +git commit -m "fix(msrc): switch to real CSAF 2.0 static file distribution + +The previous /cvrf/v3.0/csaf/{id} endpoint never existed on Microsoft's +API — it returned HTTP 400 for all release IDs. The adapter now uses +Microsoft's CSAF static file distribution at msrc.microsoft.com/csaf/ +advisories/, which serves proper CSAF 2.0 JSON per-CVE files. + +Discovery uses changes.csv (standard CSAF directory mechanism) for +incremental sync. The csafToPatches() and buildVendorEnrichment() +functions are unchanged." +``` + +--- + +### Task 3: Update MSRC adapter unit tests for new Fetch flow + +**Files:** +- Modify: `internal/feed/msrc/adapter_test.go` + +**Depends on:** Task 2 (adapter rewrite) + +The test file has two sections: +1. **CSAF parsing tests** (lines 66-776) — `TestCSAFToPatches_*` and `csafToPatchesFromJSON` helper. These are UNCHANGED. +2. **Fetch tests** (lines 433-651) — `TestFetch_*`, `TestParseUpdates`, and `redirectTransport`. These must be REWRITTEN. + +**Step 1: Remove dead test code** + +Delete: +- `TestParseUpdates` function (tests removed `parseUpdates()`) +- `redirectTransport` struct and its `RoundTrip` method (replace with `testutil.NewURLRewriteTransport`) + +**Step 2: Add `parseChangesCSV` unit test** + +```go +func TestParseChangesCSV(t *testing.T) { + t.Parallel() + + body := `"2026/msrc_cve-2026-3909.json","2026-03-18T01:00:00Z" +"2026/msrc_cve-2026-21510.json","2026-03-17T07:00:00Z" +"2025/msrc_cve-2025-14174.json","2026-03-12T07:00:00Z" +` + + entries, err := parseChangesCSV(strings.NewReader(body)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(entries) != 3 { + t.Fatalf("len(entries) = %d, want 3", len(entries)) + } + if entries[0].Path != "2026/msrc_cve-2026-3909.json" { + t.Errorf("entries[0].Path = %q, want %q", entries[0].Path, "2026/msrc_cve-2026-3909.json") + } + if entries[0].Timestamp != "2026-03-18T01:00:00Z" { + t.Errorf("entries[0].Timestamp = %q, want %q", entries[0].Timestamp, "2026-03-18T01:00:00Z") + } +} +``` + +**Step 3: Rewrite Fetch tests** + +Replace `TestFetch_Success` with a test that: +1. Creates an httptest server with two route handlers: + - `/csaf/advisories/changes.csv` → serves a 1-entry CSV fixture + - `/csaf/advisories/2026/msrc_cve-2026-21001.json` → serves `minimalCSAFDoc` (the existing test constant) +2. Uses `testutil.NewURLRewriteTransport("https://msrc.microsoft.com", srv.URL, http.DefaultTransport)` — import `"github.com/scarson/cvert-ops/internal/testutil"` +3. Calls `adapter.Fetch(ctx, nil)` with nil cursor +4. Asserts: 1 patch returned, CVEID = "CVE-2026-21001", SourceName = "msrc", cursor.LastUpdated is set, LastPage = true + +Replace `TestFetch_ShortCircuit` with a test that: +1. Creates server serving a 1-entry changes.csv with timestamp `"2026-03-12T08:00:00Z"` +2. Provides a cursor with `LastUpdated: "2026-03-12T08:00:00Z"` (same as CSV) +3. Asserts: 0 patches, only 1 HTTP request (changes.csv only, no CSAF file requests) + +Replace `TestFetch_HTTPError` with a test that: +1. Creates server returning 500 on `/changes.csv` +2. Asserts: error contains "HTTP 500" + +Replace `TestFetch_CSAFHTTPError` with a test that: +1. Creates server returning valid changes.csv but 500 on the CSAF file request +2. Asserts: error contains "HTTP 500" + +Remove `TestFetch_InvalidCursorDate` entirely — the OData injection test is no longer relevant (no OData filter is constructed). The cursor is just a timestamp string compared locally. + +**Step 4: Run all MSRC tests** + +```bash +cd && go test ./internal/feed/msrc/... -v -count=1 +``` + +Expected: ALL tests PASS — both the unchanged CSAF parsing tests and the new Fetch tests. + +**Step 5: Commit** + +```bash +git add internal/feed/msrc/adapter_test.go +git commit -m "test(msrc): update Fetch tests for CSAF static file distribution" +``` + +**Phase B review loop:** After completing Tasks 2-3, run the review loop described in the Subagent Execution Protocol. Specifically verify: (1) all 6 `TestCSAFToPatches_*` tests still pass unchanged; (2) new Fetch tests cover success, short-circuit, and both error paths; (3) no existing test was deleted without replacement; (4) `parseChangesCSV` has its own unit test; (5) no `url` or `regexp` imports remain (removed dead code); (6) ABOUTME comments are updated if needed. + +--- + +## Phase C: Golden File Tests + +### Task 4: Write MSRC golden file test + +**Files:** +- Create: `internal/feed/msrc/golden_test.go` + +**Depends on:** Tasks 1 (fixtures) and 2 (adapter fix) + +This test is in the `msrc_test` package (external test package). It serves the golden fixtures via httptest, runs the adapter's `Fetch()`, and verifies the real CSAF data parses correctly. + +**Step 1: Write the test** + +```go +// ABOUTME: Golden file test for the MSRC adapter using real CSAF 2.0 files. +// ABOUTME: Verifies vendor enrichment, CVSS extraction, and CSAF parsing from real Microsoft data. +package msrc_test +``` + +The test must: +1. Read `testdata/golden/changes.csv` fixture +2. Read all `testdata/golden/csaf/*.json` fixture files into a map keyed by path +3. Create an httptest server that routes: + - Requests ending in `/changes.csv` → serve the CSV fixture + - Requests containing path segments matching CSAF filenames → serve the corresponding fixture +4. Use `testutil.NewURLRewriteTransport("https://msrc.microsoft.com", srv.URL, http.DefaultTransport)` +5. Create adapter with `msrc.New(client)`, call `Fetch(ctx, nil)` +6. Loop until `LastPage == true`, collecting all patches + +**Required assertions (from Phase 10 plan Task 10F):** +1. `len(allPatches) > 0` — non-zero patches +2. Every patch: `p.CVEID != ""` — all entries map to CVE IDs +3. At least one patch: `p.VendorEnrichment != nil` with `len(p.VendorEnrichment.Data) > 0` +4. At least one patch: `p.CVSSv3Score != nil` — CVSS data extracted +5. Every patch: `p.CVEID` starts with `"CVE-"` — proper CVE format +6. **Falsy-value check (testing-pitfalls §9.4):** If any patch has `CVSSv3Score == 0.0`, log it as correctly preserved + +**Step 2: Run and verify** + +```bash +cd && go test ./internal/feed/msrc/... -run TestFetch_GoldenFiles -v -count=1 +``` + +Expected: PASS + +**Step 3: Commit** + +```bash +git add internal/feed/msrc/golden_test.go +git commit -m "test: add MSRC golden file test against real CSAF 2.0 advisories" +``` + +--- + +### Task 5: Run EPSS golden file test + +**Files:** `internal/feed/epss/golden_test.go` (already created in worktree) + +**Depends on:** Docker Desktop must be running (testcontainers) + +The EPSS golden test was written in a previous session but never executed (computer shut down). Run it now. + +**Step 1: Verify Docker is available** + +```bash +docker info >/dev/null 2>&1 && echo "Docker available" || echo "Docker NOT available" +``` + +If Docker is NOT available, this is a **HARD BLOCKER**. Stop and report. + +**Step 2: Run the test** + +```bash +cd && go test ./internal/feed/epss/... -run TestApply_GoldenFiles -v -count=1 -timeout=300s +``` + +Expected: PASS — seeds NVD CVEs via merge pipeline, applies EPSS scores, verifies DB values. + +If the test FAILS, debug and fix. Common issues: +- NVD fixture path resolution (`../nvd/testdata/golden/` relative to EPSS package) +- EPSS rate limiter blocking (24h limiter — should succeed on first call) +- Testcontainer startup timeout + +**Step 3: If test passes, no commit needed** (test file was already committed by previous session) + +If test required fixes, commit the fixes: +```bash +git add internal/feed/epss/golden_test.go +git commit -m "fix: correct EPSS golden test [describe what was fixed]" +``` + +--- + +## Phase D: SeedCorpus Integration + +### Task 6: Add MSRC to SeedCorpus helper + +**Files:** +- Modify: `internal/testutil/seedcorpus.go` + +**Depends on:** Tasks 1-2 (fixtures + adapter fix) + +The `SeedCorpus` function currently seeds 6 feeds: NVD, MITRE, GHSA, OSV, KEV, Red Hat. MSRC is missing. Add it between KEV and Red Hat (matching source precedence order from PLAN.md §5.1). + +**Step 1: Add MSRC import** + +Add to the import block: +```go +"github.com/scarson/cvert-ops/internal/feed/msrc" +``` + +**Step 2: Add MSRC to feeds list** + +In the `feeds` slice (around line 58-65), add MSRC between KEV and Red Hat: + +```go +feeds := []feedDef{ + {"nvd", "nvd", fetchNVDGolden}, + {"mitre", "mitre", fetchMITREGolden}, + {"ghsa", "ghsa", fetchGHSAGolden}, + {"osv", "osv", fetchOSVGolden}, + {"kev", "kev", fetchKEVGolden}, + {"msrc", "msrc", fetchMSRCGolden}, // ADD THIS LINE + {"redhat", "redhat", fetchRedHatGolden}, +} +``` + +**Step 3: Add `fetchMSRCGolden` function** + +Add this function after `fetchKEVGolden` and before `fetchRedHatGolden`: + +```go +func fetchMSRCGolden(t *testing.T, projectRoot string) []feed.CanonicalPatch { + t.Helper() + goldenDir := filepath.Join(projectRoot, "internal", "feed", "msrc", "testdata", "golden") + + // Read changes.csv fixture + changesData, err := os.ReadFile(filepath.Join(goldenDir, "changes.csv")) + if err != nil { + t.Fatalf("MSRC changes.csv fixture missing: %v", err) + } + + // Read all CSAF fixture files into a map + csafDir := filepath.Join(goldenDir, "csaf") + csafEntries, err := os.ReadDir(csafDir) + if err != nil { + t.Fatalf("MSRC CSAF fixtures missing: %v", err) + } + + csafByName := make(map[string][]byte) + for _, e := range csafEntries { + if filepath.Ext(e.Name()) != ".json" { + continue + } + data, readErr := os.ReadFile(filepath.Join(csafDir, e.Name())) + if readErr != nil { + t.Fatalf("read MSRC CSAF fixture %s: %v", e.Name(), readErr) + } + csafByName[e.Name()] = data + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + path := r.URL.Path + + if strings.HasSuffix(path, "/changes.csv") { + w.Header().Set("Content-Type", "text/csv") + w.Write(changesData) //nolint:errcheck + return + } + + // Serve CSAF files by filename + for name, data := range csafByName { + if strings.HasSuffix(path, "/"+name) { + w.Write(data) //nolint:errcheck + return + } + } + + http.NotFound(w, r) + })) + t.Cleanup(srv.Close) + + client := &http.Client{ + Transport: NewURLRewriteTransport("https://msrc.microsoft.com", srv.URL, http.DefaultTransport), + } + + return fetchAllPatches(t, msrc.New(client), nil) +} +``` + +**Step 4: Run SeedCorpus test** + +```bash +cd && go test ./internal/testutil/... -run TestSeedCorpus -v -count=1 -timeout=600s +``` + +Expected: PASS — now seeds 8 feeds (was 7 before, added MSRC). The test asserts `FeedsSeeded == len(requiredFeeds)` where `requiredFeeds` includes "msrc". + +Wait — check the existing test's `requiredFeeds` list. If it doesn't already include "msrc", this will fail correctly. If it does include "msrc", the test was already expecting MSRC and was previously failing (or the test counted differently). Read `internal/testutil/seedcorpus_test.go` to verify, and update `requiredFeeds` if needed. + +**Step 5: Commit** + +```bash +git add internal/testutil/seedcorpus.go +git commit -m "test: add MSRC to SeedCorpus golden fixture helper" +``` + +**Phase D review loop:** After completing Task 6, run the review loop. Verify: (1) MSRC is positioned correctly in source precedence order; (2) the test server handles both changes.csv and per-CVE CSAF file routes; (3) `fetchMSRCGolden` uses `NewURLRewriteTransport` consistently with other adapters; (4) `SeedCorpus` now reports 8 feeds seeded (not 7). + +--- + +## Phase E: Verification & Documentation + +### Task 7: Verify full feed test suite (Task 10H from Phase 10 plan) + +**Step 1: Run all feed adapter tests** + +```bash +cd && go test ./internal/feed/... -v -count=1 -timeout=300s +``` + +Verify: +1. All golden file tests RUN (not SKIP) for: NVD, KEV, GHSA, MITRE, OSV, Red Hat, MSRC +2. EPSS golden test may skip if Docker is not available (it uses testcontainers) — this is acceptable for feed-only verification +3. All existing inline-JSON tests still pass +4. No compilation errors or import cycles + +**Step 2: Verify full-project compilation** + +```bash +cd && go build ./... +``` + +Expected: clean build, no errors. + +**Step 3: No commit** — this is a verification step. + +--- + +### Task 8: Final verification of changed packages (Task 12 from Phase 10 plan) + +```bash +cd && go test ./internal/feed/... ./internal/testutil/... -count=1 -timeout=600s +``` + +Also verify compilation: +```bash +cd && go build ./... +``` + +If EPSS test needs Docker and it's available, also run: +```bash +cd && go test ./internal/feed/epss/... -run TestApply_GoldenFiles -v -count=1 -timeout=300s +``` + +All tests must pass. If any golden file tests skip, go back and fix the missing fixtures. + +--- + +### Task 9: Document the refresh process (Task 13 from Phase 10 plan) + +**Files:** +- Modify: `dev/plans/2026-03-15-phase10-test-fixture-corpus-plan.md` + +Add a `## Refresh Process` section at the end of the document (before the dependency graph if one exists): + +```markdown +## Refresh Process + +1. Re-run the capture: `go run ./dev/cmd/capture-feeds/... all` +2. Re-run the selection agent (Task 6 instructions) against new captures +3. Review and commit the updated canonical manifest at `dev/plans/test-fixture-manifest.json` +4. Re-run extraction: `go run ./dev/cmd/extract-fixtures/...` +5. For MSRC: download updated CSAF files from `https://msrc.microsoft.com/csaf/advisories/` and rebuild `changes.csv` +6. Run all adapter tests: `go test ./internal/feed/...` +7. If tests pass, commit the updated manifest and fixtures together +8. If tests fail, investigate — the upstream schema may have changed + +**When to refresh:** +- When an adapter test breaks in a way suggesting upstream schema change +- When adding a new edge case category to the matrix +- When adding a new feed adapter + +**Adding a new feed adapter:** +1. Add a capture case to `dev/cmd/capture-feeds/main.go` +2. Add extraction logic to `dev/cmd/extract-fixtures/main.go` +3. Add a `golden_test.go` to the new adapter package +4. Re-run capture and extraction to populate fixtures +``` + +**Commit:** + +```bash +git add dev/plans/2026-03-15-phase10-test-fixture-corpus-plan.md +git commit -m "docs: add fixture corpus refresh process documentation" +``` + +--- + +## Dependency Graph + +``` +Task 1 (capture fixtures) ──┐ + ├──→ Task 2 (fix adapter) ──→ Task 3 (fix tests) ──→ Task 4 (MSRC golden test) + │ │ + └──────────────────────────────────────────────────→ Task 6 (SeedCorpus) + │ +Task 5 (EPSS golden test) ─────────────────────────────────────────────────────────────→ │ + ↓ + Task 7 (verify feed suite) + ↓ + Task 8 (final verification) + ↓ + Task 9 (documentation) +``` + +Tasks 1-4 and Task 5 are independent and can run in parallel. Tasks 7-9 are sequential and depend on all prior tasks. + +--- + +## Execution Notes + +- **All work in the worktree** at `.claude/worktrees/phase10-fixture-corpus` +- **Do NOT run `go test ./...`** — Docker container overload. Only run relevant package subsets. +- **MSRC rate limiter:** The adapter uses 1 req/sec. For golden tests with 3-5 fixture files, this adds ~3-5 seconds of delay. Acceptable. +- **EPSS test requires Docker Desktop** — if unavailable, this is a hard blocker for Task 5 only. Other tasks can proceed. +- **Worktree already has dev merged** — SCIM code is present. No further merges needed. diff --git a/internal/api/scim_admin.go b/internal/api/scim_admin.go index f4020cf3..0b456a9b 100644 --- a/internal/api/scim_admin.go +++ b/internal/api/scim_admin.go @@ -405,13 +405,13 @@ func (srv *Server) patchSCIMGroupMappingHandler(w http.ResponseWriter, r *http.R } // Get the SCIM group and verify it belongs to this org. - scimGroup, err := srv.store.GetSCIMGroup(r.Context(), groupID) + scimGroup, err := srv.store.GetSCIMGroup(r.Context(), orgID, groupID) if err != nil { slog.ErrorContext(r.Context(), "scim group mapping patch: get group", "error", err) writeProblem(w, http.StatusInternalServerError, "internal error") return } - if scimGroup == nil || scimGroup.OrgID != orgID { + if scimGroup == nil { writeProblem(w, http.StatusNotFound, "SCIM group not found") return } @@ -442,7 +442,7 @@ func (srv *Server) patchSCIMGroupMappingHandler(w http.ResponseWriter, r *http.R // Validate mapped_group_id if provided: must be same org and active. if groupIDSent { - group, err := srv.store.GetGroupIfActive(r.Context(), *req.MappedGroupID) + group, err := srv.store.GetGroupIfActive(r.Context(), orgID, *req.MappedGroupID) if err != nil { slog.ErrorContext(r.Context(), "scim group mapping patch: get notification group", "error", err) writeProblem(w, http.StatusInternalServerError, "internal error") @@ -452,10 +452,6 @@ func (srv *Server) patchSCIMGroupMappingHandler(w http.ResponseWriter, r *http.R writeProblem(w, http.StatusBadRequest, "notification group not found or deleted") return } - if group.OrgID != orgID { - writeProblem(w, http.StatusBadRequest, "notification group belongs to a different organization") - return - } } // Capture old mapping for comparison. @@ -478,7 +474,7 @@ func (srv *Server) patchSCIMGroupMappingHandler(w http.ResponseWriter, r *http.R mappedGroupIDPtr = &scimGroup.MappedGroupID.UUID } - if err := srv.store.UpdateSCIMGroupMapping(r.Context(), groupID, mappedRolePtr, mappedGroupIDPtr); err != nil { + if err := srv.store.UpdateSCIMGroupMapping(r.Context(), orgID, groupID, mappedRolePtr, mappedGroupIDPtr); err != nil { slog.ErrorContext(r.Context(), "scim group mapping patch: update", "error", err) writeProblem(w, http.StatusInternalServerError, "internal error") return @@ -497,7 +493,7 @@ func (srv *Server) patchSCIMGroupMappingHandler(w http.ResponseWriter, r *http.R } // Apply immediate effects to all current members. - members, err := srv.store.ListSCIMGroupMembers(r.Context(), groupID) + members, err := srv.store.ListSCIMGroupMembers(r.Context(), orgID, groupID) if err != nil { slog.ErrorContext(r.Context(), "scim group mapping patch: list members", "error", err) writeProblem(w, http.StatusInternalServerError, "internal error") @@ -545,7 +541,7 @@ func (srv *Server) patchSCIMGroupMappingHandler(w http.ResponseWriter, r *http.R }) // Re-read the group to get updated state. - updated, err := srv.store.GetSCIMGroup(r.Context(), groupID) + updated, err := srv.store.GetSCIMGroup(r.Context(), orgID, groupID) if err != nil || updated == nil { slog.ErrorContext(r.Context(), "scim group mapping patch: re-read", "error", err) writeProblem(w, http.StatusInternalServerError, "internal error") diff --git a/internal/api/scim_admin_test.go b/internal/api/scim_admin_test.go index a122fae0..45ad4be1 100644 --- a/internal/api/scim_admin_test.go +++ b/internal/api/scim_admin_test.go @@ -739,7 +739,7 @@ func TestGroupMapping_ClearMapping(t *testing.T) { t.Fatalf("create scim group: %v", err) } admin := "admin" - if err := env.db.UpdateSCIMGroupMapping(ctx, scimGroup.ID, &admin, nil); err != nil { + if err := env.db.UpdateSCIMGroupMapping(ctx, env.orgID, scimGroup.ID, &admin, nil); err != nil { t.Fatalf("set initial mapping: %v", err) } diff --git a/internal/api/scim_e2e_test.go b/internal/api/scim_e2e_test.go index fa8577f3..9679c2b0 100644 --- a/internal/api/scim_e2e_test.go +++ b/internal/api/scim_e2e_test.go @@ -330,7 +330,7 @@ func TestSCIME2E_GroupRoleMapping(t *testing.T) { // Set mapped_role="admin" via store (simulating admin action). adminRole := "admin" - if err := env.db.UpdateSCIMGroupMapping(ctx, groupUUID, &adminRole, nil); err != nil { + if err := env.db.UpdateSCIMGroupMapping(ctx, env.orgID, groupUUID, &adminRole, nil); err != nil { t.Fatalf("UpdateSCIMGroupMapping: %v", err) } diff --git a/internal/api/scim_groups_handler.go b/internal/api/scim_groups_handler.go index b0b860a0..61c53a73 100644 --- a/internal/api/scim_groups_handler.go +++ b/internal/api/scim_groups_handler.go @@ -33,6 +33,8 @@ func (srv *Server) scimCreateGroup(w http.ResponseWriter, r *http.Request) { ctx := r.Context() orgID := ctx.Value(ctxOrgID).(uuid.UUID) + slog.InfoContext(ctx, "scim create group", slog.String("org_id", orgID.String())) + var body scimGroupRequest if err := json.NewDecoder(r.Body).Decode(&body); err != nil { writeSCIMError(w, http.StatusBadRequest, "invalidValue", "invalid JSON body") @@ -86,7 +88,7 @@ func (srv *Server) scimCreateGroup(w http.ResponseWriter, r *http.Request) { } // Load members for response. - memberIDs, err := srv.store.ListSCIMGroupMembers(ctx, group.ID) + memberIDs, err := srv.store.ListSCIMGroupMembers(ctx, orgID, group.ID) if err != nil { slog.ErrorContext(ctx, "scim: list members after create", "error", err) } @@ -108,6 +110,10 @@ func (srv *Server) scimCreateGroup(w http.ResponseWriter, r *http.Request) { // scimGetGroup handles GET /Groups/{id}. func (srv *Server) scimGetGroup(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + orgID := ctx.Value(ctxOrgID).(uuid.UUID) + + slog.InfoContext(ctx, "scim get group", slog.String("org_id", orgID.String())) + groupIDStr := chi.URLParam(r, "id") groupID, err := uuid.Parse(groupIDStr) if err != nil { @@ -115,7 +121,7 @@ func (srv *Server) scimGetGroup(w http.ResponseWriter, r *http.Request) { return } - group, err := srv.store.GetSCIMGroup(ctx, groupID) + group, err := srv.store.GetSCIMGroup(ctx, orgID, groupID) if err != nil { slog.ErrorContext(ctx, "scim: get group", "error", err) writeSCIMError(w, http.StatusInternalServerError, "", "internal error") @@ -126,14 +132,7 @@ func (srv *Server) scimGetGroup(w http.ResponseWriter, r *http.Request) { return } - // Verify the group belongs to this org. - orgID := ctx.Value(ctxOrgID).(uuid.UUID) - if group.OrgID != orgID { - writeSCIMError(w, http.StatusNotFound, "", "group not found") - return - } - - memberIDs, err := srv.store.ListSCIMGroupMembers(ctx, group.ID) + memberIDs, err := srv.store.ListSCIMGroupMembers(ctx, orgID, group.ID) if err != nil { slog.ErrorContext(ctx, "scim: list group members", "error", err) writeSCIMError(w, http.StatusInternalServerError, "", "internal error") @@ -149,6 +148,8 @@ func (srv *Server) scimListGroups(w http.ResponseWriter, r *http.Request) { ctx := r.Context() orgID := ctx.Value(ctxOrgID).(uuid.UUID) + slog.InfoContext(ctx, "scim list groups", slog.String("org_id", orgID.String())) + filterStr := r.URL.Query().Get("filter") exprs, err := parseSCIMFilter(filterStr) if err != nil { @@ -167,7 +168,7 @@ func (srv *Server) scimListGroups(w http.ResponseWriter, r *http.Request) { var filtered []any for _, g := range groups { if matchesSCIMGroupFilter(g.ID.String(), g.ExternalID.String, g.DisplayName, exprs) { - memberIDs, mErr := srv.store.ListSCIMGroupMembers(ctx, g.ID) + memberIDs, mErr := srv.store.ListSCIMGroupMembers(ctx, orgID, g.ID) if mErr != nil { slog.ErrorContext(ctx, "scim: list group members for list", "group_id", g.ID, "error", mErr) continue @@ -195,6 +196,8 @@ func (srv *Server) scimReplaceGroup(w http.ResponseWriter, r *http.Request) { ctx := r.Context() orgID := ctx.Value(ctxOrgID).(uuid.UUID) + slog.InfoContext(ctx, "scim replace group", slog.String("org_id", orgID.String())) + groupIDStr := chi.URLParam(r, "id") groupID, err := uuid.Parse(groupIDStr) if err != nil { @@ -202,13 +205,13 @@ func (srv *Server) scimReplaceGroup(w http.ResponseWriter, r *http.Request) { return } - group, err := srv.store.GetSCIMGroup(ctx, groupID) + group, err := srv.store.GetSCIMGroup(ctx, orgID, groupID) if err != nil { slog.ErrorContext(ctx, "scim: get group for replace", "error", err) writeSCIMError(w, http.StatusInternalServerError, "", "internal error") return } - if group == nil || group.OrgID != orgID { + if group == nil { writeSCIMError(w, http.StatusNotFound, "", "group not found") return } @@ -229,7 +232,7 @@ func (srv *Server) scimReplaceGroup(w http.ResponseWriter, r *http.Request) { externalID = &body.ExternalID } - if err := srv.store.UpdateSCIMGroup(ctx, groupID, body.DisplayName, externalID); err != nil { + if err := srv.store.UpdateSCIMGroup(ctx, orgID, groupID, body.DisplayName, externalID); err != nil { if strings.Contains(err.Error(), "duplicate key") || strings.Contains(err.Error(), "unique constraint") { writeSCIMError(w, http.StatusConflict, "uniqueness", "group displayName already exists in this organization") return @@ -252,7 +255,7 @@ func (srv *Server) scimReplaceGroup(w http.ResponseWriter, r *http.Request) { } // Diff members: current vs new. - currentMembers, err := srv.store.ListSCIMGroupMembers(ctx, groupID) + currentMembers, err := srv.store.ListSCIMGroupMembers(ctx, orgID, groupID) if err != nil { slog.ErrorContext(ctx, "scim: list members for diff", "error", err) writeSCIMError(w, http.StatusInternalServerError, "", "internal error") @@ -274,7 +277,7 @@ func (srv *Server) scimReplaceGroup(w http.ResponseWriter, r *http.Request) { } // Reload group to get the current mapped_role and mapped_group_id. - group, _ = srv.store.GetSCIMGroup(ctx, groupID) + group, _ = srv.store.GetSCIMGroup(ctx, orgID, groupID) // Add new members. for uid := range newMemberSet { @@ -290,7 +293,7 @@ func (srv *Server) scimReplaceGroup(w http.ResponseWriter, r *http.Request) { // Remove absent members. for _, uid := range currentMembers { if !newMemberSet[uid] { - if rmErr := srv.store.RemoveSCIMGroupMember(ctx, groupID, uid); rmErr != nil { + if rmErr := srv.store.RemoveSCIMGroupMember(ctx, orgID, groupID, uid); rmErr != nil { slog.ErrorContext(ctx, "scim: remove member in replace", "user_id", uid, "error", rmErr) continue } @@ -309,8 +312,8 @@ func (srv *Server) scimReplaceGroup(w http.ResponseWriter, r *http.Request) { }) // Reload for response. - memberIDs, _ := srv.store.ListSCIMGroupMembers(ctx, groupID) - group, _ = srv.store.GetSCIMGroup(ctx, groupID) + memberIDs, _ := srv.store.ListSCIMGroupMembers(ctx, orgID, groupID) + group, _ = srv.store.GetSCIMGroup(ctx, orgID, groupID) resp := srv.buildSCIMGroupResponse(r, group.ID.String(), group.ExternalID.String, group.DisplayName, group.CreatedAt.Format("2006-01-02T15:04:05Z"), group.UpdatedAt.Format("2006-01-02T15:04:05Z"), memberIDs) writeSCIMJSON(w, http.StatusOK, resp) } @@ -320,6 +323,8 @@ func (srv *Server) scimPatchGroup(w http.ResponseWriter, r *http.Request) { ctx := r.Context() orgID := ctx.Value(ctxOrgID).(uuid.UUID) + slog.InfoContext(ctx, "scim patch group", slog.String("org_id", orgID.String())) + groupIDStr := chi.URLParam(r, "id") groupID, err := uuid.Parse(groupIDStr) if err != nil { @@ -327,13 +332,13 @@ func (srv *Server) scimPatchGroup(w http.ResponseWriter, r *http.Request) { return } - group, err := srv.store.GetSCIMGroup(ctx, groupID) + group, err := srv.store.GetSCIMGroup(ctx, orgID, groupID) if err != nil { slog.ErrorContext(ctx, "scim: get group for patch", "error", err) writeSCIMError(w, http.StatusInternalServerError, "", "internal error") return } - if group == nil || group.OrgID != orgID { + if group == nil { writeSCIMError(w, http.StatusNotFound, "", "group not found") return } @@ -376,7 +381,7 @@ func (srv *Server) scimPatchGroup(w http.ResponseWriter, r *http.Request) { case "remove": userIDs := srv.extractRemoveTargets(op) for _, userID := range userIDs { - if rmErr := srv.store.RemoveSCIMGroupMember(ctx, group.ID, userID); rmErr != nil { + if rmErr := srv.store.RemoveSCIMGroupMember(ctx, orgID, group.ID, userID); rmErr != nil { slog.ErrorContext(ctx, "scim: patch remove member", "user_id", userID, "error", rmErr) continue } @@ -394,7 +399,7 @@ func (srv *Server) scimPatchGroup(w http.ResponseWriter, r *http.Request) { writeSCIMError(w, http.StatusBadRequest, "invalidValue", "displayName cannot be empty") return } - if updateErr := srv.store.UpdateSCIMGroup(ctx, group.ID, newName, nil); updateErr != nil { + if updateErr := srv.store.UpdateSCIMGroup(ctx, orgID, group.ID, newName, nil); updateErr != nil { if strings.Contains(updateErr.Error(), "duplicate key") || strings.Contains(updateErr.Error(), "unique constraint") { writeSCIMError(w, http.StatusConflict, "uniqueness", "group displayName already exists in this organization") return @@ -425,8 +430,8 @@ func (srv *Server) scimPatchGroup(w http.ResponseWriter, r *http.Request) { }) // Reload for response. - group, _ = srv.store.GetSCIMGroup(ctx, groupID) - memberIDs, _ := srv.store.ListSCIMGroupMembers(ctx, groupID) + group, _ = srv.store.GetSCIMGroup(ctx, orgID, groupID) + memberIDs, _ := srv.store.ListSCIMGroupMembers(ctx, orgID, groupID) resp := srv.buildSCIMGroupResponse(r, group.ID.String(), group.ExternalID.String, group.DisplayName, group.CreatedAt.Format("2006-01-02T15:04:05Z"), group.UpdatedAt.Format("2006-01-02T15:04:05Z"), memberIDs) writeSCIMJSON(w, http.StatusOK, resp) } @@ -436,6 +441,8 @@ func (srv *Server) scimDeleteGroup(w http.ResponseWriter, r *http.Request) { ctx := r.Context() orgID := ctx.Value(ctxOrgID).(uuid.UUID) + slog.InfoContext(ctx, "scim delete group", slog.String("org_id", orgID.String())) + groupIDStr := chi.URLParam(r, "id") groupID, err := uuid.Parse(groupIDStr) if err != nil { @@ -443,20 +450,20 @@ func (srv *Server) scimDeleteGroup(w http.ResponseWriter, r *http.Request) { return } - group, err := srv.store.GetSCIMGroup(ctx, groupID) + group, err := srv.store.GetSCIMGroup(ctx, orgID, groupID) if err != nil { slog.ErrorContext(ctx, "scim: get group for delete", "error", err) writeSCIMError(w, http.StatusInternalServerError, "", "internal error") return } - if group == nil || group.OrgID != orgID { + if group == nil { // Idempotent — already deleted returns 204. w.WriteHeader(http.StatusNoContent) return } // Collect affected non-exempt users before deletion. - memberIDs, err := srv.store.ListSCIMGroupMembers(ctx, groupID) + memberIDs, err := srv.store.ListSCIMGroupMembers(ctx, orgID, groupID) if err != nil { slog.ErrorContext(ctx, "scim: list members for delete", "error", err) writeSCIMError(w, http.StatusInternalServerError, "", "internal error") @@ -476,7 +483,7 @@ func (srv *Server) scimDeleteGroup(w http.ResponseWriter, r *http.Request) { } // Delete the group (CASCADE deletes scim_group_members). - if err := srv.store.DeleteSCIMGroup(ctx, groupID); err != nil { + if err := srv.store.DeleteSCIMGroup(ctx, orgID, groupID); err != nil { slog.ErrorContext(ctx, "scim: delete group", "error", err) writeSCIMError(w, http.StatusInternalServerError, "", "internal error") return @@ -628,6 +635,7 @@ func matchesSCIMGroupFilter(id, externalID, displayName string, exprs []SCIMFilt return false } default: + slog.Warn("scim list groups: unsupported filter attribute", slog.String("attribute", expr.Attr)) //nolint:gosec // G706: slog structured field, not interpolated into log format string return false } } @@ -663,16 +671,9 @@ func (srv *Server) buildSCIMGroupResponse(r *http.Request, id, externalID, displ // For a request to /api/v1/orgs/{org_id}/scim/v2/Groups/..., returns // the URL up to and including /scim/v2. func scimBaseURL(r *http.Request) string { - scheme := "https" - if r.TLS == nil { - scheme = "http" - } - if fwd := r.Header.Get("X-Forwarded-Proto"); fwd != "" { - scheme = fwd - } path := r.URL.Path if idx := strings.Index(path, "/scim/v2"); idx >= 0 { path = path[:idx+len("/scim/v2")] } - return fmt.Sprintf("%s://%s%s", scheme, r.Host, path) + return fmt.Sprintf("%s://%s%s", scimScheme(r), r.Host, path) } diff --git a/internal/api/scim_groups_handler_test.go b/internal/api/scim_groups_handler_test.go index b0657917..f632c314 100644 --- a/internal/api/scim_groups_handler_test.go +++ b/internal/api/scim_groups_handler_test.go @@ -503,7 +503,7 @@ func TestSCIMDeleteGroup_CascadesMembers(t *testing.T) { assert.Equal(t, http.StatusNotFound, getResp.StatusCode) // Verify members are gone. - memberIDs, err := env.db.ListSCIMGroupMembers(ctx, uuid.MustParse(created.ID)) + memberIDs, err := env.db.ListSCIMGroupMembers(ctx, env.orgID, uuid.MustParse(created.ID)) require.NoError(t, err) assert.Empty(t, memberIDs) } @@ -519,7 +519,7 @@ func TestSCIMDeleteGroup_RecomputesRoles(t *testing.T) { group, err := env.db.CreateSCIMGroup(ctx, env.orgID, nil, "AdminGroup") require.NoError(t, err) adminRole := "admin" - require.NoError(t, env.db.UpdateSCIMGroupMapping(ctx, group.ID, &adminRole, nil)) + require.NoError(t, env.db.UpdateSCIMGroupMapping(ctx, env.orgID, group.ID, &adminRole, nil)) require.NoError(t, env.db.AddSCIMGroupMember(ctx, group.ID, userID, env.orgID)) // Recompute role to set user to admin. diff --git a/internal/api/scim_notif_sync.go b/internal/api/scim_notif_sync.go index 4a6b7962..603995f2 100644 --- a/internal/api/scim_notif_sync.go +++ b/internal/api/scim_notif_sync.go @@ -13,7 +13,7 @@ import ( // a member (manually or via SCIM), the existing membership is preserved via ON CONFLICT DO NOTHING. func (srv *Server) syncNotifGroupAdd(ctx context.Context, orgID, userID, mappedGroupID, _ uuid.UUID) error { // Verify the target notification group exists and is not soft-deleted. - group, err := srv.store.GetGroupIfActive(ctx, mappedGroupID) + group, err := srv.store.GetGroupIfActive(ctx, orgID, mappedGroupID) if err != nil { return err } @@ -27,9 +27,9 @@ func (srv *Server) syncNotifGroupAdd(ctx context.Context, orgID, userID, mappedG // syncNotifGroupRemove removes a user from a notification group, but only if: // - The membership is scim_managed=true (manual memberships are preserved) // - No other SCIM group with the same mapped_group_id still includes the user -func (srv *Server) syncNotifGroupRemove(ctx context.Context, _, userID, mappedGroupID, scimGroupID uuid.UUID) error { +func (srv *Server) syncNotifGroupRemove(ctx context.Context, orgID, userID, mappedGroupID, scimGroupID uuid.UUID) error { // Check if another SCIM group maps to the same notification group and includes this user. - count, err := srv.store.CountOtherSCIMGroupsWithSameMapping(ctx, userID, mappedGroupID, scimGroupID) + count, err := srv.store.CountOtherSCIMGroupsWithSameMapping(ctx, orgID, userID, mappedGroupID, scimGroupID) if err != nil { return err } @@ -37,5 +37,5 @@ func (srv *Server) syncNotifGroupRemove(ctx context.Context, _, userID, mappedGr return nil // another SCIM group still maps here — keep the membership } - return srv.store.RemoveSCIMManagedGroupMember(ctx, mappedGroupID, userID) + return srv.store.RemoveSCIMManagedGroupMember(ctx, mappedGroupID, userID, orgID) } diff --git a/internal/api/scim_notif_sync_test.go b/internal/api/scim_notif_sync_test.go index 7fea0980..b84a7fd5 100644 --- a/internal/api/scim_notif_sync_test.go +++ b/internal/api/scim_notif_sync_test.go @@ -188,10 +188,10 @@ func TestNotifSync_Remove_MultiMapping(t *testing.T) { // Map both SCIM groups to the same notification group. notifGroupID := notifGroup.ID - if err := db.UpdateSCIMGroupMapping(ctx, scimGroupA.ID, nil, ¬ifGroupID); err != nil { + if err := db.UpdateSCIMGroupMapping(ctx, org.ID, scimGroupA.ID, nil, ¬ifGroupID); err != nil { t.Fatalf("setup: UpdateSCIMGroupMapping A: %v", err) } - if err := db.UpdateSCIMGroupMapping(ctx, scimGroupB.ID, nil, ¬ifGroupID); err != nil { + if err := db.UpdateSCIMGroupMapping(ctx, org.ID, scimGroupB.ID, nil, ¬ifGroupID); err != nil { t.Fatalf("setup: UpdateSCIMGroupMapping B: %v", err) } @@ -218,7 +218,7 @@ func TestNotifSync_Remove_MultiMapping(t *testing.T) { } // Now remove from SCIM group B — no more mappings, should remove. - if err := db.RemoveSCIMGroupMember(ctx, scimGroupA.ID, user.ID); err != nil { + if err := db.RemoveSCIMGroupMember(ctx, org.ID, scimGroupA.ID, user.ID); err != nil { t.Fatalf("RemoveSCIMGroupMember A: %v", err) } if err := srv.syncNotifGroupRemove(ctx, org.ID, user.ID, notifGroup.ID, scimGroupB.ID); err != nil { @@ -256,7 +256,7 @@ func TestNotifSync_GroupDelete_NoRemoval(t *testing.T) { if err := db.AddSCIMGroupMember(ctx, scimGroup.ID, user.ID, org.ID); err != nil { t.Fatalf("setup: AddSCIMGroupMember: %v", err) } - if err := db.DeleteSCIMGroup(ctx, scimGroup.ID); err != nil { + if err := db.DeleteSCIMGroup(ctx, org.ID, scimGroup.ID); err != nil { t.Fatalf("DeleteSCIMGroup: %v", err) } diff --git a/internal/api/scim_roles_test.go b/internal/api/scim_roles_test.go index d0ac38da..5d2aff0b 100644 --- a/internal/api/scim_roles_test.go +++ b/internal/api/scim_roles_test.go @@ -52,7 +52,7 @@ func setupSCIMGroupWithMapping(t *testing.T, db *testutil.TestDB, ctx context.Co } if mappedRole != nil { - if err := db.UpdateSCIMGroupMapping(ctx, group.ID, mappedRole, nil); err != nil { + if err := db.UpdateSCIMGroupMapping(ctx, orgID, group.ID, mappedRole, nil); err != nil { t.Fatalf("setup: UpdateSCIMGroupMapping: %v", err) } } @@ -227,7 +227,7 @@ func TestRoleRecompute_RemovedFromAllGroups(t *testing.T) { // Add to a group with mapped role, then remove. adminRole := "admin" groupID := setupSCIMGroupWithMapping(t, db, ctx, orgID, userID, "Admins", &adminRole) - if err := db.RemoveSCIMGroupMember(ctx, groupID, userID); err != nil { + if err := db.RemoveSCIMGroupMember(ctx, orgID, groupID, userID); err != nil { t.Fatalf("RemoveSCIMGroupMember: %v", err) } diff --git a/internal/api/scim_users.go b/internal/api/scim_users.go index 5e861ddc..f6c84736 100644 --- a/internal/api/scim_users.go +++ b/internal/api/scim_users.go @@ -34,14 +34,22 @@ func scimProvider(scimConfigID uuid.UUID) string { return fmt.Sprintf("scim:%s", scimConfigID) } +// scimScheme returns "https" or "http" based on TLS state and X-Forwarded-Proto. +// Only trusts X-Forwarded-Proto if it's a valid scheme. +func scimScheme(r *http.Request) string { + if fwd := r.Header.Get("X-Forwarded-Proto"); fwd == "https" || fwd == "http" { + return fwd + } + if r.TLS != nil { + return "https" + } + return "http" +} + // scimUserLocation returns the SCIM resource location for a user. func scimUserLocation(r *http.Request, orgID, userID uuid.UUID) string { - scheme := "https" - if r.TLS == nil { - scheme = "http" - } return fmt.Sprintf("%s://%s/api/v1/orgs/%s/scim/v2/Users/%s", - scheme, r.Host, orgID, userID) + scimScheme(r), r.Host, orgID, userID) } // buildSCIMUser constructs a SCIMUser response from component data. @@ -370,6 +378,11 @@ func (srv *Server) scimGetUser(w http.ResponseWriter, r *http.Request) { provider := scimProvider(scimConfigID) ctx := r.Context() + slog.InfoContext(ctx, "scim get user", + slog.String("org_id", orgID.String()), + slog.String("scim_config_id", scimConfigID.String()), + ) + userIDStr := chi.URLParam(r, "id") userID, err := uuid.Parse(userIDStr) if err != nil { @@ -408,6 +421,11 @@ func (srv *Server) scimListUsers(w http.ResponseWriter, r *http.Request) { provider := scimProvider(scimConfigID) ctx := r.Context() + slog.InfoContext(ctx, "scim list users", + slog.String("org_id", orgID.String()), + slog.String("scim_config_id", scimConfigID.String()), + ) + // Parse pagination params. startIndex := 1 count := 100 @@ -450,6 +468,20 @@ func (srv *Server) scimListUsers(w http.ResponseWriter, r *http.Request) { return } + // Batch-load external IDs for all members. + userIDs := make([]uuid.UUID, len(members)) + for i, m := range members { + userIDs[i] = m.UserID + } + extIDMap := make(map[uuid.UUID]string) + identities, idErr := srv.store.ListIdentitiesByProviderAndUsers(ctx, provider, userIDs) + if idErr != nil { + slog.ErrorContext(ctx, "scim list users: batch load identities", "error", idErr) + } + for _, identity := range identities { + extIDMap[identity.UserID] = identity.ProviderUserID + } + // Apply filters. type scimMember struct { UserID uuid.UUID @@ -464,13 +496,7 @@ func (srv *Server) scimListUsers(w http.ResponseWriter, r *http.Request) { var filtered []scimMember for _, m := range members { active := !m.DeactivatedAt.Valid - - // Get external ID for this member. - extID := "" - identity, _ := srv.store.GetIdentityByProviderAndUser(ctx, provider, m.UserID) - if identity != nil { - extID = identity.ProviderUserID - } + extID := extIDMap[m.UserID] // Apply filters. match := true @@ -548,6 +574,11 @@ func (srv *Server) scimReplaceUser(w http.ResponseWriter, r *http.Request) { provider := scimProvider(scimConfigID) ctx := r.Context() + slog.InfoContext(ctx, "scim replace user", + slog.String("org_id", orgID.String()), + slog.String("scim_config_id", scimConfigID.String()), + ) + userIDStr := chi.URLParam(r, "id") userID, err := uuid.Parse(userIDStr) if err != nil { @@ -705,6 +736,11 @@ func (srv *Server) scimPatchUser(w http.ResponseWriter, r *http.Request) { provider := scimProvider(scimConfigID) ctx := r.Context() + slog.InfoContext(ctx, "scim patch user", + slog.String("org_id", orgID.String()), + slog.String("scim_config_id", scimConfigID.String()), + ) + userIDStr := chi.URLParam(r, "id") userID, err := uuid.Parse(userIDStr) if err != nil { @@ -948,6 +984,11 @@ func (srv *Server) scimDeleteUser(w http.ResponseWriter, r *http.Request) { scimConfigID := r.Context().Value(ctxSCIMConfigID).(uuid.UUID) ctx := r.Context() + slog.InfoContext(ctx, "scim delete user", + slog.String("org_id", orgID.String()), + slog.String("scim_config_id", scimConfigID.String()), + ) + userIDStr := chi.URLParam(r, "id") userID, err := uuid.Parse(userIDStr) if err != nil { @@ -1064,8 +1105,6 @@ func (srv *Server) getSCIMDefaultRole(ctx context.Context, orgID uuid.UUID) stri // scimAuditLog writes an audit log entry for SCIM operations. // SCIM operations have no human actor (actor_id = nil). func (srv *Server) scimAuditLog(r *http.Request, _, _ uuid.UUID, entry audit.Entry) { - // Override: SCIM has no actor. - noActor := uuid.Nil - entry.ActorID = &noActor + entry.ActorID = nil srv.auditLog(r, entry) } diff --git a/internal/store/auth.go b/internal/store/auth.go index 46857fa8..3ea20aaf 100644 --- a/internal/store/auth.go +++ b/internal/store/auth.go @@ -363,6 +363,19 @@ func (s *Store) UpdateUserProfile(ctx context.Context, id uuid.UUID, email, disp }) } +// ListIdentitiesByProviderAndUsers returns all identity rows for a provider +// and set of user IDs. Used by SCIM list users to batch-load external IDs. +func (s *Store) ListIdentitiesByProviderAndUsers(ctx context.Context, provider string, userIDs []uuid.UUID) ([]generated.UserIdentity, error) { + rows, err := s.q.ListIdentitiesByProviderAndUsers(ctx, generated.ListIdentitiesByProviderAndUsersParams{ + Provider: provider, + Column2: userIDs, + }) + if err != nil { + return nil, fmt.Errorf("list identities by provider and users: %w", err) + } + return rows, nil +} + // GetIdentityByProviderAndUser returns the identity row for a provider+user, // or (nil, nil) if not found. func (s *Store) GetIdentityByProviderAndUser(ctx context.Context, provider string, userID uuid.UUID) (*generated.UserIdentity, error) { diff --git a/internal/store/generated/auth.sql.go b/internal/store/generated/auth.sql.go index 6dd0d0cf..d849f422 100644 --- a/internal/store/generated/auth.sql.go +++ b/internal/store/generated/auth.sql.go @@ -11,6 +11,7 @@ import ( "time" "github.com/google/uuid" + "github.com/lib/pq" ) const clearForcePasswordReset = `-- name: ClearForcePasswordReset :exec @@ -314,6 +315,48 @@ func (q *Queries) IsUserEnabled(ctx context.Context, id uuid.UUID) (bool, error) return enabled, err } +const listIdentitiesByProviderAndUsers = `-- name: ListIdentitiesByProviderAndUsers :many +SELECT id, user_id, provider, provider_user_id, email, created_at FROM user_identities +WHERE provider = $1 AND user_id = ANY($2::uuid[]) +` + +type ListIdentitiesByProviderAndUsersParams struct { + Provider string + Column2 []uuid.UUID +} + +// Returns identity rows for a given provider and set of user IDs. +// Used by SCIM list users to batch-load external IDs. +func (q *Queries) ListIdentitiesByProviderAndUsers(ctx context.Context, arg ListIdentitiesByProviderAndUsersParams) ([]UserIdentity, error) { + rows, err := q.db.QueryContext(ctx, listIdentitiesByProviderAndUsers, arg.Provider, pq.Array(arg.Column2)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []UserIdentity + for rows.Next() { + var i UserIdentity + if err := rows.Scan( + &i.ID, + &i.UserID, + &i.Provider, + &i.ProviderUserID, + &i.Email, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const markRefreshTokenUsed = `-- name: MarkRefreshTokenUsed :exec UPDATE refresh_tokens SET used_at = now(), replaced_by_jti = $2 diff --git a/internal/store/generated/groups.sql.go b/internal/store/generated/groups.sql.go index d875adde..517ef6ad 100644 --- a/internal/store/generated/groups.sql.go +++ b/internal/store/generated/groups.sql.go @@ -96,11 +96,16 @@ func (q *Queries) GetGroup(ctx context.Context, arg GetGroupParams) (Group, erro } const getGroupIfActive = `-- name: GetGroupIfActive :one -SELECT id, org_id, name, description, created_at, deleted_at FROM groups WHERE id = $1 AND deleted_at IS NULL +SELECT id, org_id, name, description, created_at, deleted_at FROM groups WHERE id = $1 AND org_id = $2 AND deleted_at IS NULL ` -func (q *Queries) GetGroupIfActive(ctx context.Context, id uuid.UUID) (Group, error) { - row := q.db.QueryRowContext(ctx, getGroupIfActive, id) +type GetGroupIfActiveParams struct { + ID uuid.UUID + OrgID uuid.UUID +} + +func (q *Queries) GetGroupIfActive(ctx context.Context, arg GetGroupIfActiveParams) (Group, error) { + row := q.db.QueryRowContext(ctx, getGroupIfActive, arg.ID, arg.OrgID) var i Group err := row.Scan( &i.ID, @@ -235,16 +240,17 @@ func (q *Queries) RemoveGroupMember(ctx context.Context, arg RemoveGroupMemberPa const removeSCIMManagedGroupMember = `-- name: RemoveSCIMManagedGroupMember :exec DELETE FROM group_members -WHERE group_id = $1 AND user_id = $2 AND scim_managed = true +WHERE group_id = $1 AND user_id = $2 AND org_id = $3 AND scim_managed = true ` type RemoveSCIMManagedGroupMemberParams struct { GroupID uuid.UUID UserID uuid.UUID + OrgID uuid.UUID } func (q *Queries) RemoveSCIMManagedGroupMember(ctx context.Context, arg RemoveSCIMManagedGroupMemberParams) error { - _, err := q.db.ExecContext(ctx, removeSCIMManagedGroupMember, arg.GroupID, arg.UserID) + _, err := q.db.ExecContext(ctx, removeSCIMManagedGroupMember, arg.GroupID, arg.UserID, arg.OrgID) return err } diff --git a/internal/store/generated/scim_groups.sql.go b/internal/store/generated/scim_groups.sql.go index 33456dc2..71068c04 100644 --- a/internal/store/generated/scim_groups.sql.go +++ b/internal/store/generated/scim_groups.sql.go @@ -36,16 +36,23 @@ JOIN scim_groups sg ON sgm.scim_group_id = sg.id WHERE sgm.user_id = $1 AND sg.mapped_group_id = $2 AND sg.id != $3 + AND sgm.org_id = $4 ` type CountOtherSCIMGroupsWithSameMappingParams struct { UserID uuid.UUID MappedGroupID uuid.NullUUID ID uuid.UUID + OrgID uuid.UUID } func (q *Queries) CountOtherSCIMGroupsWithSameMapping(ctx context.Context, arg CountOtherSCIMGroupsWithSameMappingParams) (int32, error) { - row := q.db.QueryRowContext(ctx, countOtherSCIMGroupsWithSameMapping, arg.UserID, arg.MappedGroupID, arg.ID) + row := q.db.QueryRowContext(ctx, countOtherSCIMGroupsWithSameMapping, + arg.UserID, + arg.MappedGroupID, + arg.ID, + arg.OrgID, + ) var column_1 int32 err := row.Scan(&column_1) return column_1, err @@ -82,11 +89,16 @@ func (q *Queries) CreateSCIMGroup(ctx context.Context, arg CreateSCIMGroupParams } const deleteSCIMGroup = `-- name: DeleteSCIMGroup :exec -DELETE FROM scim_groups WHERE id = $1 +DELETE FROM scim_groups WHERE id = $1 AND org_id = $2 ` -func (q *Queries) DeleteSCIMGroup(ctx context.Context, id uuid.UUID) error { - _, err := q.db.ExecContext(ctx, deleteSCIMGroup, id) +type DeleteSCIMGroupParams struct { + ID uuid.UUID + OrgID uuid.UUID +} + +func (q *Queries) DeleteSCIMGroup(ctx context.Context, arg DeleteSCIMGroupParams) error { + _, err := q.db.ExecContext(ctx, deleteSCIMGroup, arg.ID, arg.OrgID) return err } @@ -141,11 +153,16 @@ func (q *Queries) GetSCIMGroupByExternalID(ctx context.Context, arg GetSCIMGroup } const getSCIMGroupByID = `-- name: GetSCIMGroupByID :one -SELECT id, org_id, external_id, display_name, mapped_role, mapped_group_id, created_at, updated_at FROM scim_groups WHERE id = $1 +SELECT id, org_id, external_id, display_name, mapped_role, mapped_group_id, created_at, updated_at FROM scim_groups WHERE id = $1 AND org_id = $2 ` -func (q *Queries) GetSCIMGroupByID(ctx context.Context, id uuid.UUID) (ScimGroup, error) { - row := q.db.QueryRowContext(ctx, getSCIMGroupByID, id) +type GetSCIMGroupByIDParams struct { + ID uuid.UUID + OrgID uuid.UUID +} + +func (q *Queries) GetSCIMGroupByID(ctx context.Context, arg GetSCIMGroupByIDParams) (ScimGroup, error) { + row := q.db.QueryRowContext(ctx, getSCIMGroupByID, arg.ID, arg.OrgID) var i ScimGroup err := row.Scan( &i.ID, @@ -161,11 +178,16 @@ func (q *Queries) GetSCIMGroupByID(ctx context.Context, id uuid.UUID) (ScimGroup } const listSCIMGroupMembers = `-- name: ListSCIMGroupMembers :many -SELECT user_id FROM scim_group_members WHERE scim_group_id = $1 +SELECT user_id FROM scim_group_members WHERE scim_group_id = $1 AND org_id = $2 ` -func (q *Queries) ListSCIMGroupMembers(ctx context.Context, scimGroupID uuid.UUID) ([]uuid.UUID, error) { - rows, err := q.db.QueryContext(ctx, listSCIMGroupMembers, scimGroupID) +type ListSCIMGroupMembersParams struct { + ScimGroupID uuid.UUID + OrgID uuid.UUID +} + +func (q *Queries) ListSCIMGroupMembers(ctx context.Context, arg ListSCIMGroupMembersParams) ([]uuid.UUID, error) { + rows, err := q.db.QueryContext(ctx, listSCIMGroupMembers, arg.ScimGroupID, arg.OrgID) if err != nil { return nil, err } @@ -285,16 +307,17 @@ func (q *Queries) ListUserSCIMGroups(ctx context.Context, arg ListUserSCIMGroups } const removeSCIMGroupMember = `-- name: RemoveSCIMGroupMember :exec -DELETE FROM scim_group_members WHERE scim_group_id = $1 AND user_id = $2 +DELETE FROM scim_group_members WHERE scim_group_id = $1 AND user_id = $2 AND org_id = $3 ` type RemoveSCIMGroupMemberParams struct { ScimGroupID uuid.UUID UserID uuid.UUID + OrgID uuid.UUID } func (q *Queries) RemoveSCIMGroupMember(ctx context.Context, arg RemoveSCIMGroupMemberParams) error { - _, err := q.db.ExecContext(ctx, removeSCIMGroupMember, arg.ScimGroupID, arg.UserID) + _, err := q.db.ExecContext(ctx, removeSCIMGroupMember, arg.ScimGroupID, arg.UserID, arg.OrgID) return err } @@ -309,32 +332,44 @@ func (q *Queries) SetSCIMGroupMembers_Delete(ctx context.Context, scimGroupID uu const updateSCIMGroup = `-- name: UpdateSCIMGroup :exec UPDATE scim_groups SET display_name = $2, external_id = $3, updated_at = now() -WHERE id = $1 +WHERE id = $1 AND org_id = $4 ` type UpdateSCIMGroupParams struct { ID uuid.UUID DisplayName string ExternalID sql.NullString + OrgID uuid.UUID } func (q *Queries) UpdateSCIMGroup(ctx context.Context, arg UpdateSCIMGroupParams) error { - _, err := q.db.ExecContext(ctx, updateSCIMGroup, arg.ID, arg.DisplayName, arg.ExternalID) + _, err := q.db.ExecContext(ctx, updateSCIMGroup, + arg.ID, + arg.DisplayName, + arg.ExternalID, + arg.OrgID, + ) return err } const updateSCIMGroupMapping = `-- name: UpdateSCIMGroupMapping :exec UPDATE scim_groups SET mapped_role = $2, mapped_group_id = $3, updated_at = now() -WHERE id = $1 +WHERE id = $1 AND org_id = $4 ` type UpdateSCIMGroupMappingParams struct { ID uuid.UUID MappedRole sql.NullString MappedGroupID uuid.NullUUID + OrgID uuid.UUID } func (q *Queries) UpdateSCIMGroupMapping(ctx context.Context, arg UpdateSCIMGroupMappingParams) error { - _, err := q.db.ExecContext(ctx, updateSCIMGroupMapping, arg.ID, arg.MappedRole, arg.MappedGroupID) + _, err := q.db.ExecContext(ctx, updateSCIMGroupMapping, + arg.ID, + arg.MappedRole, + arg.MappedGroupID, + arg.OrgID, + ) return err } diff --git a/internal/store/group.go b/internal/store/group.go index 4fdb85db..4ff916ce 100644 --- a/internal/store/group.go +++ b/internal/store/group.go @@ -143,12 +143,14 @@ func (s *Store) ListGroupMembers(ctx context.Context, orgID, groupID uuid.UUID) } // GetGroupIfActive returns the group if it exists and is not soft-deleted, -// or (nil, nil) if not found or deleted. Does not require org context since -// it's used by SCIM sync where the group ID is already known. -func (s *Store) GetGroupIfActive(ctx context.Context, id uuid.UUID) (*generated.Group, error) { +// or (nil, nil) if not found or deleted. +func (s *Store) GetGroupIfActive(ctx context.Context, orgID uuid.UUID, id uuid.UUID) (*generated.Group, error) { var result *generated.Group - err := s.withBypassTx(ctx, func(q *generated.Queries) error { - row, err := q.GetGroupIfActive(ctx, id) + err := s.withOrgTx(ctx, orgID, func(q *generated.Queries) error { + row, err := q.GetGroupIfActive(ctx, generated.GetGroupIfActiveParams{ + ID: id, + OrgID: orgID, + }) if errors.Is(err, sql.ErrNoRows) { return nil } @@ -167,7 +169,7 @@ func (s *Store) GetGroupIfActive(ctx context.Context, id uuid.UUID) (*generated. // AddGroupMemberSCIMManaged adds a user to the group with scim_managed=true. // Idempotent — ON CONFLICT DO NOTHING preserves existing memberships (including manual ones). func (s *Store) AddGroupMemberSCIMManaged(ctx context.Context, groupID, userID, orgID uuid.UUID) error { - return s.withBypassTx(ctx, func(q *generated.Queries) error { + return s.withOrgTx(ctx, orgID, func(q *generated.Queries) error { if err := q.AddGroupMemberSCIMManaged(ctx, generated.AddGroupMemberSCIMManagedParams{ GroupID: groupID, UserID: userID, @@ -181,11 +183,12 @@ func (s *Store) AddGroupMemberSCIMManaged(ctx context.Context, groupID, userID, // RemoveSCIMManagedGroupMember removes a user from a group only if their // membership is scim_managed=true. Manual memberships are left intact. -func (s *Store) RemoveSCIMManagedGroupMember(ctx context.Context, groupID, userID uuid.UUID) error { - return s.withBypassTx(ctx, func(q *generated.Queries) error { +func (s *Store) RemoveSCIMManagedGroupMember(ctx context.Context, groupID, userID, orgID uuid.UUID) error { + return s.withOrgTx(ctx, orgID, func(q *generated.Queries) error { if err := q.RemoveSCIMManagedGroupMember(ctx, generated.RemoveSCIMManagedGroupMemberParams{ GroupID: groupID, UserID: userID, + OrgID: orgID, }); err != nil { return fmt.Errorf("remove scim managed group member: %w", err) } diff --git a/internal/store/queries/auth.sql b/internal/store/queries/auth.sql index 76cd4a26..9acc2ff1 100644 --- a/internal/store/queries/auth.sql +++ b/internal/store/queries/auth.sql @@ -114,3 +114,9 @@ UPDATE users SET email = $2, display_name = $3 WHERE id = $1; SELECT * FROM user_identities WHERE provider = $1 AND user_id = $2 LIMIT 1; + +-- name: ListIdentitiesByProviderAndUsers :many +-- Returns identity rows for a given provider and set of user IDs. +-- Used by SCIM list users to batch-load external IDs. +SELECT * FROM user_identities +WHERE provider = $1 AND user_id = ANY($2::uuid[]); diff --git a/internal/store/queries/groups.sql b/internal/store/queries/groups.sql index 872483a5..34d85d50 100644 --- a/internal/store/queries/groups.sql +++ b/internal/store/queries/groups.sql @@ -36,10 +36,10 @@ ON CONFLICT (group_id, user_id) DO NOTHING; -- name: RemoveSCIMManagedGroupMember :exec DELETE FROM group_members -WHERE group_id = $1 AND user_id = $2 AND scim_managed = true; +WHERE group_id = $1 AND user_id = $2 AND org_id = $3 AND scim_managed = true; -- name: IsGroupMemberSCIMManaged :one SELECT scim_managed FROM group_members WHERE group_id = $1 AND user_id = $2; -- name: GetGroupIfActive :one -SELECT * FROM groups WHERE id = $1 AND deleted_at IS NULL; +SELECT * FROM groups WHERE id = $1 AND org_id = $2 AND deleted_at IS NULL; diff --git a/internal/store/queries/scim_groups.sql b/internal/store/queries/scim_groups.sql index 095696a8..3fb1f825 100644 --- a/internal/store/queries/scim_groups.sql +++ b/internal/store/queries/scim_groups.sql @@ -6,7 +6,7 @@ INSERT INTO scim_groups (org_id, external_id, display_name) VALUES ($1, $2, $3) RETURNING *; -- name: GetSCIMGroupByID :one -SELECT * FROM scim_groups WHERE id = $1; +SELECT * FROM scim_groups WHERE id = $1 AND org_id = $2; -- name: GetSCIMGroupByDisplayName :one SELECT * FROM scim_groups WHERE org_id = $1 AND display_name = $2; @@ -24,14 +24,14 @@ ORDER BY sg.display_name; -- name: UpdateSCIMGroup :exec UPDATE scim_groups SET display_name = $2, external_id = $3, updated_at = now() -WHERE id = $1; +WHERE id = $1 AND org_id = $4; -- name: UpdateSCIMGroupMapping :exec UPDATE scim_groups SET mapped_role = $2, mapped_group_id = $3, updated_at = now() -WHERE id = $1; +WHERE id = $1 AND org_id = $4; -- name: DeleteSCIMGroup :exec -DELETE FROM scim_groups WHERE id = $1; +DELETE FROM scim_groups WHERE id = $1 AND org_id = $2; -- name: AddSCIMGroupMember :exec INSERT INTO scim_group_members (scim_group_id, user_id, org_id) @@ -39,10 +39,10 @@ VALUES ($1, $2, $3) ON CONFLICT (scim_group_id, user_id) DO NOTHING; -- name: RemoveSCIMGroupMember :exec -DELETE FROM scim_group_members WHERE scim_group_id = $1 AND user_id = $2; +DELETE FROM scim_group_members WHERE scim_group_id = $1 AND user_id = $2 AND org_id = $3; -- name: ListSCIMGroupMembers :many -SELECT user_id FROM scim_group_members WHERE scim_group_id = $1; +SELECT user_id FROM scim_group_members WHERE scim_group_id = $1 AND org_id = $2; -- name: ListUserSCIMGroups :many SELECT sg.* FROM scim_groups sg @@ -57,4 +57,5 @@ SELECT COUNT(*)::int FROM scim_group_members sgm JOIN scim_groups sg ON sgm.scim_group_id = sg.id WHERE sgm.user_id = $1 AND sg.mapped_group_id = $2 - AND sg.id != $3; + AND sg.id != $3 + AND sgm.org_id = $4; diff --git a/internal/store/scim_groups.go b/internal/store/scim_groups.go index 2f0c09c9..e917a2dc 100644 --- a/internal/store/scim_groups.go +++ b/internal/store/scim_groups.go @@ -36,14 +36,14 @@ func (s *Store) CreateSCIMGroup(ctx context.Context, orgID uuid.UUID, externalID return &row, nil } -// GetSCIMGroup returns the SCIM group by ID, or (nil, nil) if not found. -func (s *Store) GetSCIMGroup(ctx context.Context, id uuid.UUID) (*generated.ScimGroup, error) { +// GetSCIMGroup returns the SCIM group by ID within the given org, or (nil, nil) if not found. +func (s *Store) GetSCIMGroup(ctx context.Context, orgID uuid.UUID, id uuid.UUID) (*generated.ScimGroup, error) { var result *generated.ScimGroup - // RLS requires org context but GetSCIMGroupByID filters by id only. - // We use withBypassTx here since callers may not know the org_id upfront, - // and GetSCIMGroupByID queries by PK (RLS would restrict without SET LOCAL). - err := s.withBypassTx(ctx, func(q *generated.Queries) error { - row, err := q.GetSCIMGroupByID(ctx, id) + err := s.withOrgTx(ctx, orgID, func(q *generated.Queries) error { + row, err := q.GetSCIMGroupByID(ctx, generated.GetSCIMGroupByIDParams{ + ID: id, + OrgID: orgID, + }) if errors.Is(err, sql.ErrNoRows) { return nil } @@ -123,23 +123,24 @@ func (s *Store) ListSCIMGroups(ctx context.Context, orgID uuid.UUID) ([]generate } // UpdateSCIMGroup updates the display name and external ID of a SCIM group. -func (s *Store) UpdateSCIMGroup(ctx context.Context, id uuid.UUID, displayName string, externalID *string) error { +func (s *Store) UpdateSCIMGroup(ctx context.Context, orgID uuid.UUID, id uuid.UUID, displayName string, externalID *string) error { var extID sql.NullString if externalID != nil { extID = sql.NullString{String: *externalID, Valid: true} } - return s.withBypassTx(ctx, func(q *generated.Queries) error { + return s.withOrgTx(ctx, orgID, func(q *generated.Queries) error { return q.UpdateSCIMGroup(ctx, generated.UpdateSCIMGroupParams{ ID: id, DisplayName: displayName, ExternalID: extID, + OrgID: orgID, }) }) } // UpdateSCIMGroupMapping updates the mapped role and notification group for a SCIM group. -func (s *Store) UpdateSCIMGroupMapping(ctx context.Context, id uuid.UUID, mappedRole *string, mappedGroupID *uuid.UUID) error { +func (s *Store) UpdateSCIMGroupMapping(ctx context.Context, orgID uuid.UUID, id uuid.UUID, mappedRole *string, mappedGroupID *uuid.UUID) error { var role sql.NullString if mappedRole != nil { role = sql.NullString{String: *mappedRole, Valid: true} @@ -149,19 +150,23 @@ func (s *Store) UpdateSCIMGroupMapping(ctx context.Context, id uuid.UUID, mapped groupID = uuid.NullUUID{UUID: *mappedGroupID, Valid: true} } - return s.withBypassTx(ctx, func(q *generated.Queries) error { + return s.withOrgTx(ctx, orgID, func(q *generated.Queries) error { return q.UpdateSCIMGroupMapping(ctx, generated.UpdateSCIMGroupMappingParams{ ID: id, MappedRole: role, MappedGroupID: groupID, + OrgID: orgID, }) }) } // DeleteSCIMGroup deletes a SCIM group by ID. Members are cascade-deleted. -func (s *Store) DeleteSCIMGroup(ctx context.Context, id uuid.UUID) error { - return s.withBypassTx(ctx, func(q *generated.Queries) error { - return q.DeleteSCIMGroup(ctx, id) +func (s *Store) DeleteSCIMGroup(ctx context.Context, orgID uuid.UUID, id uuid.UUID) error { + return s.withOrgTx(ctx, orgID, func(q *generated.Queries) error { + return q.DeleteSCIMGroup(ctx, generated.DeleteSCIMGroupParams{ + ID: id, + OrgID: orgID, + }) }) } @@ -177,21 +182,25 @@ func (s *Store) AddSCIMGroupMember(ctx context.Context, scimGroupID, userID, org } // RemoveSCIMGroupMember removes a user from a SCIM group. -func (s *Store) RemoveSCIMGroupMember(ctx context.Context, scimGroupID, userID uuid.UUID) error { - return s.withBypassTx(ctx, func(q *generated.Queries) error { +func (s *Store) RemoveSCIMGroupMember(ctx context.Context, orgID uuid.UUID, scimGroupID, userID uuid.UUID) error { + return s.withOrgTx(ctx, orgID, func(q *generated.Queries) error { return q.RemoveSCIMGroupMember(ctx, generated.RemoveSCIMGroupMemberParams{ ScimGroupID: scimGroupID, UserID: userID, + OrgID: orgID, }) }) } // ListSCIMGroupMembers returns all user IDs in the given SCIM group. -func (s *Store) ListSCIMGroupMembers(ctx context.Context, scimGroupID uuid.UUID) ([]uuid.UUID, error) { +func (s *Store) ListSCIMGroupMembers(ctx context.Context, orgID uuid.UUID, scimGroupID uuid.UUID) ([]uuid.UUID, error) { var result []uuid.UUID - err := s.withBypassTx(ctx, func(q *generated.Queries) error { + err := s.withOrgTx(ctx, orgID, func(q *generated.Queries) error { var err error - result, err = q.ListSCIMGroupMembers(ctx, scimGroupID) + result, err = q.ListSCIMGroupMembers(ctx, generated.ListSCIMGroupMembersParams{ + ScimGroupID: scimGroupID, + OrgID: orgID, + }) return err }) if err != nil { @@ -219,14 +228,15 @@ func (s *Store) ListUserSCIMGroups(ctx context.Context, userID, orgID uuid.UUID) // CountOtherSCIMGroupsWithSameMapping counts how many other SCIM groups (excluding // excludeGroupID) map to the same notification group and contain the given user. -func (s *Store) CountOtherSCIMGroupsWithSameMapping(ctx context.Context, userID uuid.UUID, mappedGroupID uuid.UUID, excludeGroupID uuid.UUID) (int, error) { +func (s *Store) CountOtherSCIMGroupsWithSameMapping(ctx context.Context, orgID uuid.UUID, userID uuid.UUID, mappedGroupID uuid.UUID, excludeGroupID uuid.UUID) (int, error) { var count int32 - err := s.withBypassTx(ctx, func(q *generated.Queries) error { + err := s.withOrgTx(ctx, orgID, func(q *generated.Queries) error { var err error count, err = q.CountOtherSCIMGroupsWithSameMapping(ctx, generated.CountOtherSCIMGroupsWithSameMappingParams{ UserID: userID, MappedGroupID: uuid.NullUUID{UUID: mappedGroupID, Valid: true}, ID: excludeGroupID, + OrgID: orgID, }) return err }) diff --git a/internal/store/scim_groups_test.go b/internal/store/scim_groups_test.go index 7f6c3d05..973dfd60 100644 --- a/internal/store/scim_groups_test.go +++ b/internal/store/scim_groups_test.go @@ -114,11 +114,11 @@ func TestUpdateSCIMGroupMapping(t *testing.T) { notifGroup := db.MustCreateGroup(t, ctx, org.ID, "NotifGroup", "for mapping test") mappedRole := "admin" - err = db.UpdateSCIMGroupMapping(ctx, group.ID, &mappedRole, ¬ifGroup.ID) + err = db.UpdateSCIMGroupMapping(ctx, org.ID, group.ID, &mappedRole, ¬ifGroup.ID) require.NoError(t, err) // Re-read and verify. - got, err := db.GetSCIMGroup(ctx, group.ID) + got, err := db.GetSCIMGroup(ctx, org.ID, group.ID) require.NoError(t, err) require.NotNil(t, got) require.True(t, got.MappedRole.Valid) @@ -142,15 +142,15 @@ func TestDeleteSCIMGroup_CascadesMembers(t *testing.T) { require.NoError(t, db.AddSCIMGroupMember(ctx, group.ID, user.ID, org.ID)) // Verify member exists. - members, err := db.ListSCIMGroupMembers(ctx, group.ID) + members, err := db.ListSCIMGroupMembers(ctx, org.ID, group.ID) require.NoError(t, err) require.Len(t, members, 1) // Delete group. - require.NoError(t, db.DeleteSCIMGroup(ctx, group.ID)) + require.NoError(t, db.DeleteSCIMGroup(ctx, org.ID, group.ID)) // Verify group gone. - got, err := db.GetSCIMGroup(ctx, group.ID) + got, err := db.GetSCIMGroup(ctx, org.ID, group.ID) require.NoError(t, err) require.Nil(t, got) @@ -178,7 +178,7 @@ func TestAddSCIMGroupMember_Idempotent(t *testing.T) { require.NoError(t, db.AddSCIMGroupMember(ctx, group.ID, user.ID, org.ID)) require.NoError(t, db.AddSCIMGroupMember(ctx, group.ID, user.ID, org.ID)) - members, err := db.ListSCIMGroupMembers(ctx, group.ID) + members, err := db.ListSCIMGroupMembers(ctx, org.ID, group.ID) require.NoError(t, err) require.Len(t, members, 1, "member count should be 1 after duplicate add") } @@ -198,9 +198,9 @@ func TestRemoveSCIMGroupMember(t *testing.T) { require.NoError(t, db.AddSCIMGroupMember(ctx, group.ID, user.ID, org.ID)) // Remove. - require.NoError(t, db.RemoveSCIMGroupMember(ctx, group.ID, user.ID)) + require.NoError(t, db.RemoveSCIMGroupMember(ctx, org.ID, group.ID, user.ID)) - members, err := db.ListSCIMGroupMembers(ctx, group.ID) + members, err := db.ListSCIMGroupMembers(ctx, org.ID, group.ID) require.NoError(t, err) require.Empty(t, members) } @@ -253,8 +253,8 @@ func TestCountOtherSCIMGroupsWithSameMapping(t *testing.T) { // Map both SCIM groups to the same notification group. role := "member" - require.NoError(t, db.UpdateSCIMGroupMapping(ctx, groupA.ID, &role, ¬ifGroup.ID)) - require.NoError(t, db.UpdateSCIMGroupMapping(ctx, groupB.ID, &role, ¬ifGroup.ID)) + require.NoError(t, db.UpdateSCIMGroupMapping(ctx, org.ID, groupA.ID, &role, ¬ifGroup.ID)) + require.NoError(t, db.UpdateSCIMGroupMapping(ctx, org.ID, groupB.ID, &role, ¬ifGroup.ID)) user := db.MustCreateUser(t, ctx, "countuser@example.com", "CountUser", "hash", 1) require.NoError(t, db.CreateOrgMember(ctx, org.ID, user.ID, "member")) @@ -264,15 +264,15 @@ func TestCountOtherSCIMGroupsWithSameMapping(t *testing.T) { require.NoError(t, db.AddSCIMGroupMember(ctx, groupB.ID, user.ID, org.ID)) // Excluding group A → count should be 1 (group B). - count, err := db.CountOtherSCIMGroupsWithSameMapping(ctx, user.ID, notifGroup.ID, groupA.ID) + count, err := db.CountOtherSCIMGroupsWithSameMapping(ctx, org.ID, user.ID, notifGroup.ID, groupA.ID) require.NoError(t, err) require.Equal(t, 1, count) // Now remove user from group B. - require.NoError(t, db.RemoveSCIMGroupMember(ctx, groupB.ID, user.ID)) + require.NoError(t, db.RemoveSCIMGroupMember(ctx, org.ID, groupB.ID, user.ID)) // Excluding group A → count should be 0. - count, err = db.CountOtherSCIMGroupsWithSameMapping(ctx, user.ID, notifGroup.ID, groupA.ID) + count, err = db.CountOtherSCIMGroupsWithSameMapping(ctx, org.ID, user.ID, notifGroup.ID, groupA.ID) require.NoError(t, err) require.Equal(t, 0, count) }