diff --git a/audit/multi_auditor.go b/audit/multi_auditor.go index 91607dc..bd0789b 100644 --- a/audit/multi_auditor.go +++ b/audit/multi_auditor.go @@ -50,8 +50,7 @@ func SetupAuditor(ctx context.Context, logger *slog.Logger, disableAuditLogs boo } agentWillProxy := !os.IsNotExist(err) if agentWillProxy { - seq := &SequenceCounter{} - socketAuditor := NewSocketAuditor(logger, logProxySocketPath, sessionID, seq) + socketAuditor := NewSocketAuditor(logger, logProxySocketPath, sessionID) go socketAuditor.Loop(ctx) auditors = append(auditors, socketAuditor) } else { diff --git a/audit/request.go b/audit/request.go index c6ef1b3..d63f2a9 100644 --- a/audit/request.go +++ b/audit/request.go @@ -11,4 +11,9 @@ type Request struct { Host string Allowed bool Rule string // The rule that matched (if any) + + // SequenceNumber is the sequence number assigned to this audit event + // by the proxy. It is monotonically increasing within a session and + // is shared with any injected HTTP header so both carry the same value. + SequenceNumber int32 } diff --git a/audit/socket_auditor.go b/audit/socket_auditor.go index afb50a3..5ab7f82 100644 --- a/audit/socket_auditor.go +++ b/audit/socket_auditor.go @@ -36,7 +36,6 @@ type SocketAuditor struct { batchTimerDuration time.Duration socketPath string sessionID uuid.UUID - seq *SequenceCounter droppedChannelFull atomic.Int64 droppedBatchFull atomic.Int64 @@ -48,7 +47,7 @@ type SocketAuditor struct { // NewSocketAuditor creates a new SocketAuditor that sends logs to the agent's // boundary log proxy socket after SocketAuditor.Loop is called. The socket path // is read from EnvAuditSocketPath, falling back to defaultAuditSocketPath. -func NewSocketAuditor(logger *slog.Logger, socketPath string, sessionID uuid.UUID, seq *SequenceCounter) *SocketAuditor { +func NewSocketAuditor(logger *slog.Logger, socketPath string, sessionID uuid.UUID) *SocketAuditor { // This channel buffer size intends to allow enough buffering for bursty // AI agent network requests while a batch is being sent to the workspace // agent. @@ -64,7 +63,6 @@ func NewSocketAuditor(logger *slog.Logger, socketPath string, sessionID uuid.UUI batchTimerDuration: defaultBatchTimerDuration, socketPath: socketPath, sessionID: sessionID, - seq: seq, } } @@ -84,7 +82,7 @@ func (s *SocketAuditor) AuditRequest(req Request) { log := &agentproto.BoundaryLog{ Allowed: req.Allowed, Time: timestamppb.Now(), - SequenceNumber: s.seq.Next(), + SequenceNumber: req.SequenceNumber, Resource: &agentproto.BoundaryLog_HttpRequest_{HttpRequest: httpReq}, } diff --git a/audit/socket_auditor_test.go b/audit/socket_auditor_test.go index 6443781..09e359a 100644 --- a/audit/socket_auditor_test.go +++ b/audit/socket_auditor_test.go @@ -242,7 +242,6 @@ func TestSocketAuditor_Loop_RetriesOnConnectionFailure(t *testing.T) { batchSize: defaultBatchSize, batchTimerDuration: time.Hour, // Ensure timer doesn't interfere with the test sessionID: uuid.MustParse("00000000-0000-4000-8000-000000000001"), - seq: &SequenceCounter{}, } // Set up hook to detect flush attempts @@ -363,7 +362,6 @@ func TestSocketAuditor_Loop_ReportsBatchFullDrops(t *testing.T) { batchSize: defaultBatchSize, batchTimerDuration: time.Hour, sessionID: uuid.MustParse("00000000-0000-4000-8000-000000000001"), - seq: &SequenceCounter{}, } flushed := make(chan struct{}, 4) @@ -483,7 +481,7 @@ func TestSocketAuditor_AuditRequest_SequenceNumberIncrements(t *testing.T) { auditor := setupSocketAuditor(t) for i := range 5 { - auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true}) + auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true, SequenceNumber: int32(i)}) select { case log := <-auditor.logCh: @@ -509,7 +507,7 @@ func TestSocketAuditor_Loop_FlushIncludesSessionID(t *testing.T) { // Fill a batch to trigger a flush. for i := 0; i < auditor.batchSize; i++ { - auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true}) + auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true, SequenceNumber: int32(i)}) } select { @@ -549,7 +547,6 @@ func setupSocketAuditor(t *testing.T) *SocketAuditor { batchSize: defaultBatchSize, batchTimerDuration: defaultBatchTimerDuration, sessionID: uuid.MustParse("00000000-0000-4000-8000-000000000001"), - seq: &SequenceCounter{}, } } @@ -580,7 +577,6 @@ func setupTestAuditor(t *testing.T) (*SocketAuditor, net.Conn) { batchSize: defaultBatchSize, batchTimerDuration: defaultBatchTimerDuration, sessionID: uuid.MustParse("00000000-0000-4000-8000-000000000001"), - seq: &SequenceCounter{}, } return auditor, serverConn diff --git a/cli/cli.go b/cli/cli.go index bed3e6e..f43eeca 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -195,7 +195,7 @@ func BaseCommand(version string) *serpent.Command { { Flag: "session-id-inject-target", Env: "BOUNDARY_SESSION_ID_INJECT_TARGET", - Description: `Inject target for session correlation headers. Repeat the flag once per target; each value describes exactly one target. Format: "domain= [path=]". Example: --session-id-inject-target "domain=prod.coder.com path=/api/v2/aibridge/*"`, + Description: `Inject target for session correlation headers. Repeat the flag once per target; each value describes exactly one target. Format: "domain= [path=]". Example: --session-id-inject-target "domain=prod.coder.com path=/api/v2/aibridge/*".`, Value: &cliConfig.InjectSessionIDTarget, YAML: "", // CLI only, YAML uses session_id_inject_targets. }, diff --git a/config/config.go b/config/config.go index b0100fa..9de181b 100644 --- a/config/config.go +++ b/config/config.go @@ -141,26 +141,19 @@ func NewAppConfigFromCliConfig(cfg CliConfig, targetCMD []string, environ []stri }, nil } -// buildSessionCorrelation merges CLI and YAML inject target sources, -// parses each target string, and validates the resulting configuration. -// environ is passed explicitly (rather than reading os.Environ inside) -// so that callers and tests can supply a controlled environment. +// buildSessionCorrelation merges CLI and YAML inject target sources +// and validates the resulting configuration. Inject targets use the same +// "domain=... path=..." syntax as --allow rules so that matching semantics +// are identical. environ is passed explicitly (rather than reading +// os.Environ inside) so that callers and tests can supply a controlled +// environment. func buildSessionCorrelation(cfg CliConfig, environ []string) (SessionCorrelationConfig, error) { // Merge YAML targets with CLI targets. - rawTargets := append(cfg.InjectSessionIDTargets.Value(), cfg.InjectSessionIDTarget.Value()...) - - var targets []InjectTarget - for _, raw := range rawTargets { - t, err := ParseInjectTarget(raw) - if err != nil { - return SessionCorrelationConfig{}, err - } - targets = append(targets, t) - } + targets := append(cfg.InjectSessionIDTargets.Value(), cfg.InjectSessionIDTarget.Value()...) if len(targets) == 0 && cfg.SessionCorrelationEnabled.Value() { - if t := DefaultInjectTargetFromEnv(environ); t != nil { - targets = []InjectTarget{*t} + if t := DefaultInjectTargetFromEnv(environ); t != "" { + targets = []string{t} } } diff --git a/config/session_correlation.go b/config/session_correlation.go index 5dae21c..67a7b27 100644 --- a/config/session_correlation.go +++ b/config/session_correlation.go @@ -4,6 +4,8 @@ import ( "fmt" "net/url" "strings" + + "github.com/coder/boundary/rulesengine" ) // Header names and paths for session correlation. @@ -27,14 +29,6 @@ const ( CoderAgentURLEnv = "CODER_AGENT_URL" ) -// InjectTarget represents a parsed target for session correlation header -// injection. Requests matching the domain (and optional path glob) will -// receive the session ID and sequence number headers. -type InjectTarget struct { - Domain string - Path string -} - // SessionCorrelationConfig holds configuration for session correlation // header injection. When enabled, boundary injects its session ID and // sequence number as custom headers on matching outbound requests so @@ -45,63 +39,22 @@ type SessionCorrelationConfig struct { // Deployments without AI Bridge in front should set this to false. Enabled bool - // InjectTargets is the list of domain/path patterns that should - // receive session correlation headers. - InjectTargets []InjectTarget -} - -// ParseInjectTarget parses a string of the form "domain=... path=..." -// into an InjectTarget. The domain key is required; path is optional. -func ParseInjectTarget(raw string) (InjectTarget, error) { - raw = strings.TrimSpace(raw) - if raw == "" { - return InjectTarget{}, fmt.Errorf("inject target must not be empty") - } - - var target InjectTarget - seen := make(map[string]bool) - for _, part := range strings.Fields(raw) { - key, value, ok := strings.Cut(part, "=") - if !ok { - return InjectTarget{}, fmt.Errorf( - "inject target: malformed key-value pair %q, expected key=value", part, - ) - } - if seen[key] { - return InjectTarget{}, fmt.Errorf( - "inject target: duplicate key %q (use separate flags for multiple targets)", key, - ) - } - seen[key] = true - switch key { - case "domain": - if value == "" { - return InjectTarget{}, fmt.Errorf("inject target: domain must not be empty") - } - target.Domain = value - case "path": - target.Path = value - default: - return InjectTarget{}, fmt.Errorf("inject target: unknown key %q", key) - } - } - - if target.Domain == "" { - return InjectTarget{}, fmt.Errorf("inject target: domain is required") - } - - return target, nil + // InjectTargets is the list of raw rule specs (same syntax as --allow) + // that should receive session correlation headers. Each string uses the + // rulesengine "domain=... path=..." format so that inject target + // matching is identical to allow-rule matching. + InjectTargets []string } -// DefaultInjectTargetFromEnv derives an InjectTarget from the CODER_AGENT_URL -// variable in the provided environment slice. It returns nil if the variable is -// absent, empty, or not a valid URL with a host. The derived target uses -// DefaultAIBridgePath as the path glob so that all AI Bridge traffic on the -// control-plane host is matched. +// DefaultInjectTargetFromEnv derives an inject target rule string from the +// CODER_AGENT_URL variable in the provided environment slice. It returns "" +// if the variable is absent, empty, or not a valid URL with a host. The +// derived target uses DefaultAIBridgePath as the path glob so that all AI +// Bridge traffic on the control-plane host is matched. // // The environ parameter is accepted rather than reading os.Environ directly so // that callers (and tests) can supply an arbitrary environment. -func DefaultInjectTargetFromEnv(environ []string) *InjectTarget { +func DefaultInjectTargetFromEnv(environ []string) string { var raw string for _, e := range environ { k, v, ok := strings.Cut(e, "=") @@ -111,23 +64,22 @@ func DefaultInjectTargetFromEnv(environ []string) *InjectTarget { } } if raw == "" { - return nil + return "" } u, err := url.Parse(raw) if err != nil || u.Host == "" { - return nil + return "" } - return &InjectTarget{ - Domain: u.Hostname(), - Path: DefaultAIBridgePath, - } + return fmt.Sprintf("domain=%s path=%s", u.Hostname(), DefaultAIBridgePath) } // ValidateSessionCorrelation checks that the session correlation config -// is internally consistent. It returns an error describing the first -// problem found, or nil if the config is valid. +// is internally consistent. When enabled it verifies that at least one +// inject target is configured and that every target string is a valid +// rulesengine rule. It returns an error describing the first problem +// found, or nil if the config is valid. func ValidateSessionCorrelation(cfg SessionCorrelationConfig) error { if !cfg.Enabled { return nil @@ -139,5 +91,26 @@ func ValidateSessionCorrelation(cfg SessionCorrelationConfig) error { ) } + // Reject empty target strings before passing to the parser. + for _, t := range cfg.InjectTargets { + if strings.TrimSpace(t) == "" { + return fmt.Errorf("inject target: must not be empty") + } + } + + // Validate each target parses as a rulesengine rule. + rules, err := rulesengine.ParseAllowSpecs(cfg.InjectTargets) + if err != nil { + return fmt.Errorf("inject target: %w", err) + } + + // Inject targets must specify a domain; path-only rules are not + // meaningful for header injection. + for i, r := range rules { + if r.HostPattern == nil { + return fmt.Errorf("inject target %q: domain is required", cfg.InjectTargets[i]) + } + } + return nil } diff --git a/config/session_correlation_test.go b/config/session_correlation_test.go index 4093ee1..588f9e6 100644 --- a/config/session_correlation_test.go +++ b/config/session_correlation_test.go @@ -4,77 +4,48 @@ import ( "testing" ) -func TestParseInjectTarget(t *testing.T) { +func TestParseInjectTarget_ViaValidation(t *testing.T) { t.Parallel() tests := []struct { name string input string - want InjectTarget wantErr bool }{ { name: "domain only", input: "domain=dev.coder.com", - want: InjectTarget{Domain: "dev.coder.com"}, }, { name: "domain and path", input: "domain=dev.coder.com path=/api/v2/aibridge/*", - want: InjectTarget{Domain: "dev.coder.com", Path: "/api/v2/aibridge/*"}, - }, - { - name: "leading and trailing whitespace", - input: " domain=dev.coder.com path=/api/* ", - want: InjectTarget{Domain: "dev.coder.com", Path: "/api/*"}, }, { name: "empty string", input: "", wantErr: true, }, - { - name: "whitespace only", - input: " ", - wantErr: true, - }, { name: "missing domain", input: "path=/api/*", wantErr: true, }, - { - name: "empty domain value", - input: "domain=", - wantErr: true, - }, - { - name: "malformed pair no equals", - input: "domain", - wantErr: true, - }, { name: "unknown key", input: "domain=example.com port=443", wantErr: true, }, - { - name: "duplicate domain key", - input: "domain=staging.coder.com domain=prod.coder.com", - wantErr: true, - }, - { - name: "duplicate path key", - input: "domain=example.com path=/api/* path=/other/*", - wantErr: true, - }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { t.Parallel() - got, err := ParseInjectTarget(tc.input) + cfg := SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []string{tc.input}, + } + err := ValidateSessionCorrelation(cfg) if tc.wantErr { if err == nil { t.Fatalf("expected error, got nil") @@ -84,12 +55,6 @@ func TestParseInjectTarget(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if got.Domain != tc.want.Domain { - t.Errorf("Domain: got %q, want %q", got.Domain, tc.want.Domain) - } - if got.Path != tc.want.Path { - t.Errorf("Path: got %q, want %q", got.Path, tc.want.Path) - } }) } } @@ -119,7 +84,7 @@ func TestValidateSessionCorrelation(t *testing.T) { name: "enabled with targets", cfg: SessionCorrelationConfig{ Enabled: true, - InjectTargets: []InjectTarget{{Domain: "dev.coder.com"}}, + InjectTargets: []string{"domain=dev.coder.com"}, }, }, { @@ -134,7 +99,7 @@ func TestValidateSessionCorrelation(t *testing.T) { name: "enabled with empty targets slice", cfg: SessionCorrelationConfig{ Enabled: true, - InjectTargets: []InjectTarget{}, + InjectTargets: []string{}, }, wantErr: true, }, @@ -162,19 +127,18 @@ func TestNewAppConfigFromCliConfig_SessionCorrelation(t *testing.T) { t.Parallel() tests := []struct { - name string - cli CliConfig - environ []string - want SessionCorrelationConfig - wantErr bool + name string + cli CliConfig + environ []string + wantEnabled bool + wantTargets []string + wantErr bool }{ { - name: "defaults when not configured", - cli: baseCliConfig(), - want: SessionCorrelationConfig{ - Enabled: false, - InjectTargets: nil, - }, + name: "defaults when not configured", + cli: baseCliConfig(), + wantEnabled: false, + wantTargets: nil, }, { name: "enabled with inject targets", @@ -184,28 +148,22 @@ func TestNewAppConfigFromCliConfig_SessionCorrelation(t *testing.T) { _ = c.InjectSessionIDTarget.Set("domain=dev.coder.com path=/api/v2/aibridge/*") return c }(), - want: SessionCorrelationConfig{ - Enabled: true, - InjectTargets: []InjectTarget{ - {Domain: "dev.coder.com", Path: "/api/v2/aibridge/*"}, - }, - }, + wantEnabled: true, + wantTargets: []string{"domain=dev.coder.com path=/api/v2/aibridge/*"}, }, { - name: "enabled with no targets, CODER_AGENT_URL set → auto-derived", + name: "enabled with no targets, CODER_AGENT_URL set -> auto-derived", cli: func() CliConfig { c := baseCliConfig() _ = c.SessionCorrelationEnabled.Set("true") return c }(), - environ: []string{"CODER_AGENT_URL=https://dev.coder.com/"}, - want: SessionCorrelationConfig{ - Enabled: true, - InjectTargets: []InjectTarget{{Domain: "dev.coder.com", Path: DefaultAIBridgePath}}, - }, + environ: []string{"CODER_AGENT_URL=https://dev.coder.com/"}, + wantEnabled: true, + wantTargets: []string{"domain=dev.coder.com path=" + DefaultAIBridgePath}, }, { - name: "enabled with no targets, CODER_AGENT_URL absent → error", + name: "enabled with no targets, CODER_AGENT_URL absent -> error", cli: func() CliConfig { c := baseCliConfig() _ = c.SessionCorrelationEnabled.Set("true") @@ -242,21 +200,17 @@ func TestNewAppConfigFromCliConfig_SessionCorrelation(t *testing.T) { } sc := got.SessionCorrelation - if sc.Enabled != tc.want.Enabled { - t.Errorf("Enabled: got %v, want %v", sc.Enabled, tc.want.Enabled) + if sc.Enabled != tc.wantEnabled { + t.Errorf("Enabled: got %v, want %v", sc.Enabled, tc.wantEnabled) } - if len(sc.InjectTargets) != len(tc.want.InjectTargets) { + if len(sc.InjectTargets) != len(tc.wantTargets) { t.Fatalf("InjectTargets len: got %d, want %d", - len(sc.InjectTargets), len(tc.want.InjectTargets)) + len(sc.InjectTargets), len(tc.wantTargets)) } for i := range sc.InjectTargets { - if sc.InjectTargets[i].Domain != tc.want.InjectTargets[i].Domain { - t.Errorf("InjectTargets[%d].Domain: got %q, want %q", - i, sc.InjectTargets[i].Domain, tc.want.InjectTargets[i].Domain) - } - if sc.InjectTargets[i].Path != tc.want.InjectTargets[i].Path { - t.Errorf("InjectTargets[%d].Path: got %q, want %q", - i, sc.InjectTargets[i].Path, tc.want.InjectTargets[i].Path) + if sc.InjectTargets[i] != tc.wantTargets[i] { + t.Errorf("InjectTargets[%d]: got %q, want %q", + i, sc.InjectTargets[i], tc.wantTargets[i]) } } }) @@ -269,42 +223,42 @@ func TestDefaultInjectTargetFromEnv(t *testing.T) { tests := []struct { name string environ []string - want *InjectTarget + want string }{ { name: "valid URL with trailing slash", environ: []string{"CODER_AGENT_URL=https://dev.coder.com/"}, - want: &InjectTarget{Domain: "dev.coder.com", Path: DefaultAIBridgePath}, + want: "domain=dev.coder.com path=" + DefaultAIBridgePath, }, { name: "valid URL without trailing slash", environ: []string{"CODER_AGENT_URL=https://dev.coder.com"}, - want: &InjectTarget{Domain: "dev.coder.com", Path: DefaultAIBridgePath}, + want: "domain=dev.coder.com path=" + DefaultAIBridgePath, }, { - name: "URL with port", // Ports are ignored in the rules engine, so we strip them here. + name: "URL with port", environ: []string{"CODER_AGENT_URL=https://dev.coder.com:8443/"}, - want: &InjectTarget{Domain: "dev.coder.com", Path: DefaultAIBridgePath}, + want: "domain=dev.coder.com path=" + DefaultAIBridgePath, }, { name: "unset variable", environ: []string{}, - want: nil, + want: "", }, { name: "empty value", environ: []string{"CODER_AGENT_URL="}, - want: nil, + want: "", }, { name: "no host in URL", environ: []string{"CODER_AGENT_URL=not-a-url"}, - want: nil, + want: "", }, { name: "other env vars present but not CODER_AGENT_URL", environ: []string{"CODER_URL=https://dev.coder.com/", "HOME=/home/user"}, - want: nil, + want: "", }, } @@ -313,20 +267,8 @@ func TestDefaultInjectTargetFromEnv(t *testing.T) { t.Parallel() got := DefaultInjectTargetFromEnv(tc.environ) - if tc.want == nil { - if got != nil { - t.Errorf("expected nil, got %+v", got) - } - return - } - if got == nil { - t.Fatalf("expected %+v, got nil", tc.want) - } - if got.Domain != tc.want.Domain { - t.Errorf("Domain: got %q, want %q", got.Domain, tc.want.Domain) - } - if got.Path != tc.want.Path { - t.Errorf("Path: got %q, want %q", got.Path, tc.want.Path) + if got != tc.want { + t.Errorf("got %q, want %q", got, tc.want) } }) } @@ -338,7 +280,7 @@ func TestBuildSessionCorrelation_TargetMerge(t *testing.T) { tests := []struct { name string cfg func() CliConfig - wantTargets []InjectTarget + wantTargets []string }{ { name: "YAML targets only", @@ -348,8 +290,8 @@ func TestBuildSessionCorrelation_TargetMerge(t *testing.T) { _ = c.InjectSessionIDTargets.Set("domain=yaml.example.com path=/yaml/*") return c }, - wantTargets: []InjectTarget{ - {Domain: "yaml.example.com", Path: "/yaml/*"}, + wantTargets: []string{ + "domain=yaml.example.com path=/yaml/*", }, }, { @@ -360,8 +302,8 @@ func TestBuildSessionCorrelation_TargetMerge(t *testing.T) { _ = c.InjectSessionIDTarget.Set("domain=cli.example.com path=/cli/*") return c }, - wantTargets: []InjectTarget{ - {Domain: "cli.example.com", Path: "/cli/*"}, + wantTargets: []string{ + "domain=cli.example.com path=/cli/*", }, }, { @@ -373,9 +315,9 @@ func TestBuildSessionCorrelation_TargetMerge(t *testing.T) { _ = c.InjectSessionIDTarget.Set("domain=cli.example.com path=/cli/*") return c }, - wantTargets: []InjectTarget{ - {Domain: "yaml.example.com", Path: "/yaml/*"}, - {Domain: "cli.example.com", Path: "/cli/*"}, + wantTargets: []string{ + "domain=yaml.example.com path=/yaml/*", + "domain=cli.example.com path=/cli/*", }, }, { @@ -389,11 +331,11 @@ func TestBuildSessionCorrelation_TargetMerge(t *testing.T) { _ = c.InjectSessionIDTarget.Set("domain=cli2.example.com") return c }, - wantTargets: []InjectTarget{ - {Domain: "yaml1.example.com"}, - {Domain: "yaml2.example.com"}, - {Domain: "cli1.example.com"}, - {Domain: "cli2.example.com"}, + wantTargets: []string{ + "domain=yaml1.example.com", + "domain=yaml2.example.com", + "domain=cli1.example.com", + "domain=cli2.example.com", }, }, } @@ -411,13 +353,9 @@ func TestBuildSessionCorrelation_TargetMerge(t *testing.T) { len(sc.InjectTargets), len(tc.wantTargets)) } for i := range sc.InjectTargets { - if sc.InjectTargets[i].Domain != tc.wantTargets[i].Domain { - t.Errorf("InjectTargets[%d].Domain: got %q, want %q", - i, sc.InjectTargets[i].Domain, tc.wantTargets[i].Domain) - } - if sc.InjectTargets[i].Path != tc.wantTargets[i].Path { - t.Errorf("InjectTargets[%d].Path: got %q, want %q", - i, sc.InjectTargets[i].Path, tc.wantTargets[i].Path) + if sc.InjectTargets[i] != tc.wantTargets[i] { + t.Errorf("InjectTargets[%d]: got %q, want %q", + i, sc.InjectTargets[i], tc.wantTargets[i]) } } }) @@ -431,23 +369,23 @@ func TestBuildSessionCorrelation_AgentURLFallback(t *testing.T) { name string cfg func() CliConfig environ []string - wantTargets []InjectTarget + wantTargets []string wantErr bool }{ { - name: "enabled, no explicit targets, CODER_AGENT_URL set → auto-derived", + name: "enabled, no explicit targets, CODER_AGENT_URL set -> auto-derived", cfg: func() CliConfig { c := baseCliConfig() _ = c.SessionCorrelationEnabled.Set("true") return c }, environ: []string{"CODER_AGENT_URL=https://dev.coder.com/"}, - wantTargets: []InjectTarget{ - {Domain: "dev.coder.com", Path: DefaultAIBridgePath}, + wantTargets: []string{ + "domain=dev.coder.com path=" + DefaultAIBridgePath, }, }, { - name: "enabled, no explicit targets, CODER_AGENT_URL absent → error", + name: "enabled, no explicit targets, CODER_AGENT_URL absent -> error", cfg: func() CliConfig { c := baseCliConfig() _ = c.SessionCorrelationEnabled.Set("true") @@ -465,12 +403,12 @@ func TestBuildSessionCorrelation_AgentURLFallback(t *testing.T) { return c }, environ: []string{"CODER_AGENT_URL=https://dev.coder.com/"}, - wantTargets: []InjectTarget{ - {Domain: "custom.example.com", Path: ""}, + wantTargets: []string{ + "domain=custom.example.com", }, }, { - name: "disabled, CODER_AGENT_URL absent → valid (no targets needed)", + name: "disabled, CODER_AGENT_URL absent -> valid (no targets needed)", cfg: func() CliConfig { return baseCliConfig() }, @@ -498,13 +436,9 @@ func TestBuildSessionCorrelation_AgentURLFallback(t *testing.T) { len(sc.InjectTargets), len(tc.wantTargets)) } for i := range sc.InjectTargets { - if sc.InjectTargets[i].Domain != tc.wantTargets[i].Domain { - t.Errorf("InjectTargets[%d].Domain: got %q, want %q", - i, sc.InjectTargets[i].Domain, tc.wantTargets[i].Domain) - } - if sc.InjectTargets[i].Path != tc.wantTargets[i].Path { - t.Errorf("InjectTargets[%d].Path: got %q, want %q", - i, sc.InjectTargets[i].Path, tc.wantTargets[i].Path) + if sc.InjectTargets[i] != tc.wantTargets[i] { + t.Errorf("InjectTargets[%d]: got %q, want %q", + i, sc.InjectTargets[i], tc.wantTargets[i]) } } }) diff --git a/proxy/proxy.go b/proxy/proxy.go index 154a6bc..229cce0 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -19,17 +19,22 @@ import ( "time" "github.com/coder/boundary/audit" + "github.com/coder/boundary/config" "github.com/coder/boundary/rulesengine" ) // Server handles HTTP and HTTPS requests with rule-based filtering type Server struct { - ruleEngine rulesengine.Engine - auditor audit.Auditor - logger *slog.Logger - tlsConfig *tls.Config - httpPort int - started atomic.Bool + ruleEngine rulesengine.Engine + auditor audit.Auditor + logger *slog.Logger + tlsConfig *tls.Config + httpPort int + started atomic.Bool + injectEngine *rulesengine.Engine // nil when session correlation is disabled + sessionID string + seqCounter audit.SequenceCounter + forwardTransport http.RoundTripper listener net.Listener pprofServer *http.Server @@ -46,18 +51,35 @@ type Config struct { TLSConfig *tls.Config PprofEnabled bool PprofPort int + // SessionCorrelation controls header injection for AI Bridge + // correlation. See config.SessionCorrelationConfig for details. + SessionCorrelation config.SessionCorrelationConfig + // InjectEngine, if non-nil, is used to evaluate whether outgoing + // requests match configured inject targets. Built from + // SessionCorrelation.InjectTargets using rulesengine.ParseAllowSpecs. + InjectEngine *rulesengine.Engine + // SessionID is the boundary session UUID injected as a header + // on matching requests. + SessionID string + // ForwardTransport, if non-nil, is used when forwarding requests to + // backend servers. Defaults to http.DefaultTransport when nil. Set in + // tests to trust self-signed backend certificates. + ForwardTransport http.RoundTripper } // NewProxyServer creates a new proxy server instance func NewProxyServer(config Config) *Server { return &Server{ - ruleEngine: config.RuleEngine, - auditor: config.Auditor, - logger: config.Logger, - tlsConfig: config.TLSConfig, - httpPort: config.HTTPPort, - pprofEnabled: config.PprofEnabled, - pprofPort: config.PprofPort, + ruleEngine: config.RuleEngine, + auditor: config.Auditor, + logger: config.Logger, + tlsConfig: config.TLSConfig, + httpPort: config.HTTPPort, + pprofEnabled: config.PprofEnabled, + pprofPort: config.PprofPort, + injectEngine: config.InjectEngine, + sessionID: config.SessionID, + forwardTransport: config.ForwardTransport, } } @@ -276,12 +298,15 @@ func (p *Server) processHTTPRequest(conn net.Conn, req *http.Request, https bool result := p.ruleEngine.Evaluate(req.Method, fullURL) + seqNum := p.seqCounter.Next() + p.auditor.AuditRequest(audit.Request{ - Method: req.Method, - URL: fullURL, - Host: req.Host, - Allowed: result.Allowed, - Rule: result.Rule, + Method: req.Method, + URL: fullURL, + Host: req.Host, + Allowed: result.Allowed, + Rule: result.Rule, + SequenceNumber: seqNum, }) if !result.Allowed { @@ -290,15 +315,30 @@ func (p *Server) processHTTPRequest(conn net.Conn, req *http.Request, https bool } // Forward request to destination - p.forwardRequest(conn, req, https) + p.forwardRequest(conn, req, https, seqNum) +} + +// shouldInjectHeaders reports whether the request URL matches any +// configured inject target. Inject targets are evaluated using the same +// rulesengine matching as --allow rules so that domain/path semantics +// are unified. When session correlation is disabled or no targets match +// it returns false. +func (p *Server) shouldInjectHeaders(fullURL string) bool { + if p.injectEngine == nil { + return false + } + // Inject targets do not restrict by HTTP method, so we pass an empty + // method. Rules without a method pattern match all methods. + return p.injectEngine.Evaluate("", fullURL).Allowed } -func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool) { +func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool, seqNum int32) { // Create HTTP client client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse // Don't follow redirects }, + Transport: p.forwardTransport, // nil → http.DefaultTransport } scheme := "http" @@ -338,6 +378,11 @@ func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool) { } } + if p.shouldInjectHeaders(targetURL.String()) { + newReq.Header.Set(config.SessionIDHeaderName, p.sessionID) + newReq.Header.Set(config.SequenceNumberHeaderName, strconv.Itoa(int(seqNum))) + } + // Make request to destination resp, err := client.Do(newReq) if err != nil { diff --git a/proxy/proxy_audit_test.go b/proxy/proxy_audit_test.go index bbb7126..0e591e2 100644 --- a/proxy/proxy_audit_test.go +++ b/proxy/proxy_audit_test.go @@ -1,6 +1,7 @@ package proxy import ( + "crypto/tls" "net" "net/http" "net/http/httptest" @@ -31,6 +32,76 @@ func (c *capturingAuditor) getRequests() []audit.Request { return append([]audit.Request{}, c.requests...) } +func TestSequenceNumberIncrementsAcrossRequestTypes(t *testing.T) { + // Plain HTTP backend — used by the plain HTTP request. + httpBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer httpBackend.Close() + + // TLS backend — used by both the implicit-CONNECT and explicit-CONNECT + // requests. The proxy needs InsecureSkipVerify to trust its self-signed cert. + tlsBackend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer tlsBackend.Close() + + httpBackendURL, err := url.Parse(httpBackend.URL) + require.NoError(t, err) + tlsBackendURL, err := url.Parse(tlsBackend.URL) + require.NoError(t, err) + + // TLS SNI requires a hostname, not an IP address. httptest servers bind to + // 127.0.0.1, so rewrite the host to "localhost" for all TLS connections so + // that the proxy's cert manager receives a proper SNI value. + tlsHost := "localhost:" + tlsBackendURL.Port() + tlsURL := "https://" + tlsHost + + auditor := &capturingAuditor{} + + //nolint:gosec + insecureTransport := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(httpBackendURL.Hostname()), + WithAllowedDomain("localhost"), + WithAuditor(auditor), + WithForwardTransport(insecureTransport), + ).Start() + defer pt.Stop() + + // Request 1: plain HTTP — handleHTTPConnection → processHTTPRequest(https=false) + resp, err := pt.proxyClient.Get(httpBackend.URL + "/") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Request 2: HTTPS via implicit CONNECT — Go's transport automatically sends + // CONNECT for https:// URLs → handleCONNECTTunnel → processHTTPRequest(https=true) + resp, err = pt.proxyClient.Get(tlsURL + "/") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Request 3: inside an explicit CONNECT tunnel — handleCONNECTTunnel → + // processHTTPRequest(https=true), driven by a manually established tunnel. + tunnel, err := pt.establishExplicitCONNECT(tlsHost) + require.NoError(t, err) + defer tunnel.close() //nolint:errcheck + _, err = tunnel.sendRequest(tlsHost, "/") + require.NoError(t, err) + + requests := auditor.getRequests() + require.Len(t, requests, 3, "expected one audit record per request") + + assert.Equal(t, int32(0), requests[0].SequenceNumber, "HTTP request must have sequence number 0") + assert.Equal(t, int32(1), requests[1].SequenceNumber, "implicit-CONNECT request must have sequence number 1") + assert.Equal(t, int32(2), requests[2].SequenceNumber, "explicit-CONNECT tunnel request must have sequence number 2") +} + func TestAuditURLIsFullyFormed_HTTP(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) diff --git a/proxy/proxy_framework_test.go b/proxy/proxy_framework_test.go index 36a332b..87edf26 100644 --- a/proxy/proxy_framework_test.go +++ b/proxy/proxy_framework_test.go @@ -18,6 +18,7 @@ import ( "time" "github.com/coder/boundary/audit" + "github.com/coder/boundary/config" "github.com/coder/boundary/rulesengine" boundary_tls "github.com/coder/boundary/tls" "github.com/stretchr/testify/require" @@ -32,16 +33,19 @@ func (m *mockAuditor) AuditRequest(req audit.Request) { // ProxyTest is a high-level test framework for proxy tests type ProxyTest struct { - t *testing.T - server *Server - client *http.Client - proxyClient *http.Client - port int - useCertManager bool - configDir string - startupDelay time.Duration - allowedRules []string - auditor audit.Auditor + t *testing.T + server *Server + client *http.Client + proxyClient *http.Client + port int + useCertManager bool + configDir string + startupDelay time.Duration + allowedRules []string + auditor audit.Auditor + sessionCorrelation config.SessionCorrelationConfig + sessionID string + forwardTransport http.RoundTripper } // ProxyTestOption is a function that configures ProxyTest @@ -109,6 +113,30 @@ func WithAuditor(auditor audit.Auditor) ProxyTestOption { } } +// WithSessionCorrelation sets the session correlation config for the +// proxy under test. +func WithSessionCorrelation(sc config.SessionCorrelationConfig) ProxyTestOption { + return func(pt *ProxyTest) { + pt.sessionCorrelation = sc + } +} + +// WithSessionID sets the boundary session ID for the proxy under test. +func WithSessionID(id string) ProxyTestOption { + return func(pt *ProxyTest) { + pt.sessionID = id + } +} + +// WithForwardTransport sets the http.RoundTripper the proxy uses when +// forwarding requests to backends. Use in tests to trust self-signed +// backend certificates (e.g. those from httptest.NewTLSServer). +func WithForwardTransport(transport http.RoundTripper) ProxyTestOption { + return func(pt *ProxyTest) { + pt.forwardTransport = transport + } +} + // Start starts the proxy server func (pt *ProxyTest) Start() *ProxyTest { pt.t.Helper() @@ -152,12 +180,25 @@ func (pt *ProxyTest) Start() *ProxyTest { } } + // Build inject engine from session correlation targets. + var injectEngine *rulesengine.Engine + if pt.sessionCorrelation.Enabled && len(pt.sessionCorrelation.InjectTargets) > 0 { + injectRules, err := rulesengine.ParseAllowSpecs(pt.sessionCorrelation.InjectTargets) + require.NoError(pt.t, err, "Failed to parse inject target rules") + eng := rulesengine.NewRuleEngine(injectRules, logger) + injectEngine = &eng + } + pt.server = NewProxyServer(Config{ - HTTPPort: pt.port, - RuleEngine: ruleEngine, - Auditor: auditor, - Logger: logger, - TLSConfig: tlsConfig, + HTTPPort: pt.port, + RuleEngine: ruleEngine, + Auditor: auditor, + Logger: logger, + TLSConfig: tlsConfig, + SessionCorrelation: pt.sessionCorrelation, + InjectEngine: injectEngine, + SessionID: pt.sessionID, + ForwardTransport: pt.forwardTransport, }) err = pt.server.Start() diff --git a/proxy/proxy_session_correlation_test.go b/proxy/proxy_session_correlation_test.go new file mode 100644 index 0000000..45f7ac6 --- /dev/null +++ b/proxy/proxy_session_correlation_test.go @@ -0,0 +1,278 @@ +package proxy + +import ( + "net/http" + "net/http/httptest" + "net/url" + "sync" + "testing" + + "github.com/coder/boundary/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// headerCapturingBackend spins up an httptest.Server that records the +// headers it receives. Call receivedHeaders after the request to inspect +// them. +type headerCapturingBackend struct { + server *httptest.Server + mu sync.Mutex + headers http.Header +} + +func newHeaderCapturingBackend() *headerCapturingBackend { + hcb := &headerCapturingBackend{} + hcb.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hcb.mu.Lock() + hcb.headers = r.Header.Clone() + hcb.mu.Unlock() + w.WriteHeader(http.StatusOK) + })) + return hcb +} + +func (h *headerCapturingBackend) close() { h.server.Close() } + +func (h *headerCapturingBackend) receivedHeaders() http.Header { + h.mu.Lock() + defer h.mu.Unlock() + return h.headers.Clone() +} + +func TestSessionCorrelation_MatchedDomain(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []string{"domain=" + backendURL.Hostname()}, + }), + WithSessionID("test-session-id-1234"), + ).Start() + defer pt.Stop() + + resp, err := pt.proxyClient.Get(backend.server.URL + "/api/v2") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Equal(t, "test-session-id-1234", got.Get(config.SessionIDHeaderName), + "session ID header must be injected on matching domain") + assert.Equal(t, "0", got.Get(config.SequenceNumberHeaderName), + "sequence number header must start at 0") +} + +func TestSessionCorrelation_MatchedDomainWithPort(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + // httptest.NewServer binds to a random ephemeral port, so backendURL.Host + // is always "127.0.0.1:" and backendURL.Hostname() is "127.0.0.1". + // Verify that assumption holds before relying on it below. + require.NotEmpty(t, backendURL.Port(), "httptest.Server URL must include a port") + require.Equal(t, backendURL.Hostname()+":"+backendURL.Port(), backendURL.Host, + "backendURL.Host must be host:port") + + // Configure the inject target using just the hostname (without port). + // The rulesengine strips ports during matching, same as for --allow rules. + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []string{"domain=" + backendURL.Hostname()}, + }), + WithSessionID("test-session-id-1234"), + ).Start() + defer pt.Stop() + + resp, err := pt.proxyClient.Get(backend.server.URL + "/api/v2") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Equal(t, "test-session-id-1234", got.Get(config.SessionIDHeaderName), + "session ID header must be injected when inject target domain matches request hostname") + assert.Equal(t, "0", got.Get(config.SequenceNumberHeaderName), + "sequence number header must start at 0") +} + +func TestSessionCorrelation_UnmatchedDomain(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []string{"domain=other-domain.example.com"}, + }), + WithSessionID("test-session-id-1234"), + ).Start() + defer pt.Stop() + + resp, err := pt.proxyClient.Get(backend.server.URL + "/api/v2") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Empty(t, got.Get(config.SessionIDHeaderName), + "session ID header must not be injected on unmatched domain") + assert.Empty(t, got.Get(config.SequenceNumberHeaderName), + "sequence number header must not be injected on unmatched domain") +} + +func TestSessionCorrelation_Disabled(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: false, + InjectTargets: []string{"domain=" + backendURL.Hostname()}, + }), + WithSessionID("test-session-id-1234"), + ).Start() + defer pt.Stop() + + resp, err := pt.proxyClient.Get(backend.server.URL + "/api/v2") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Empty(t, got.Get(config.SessionIDHeaderName), + "session ID header must not be injected when correlation is disabled") + assert.Empty(t, got.Get(config.SequenceNumberHeaderName), + "sequence number header must not be injected when correlation is disabled") +} + +func TestSessionCorrelation_OverwritesClientValue(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []string{"domain=" + backendURL.Hostname()}, + }), + WithSessionID("real-session-id"), + ).Start() + defer pt.Stop() + + // Send a request with client-supplied session correlation headers + // that should be overwritten by the proxy. + req, err := http.NewRequest(http.MethodGet, backend.server.URL+"/api/v2", nil) + require.NoError(t, err) + req.Header.Set(config.SessionIDHeaderName, "spoofed-session-id") + req.Header.Set(config.SequenceNumberHeaderName, "99999") + + resp, err := pt.proxyClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Equal(t, "real-session-id", got.Get(config.SessionIDHeaderName), + "proxy must overwrite client-supplied session ID header") + assert.Equal(t, "0", got.Get(config.SequenceNumberHeaderName), + "proxy must overwrite client-supplied sequence number header") +} + +func TestSessionCorrelation_PathMatching(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []string{"domain=" + backendURL.Hostname() + " path=/api/*"}, + }), + WithSessionID("test-session-id"), + ).Start() + defer pt.Stop() + + t.Run("matching path", func(t *testing.T) { + resp, err := pt.proxyClient.Get(backend.server.URL + "/api/v2") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Equal(t, "test-session-id", got.Get(config.SessionIDHeaderName), + "header must be injected when path matches") + }) + + t.Run("non-matching path", func(t *testing.T) { + resp, err := pt.proxyClient.Get(backend.server.URL + "/other/path") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Empty(t, got.Get(config.SessionIDHeaderName), + "header must not be injected when path does not match") + }) +} + +func TestSessionCorrelation_SequenceNumberIncrements(t *testing.T) { + backend := newHeaderCapturingBackend() + defer backend.close() + + backendURL, err := url.Parse(backend.server.URL) + require.NoError(t, err) + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(backendURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []string{"domain=" + backendURL.Hostname()}, + }), + WithSessionID("test-session-id"), + ).Start() + defer pt.Stop() + + for i, expected := range []string{"0", "1", "2"} { + resp, err := pt.proxyClient.Get(backend.server.URL + "/api/v2") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := backend.receivedHeaders() + assert.Equal(t, expected, got.Get(config.SequenceNumberHeaderName), + "request %d: sequence number must be %s", i, expected) + } +}