From af548174838aa4d7b377421d69d8bcd09e8ec097 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Fri, 1 May 2026 12:36:05 +0000 Subject: [PATCH 1/8] test(proxy): integration tests for session correlation audit and header agreement Add integration tests that verify the core invariants of session correlation across the proxy, auditor, and forwarded request headers working together. These tests fill the gap identified during review of the session correlation PR stack (#196, #197, #198) where unit tests verified each component in isolation but did not verify them in concert. New test file: proxy/proxy_session_correlation_integration_test.go Tests added: - LLMRequestAuditAndHeadersAgree: audit sequence number matches the forwarded header value on inject-target requests. - NonLLMRequestAuditedWithoutHeaders: allowed non-inject-target requests are audited but carry no correlation headers. - DeniedRequestAuditedNeverForwarded: denied requests consume a sequence number but are never forwarded. - MixedRequestsSequenceOrdering: interleaved LLM, non-LLM, and denied requests all advance the counter monotonically. - SequenceGapRevealsAgenticLoop: gap between two LLM sequence numbers precisely equals intermediate tool-use requests. - SpoofedHeadersOverwrittenWithCorrectSequence: client-supplied headers are replaced and the audit event still agrees. - DisabledCorrelationNoHeadersNoPreallocatedSequence: disabled correlation means no headers and no pre-allocated sequence. - ConcurrentRequestsUniqueSequenceNumbers: concurrent requests each get a unique, dense sequence number. --- ...xy_session_correlation_integration_test.go | 578 ++++++++++++++++++ 1 file changed, 578 insertions(+) create mode 100644 proxy/proxy_session_correlation_integration_test.go diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go new file mode 100644 index 0000000..b8035b7 --- /dev/null +++ b/proxy/proxy_session_correlation_integration_test.go @@ -0,0 +1,578 @@ +package proxy + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "sync" + "testing" + + "github.com/coder/boundary/audit" + "github.com/coder/boundary/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// multiRequestCapturingBackend records the headers from every request it +// receives, not just the last one. This is needed by integration tests +// that send multiple requests to the same backend and want to verify +// each one independently. +type multiRequestCapturingBackend struct { + server *httptest.Server + mu sync.Mutex + all []http.Header +} + +func newMultiRequestCapturingBackend() *multiRequestCapturingBackend { + mcb := &multiRequestCapturingBackend{} + mcb.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mcb.mu.Lock() + mcb.all = append(mcb.all, r.Header.Clone()) + mcb.mu.Unlock() + w.WriteHeader(http.StatusOK) + })) + return mcb +} + +func (m *multiRequestCapturingBackend) close() { m.server.Close() } + +func (m *multiRequestCapturingBackend) requestCount() int { + m.mu.Lock() + defer m.mu.Unlock() + return len(m.all) +} + +func (m *multiRequestCapturingBackend) headersAt(i int) http.Header { + m.mu.Lock() + defer m.mu.Unlock() + return m.all[i].Clone() +} + +// sessionCorrelationIntegrationSetup holds the shared objects for an +// integration test: the proxy, auditor, backend(s), and sequence +// counter. Tests build one via newSessionCorrelationIntegrationSetup +// and tear it down with stop. +type sessionCorrelationIntegrationSetup struct { + pt *ProxyTest + auditor *capturingAuditor + seq *audit.SequenceCounter + llmBackend *multiRequestCapturingBackend + otherBackend *multiRequestCapturingBackend +} + +func (s *sessionCorrelationIntegrationSetup) stop() { + s.pt.Stop() + if s.llmBackend != nil { + s.llmBackend.close() + } + if s.otherBackend != nil { + s.otherBackend.close() + } +} + +// newSessionCorrelationIntegrationSetup builds a proxy that allows +// traffic to two httptest backends: one that matches an inject target +// (simulating an LLM provider) and one that does not (simulating a +// generic allowed domain like github.com). Both backends capture all +// received request headers. A capturingAuditor records every audit +// event for later inspection. +func newSessionCorrelationIntegrationSetup(t *testing.T, sessionID string) *sessionCorrelationIntegrationSetup { + t.Helper() + + llm := newMultiRequestCapturingBackend() + other := newMultiRequestCapturingBackend() + + llmURL, err := url.Parse(llm.server.URL) + require.NoError(t, err) + + otherURL, err := url.Parse(other.server.URL) + require.NoError(t, err) + + aud := &capturingAuditor{} + seq := &audit.SequenceCounter{} + + // Both httptest backends resolve to 127.0.0.1, so a domain-only + // inject target would match both. We use a path glob on the LLM + // paths (/v1/*) to limit header injection to LLM requests. + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + // Allow both backends. + WithAllowedDomain(llmURL.Hostname()), + WithAllowedDomain(otherURL.Hostname()), + // Only requests matching the LLM path receive headers. + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{ + Domain: llmURL.Hostname(), + Path: "/v1/*", + }}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID(sessionID), + WithSequenceCounter(seq), + WithAuditor(aud), + ).Start() + + return &sessionCorrelationIntegrationSetup{ + pt: pt, + auditor: aud, + seq: seq, + llmBackend: llm, + otherBackend: other, + } +} + +// ---------- Integration Tests ---------- + +// TestIntegration_LLMRequestAuditAndHeadersAgree verifies the core +// correlation invariant: when an allowed request hits an inject target, +// the sequence number in the audit event equals the sequence number in +// the forwarded header. +func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { + const sessionID = "e5f6a7b8-0000-0000-0000-000000000000" + s := newSessionCorrelationIntegrationSetup(t, sessionID) + defer s.stop() + + resp, err := s.pt.proxyClient.Get(s.llmBackend.server.URL + "/v1/messages") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Audit event. + events := s.auditor.getRequests() + require.Len(t, events, 1) + require.True(t, events[0].Allowed) + require.NotNil(t, events[0].SequenceNumber) + assert.Equal(t, uint64(0), *events[0].SequenceNumber) + + // Forwarded headers. + require.Equal(t, 1, s.llmBackend.requestCount()) + hdr := s.llmBackend.headersAt(0) + assert.Equal(t, sessionID, hdr.Get(config.DefaultSessionIDHeaderName)) + assert.Equal(t, "0", hdr.Get(config.DefaultSequenceNumberHeaderName)) + + // The two must agree. + assert.Equal(t, + strconv.FormatUint(*events[0].SequenceNumber, 10), + hdr.Get(config.DefaultSequenceNumberHeaderName), + "audit event and forwarded header must carry the same sequence number", + ) +} + +// TestIntegration_NonLLMRequestAuditedWithoutHeaders verifies that an +// allowed request to a domain that is NOT an inject target still gets +// audited (with a sequence number) but does NOT receive correlation +// headers. +func TestIntegration_NonLLMRequestAuditedWithoutHeaders(t *testing.T) { + s := newSessionCorrelationIntegrationSetup(t, "test-session") + defer s.stop() + + resp, err := s.pt.proxyClient.Get(s.otherBackend.server.URL + "/pulls") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Audit event recorded. + events := s.auditor.getRequests() + require.Len(t, events, 1) + require.True(t, events[0].Allowed) + require.NotNil(t, events[0].SequenceNumber) + assert.Equal(t, uint64(0), *events[0].SequenceNumber) + + // No correlation headers on the backend. + require.Equal(t, 1, s.otherBackend.requestCount()) + hdr := s.otherBackend.headersAt(0) + assert.Empty(t, hdr.Get(config.DefaultSessionIDHeaderName), + "non-inject-target requests must not carry session ID header") + assert.Empty(t, hdr.Get(config.DefaultSequenceNumberHeaderName), + "non-inject-target requests must not carry sequence number header") +} + +// TestIntegration_DeniedRequestAuditedNeverForwarded verifies that a +// request denied by the rules engine is audited (consuming a sequence +// number) but is never forwarded to any backend. +func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { + // Create a setup with a custom deny-all proxy, but keep the same + // pattern of shared sequence counter and auditor. + llm := newMultiRequestCapturingBackend() + defer llm.close() + + aud := &capturingAuditor{} + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + // No allowed domains: deny everything. + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{Domain: "anything.example.com"}}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("test-session"), + WithSequenceCounter(seq), + WithAuditor(aud), + ).Start() + defer pt.Stop() + + resp, err := pt.proxyClient.Get(llm.server.URL + "/exfil") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusForbidden, resp.StatusCode) + + // Audit event recorded. + events := aud.getRequests() + require.Len(t, events, 1) + require.False(t, events[0].Allowed) + require.NotNil(t, events[0].SequenceNumber) + assert.Equal(t, uint64(0), *events[0].SequenceNumber) + + // Backend never hit. + assert.Equal(t, 0, llm.requestCount(), + "denied requests must not be forwarded to the backend") +} + +// TestIntegration_MixedRequestsSequenceOrdering sends a realistic +// sequence of LLM, non-LLM, and denied requests, then verifies: +// 1. Sequence numbers increase monotonically across all request types. +// 2. Only inject-target requests carry correlation headers. +// 3. The sequence numbers in headers match the audit events. +// 4. The gap between two LLM requests' sequence numbers reveals the +// intermediate non-LLM and denied activity. +func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { + const sessionID = "mixed-test-session" + + // Two allowed backends (LLM and "github"), one denied domain. + llm := newMultiRequestCapturingBackend() + defer llm.close() + + other := newMultiRequestCapturingBackend() + defer other.close() + + llmURL, err := url.Parse(llm.server.URL) + require.NoError(t, err) + + otherURL, err := url.Parse(other.server.URL) + require.NoError(t, err) + + aud := &capturingAuditor{} + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(llmURL.Hostname()), + WithAllowedDomain(otherURL.Hostname()), + // Only LLM is an inject target. + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{ + Domain: llmURL.Hostname(), + Path: "/v1/*", + }}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID(sessionID), + WithSequenceCounter(seq), + WithAuditor(aud), + ).Start() + defer pt.Stop() + + // Request 0: LLM (allowed, inject target). + resp, err := pt.proxyClient.Get(llm.server.URL + "/v1/messages") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Request 1: non-LLM (allowed, no inject). + resp, err = pt.proxyClient.Get(other.server.URL + "/coder/coder") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Request 2: denied (nothing is allowed for evil.example.com). + resp, err = pt.proxyClient.Get("http://evil.example.com/exfil") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusForbidden, resp.StatusCode) + + // Request 3: LLM again. + resp, err = pt.proxyClient.Get(llm.server.URL + "/v1/messages") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // -- Verify audit events -- + events := aud.getRequests() + require.Len(t, events, 4, "expected exactly four audit events") + + expectedSeq := []uint64{0, 1, 2, 3} + expectedAllowed := []bool{true, true, false, true} + for i, ev := range events { + require.NotNil(t, ev.SequenceNumber, "event %d: sequence number must be set", i) + assert.Equal(t, expectedSeq[i], *ev.SequenceNumber, + "event %d: wrong sequence number", i) + assert.Equal(t, expectedAllowed[i], ev.Allowed, + "event %d: wrong allowed flag", i) + } + + // -- Verify LLM backend headers -- + require.Equal(t, 2, llm.requestCount(), + "LLM backend should have received exactly two requests") + + firstLLMHdr := llm.headersAt(0) + assert.Equal(t, sessionID, firstLLMHdr.Get(config.DefaultSessionIDHeaderName)) + assert.Equal(t, "0", firstLLMHdr.Get(config.DefaultSequenceNumberHeaderName), + "first LLM request must have sequence 0") + + secondLLMHdr := llm.headersAt(1) + assert.Equal(t, sessionID, secondLLMHdr.Get(config.DefaultSessionIDHeaderName)) + assert.Equal(t, "3", secondLLMHdr.Get(config.DefaultSequenceNumberHeaderName), + "second LLM request must have sequence 3") + + // -- Verify non-LLM backend has no correlation headers -- + require.Equal(t, 1, other.requestCount()) + otherHdr := other.headersAt(0) + assert.Empty(t, otherHdr.Get(config.DefaultSessionIDHeaderName)) + assert.Empty(t, otherHdr.Get(config.DefaultSequenceNumberHeaderName)) + + // -- Verify the gap reveals intermediate activity -- + // The gap between the two LLM sequence numbers (0 and 3) means + // that sequence numbers 1 and 2 were consumed by non-LLM + // activity, matching audit events 1 (non-LLM allowed) and 2 + // (denied). + firstLLMSeq := *events[0].SequenceNumber + secondLLMSeq := *events[3].SequenceNumber + gap := secondLLMSeq - firstLLMSeq - 1 + assert.Equal(t, uint64(2), gap, + "gap between LLM requests should reveal 2 intermediate events") +} + +// TestIntegration_SequenceGapRevealsAgenticLoop sends two LLM requests +// with several non-LLM requests in between, simulating an agentic loop +// where the model triggers tool-use HTTP calls between prompts. The +// test verifies that the gap in LLM sequence numbers precisely +// reflects the count of intermediate boundary events. +func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { + const sessionID = "agentic-loop-session" + + llm := newMultiRequestCapturingBackend() + defer llm.close() + + other := newMultiRequestCapturingBackend() + defer other.close() + + llmURL, err := url.Parse(llm.server.URL) + require.NoError(t, err) + + otherURL, err := url.Parse(other.server.URL) + require.NoError(t, err) + + aud := &capturingAuditor{} + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(llmURL.Hostname()), + WithAllowedDomain(otherURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{ + Domain: llmURL.Hostname(), + Path: "/v1/*", + }}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID(sessionID), + WithSequenceCounter(seq), + WithAuditor(aud), + ).Start() + defer pt.Stop() + + // First LLM prompt (seq 0). + resp, err := pt.proxyClient.Get(llm.server.URL + "/v1/messages") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + + // Agentic loop: three tool-use HTTP calls. + for _, p := range []string{"/coder/coder", "/coder/coder/issues", "/coder/coder/pulls"} { + resp, err = pt.proxyClient.Get(other.server.URL + p) + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + } + + // Second LLM prompt (seq 4). + resp, err = pt.proxyClient.Get(llm.server.URL + "/v1/messages") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + + // Verify LLM sequence headers. + require.Equal(t, 2, llm.requestCount()) + assert.Equal(t, "0", llm.headersAt(0).Get(config.DefaultSequenceNumberHeaderName)) + assert.Equal(t, "4", llm.headersAt(1).Get(config.DefaultSequenceNumberHeaderName)) + + // The gap between sequence numbers 0 and 4 is 3, matching the + // three tool-use requests in between. + events := aud.getRequests() + require.Len(t, events, 5) + + firstLLMSeq := *events[0].SequenceNumber + secondLLMSeq := *events[4].SequenceNumber + gap := secondLLMSeq - firstLLMSeq - 1 + assert.Equal(t, uint64(3), gap, + "gap between prompts should equal number of tool-use requests") + + // Verify the intermediate events are the tool-use requests. + for i := 1; i <= 3; i++ { + require.NotNil(t, events[i].SequenceNumber) + assert.Equal(t, uint64(i), *events[i].SequenceNumber) + assert.True(t, events[i].Allowed) + } +} + +// TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence +// verifies that when a jailed client sets its own correlation headers, +// the proxy replaces them with the real session ID and the real +// sequence number, and the audit event still agrees with the header. +func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) { + const sessionID = "real-session-uuid" + s := newSessionCorrelationIntegrationSetup(t, sessionID) + defer s.stop() + + req, err := http.NewRequest(http.MethodPost, s.llmBackend.server.URL+"/v1/messages", nil) + require.NoError(t, err) + req.Header.Set(config.DefaultSessionIDHeaderName, "spoofed-session") + req.Header.Set(config.DefaultSequenceNumberHeaderName, "9999") + + resp, err := s.pt.proxyClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Backend received real values, not spoofed. + require.Equal(t, 1, s.llmBackend.requestCount()) + hdr := s.llmBackend.headersAt(0) + assert.Equal(t, sessionID, hdr.Get(config.DefaultSessionIDHeaderName)) + assert.Equal(t, "0", hdr.Get(config.DefaultSequenceNumberHeaderName)) + + // Audit event agrees with header. + events := s.auditor.getRequests() + require.Len(t, events, 1) + require.NotNil(t, events[0].SequenceNumber) + assert.Equal(t, + strconv.FormatUint(*events[0].SequenceNumber, 10), + hdr.Get(config.DefaultSequenceNumberHeaderName), + ) +} + +// TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence +// verifies that when session correlation is disabled, the proxy does +// not inject headers and does not pre-allocate sequence numbers (the +// auditor falls back to its own counter instead). +func TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence(t *testing.T) { + llm := newMultiRequestCapturingBackend() + defer llm.close() + + llmURL, err := url.Parse(llm.server.URL) + require.NoError(t, err) + + aud := &capturingAuditor{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(llmURL.Hostname()), + // Correlation disabled; no sequence counter. + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: false, + InjectTargets: []config.InjectTarget{{Domain: llmURL.Hostname()}}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("should-not-appear"), + // Explicitly do NOT set WithSequenceCounter; seqCounter is nil. + WithAuditor(aud), + ).Start() + defer pt.Stop() + + resp, err := pt.proxyClient.Get(llm.server.URL + "/v1/messages") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // No correlation headers. + require.Equal(t, 1, llm.requestCount()) + hdr := llm.headersAt(0) + assert.Empty(t, hdr.Get(config.DefaultSessionIDHeaderName)) + assert.Empty(t, hdr.Get(config.DefaultSequenceNumberHeaderName)) + + // Audit event recorded but without a pre-allocated sequence + // number (nil), because no SequenceCounter was provided. + events := aud.getRequests() + require.Len(t, events, 1) + assert.Nil(t, events[0].SequenceNumber, + "no sequence counter means no pre-allocated sequence number") +} + +// TestIntegration_ConcurrentRequestsUniqueSequenceNumbers sends +// multiple requests concurrently and verifies that every request +// receives a unique sequence number, and that the set of numbers is +// dense (no gaps, no duplicates). +func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { + const sessionID = "concurrent-session" + const numRequests = 10 + + s := newSessionCorrelationIntegrationSetup(t, sessionID) + defer s.stop() + + var wg sync.WaitGroup + for i := 0; i < numRequests; i++ { + wg.Add(1) + go func() { + defer wg.Done() + resp, err := s.pt.proxyClient.Get(s.llmBackend.server.URL + "/v1/messages") + assert.NoError(t, err) + if resp != nil { + resp.Body.Close() //nolint:errcheck + } + }() + } + wg.Wait() + + // Every request should have been audited. + events := s.auditor.getRequests() + require.Len(t, events, numRequests) + + // Collect all sequence numbers and verify uniqueness. + seen := make(map[uint64]bool, numRequests) + for i, ev := range events { + require.NotNil(t, ev.SequenceNumber, + "event %d: sequence number must not be nil", i) + assert.False(t, seen[*ev.SequenceNumber], + "event %d: duplicate sequence number %d", i, *ev.SequenceNumber) + seen[*ev.SequenceNumber] = true + } + + // The set should be exactly {0, 1, ..., numRequests-1}. + for i := uint64(0); i < numRequests; i++ { + assert.True(t, seen[i], + "sequence number %d is missing from the set", i) + } + + // Every header should also carry a matching sequence number. + require.Equal(t, numRequests, s.llmBackend.requestCount()) + headerSeqs := make(map[string]bool, numRequests) + for i := 0; i < numRequests; i++ { + hdr := s.llmBackend.headersAt(i) + seqStr := hdr.Get(config.DefaultSequenceNumberHeaderName) + assert.NotEmpty(t, seqStr, "request %d: sequence header must be set", i) + headerSeqs[seqStr] = true + } + for i := uint64(0); i < numRequests; i++ { + assert.True(t, headerSeqs[fmt.Sprintf("%d", i)], + "header sequence number %d is missing", i) + } +} From 5bd67ece35b747664bf57d0e184134dfea19e88f Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Thu, 7 May 2026 17:27:55 +0000 Subject: [PATCH 2/8] refactor(proxy): update session correlation tests to use new header names and sequence number type Modified integration tests to reflect changes in session correlation header names and updated the sequence number type from uint64 to int32. Adjusted assertions in tests to ensure consistency with the new data types and header configurations, enhancing clarity and correctness in the test suite. --- ...xy_session_correlation_integration_test.go | 109 ++++++++---------- 1 file changed, 46 insertions(+), 63 deletions(-) diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go index b8035b7..4000270 100644 --- a/proxy/proxy_session_correlation_integration_test.go +++ b/proxy/proxy_session_correlation_integration_test.go @@ -108,11 +108,8 @@ func newSessionCorrelationIntegrationSetup(t *testing.T, sessionID string) *sess Domain: llmURL.Hostname(), Path: "/v1/*", }}, - SessionIDHeaderName: config.DefaultSessionIDHeaderName, - SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, }), WithSessionID(sessionID), - WithSequenceCounter(seq), WithAuditor(aud), ).Start() @@ -146,18 +143,18 @@ func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { require.Len(t, events, 1) require.True(t, events[0].Allowed) require.NotNil(t, events[0].SequenceNumber) - assert.Equal(t, uint64(0), *events[0].SequenceNumber) + assert.Equal(t, int32(0), events[0].SequenceNumber) // Forwarded headers. require.Equal(t, 1, s.llmBackend.requestCount()) hdr := s.llmBackend.headersAt(0) - assert.Equal(t, sessionID, hdr.Get(config.DefaultSessionIDHeaderName)) - assert.Equal(t, "0", hdr.Get(config.DefaultSequenceNumberHeaderName)) + assert.Equal(t, sessionID, hdr.Get(config.SessionIDHeaderName)) + assert.Equal(t, "0", hdr.Get(config.SequenceNumberHeaderName)) // The two must agree. assert.Equal(t, - strconv.FormatUint(*events[0].SequenceNumber, 10), - hdr.Get(config.DefaultSequenceNumberHeaderName), + strconv.Itoa(int(events[0].SequenceNumber)), + hdr.Get(config.SequenceNumberHeaderName), "audit event and forwarded header must carry the same sequence number", ) } @@ -180,14 +177,14 @@ func TestIntegration_NonLLMRequestAuditedWithoutHeaders(t *testing.T) { require.Len(t, events, 1) require.True(t, events[0].Allowed) require.NotNil(t, events[0].SequenceNumber) - assert.Equal(t, uint64(0), *events[0].SequenceNumber) + assert.Equal(t, int32(0), events[0].SequenceNumber) // No correlation headers on the backend. require.Equal(t, 1, s.otherBackend.requestCount()) hdr := s.otherBackend.headersAt(0) - assert.Empty(t, hdr.Get(config.DefaultSessionIDHeaderName), + assert.Empty(t, hdr.Get(config.SessionIDHeaderName), "non-inject-target requests must not carry session ID header") - assert.Empty(t, hdr.Get(config.DefaultSequenceNumberHeaderName), + assert.Empty(t, hdr.Get(config.SequenceNumberHeaderName), "non-inject-target requests must not carry sequence number header") } @@ -201,19 +198,15 @@ func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { defer llm.close() aud := &capturingAuditor{} - seq := &audit.SequenceCounter{} pt := NewProxyTest(t, WithCertManager(t.TempDir()), // No allowed domains: deny everything. WithSessionCorrelation(config.SessionCorrelationConfig{ - Enabled: true, - InjectTargets: []config.InjectTarget{{Domain: "anything.example.com"}}, - SessionIDHeaderName: config.DefaultSessionIDHeaderName, - SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + Enabled: true, + InjectTargets: []config.InjectTarget{{Domain: "anything.example.com"}}, }), WithSessionID("test-session"), - WithSequenceCounter(seq), WithAuditor(aud), ).Start() defer pt.Stop() @@ -228,7 +221,7 @@ func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { require.Len(t, events, 1) require.False(t, events[0].Allowed) require.NotNil(t, events[0].SequenceNumber) - assert.Equal(t, uint64(0), *events[0].SequenceNumber) + assert.Equal(t, int32(0), events[0].SequenceNumber) // Backend never hit. assert.Equal(t, 0, llm.requestCount(), @@ -259,7 +252,6 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { require.NoError(t, err) aud := &capturingAuditor{} - seq := &audit.SequenceCounter{} pt := NewProxyTest(t, WithCertManager(t.TempDir()), @@ -272,11 +264,8 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { Domain: llmURL.Hostname(), Path: "/v1/*", }}, - SessionIDHeaderName: config.DefaultSessionIDHeaderName, - SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, }), WithSessionID(sessionID), - WithSequenceCounter(seq), WithAuditor(aud), ).Start() defer pt.Stop() @@ -309,11 +298,11 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { events := aud.getRequests() require.Len(t, events, 4, "expected exactly four audit events") - expectedSeq := []uint64{0, 1, 2, 3} + expectedSeq := []int32{0, 1, 2, 3} expectedAllowed := []bool{true, true, false, true} for i, ev := range events { require.NotNil(t, ev.SequenceNumber, "event %d: sequence number must be set", i) - assert.Equal(t, expectedSeq[i], *ev.SequenceNumber, + assert.Equal(t, expectedSeq[i], ev.SequenceNumber, "event %d: wrong sequence number", i) assert.Equal(t, expectedAllowed[i], ev.Allowed, "event %d: wrong allowed flag", i) @@ -324,30 +313,30 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { "LLM backend should have received exactly two requests") firstLLMHdr := llm.headersAt(0) - assert.Equal(t, sessionID, firstLLMHdr.Get(config.DefaultSessionIDHeaderName)) - assert.Equal(t, "0", firstLLMHdr.Get(config.DefaultSequenceNumberHeaderName), + assert.Equal(t, sessionID, firstLLMHdr.Get(config.SessionIDHeaderName)) + assert.Equal(t, "0", firstLLMHdr.Get(config.SequenceNumberHeaderName), "first LLM request must have sequence 0") secondLLMHdr := llm.headersAt(1) - assert.Equal(t, sessionID, secondLLMHdr.Get(config.DefaultSessionIDHeaderName)) - assert.Equal(t, "3", secondLLMHdr.Get(config.DefaultSequenceNumberHeaderName), + assert.Equal(t, sessionID, secondLLMHdr.Get(config.SessionIDHeaderName)) + assert.Equal(t, "3", secondLLMHdr.Get(config.SequenceNumberHeaderName), "second LLM request must have sequence 3") // -- Verify non-LLM backend has no correlation headers -- require.Equal(t, 1, other.requestCount()) otherHdr := other.headersAt(0) - assert.Empty(t, otherHdr.Get(config.DefaultSessionIDHeaderName)) - assert.Empty(t, otherHdr.Get(config.DefaultSequenceNumberHeaderName)) + assert.Empty(t, otherHdr.Get(config.SessionIDHeaderName)) + assert.Empty(t, otherHdr.Get(config.SequenceNumberHeaderName)) // -- Verify the gap reveals intermediate activity -- // The gap between the two LLM sequence numbers (0 and 3) means // that sequence numbers 1 and 2 were consumed by non-LLM // activity, matching audit events 1 (non-LLM allowed) and 2 // (denied). - firstLLMSeq := *events[0].SequenceNumber - secondLLMSeq := *events[3].SequenceNumber + firstLLMSeq := events[0].SequenceNumber + secondLLMSeq := events[3].SequenceNumber gap := secondLLMSeq - firstLLMSeq - 1 - assert.Equal(t, uint64(2), gap, + assert.Equal(t, int32(2), gap, "gap between LLM requests should reveal 2 intermediate events") } @@ -372,7 +361,6 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { require.NoError(t, err) aud := &capturingAuditor{} - seq := &audit.SequenceCounter{} pt := NewProxyTest(t, WithCertManager(t.TempDir()), @@ -384,11 +372,8 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { Domain: llmURL.Hostname(), Path: "/v1/*", }}, - SessionIDHeaderName: config.DefaultSessionIDHeaderName, - SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, }), WithSessionID(sessionID), - WithSequenceCounter(seq), WithAuditor(aud), ).Start() defer pt.Stop() @@ -412,24 +397,24 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { // Verify LLM sequence headers. require.Equal(t, 2, llm.requestCount()) - assert.Equal(t, "0", llm.headersAt(0).Get(config.DefaultSequenceNumberHeaderName)) - assert.Equal(t, "4", llm.headersAt(1).Get(config.DefaultSequenceNumberHeaderName)) + assert.Equal(t, "0", llm.headersAt(0).Get(config.SequenceNumberHeaderName)) + assert.Equal(t, "4", llm.headersAt(1).Get(config.SequenceNumberHeaderName)) // The gap between sequence numbers 0 and 4 is 3, matching the // three tool-use requests in between. events := aud.getRequests() require.Len(t, events, 5) - firstLLMSeq := *events[0].SequenceNumber - secondLLMSeq := *events[4].SequenceNumber + firstLLMSeq := events[0].SequenceNumber + secondLLMSeq := events[4].SequenceNumber gap := secondLLMSeq - firstLLMSeq - 1 - assert.Equal(t, uint64(3), gap, + assert.Equal(t, int32(3), gap, "gap between prompts should equal number of tool-use requests") // Verify the intermediate events are the tool-use requests. for i := 1; i <= 3; i++ { require.NotNil(t, events[i].SequenceNumber) - assert.Equal(t, uint64(i), *events[i].SequenceNumber) + assert.Equal(t, int32(i), events[i].SequenceNumber) assert.True(t, events[i].Allowed) } } @@ -445,8 +430,8 @@ func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) req, err := http.NewRequest(http.MethodPost, s.llmBackend.server.URL+"/v1/messages", nil) require.NoError(t, err) - req.Header.Set(config.DefaultSessionIDHeaderName, "spoofed-session") - req.Header.Set(config.DefaultSequenceNumberHeaderName, "9999") + req.Header.Set(config.SessionIDHeaderName, "spoofed-session") + req.Header.Set(config.SequenceNumberHeaderName, "9999") resp, err := s.pt.proxyClient.Do(req) require.NoError(t, err) @@ -456,16 +441,16 @@ func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) // Backend received real values, not spoofed. require.Equal(t, 1, s.llmBackend.requestCount()) hdr := s.llmBackend.headersAt(0) - assert.Equal(t, sessionID, hdr.Get(config.DefaultSessionIDHeaderName)) - assert.Equal(t, "0", hdr.Get(config.DefaultSequenceNumberHeaderName)) + assert.Equal(t, sessionID, hdr.Get(config.SessionIDHeaderName)) + assert.Equal(t, "0", hdr.Get(config.SequenceNumberHeaderName)) // Audit event agrees with header. events := s.auditor.getRequests() require.Len(t, events, 1) require.NotNil(t, events[0].SequenceNumber) assert.Equal(t, - strconv.FormatUint(*events[0].SequenceNumber, 10), - hdr.Get(config.DefaultSequenceNumberHeaderName), + strconv.Itoa(int(events[0].SequenceNumber)), + hdr.Get(config.SequenceNumberHeaderName), ) } @@ -487,10 +472,8 @@ func TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence(t *testi WithAllowedDomain(llmURL.Hostname()), // Correlation disabled; no sequence counter. WithSessionCorrelation(config.SessionCorrelationConfig{ - Enabled: false, - InjectTargets: []config.InjectTarget{{Domain: llmURL.Hostname()}}, - SessionIDHeaderName: config.DefaultSessionIDHeaderName, - SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + Enabled: false, + InjectTargets: []config.InjectTarget{{Domain: llmURL.Hostname()}}, }), WithSessionID("should-not-appear"), // Explicitly do NOT set WithSequenceCounter; seqCounter is nil. @@ -506,14 +489,14 @@ func TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence(t *testi // No correlation headers. require.Equal(t, 1, llm.requestCount()) hdr := llm.headersAt(0) - assert.Empty(t, hdr.Get(config.DefaultSessionIDHeaderName)) - assert.Empty(t, hdr.Get(config.DefaultSequenceNumberHeaderName)) + assert.Empty(t, hdr.Get(config.SessionIDHeaderName)) + assert.Empty(t, hdr.Get(config.SequenceNumberHeaderName)) // Audit event recorded but without a pre-allocated sequence // number (nil), because no SequenceCounter was provided. events := aud.getRequests() require.Len(t, events, 1) - assert.Nil(t, events[0].SequenceNumber, + assert.Equal(t, int32(0), events[0].SequenceNumber, "no sequence counter means no pre-allocated sequence number") } @@ -547,17 +530,17 @@ func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { require.Len(t, events, numRequests) // Collect all sequence numbers and verify uniqueness. - seen := make(map[uint64]bool, numRequests) + seen := make(map[int32]bool, numRequests) for i, ev := range events { require.NotNil(t, ev.SequenceNumber, "event %d: sequence number must not be nil", i) - assert.False(t, seen[*ev.SequenceNumber], - "event %d: duplicate sequence number %d", i, *ev.SequenceNumber) - seen[*ev.SequenceNumber] = true + assert.False(t, seen[ev.SequenceNumber], + "event %d: duplicate sequence number %d", i, ev.SequenceNumber) + seen[ev.SequenceNumber] = true } // The set should be exactly {0, 1, ..., numRequests-1}. - for i := uint64(0); i < numRequests; i++ { + for i := int32(0); i < numRequests; i++ { assert.True(t, seen[i], "sequence number %d is missing from the set", i) } @@ -567,11 +550,11 @@ func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { headerSeqs := make(map[string]bool, numRequests) for i := 0; i < numRequests; i++ { hdr := s.llmBackend.headersAt(i) - seqStr := hdr.Get(config.DefaultSequenceNumberHeaderName) + seqStr := hdr.Get(config.SequenceNumberHeaderName) assert.NotEmpty(t, seqStr, "request %d: sequence header must be set", i) headerSeqs[seqStr] = true } - for i := uint64(0); i < numRequests; i++ { + for i := int32(0); i < numRequests; i++ { assert.True(t, headerSeqs[fmt.Sprintf("%d", i)], "header sequence number %d is missing", i) } From f1726b5a25085d18cc30040e4792dec7b564d0b8 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Wed, 13 May 2026 12:28:24 +0000 Subject: [PATCH 3/8] refactor(proxy): clean up integration test naming and style - Remove '// ---------- Integration Tests ----------' section separator - Rename 'hdr'/'Hdr' variables to 'header'/'Header' for clarity - Rename 'llm'/'llmBackend' to 'injectBackend'/'inject'/'backend' to reflect the actual concept (inject target) rather than a specific use case (LLM) - Update comments to match the new naming Generated by Coder Agents --- ...xy_session_correlation_integration_test.go | 226 +++++++++--------- 1 file changed, 112 insertions(+), 114 deletions(-) diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go index 4000270..b2206fb 100644 --- a/proxy/proxy_session_correlation_integration_test.go +++ b/proxy/proxy_session_correlation_integration_test.go @@ -55,17 +55,17 @@ func (m *multiRequestCapturingBackend) headersAt(i int) http.Header { // counter. Tests build one via newSessionCorrelationIntegrationSetup // and tear it down with stop. type sessionCorrelationIntegrationSetup struct { - pt *ProxyTest - auditor *capturingAuditor - seq *audit.SequenceCounter - llmBackend *multiRequestCapturingBackend - otherBackend *multiRequestCapturingBackend + pt *ProxyTest + auditor *capturingAuditor + seq *audit.SequenceCounter + injectBackend *multiRequestCapturingBackend + otherBackend *multiRequestCapturingBackend } func (s *sessionCorrelationIntegrationSetup) stop() { s.pt.Stop() - if s.llmBackend != nil { - s.llmBackend.close() + if s.injectBackend != nil { + s.injectBackend.close() } if s.otherBackend != nil { s.otherBackend.close() @@ -74,17 +74,16 @@ func (s *sessionCorrelationIntegrationSetup) stop() { // newSessionCorrelationIntegrationSetup builds a proxy that allows // traffic to two httptest backends: one that matches an inject target -// (simulating an LLM provider) and one that does not (simulating a -// generic allowed domain like github.com). Both backends capture all -// received request headers. A capturingAuditor records every audit -// event for later inspection. +// and one that does not (simulating a generic allowed domain like +// github.com). Both backends capture all received request headers. +// A capturingAuditor records every audit event for later inspection. func newSessionCorrelationIntegrationSetup(t *testing.T, sessionID string) *sessionCorrelationIntegrationSetup { t.Helper() - llm := newMultiRequestCapturingBackend() + inject := newMultiRequestCapturingBackend() other := newMultiRequestCapturingBackend() - llmURL, err := url.Parse(llm.server.URL) + injectURL, err := url.Parse(inject.server.URL) require.NoError(t, err) otherURL, err := url.Parse(other.server.URL) @@ -94,18 +93,18 @@ func newSessionCorrelationIntegrationSetup(t *testing.T, sessionID string) *sess seq := &audit.SequenceCounter{} // Both httptest backends resolve to 127.0.0.1, so a domain-only - // inject target would match both. We use a path glob on the LLM - // paths (/v1/*) to limit header injection to LLM requests. + // inject target would match both. We use a path glob on the + // inject-target paths (/v1/*) to limit header injection. pt := NewProxyTest(t, WithCertManager(t.TempDir()), // Allow both backends. - WithAllowedDomain(llmURL.Hostname()), + WithAllowedDomain(injectURL.Hostname()), WithAllowedDomain(otherURL.Hostname()), - // Only requests matching the LLM path receive headers. + // Only requests matching the inject-target path receive headers. WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, InjectTargets: []config.InjectTarget{{ - Domain: llmURL.Hostname(), + Domain: injectURL.Hostname(), Path: "/v1/*", }}, }), @@ -114,16 +113,14 @@ func newSessionCorrelationIntegrationSetup(t *testing.T, sessionID string) *sess ).Start() return &sessionCorrelationIntegrationSetup{ - pt: pt, - auditor: aud, - seq: seq, - llmBackend: llm, - otherBackend: other, + pt: pt, + auditor: aud, + seq: seq, + injectBackend: inject, + otherBackend: other, } } -// ---------- Integration Tests ---------- - // TestIntegration_LLMRequestAuditAndHeadersAgree verifies the core // correlation invariant: when an allowed request hits an inject target, // the sequence number in the audit event equals the sequence number in @@ -133,7 +130,7 @@ func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { s := newSessionCorrelationIntegrationSetup(t, sessionID) defer s.stop() - resp, err := s.pt.proxyClient.Get(s.llmBackend.server.URL + "/v1/messages") + resp, err := s.pt.proxyClient.Get(s.injectBackend.server.URL + "/v1/messages") require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) @@ -146,15 +143,15 @@ func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { assert.Equal(t, int32(0), events[0].SequenceNumber) // Forwarded headers. - require.Equal(t, 1, s.llmBackend.requestCount()) - hdr := s.llmBackend.headersAt(0) - assert.Equal(t, sessionID, hdr.Get(config.SessionIDHeaderName)) - assert.Equal(t, "0", hdr.Get(config.SequenceNumberHeaderName)) + require.Equal(t, 1, s.injectBackend.requestCount()) + header := s.injectBackend.headersAt(0) + assert.Equal(t, sessionID, header.Get(config.SessionIDHeaderName)) + assert.Equal(t, "0", header.Get(config.SequenceNumberHeaderName)) // The two must agree. assert.Equal(t, strconv.Itoa(int(events[0].SequenceNumber)), - hdr.Get(config.SequenceNumberHeaderName), + header.Get(config.SequenceNumberHeaderName), "audit event and forwarded header must carry the same sequence number", ) } @@ -181,10 +178,10 @@ func TestIntegration_NonLLMRequestAuditedWithoutHeaders(t *testing.T) { // No correlation headers on the backend. require.Equal(t, 1, s.otherBackend.requestCount()) - hdr := s.otherBackend.headersAt(0) - assert.Empty(t, hdr.Get(config.SessionIDHeaderName), + header := s.otherBackend.headersAt(0) + assert.Empty(t, header.Get(config.SessionIDHeaderName), "non-inject-target requests must not carry session ID header") - assert.Empty(t, hdr.Get(config.SequenceNumberHeaderName), + assert.Empty(t, header.Get(config.SequenceNumberHeaderName), "non-inject-target requests must not carry sequence number header") } @@ -194,8 +191,8 @@ func TestIntegration_NonLLMRequestAuditedWithoutHeaders(t *testing.T) { func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { // Create a setup with a custom deny-all proxy, but keep the same // pattern of shared sequence counter and auditor. - llm := newMultiRequestCapturingBackend() - defer llm.close() + backend := newMultiRequestCapturingBackend() + defer backend.close() aud := &capturingAuditor{} @@ -211,7 +208,7 @@ func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { ).Start() defer pt.Stop() - resp, err := pt.proxyClient.Get(llm.server.URL + "/exfil") + resp, err := pt.proxyClient.Get(backend.server.URL + "/exfil") require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusForbidden, resp.StatusCode) @@ -224,7 +221,7 @@ func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { assert.Equal(t, int32(0), events[0].SequenceNumber) // Backend never hit. - assert.Equal(t, 0, llm.requestCount(), + assert.Equal(t, 0, backend.requestCount(), "denied requests must not be forwarded to the backend") } @@ -238,14 +235,14 @@ func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { const sessionID = "mixed-test-session" - // Two allowed backends (LLM and "github"), one denied domain. - llm := newMultiRequestCapturingBackend() - defer llm.close() + // Two allowed backends (inject target and "github"), one denied domain. + inject := newMultiRequestCapturingBackend() + defer inject.close() other := newMultiRequestCapturingBackend() defer other.close() - llmURL, err := url.Parse(llm.server.URL) + injectURL, err := url.Parse(inject.server.URL) require.NoError(t, err) otherURL, err := url.Parse(other.server.URL) @@ -255,13 +252,13 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { pt := NewProxyTest(t, WithCertManager(t.TempDir()), - WithAllowedDomain(llmURL.Hostname()), + WithAllowedDomain(injectURL.Hostname()), WithAllowedDomain(otherURL.Hostname()), - // Only LLM is an inject target. + // Only the inject backend is an inject target. WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, InjectTargets: []config.InjectTarget{{ - Domain: llmURL.Hostname(), + Domain: injectURL.Hostname(), Path: "/v1/*", }}, }), @@ -270,13 +267,13 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { ).Start() defer pt.Stop() - // Request 0: LLM (allowed, inject target). - resp, err := pt.proxyClient.Get(llm.server.URL + "/v1/messages") + // Request 0: inject target (allowed, headers injected). + resp, err := pt.proxyClient.Get(inject.server.URL + "/v1/messages") require.NoError(t, err) resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) - // Request 1: non-LLM (allowed, no inject). + // Request 1: non-inject-target (allowed, no headers). resp, err = pt.proxyClient.Get(other.server.URL + "/coder/coder") require.NoError(t, err) resp.Body.Close() //nolint:errcheck @@ -288,8 +285,8 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusForbidden, resp.StatusCode) - // Request 3: LLM again. - resp, err = pt.proxyClient.Get(llm.server.URL + "/v1/messages") + // Request 3: inject target again. + resp, err = pt.proxyClient.Get(inject.server.URL + "/v1/messages") require.NoError(t, err) resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) @@ -308,53 +305,54 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { "event %d: wrong allowed flag", i) } - // -- Verify LLM backend headers -- - require.Equal(t, 2, llm.requestCount(), - "LLM backend should have received exactly two requests") + // -- Verify inject-target backend headers -- + require.Equal(t, 2, inject.requestCount(), + "inject-target backend should have received exactly two requests") - firstLLMHdr := llm.headersAt(0) - assert.Equal(t, sessionID, firstLLMHdr.Get(config.SessionIDHeaderName)) - assert.Equal(t, "0", firstLLMHdr.Get(config.SequenceNumberHeaderName), - "first LLM request must have sequence 0") + firstInjectHeader := inject.headersAt(0) + assert.Equal(t, sessionID, firstInjectHeader.Get(config.SessionIDHeaderName)) + assert.Equal(t, "0", firstInjectHeader.Get(config.SequenceNumberHeaderName), + "first inject-target request must have sequence 0") - secondLLMHdr := llm.headersAt(1) - assert.Equal(t, sessionID, secondLLMHdr.Get(config.SessionIDHeaderName)) - assert.Equal(t, "3", secondLLMHdr.Get(config.SequenceNumberHeaderName), - "second LLM request must have sequence 3") + secondInjectHeader := inject.headersAt(1) + assert.Equal(t, sessionID, secondInjectHeader.Get(config.SessionIDHeaderName)) + assert.Equal(t, "3", secondInjectHeader.Get(config.SequenceNumberHeaderName), + "second inject-target request must have sequence 3") - // -- Verify non-LLM backend has no correlation headers -- + // -- Verify non-inject-target backend has no correlation headers -- require.Equal(t, 1, other.requestCount()) - otherHdr := other.headersAt(0) - assert.Empty(t, otherHdr.Get(config.SessionIDHeaderName)) - assert.Empty(t, otherHdr.Get(config.SequenceNumberHeaderName)) + otherHeader := other.headersAt(0) + assert.Empty(t, otherHeader.Get(config.SessionIDHeaderName)) + assert.Empty(t, otherHeader.Get(config.SequenceNumberHeaderName)) // -- Verify the gap reveals intermediate activity -- - // The gap between the two LLM sequence numbers (0 and 3) means - // that sequence numbers 1 and 2 were consumed by non-LLM - // activity, matching audit events 1 (non-LLM allowed) and 2 - // (denied). - firstLLMSeq := events[0].SequenceNumber - secondLLMSeq := events[3].SequenceNumber - gap := secondLLMSeq - firstLLMSeq - 1 + // The gap between the two inject-target sequence numbers (0 and 3) + // means that sequence numbers 1 and 2 were consumed by + // non-inject-target activity, matching audit events 1 + // (non-inject-target allowed) and 2 (denied). + firstInjectSeq := events[0].SequenceNumber + secondInjectSeq := events[3].SequenceNumber + gap := secondInjectSeq - firstInjectSeq - 1 assert.Equal(t, int32(2), gap, - "gap between LLM requests should reveal 2 intermediate events") + "gap between inject-target requests should reveal 2 intermediate events") } -// TestIntegration_SequenceGapRevealsAgenticLoop sends two LLM requests -// with several non-LLM requests in between, simulating an agentic loop -// where the model triggers tool-use HTTP calls between prompts. The -// test verifies that the gap in LLM sequence numbers precisely -// reflects the count of intermediate boundary events. +// TestIntegration_SequenceGapRevealsAgenticLoop sends two inject-target +// requests with several non-inject-target requests in between, +// simulating an agentic loop where the model triggers tool-use HTTP +// calls between prompts. The test verifies that the gap in +// inject-target sequence numbers precisely reflects the count of +// intermediate boundary events. func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { const sessionID = "agentic-loop-session" - llm := newMultiRequestCapturingBackend() - defer llm.close() + inject := newMultiRequestCapturingBackend() + defer inject.close() other := newMultiRequestCapturingBackend() defer other.close() - llmURL, err := url.Parse(llm.server.URL) + injectURL, err := url.Parse(inject.server.URL) require.NoError(t, err) otherURL, err := url.Parse(other.server.URL) @@ -364,12 +362,12 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { pt := NewProxyTest(t, WithCertManager(t.TempDir()), - WithAllowedDomain(llmURL.Hostname()), + WithAllowedDomain(injectURL.Hostname()), WithAllowedDomain(otherURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, InjectTargets: []config.InjectTarget{{ - Domain: llmURL.Hostname(), + Domain: injectURL.Hostname(), Path: "/v1/*", }}, }), @@ -378,8 +376,8 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { ).Start() defer pt.Stop() - // First LLM prompt (seq 0). - resp, err := pt.proxyClient.Get(llm.server.URL + "/v1/messages") + // First inject-target request (seq 0). + resp, err := pt.proxyClient.Get(inject.server.URL + "/v1/messages") require.NoError(t, err) resp.Body.Close() //nolint:errcheck @@ -390,24 +388,24 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { resp.Body.Close() //nolint:errcheck } - // Second LLM prompt (seq 4). - resp, err = pt.proxyClient.Get(llm.server.URL + "/v1/messages") + // Second inject-target request (seq 4). + resp, err = pt.proxyClient.Get(inject.server.URL + "/v1/messages") require.NoError(t, err) resp.Body.Close() //nolint:errcheck - // Verify LLM sequence headers. - require.Equal(t, 2, llm.requestCount()) - assert.Equal(t, "0", llm.headersAt(0).Get(config.SequenceNumberHeaderName)) - assert.Equal(t, "4", llm.headersAt(1).Get(config.SequenceNumberHeaderName)) + // Verify inject-target sequence headers. + require.Equal(t, 2, inject.requestCount()) + assert.Equal(t, "0", inject.headersAt(0).Get(config.SequenceNumberHeaderName)) + assert.Equal(t, "4", inject.headersAt(1).Get(config.SequenceNumberHeaderName)) // The gap between sequence numbers 0 and 4 is 3, matching the // three tool-use requests in between. events := aud.getRequests() require.Len(t, events, 5) - firstLLMSeq := events[0].SequenceNumber - secondLLMSeq := events[4].SequenceNumber - gap := secondLLMSeq - firstLLMSeq - 1 + firstInjectSeq := events[0].SequenceNumber + secondInjectSeq := events[4].SequenceNumber + gap := secondInjectSeq - firstInjectSeq - 1 assert.Equal(t, int32(3), gap, "gap between prompts should equal number of tool-use requests") @@ -428,7 +426,7 @@ func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) s := newSessionCorrelationIntegrationSetup(t, sessionID) defer s.stop() - req, err := http.NewRequest(http.MethodPost, s.llmBackend.server.URL+"/v1/messages", nil) + req, err := http.NewRequest(http.MethodPost, s.injectBackend.server.URL+"/v1/messages", nil) require.NoError(t, err) req.Header.Set(config.SessionIDHeaderName, "spoofed-session") req.Header.Set(config.SequenceNumberHeaderName, "9999") @@ -439,10 +437,10 @@ func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) require.Equal(t, http.StatusOK, resp.StatusCode) // Backend received real values, not spoofed. - require.Equal(t, 1, s.llmBackend.requestCount()) - hdr := s.llmBackend.headersAt(0) - assert.Equal(t, sessionID, hdr.Get(config.SessionIDHeaderName)) - assert.Equal(t, "0", hdr.Get(config.SequenceNumberHeaderName)) + require.Equal(t, 1, s.injectBackend.requestCount()) + header := s.injectBackend.headersAt(0) + assert.Equal(t, sessionID, header.Get(config.SessionIDHeaderName)) + assert.Equal(t, "0", header.Get(config.SequenceNumberHeaderName)) // Audit event agrees with header. events := s.auditor.getRequests() @@ -450,7 +448,7 @@ func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) require.NotNil(t, events[0].SequenceNumber) assert.Equal(t, strconv.Itoa(int(events[0].SequenceNumber)), - hdr.Get(config.SequenceNumberHeaderName), + header.Get(config.SequenceNumberHeaderName), ) } @@ -459,21 +457,21 @@ func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) // not inject headers and does not pre-allocate sequence numbers (the // auditor falls back to its own counter instead). func TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence(t *testing.T) { - llm := newMultiRequestCapturingBackend() - defer llm.close() + backend := newMultiRequestCapturingBackend() + defer backend.close() - llmURL, err := url.Parse(llm.server.URL) + backendURL, err := url.Parse(backend.server.URL) require.NoError(t, err) aud := &capturingAuditor{} pt := NewProxyTest(t, WithCertManager(t.TempDir()), - WithAllowedDomain(llmURL.Hostname()), + WithAllowedDomain(backendURL.Hostname()), // Correlation disabled; no sequence counter. WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: false, - InjectTargets: []config.InjectTarget{{Domain: llmURL.Hostname()}}, + InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, }), WithSessionID("should-not-appear"), // Explicitly do NOT set WithSequenceCounter; seqCounter is nil. @@ -481,16 +479,16 @@ func TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence(t *testi ).Start() defer pt.Stop() - resp, err := pt.proxyClient.Get(llm.server.URL + "/v1/messages") + resp, err := pt.proxyClient.Get(backend.server.URL + "/v1/messages") require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) // No correlation headers. - require.Equal(t, 1, llm.requestCount()) - hdr := llm.headersAt(0) - assert.Empty(t, hdr.Get(config.SessionIDHeaderName)) - assert.Empty(t, hdr.Get(config.SequenceNumberHeaderName)) + require.Equal(t, 1, backend.requestCount()) + header := backend.headersAt(0) + assert.Empty(t, header.Get(config.SessionIDHeaderName)) + assert.Empty(t, header.Get(config.SequenceNumberHeaderName)) // Audit event recorded but without a pre-allocated sequence // number (nil), because no SequenceCounter was provided. @@ -516,7 +514,7 @@ func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - resp, err := s.pt.proxyClient.Get(s.llmBackend.server.URL + "/v1/messages") + resp, err := s.pt.proxyClient.Get(s.injectBackend.server.URL + "/v1/messages") assert.NoError(t, err) if resp != nil { resp.Body.Close() //nolint:errcheck @@ -546,11 +544,11 @@ func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { } // Every header should also carry a matching sequence number. - require.Equal(t, numRequests, s.llmBackend.requestCount()) + require.Equal(t, numRequests, s.injectBackend.requestCount()) headerSeqs := make(map[string]bool, numRequests) for i := 0; i < numRequests; i++ { - hdr := s.llmBackend.headersAt(i) - seqStr := hdr.Get(config.SequenceNumberHeaderName) + header := s.injectBackend.headersAt(i) + seqStr := header.Get(config.SequenceNumberHeaderName) assert.NotEmpty(t, seqStr, "request %d: sequence header must be set", i) headerSeqs[seqStr] = true } From 4db7db6d2c3a8e5a2d0a230bc7ccb603317c0319 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Wed, 13 May 2026 12:34:02 +0000 Subject: [PATCH 4/8] refactor(proxy): improve comments and structure in session correlation integration tests - Enhanced comments for clarity regarding the purpose of `injectBackend` and `otherBackend`. - Removed unnecessary comments to streamline the test code. - Adjusted formatting for consistency and readability in the `sessionCorrelationIntegrationSetup` struct. These changes aim to improve the maintainability and understanding of the integration tests related to session correlation. --- .../proxy_session_correlation_integration_test.go | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go index b2206fb..174e012 100644 --- a/proxy/proxy_session_correlation_integration_test.go +++ b/proxy/proxy_session_correlation_integration_test.go @@ -55,11 +55,16 @@ func (m *multiRequestCapturingBackend) headersAt(i int) http.Header { // counter. Tests build one via newSessionCorrelationIntegrationSetup // and tear it down with stop. type sessionCorrelationIntegrationSetup struct { - pt *ProxyTest - auditor *capturingAuditor - seq *audit.SequenceCounter + pt *ProxyTest + auditor *capturingAuditor + seq *audit.SequenceCounter + // llmBackend expects headers to be injected as these requests are + // expected to be seen by the AI Gateway and then correlated back + // to the audit event injectBackend *multiRequestCapturingBackend - otherBackend *multiRequestCapturingBackend + // otherBackend does not expect headers to be injected as these + // requests should not be routed through the AI Gateway. + otherBackend *multiRequestCapturingBackend } func (s *sessionCorrelationIntegrationSetup) stop() { @@ -135,7 +140,6 @@ func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { defer resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) - // Audit event. events := s.auditor.getRequests() require.Len(t, events, 1) require.True(t, events[0].Allowed) @@ -148,7 +152,6 @@ func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { assert.Equal(t, sessionID, header.Get(config.SessionIDHeaderName)) assert.Equal(t, "0", header.Get(config.SequenceNumberHeaderName)) - // The two must agree. assert.Equal(t, strconv.Itoa(int(events[0].SequenceNumber)), header.Get(config.SequenceNumberHeaderName), From b84170e622c96526d8f0d0bd7c463ed01e049412 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Wed, 13 May 2026 12:42:13 +0000 Subject: [PATCH 5/8] refactor(proxy): rename setup struct, fix disabled-correlation test - Rename sessionCorrelationIntegrationSetup to correlationTestEnv and newSessionCorrelationIntegrationSetup to newCorrelationTestEnv for brevity and clarity. - Rename TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence to TestIntegration_DisabledCorrelationNoHeaders. The sequence counter is a value type on the proxy server and always increments regardless of the correlation setting, so the previous name and assertions about 'no pre-allocated sequence number' were misleading. The test now focuses on what actually differs: no headers are injected. - Remove misleading 'auditor falls back to its own counter' comment. Generated by Coder Agents --- ...xy_session_correlation_integration_test.go | 66 +++++++++---------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go index 174e012..17b650c 100644 --- a/proxy/proxy_session_correlation_integration_test.go +++ b/proxy/proxy_session_correlation_integration_test.go @@ -50,24 +50,24 @@ func (m *multiRequestCapturingBackend) headersAt(i int) http.Header { return m.all[i].Clone() } -// sessionCorrelationIntegrationSetup holds the shared objects for an +// correlationTestEnv holds the shared objects for a session-correlation // integration test: the proxy, auditor, backend(s), and sequence -// counter. Tests build one via newSessionCorrelationIntegrationSetup -// and tear it down with stop. -type sessionCorrelationIntegrationSetup struct { +// counter. Tests build one via newCorrelationTestEnv and tear it down +// with stop. +type correlationTestEnv struct { pt *ProxyTest auditor *capturingAuditor seq *audit.SequenceCounter - // llmBackend expects headers to be injected as these requests are - // expected to be seen by the AI Gateway and then correlated back - // to the audit event + // injectBackend expects headers to be injected as these requests + // are expected to be seen by the AI Gateway and then correlated + // back to the audit event. injectBackend *multiRequestCapturingBackend // otherBackend does not expect headers to be injected as these // requests should not be routed through the AI Gateway. otherBackend *multiRequestCapturingBackend } -func (s *sessionCorrelationIntegrationSetup) stop() { +func (s *correlationTestEnv) stop() { s.pt.Stop() if s.injectBackend != nil { s.injectBackend.close() @@ -77,12 +77,12 @@ func (s *sessionCorrelationIntegrationSetup) stop() { } } -// newSessionCorrelationIntegrationSetup builds a proxy that allows -// traffic to two httptest backends: one that matches an inject target -// and one that does not (simulating a generic allowed domain like -// github.com). Both backends capture all received request headers. -// A capturingAuditor records every audit event for later inspection. -func newSessionCorrelationIntegrationSetup(t *testing.T, sessionID string) *sessionCorrelationIntegrationSetup { +// newCorrelationTestEnv builds a proxy that allows traffic to two +// httptest backends: one that matches an inject target and one that +// does not (simulating a generic allowed domain like github.com). +// Both backends capture all received request headers. A +// capturingAuditor records every audit event for later inspection. +func newCorrelationTestEnv(t *testing.T, sessionID string) *correlationTestEnv { t.Helper() inject := newMultiRequestCapturingBackend() @@ -117,7 +117,7 @@ func newSessionCorrelationIntegrationSetup(t *testing.T, sessionID string) *sess WithAuditor(aud), ).Start() - return &sessionCorrelationIntegrationSetup{ + return &correlationTestEnv{ pt: pt, auditor: aud, seq: seq, @@ -132,7 +132,7 @@ func newSessionCorrelationIntegrationSetup(t *testing.T, sessionID string) *sess // the forwarded header. func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { const sessionID = "e5f6a7b8-0000-0000-0000-000000000000" - s := newSessionCorrelationIntegrationSetup(t, sessionID) + s := newCorrelationTestEnv(t, sessionID) defer s.stop() resp, err := s.pt.proxyClient.Get(s.injectBackend.server.URL + "/v1/messages") @@ -164,7 +164,7 @@ func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { // audited (with a sequence number) but does NOT receive correlation // headers. func TestIntegration_NonLLMRequestAuditedWithoutHeaders(t *testing.T) { - s := newSessionCorrelationIntegrationSetup(t, "test-session") + s := newCorrelationTestEnv(t, "test-session") defer s.stop() resp, err := s.pt.proxyClient.Get(s.otherBackend.server.URL + "/pulls") @@ -426,7 +426,7 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { // sequence number, and the audit event still agrees with the header. func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) { const sessionID = "real-session-uuid" - s := newSessionCorrelationIntegrationSetup(t, sessionID) + s := newCorrelationTestEnv(t, sessionID) defer s.stop() req, err := http.NewRequest(http.MethodPost, s.injectBackend.server.URL+"/v1/messages", nil) @@ -455,11 +455,13 @@ func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) ) } -// TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence -// verifies that when session correlation is disabled, the proxy does -// not inject headers and does not pre-allocate sequence numbers (the -// auditor falls back to its own counter instead). -func TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence(t *testing.T) { +// TestIntegration_DisabledCorrelationNoHeaders verifies that when +// session correlation is disabled, the proxy does not inject +// correlation headers even for requests that match an inject target. +// Note: the sequence counter is a value type on the proxy server and +// always increments regardless of the correlation setting, so we only +// assert on the absence of headers here. +func TestIntegration_DisabledCorrelationNoHeaders(t *testing.T) { backend := newMultiRequestCapturingBackend() defer backend.close() @@ -471,13 +473,11 @@ func TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence(t *testi pt := NewProxyTest(t, WithCertManager(t.TempDir()), WithAllowedDomain(backendURL.Hostname()), - // Correlation disabled; no sequence counter. WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: false, InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, }), WithSessionID("should-not-appear"), - // Explicitly do NOT set WithSequenceCounter; seqCounter is nil. WithAuditor(aud), ).Start() defer pt.Stop() @@ -487,18 +487,18 @@ func TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence(t *testi defer resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) - // No correlation headers. + // No correlation headers injected. require.Equal(t, 1, backend.requestCount()) header := backend.headersAt(0) - assert.Empty(t, header.Get(config.SessionIDHeaderName)) - assert.Empty(t, header.Get(config.SequenceNumberHeaderName)) + assert.Empty(t, header.Get(config.SessionIDHeaderName), + "session ID header must not be injected when correlation is disabled") + assert.Empty(t, header.Get(config.SequenceNumberHeaderName), + "sequence number header must not be injected when correlation is disabled") - // Audit event recorded but without a pre-allocated sequence - // number (nil), because no SequenceCounter was provided. + // Request is still audited. events := aud.getRequests() require.Len(t, events, 1) - assert.Equal(t, int32(0), events[0].SequenceNumber, - "no sequence counter means no pre-allocated sequence number") + require.True(t, events[0].Allowed) } // TestIntegration_ConcurrentRequestsUniqueSequenceNumbers sends @@ -509,7 +509,7 @@ func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { const sessionID = "concurrent-session" const numRequests = 10 - s := newSessionCorrelationIntegrationSetup(t, sessionID) + s := newCorrelationTestEnv(t, sessionID) defer s.stop() var wg sync.WaitGroup From 49321bddcea1557c9105edb7885f0f90cd122dd8 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Wed, 13 May 2026 13:28:15 +0000 Subject: [PATCH 6/8] fix(proxy): sync with main branch InjectTarget changes PR #201 was merged to main, replacing config.InjectTarget struct with []string rule specs (rulesengine syntax). Update the integration test file to use the []string format, and sync config/, proxy/proxy.go, and other test files from main to fix the build. Generated by Coder Agents --- ...xy_session_correlation_integration_test.go | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go index 17b650c..5baef04 100644 --- a/proxy/proxy_session_correlation_integration_test.go +++ b/proxy/proxy_session_correlation_integration_test.go @@ -108,10 +108,7 @@ func newCorrelationTestEnv(t *testing.T, sessionID string) *correlationTestEnv { // Only requests matching the inject-target path receive headers. WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, - InjectTargets: []config.InjectTarget{{ - Domain: injectURL.Hostname(), - Path: "/v1/*", - }}, + InjectTargets: []string{"domain=" + injectURL.Hostname() + " path=/v1/*"}, }), WithSessionID(sessionID), WithAuditor(aud), @@ -204,7 +201,7 @@ func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { // No allowed domains: deny everything. WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, - InjectTargets: []config.InjectTarget{{Domain: "anything.example.com"}}, + InjectTargets: []string{"domain=anything.example.com"}, }), WithSessionID("test-session"), WithAuditor(aud), @@ -260,10 +257,7 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { // Only the inject backend is an inject target. WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, - InjectTargets: []config.InjectTarget{{ - Domain: injectURL.Hostname(), - Path: "/v1/*", - }}, + InjectTargets: []string{"domain=" + injectURL.Hostname() + " path=/v1/*"}, }), WithSessionID(sessionID), WithAuditor(aud), @@ -369,10 +363,7 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { WithAllowedDomain(otherURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, - InjectTargets: []config.InjectTarget{{ - Domain: injectURL.Hostname(), - Path: "/v1/*", - }}, + InjectTargets: []string{"domain=" + injectURL.Hostname() + " path=/v1/*"}, }), WithSessionID(sessionID), WithAuditor(aud), @@ -475,7 +466,7 @@ func TestIntegration_DisabledCorrelationNoHeaders(t *testing.T) { WithAllowedDomain(backendURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: false, - InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, + InjectTargets: []string{"domain=" + backendURL.Hostname()}, }), WithSessionID("should-not-appear"), WithAuditor(aud), From f578460773fd4f6ed1453e43822fb0372ed10b50 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Wed, 13 May 2026 13:31:34 +0000 Subject: [PATCH 7/8] make fmt --- proxy/proxy_session_correlation_integration_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go index 5baef04..3f36bb1 100644 --- a/proxy/proxy_session_correlation_integration_test.go +++ b/proxy/proxy_session_correlation_integration_test.go @@ -107,7 +107,7 @@ func newCorrelationTestEnv(t *testing.T, sessionID string) *correlationTestEnv { WithAllowedDomain(otherURL.Hostname()), // Only requests matching the inject-target path receive headers. WithSessionCorrelation(config.SessionCorrelationConfig{ - Enabled: true, + Enabled: true, InjectTargets: []string{"domain=" + injectURL.Hostname() + " path=/v1/*"}, }), WithSessionID(sessionID), @@ -256,7 +256,7 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { WithAllowedDomain(otherURL.Hostname()), // Only the inject backend is an inject target. WithSessionCorrelation(config.SessionCorrelationConfig{ - Enabled: true, + Enabled: true, InjectTargets: []string{"domain=" + injectURL.Hostname() + " path=/v1/*"}, }), WithSessionID(sessionID), @@ -362,7 +362,7 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { WithAllowedDomain(injectURL.Hostname()), WithAllowedDomain(otherURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ - Enabled: true, + Enabled: true, InjectTargets: []string{"domain=" + injectURL.Hostname() + " path=/v1/*"}, }), WithSessionID(sessionID), From a9d59ec67b877527bad759a50722e82a6af99d11 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Thu, 14 May 2026 17:24:27 +0000 Subject: [PATCH 8/8] test(proxy): use Given/When/Then style comments in integration tests --- ...xy_session_correlation_integration_test.go | 81 ++++++++++--------- 1 file changed, 44 insertions(+), 37 deletions(-) diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go index 3f36bb1..c21908c 100644 --- a/proxy/proxy_session_correlation_integration_test.go +++ b/proxy/proxy_session_correlation_integration_test.go @@ -128,27 +128,31 @@ func newCorrelationTestEnv(t *testing.T, sessionID string) *correlationTestEnv { // the sequence number in the audit event equals the sequence number in // the forwarded header. func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { + // Given: a proxy with session correlation enabled and an inject-target backend. const sessionID = "e5f6a7b8-0000-0000-0000-000000000000" s := newCorrelationTestEnv(t, sessionID) defer s.stop() + // When: a single request is sent to the inject-target backend. resp, err := s.pt.proxyClient.Get(s.injectBackend.server.URL + "/v1/messages") require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) + // Then: the audit event records the correct sequence number. events := s.auditor.getRequests() require.Len(t, events, 1) require.True(t, events[0].Allowed) require.NotNil(t, events[0].SequenceNumber) assert.Equal(t, int32(0), events[0].SequenceNumber) - // Forwarded headers. + // Then: the forwarded request carries the session ID and sequence number headers. require.Equal(t, 1, s.injectBackend.requestCount()) header := s.injectBackend.headersAt(0) assert.Equal(t, sessionID, header.Get(config.SessionIDHeaderName)) assert.Equal(t, "0", header.Get(config.SequenceNumberHeaderName)) + // Then: the audit event and forwarded header agree on the sequence number. assert.Equal(t, strconv.Itoa(int(events[0].SequenceNumber)), header.Get(config.SequenceNumberHeaderName), @@ -161,22 +165,24 @@ func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { // audited (with a sequence number) but does NOT receive correlation // headers. func TestIntegration_NonLLMRequestAuditedWithoutHeaders(t *testing.T) { + // Given: a proxy with session correlation enabled and a non-inject-target backend. s := newCorrelationTestEnv(t, "test-session") defer s.stop() + // When: a request is sent to the non-inject-target backend. resp, err := s.pt.proxyClient.Get(s.otherBackend.server.URL + "/pulls") require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) - // Audit event recorded. + // Then: an audit event is recorded with a sequence number. events := s.auditor.getRequests() require.Len(t, events, 1) require.True(t, events[0].Allowed) require.NotNil(t, events[0].SequenceNumber) assert.Equal(t, int32(0), events[0].SequenceNumber) - // No correlation headers on the backend. + // Then: no correlation headers are present on the forwarded request. require.Equal(t, 1, s.otherBackend.requestCount()) header := s.otherBackend.headersAt(0) assert.Empty(t, header.Get(config.SessionIDHeaderName), @@ -189,8 +195,7 @@ func TestIntegration_NonLLMRequestAuditedWithoutHeaders(t *testing.T) { // request denied by the rules engine is audited (consuming a sequence // number) but is never forwarded to any backend. func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { - // Create a setup with a custom deny-all proxy, but keep the same - // pattern of shared sequence counter and auditor. + // Given: a proxy with no allowed domains (deny-all configuration). backend := newMultiRequestCapturingBackend() defer backend.close() @@ -198,7 +203,6 @@ func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { pt := NewProxyTest(t, WithCertManager(t.TempDir()), - // No allowed domains: deny everything. WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, InjectTargets: []string{"domain=anything.example.com"}, @@ -208,19 +212,20 @@ func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { ).Start() defer pt.Stop() + // When: a request is sent to a domain that is not allowed. resp, err := pt.proxyClient.Get(backend.server.URL + "/exfil") require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusForbidden, resp.StatusCode) - // Audit event recorded. + // Then: an audit event is recorded with the denied flag and a sequence number. events := aud.getRequests() require.Len(t, events, 1) require.False(t, events[0].Allowed) require.NotNil(t, events[0].SequenceNumber) assert.Equal(t, int32(0), events[0].SequenceNumber) - // Backend never hit. + // Then: the backend never receives the request. assert.Equal(t, 0, backend.requestCount(), "denied requests must not be forwarded to the backend") } @@ -235,7 +240,7 @@ func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { const sessionID = "mixed-test-session" - // Two allowed backends (inject target and "github"), one denied domain. + // Given: a proxy with an inject-target backend and a non-inject-target backend. inject := newMultiRequestCapturingBackend() defer inject.close() @@ -254,7 +259,6 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { WithCertManager(t.TempDir()), WithAllowedDomain(injectURL.Hostname()), WithAllowedDomain(otherURL.Hostname()), - // Only the inject backend is an inject target. WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, InjectTargets: []string{"domain=" + injectURL.Hostname() + " path=/v1/*"}, @@ -264,31 +268,30 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { ).Start() defer pt.Stop() - // Request 0: inject target (allowed, headers injected). + // When: an inject-target, non-inject-target, denied, and inject-target + // request are sent in sequence. resp, err := pt.proxyClient.Get(inject.server.URL + "/v1/messages") require.NoError(t, err) resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) - // Request 1: non-inject-target (allowed, no headers). resp, err = pt.proxyClient.Get(other.server.URL + "/coder/coder") require.NoError(t, err) resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) - // Request 2: denied (nothing is allowed for evil.example.com). resp, err = pt.proxyClient.Get("http://evil.example.com/exfil") require.NoError(t, err) resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusForbidden, resp.StatusCode) - // Request 3: inject target again. resp, err = pt.proxyClient.Get(inject.server.URL + "/v1/messages") require.NoError(t, err) resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) - // -- Verify audit events -- + // Then: all four requests produce audit events with monotonically + // increasing sequence numbers. events := aud.getRequests() require.Len(t, events, 4, "expected exactly four audit events") @@ -302,7 +305,8 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { "event %d: wrong allowed flag", i) } - // -- Verify inject-target backend headers -- + // Then: the inject-target backend receives correlation headers with + // the correct sequence numbers. require.Equal(t, 2, inject.requestCount(), "inject-target backend should have received exactly two requests") @@ -316,17 +320,14 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { assert.Equal(t, "3", secondInjectHeader.Get(config.SequenceNumberHeaderName), "second inject-target request must have sequence 3") - // -- Verify non-inject-target backend has no correlation headers -- + // Then: the non-inject-target backend receives no correlation headers. require.Equal(t, 1, other.requestCount()) otherHeader := other.headersAt(0) assert.Empty(t, otherHeader.Get(config.SessionIDHeaderName)) assert.Empty(t, otherHeader.Get(config.SequenceNumberHeaderName)) - // -- Verify the gap reveals intermediate activity -- - // The gap between the two inject-target sequence numbers (0 and 3) - // means that sequence numbers 1 and 2 were consumed by - // non-inject-target activity, matching audit events 1 - // (non-inject-target allowed) and 2 (denied). + // Then: the gap between inject-target sequence numbers (0 and 3) + // reveals 2 intermediate events (non-inject-target allowed and denied). firstInjectSeq := events[0].SequenceNumber secondInjectSeq := events[3].SequenceNumber gap := secondInjectSeq - firstInjectSeq - 1 @@ -343,6 +344,7 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { const sessionID = "agentic-loop-session" + // Given: a proxy with an inject-target and a non-inject-target backend. inject := newMultiRequestCapturingBackend() defer inject.close() @@ -370,30 +372,29 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { ).Start() defer pt.Stop() - // First inject-target request (seq 0). + // When: an inject-target request, three tool-use requests to the + // non-inject-target backend, and another inject-target request are + // sent in sequence. resp, err := pt.proxyClient.Get(inject.server.URL + "/v1/messages") require.NoError(t, err) resp.Body.Close() //nolint:errcheck - // Agentic loop: three tool-use HTTP calls. for _, p := range []string{"/coder/coder", "/coder/coder/issues", "/coder/coder/pulls"} { resp, err = pt.proxyClient.Get(other.server.URL + p) require.NoError(t, err) resp.Body.Close() //nolint:errcheck } - // Second inject-target request (seq 4). resp, err = pt.proxyClient.Get(inject.server.URL + "/v1/messages") require.NoError(t, err) resp.Body.Close() //nolint:errcheck - // Verify inject-target sequence headers. + // Then: the inject-target headers show a gap from sequence 0 to 4. require.Equal(t, 2, inject.requestCount()) assert.Equal(t, "0", inject.headersAt(0).Get(config.SequenceNumberHeaderName)) assert.Equal(t, "4", inject.headersAt(1).Get(config.SequenceNumberHeaderName)) - // The gap between sequence numbers 0 and 4 is 3, matching the - // three tool-use requests in between. + // Then: the gap of 3 matches the three intermediate tool-use requests. events := aud.getRequests() require.Len(t, events, 5) @@ -403,7 +404,7 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { assert.Equal(t, int32(3), gap, "gap between prompts should equal number of tool-use requests") - // Verify the intermediate events are the tool-use requests. + // Then: the intermediate audit events correspond to the tool-use requests. for i := 1; i <= 3; i++ { require.NotNil(t, events[i].SequenceNumber) assert.Equal(t, int32(i), events[i].SequenceNumber) @@ -416,10 +417,12 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { // the proxy replaces them with the real session ID and the real // sequence number, and the audit event still agrees with the header. func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) { + // Given: a proxy with session correlation enabled. const sessionID = "real-session-uuid" s := newCorrelationTestEnv(t, sessionID) defer s.stop() + // When: a request is sent with spoofed correlation headers. req, err := http.NewRequest(http.MethodPost, s.injectBackend.server.URL+"/v1/messages", nil) require.NoError(t, err) req.Header.Set(config.SessionIDHeaderName, "spoofed-session") @@ -430,13 +433,13 @@ func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) defer resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) - // Backend received real values, not spoofed. + // Then: the backend receives the real values, not the spoofed ones. require.Equal(t, 1, s.injectBackend.requestCount()) header := s.injectBackend.headersAt(0) assert.Equal(t, sessionID, header.Get(config.SessionIDHeaderName)) assert.Equal(t, "0", header.Get(config.SequenceNumberHeaderName)) - // Audit event agrees with header. + // Then: the audit event agrees with the forwarded header. events := s.auditor.getRequests() require.Len(t, events, 1) require.NotNil(t, events[0].SequenceNumber) @@ -453,6 +456,7 @@ func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) // always increments regardless of the correlation setting, so we only // assert on the absence of headers here. func TestIntegration_DisabledCorrelationNoHeaders(t *testing.T) { + // Given: a proxy with session correlation disabled. backend := newMultiRequestCapturingBackend() defer backend.close() @@ -473,12 +477,13 @@ func TestIntegration_DisabledCorrelationNoHeaders(t *testing.T) { ).Start() defer pt.Stop() + // When: a request is sent that would match an inject target. resp, err := pt.proxyClient.Get(backend.server.URL + "/v1/messages") require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) - // No correlation headers injected. + // Then: no correlation headers are injected on the forwarded request. require.Equal(t, 1, backend.requestCount()) header := backend.headersAt(0) assert.Empty(t, header.Get(config.SessionIDHeaderName), @@ -486,7 +491,7 @@ func TestIntegration_DisabledCorrelationNoHeaders(t *testing.T) { assert.Empty(t, header.Get(config.SequenceNumberHeaderName), "sequence number header must not be injected when correlation is disabled") - // Request is still audited. + // Then: the request is still audited. events := aud.getRequests() require.Len(t, events, 1) require.True(t, events[0].Allowed) @@ -500,9 +505,11 @@ func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { const sessionID = "concurrent-session" const numRequests = 10 + // Given: a proxy with session correlation enabled. s := newCorrelationTestEnv(t, sessionID) defer s.stop() + // When: multiple requests are sent concurrently to the inject-target backend. var wg sync.WaitGroup for i := 0; i < numRequests; i++ { wg.Add(1) @@ -517,11 +524,11 @@ func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { } wg.Wait() - // Every request should have been audited. + // Then: every request is audited. events := s.auditor.getRequests() require.Len(t, events, numRequests) - // Collect all sequence numbers and verify uniqueness. + // Then: each audit event has a unique sequence number. seen := make(map[int32]bool, numRequests) for i, ev := range events { require.NotNil(t, ev.SequenceNumber, @@ -531,13 +538,13 @@ func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { seen[ev.SequenceNumber] = true } - // The set should be exactly {0, 1, ..., numRequests-1}. + // Then: the sequence numbers form a dense set {0, 1, ..., numRequests-1}. for i := int32(0); i < numRequests; i++ { assert.True(t, seen[i], "sequence number %d is missing from the set", i) } - // Every header should also carry a matching sequence number. + // Then: every forwarded request header carries a matching sequence number. require.Equal(t, numRequests, s.injectBackend.requestCount()) headerSeqs := make(map[string]bool, numRequests) for i := 0; i < numRequests; i++ {