From f50def8178b69a0a8c005a3e58a6d58da5dd42ac Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Thu, 30 Apr 2026 12:52:47 +0000 Subject: [PATCH 01/10] feat(proxy): inject session ID and sequence number headers on matching requests When session correlation is enabled and the outgoing request matches a configured inject target, set X-Coder-Agent-Firewall-Session-Id and X-Coder-Agent-Firewall-Sequence-Number headers on the forwarded request. Any values the jailed client may have set are overwritten so the upstream always sees boundary's authoritative session ID and sequence number. The sequence number is pre-allocated before the audit event so both the audit log and the injected header carry the same value. audit.Request gains a SequenceNumber pointer field; when non-nil the socket auditor uses it instead of calling its own counter. New proxy.Config fields: SessionCorrelation, SessionID, SequenceCounter. New Server method: shouldInjectHeaders (domain + optional path glob matching). Tests cover matched domain, unmatched domain, disabled injection, client-supplied header overwrite, path glob matching, and sequence number incrementing. --- audit/request.go | 7 + audit/socket_auditor.go | 9 +- proxy/proxy.go | 102 +++++++-- proxy/proxy_framework_test.go | 59 +++-- proxy/proxy_session_correlation_test.go | 273 ++++++++++++++++++++++++ 5 files changed, 414 insertions(+), 36 deletions(-) create mode 100644 proxy/proxy_session_correlation_test.go diff --git a/audit/request.go b/audit/request.go index c6ef1b37..e23f774e 100644 --- a/audit/request.go +++ b/audit/request.go @@ -11,4 +11,11 @@ type Request struct { Host string Allowed bool Rule string // The rule that matched (if any) + + // SequenceNumber is a pre-allocated sequence number for this + // audit event. When non-nil the auditor must use this value + // instead of generating its own so that the audit log and + // any injected HTTP header carry the same number. When nil + // the auditor falls back to its internal SequenceCounter. + SequenceNumber *uint64 } diff --git a/audit/socket_auditor.go b/audit/socket_auditor.go index afb50a34..06fce950 100644 --- a/audit/socket_auditor.go +++ b/audit/socket_auditor.go @@ -81,10 +81,17 @@ func (s *SocketAuditor) AuditRequest(req Request) { httpReq.MatchedRule = req.Rule } + var seqNum uint64 + if req.SequenceNumber != nil { + seqNum = *req.SequenceNumber + } else { + seqNum = s.seq.Next() + } + log := &agentproto.BoundaryLog{ Allowed: req.Allowed, Time: timestamppb.Now(), - SequenceNumber: s.seq.Next(), + SequenceNumber: seqNum, Resource: &agentproto.BoundaryLog_HttpRequest_{HttpRequest: httpReq}, } diff --git a/proxy/proxy.go b/proxy/proxy.go index 154a6bcc..0fb5dfa7 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -13,23 +13,28 @@ import ( "net/http" _ "net/http/pprof" "net/url" + "path" "strconv" "strings" "sync/atomic" "time" "github.com/coder/boundary/audit" + "github.com/coder/boundary/config" "github.com/coder/boundary/rulesengine" ) // Server handles HTTP and HTTPS requests with rule-based filtering type Server struct { - ruleEngine rulesengine.Engine - auditor audit.Auditor - logger *slog.Logger - tlsConfig *tls.Config - httpPort int - started atomic.Bool + ruleEngine rulesengine.Engine + auditor audit.Auditor + logger *slog.Logger + tlsConfig *tls.Config + httpPort int + started atomic.Bool + sessionCorrelation config.SessionCorrelationConfig + sessionID string + seqCounter *audit.SequenceCounter listener net.Listener pprofServer *http.Server @@ -46,18 +51,30 @@ type Config struct { TLSConfig *tls.Config PprofEnabled bool PprofPort int + // SessionCorrelation controls header injection for AI Bridge + // correlation. See config.SessionCorrelationConfig for details. + SessionCorrelation config.SessionCorrelationConfig + // SessionID is the boundary session UUID injected as a header + // on matching requests. + SessionID string + // SequenceCounter provides monotonically increasing sequence + // numbers shared with the auditor so both carry the same value. + SequenceCounter *audit.SequenceCounter } // NewProxyServer creates a new proxy server instance func NewProxyServer(config Config) *Server { return &Server{ - ruleEngine: config.RuleEngine, - auditor: config.Auditor, - logger: config.Logger, - tlsConfig: config.TLSConfig, - httpPort: config.HTTPPort, - pprofEnabled: config.PprofEnabled, - pprofPort: config.PprofPort, + ruleEngine: config.RuleEngine, + auditor: config.Auditor, + logger: config.Logger, + tlsConfig: config.TLSConfig, + httpPort: config.HTTPPort, + pprofEnabled: config.PprofEnabled, + pprofPort: config.PprofPort, + sessionCorrelation: config.SessionCorrelation, + sessionID: config.SessionID, + seqCounter: config.SequenceCounter, } } @@ -276,12 +293,21 @@ func (p *Server) processHTTPRequest(conn net.Conn, req *http.Request, https bool result := p.ruleEngine.Evaluate(req.Method, fullURL) + // Pre-allocate a sequence number so the audit event and any + // injected header carry the same value. + var seqNum *uint64 + if p.seqCounter != nil { + n := p.seqCounter.Next() + seqNum = &n + } + p.auditor.AuditRequest(audit.Request{ - Method: req.Method, - URL: fullURL, - Host: req.Host, - Allowed: result.Allowed, - Rule: result.Rule, + Method: req.Method, + URL: fullURL, + Host: req.Host, + Allowed: result.Allowed, + Rule: result.Rule, + SequenceNumber: seqNum, }) if !result.Allowed { @@ -290,10 +316,36 @@ func (p *Server) processHTTPRequest(conn net.Conn, req *http.Request, https bool } // Forward request to destination - p.forwardRequest(conn, req, https) + p.forwardRequest(conn, req, https, seqNum) +} + +// shouldInjectHeaders reports whether the request to the given host +// and path matches any configured inject target. When session +// correlation is disabled or no targets match it returns false. +func (p *Server) shouldInjectHeaders(host, reqPath string) bool { + if !p.sessionCorrelation.Enabled { + return false + } + // Strip port from host for matching (e.g. "example.com:443" -> "example.com"). + h := host + if i := strings.LastIndex(h, ":"); i != -1 { + h = h[:i] + } + for _, target := range p.sessionCorrelation.InjectTargets { + if !strings.EqualFold(target.Domain, h) { + continue + } + if target.Path == "" { + return true + } + if matched, _ := path.Match(target.Path, reqPath); matched { + return true + } + } + return false } -func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool) { +func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool, seqNum *uint64) { // Create HTTP client client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { @@ -338,6 +390,16 @@ func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool) { } } + // Stamp session correlation headers on matching requests, + // overwriting any value the jailed client may have set so the + // upstream always sees boundary's ID. + if p.shouldInjectHeaders(req.Host, req.URL.Path) { + newReq.Header.Set(p.sessionCorrelation.SessionIDHeaderName, p.sessionID) + if seqNum != nil { + newReq.Header.Set(p.sessionCorrelation.SequenceNumberHeaderName, strconv.FormatUint(*seqNum, 10)) + } + } + // Make request to destination resp, err := client.Do(newReq) if err != nil { diff --git a/proxy/proxy_framework_test.go b/proxy/proxy_framework_test.go index 36a332bc..97c9975f 100644 --- a/proxy/proxy_framework_test.go +++ b/proxy/proxy_framework_test.go @@ -18,6 +18,7 @@ import ( "time" "github.com/coder/boundary/audit" + "github.com/coder/boundary/config" "github.com/coder/boundary/rulesengine" boundary_tls "github.com/coder/boundary/tls" "github.com/stretchr/testify/require" @@ -32,16 +33,19 @@ func (m *mockAuditor) AuditRequest(req audit.Request) { // ProxyTest is a high-level test framework for proxy tests type ProxyTest struct { - t *testing.T - server *Server - client *http.Client - proxyClient *http.Client - port int - useCertManager bool - configDir string - startupDelay time.Duration - allowedRules []string - auditor audit.Auditor + t *testing.T + server *Server + client *http.Client + proxyClient *http.Client + port int + useCertManager bool + configDir string + startupDelay time.Duration + allowedRules []string + auditor audit.Auditor + sessionCorrelation config.SessionCorrelationConfig + sessionID string + seqCounter *audit.SequenceCounter } // ProxyTestOption is a function that configures ProxyTest @@ -109,6 +113,28 @@ func WithAuditor(auditor audit.Auditor) ProxyTestOption { } } +// WithSessionCorrelation sets the session correlation config for the +// proxy under test. +func WithSessionCorrelation(sc config.SessionCorrelationConfig) ProxyTestOption { + return func(pt *ProxyTest) { + pt.sessionCorrelation = sc + } +} + +// WithSessionID sets the boundary session ID for the proxy under test. +func WithSessionID(id string) ProxyTestOption { + return func(pt *ProxyTest) { + pt.sessionID = id + } +} + +// WithSequenceCounter sets the sequence counter for the proxy under test. +func WithSequenceCounter(seq *audit.SequenceCounter) ProxyTestOption { + return func(pt *ProxyTest) { + pt.seqCounter = seq + } +} + // Start starts the proxy server func (pt *ProxyTest) Start() *ProxyTest { pt.t.Helper() @@ -153,11 +179,14 @@ func (pt *ProxyTest) Start() *ProxyTest { } pt.server = NewProxyServer(Config{ - HTTPPort: pt.port, - RuleEngine: ruleEngine, - Auditor: auditor, - Logger: logger, - TLSConfig: tlsConfig, + HTTPPort: pt.port, + RuleEngine: ruleEngine, + Auditor: auditor, + Logger: logger, + TLSConfig: tlsConfig, + SessionCorrelation: pt.sessionCorrelation, + SessionID: pt.sessionID, + SequenceCounter: pt.seqCounter, }) err = pt.server.Start() diff --git a/proxy/proxy_session_correlation_test.go b/proxy/proxy_session_correlation_test.go new file mode 100644 index 00000000..40a419e7 --- /dev/null +++ b/proxy/proxy_session_correlation_test.go @@ -0,0 +1,273 @@ +package proxy + +import ( + "net/http" + "net/http/httptest" + "net/url" + "sync" + "testing" + + "github.com/coder/boundary/audit" + "github.com/coder/boundary/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// headerCapturingBackend spins up an httptest.Server that records the +// headers it receives. Call receivedHeaders after the request to inspect +// them. +type headerCapturingBackend struct { + server *httptest.Server + mu sync.Mutex + headers http.Header +} + +func newHeaderCapturingBackend() *headerCapturingBackend { + hcb := &headerCapturingBackend{} + hcb.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hcb.mu.Lock() + hcb.headers = r.Header.Clone() + hcb.mu.Unlock() + w.WriteHeader(http.StatusOK) + })) + return hcb +} + +func (h *headerCapturingBackend) close() { h.server.Close() } + +func (h *headerCapturingBackend) receivedHeaders() http.Header { + h.mu.Lock() + defer h.mu.Unlock() + return h.headers.Clone() +} + +func TestSessionCorrelation_MatchedDomain(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("test-session-id-1234"), + WithSequenceCounter(seq), + ).Start() + defer pt.Stop() + + resp, err := pt.proxyClient.Get(backend.server.URL + "/api/v2") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Equal(t, "test-session-id-1234", got.Get(config.DefaultSessionIDHeaderName), + "session ID header must be injected on matching domain") + assert.Equal(t, "0", got.Get(config.DefaultSequenceNumberHeaderName), + "sequence number header must start at 0") +} + +func TestSessionCorrelation_UnmatchedDomain(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{Domain: "other-domain.example.com"}}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("test-session-id-1234"), + WithSequenceCounter(seq), + ).Start() + defer pt.Stop() + + resp, err := pt.proxyClient.Get(backend.server.URL + "/api/v2") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Empty(t, got.Get(config.DefaultSessionIDHeaderName), + "session ID header must not be injected on unmatched domain") + assert.Empty(t, got.Get(config.DefaultSequenceNumberHeaderName), + "sequence number header must not be injected on unmatched domain") +} + +func TestSessionCorrelation_Disabled(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: false, + InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("test-session-id-1234"), + WithSequenceCounter(seq), + ).Start() + defer pt.Stop() + + resp, err := pt.proxyClient.Get(backend.server.URL + "/api/v2") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Empty(t, got.Get(config.DefaultSessionIDHeaderName), + "session ID header must not be injected when correlation is disabled") + assert.Empty(t, got.Get(config.DefaultSequenceNumberHeaderName), + "sequence number header must not be injected when correlation is disabled") +} + +func TestSessionCorrelation_OverwritesClientValue(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("real-session-id"), + WithSequenceCounter(seq), + ).Start() + defer pt.Stop() + + // Send a request with client-supplied session correlation headers + // that should be overwritten by the proxy. + req, err := http.NewRequest(http.MethodGet, backend.server.URL+"/api/v2", nil) + require.NoError(t, err) + req.Header.Set(config.DefaultSessionIDHeaderName, "spoofed-session-id") + req.Header.Set(config.DefaultSequenceNumberHeaderName, "99999") + + resp, err := pt.proxyClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Equal(t, "real-session-id", got.Get(config.DefaultSessionIDHeaderName), + "proxy must overwrite client-supplied session ID header") + assert.Equal(t, "0", got.Get(config.DefaultSequenceNumberHeaderName), + "proxy must overwrite client-supplied sequence number header") +} + +func TestSessionCorrelation_PathMatching(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{ + Domain: backendURL.Hostname(), + Path: "/api/*", + }}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("test-session-id"), + WithSequenceCounter(seq), + ).Start() + defer pt.Stop() + + t.Run("matching path", func(t *testing.T) { + resp, err := pt.proxyClient.Get(backend.server.URL + "/api/v2") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Equal(t, "test-session-id", got.Get(config.DefaultSessionIDHeaderName), + "header must be injected when path matches") + }) + + t.Run("non-matching path", func(t *testing.T) { + resp, err := pt.proxyClient.Get(backend.server.URL + "/other/path") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Empty(t, got.Get(config.DefaultSessionIDHeaderName), + "header must not be injected when path does not match") + }) +} + +func TestSessionCorrelation_SequenceNumberIncrements(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("test-session-id"), + WithSequenceCounter(seq), + ).Start() + defer pt.Stop() + + for i, expected := range []string{"0", "1", "2"} { + resp, err := pt.proxyClient.Get(backend.server.URL + "/api/v2") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Equal(t, expected, got.Get(config.DefaultSequenceNumberHeaderName), + "request %d: sequence number must be %s", i, expected) + } +} From 79d648559d166598aaf0574e222d6a79f465d8f9 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Thu, 7 May 2026 14:33:08 +0000 Subject: [PATCH 02/10] refactor(audit, proxy): change sequence number type from uint64 to int32 Updated the SequenceNumber field in the audit.Request struct and related handling in the SocketAuditor and proxy components to use int32 instead of uint64. This change ensures consistency in data types across the application. Additionally, minor adjustments were made to the session correlation tests to reflect the updated header names and improve clarity in assertions. --- audit/request.go | 2 +- audit/socket_auditor.go | 2 +- cli/cli.go | 2 +- proxy/proxy.go | 8 ++-- proxy/proxy_session_correlation_test.go | 58 ++++++++++--------------- 5 files changed, 30 insertions(+), 42 deletions(-) diff --git a/audit/request.go b/audit/request.go index e23f774e..08608ad9 100644 --- a/audit/request.go +++ b/audit/request.go @@ -17,5 +17,5 @@ type Request struct { // instead of generating its own so that the audit log and // any injected HTTP header carry the same number. When nil // the auditor falls back to its internal SequenceCounter. - SequenceNumber *uint64 + SequenceNumber *int32 } diff --git a/audit/socket_auditor.go b/audit/socket_auditor.go index 06fce950..1d69e05f 100644 --- a/audit/socket_auditor.go +++ b/audit/socket_auditor.go @@ -81,7 +81,7 @@ func (s *SocketAuditor) AuditRequest(req Request) { httpReq.MatchedRule = req.Rule } - var seqNum uint64 + var seqNum int32 if req.SequenceNumber != nil { seqNum = *req.SequenceNumber } else { diff --git a/cli/cli.go b/cli/cli.go index bed3e6e7..f43eeca2 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -195,7 +195,7 @@ func BaseCommand(version string) *serpent.Command { { Flag: "session-id-inject-target", Env: "BOUNDARY_SESSION_ID_INJECT_TARGET", - Description: `Inject target for session correlation headers. Repeat the flag once per target; each value describes exactly one target. Format: "domain= [path=]". Example: --session-id-inject-target "domain=prod.coder.com path=/api/v2/aibridge/*"`, + Description: `Inject target for session correlation headers. Repeat the flag once per target; each value describes exactly one target. Format: "domain= [path=]". Example: --session-id-inject-target "domain=prod.coder.com path=/api/v2/aibridge/*".`, Value: &cliConfig.InjectSessionIDTarget, YAML: "", // CLI only, YAML uses session_id_inject_targets. }, diff --git a/proxy/proxy.go b/proxy/proxy.go index 0fb5dfa7..59ab57bd 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -295,7 +295,7 @@ func (p *Server) processHTTPRequest(conn net.Conn, req *http.Request, https bool // Pre-allocate a sequence number so the audit event and any // injected header carry the same value. - var seqNum *uint64 + var seqNum *int32 if p.seqCounter != nil { n := p.seqCounter.Next() seqNum = &n @@ -345,7 +345,7 @@ func (p *Server) shouldInjectHeaders(host, reqPath string) bool { return false } -func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool, seqNum *uint64) { +func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool, seqNum *int32) { // Create HTTP client client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { @@ -394,9 +394,9 @@ func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool, se // overwriting any value the jailed client may have set so the // upstream always sees boundary's ID. if p.shouldInjectHeaders(req.Host, req.URL.Path) { - newReq.Header.Set(p.sessionCorrelation.SessionIDHeaderName, p.sessionID) + newReq.Header.Set(config.SessionIDHeaderName, p.sessionID) if seqNum != nil { - newReq.Header.Set(p.sessionCorrelation.SequenceNumberHeaderName, strconv.FormatUint(*seqNum, 10)) + newReq.Header.Set(config.SequenceNumberHeaderName, strconv.Itoa(int(*seqNum))) } } diff --git a/proxy/proxy_session_correlation_test.go b/proxy/proxy_session_correlation_test.go index 40a419e7..dde3a169 100644 --- a/proxy/proxy_session_correlation_test.go +++ b/proxy/proxy_session_correlation_test.go @@ -54,10 +54,8 @@ func TestSessionCorrelation_MatchedDomain(t *testing.T) { WithCertManager(t.TempDir()), WithAllowedDomain(backendURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ - Enabled: true, - InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, - SessionIDHeaderName: config.DefaultSessionIDHeaderName, - SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + Enabled: true, + InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, }), WithSessionID("test-session-id-1234"), WithSequenceCounter(seq), @@ -70,9 +68,9 @@ func TestSessionCorrelation_MatchedDomain(t *testing.T) { require.Equal(t, http.StatusOK, resp.StatusCode) got := backend.receivedHeaders() - assert.Equal(t, "test-session-id-1234", got.Get(config.DefaultSessionIDHeaderName), + assert.Equal(t, "test-session-id-1234", got.Get(config.SessionIDHeaderName), "session ID header must be injected on matching domain") - assert.Equal(t, "0", got.Get(config.DefaultSequenceNumberHeaderName), + assert.Equal(t, "0", got.Get(config.SequenceNumberHeaderName), "sequence number header must start at 0") } @@ -89,10 +87,8 @@ func TestSessionCorrelation_UnmatchedDomain(t *testing.T) { WithCertManager(t.TempDir()), WithAllowedDomain(backendURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ - Enabled: true, - InjectTargets: []config.InjectTarget{{Domain: "other-domain.example.com"}}, - SessionIDHeaderName: config.DefaultSessionIDHeaderName, - SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + Enabled: true, + InjectTargets: []config.InjectTarget{{Domain: "other-domain.example.com"}}, }), WithSessionID("test-session-id-1234"), WithSequenceCounter(seq), @@ -105,9 +101,9 @@ func TestSessionCorrelation_UnmatchedDomain(t *testing.T) { require.Equal(t, http.StatusOK, resp.StatusCode) got := backend.receivedHeaders() - assert.Empty(t, got.Get(config.DefaultSessionIDHeaderName), + assert.Empty(t, got.Get(config.SessionIDHeaderName), "session ID header must not be injected on unmatched domain") - assert.Empty(t, got.Get(config.DefaultSequenceNumberHeaderName), + assert.Empty(t, got.Get(config.SequenceNumberHeaderName), "sequence number header must not be injected on unmatched domain") } @@ -124,10 +120,8 @@ func TestSessionCorrelation_Disabled(t *testing.T) { WithCertManager(t.TempDir()), WithAllowedDomain(backendURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ - Enabled: false, - InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, - SessionIDHeaderName: config.DefaultSessionIDHeaderName, - SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + Enabled: false, + InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, }), WithSessionID("test-session-id-1234"), WithSequenceCounter(seq), @@ -140,9 +134,9 @@ func TestSessionCorrelation_Disabled(t *testing.T) { require.Equal(t, http.StatusOK, resp.StatusCode) got := backend.receivedHeaders() - assert.Empty(t, got.Get(config.DefaultSessionIDHeaderName), + assert.Empty(t, got.Get(config.SessionIDHeaderName), "session ID header must not be injected when correlation is disabled") - assert.Empty(t, got.Get(config.DefaultSequenceNumberHeaderName), + assert.Empty(t, got.Get(config.SequenceNumberHeaderName), "sequence number header must not be injected when correlation is disabled") } @@ -159,10 +153,8 @@ func TestSessionCorrelation_OverwritesClientValue(t *testing.T) { WithCertManager(t.TempDir()), WithAllowedDomain(backendURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ - Enabled: true, - InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, - SessionIDHeaderName: config.DefaultSessionIDHeaderName, - SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + Enabled: true, + InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, }), WithSessionID("real-session-id"), WithSequenceCounter(seq), @@ -173,8 +165,8 @@ func TestSessionCorrelation_OverwritesClientValue(t *testing.T) { // that should be overwritten by the proxy. req, err := http.NewRequest(http.MethodGet, backend.server.URL+"/api/v2", nil) require.NoError(t, err) - req.Header.Set(config.DefaultSessionIDHeaderName, "spoofed-session-id") - req.Header.Set(config.DefaultSequenceNumberHeaderName, "99999") + req.Header.Set(config.SessionIDHeaderName, "spoofed-session-id") + req.Header.Set(config.SequenceNumberHeaderName, "99999") resp, err := pt.proxyClient.Do(req) require.NoError(t, err) @@ -182,9 +174,9 @@ func TestSessionCorrelation_OverwritesClientValue(t *testing.T) { require.Equal(t, http.StatusOK, resp.StatusCode) got := backend.receivedHeaders() - assert.Equal(t, "real-session-id", got.Get(config.DefaultSessionIDHeaderName), + assert.Equal(t, "real-session-id", got.Get(config.SessionIDHeaderName), "proxy must overwrite client-supplied session ID header") - assert.Equal(t, "0", got.Get(config.DefaultSequenceNumberHeaderName), + assert.Equal(t, "0", got.Get(config.SequenceNumberHeaderName), "proxy must overwrite client-supplied sequence number header") } @@ -206,8 +198,6 @@ func TestSessionCorrelation_PathMatching(t *testing.T) { Domain: backendURL.Hostname(), Path: "/api/*", }}, - SessionIDHeaderName: config.DefaultSessionIDHeaderName, - SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, }), WithSessionID("test-session-id"), WithSequenceCounter(seq), @@ -221,7 +211,7 @@ func TestSessionCorrelation_PathMatching(t *testing.T) { require.Equal(t, http.StatusOK, resp.StatusCode) got := backend.receivedHeaders() - assert.Equal(t, "test-session-id", got.Get(config.DefaultSessionIDHeaderName), + assert.Equal(t, "test-session-id", got.Get(config.SessionIDHeaderName), "header must be injected when path matches") }) @@ -232,7 +222,7 @@ func TestSessionCorrelation_PathMatching(t *testing.T) { require.Equal(t, http.StatusOK, resp.StatusCode) got := backend.receivedHeaders() - assert.Empty(t, got.Get(config.DefaultSessionIDHeaderName), + assert.Empty(t, got.Get(config.SessionIDHeaderName), "header must not be injected when path does not match") }) } @@ -250,10 +240,8 @@ func TestSessionCorrelation_SequenceNumberIncrements(t *testing.T) { WithCertManager(t.TempDir()), WithAllowedDomain(backendURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ - Enabled: true, - InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, - SessionIDHeaderName: config.DefaultSessionIDHeaderName, - SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + Enabled: true, + InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, }), WithSessionID("test-session-id"), WithSequenceCounter(seq), @@ -267,7 +255,7 @@ func TestSessionCorrelation_SequenceNumberIncrements(t *testing.T) { require.Equal(t, http.StatusOK, resp.StatusCode) got := backend.receivedHeaders() - assert.Equal(t, expected, got.Get(config.DefaultSequenceNumberHeaderName), + assert.Equal(t, expected, got.Get(config.SequenceNumberHeaderName), "request %d: sequence number must be %s", i, expected) } } From acc80763d5de6f0f542b4bda581aca8f7321b207 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Thu, 7 May 2026 14:57:34 +0000 Subject: [PATCH 03/10] refactor(audit, proxy): remove SequenceCounter and update SequenceNumber handling Eliminated the SequenceCounter from the SocketAuditor and related components, simplifying the sequence number management. The SequenceNumber field in the audit.Request struct is now a non-pointer int32, ensuring consistent handling across the application. Adjusted tests to reflect these changes and maintain functionality. --- audit/multi_auditor.go | 3 +-- audit/request.go | 10 ++++------ audit/socket_auditor.go | 15 ++++----------- audit/socket_auditor_test.go | 8 ++------ proxy/proxy.go | 20 ++++---------------- proxy/proxy_framework_test.go | 9 --------- proxy/proxy_session_correlation_test.go | 19 ------------------- 7 files changed, 15 insertions(+), 69 deletions(-) diff --git a/audit/multi_auditor.go b/audit/multi_auditor.go index 91607dc1..bd0789ba 100644 --- a/audit/multi_auditor.go +++ b/audit/multi_auditor.go @@ -50,8 +50,7 @@ func SetupAuditor(ctx context.Context, logger *slog.Logger, disableAuditLogs boo } agentWillProxy := !os.IsNotExist(err) if agentWillProxy { - seq := &SequenceCounter{} - socketAuditor := NewSocketAuditor(logger, logProxySocketPath, sessionID, seq) + socketAuditor := NewSocketAuditor(logger, logProxySocketPath, sessionID) go socketAuditor.Loop(ctx) auditors = append(auditors, socketAuditor) } else { diff --git a/audit/request.go b/audit/request.go index 08608ad9..d63f2a93 100644 --- a/audit/request.go +++ b/audit/request.go @@ -12,10 +12,8 @@ type Request struct { Allowed bool Rule string // The rule that matched (if any) - // SequenceNumber is a pre-allocated sequence number for this - // audit event. When non-nil the auditor must use this value - // instead of generating its own so that the audit log and - // any injected HTTP header carry the same number. When nil - // the auditor falls back to its internal SequenceCounter. - SequenceNumber *int32 + // SequenceNumber is the sequence number assigned to this audit event + // by the proxy. It is monotonically increasing within a session and + // is shared with any injected HTTP header so both carry the same value. + SequenceNumber int32 } diff --git a/audit/socket_auditor.go b/audit/socket_auditor.go index 1d69e05f..e3477d0f 100644 --- a/audit/socket_auditor.go +++ b/audit/socket_auditor.go @@ -34,9 +34,8 @@ type SocketAuditor struct { logCh chan *agentproto.BoundaryLog batchSize int batchTimerDuration time.Duration - socketPath string - sessionID uuid.UUID - seq *SequenceCounter + socketPath string + sessionID uuid.UUID droppedChannelFull atomic.Int64 droppedBatchFull atomic.Int64 @@ -48,7 +47,7 @@ type SocketAuditor struct { // NewSocketAuditor creates a new SocketAuditor that sends logs to the agent's // boundary log proxy socket after SocketAuditor.Loop is called. The socket path // is read from EnvAuditSocketPath, falling back to defaultAuditSocketPath. -func NewSocketAuditor(logger *slog.Logger, socketPath string, sessionID uuid.UUID, seq *SequenceCounter) *SocketAuditor { +func NewSocketAuditor(logger *slog.Logger, socketPath string, sessionID uuid.UUID) *SocketAuditor { // This channel buffer size intends to allow enough buffering for bursty // AI agent network requests while a batch is being sent to the workspace // agent. @@ -64,7 +63,6 @@ func NewSocketAuditor(logger *slog.Logger, socketPath string, sessionID uuid.UUI batchTimerDuration: defaultBatchTimerDuration, socketPath: socketPath, sessionID: sessionID, - seq: seq, } } @@ -81,12 +79,7 @@ func (s *SocketAuditor) AuditRequest(req Request) { httpReq.MatchedRule = req.Rule } - var seqNum int32 - if req.SequenceNumber != nil { - seqNum = *req.SequenceNumber - } else { - seqNum = s.seq.Next() - } + var seqNum int32 = req.SequenceNumber log := &agentproto.BoundaryLog{ Allowed: req.Allowed, diff --git a/audit/socket_auditor_test.go b/audit/socket_auditor_test.go index 64437810..09e359a7 100644 --- a/audit/socket_auditor_test.go +++ b/audit/socket_auditor_test.go @@ -242,7 +242,6 @@ func TestSocketAuditor_Loop_RetriesOnConnectionFailure(t *testing.T) { batchSize: defaultBatchSize, batchTimerDuration: time.Hour, // Ensure timer doesn't interfere with the test sessionID: uuid.MustParse("00000000-0000-4000-8000-000000000001"), - seq: &SequenceCounter{}, } // Set up hook to detect flush attempts @@ -363,7 +362,6 @@ func TestSocketAuditor_Loop_ReportsBatchFullDrops(t *testing.T) { batchSize: defaultBatchSize, batchTimerDuration: time.Hour, sessionID: uuid.MustParse("00000000-0000-4000-8000-000000000001"), - seq: &SequenceCounter{}, } flushed := make(chan struct{}, 4) @@ -483,7 +481,7 @@ func TestSocketAuditor_AuditRequest_SequenceNumberIncrements(t *testing.T) { auditor := setupSocketAuditor(t) for i := range 5 { - auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true}) + auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true, SequenceNumber: int32(i)}) select { case log := <-auditor.logCh: @@ -509,7 +507,7 @@ func TestSocketAuditor_Loop_FlushIncludesSessionID(t *testing.T) { // Fill a batch to trigger a flush. for i := 0; i < auditor.batchSize; i++ { - auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true}) + auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true, SequenceNumber: int32(i)}) } select { @@ -549,7 +547,6 @@ func setupSocketAuditor(t *testing.T) *SocketAuditor { batchSize: defaultBatchSize, batchTimerDuration: defaultBatchTimerDuration, sessionID: uuid.MustParse("00000000-0000-4000-8000-000000000001"), - seq: &SequenceCounter{}, } } @@ -580,7 +577,6 @@ func setupTestAuditor(t *testing.T) (*SocketAuditor, net.Conn) { batchSize: defaultBatchSize, batchTimerDuration: defaultBatchTimerDuration, sessionID: uuid.MustParse("00000000-0000-4000-8000-000000000001"), - seq: &SequenceCounter{}, } return auditor, serverConn diff --git a/proxy/proxy.go b/proxy/proxy.go index 59ab57bd..bd49e65d 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -34,7 +34,7 @@ type Server struct { started atomic.Bool sessionCorrelation config.SessionCorrelationConfig sessionID string - seqCounter *audit.SequenceCounter + seqCounter audit.SequenceCounter listener net.Listener pprofServer *http.Server @@ -57,9 +57,6 @@ type Config struct { // SessionID is the boundary session UUID injected as a header // on matching requests. SessionID string - // SequenceCounter provides monotonically increasing sequence - // numbers shared with the auditor so both carry the same value. - SequenceCounter *audit.SequenceCounter } // NewProxyServer creates a new proxy server instance @@ -74,7 +71,6 @@ func NewProxyServer(config Config) *Server { pprofPort: config.PprofPort, sessionCorrelation: config.SessionCorrelation, sessionID: config.SessionID, - seqCounter: config.SequenceCounter, } } @@ -293,13 +289,7 @@ func (p *Server) processHTTPRequest(conn net.Conn, req *http.Request, https bool result := p.ruleEngine.Evaluate(req.Method, fullURL) - // Pre-allocate a sequence number so the audit event and any - // injected header carry the same value. - var seqNum *int32 - if p.seqCounter != nil { - n := p.seqCounter.Next() - seqNum = &n - } + seqNum := p.seqCounter.Next() p.auditor.AuditRequest(audit.Request{ Method: req.Method, @@ -345,7 +335,7 @@ func (p *Server) shouldInjectHeaders(host, reqPath string) bool { return false } -func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool, seqNum *int32) { +func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool, seqNum int32) { // Create HTTP client client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { @@ -395,9 +385,7 @@ func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool, se // upstream always sees boundary's ID. if p.shouldInjectHeaders(req.Host, req.URL.Path) { newReq.Header.Set(config.SessionIDHeaderName, p.sessionID) - if seqNum != nil { - newReq.Header.Set(config.SequenceNumberHeaderName, strconv.Itoa(int(*seqNum))) - } + newReq.Header.Set(config.SequenceNumberHeaderName, strconv.Itoa(int(seqNum))) } // Make request to destination diff --git a/proxy/proxy_framework_test.go b/proxy/proxy_framework_test.go index 97c9975f..bce32e55 100644 --- a/proxy/proxy_framework_test.go +++ b/proxy/proxy_framework_test.go @@ -45,7 +45,6 @@ type ProxyTest struct { auditor audit.Auditor sessionCorrelation config.SessionCorrelationConfig sessionID string - seqCounter *audit.SequenceCounter } // ProxyTestOption is a function that configures ProxyTest @@ -128,13 +127,6 @@ func WithSessionID(id string) ProxyTestOption { } } -// WithSequenceCounter sets the sequence counter for the proxy under test. -func WithSequenceCounter(seq *audit.SequenceCounter) ProxyTestOption { - return func(pt *ProxyTest) { - pt.seqCounter = seq - } -} - // Start starts the proxy server func (pt *ProxyTest) Start() *ProxyTest { pt.t.Helper() @@ -186,7 +178,6 @@ func (pt *ProxyTest) Start() *ProxyTest { TLSConfig: tlsConfig, SessionCorrelation: pt.sessionCorrelation, SessionID: pt.sessionID, - SequenceCounter: pt.seqCounter, }) err = pt.server.Start() diff --git a/proxy/proxy_session_correlation_test.go b/proxy/proxy_session_correlation_test.go index dde3a169..a20ebae2 100644 --- a/proxy/proxy_session_correlation_test.go +++ b/proxy/proxy_session_correlation_test.go @@ -7,7 +7,6 @@ import ( "sync" "testing" - "github.com/coder/boundary/audit" "github.com/coder/boundary/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -48,8 +47,6 @@ func TestSessionCorrelation_MatchedDomain(t *testing.T) { backendURL, err := url.Parse(backend.server.URL) require.NoError(t, err) - seq := &audit.SequenceCounter{} - pt := NewProxyTest(t, WithCertManager(t.TempDir()), WithAllowedDomain(backendURL.Hostname()), @@ -58,7 +55,6 @@ func TestSessionCorrelation_MatchedDomain(t *testing.T) { InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, }), WithSessionID("test-session-id-1234"), - WithSequenceCounter(seq), ).Start() defer pt.Stop() @@ -81,8 +77,6 @@ func TestSessionCorrelation_UnmatchedDomain(t *testing.T) { backendURL, err := url.Parse(backend.server.URL) require.NoError(t, err) - seq := &audit.SequenceCounter{} - pt := NewProxyTest(t, WithCertManager(t.TempDir()), WithAllowedDomain(backendURL.Hostname()), @@ -91,7 +85,6 @@ func TestSessionCorrelation_UnmatchedDomain(t *testing.T) { InjectTargets: []config.InjectTarget{{Domain: "other-domain.example.com"}}, }), WithSessionID("test-session-id-1234"), - WithSequenceCounter(seq), ).Start() defer pt.Stop() @@ -114,8 +107,6 @@ func TestSessionCorrelation_Disabled(t *testing.T) { backendURL, err := url.Parse(backend.server.URL) require.NoError(t, err) - seq := &audit.SequenceCounter{} - pt := NewProxyTest(t, WithCertManager(t.TempDir()), WithAllowedDomain(backendURL.Hostname()), @@ -124,7 +115,6 @@ func TestSessionCorrelation_Disabled(t *testing.T) { InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, }), WithSessionID("test-session-id-1234"), - WithSequenceCounter(seq), ).Start() defer pt.Stop() @@ -147,8 +137,6 @@ func TestSessionCorrelation_OverwritesClientValue(t *testing.T) { backendURL, err := url.Parse(backend.server.URL) require.NoError(t, err) - seq := &audit.SequenceCounter{} - pt := NewProxyTest(t, WithCertManager(t.TempDir()), WithAllowedDomain(backendURL.Hostname()), @@ -157,7 +145,6 @@ func TestSessionCorrelation_OverwritesClientValue(t *testing.T) { InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, }), WithSessionID("real-session-id"), - WithSequenceCounter(seq), ).Start() defer pt.Stop() @@ -187,8 +174,6 @@ func TestSessionCorrelation_PathMatching(t *testing.T) { backendURL, err := url.Parse(backend.server.URL) require.NoError(t, err) - seq := &audit.SequenceCounter{} - pt := NewProxyTest(t, WithCertManager(t.TempDir()), WithAllowedDomain(backendURL.Hostname()), @@ -200,7 +185,6 @@ func TestSessionCorrelation_PathMatching(t *testing.T) { }}, }), WithSessionID("test-session-id"), - WithSequenceCounter(seq), ).Start() defer pt.Stop() @@ -234,8 +218,6 @@ func TestSessionCorrelation_SequenceNumberIncrements(t *testing.T) { backendURL, err := url.Parse(backend.server.URL) require.NoError(t, err) - seq := &audit.SequenceCounter{} - pt := NewProxyTest(t, WithCertManager(t.TempDir()), WithAllowedDomain(backendURL.Hostname()), @@ -244,7 +226,6 @@ func TestSessionCorrelation_SequenceNumberIncrements(t *testing.T) { InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, }), WithSessionID("test-session-id"), - WithSequenceCounter(seq), ).Start() defer pt.Stop() From b9900ce40095579436a1dc57ca25cf90537e910f Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Thu, 7 May 2026 15:07:11 +0000 Subject: [PATCH 04/10] refactor(proxy): streamline host port stripping and remove redundant comments Updated the host port stripping logic in the shouldInjectHeaders function to utilize net.SplitHostPort for improved clarity and reliability. Removed outdated comments regarding session correlation header injection to enhance code readability. --- proxy/proxy.go | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index bd49e65d..180d62b5 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -316,10 +316,9 @@ func (p *Server) shouldInjectHeaders(host, reqPath string) bool { if !p.sessionCorrelation.Enabled { return false } - // Strip port from host for matching (e.g. "example.com:443" -> "example.com"). h := host - if i := strings.LastIndex(h, ":"); i != -1 { - h = h[:i] + if stripped, _, err := net.SplitHostPort(h); err == nil { + h = stripped } for _, target := range p.sessionCorrelation.InjectTargets { if !strings.EqualFold(target.Domain, h) { @@ -380,9 +379,6 @@ func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool, se } } - // Stamp session correlation headers on matching requests, - // overwriting any value the jailed client may have set so the - // upstream always sees boundary's ID. if p.shouldInjectHeaders(req.Host, req.URL.Path) { newReq.Header.Set(config.SessionIDHeaderName, p.sessionID) newReq.Header.Set(config.SequenceNumberHeaderName, strconv.Itoa(int(seqNum))) From be83d88934bccdc519985f8672470dc763c45f1c Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Thu, 7 May 2026 15:07:41 +0000 Subject: [PATCH 05/10] make fmt --- audit/socket_auditor.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/audit/socket_auditor.go b/audit/socket_auditor.go index e3477d0f..6ca6f1d9 100644 --- a/audit/socket_auditor.go +++ b/audit/socket_auditor.go @@ -34,8 +34,8 @@ type SocketAuditor struct { logCh chan *agentproto.BoundaryLog batchSize int batchTimerDuration time.Duration - socketPath string - sessionID uuid.UUID + socketPath string + sessionID uuid.UUID droppedChannelFull atomic.Int64 droppedBatchFull atomic.Int64 From cab44878a59a18f1b007b07c9b360e091c6dcc20 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Thu, 7 May 2026 15:09:57 +0000 Subject: [PATCH 06/10] make lint --- audit/socket_auditor.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/audit/socket_auditor.go b/audit/socket_auditor.go index 6ca6f1d9..5ab7f820 100644 --- a/audit/socket_auditor.go +++ b/audit/socket_auditor.go @@ -79,12 +79,10 @@ func (s *SocketAuditor) AuditRequest(req Request) { httpReq.MatchedRule = req.Rule } - var seqNum int32 = req.SequenceNumber - log := &agentproto.BoundaryLog{ Allowed: req.Allowed, Time: timestamppb.Now(), - SequenceNumber: seqNum, + SequenceNumber: req.SequenceNumber, Resource: &agentproto.BoundaryLog_HttpRequest_{HttpRequest: httpReq}, } From dedcb852233762a40a464b21908a658759ab1803 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Thu, 7 May 2026 15:35:21 +0000 Subject: [PATCH 07/10] feat(proxy): enhance request handling with sequence number tracking and transport configuration Added a new test to verify that the sequence number increments correctly across different request types (HTTP and HTTPS) in the proxy. Introduced a `WithForwardTransport` option to allow tests to specify a custom HTTP transport for handling self-signed certificates. Updated the proxy server configuration to support this transport, ensuring proper handling of requests with session correlation and sequence number injection. --- proxy/proxy.go | 13 ++++- proxy/proxy_audit_test.go | 71 +++++++++++++++++++++++++ proxy/proxy_framework_test.go | 11 ++++ proxy/proxy_session_correlation_test.go | 39 ++++++++++++++ 4 files changed, 133 insertions(+), 1 deletion(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index 180d62b5..cc99f34d 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -35,6 +35,7 @@ type Server struct { sessionCorrelation config.SessionCorrelationConfig sessionID string seqCounter audit.SequenceCounter + forwardTransport http.RoundTripper // nil means use http.DefaultTransport listener net.Listener pprofServer *http.Server @@ -57,6 +58,10 @@ type Config struct { // SessionID is the boundary session UUID injected as a header // on matching requests. SessionID string + // ForwardTransport, if non-nil, is used when forwarding requests to + // backend servers. Defaults to http.DefaultTransport when nil. Set in + // tests to trust self-signed backend certificates. + ForwardTransport http.RoundTripper } // NewProxyServer creates a new proxy server instance @@ -71,6 +76,7 @@ func NewProxyServer(config Config) *Server { pprofPort: config.PprofPort, sessionCorrelation: config.SessionCorrelation, sessionID: config.SessionID, + forwardTransport: config.ForwardTransport, } } @@ -321,7 +327,11 @@ func (p *Server) shouldInjectHeaders(host, reqPath string) bool { h = stripped } for _, target := range p.sessionCorrelation.InjectTargets { - if !strings.EqualFold(target.Domain, h) { + targetHost := target.Domain + if stripped, _, err := net.SplitHostPort(targetHost); err == nil { + targetHost = stripped + } + if !strings.EqualFold(targetHost, h) { continue } if target.Path == "" { @@ -340,6 +350,7 @@ func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool, se CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse // Don't follow redirects }, + Transport: p.forwardTransport, // nil → http.DefaultTransport } scheme := "http" diff --git a/proxy/proxy_audit_test.go b/proxy/proxy_audit_test.go index bbb71268..0e591e22 100644 --- a/proxy/proxy_audit_test.go +++ b/proxy/proxy_audit_test.go @@ -1,6 +1,7 @@ package proxy import ( + "crypto/tls" "net" "net/http" "net/http/httptest" @@ -31,6 +32,76 @@ func (c *capturingAuditor) getRequests() []audit.Request { return append([]audit.Request{}, c.requests...) } +func TestSequenceNumberIncrementsAcrossRequestTypes(t *testing.T) { + // Plain HTTP backend — used by the plain HTTP request. + httpBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer httpBackend.Close() + + // TLS backend — used by both the implicit-CONNECT and explicit-CONNECT + // requests. The proxy needs InsecureSkipVerify to trust its self-signed cert. + tlsBackend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer tlsBackend.Close() + + httpBackendURL, err := url.Parse(httpBackend.URL) + require.NoError(t, err) + tlsBackendURL, err := url.Parse(tlsBackend.URL) + require.NoError(t, err) + + // TLS SNI requires a hostname, not an IP address. httptest servers bind to + // 127.0.0.1, so rewrite the host to "localhost" for all TLS connections so + // that the proxy's cert manager receives a proper SNI value. + tlsHost := "localhost:" + tlsBackendURL.Port() + tlsURL := "https://" + tlsHost + + auditor := &capturingAuditor{} + + //nolint:gosec + insecureTransport := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(httpBackendURL.Hostname()), + WithAllowedDomain("localhost"), + WithAuditor(auditor), + WithForwardTransport(insecureTransport), + ).Start() + defer pt.Stop() + + // Request 1: plain HTTP — handleHTTPConnection → processHTTPRequest(https=false) + resp, err := pt.proxyClient.Get(httpBackend.URL + "/") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Request 2: HTTPS via implicit CONNECT — Go's transport automatically sends + // CONNECT for https:// URLs → handleCONNECTTunnel → processHTTPRequest(https=true) + resp, err = pt.proxyClient.Get(tlsURL + "/") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Request 3: inside an explicit CONNECT tunnel — handleCONNECTTunnel → + // processHTTPRequest(https=true), driven by a manually established tunnel. + tunnel, err := pt.establishExplicitCONNECT(tlsHost) + require.NoError(t, err) + defer tunnel.close() //nolint:errcheck + _, err = tunnel.sendRequest(tlsHost, "/") + require.NoError(t, err) + + requests := auditor.getRequests() + require.Len(t, requests, 3, "expected one audit record per request") + + assert.Equal(t, int32(0), requests[0].SequenceNumber, "HTTP request must have sequence number 0") + assert.Equal(t, int32(1), requests[1].SequenceNumber, "implicit-CONNECT request must have sequence number 1") + assert.Equal(t, int32(2), requests[2].SequenceNumber, "explicit-CONNECT tunnel request must have sequence number 2") +} + func TestAuditURLIsFullyFormed_HTTP(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) diff --git a/proxy/proxy_framework_test.go b/proxy/proxy_framework_test.go index bce32e55..dd94726f 100644 --- a/proxy/proxy_framework_test.go +++ b/proxy/proxy_framework_test.go @@ -45,6 +45,7 @@ type ProxyTest struct { auditor audit.Auditor sessionCorrelation config.SessionCorrelationConfig sessionID string + forwardTransport http.RoundTripper } // ProxyTestOption is a function that configures ProxyTest @@ -127,6 +128,15 @@ func WithSessionID(id string) ProxyTestOption { } } +// WithForwardTransport sets the http.RoundTripper the proxy uses when +// forwarding requests to backends. Use in tests to trust self-signed +// backend certificates (e.g. those from httptest.NewTLSServer). +func WithForwardTransport(transport http.RoundTripper) ProxyTestOption { + return func(pt *ProxyTest) { + pt.forwardTransport = transport + } +} + // Start starts the proxy server func (pt *ProxyTest) Start() *ProxyTest { pt.t.Helper() @@ -178,6 +188,7 @@ func (pt *ProxyTest) Start() *ProxyTest { TLSConfig: tlsConfig, SessionCorrelation: pt.sessionCorrelation, SessionID: pt.sessionID, + ForwardTransport: pt.forwardTransport, }) err = pt.server.Start() diff --git a/proxy/proxy_session_correlation_test.go b/proxy/proxy_session_correlation_test.go index a20ebae2..631774b0 100644 --- a/proxy/proxy_session_correlation_test.go +++ b/proxy/proxy_session_correlation_test.go @@ -70,6 +70,45 @@ func TestSessionCorrelation_MatchedDomain(t *testing.T) { "sequence number header must start at 0") } +func TestSessionCorrelation_MatchedDomainWithPort(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + // httptest.NewServer binds to a random ephemeral port, so backendURL.Host + // is always "127.0.0.1:" and backendURL.Hostname() is "127.0.0.1". + // Verify that assumption holds before relying on it below. + require.NotEmpty(t, backendURL.Port(), "httptest.Server URL must include a port") + require.Equal(t, backendURL.Hostname()+":"+backendURL.Port(), backendURL.Host, + "backendURL.Host must be host:port") + + // Configure the inject target with host:port instead of just the hostname. + // The port must be stripped during matching so the request still matches. + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{Domain: backendURL.Host}}, + }), + WithSessionID("test-session-id-1234"), + ).Start() + defer pt.Stop() + + resp, err := pt.proxyClient.Get(backend.server.URL + "/api/v2") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Equal(t, "test-session-id-1234", got.Get(config.SessionIDHeaderName), + "session ID header must be injected when inject target domain includes a port") + assert.Equal(t, "0", got.Get(config.SequenceNumberHeaderName), + "sequence number header must start at 0") +} + func TestSessionCorrelation_UnmatchedDomain(t *testing.T) { backend := newHeaderCapturingBackend() defer backend.close() From 5681970a8168213c08fc2edb3c28dc866190966b Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Tue, 12 May 2026 10:09:46 +0000 Subject: [PATCH 08/10] refactor(proxy): unify inject target matching with rulesengine Replace the custom shouldInjectHeaders domain/path matching with the existing rulesengine engine. Inject targets now use the same "domain=... path=..." syntax and identical matching semantics as --allow rules. Changes: - config: Remove InjectTarget struct and ParseInjectTarget. Change SessionCorrelationConfig.InjectTargets from []InjectTarget to []string (raw rule specs). DefaultInjectTargetFromEnv returns a string. ValidateSessionCorrelation delegates parsing to rulesengine.ParseAllowSpecs. - config: Simplify buildSessionCorrelation to pass raw strings. - proxy: Replace sessionCorrelation field with injectEngine (*rulesengine.Engine). shouldInjectHeaders now delegates to injectEngine.Evaluate instead of hand-rolled matching. - proxy: Add InjectEngine to proxy.Config, built from SessionCorrelation.InjectTargets at construction time. - tests: Update all session correlation tests to use raw rule strings. This resolves the behavioral differences between inject targets and allow rules: case sensitivity, wildcard/subdomain support, path glob semantics, trailing-slash normalization, and input validation are now identical. --- config/config.go | 25 ++- config/session_correlation.go | 109 +++++-------- config/session_correlation_test.go | 206 ++++++++---------------- proxy/proxy.go | 54 +++---- proxy/proxy_framework_test.go | 10 ++ proxy/proxy_session_correlation_test.go | 25 ++- 6 files changed, 162 insertions(+), 267 deletions(-) diff --git a/config/config.go b/config/config.go index b0100fa4..9de181b4 100644 --- a/config/config.go +++ b/config/config.go @@ -141,26 +141,19 @@ func NewAppConfigFromCliConfig(cfg CliConfig, targetCMD []string, environ []stri }, nil } -// buildSessionCorrelation merges CLI and YAML inject target sources, -// parses each target string, and validates the resulting configuration. -// environ is passed explicitly (rather than reading os.Environ inside) -// so that callers and tests can supply a controlled environment. +// buildSessionCorrelation merges CLI and YAML inject target sources +// and validates the resulting configuration. Inject targets use the same +// "domain=... path=..." syntax as --allow rules so that matching semantics +// are identical. environ is passed explicitly (rather than reading +// os.Environ inside) so that callers and tests can supply a controlled +// environment. func buildSessionCorrelation(cfg CliConfig, environ []string) (SessionCorrelationConfig, error) { // Merge YAML targets with CLI targets. - rawTargets := append(cfg.InjectSessionIDTargets.Value(), cfg.InjectSessionIDTarget.Value()...) - - var targets []InjectTarget - for _, raw := range rawTargets { - t, err := ParseInjectTarget(raw) - if err != nil { - return SessionCorrelationConfig{}, err - } - targets = append(targets, t) - } + targets := append(cfg.InjectSessionIDTargets.Value(), cfg.InjectSessionIDTarget.Value()...) if len(targets) == 0 && cfg.SessionCorrelationEnabled.Value() { - if t := DefaultInjectTargetFromEnv(environ); t != nil { - targets = []InjectTarget{*t} + if t := DefaultInjectTargetFromEnv(environ); t != "" { + targets = []string{t} } } diff --git a/config/session_correlation.go b/config/session_correlation.go index 5dae21c5..67a7b278 100644 --- a/config/session_correlation.go +++ b/config/session_correlation.go @@ -4,6 +4,8 @@ import ( "fmt" "net/url" "strings" + + "github.com/coder/boundary/rulesengine" ) // Header names and paths for session correlation. @@ -27,14 +29,6 @@ const ( CoderAgentURLEnv = "CODER_AGENT_URL" ) -// InjectTarget represents a parsed target for session correlation header -// injection. Requests matching the domain (and optional path glob) will -// receive the session ID and sequence number headers. -type InjectTarget struct { - Domain string - Path string -} - // SessionCorrelationConfig holds configuration for session correlation // header injection. When enabled, boundary injects its session ID and // sequence number as custom headers on matching outbound requests so @@ -45,63 +39,22 @@ type SessionCorrelationConfig struct { // Deployments without AI Bridge in front should set this to false. Enabled bool - // InjectTargets is the list of domain/path patterns that should - // receive session correlation headers. - InjectTargets []InjectTarget -} - -// ParseInjectTarget parses a string of the form "domain=... path=..." -// into an InjectTarget. The domain key is required; path is optional. -func ParseInjectTarget(raw string) (InjectTarget, error) { - raw = strings.TrimSpace(raw) - if raw == "" { - return InjectTarget{}, fmt.Errorf("inject target must not be empty") - } - - var target InjectTarget - seen := make(map[string]bool) - for _, part := range strings.Fields(raw) { - key, value, ok := strings.Cut(part, "=") - if !ok { - return InjectTarget{}, fmt.Errorf( - "inject target: malformed key-value pair %q, expected key=value", part, - ) - } - if seen[key] { - return InjectTarget{}, fmt.Errorf( - "inject target: duplicate key %q (use separate flags for multiple targets)", key, - ) - } - seen[key] = true - switch key { - case "domain": - if value == "" { - return InjectTarget{}, fmt.Errorf("inject target: domain must not be empty") - } - target.Domain = value - case "path": - target.Path = value - default: - return InjectTarget{}, fmt.Errorf("inject target: unknown key %q", key) - } - } - - if target.Domain == "" { - return InjectTarget{}, fmt.Errorf("inject target: domain is required") - } - - return target, nil + // InjectTargets is the list of raw rule specs (same syntax as --allow) + // that should receive session correlation headers. Each string uses the + // rulesengine "domain=... path=..." format so that inject target + // matching is identical to allow-rule matching. + InjectTargets []string } -// DefaultInjectTargetFromEnv derives an InjectTarget from the CODER_AGENT_URL -// variable in the provided environment slice. It returns nil if the variable is -// absent, empty, or not a valid URL with a host. The derived target uses -// DefaultAIBridgePath as the path glob so that all AI Bridge traffic on the -// control-plane host is matched. +// DefaultInjectTargetFromEnv derives an inject target rule string from the +// CODER_AGENT_URL variable in the provided environment slice. It returns "" +// if the variable is absent, empty, or not a valid URL with a host. The +// derived target uses DefaultAIBridgePath as the path glob so that all AI +// Bridge traffic on the control-plane host is matched. // // The environ parameter is accepted rather than reading os.Environ directly so // that callers (and tests) can supply an arbitrary environment. -func DefaultInjectTargetFromEnv(environ []string) *InjectTarget { +func DefaultInjectTargetFromEnv(environ []string) string { var raw string for _, e := range environ { k, v, ok := strings.Cut(e, "=") @@ -111,23 +64,22 @@ func DefaultInjectTargetFromEnv(environ []string) *InjectTarget { } } if raw == "" { - return nil + return "" } u, err := url.Parse(raw) if err != nil || u.Host == "" { - return nil + return "" } - return &InjectTarget{ - Domain: u.Hostname(), - Path: DefaultAIBridgePath, - } + return fmt.Sprintf("domain=%s path=%s", u.Hostname(), DefaultAIBridgePath) } // ValidateSessionCorrelation checks that the session correlation config -// is internally consistent. It returns an error describing the first -// problem found, or nil if the config is valid. +// is internally consistent. When enabled it verifies that at least one +// inject target is configured and that every target string is a valid +// rulesengine rule. It returns an error describing the first problem +// found, or nil if the config is valid. func ValidateSessionCorrelation(cfg SessionCorrelationConfig) error { if !cfg.Enabled { return nil @@ -139,5 +91,26 @@ func ValidateSessionCorrelation(cfg SessionCorrelationConfig) error { ) } + // Reject empty target strings before passing to the parser. + for _, t := range cfg.InjectTargets { + if strings.TrimSpace(t) == "" { + return fmt.Errorf("inject target: must not be empty") + } + } + + // Validate each target parses as a rulesengine rule. + rules, err := rulesengine.ParseAllowSpecs(cfg.InjectTargets) + if err != nil { + return fmt.Errorf("inject target: %w", err) + } + + // Inject targets must specify a domain; path-only rules are not + // meaningful for header injection. + for i, r := range rules { + if r.HostPattern == nil { + return fmt.Errorf("inject target %q: domain is required", cfg.InjectTargets[i]) + } + } + return nil } diff --git a/config/session_correlation_test.go b/config/session_correlation_test.go index 4093ee1a..588f9e61 100644 --- a/config/session_correlation_test.go +++ b/config/session_correlation_test.go @@ -4,77 +4,48 @@ import ( "testing" ) -func TestParseInjectTarget(t *testing.T) { +func TestParseInjectTarget_ViaValidation(t *testing.T) { t.Parallel() tests := []struct { name string input string - want InjectTarget wantErr bool }{ { name: "domain only", input: "domain=dev.coder.com", - want: InjectTarget{Domain: "dev.coder.com"}, }, { name: "domain and path", input: "domain=dev.coder.com path=/api/v2/aibridge/*", - want: InjectTarget{Domain: "dev.coder.com", Path: "/api/v2/aibridge/*"}, - }, - { - name: "leading and trailing whitespace", - input: " domain=dev.coder.com path=/api/* ", - want: InjectTarget{Domain: "dev.coder.com", Path: "/api/*"}, }, { name: "empty string", input: "", wantErr: true, }, - { - name: "whitespace only", - input: " ", - wantErr: true, - }, { name: "missing domain", input: "path=/api/*", wantErr: true, }, - { - name: "empty domain value", - input: "domain=", - wantErr: true, - }, - { - name: "malformed pair no equals", - input: "domain", - wantErr: true, - }, { name: "unknown key", input: "domain=example.com port=443", wantErr: true, }, - { - name: "duplicate domain key", - input: "domain=staging.coder.com domain=prod.coder.com", - wantErr: true, - }, - { - name: "duplicate path key", - input: "domain=example.com path=/api/* path=/other/*", - wantErr: true, - }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { t.Parallel() - got, err := ParseInjectTarget(tc.input) + cfg := SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []string{tc.input}, + } + err := ValidateSessionCorrelation(cfg) if tc.wantErr { if err == nil { t.Fatalf("expected error, got nil") @@ -84,12 +55,6 @@ func TestParseInjectTarget(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if got.Domain != tc.want.Domain { - t.Errorf("Domain: got %q, want %q", got.Domain, tc.want.Domain) - } - if got.Path != tc.want.Path { - t.Errorf("Path: got %q, want %q", got.Path, tc.want.Path) - } }) } } @@ -119,7 +84,7 @@ func TestValidateSessionCorrelation(t *testing.T) { name: "enabled with targets", cfg: SessionCorrelationConfig{ Enabled: true, - InjectTargets: []InjectTarget{{Domain: "dev.coder.com"}}, + InjectTargets: []string{"domain=dev.coder.com"}, }, }, { @@ -134,7 +99,7 @@ func TestValidateSessionCorrelation(t *testing.T) { name: "enabled with empty targets slice", cfg: SessionCorrelationConfig{ Enabled: true, - InjectTargets: []InjectTarget{}, + InjectTargets: []string{}, }, wantErr: true, }, @@ -162,19 +127,18 @@ func TestNewAppConfigFromCliConfig_SessionCorrelation(t *testing.T) { t.Parallel() tests := []struct { - name string - cli CliConfig - environ []string - want SessionCorrelationConfig - wantErr bool + name string + cli CliConfig + environ []string + wantEnabled bool + wantTargets []string + wantErr bool }{ { - name: "defaults when not configured", - cli: baseCliConfig(), - want: SessionCorrelationConfig{ - Enabled: false, - InjectTargets: nil, - }, + name: "defaults when not configured", + cli: baseCliConfig(), + wantEnabled: false, + wantTargets: nil, }, { name: "enabled with inject targets", @@ -184,28 +148,22 @@ func TestNewAppConfigFromCliConfig_SessionCorrelation(t *testing.T) { _ = c.InjectSessionIDTarget.Set("domain=dev.coder.com path=/api/v2/aibridge/*") return c }(), - want: SessionCorrelationConfig{ - Enabled: true, - InjectTargets: []InjectTarget{ - {Domain: "dev.coder.com", Path: "/api/v2/aibridge/*"}, - }, - }, + wantEnabled: true, + wantTargets: []string{"domain=dev.coder.com path=/api/v2/aibridge/*"}, }, { - name: "enabled with no targets, CODER_AGENT_URL set → auto-derived", + name: "enabled with no targets, CODER_AGENT_URL set -> auto-derived", cli: func() CliConfig { c := baseCliConfig() _ = c.SessionCorrelationEnabled.Set("true") return c }(), - environ: []string{"CODER_AGENT_URL=https://dev.coder.com/"}, - want: SessionCorrelationConfig{ - Enabled: true, - InjectTargets: []InjectTarget{{Domain: "dev.coder.com", Path: DefaultAIBridgePath}}, - }, + environ: []string{"CODER_AGENT_URL=https://dev.coder.com/"}, + wantEnabled: true, + wantTargets: []string{"domain=dev.coder.com path=" + DefaultAIBridgePath}, }, { - name: "enabled with no targets, CODER_AGENT_URL absent → error", + name: "enabled with no targets, CODER_AGENT_URL absent -> error", cli: func() CliConfig { c := baseCliConfig() _ = c.SessionCorrelationEnabled.Set("true") @@ -242,21 +200,17 @@ func TestNewAppConfigFromCliConfig_SessionCorrelation(t *testing.T) { } sc := got.SessionCorrelation - if sc.Enabled != tc.want.Enabled { - t.Errorf("Enabled: got %v, want %v", sc.Enabled, tc.want.Enabled) + if sc.Enabled != tc.wantEnabled { + t.Errorf("Enabled: got %v, want %v", sc.Enabled, tc.wantEnabled) } - if len(sc.InjectTargets) != len(tc.want.InjectTargets) { + if len(sc.InjectTargets) != len(tc.wantTargets) { t.Fatalf("InjectTargets len: got %d, want %d", - len(sc.InjectTargets), len(tc.want.InjectTargets)) + len(sc.InjectTargets), len(tc.wantTargets)) } for i := range sc.InjectTargets { - if sc.InjectTargets[i].Domain != tc.want.InjectTargets[i].Domain { - t.Errorf("InjectTargets[%d].Domain: got %q, want %q", - i, sc.InjectTargets[i].Domain, tc.want.InjectTargets[i].Domain) - } - if sc.InjectTargets[i].Path != tc.want.InjectTargets[i].Path { - t.Errorf("InjectTargets[%d].Path: got %q, want %q", - i, sc.InjectTargets[i].Path, tc.want.InjectTargets[i].Path) + if sc.InjectTargets[i] != tc.wantTargets[i] { + t.Errorf("InjectTargets[%d]: got %q, want %q", + i, sc.InjectTargets[i], tc.wantTargets[i]) } } }) @@ -269,42 +223,42 @@ func TestDefaultInjectTargetFromEnv(t *testing.T) { tests := []struct { name string environ []string - want *InjectTarget + want string }{ { name: "valid URL with trailing slash", environ: []string{"CODER_AGENT_URL=https://dev.coder.com/"}, - want: &InjectTarget{Domain: "dev.coder.com", Path: DefaultAIBridgePath}, + want: "domain=dev.coder.com path=" + DefaultAIBridgePath, }, { name: "valid URL without trailing slash", environ: []string{"CODER_AGENT_URL=https://dev.coder.com"}, - want: &InjectTarget{Domain: "dev.coder.com", Path: DefaultAIBridgePath}, + want: "domain=dev.coder.com path=" + DefaultAIBridgePath, }, { - name: "URL with port", // Ports are ignored in the rules engine, so we strip them here. + name: "URL with port", environ: []string{"CODER_AGENT_URL=https://dev.coder.com:8443/"}, - want: &InjectTarget{Domain: "dev.coder.com", Path: DefaultAIBridgePath}, + want: "domain=dev.coder.com path=" + DefaultAIBridgePath, }, { name: "unset variable", environ: []string{}, - want: nil, + want: "", }, { name: "empty value", environ: []string{"CODER_AGENT_URL="}, - want: nil, + want: "", }, { name: "no host in URL", environ: []string{"CODER_AGENT_URL=not-a-url"}, - want: nil, + want: "", }, { name: "other env vars present but not CODER_AGENT_URL", environ: []string{"CODER_URL=https://dev.coder.com/", "HOME=/home/user"}, - want: nil, + want: "", }, } @@ -313,20 +267,8 @@ func TestDefaultInjectTargetFromEnv(t *testing.T) { t.Parallel() got := DefaultInjectTargetFromEnv(tc.environ) - if tc.want == nil { - if got != nil { - t.Errorf("expected nil, got %+v", got) - } - return - } - if got == nil { - t.Fatalf("expected %+v, got nil", tc.want) - } - if got.Domain != tc.want.Domain { - t.Errorf("Domain: got %q, want %q", got.Domain, tc.want.Domain) - } - if got.Path != tc.want.Path { - t.Errorf("Path: got %q, want %q", got.Path, tc.want.Path) + if got != tc.want { + t.Errorf("got %q, want %q", got, tc.want) } }) } @@ -338,7 +280,7 @@ func TestBuildSessionCorrelation_TargetMerge(t *testing.T) { tests := []struct { name string cfg func() CliConfig - wantTargets []InjectTarget + wantTargets []string }{ { name: "YAML targets only", @@ -348,8 +290,8 @@ func TestBuildSessionCorrelation_TargetMerge(t *testing.T) { _ = c.InjectSessionIDTargets.Set("domain=yaml.example.com path=/yaml/*") return c }, - wantTargets: []InjectTarget{ - {Domain: "yaml.example.com", Path: "/yaml/*"}, + wantTargets: []string{ + "domain=yaml.example.com path=/yaml/*", }, }, { @@ -360,8 +302,8 @@ func TestBuildSessionCorrelation_TargetMerge(t *testing.T) { _ = c.InjectSessionIDTarget.Set("domain=cli.example.com path=/cli/*") return c }, - wantTargets: []InjectTarget{ - {Domain: "cli.example.com", Path: "/cli/*"}, + wantTargets: []string{ + "domain=cli.example.com path=/cli/*", }, }, { @@ -373,9 +315,9 @@ func TestBuildSessionCorrelation_TargetMerge(t *testing.T) { _ = c.InjectSessionIDTarget.Set("domain=cli.example.com path=/cli/*") return c }, - wantTargets: []InjectTarget{ - {Domain: "yaml.example.com", Path: "/yaml/*"}, - {Domain: "cli.example.com", Path: "/cli/*"}, + wantTargets: []string{ + "domain=yaml.example.com path=/yaml/*", + "domain=cli.example.com path=/cli/*", }, }, { @@ -389,11 +331,11 @@ func TestBuildSessionCorrelation_TargetMerge(t *testing.T) { _ = c.InjectSessionIDTarget.Set("domain=cli2.example.com") return c }, - wantTargets: []InjectTarget{ - {Domain: "yaml1.example.com"}, - {Domain: "yaml2.example.com"}, - {Domain: "cli1.example.com"}, - {Domain: "cli2.example.com"}, + wantTargets: []string{ + "domain=yaml1.example.com", + "domain=yaml2.example.com", + "domain=cli1.example.com", + "domain=cli2.example.com", }, }, } @@ -411,13 +353,9 @@ func TestBuildSessionCorrelation_TargetMerge(t *testing.T) { len(sc.InjectTargets), len(tc.wantTargets)) } for i := range sc.InjectTargets { - if sc.InjectTargets[i].Domain != tc.wantTargets[i].Domain { - t.Errorf("InjectTargets[%d].Domain: got %q, want %q", - i, sc.InjectTargets[i].Domain, tc.wantTargets[i].Domain) - } - if sc.InjectTargets[i].Path != tc.wantTargets[i].Path { - t.Errorf("InjectTargets[%d].Path: got %q, want %q", - i, sc.InjectTargets[i].Path, tc.wantTargets[i].Path) + if sc.InjectTargets[i] != tc.wantTargets[i] { + t.Errorf("InjectTargets[%d]: got %q, want %q", + i, sc.InjectTargets[i], tc.wantTargets[i]) } } }) @@ -431,23 +369,23 @@ func TestBuildSessionCorrelation_AgentURLFallback(t *testing.T) { name string cfg func() CliConfig environ []string - wantTargets []InjectTarget + wantTargets []string wantErr bool }{ { - name: "enabled, no explicit targets, CODER_AGENT_URL set → auto-derived", + name: "enabled, no explicit targets, CODER_AGENT_URL set -> auto-derived", cfg: func() CliConfig { c := baseCliConfig() _ = c.SessionCorrelationEnabled.Set("true") return c }, environ: []string{"CODER_AGENT_URL=https://dev.coder.com/"}, - wantTargets: []InjectTarget{ - {Domain: "dev.coder.com", Path: DefaultAIBridgePath}, + wantTargets: []string{ + "domain=dev.coder.com path=" + DefaultAIBridgePath, }, }, { - name: "enabled, no explicit targets, CODER_AGENT_URL absent → error", + name: "enabled, no explicit targets, CODER_AGENT_URL absent -> error", cfg: func() CliConfig { c := baseCliConfig() _ = c.SessionCorrelationEnabled.Set("true") @@ -465,12 +403,12 @@ func TestBuildSessionCorrelation_AgentURLFallback(t *testing.T) { return c }, environ: []string{"CODER_AGENT_URL=https://dev.coder.com/"}, - wantTargets: []InjectTarget{ - {Domain: "custom.example.com", Path: ""}, + wantTargets: []string{ + "domain=custom.example.com", }, }, { - name: "disabled, CODER_AGENT_URL absent → valid (no targets needed)", + name: "disabled, CODER_AGENT_URL absent -> valid (no targets needed)", cfg: func() CliConfig { return baseCliConfig() }, @@ -498,13 +436,9 @@ func TestBuildSessionCorrelation_AgentURLFallback(t *testing.T) { len(sc.InjectTargets), len(tc.wantTargets)) } for i := range sc.InjectTargets { - if sc.InjectTargets[i].Domain != tc.wantTargets[i].Domain { - t.Errorf("InjectTargets[%d].Domain: got %q, want %q", - i, sc.InjectTargets[i].Domain, tc.wantTargets[i].Domain) - } - if sc.InjectTargets[i].Path != tc.wantTargets[i].Path { - t.Errorf("InjectTargets[%d].Path: got %q, want %q", - i, sc.InjectTargets[i].Path, tc.wantTargets[i].Path) + if sc.InjectTargets[i] != tc.wantTargets[i] { + t.Errorf("InjectTargets[%d]: got %q, want %q", + i, sc.InjectTargets[i], tc.wantTargets[i]) } } }) diff --git a/proxy/proxy.go b/proxy/proxy.go index cc99f34d..b7e03633 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -13,7 +13,6 @@ import ( "net/http" _ "net/http/pprof" "net/url" - "path" "strconv" "strings" "sync/atomic" @@ -32,10 +31,10 @@ type Server struct { tlsConfig *tls.Config httpPort int started atomic.Bool - sessionCorrelation config.SessionCorrelationConfig - sessionID string - seqCounter audit.SequenceCounter - forwardTransport http.RoundTripper // nil means use http.DefaultTransport + injectEngine *rulesengine.Engine // nil when session correlation is disabled + sessionID string + seqCounter audit.SequenceCounter + forwardTransport http.RoundTripper // nil means use http.DefaultTransport listener net.Listener pprofServer *http.Server @@ -55,6 +54,10 @@ type Config struct { // SessionCorrelation controls header injection for AI Bridge // correlation. See config.SessionCorrelationConfig for details. SessionCorrelation config.SessionCorrelationConfig + // InjectEngine, if non-nil, is used to evaluate whether outgoing + // requests match configured inject targets. Built from + // SessionCorrelation.InjectTargets using rulesengine.ParseAllowSpecs. + InjectEngine *rulesengine.Engine // SessionID is the boundary session UUID injected as a header // on matching requests. SessionID string @@ -74,8 +77,8 @@ func NewProxyServer(config Config) *Server { httpPort: config.HTTPPort, pprofEnabled: config.PprofEnabled, pprofPort: config.PprofPort, - sessionCorrelation: config.SessionCorrelation, - sessionID: config.SessionID, + injectEngine: config.InjectEngine, + sessionID: config.SessionID, forwardTransport: config.ForwardTransport, } } @@ -315,33 +318,18 @@ func (p *Server) processHTTPRequest(conn net.Conn, req *http.Request, https bool p.forwardRequest(conn, req, https, seqNum) } -// shouldInjectHeaders reports whether the request to the given host -// and path matches any configured inject target. When session -// correlation is disabled or no targets match it returns false. -func (p *Server) shouldInjectHeaders(host, reqPath string) bool { - if !p.sessionCorrelation.Enabled { +// shouldInjectHeaders reports whether the request URL matches any +// configured inject target. Inject targets are evaluated using the same +// rulesengine matching as --allow rules so that domain/path semantics +// are unified. When session correlation is disabled or no targets match +// it returns false. +func (p *Server) shouldInjectHeaders(fullURL string) bool { + if p.injectEngine == nil { return false } - h := host - if stripped, _, err := net.SplitHostPort(h); err == nil { - h = stripped - } - for _, target := range p.sessionCorrelation.InjectTargets { - targetHost := target.Domain - if stripped, _, err := net.SplitHostPort(targetHost); err == nil { - targetHost = stripped - } - if !strings.EqualFold(targetHost, h) { - continue - } - if target.Path == "" { - return true - } - if matched, _ := path.Match(target.Path, reqPath); matched { - return true - } - } - return false + // Inject targets do not restrict by HTTP method, so we pass an empty + // method. Rules without a method pattern match all methods. + return p.injectEngine.Evaluate("", fullURL).Allowed } func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool, seqNum int32) { @@ -390,7 +378,7 @@ func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool, se } } - if p.shouldInjectHeaders(req.Host, req.URL.Path) { + if p.shouldInjectHeaders(targetURL.String()) { newReq.Header.Set(config.SessionIDHeaderName, p.sessionID) newReq.Header.Set(config.SequenceNumberHeaderName, strconv.Itoa(int(seqNum))) } diff --git a/proxy/proxy_framework_test.go b/proxy/proxy_framework_test.go index dd94726f..87edf265 100644 --- a/proxy/proxy_framework_test.go +++ b/proxy/proxy_framework_test.go @@ -180,6 +180,15 @@ func (pt *ProxyTest) Start() *ProxyTest { } } + // Build inject engine from session correlation targets. + var injectEngine *rulesengine.Engine + if pt.sessionCorrelation.Enabled && len(pt.sessionCorrelation.InjectTargets) > 0 { + injectRules, err := rulesengine.ParseAllowSpecs(pt.sessionCorrelation.InjectTargets) + require.NoError(pt.t, err, "Failed to parse inject target rules") + eng := rulesengine.NewRuleEngine(injectRules, logger) + injectEngine = &eng + } + pt.server = NewProxyServer(Config{ HTTPPort: pt.port, RuleEngine: ruleEngine, @@ -187,6 +196,7 @@ func (pt *ProxyTest) Start() *ProxyTest { Logger: logger, TLSConfig: tlsConfig, SessionCorrelation: pt.sessionCorrelation, + InjectEngine: injectEngine, SessionID: pt.sessionID, ForwardTransport: pt.forwardTransport, }) diff --git a/proxy/proxy_session_correlation_test.go b/proxy/proxy_session_correlation_test.go index 631774b0..45f7ac66 100644 --- a/proxy/proxy_session_correlation_test.go +++ b/proxy/proxy_session_correlation_test.go @@ -52,7 +52,7 @@ func TestSessionCorrelation_MatchedDomain(t *testing.T) { WithAllowedDomain(backendURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, - InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, + InjectTargets: []string{"domain=" + backendURL.Hostname()}, }), WithSessionID("test-session-id-1234"), ).Start() @@ -84,14 +84,14 @@ func TestSessionCorrelation_MatchedDomainWithPort(t *testing.T) { require.Equal(t, backendURL.Hostname()+":"+backendURL.Port(), backendURL.Host, "backendURL.Host must be host:port") - // Configure the inject target with host:port instead of just the hostname. - // The port must be stripped during matching so the request still matches. + // Configure the inject target using just the hostname (without port). + // The rulesengine strips ports during matching, same as for --allow rules. pt := NewProxyTest(t, WithCertManager(t.TempDir()), WithAllowedDomain(backendURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, - InjectTargets: []config.InjectTarget{{Domain: backendURL.Host}}, + InjectTargets: []string{"domain=" + backendURL.Hostname()}, }), WithSessionID("test-session-id-1234"), ).Start() @@ -104,7 +104,7 @@ func TestSessionCorrelation_MatchedDomainWithPort(t *testing.T) { got := backend.receivedHeaders() assert.Equal(t, "test-session-id-1234", got.Get(config.SessionIDHeaderName), - "session ID header must be injected when inject target domain includes a port") + "session ID header must be injected when inject target domain matches request hostname") assert.Equal(t, "0", got.Get(config.SequenceNumberHeaderName), "sequence number header must start at 0") } @@ -121,7 +121,7 @@ func TestSessionCorrelation_UnmatchedDomain(t *testing.T) { WithAllowedDomain(backendURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, - InjectTargets: []config.InjectTarget{{Domain: "other-domain.example.com"}}, + InjectTargets: []string{"domain=other-domain.example.com"}, }), WithSessionID("test-session-id-1234"), ).Start() @@ -151,7 +151,7 @@ func TestSessionCorrelation_Disabled(t *testing.T) { WithAllowedDomain(backendURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: false, - InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, + InjectTargets: []string{"domain=" + backendURL.Hostname()}, }), WithSessionID("test-session-id-1234"), ).Start() @@ -181,7 +181,7 @@ func TestSessionCorrelation_OverwritesClientValue(t *testing.T) { WithAllowedDomain(backendURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, - InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, + InjectTargets: []string{"domain=" + backendURL.Hostname()}, }), WithSessionID("real-session-id"), ).Start() @@ -217,11 +217,8 @@ func TestSessionCorrelation_PathMatching(t *testing.T) { WithCertManager(t.TempDir()), WithAllowedDomain(backendURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ - Enabled: true, - InjectTargets: []config.InjectTarget{{ - Domain: backendURL.Hostname(), - Path: "/api/*", - }}, + Enabled: true, + InjectTargets: []string{"domain=" + backendURL.Hostname() + " path=/api/*"}, }), WithSessionID("test-session-id"), ).Start() @@ -262,7 +259,7 @@ func TestSessionCorrelation_SequenceNumberIncrements(t *testing.T) { WithAllowedDomain(backendURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, - InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, + InjectTargets: []string{"domain=" + backendURL.Hostname()}, }), WithSessionID("test-session-id"), ).Start() From 5d714115dddacb2696ce49751ac00b2e80820db4 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Tue, 12 May 2026 11:27:58 +0000 Subject: [PATCH 09/10] fix(proxy): remove duplicate forwardTransport comment The behavior is already documented on the Config.ForwardTransport field. Addresses review feedback from johnstcn. --- proxy/proxy.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index b7e03633..bb95f523 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -34,7 +34,7 @@ type Server struct { injectEngine *rulesengine.Engine // nil when session correlation is disabled sessionID string seqCounter audit.SequenceCounter - forwardTransport http.RoundTripper // nil means use http.DefaultTransport + forwardTransport http.RoundTripper listener net.Listener pprofServer *http.Server From cd7953044a367e40f609767687adcce2c08c10bf Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Wed, 13 May 2026 08:28:03 +0000 Subject: [PATCH 10/10] make fmt --- proxy/proxy.go | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index bb95f523..229cce0d 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -25,12 +25,12 @@ import ( // Server handles HTTP and HTTPS requests with rule-based filtering type Server struct { - ruleEngine rulesengine.Engine - auditor audit.Auditor - logger *slog.Logger - tlsConfig *tls.Config - httpPort int - started atomic.Bool + ruleEngine rulesengine.Engine + auditor audit.Auditor + logger *slog.Logger + tlsConfig *tls.Config + httpPort int + started atomic.Bool injectEngine *rulesengine.Engine // nil when session correlation is disabled sessionID string seqCounter audit.SequenceCounter @@ -70,16 +70,16 @@ type Config struct { // NewProxyServer creates a new proxy server instance func NewProxyServer(config Config) *Server { return &Server{ - ruleEngine: config.RuleEngine, - auditor: config.Auditor, - logger: config.Logger, - tlsConfig: config.TLSConfig, - httpPort: config.HTTPPort, - pprofEnabled: config.PprofEnabled, - pprofPort: config.PprofPort, + ruleEngine: config.RuleEngine, + auditor: config.Auditor, + logger: config.Logger, + tlsConfig: config.TLSConfig, + httpPort: config.HTTPPort, + pprofEnabled: config.PprofEnabled, + pprofPort: config.PprofPort, injectEngine: config.InjectEngine, sessionID: config.SessionID, - forwardTransport: config.ForwardTransport, + forwardTransport: config.ForwardTransport, } }