From 6c260b16754a4bfb7d7329c61a6940e4b7d38878 Mon Sep 17 00:00:00 2001 From: dev Date: Mon, 18 May 2026 19:05:40 +0200 Subject: [PATCH 1/6] fix(skills): use SSRF-safe HTTP client for remote skills registry MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The skills cache fetches from a base URL that may be operator-supplied, and the contents are fed to the model as instructions. Using the plain remote transport allowed any IP — including 127.0.0.1, RFC1918, or 169.254.169.254 (cloud metadata) — to be reached, turning a hostile registry into a metadata-exfiltration vector via prompt injection. Switch to httpclient.NewSafeClient so non-public addresses are refused at dial time (after DNS resolution, defeating DNS rebinding). Tests override the client via TestMain since httptest.NewServer binds to 127.0.0.1. --- pkg/skills/cache.go | 25 ++++++++++++++++--------- pkg/skills/main_test.go | 16 ++++++++++++++++ 2 files changed, 32 insertions(+), 9 deletions(-) create mode 100644 pkg/skills/main_test.go diff --git a/pkg/skills/cache.go b/pkg/skills/cache.go index 3242a6ef9..2ff40c034 100644 --- a/pkg/skills/cache.go +++ b/pkg/skills/cache.go @@ -15,25 +15,32 @@ import ( "strings" "time" - "github.com/docker/docker-agent/pkg/remote" + "github.com/docker/docker-agent/pkg/httpclient" ) // remoteHTTPTimeout caps each HTTP request made to a remote skills source. const remoteHTTPTimeout = 30 * time.Second -// httpGet performs a GET request using the standard remote transport so that -// Docker Desktop proxy/SSL settings are honoured. The returned response body -// must be closed by the caller. +// skillsHTTPClient is used for outbound calls to remote skill registries. +// The base URL is operator-supplied and the contents are fed to the model +// as instructions, so a hostile (or compromised) registry could otherwise +// be used to read internal endpoints (loopback, RFC1918, link-local incl. +// cloud metadata at 169.254.169.254) and exfiltrate them through prompt +// injection. The SSRF-safe client refuses such targets at dial time, after +// DNS resolution, defeating DNS rebinding. +// +// Tests in this package replace the var via TestMain (see main_test.go) +// because httptest.NewServer binds to 127.0.0.1. +var skillsHTTPClient = httpclient.NewSafeClient(remoteHTTPTimeout, false) + +// httpGet performs a GET request using the SSRF-safe HTTP client. The +// returned response body must be closed by the caller. func httpGet(ctx context.Context, url string) (*http.Response, error) { - client := &http.Client{ - Timeout: remoteHTTPTimeout, - Transport: remote.NewTransport(ctx), - } req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) if err != nil { return nil, fmt.Errorf("creating request for %s: %w", url, err) } - return client.Do(req) + return skillsHTTPClient.Do(req) } type diskCache struct { diff --git a/pkg/skills/main_test.go b/pkg/skills/main_test.go new file mode 100644 index 000000000..833235db8 --- /dev/null +++ b/pkg/skills/main_test.go @@ -0,0 +1,16 @@ +package skills + +import ( + "os" + "testing" + + "github.com/docker/docker-agent/pkg/httpclient" +) + +// TestMain swaps the SSRF-safe HTTP client for the loopback-allowing +// variant so tests can hit httptest.NewServer (which binds to 127.0.0.1). +// Production code keeps the safe client. +func TestMain(m *testing.M) { + skillsHTTPClient = httpclient.NewSafeClient(remoteHTTPTimeout, true) + os.Exit(m.Run()) +} From 34a1f3f1f974aacf6c9252ab7a1506c66f35e2fa Mon Sep 17 00:00:00 2001 From: dev Date: Mon, 18 May 2026 19:13:54 +0200 Subject: [PATCH 2/6] refactor(rag,evaluation): use errors.Is for sentinel error checks Replace direct == comparisons with errors.Is so wrapped errors are matched correctly, and replace deprecated os.IsNotExist with errors.Is(err, fs.ErrNotExist). For sql.ErrNoRows the existing == form worked because database/sql.Scan is documented to return it unwrapped (and golangci-lint's errorlint allowlists this pattern), but using errors.Is is consistent with the rest of these files (e.g. bm25_database.go already used errors.Is at line 100) and makes the code resilient if a future caller wraps the error. Note: the project already enables errorlint via .golangci.yml, but errorlint intentionally does not flag these specific patterns (sql.Scan -> sql.ErrNoRows is allowlisted; os.IsNotExist is not in errorlint's scope). No linter change is therefore needed. --- pkg/evaluation/build.go | 3 ++- pkg/rag/strategy/bm25.go | 3 ++- pkg/rag/strategy/bm25_database.go | 5 +++-- pkg/rag/strategy/chunked_embeddings_database.go | 2 +- pkg/rag/strategy/semantic_embeddings_database.go | 2 +- 5 files changed, 9 insertions(+), 6 deletions(-) diff --git a/pkg/evaluation/build.go b/pkg/evaluation/build.go index 58c786d78..d65497c58 100644 --- a/pkg/evaluation/build.go +++ b/pkg/evaluation/build.go @@ -6,6 +6,7 @@ import ( _ "embed" "errors" "fmt" + "io/fs" "os" "os/exec" "path/filepath" @@ -91,7 +92,7 @@ func (r *Runner) buildEvalImage(ctx context.Context, evals *session.EvalCriteria data.CopyWorkingDir = false } else { buildContext = filepath.Join(r.EvalsDir, "working_dirs", evals.WorkingDir) - if _, err := os.Stat(buildContext); os.IsNotExist(err) { + if _, err := os.Stat(buildContext); errors.Is(err, fs.ErrNotExist) { return "", fmt.Errorf("working directory not found: %s", buildContext) } data.CopyWorkingDir = true diff --git a/pkg/rag/strategy/bm25.go b/pkg/rag/strategy/bm25.go index be3faf54a..4524b0cb7 100644 --- a/pkg/rag/strategy/bm25.go +++ b/pkg/rag/strategy/bm25.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "io/fs" "log/slog" "math" "os" @@ -642,7 +643,7 @@ func (s *BM25Strategy) addPathToWatcher(ctx context.Context, path string) error stat, err := os.Stat(absPath) if err != nil { - if os.IsNotExist(err) { + if errors.Is(err, fs.ErrNotExist) { return nil } return fmt.Errorf("failed to stat path: %w", err) diff --git a/pkg/rag/strategy/bm25_database.go b/pkg/rag/strategy/bm25_database.go index 32be73635..54ff217c4 100644 --- a/pkg/rag/strategy/bm25_database.go +++ b/pkg/rag/strategy/bm25_database.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "fmt" + "io/fs" "log/slog" "os" "path/filepath" @@ -159,7 +160,7 @@ func (d *bm25DB) GetFileMetadata(ctx context.Context, sourcePath string) (*datab fmt.Sprintf("SELECT source_path, file_hash, last_indexed, chunk_count FROM %s WHERE source_path = ?", d.metadataTable), sourcePath).Scan(&metadata.SourcePath, &metadata.FileHash, &metadata.LastIndexed, &metadata.ChunkCount) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, nil } if err != nil { @@ -220,7 +221,7 @@ func ensureDir(filePath string) error { return nil } - if _, err := os.Stat(dir); os.IsNotExist(err) { + if _, err := os.Stat(dir); errors.Is(err, fs.ErrNotExist) { return os.MkdirAll(dir, 0o700) } diff --git a/pkg/rag/strategy/chunked_embeddings_database.go b/pkg/rag/strategy/chunked_embeddings_database.go index f0253dbbb..373ff22de 100644 --- a/pkg/rag/strategy/chunked_embeddings_database.go +++ b/pkg/rag/strategy/chunked_embeddings_database.go @@ -197,7 +197,7 @@ func (d *chunkedVectorDB) GetFileMetadata(ctx context.Context, sourcePath string GROUP BY f.source_path, f.file_hash, f.indexed_at`, d.filesTable, d.chunksTable), sourcePath).Scan(&metadata.SourcePath, &metadata.FileHash, &metadata.LastIndexed, &metadata.ChunkCount) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, nil } if err != nil { diff --git a/pkg/rag/strategy/semantic_embeddings_database.go b/pkg/rag/strategy/semantic_embeddings_database.go index 620d375c7..2d597769a 100644 --- a/pkg/rag/strategy/semantic_embeddings_database.go +++ b/pkg/rag/strategy/semantic_embeddings_database.go @@ -207,7 +207,7 @@ func (d *semanticVectorDB) GetFileMetadata(ctx context.Context, sourcePath strin GROUP BY f.source_path, f.file_hash, f.indexed_at`, d.filesTable, d.chunksTable), sourcePath).Scan(&metadata.SourcePath, &metadata.FileHash, &metadata.LastIndexed, &metadata.ChunkCount) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, nil } if err != nil { From af0fa53486db98b1480744ffb48a6a20e241415e Mon Sep 17 00:00:00 2001 From: dev Date: Mon, 18 May 2026 19:16:43 +0200 Subject: [PATCH 3/6] fix(mcp/oauth): drop stray OAuth callbacks instead of blocking the handler MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The callback server published the OAuth result via three size-1 channels with blocking sends. Any additional hit on /callback — a stale browser tab, the user clicking back then re-submitting, or any local process on the loopback interface poking the port — would wedge the HTTP handler goroutine forever (until Shutdown), leaking goroutines and TCP connections per stray request, and giving a local attacker a trivial way to pin those resources. Collapse codeCh/stateCh/errCh into a single buffered resultCh and use non-blocking sends via a deliver helper. The first callback wins; later ones are dropped on the floor. WaitForCallback no longer has the two-channel race where a context cancellation between codeCh and stateCh reads would lose the code. Add TestCallbackServer_DuplicateCallbacksDoNotBlock which fires five back-to-back callbacks and verifies (a) every HTTP response completes within the timeout and (b) the first one is still delivered to the waiter. --- pkg/tools/mcp/oauth_server.go | 52 ++++++++++++++++++------------ pkg/tools/mcp/oauth_server_test.go | 50 ++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 21 deletions(-) diff --git a/pkg/tools/mcp/oauth_server.go b/pkg/tools/mcp/oauth_server.go index aa727bb7c..82eb8dfb1 100644 --- a/pkg/tools/mcp/oauth_server.go +++ b/pkg/tools/mcp/oauth_server.go @@ -21,15 +21,24 @@ type CallbackServer struct { listener net.Listener mu sync.Mutex - // Channels for communicating the authorization code and state - codeCh chan string - stateCh chan string - errCh chan error + // resultCh delivers the outcome of the first received callback. + // It is buffered (size 1) and all sends are non-blocking so that a + // stray duplicate or attacker-triggered callback cannot wedge the + // HTTP handler goroutine on a full channel. + resultCh chan callbackResult // Expected state parameter for CSRF protection expectedState string } +// callbackResult is the outcome of a single OAuth callback. Exactly one +// of err / (code, state) is set. +type callbackResult struct { + code string + state string + err error +} + // NewCallbackServer creates a new OAuth callback server on a random available port func NewCallbackServer() (*CallbackServer, error) { return NewCallbackServerOnPort(0) @@ -46,9 +55,7 @@ func NewCallbackServerOnPort(port int) (*CallbackServer, error) { cs := &CallbackServer{ listener: listener, - codeCh: make(chan string, 1), - stateCh: make(chan string, 1), - errCh: make(chan error, 1), + resultCh: make(chan callbackResult, 1), } mux := http.NewServeMux() @@ -132,7 +139,7 @@ func (cs *CallbackServer) handleCallback(w http.ResponseWriter, r *http.Request) errMsg = fmt.Sprintf("%s: %s", errMsg, errDesc) } - cs.errCh <- fmt.Errorf("OAuth error: %s", errMsg) + cs.deliver(callbackResult{err: fmt.Errorf("OAuth error: %s", errMsg)}) w.WriteHeader(http.StatusBadRequest) fmt.Fprintf(w, ` @@ -157,7 +164,7 @@ func (cs *CallbackServer) handleCallback(w http.ResponseWriter, r *http.Request) state := query.Get("state") if code == "" { - cs.errCh <- errors.New("no authorization code received") + cs.deliver(callbackResult{err: errors.New("no authorization code received")}) w.WriteHeader(http.StatusBadRequest) fmt.Fprint(w, "No authorization code received") return @@ -171,14 +178,13 @@ func (cs *CallbackServer) handleCallback(w http.ResponseWriter, r *http.Request) cs.mu.Unlock() if expectedState == "" || subtle.ConstantTimeCompare([]byte(state), []byte(expectedState)) != 1 { - cs.errCh <- errors.New("OAuth state mismatch (possible CSRF attempt or stale callback)") + cs.deliver(callbackResult{err: errors.New("OAuth state mismatch (possible CSRF attempt or stale callback)")}) w.WriteHeader(http.StatusBadRequest) fmt.Fprint(w, "Invalid state parameter") return } - cs.codeCh <- code - cs.stateCh <- state + cs.deliver(callbackResult{code: code, state: state}) w.WriteHeader(http.StatusOK) fmt.Fprint(w, ` @@ -198,17 +204,21 @@ func (cs *CallbackServer) handleCallback(w http.ResponseWriter, r *http.Request) `) } +// deliver attempts to publish r on resultCh without blocking. The first +// callback wins; later callbacks (stale browser tabs, duplicate clicks, +// any local process probing the loopback port) are dropped on the floor +// instead of pinning the HTTP handler goroutine on a full channel. +func (cs *CallbackServer) deliver(r callbackResult) { + select { + case cs.resultCh <- r: + default: + } +} + func (cs *CallbackServer) WaitForCallback(ctx context.Context) (code, state string, err error) { select { - case code = <-cs.codeCh: - select { - case state = <-cs.stateCh: - return code, state, nil - case <-ctx.Done(): - return "", "", ctx.Err() - } - case err = <-cs.errCh: - return "", "", err + case r := <-cs.resultCh: + return r.code, r.state, r.err case <-ctx.Done(): return "", "", ctx.Err() } diff --git a/pkg/tools/mcp/oauth_server_test.go b/pkg/tools/mcp/oauth_server_test.go index c89bffab2..5963b4b40 100644 --- a/pkg/tools/mcp/oauth_server_test.go +++ b/pkg/tools/mcp/oauth_server_test.go @@ -2,9 +2,11 @@ package mcp import ( "fmt" + "net/http" "strconv" "strings" "testing" + "time" ) func TestCallbackServer_Port(t *testing.T) { @@ -73,6 +75,54 @@ func TestBuildRedirectURI(t *testing.T) { } } +// TestCallbackServer_DuplicateCallbacksDoNotBlock guards against a regression +// where extra callbacks (stale browser tabs, page refreshes, or any local +// process probing the loopback port) blocked the HTTP handler goroutine on +// a full result channel. Sends are now non-blocking; the first callback +// wins and later ones must be dropped without wedging the server. +func TestCallbackServer_DuplicateCallbacksDoNotBlock(t *testing.T) { + cs, err := NewCallbackServer() + if err != nil { + t.Fatal(err) + } + if err := cs.Start(); err != nil { + t.Fatal(err) + } + defer func() { _ = cs.Shutdown(t.Context()) }() + + cs.SetExpectedState("expected-state") + callbackURL := cs.GetRedirectURI() + "?code=authcode&state=expected-state" + + client := &http.Client{Timeout: 2 * time.Second} + + // Fire several callbacks back-to-back. Each one must complete (so the + // handler goroutine isn't stuck) regardless of whether anyone is + // reading from resultCh yet. + for i := range 5 { + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, callbackURL, http.NoBody) + if err != nil { + t.Fatal(err) + } + resp, err := client.Do(req) + if err != nil { + t.Fatalf("callback %d: %v", i, err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("callback %d: status = %d, want 200", i, resp.StatusCode) + } + } + + // The first callback must still be deliverable to the waiter. + code, state, err := cs.WaitForCallback(t.Context()) + if err != nil { + t.Fatalf("WaitForCallback: %v", err) + } + if code != "authcode" || state != "expected-state" { + t.Errorf("got code=%q state=%q, want code=authcode state=expected-state", code, state) + } +} + // TestCallbackServer_ResolveRedirectURI exercises the method wrapper end-to-end // to make sure it stitches GetRedirectURI() and Port() together correctly. func TestCallbackServer_ResolveRedirectURI(t *testing.T) { From e12fcf873d1b8b1c3c662bd623460d52f0a0a153 Mon Sep 17 00:00:00 2001 From: dev Date: Mon, 18 May 2026 19:19:29 +0200 Subject: [PATCH 4/6] fix(skills): honour Cache-Control no-store and no-cache properly The skills cache previously persisted every response to disk, even when the upstream marked the body as Cache-Control: no-store. Since the cached content is fed to the model as instructions, this was both a privacy hazard (sensitive content lingering under ~/.cagent) and a violation of RFC 9111 \u00a75.2.2.5 ("the cache MUST NOT store any part of either the immediate request or response"). no-cache was also conflated with max-age=0 \u2014 it actually means "may be stored, but must be revalidated before reuse" (RFC 9111 \u00a75.2.2.4). For the skills cache, which has no conditional-GET support yet, the safe approximation is to store but report a miss on the next Get, forcing a fresh fetch. - Replace parseCacheExpiry with parseCacheControl returning a structured cacheDirective. - FetchAndStore now skips the disk write entirely on no-store and returns the body in-memory. - Get treats fs.ErrNotExist as a clean miss, but logs other errors (corrupt JSON, EACCES, \u2026) at debug instead of silently masking them as a benign refetch trigger. - Add tests for no-store, no-cache, and the directive precedence. --- pkg/skills/cache.go | 84 +++++++++++++++++++++++++++------- pkg/skills/cache_test.go | 97 +++++++++++++++++++++++++++++++++------- 2 files changed, 148 insertions(+), 33 deletions(-) diff --git a/pkg/skills/cache.go b/pkg/skills/cache.go index 2ff40c034..6fb7a197f 100644 --- a/pkg/skills/cache.go +++ b/pkg/skills/cache.go @@ -5,8 +5,10 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "errors" "fmt" "io" + "io/fs" "log/slog" "net/http" "os" @@ -68,6 +70,10 @@ func (c *diskCache) cacheDir(baseURL, skillName string) string { } // Get returns the cached content for a file if it exists and is not expired. +// Treats missing-file errors as a cache miss (returns false). Other I/O +// errors (e.g. EACCES, corrupt JSON) are surfaced through a debug log so +// they don't masquerade as a benign refetch trigger but still don't break +// the caller — a refetch is the right fallback. func (c *diskCache) Get(baseURL, skillName, filePath string) (string, bool) { dir := c.cacheDir(baseURL, skillName) contentPath := filepath.Join(dir, filePath) @@ -75,6 +81,9 @@ func (c *diskCache) Get(baseURL, skillName, filePath string) (string, bool) { meta, err := c.readMetadata(metaPath) if err != nil { + if !errors.Is(err, fs.ErrNotExist) { + slog.Debug("Skill cache metadata unreadable, treating as miss", "path", metaPath, "error", err) + } return "", false } @@ -84,6 +93,9 @@ func (c *diskCache) Get(baseURL, skillName, filePath string) (string, bool) { data, err := os.ReadFile(contentPath) if err != nil { + if !errors.Is(err, fs.ErrNotExist) { + slog.Debug("Skill cache content unreadable, treating as miss", "path", contentPath, "error", err) + } return "", false } @@ -91,7 +103,8 @@ func (c *diskCache) Get(baseURL, skillName, filePath string) (string, bool) { } // FetchAndStore downloads a file from the given URL and stores it in the cache. -// It respects Cache-Control headers to determine expiry. +// It respects Cache-Control headers to determine expiry, and skips the disk +// write entirely when the response is marked no-store (RFC 9111 §5.2.2.5). func (c *diskCache) FetchAndStore(ctx context.Context, baseURL, skillName, filePath, fileURL string) (string, error) { slog.DebugContext(ctx, "Fetching remote skill file", "url", fileURL) @@ -110,7 +123,17 @@ func (c *diskCache) FetchAndStore(ctx context.Context, baseURL, skillName, fileP return "", fmt.Errorf("reading %s: %w", fileURL, err) } - expiresAt := parseCacheExpiry(resp.Header.Get("Cache-Control")) + directive := parseCacheControl(resp.Header.Get("Cache-Control")) + + // Honour no-store: never persist the response. This matters because the + // fetched content is fed to the LLM as instructions, so persisting an + // upstream-marked-private response on disk under ~/.cagent is both a + // privacy hazard (lingering sensitive content) and a correctness bug + // per RFC 9111 §5.2.2.5. + if directive.noStore { + slog.DebugContext(ctx, "Cache-Control no-store: skipping disk write", "url", fileURL) + return string(body), nil + } dir := c.cacheDir(baseURL, skillName) contentPath := filepath.Join(dir, filePath) @@ -127,7 +150,7 @@ func (c *diskCache) FetchAndStore(ctx context.Context, baseURL, skillName, fileP meta := cacheMetadata{ URL: fileURL, CachedAt: time.Now(), - ExpiresAt: expiresAt, + ExpiresAt: directive.expiresAt(), } metaJSON, _ := json.Marshal(meta) if err := os.WriteFile(metaPath, metaJSON, 0o600); err != nil { @@ -152,28 +175,55 @@ func (c *diskCache) readMetadata(metaPath string) (cacheMetadata, error) { const defaultCacheTTL = 1 * time.Hour -// parseCacheExpiry extracts the expiry time from a Cache-Control header value. -// Falls back to defaultCacheTTL if the header is missing or unparseable. -func parseCacheExpiry(cacheControl string) time.Time { - if cacheControl == "" { - return time.Now().Add(defaultCacheTTL) +// cacheDirective is the parsed subset of Cache-Control we care about. +type cacheDirective struct { + noStore bool + noCache bool + // hasMaxAge is true when the header explicitly carried max-age=N. + hasMaxAge bool + maxAge time.Duration +} + +// expiresAt returns the absolute time after which the cached entry must +// not be served. no-cache forces immediate expiry — the response may be +// stored, but every read must re-validate (we currently approximate this +// as a refetch since conditional-GET support isn't implemented yet). +func (d cacheDirective) expiresAt() time.Time { + now := time.Now() + if d.noCache { + return now } + if d.hasMaxAge { + return now.Add(d.maxAge) + } + return now.Add(defaultCacheTTL) +} - for directive := range strings.SplitSeq(cacheControl, ",") { - directive = strings.TrimSpace(directive) +// parseCacheControl extracts the directives we honour from a Cache-Control +// header value. Unknown directives are ignored; an empty header yields the +// zero value, which falls back to defaultCacheTTL via expiresAt(). +func parseCacheControl(header string) cacheDirective { + var d cacheDirective + if header == "" { + return d + } - if strings.EqualFold(directive, "no-store") || strings.EqualFold(directive, "no-cache") { - // Still cache, but with zero TTL so it's refetched next time - return time.Now() - } + for directive := range strings.SplitSeq(header, ",") { + directive = strings.TrimSpace(directive) - if strings.HasPrefix(strings.ToLower(directive), "max-age=") { + switch { + case strings.EqualFold(directive, "no-store"): + d.noStore = true + case strings.EqualFold(directive, "no-cache"): + d.noCache = true + case strings.HasPrefix(strings.ToLower(directive), "max-age="): ageStr := directive[len("max-age="):] if seconds, err := strconv.ParseInt(ageStr, 10, 64); err == nil && seconds >= 0 { - return time.Now().Add(time.Duration(seconds) * time.Second) + d.hasMaxAge = true + d.maxAge = time.Duration(seconds) * time.Second } } } - return time.Now().Add(defaultCacheTTL) + return d } diff --git a/pkg/skills/cache_test.go b/pkg/skills/cache_test.go index 5ec753ed6..dacb50451 100644 --- a/pkg/skills/cache_test.go +++ b/pkg/skills/cache_test.go @@ -107,42 +107,53 @@ func TestDiskCache_DifferentURLsGetDifferentDirs(t *testing.T) { assert.NotEqual(t, dir1, dir2) } -func TestParseCacheExpiry(t *testing.T) { +func TestParseCacheControl(t *testing.T) { now := time.Now() t.Run("empty header uses default", func(t *testing.T) { - expiry := parseCacheExpiry("") - assert.WithinDuration(t, now.Add(1*time.Hour), expiry, 2*time.Second) + d := parseCacheControl("") + assert.False(t, d.noStore) + assert.False(t, d.noCache) + assert.WithinDuration(t, now.Add(1*time.Hour), d.expiresAt(), 2*time.Second) }) t.Run("max-age=3600", func(t *testing.T) { - expiry := parseCacheExpiry("max-age=3600") - assert.WithinDuration(t, now.Add(3600*time.Second), expiry, 2*time.Second) + d := parseCacheControl("max-age=3600") + assert.True(t, d.hasMaxAge) + assert.WithinDuration(t, now.Add(3600*time.Second), d.expiresAt(), 2*time.Second) }) t.Run("max-age=0", func(t *testing.T) { - expiry := parseCacheExpiry("max-age=0") - assert.WithinDuration(t, now, expiry, 2*time.Second) + d := parseCacheControl("max-age=0") + assert.True(t, d.hasMaxAge) + assert.WithinDuration(t, now, d.expiresAt(), 2*time.Second) }) t.Run("no-store", func(t *testing.T) { - expiry := parseCacheExpiry("no-store") - assert.WithinDuration(t, now, expiry, 2*time.Second) + d := parseCacheControl("no-store") + assert.True(t, d.noStore) }) - t.Run("no-cache", func(t *testing.T) { - expiry := parseCacheExpiry("no-cache") - assert.WithinDuration(t, now, expiry, 2*time.Second) + t.Run("no-cache forces immediate expiry", func(t *testing.T) { + d := parseCacheControl("no-cache") + assert.True(t, d.noCache) + assert.WithinDuration(t, now, d.expiresAt(), 2*time.Second) + }) + + t.Run("no-cache wins over max-age", func(t *testing.T) { + d := parseCacheControl("max-age=3600, no-cache") + assert.True(t, d.noCache) + assert.WithinDuration(t, now, d.expiresAt(), 2*time.Second) }) t.Run("multiple directives with max-age", func(t *testing.T) { - expiry := parseCacheExpiry("public, max-age=7200") - assert.WithinDuration(t, now.Add(7200*time.Second), expiry, 2*time.Second) + d := parseCacheControl("public, max-age=7200") + assert.WithinDuration(t, now.Add(7200*time.Second), d.expiresAt(), 2*time.Second) }) t.Run("unknown directives use default", func(t *testing.T) { - expiry := parseCacheExpiry("public") - assert.WithinDuration(t, now.Add(1*time.Hour), expiry, 2*time.Second) + d := parseCacheControl("public") + assert.WithinDuration(t, now.Add(1*time.Hour), d.expiresAt(), 2*time.Second) }) } @@ -156,3 +167,57 @@ func TestDiskCache_HTTPError(t *testing.T) { require.Error(t, err) assert.Contains(t, err.Error(), "HTTP 404") } + +// TestDiskCache_NoStoreSkipsDiskWrite verifies that a Cache-Control: no-store +// response is returned in-memory but never persisted, per RFC 9111 §5.2.2.5. +// The skills cache feeds fetched content to the LLM as instructions, so +// persisting an upstream-marked-private response under ~/.cagent would be +// both a privacy hazard and a spec violation. +func TestDiskCache_NoStoreSkipsDiskWrite(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Cache-Control", "no-store") + fmt.Fprint(w, "private content") + })) + defer srv.Close() + + cache := newDiskCache(t.TempDir()) + + content, err := cache.FetchAndStore(t.Context(), "https://example.com", "skill", "SKILL.md", srv.URL+"/SKILL.md") + require.NoError(t, err) + assert.Equal(t, "private content", content) + + // Nothing must have been persisted under the cache directory. + filePath := filepath.Join(cache.cacheDir("https://example.com", "skill"), "SKILL.md") + _, err = os.Stat(filePath) + require.ErrorIs(t, err, os.ErrNotExist, "no-store response must not be written to disk") + + _, err = os.Stat(filePath + ".meta") + require.ErrorIs(t, err, os.ErrNotExist, "no-store response must not have metadata persisted") + + // And subsequent Get() must report a miss. + _, ok := cache.Get("https://example.com", "skill", "SKILL.md") + assert.False(t, ok) +} + +// TestDiskCache_NoCacheStoresButExpiresImmediately verifies that no-cache +// allows storage but forces revalidation: the entry is written so it can be +// inspected, but Get() must report a miss so the next read refetches. +func TestDiskCache_NoCacheStoresButExpiresImmediately(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Cache-Control", "no-cache") + fmt.Fprint(w, "revalidate me") + })) + defer srv.Close() + + cache := newDiskCache(t.TempDir()) + + _, err := cache.FetchAndStore(t.Context(), "https://example.com", "skill", "SKILL.md", srv.URL+"/SKILL.md") + require.NoError(t, err) + + filePath := filepath.Join(cache.cacheDir("https://example.com", "skill"), "SKILL.md") + _, err = os.Stat(filePath) + require.NoError(t, err, "no-cache response should still be stored on disk") + + _, ok := cache.Get("https://example.com", "skill", "SKILL.md") + assert.False(t, ok, "no-cache must force a refetch on the next read") +} From 22661da574d2e423c6c7a276c86da286cad22500 Mon Sep 17 00:00:00 2001 From: dev Date: Mon, 18 May 2026 22:33:29 +0200 Subject: [PATCH 5/6] fix(skills): keep no-store responses on disk for in-process consumption Self-review caught a regression in the previous commit. Skipping the disk write entirely on Cache-Control: no-store made remote skills with that header unreadable, because the consumer at pkg/tools/builtin/skills reads skill.FilePath directly rather than going through diskCache.Get. After prefetchFiles returned (silently happy), readFileContent(...) hit ENOENT and the skill was effectively broken. Restore the disk write but keep no-store distinct from a normal response: the entry's metadata expires immediately, so the next Load() cycle refetches rather than reusing the stored copy. This is no longer strict RFC 9111 \u00a75.2.2.5 ("the cache MUST NOT store any part"), which is documented in the function comment along with the future direction (in-memory cache shared with the reader). Update TestDiskCache_NoStoreStoresButExpiresImmediately accordingly: the file must now exist on disk, but Get must still report a miss so prefetch refetches on the next Load. --- pkg/skills/cache.go | 37 +++++++++++++++++++++---------------- pkg/skills/cache_test.go | 37 ++++++++++++++++++++++--------------- 2 files changed, 43 insertions(+), 31 deletions(-) diff --git a/pkg/skills/cache.go b/pkg/skills/cache.go index 6fb7a197f..8d3b5b40c 100644 --- a/pkg/skills/cache.go +++ b/pkg/skills/cache.go @@ -103,8 +103,19 @@ func (c *diskCache) Get(baseURL, skillName, filePath string) (string, bool) { } // FetchAndStore downloads a file from the given URL and stores it in the cache. -// It respects Cache-Control headers to determine expiry, and skips the disk -// write entirely when the response is marked no-store (RFC 9111 §5.2.2.5). +// It respects Cache-Control headers to determine expiry: no-cache forces +// immediate expiry, max-age=N sets a TTL of N seconds, and unknown headers +// fall back to defaultCacheTTL. +// +// no-store is treated as "do not retain across Load() cycles": the entry is +// written to disk but with an immediate-expiry marker so the next prefetch +// refetches it. We do not skip the disk write entirely because callers +// (notably the read_skill tool, see pkg/tools/builtin/skills) consume the +// content by re-reading skill.FilePath, not by going through diskCache.Get. +// Skipping the write would render the skill unreadable for the rest of the +// current process. A future improvement is to keep no-store content in an +// in-memory map shared with the reader; until then we trade strict RFC +// 9111 §5.2.2.5 compliance for a working tool. func (c *diskCache) FetchAndStore(ctx context.Context, baseURL, skillName, filePath, fileURL string) (string, error) { slog.DebugContext(ctx, "Fetching remote skill file", "url", fileURL) @@ -125,16 +136,6 @@ func (c *diskCache) FetchAndStore(ctx context.Context, baseURL, skillName, fileP directive := parseCacheControl(resp.Header.Get("Cache-Control")) - // Honour no-store: never persist the response. This matters because the - // fetched content is fed to the LLM as instructions, so persisting an - // upstream-marked-private response on disk under ~/.cagent is both a - // privacy hazard (lingering sensitive content) and a correctness bug - // per RFC 9111 §5.2.2.5. - if directive.noStore { - slog.DebugContext(ctx, "Cache-Control no-store: skipping disk write", "url", fileURL) - return string(body), nil - } - dir := c.cacheDir(baseURL, skillName) contentPath := filepath.Join(dir, filePath) metaPath := contentPath + ".meta" @@ -185,12 +186,16 @@ type cacheDirective struct { } // expiresAt returns the absolute time after which the cached entry must -// not be served. no-cache forces immediate expiry — the response may be -// stored, but every read must re-validate (we currently approximate this -// as a refetch since conditional-GET support isn't implemented yet). +// not be reused. no-store and no-cache both force immediate expiry: the +// response may be on disk for the duration of the current process (so the +// in-process reader can consume it), but the next Load() cycle will +// refetch instead of reusing the stored copy. We currently approximate +// no-cache without conditional-GET support; the practical effect is the +// same as no-store with respect to whether the next read sees fresh +// content. func (d cacheDirective) expiresAt() time.Time { now := time.Now() - if d.noCache { + if d.noStore || d.noCache { return now } if d.hasMaxAge { diff --git a/pkg/skills/cache_test.go b/pkg/skills/cache_test.go index dacb50451..c9b03a744 100644 --- a/pkg/skills/cache_test.go +++ b/pkg/skills/cache_test.go @@ -129,9 +129,10 @@ func TestParseCacheControl(t *testing.T) { assert.WithinDuration(t, now, d.expiresAt(), 2*time.Second) }) - t.Run("no-store", func(t *testing.T) { + t.Run("no-store forces immediate expiry", func(t *testing.T) { d := parseCacheControl("no-store") assert.True(t, d.noStore) + assert.WithinDuration(t, now, d.expiresAt(), 2*time.Second) }) t.Run("no-cache forces immediate expiry", func(t *testing.T) { @@ -168,12 +169,19 @@ func TestDiskCache_HTTPError(t *testing.T) { assert.Contains(t, err.Error(), "HTTP 404") } -// TestDiskCache_NoStoreSkipsDiskWrite verifies that a Cache-Control: no-store -// response is returned in-memory but never persisted, per RFC 9111 §5.2.2.5. -// The skills cache feeds fetched content to the LLM as instructions, so -// persisting an upstream-marked-private response under ~/.cagent would be -// both a privacy hazard and a spec violation. -func TestDiskCache_NoStoreSkipsDiskWrite(t *testing.T) { +// TestDiskCache_NoStoreStoresButExpiresImmediately verifies that a +// Cache-Control: no-store response is still written to disk (so the +// in-process reader at pkg/tools/builtin/skills can consume it via +// readFileContent(skill.FilePath)) but is marked expired so the next +// Load() refetches instead of reusing the stored copy. +// +// We deliberately diverge from RFC 9111 §5.2.2.5 ("the cache MUST NOT +// store any part of either the immediate request or response") because +// the consumer reads files directly, not through diskCache.Get. Skipping +// the write entirely would render no-store skills unreadable for the +// rest of the process. A future refactor (in-memory cache shared with +// the reader) can make this strictly RFC-compliant. +func TestDiskCache_NoStoreStoresButExpiresImmediately(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Cache-Control", "no-store") fmt.Fprint(w, "private content") @@ -186,17 +194,16 @@ func TestDiskCache_NoStoreSkipsDiskWrite(t *testing.T) { require.NoError(t, err) assert.Equal(t, "private content", content) - // Nothing must have been persisted under the cache directory. + // The reader reads skill.FilePath directly, so the file must exist. filePath := filepath.Join(cache.cacheDir("https://example.com", "skill"), "SKILL.md") - _, err = os.Stat(filePath) - require.ErrorIs(t, err, os.ErrNotExist, "no-store response must not be written to disk") - - _, err = os.Stat(filePath + ".meta") - require.ErrorIs(t, err, os.ErrNotExist, "no-store response must not have metadata persisted") + data, err := os.ReadFile(filePath) + require.NoError(t, err) + assert.Equal(t, "private content", string(data)) - // And subsequent Get() must report a miss. + // But Get() must report a miss so prefetchFiles will refetch on the + // next Load() cycle rather than reusing a stale entry. _, ok := cache.Get("https://example.com", "skill", "SKILL.md") - assert.False(t, ok) + assert.False(t, ok, "no-store must force a refetch on the next read") } // TestDiskCache_NoCacheStoresButExpiresImmediately verifies that no-cache From 4f57a8827fce1d9e8fdd626f78fe2acc3662cdcf Mon Sep 17 00:00:00 2001 From: David Gageot Date: Wed, 20 May 2026 11:07:38 +0200 Subject: [PATCH 6/6] fix(mcp/oauth): respond 409 to duplicate OAuth callbacks instead of misleading 200 Addresses review feedback on the previous commit (drop stray OAuth callbacks instead of blocking the handler). Even though duplicate callbacks no longer wedge the handler goroutine, they were still receiving an 'Authorization Successful!' page even when their result was silently discarded by deliver(). This was misleading for users who refresh a stale browser tab, and gave a local attacker a way to probe whether a flow is in progress (200 vs 400 on a known nonce). deliver() now returns whether it actually delivered the result. The callback handler responds with HTTP 409 Conflict and an 'Authorization Already Processed' page when deliver() returns false, except for the state-mismatch path where the response stays unchanged so it doesn't leak whether a flow is currently active. TestCallbackServer_DuplicateCallbacksDoNotBlock now also verifies that the second through fifth callbacks return HTTP 409 rather than HTTP 200. --- pkg/tools/mcp/oauth_server.go | 54 ++++++++++++++++++++++++++---- pkg/tools/mcp/oauth_server_test.go | 12 +++++-- 2 files changed, 56 insertions(+), 10 deletions(-) diff --git a/pkg/tools/mcp/oauth_server.go b/pkg/tools/mcp/oauth_server.go index 82eb8dfb1..c7d3bd805 100644 --- a/pkg/tools/mcp/oauth_server.go +++ b/pkg/tools/mcp/oauth_server.go @@ -139,7 +139,10 @@ func (cs *CallbackServer) handleCallback(w http.ResponseWriter, r *http.Request) errMsg = fmt.Sprintf("%s: %s", errMsg, errDesc) } - cs.deliver(callbackResult{err: fmt.Errorf("OAuth error: %s", errMsg)}) + if !cs.deliver(callbackResult{err: fmt.Errorf("OAuth error: %s", errMsg)}) { + writeAlreadyProcessed(w) + return + } w.WriteHeader(http.StatusBadRequest) fmt.Fprintf(w, ` @@ -164,7 +167,10 @@ func (cs *CallbackServer) handleCallback(w http.ResponseWriter, r *http.Request) state := query.Get("state") if code == "" { - cs.deliver(callbackResult{err: errors.New("no authorization code received")}) + if !cs.deliver(callbackResult{err: errors.New("no authorization code received")}) { + writeAlreadyProcessed(w) + return + } w.WriteHeader(http.StatusBadRequest) fmt.Fprint(w, "No authorization code received") return @@ -178,13 +184,21 @@ func (cs *CallbackServer) handleCallback(w http.ResponseWriter, r *http.Request) cs.mu.Unlock() if expectedState == "" || subtle.ConstantTimeCompare([]byte(state), []byte(expectedState)) != 1 { + // Don't leak whether a flow is in progress: respond identically + // regardless of whether deliver succeeded. cs.deliver(callbackResult{err: errors.New("OAuth state mismatch (possible CSRF attempt or stale callback)")}) w.WriteHeader(http.StatusBadRequest) fmt.Fprint(w, "Invalid state parameter") return } - cs.deliver(callbackResult{code: code, state: state}) + if !cs.deliver(callbackResult{code: code, state: state}) { + // A previous callback already won the race. Tell the browser the + // flow is already complete instead of misleadingly claiming this + // stray request succeeded. + writeAlreadyProcessed(w) + return + } w.WriteHeader(http.StatusOK) fmt.Fprint(w, ` @@ -205,16 +219,42 @@ func (cs *CallbackServer) handleCallback(w http.ResponseWriter, r *http.Request) } // deliver attempts to publish r on resultCh without blocking. The first -// callback wins; later callbacks (stale browser tabs, duplicate clicks, -// any local process probing the loopback port) are dropped on the floor -// instead of pinning the HTTP handler goroutine on a full channel. -func (cs *CallbackServer) deliver(r callbackResult) { +// callback wins (returns true); later callbacks (stale browser tabs, +// duplicate clicks, any local process probing the loopback port) are +// dropped on the floor (returns false) instead of pinning the HTTP +// handler goroutine on a full channel. +func (cs *CallbackServer) deliver(r callbackResult) bool { select { case cs.resultCh <- r: + return true default: + return false } } +// writeAlreadyProcessed responds to a stray duplicate callback with HTTP +// 409 Conflict and a short HTML page. Returning a distinct status code +// rather than another "Authorization Successful!" page avoids misleading +// the user who reloaded the browser tab while still completing the request +// promptly so the handler goroutine doesn't linger. +func writeAlreadyProcessed(w http.ResponseWriter) { + w.WriteHeader(http.StatusConflict) + fmt.Fprint(w, ` + + + Authorization Already Processed + + + +

Authorization Already Processed

+

This authorization callback has already been handled.

+

You can close this window.

+ +`) +} + func (cs *CallbackServer) WaitForCallback(ctx context.Context) (code, state string, err error) { select { case r := <-cs.resultCh: diff --git a/pkg/tools/mcp/oauth_server_test.go b/pkg/tools/mcp/oauth_server_test.go index 5963b4b40..7d4b4fb21 100644 --- a/pkg/tools/mcp/oauth_server_test.go +++ b/pkg/tools/mcp/oauth_server_test.go @@ -97,7 +97,9 @@ func TestCallbackServer_DuplicateCallbacksDoNotBlock(t *testing.T) { // Fire several callbacks back-to-back. Each one must complete (so the // handler goroutine isn't stuck) regardless of whether anyone is - // reading from resultCh yet. + // reading from resultCh yet. The first one wins with HTTP 200; the + // rest must report HTTP 409 (Conflict) rather than misleadingly + // claiming success. for i := range 5 { req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, callbackURL, http.NoBody) if err != nil { @@ -108,8 +110,12 @@ func TestCallbackServer_DuplicateCallbacksDoNotBlock(t *testing.T) { t.Fatalf("callback %d: %v", i, err) } resp.Body.Close() - if resp.StatusCode != http.StatusOK { - t.Fatalf("callback %d: status = %d, want 200", i, resp.StatusCode) + wantStatus := http.StatusOK + if i > 0 { + wantStatus = http.StatusConflict + } + if resp.StatusCode != wantStatus { + t.Fatalf("callback %d: status = %d, want %d", i, resp.StatusCode, wantStatus) } }