diff --git a/pkg/evaluation/build.go b/pkg/evaluation/build.go index 58c786d78..d65497c58 100644 --- a/pkg/evaluation/build.go +++ b/pkg/evaluation/build.go @@ -6,6 +6,7 @@ import ( _ "embed" "errors" "fmt" + "io/fs" "os" "os/exec" "path/filepath" @@ -91,7 +92,7 @@ func (r *Runner) buildEvalImage(ctx context.Context, evals *session.EvalCriteria data.CopyWorkingDir = false } else { buildContext = filepath.Join(r.EvalsDir, "working_dirs", evals.WorkingDir) - if _, err := os.Stat(buildContext); os.IsNotExist(err) { + if _, err := os.Stat(buildContext); errors.Is(err, fs.ErrNotExist) { return "", fmt.Errorf("working directory not found: %s", buildContext) } data.CopyWorkingDir = true diff --git a/pkg/rag/strategy/bm25.go b/pkg/rag/strategy/bm25.go index be3faf54a..4524b0cb7 100644 --- a/pkg/rag/strategy/bm25.go +++ b/pkg/rag/strategy/bm25.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "io/fs" "log/slog" "math" "os" @@ -642,7 +643,7 @@ func (s *BM25Strategy) addPathToWatcher(ctx context.Context, path string) error stat, err := os.Stat(absPath) if err != nil { - if os.IsNotExist(err) { + if errors.Is(err, fs.ErrNotExist) { return nil } return fmt.Errorf("failed to stat path: %w", err) diff --git a/pkg/rag/strategy/bm25_database.go b/pkg/rag/strategy/bm25_database.go index 32be73635..54ff217c4 100644 --- a/pkg/rag/strategy/bm25_database.go +++ b/pkg/rag/strategy/bm25_database.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "fmt" + "io/fs" "log/slog" "os" "path/filepath" @@ -159,7 +160,7 @@ func (d *bm25DB) GetFileMetadata(ctx context.Context, sourcePath string) (*datab fmt.Sprintf("SELECT source_path, file_hash, last_indexed, chunk_count FROM %s WHERE source_path = ?", d.metadataTable), sourcePath).Scan(&metadata.SourcePath, &metadata.FileHash, &metadata.LastIndexed, &metadata.ChunkCount) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, nil } if err != nil { @@ -220,7 +221,7 @@ func ensureDir(filePath string) error { return nil } - if _, err := os.Stat(dir); os.IsNotExist(err) { + if _, err := os.Stat(dir); errors.Is(err, fs.ErrNotExist) { return os.MkdirAll(dir, 0o700) } diff --git a/pkg/rag/strategy/chunked_embeddings_database.go b/pkg/rag/strategy/chunked_embeddings_database.go index f0253dbbb..373ff22de 100644 --- a/pkg/rag/strategy/chunked_embeddings_database.go +++ b/pkg/rag/strategy/chunked_embeddings_database.go @@ -197,7 +197,7 @@ func (d *chunkedVectorDB) GetFileMetadata(ctx context.Context, sourcePath string GROUP BY f.source_path, f.file_hash, f.indexed_at`, d.filesTable, d.chunksTable), sourcePath).Scan(&metadata.SourcePath, &metadata.FileHash, &metadata.LastIndexed, &metadata.ChunkCount) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, nil } if err != nil { diff --git a/pkg/rag/strategy/semantic_embeddings_database.go b/pkg/rag/strategy/semantic_embeddings_database.go index 620d375c7..2d597769a 100644 --- a/pkg/rag/strategy/semantic_embeddings_database.go +++ b/pkg/rag/strategy/semantic_embeddings_database.go @@ -207,7 +207,7 @@ func (d *semanticVectorDB) GetFileMetadata(ctx context.Context, sourcePath strin GROUP BY f.source_path, f.file_hash, f.indexed_at`, d.filesTable, d.chunksTable), sourcePath).Scan(&metadata.SourcePath, &metadata.FileHash, &metadata.LastIndexed, &metadata.ChunkCount) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, nil } if err != nil { diff --git a/pkg/skills/cache.go b/pkg/skills/cache.go index 3242a6ef9..8d3b5b40c 100644 --- a/pkg/skills/cache.go +++ b/pkg/skills/cache.go @@ -5,8 +5,10 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "errors" "fmt" "io" + "io/fs" "log/slog" "net/http" "os" @@ -15,25 +17,32 @@ import ( "strings" "time" - "github.com/docker/docker-agent/pkg/remote" + "github.com/docker/docker-agent/pkg/httpclient" ) // remoteHTTPTimeout caps each HTTP request made to a remote skills source. const remoteHTTPTimeout = 30 * time.Second -// httpGet performs a GET request using the standard remote transport so that -// Docker Desktop proxy/SSL settings are honoured. The returned response body -// must be closed by the caller. +// skillsHTTPClient is used for outbound calls to remote skill registries. +// The base URL is operator-supplied and the contents are fed to the model +// as instructions, so a hostile (or compromised) registry could otherwise +// be used to read internal endpoints (loopback, RFC1918, link-local incl. +// cloud metadata at 169.254.169.254) and exfiltrate them through prompt +// injection. The SSRF-safe client refuses such targets at dial time, after +// DNS resolution, defeating DNS rebinding. +// +// Tests in this package replace the var via TestMain (see main_test.go) +// because httptest.NewServer binds to 127.0.0.1. +var skillsHTTPClient = httpclient.NewSafeClient(remoteHTTPTimeout, false) + +// httpGet performs a GET request using the SSRF-safe HTTP client. The +// returned response body must be closed by the caller. func httpGet(ctx context.Context, url string) (*http.Response, error) { - client := &http.Client{ - Timeout: remoteHTTPTimeout, - Transport: remote.NewTransport(ctx), - } req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) if err != nil { return nil, fmt.Errorf("creating request for %s: %w", url, err) } - return client.Do(req) + return skillsHTTPClient.Do(req) } type diskCache struct { @@ -61,6 +70,10 @@ func (c *diskCache) cacheDir(baseURL, skillName string) string { } // Get returns the cached content for a file if it exists and is not expired. +// Treats missing-file errors as a cache miss (returns false). Other I/O +// errors (e.g. EACCES, corrupt JSON) are surfaced through a debug log so +// they don't masquerade as a benign refetch trigger but still don't break +// the caller — a refetch is the right fallback. func (c *diskCache) Get(baseURL, skillName, filePath string) (string, bool) { dir := c.cacheDir(baseURL, skillName) contentPath := filepath.Join(dir, filePath) @@ -68,6 +81,9 @@ func (c *diskCache) Get(baseURL, skillName, filePath string) (string, bool) { meta, err := c.readMetadata(metaPath) if err != nil { + if !errors.Is(err, fs.ErrNotExist) { + slog.Debug("Skill cache metadata unreadable, treating as miss", "path", metaPath, "error", err) + } return "", false } @@ -77,6 +93,9 @@ func (c *diskCache) Get(baseURL, skillName, filePath string) (string, bool) { data, err := os.ReadFile(contentPath) if err != nil { + if !errors.Is(err, fs.ErrNotExist) { + slog.Debug("Skill cache content unreadable, treating as miss", "path", contentPath, "error", err) + } return "", false } @@ -84,7 +103,19 @@ func (c *diskCache) Get(baseURL, skillName, filePath string) (string, bool) { } // FetchAndStore downloads a file from the given URL and stores it in the cache. -// It respects Cache-Control headers to determine expiry. +// It respects Cache-Control headers to determine expiry: no-cache forces +// immediate expiry, max-age=N sets a TTL of N seconds, and unknown headers +// fall back to defaultCacheTTL. +// +// no-store is treated as "do not retain across Load() cycles": the entry is +// written to disk but with an immediate-expiry marker so the next prefetch +// refetches it. We do not skip the disk write entirely because callers +// (notably the read_skill tool, see pkg/tools/builtin/skills) consume the +// content by re-reading skill.FilePath, not by going through diskCache.Get. +// Skipping the write would render the skill unreadable for the rest of the +// current process. A future improvement is to keep no-store content in an +// in-memory map shared with the reader; until then we trade strict RFC +// 9111 §5.2.2.5 compliance for a working tool. func (c *diskCache) FetchAndStore(ctx context.Context, baseURL, skillName, filePath, fileURL string) (string, error) { slog.DebugContext(ctx, "Fetching remote skill file", "url", fileURL) @@ -103,7 +134,7 @@ func (c *diskCache) FetchAndStore(ctx context.Context, baseURL, skillName, fileP return "", fmt.Errorf("reading %s: %w", fileURL, err) } - expiresAt := parseCacheExpiry(resp.Header.Get("Cache-Control")) + directive := parseCacheControl(resp.Header.Get("Cache-Control")) dir := c.cacheDir(baseURL, skillName) contentPath := filepath.Join(dir, filePath) @@ -120,7 +151,7 @@ func (c *diskCache) FetchAndStore(ctx context.Context, baseURL, skillName, fileP meta := cacheMetadata{ URL: fileURL, CachedAt: time.Now(), - ExpiresAt: expiresAt, + ExpiresAt: directive.expiresAt(), } metaJSON, _ := json.Marshal(meta) if err := os.WriteFile(metaPath, metaJSON, 0o600); err != nil { @@ -145,28 +176,59 @@ func (c *diskCache) readMetadata(metaPath string) (cacheMetadata, error) { const defaultCacheTTL = 1 * time.Hour -// parseCacheExpiry extracts the expiry time from a Cache-Control header value. -// Falls back to defaultCacheTTL if the header is missing or unparseable. -func parseCacheExpiry(cacheControl string) time.Time { - if cacheControl == "" { - return time.Now().Add(defaultCacheTTL) +// cacheDirective is the parsed subset of Cache-Control we care about. +type cacheDirective struct { + noStore bool + noCache bool + // hasMaxAge is true when the header explicitly carried max-age=N. + hasMaxAge bool + maxAge time.Duration +} + +// expiresAt returns the absolute time after which the cached entry must +// not be reused. no-store and no-cache both force immediate expiry: the +// response may be on disk for the duration of the current process (so the +// in-process reader can consume it), but the next Load() cycle will +// refetch instead of reusing the stored copy. We currently approximate +// no-cache without conditional-GET support; the practical effect is the +// same as no-store with respect to whether the next read sees fresh +// content. +func (d cacheDirective) expiresAt() time.Time { + now := time.Now() + if d.noStore || d.noCache { + return now + } + if d.hasMaxAge { + return now.Add(d.maxAge) + } + return now.Add(defaultCacheTTL) +} + +// parseCacheControl extracts the directives we honour from a Cache-Control +// header value. Unknown directives are ignored; an empty header yields the +// zero value, which falls back to defaultCacheTTL via expiresAt(). +func parseCacheControl(header string) cacheDirective { + var d cacheDirective + if header == "" { + return d } - for directive := range strings.SplitSeq(cacheControl, ",") { + for directive := range strings.SplitSeq(header, ",") { directive = strings.TrimSpace(directive) - if strings.EqualFold(directive, "no-store") || strings.EqualFold(directive, "no-cache") { - // Still cache, but with zero TTL so it's refetched next time - return time.Now() - } - - if strings.HasPrefix(strings.ToLower(directive), "max-age=") { + switch { + case strings.EqualFold(directive, "no-store"): + d.noStore = true + case strings.EqualFold(directive, "no-cache"): + d.noCache = true + case strings.HasPrefix(strings.ToLower(directive), "max-age="): ageStr := directive[len("max-age="):] if seconds, err := strconv.ParseInt(ageStr, 10, 64); err == nil && seconds >= 0 { - return time.Now().Add(time.Duration(seconds) * time.Second) + d.hasMaxAge = true + d.maxAge = time.Duration(seconds) * time.Second } } } - return time.Now().Add(defaultCacheTTL) + return d } diff --git a/pkg/skills/cache_test.go b/pkg/skills/cache_test.go index 5ec753ed6..c9b03a744 100644 --- a/pkg/skills/cache_test.go +++ b/pkg/skills/cache_test.go @@ -107,42 +107,54 @@ func TestDiskCache_DifferentURLsGetDifferentDirs(t *testing.T) { assert.NotEqual(t, dir1, dir2) } -func TestParseCacheExpiry(t *testing.T) { +func TestParseCacheControl(t *testing.T) { now := time.Now() t.Run("empty header uses default", func(t *testing.T) { - expiry := parseCacheExpiry("") - assert.WithinDuration(t, now.Add(1*time.Hour), expiry, 2*time.Second) + d := parseCacheControl("") + assert.False(t, d.noStore) + assert.False(t, d.noCache) + assert.WithinDuration(t, now.Add(1*time.Hour), d.expiresAt(), 2*time.Second) }) t.Run("max-age=3600", func(t *testing.T) { - expiry := parseCacheExpiry("max-age=3600") - assert.WithinDuration(t, now.Add(3600*time.Second), expiry, 2*time.Second) + d := parseCacheControl("max-age=3600") + assert.True(t, d.hasMaxAge) + assert.WithinDuration(t, now.Add(3600*time.Second), d.expiresAt(), 2*time.Second) }) t.Run("max-age=0", func(t *testing.T) { - expiry := parseCacheExpiry("max-age=0") - assert.WithinDuration(t, now, expiry, 2*time.Second) + d := parseCacheControl("max-age=0") + assert.True(t, d.hasMaxAge) + assert.WithinDuration(t, now, d.expiresAt(), 2*time.Second) }) - t.Run("no-store", func(t *testing.T) { - expiry := parseCacheExpiry("no-store") - assert.WithinDuration(t, now, expiry, 2*time.Second) + t.Run("no-store forces immediate expiry", func(t *testing.T) { + d := parseCacheControl("no-store") + assert.True(t, d.noStore) + assert.WithinDuration(t, now, d.expiresAt(), 2*time.Second) }) - t.Run("no-cache", func(t *testing.T) { - expiry := parseCacheExpiry("no-cache") - assert.WithinDuration(t, now, expiry, 2*time.Second) + t.Run("no-cache forces immediate expiry", func(t *testing.T) { + d := parseCacheControl("no-cache") + assert.True(t, d.noCache) + assert.WithinDuration(t, now, d.expiresAt(), 2*time.Second) + }) + + t.Run("no-cache wins over max-age", func(t *testing.T) { + d := parseCacheControl("max-age=3600, no-cache") + assert.True(t, d.noCache) + assert.WithinDuration(t, now, d.expiresAt(), 2*time.Second) }) t.Run("multiple directives with max-age", func(t *testing.T) { - expiry := parseCacheExpiry("public, max-age=7200") - assert.WithinDuration(t, now.Add(7200*time.Second), expiry, 2*time.Second) + d := parseCacheControl("public, max-age=7200") + assert.WithinDuration(t, now.Add(7200*time.Second), d.expiresAt(), 2*time.Second) }) t.Run("unknown directives use default", func(t *testing.T) { - expiry := parseCacheExpiry("public") - assert.WithinDuration(t, now.Add(1*time.Hour), expiry, 2*time.Second) + d := parseCacheControl("public") + assert.WithinDuration(t, now.Add(1*time.Hour), d.expiresAt(), 2*time.Second) }) } @@ -156,3 +168,63 @@ func TestDiskCache_HTTPError(t *testing.T) { require.Error(t, err) assert.Contains(t, err.Error(), "HTTP 404") } + +// TestDiskCache_NoStoreStoresButExpiresImmediately verifies that a +// Cache-Control: no-store response is still written to disk (so the +// in-process reader at pkg/tools/builtin/skills can consume it via +// readFileContent(skill.FilePath)) but is marked expired so the next +// Load() refetches instead of reusing the stored copy. +// +// We deliberately diverge from RFC 9111 §5.2.2.5 ("the cache MUST NOT +// store any part of either the immediate request or response") because +// the consumer reads files directly, not through diskCache.Get. Skipping +// the write entirely would render no-store skills unreadable for the +// rest of the process. A future refactor (in-memory cache shared with +// the reader) can make this strictly RFC-compliant. +func TestDiskCache_NoStoreStoresButExpiresImmediately(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Cache-Control", "no-store") + fmt.Fprint(w, "private content") + })) + defer srv.Close() + + cache := newDiskCache(t.TempDir()) + + content, err := cache.FetchAndStore(t.Context(), "https://example.com", "skill", "SKILL.md", srv.URL+"/SKILL.md") + require.NoError(t, err) + assert.Equal(t, "private content", content) + + // The reader reads skill.FilePath directly, so the file must exist. + filePath := filepath.Join(cache.cacheDir("https://example.com", "skill"), "SKILL.md") + data, err := os.ReadFile(filePath) + require.NoError(t, err) + assert.Equal(t, "private content", string(data)) + + // But Get() must report a miss so prefetchFiles will refetch on the + // next Load() cycle rather than reusing a stale entry. + _, ok := cache.Get("https://example.com", "skill", "SKILL.md") + assert.False(t, ok, "no-store must force a refetch on the next read") +} + +// TestDiskCache_NoCacheStoresButExpiresImmediately verifies that no-cache +// allows storage but forces revalidation: the entry is written so it can be +// inspected, but Get() must report a miss so the next read refetches. +func TestDiskCache_NoCacheStoresButExpiresImmediately(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Cache-Control", "no-cache") + fmt.Fprint(w, "revalidate me") + })) + defer srv.Close() + + cache := newDiskCache(t.TempDir()) + + _, err := cache.FetchAndStore(t.Context(), "https://example.com", "skill", "SKILL.md", srv.URL+"/SKILL.md") + require.NoError(t, err) + + filePath := filepath.Join(cache.cacheDir("https://example.com", "skill"), "SKILL.md") + _, err = os.Stat(filePath) + require.NoError(t, err, "no-cache response should still be stored on disk") + + _, ok := cache.Get("https://example.com", "skill", "SKILL.md") + assert.False(t, ok, "no-cache must force a refetch on the next read") +} diff --git a/pkg/skills/main_test.go b/pkg/skills/main_test.go new file mode 100644 index 000000000..833235db8 --- /dev/null +++ b/pkg/skills/main_test.go @@ -0,0 +1,16 @@ +package skills + +import ( + "os" + "testing" + + "github.com/docker/docker-agent/pkg/httpclient" +) + +// TestMain swaps the SSRF-safe HTTP client for the loopback-allowing +// variant so tests can hit httptest.NewServer (which binds to 127.0.0.1). +// Production code keeps the safe client. +func TestMain(m *testing.M) { + skillsHTTPClient = httpclient.NewSafeClient(remoteHTTPTimeout, true) + os.Exit(m.Run()) +} diff --git a/pkg/tools/mcp/oauth_server.go b/pkg/tools/mcp/oauth_server.go index aa727bb7c..c7d3bd805 100644 --- a/pkg/tools/mcp/oauth_server.go +++ b/pkg/tools/mcp/oauth_server.go @@ -21,15 +21,24 @@ type CallbackServer struct { listener net.Listener mu sync.Mutex - // Channels for communicating the authorization code and state - codeCh chan string - stateCh chan string - errCh chan error + // resultCh delivers the outcome of the first received callback. + // It is buffered (size 1) and all sends are non-blocking so that a + // stray duplicate or attacker-triggered callback cannot wedge the + // HTTP handler goroutine on a full channel. + resultCh chan callbackResult // Expected state parameter for CSRF protection expectedState string } +// callbackResult is the outcome of a single OAuth callback. Exactly one +// of err / (code, state) is set. +type callbackResult struct { + code string + state string + err error +} + // NewCallbackServer creates a new OAuth callback server on a random available port func NewCallbackServer() (*CallbackServer, error) { return NewCallbackServerOnPort(0) @@ -46,9 +55,7 @@ func NewCallbackServerOnPort(port int) (*CallbackServer, error) { cs := &CallbackServer{ listener: listener, - codeCh: make(chan string, 1), - stateCh: make(chan string, 1), - errCh: make(chan error, 1), + resultCh: make(chan callbackResult, 1), } mux := http.NewServeMux() @@ -132,7 +139,10 @@ func (cs *CallbackServer) handleCallback(w http.ResponseWriter, r *http.Request) errMsg = fmt.Sprintf("%s: %s", errMsg, errDesc) } - cs.errCh <- fmt.Errorf("OAuth error: %s", errMsg) + if !cs.deliver(callbackResult{err: fmt.Errorf("OAuth error: %s", errMsg)}) { + writeAlreadyProcessed(w) + return + } w.WriteHeader(http.StatusBadRequest) fmt.Fprintf(w, ` @@ -157,7 +167,10 @@ func (cs *CallbackServer) handleCallback(w http.ResponseWriter, r *http.Request) state := query.Get("state") if code == "" { - cs.errCh <- errors.New("no authorization code received") + if !cs.deliver(callbackResult{err: errors.New("no authorization code received")}) { + writeAlreadyProcessed(w) + return + } w.WriteHeader(http.StatusBadRequest) fmt.Fprint(w, "No authorization code received") return @@ -171,14 +184,21 @@ func (cs *CallbackServer) handleCallback(w http.ResponseWriter, r *http.Request) cs.mu.Unlock() if expectedState == "" || subtle.ConstantTimeCompare([]byte(state), []byte(expectedState)) != 1 { - cs.errCh <- errors.New("OAuth state mismatch (possible CSRF attempt or stale callback)") + // Don't leak whether a flow is in progress: respond identically + // regardless of whether deliver succeeded. + cs.deliver(callbackResult{err: errors.New("OAuth state mismatch (possible CSRF attempt or stale callback)")}) w.WriteHeader(http.StatusBadRequest) fmt.Fprint(w, "Invalid state parameter") return } - cs.codeCh <- code - cs.stateCh <- state + if !cs.deliver(callbackResult{code: code, state: state}) { + // A previous callback already won the race. Tell the browser the + // flow is already complete instead of misleadingly claiming this + // stray request succeeded. + writeAlreadyProcessed(w) + return + } w.WriteHeader(http.StatusOK) fmt.Fprint(w, ` @@ -198,17 +218,47 @@ func (cs *CallbackServer) handleCallback(w http.ResponseWriter, r *http.Request) `) } +// deliver attempts to publish r on resultCh without blocking. The first +// callback wins (returns true); later callbacks (stale browser tabs, +// duplicate clicks, any local process probing the loopback port) are +// dropped on the floor (returns false) instead of pinning the HTTP +// handler goroutine on a full channel. +func (cs *CallbackServer) deliver(r callbackResult) bool { + select { + case cs.resultCh <- r: + return true + default: + return false + } +} + +// writeAlreadyProcessed responds to a stray duplicate callback with HTTP +// 409 Conflict and a short HTML page. Returning a distinct status code +// rather than another "Authorization Successful!" page avoids misleading +// the user who reloaded the browser tab while still completing the request +// promptly so the handler goroutine doesn't linger. +func writeAlreadyProcessed(w http.ResponseWriter) { + w.WriteHeader(http.StatusConflict) + fmt.Fprint(w, ` + + + Authorization Already Processed + + + +

Authorization Already Processed

+

This authorization callback has already been handled.

+

You can close this window.

+ +`) +} + func (cs *CallbackServer) WaitForCallback(ctx context.Context) (code, state string, err error) { select { - case code = <-cs.codeCh: - select { - case state = <-cs.stateCh: - return code, state, nil - case <-ctx.Done(): - return "", "", ctx.Err() - } - case err = <-cs.errCh: - return "", "", err + case r := <-cs.resultCh: + return r.code, r.state, r.err case <-ctx.Done(): return "", "", ctx.Err() } diff --git a/pkg/tools/mcp/oauth_server_test.go b/pkg/tools/mcp/oauth_server_test.go index c89bffab2..7d4b4fb21 100644 --- a/pkg/tools/mcp/oauth_server_test.go +++ b/pkg/tools/mcp/oauth_server_test.go @@ -2,9 +2,11 @@ package mcp import ( "fmt" + "net/http" "strconv" "strings" "testing" + "time" ) func TestCallbackServer_Port(t *testing.T) { @@ -73,6 +75,60 @@ func TestBuildRedirectURI(t *testing.T) { } } +// TestCallbackServer_DuplicateCallbacksDoNotBlock guards against a regression +// where extra callbacks (stale browser tabs, page refreshes, or any local +// process probing the loopback port) blocked the HTTP handler goroutine on +// a full result channel. Sends are now non-blocking; the first callback +// wins and later ones must be dropped without wedging the server. +func TestCallbackServer_DuplicateCallbacksDoNotBlock(t *testing.T) { + cs, err := NewCallbackServer() + if err != nil { + t.Fatal(err) + } + if err := cs.Start(); err != nil { + t.Fatal(err) + } + defer func() { _ = cs.Shutdown(t.Context()) }() + + cs.SetExpectedState("expected-state") + callbackURL := cs.GetRedirectURI() + "?code=authcode&state=expected-state" + + client := &http.Client{Timeout: 2 * time.Second} + + // Fire several callbacks back-to-back. Each one must complete (so the + // handler goroutine isn't stuck) regardless of whether anyone is + // reading from resultCh yet. The first one wins with HTTP 200; the + // rest must report HTTP 409 (Conflict) rather than misleadingly + // claiming success. + for i := range 5 { + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, callbackURL, http.NoBody) + if err != nil { + t.Fatal(err) + } + resp, err := client.Do(req) + if err != nil { + t.Fatalf("callback %d: %v", i, err) + } + resp.Body.Close() + wantStatus := http.StatusOK + if i > 0 { + wantStatus = http.StatusConflict + } + if resp.StatusCode != wantStatus { + t.Fatalf("callback %d: status = %d, want %d", i, resp.StatusCode, wantStatus) + } + } + + // The first callback must still be deliverable to the waiter. + code, state, err := cs.WaitForCallback(t.Context()) + if err != nil { + t.Fatalf("WaitForCallback: %v", err) + } + if code != "authcode" || state != "expected-state" { + t.Errorf("got code=%q state=%q, want code=authcode state=expected-state", code, state) + } +} + // TestCallbackServer_ResolveRedirectURI exercises the method wrapper end-to-end // to make sure it stitches GetRedirectURI() and Port() together correctly. func TestCallbackServer_ResolveRedirectURI(t *testing.T) {