diff --git a/config/config.go b/config/config.go index b0100fa..9de181b 100644 --- a/config/config.go +++ b/config/config.go @@ -141,26 +141,19 @@ func NewAppConfigFromCliConfig(cfg CliConfig, targetCMD []string, environ []stri }, nil } -// buildSessionCorrelation merges CLI and YAML inject target sources, -// parses each target string, and validates the resulting configuration. -// environ is passed explicitly (rather than reading os.Environ inside) -// so that callers and tests can supply a controlled environment. +// buildSessionCorrelation merges CLI and YAML inject target sources +// and validates the resulting configuration. Inject targets use the same +// "domain=... path=..." syntax as --allow rules so that matching semantics +// are identical. environ is passed explicitly (rather than reading +// os.Environ inside) so that callers and tests can supply a controlled +// environment. func buildSessionCorrelation(cfg CliConfig, environ []string) (SessionCorrelationConfig, error) { // Merge YAML targets with CLI targets. - rawTargets := append(cfg.InjectSessionIDTargets.Value(), cfg.InjectSessionIDTarget.Value()...) - - var targets []InjectTarget - for _, raw := range rawTargets { - t, err := ParseInjectTarget(raw) - if err != nil { - return SessionCorrelationConfig{}, err - } - targets = append(targets, t) - } + targets := append(cfg.InjectSessionIDTargets.Value(), cfg.InjectSessionIDTarget.Value()...) if len(targets) == 0 && cfg.SessionCorrelationEnabled.Value() { - if t := DefaultInjectTargetFromEnv(environ); t != nil { - targets = []InjectTarget{*t} + if t := DefaultInjectTargetFromEnv(environ); t != "" { + targets = []string{t} } } diff --git a/config/session_correlation.go b/config/session_correlation.go index 5dae21c..67a7b27 100644 --- a/config/session_correlation.go +++ b/config/session_correlation.go @@ -4,6 +4,8 @@ import ( "fmt" "net/url" "strings" + + "github.com/coder/boundary/rulesengine" ) // Header names and paths for session correlation. @@ -27,14 +29,6 @@ const ( CoderAgentURLEnv = "CODER_AGENT_URL" ) -// InjectTarget represents a parsed target for session correlation header -// injection. Requests matching the domain (and optional path glob) will -// receive the session ID and sequence number headers. -type InjectTarget struct { - Domain string - Path string -} - // SessionCorrelationConfig holds configuration for session correlation // header injection. When enabled, boundary injects its session ID and // sequence number as custom headers on matching outbound requests so @@ -45,63 +39,22 @@ type SessionCorrelationConfig struct { // Deployments without AI Bridge in front should set this to false. Enabled bool - // InjectTargets is the list of domain/path patterns that should - // receive session correlation headers. - InjectTargets []InjectTarget -} - -// ParseInjectTarget parses a string of the form "domain=... path=..." -// into an InjectTarget. The domain key is required; path is optional. -func ParseInjectTarget(raw string) (InjectTarget, error) { - raw = strings.TrimSpace(raw) - if raw == "" { - return InjectTarget{}, fmt.Errorf("inject target must not be empty") - } - - var target InjectTarget - seen := make(map[string]bool) - for _, part := range strings.Fields(raw) { - key, value, ok := strings.Cut(part, "=") - if !ok { - return InjectTarget{}, fmt.Errorf( - "inject target: malformed key-value pair %q, expected key=value", part, - ) - } - if seen[key] { - return InjectTarget{}, fmt.Errorf( - "inject target: duplicate key %q (use separate flags for multiple targets)", key, - ) - } - seen[key] = true - switch key { - case "domain": - if value == "" { - return InjectTarget{}, fmt.Errorf("inject target: domain must not be empty") - } - target.Domain = value - case "path": - target.Path = value - default: - return InjectTarget{}, fmt.Errorf("inject target: unknown key %q", key) - } - } - - if target.Domain == "" { - return InjectTarget{}, fmt.Errorf("inject target: domain is required") - } - - return target, nil + // InjectTargets is the list of raw rule specs (same syntax as --allow) + // that should receive session correlation headers. Each string uses the + // rulesengine "domain=... path=..." format so that inject target + // matching is identical to allow-rule matching. + InjectTargets []string } -// DefaultInjectTargetFromEnv derives an InjectTarget from the CODER_AGENT_URL -// variable in the provided environment slice. It returns nil if the variable is -// absent, empty, or not a valid URL with a host. The derived target uses -// DefaultAIBridgePath as the path glob so that all AI Bridge traffic on the -// control-plane host is matched. +// DefaultInjectTargetFromEnv derives an inject target rule string from the +// CODER_AGENT_URL variable in the provided environment slice. It returns "" +// if the variable is absent, empty, or not a valid URL with a host. The +// derived target uses DefaultAIBridgePath as the path glob so that all AI +// Bridge traffic on the control-plane host is matched. // // The environ parameter is accepted rather than reading os.Environ directly so // that callers (and tests) can supply an arbitrary environment. -func DefaultInjectTargetFromEnv(environ []string) *InjectTarget { +func DefaultInjectTargetFromEnv(environ []string) string { var raw string for _, e := range environ { k, v, ok := strings.Cut(e, "=") @@ -111,23 +64,22 @@ func DefaultInjectTargetFromEnv(environ []string) *InjectTarget { } } if raw == "" { - return nil + return "" } u, err := url.Parse(raw) if err != nil || u.Host == "" { - return nil + return "" } - return &InjectTarget{ - Domain: u.Hostname(), - Path: DefaultAIBridgePath, - } + return fmt.Sprintf("domain=%s path=%s", u.Hostname(), DefaultAIBridgePath) } // ValidateSessionCorrelation checks that the session correlation config -// is internally consistent. It returns an error describing the first -// problem found, or nil if the config is valid. +// is internally consistent. When enabled it verifies that at least one +// inject target is configured and that every target string is a valid +// rulesengine rule. It returns an error describing the first problem +// found, or nil if the config is valid. func ValidateSessionCorrelation(cfg SessionCorrelationConfig) error { if !cfg.Enabled { return nil @@ -139,5 +91,26 @@ func ValidateSessionCorrelation(cfg SessionCorrelationConfig) error { ) } + // Reject empty target strings before passing to the parser. + for _, t := range cfg.InjectTargets { + if strings.TrimSpace(t) == "" { + return fmt.Errorf("inject target: must not be empty") + } + } + + // Validate each target parses as a rulesengine rule. + rules, err := rulesengine.ParseAllowSpecs(cfg.InjectTargets) + if err != nil { + return fmt.Errorf("inject target: %w", err) + } + + // Inject targets must specify a domain; path-only rules are not + // meaningful for header injection. + for i, r := range rules { + if r.HostPattern == nil { + return fmt.Errorf("inject target %q: domain is required", cfg.InjectTargets[i]) + } + } + return nil } diff --git a/config/session_correlation_test.go b/config/session_correlation_test.go index 4093ee1..588f9e6 100644 --- a/config/session_correlation_test.go +++ b/config/session_correlation_test.go @@ -4,77 +4,48 @@ import ( "testing" ) -func TestParseInjectTarget(t *testing.T) { +func TestParseInjectTarget_ViaValidation(t *testing.T) { t.Parallel() tests := []struct { name string input string - want InjectTarget wantErr bool }{ { name: "domain only", input: "domain=dev.coder.com", - want: InjectTarget{Domain: "dev.coder.com"}, }, { name: "domain and path", input: "domain=dev.coder.com path=/api/v2/aibridge/*", - want: InjectTarget{Domain: "dev.coder.com", Path: "/api/v2/aibridge/*"}, - }, - { - name: "leading and trailing whitespace", - input: " domain=dev.coder.com path=/api/* ", - want: InjectTarget{Domain: "dev.coder.com", Path: "/api/*"}, }, { name: "empty string", input: "", wantErr: true, }, - { - name: "whitespace only", - input: " ", - wantErr: true, - }, { name: "missing domain", input: "path=/api/*", wantErr: true, }, - { - name: "empty domain value", - input: "domain=", - wantErr: true, - }, - { - name: "malformed pair no equals", - input: "domain", - wantErr: true, - }, { name: "unknown key", input: "domain=example.com port=443", wantErr: true, }, - { - name: "duplicate domain key", - input: "domain=staging.coder.com domain=prod.coder.com", - wantErr: true, - }, - { - name: "duplicate path key", - input: "domain=example.com path=/api/* path=/other/*", - wantErr: true, - }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { t.Parallel() - got, err := ParseInjectTarget(tc.input) + cfg := SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []string{tc.input}, + } + err := ValidateSessionCorrelation(cfg) if tc.wantErr { if err == nil { t.Fatalf("expected error, got nil") @@ -84,12 +55,6 @@ func TestParseInjectTarget(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if got.Domain != tc.want.Domain { - t.Errorf("Domain: got %q, want %q", got.Domain, tc.want.Domain) - } - if got.Path != tc.want.Path { - t.Errorf("Path: got %q, want %q", got.Path, tc.want.Path) - } }) } } @@ -119,7 +84,7 @@ func TestValidateSessionCorrelation(t *testing.T) { name: "enabled with targets", cfg: SessionCorrelationConfig{ Enabled: true, - InjectTargets: []InjectTarget{{Domain: "dev.coder.com"}}, + InjectTargets: []string{"domain=dev.coder.com"}, }, }, { @@ -134,7 +99,7 @@ func TestValidateSessionCorrelation(t *testing.T) { name: "enabled with empty targets slice", cfg: SessionCorrelationConfig{ Enabled: true, - InjectTargets: []InjectTarget{}, + InjectTargets: []string{}, }, wantErr: true, }, @@ -162,19 +127,18 @@ func TestNewAppConfigFromCliConfig_SessionCorrelation(t *testing.T) { t.Parallel() tests := []struct { - name string - cli CliConfig - environ []string - want SessionCorrelationConfig - wantErr bool + name string + cli CliConfig + environ []string + wantEnabled bool + wantTargets []string + wantErr bool }{ { - name: "defaults when not configured", - cli: baseCliConfig(), - want: SessionCorrelationConfig{ - Enabled: false, - InjectTargets: nil, - }, + name: "defaults when not configured", + cli: baseCliConfig(), + wantEnabled: false, + wantTargets: nil, }, { name: "enabled with inject targets", @@ -184,28 +148,22 @@ func TestNewAppConfigFromCliConfig_SessionCorrelation(t *testing.T) { _ = c.InjectSessionIDTarget.Set("domain=dev.coder.com path=/api/v2/aibridge/*") return c }(), - want: SessionCorrelationConfig{ - Enabled: true, - InjectTargets: []InjectTarget{ - {Domain: "dev.coder.com", Path: "/api/v2/aibridge/*"}, - }, - }, + wantEnabled: true, + wantTargets: []string{"domain=dev.coder.com path=/api/v2/aibridge/*"}, }, { - name: "enabled with no targets, CODER_AGENT_URL set → auto-derived", + name: "enabled with no targets, CODER_AGENT_URL set -> auto-derived", cli: func() CliConfig { c := baseCliConfig() _ = c.SessionCorrelationEnabled.Set("true") return c }(), - environ: []string{"CODER_AGENT_URL=https://dev.coder.com/"}, - want: SessionCorrelationConfig{ - Enabled: true, - InjectTargets: []InjectTarget{{Domain: "dev.coder.com", Path: DefaultAIBridgePath}}, - }, + environ: []string{"CODER_AGENT_URL=https://dev.coder.com/"}, + wantEnabled: true, + wantTargets: []string{"domain=dev.coder.com path=" + DefaultAIBridgePath}, }, { - name: "enabled with no targets, CODER_AGENT_URL absent → error", + name: "enabled with no targets, CODER_AGENT_URL absent -> error", cli: func() CliConfig { c := baseCliConfig() _ = c.SessionCorrelationEnabled.Set("true") @@ -242,21 +200,17 @@ func TestNewAppConfigFromCliConfig_SessionCorrelation(t *testing.T) { } sc := got.SessionCorrelation - if sc.Enabled != tc.want.Enabled { - t.Errorf("Enabled: got %v, want %v", sc.Enabled, tc.want.Enabled) + if sc.Enabled != tc.wantEnabled { + t.Errorf("Enabled: got %v, want %v", sc.Enabled, tc.wantEnabled) } - if len(sc.InjectTargets) != len(tc.want.InjectTargets) { + if len(sc.InjectTargets) != len(tc.wantTargets) { t.Fatalf("InjectTargets len: got %d, want %d", - len(sc.InjectTargets), len(tc.want.InjectTargets)) + len(sc.InjectTargets), len(tc.wantTargets)) } for i := range sc.InjectTargets { - if sc.InjectTargets[i].Domain != tc.want.InjectTargets[i].Domain { - t.Errorf("InjectTargets[%d].Domain: got %q, want %q", - i, sc.InjectTargets[i].Domain, tc.want.InjectTargets[i].Domain) - } - if sc.InjectTargets[i].Path != tc.want.InjectTargets[i].Path { - t.Errorf("InjectTargets[%d].Path: got %q, want %q", - i, sc.InjectTargets[i].Path, tc.want.InjectTargets[i].Path) + if sc.InjectTargets[i] != tc.wantTargets[i] { + t.Errorf("InjectTargets[%d]: got %q, want %q", + i, sc.InjectTargets[i], tc.wantTargets[i]) } } }) @@ -269,42 +223,42 @@ func TestDefaultInjectTargetFromEnv(t *testing.T) { tests := []struct { name string environ []string - want *InjectTarget + want string }{ { name: "valid URL with trailing slash", environ: []string{"CODER_AGENT_URL=https://dev.coder.com/"}, - want: &InjectTarget{Domain: "dev.coder.com", Path: DefaultAIBridgePath}, + want: "domain=dev.coder.com path=" + DefaultAIBridgePath, }, { name: "valid URL without trailing slash", environ: []string{"CODER_AGENT_URL=https://dev.coder.com"}, - want: &InjectTarget{Domain: "dev.coder.com", Path: DefaultAIBridgePath}, + want: "domain=dev.coder.com path=" + DefaultAIBridgePath, }, { - name: "URL with port", // Ports are ignored in the rules engine, so we strip them here. + name: "URL with port", environ: []string{"CODER_AGENT_URL=https://dev.coder.com:8443/"}, - want: &InjectTarget{Domain: "dev.coder.com", Path: DefaultAIBridgePath}, + want: "domain=dev.coder.com path=" + DefaultAIBridgePath, }, { name: "unset variable", environ: []string{}, - want: nil, + want: "", }, { name: "empty value", environ: []string{"CODER_AGENT_URL="}, - want: nil, + want: "", }, { name: "no host in URL", environ: []string{"CODER_AGENT_URL=not-a-url"}, - want: nil, + want: "", }, { name: "other env vars present but not CODER_AGENT_URL", environ: []string{"CODER_URL=https://dev.coder.com/", "HOME=/home/user"}, - want: nil, + want: "", }, } @@ -313,20 +267,8 @@ func TestDefaultInjectTargetFromEnv(t *testing.T) { t.Parallel() got := DefaultInjectTargetFromEnv(tc.environ) - if tc.want == nil { - if got != nil { - t.Errorf("expected nil, got %+v", got) - } - return - } - if got == nil { - t.Fatalf("expected %+v, got nil", tc.want) - } - if got.Domain != tc.want.Domain { - t.Errorf("Domain: got %q, want %q", got.Domain, tc.want.Domain) - } - if got.Path != tc.want.Path { - t.Errorf("Path: got %q, want %q", got.Path, tc.want.Path) + if got != tc.want { + t.Errorf("got %q, want %q", got, tc.want) } }) } @@ -338,7 +280,7 @@ func TestBuildSessionCorrelation_TargetMerge(t *testing.T) { tests := []struct { name string cfg func() CliConfig - wantTargets []InjectTarget + wantTargets []string }{ { name: "YAML targets only", @@ -348,8 +290,8 @@ func TestBuildSessionCorrelation_TargetMerge(t *testing.T) { _ = c.InjectSessionIDTargets.Set("domain=yaml.example.com path=/yaml/*") return c }, - wantTargets: []InjectTarget{ - {Domain: "yaml.example.com", Path: "/yaml/*"}, + wantTargets: []string{ + "domain=yaml.example.com path=/yaml/*", }, }, { @@ -360,8 +302,8 @@ func TestBuildSessionCorrelation_TargetMerge(t *testing.T) { _ = c.InjectSessionIDTarget.Set("domain=cli.example.com path=/cli/*") return c }, - wantTargets: []InjectTarget{ - {Domain: "cli.example.com", Path: "/cli/*"}, + wantTargets: []string{ + "domain=cli.example.com path=/cli/*", }, }, { @@ -373,9 +315,9 @@ func TestBuildSessionCorrelation_TargetMerge(t *testing.T) { _ = c.InjectSessionIDTarget.Set("domain=cli.example.com path=/cli/*") return c }, - wantTargets: []InjectTarget{ - {Domain: "yaml.example.com", Path: "/yaml/*"}, - {Domain: "cli.example.com", Path: "/cli/*"}, + wantTargets: []string{ + "domain=yaml.example.com path=/yaml/*", + "domain=cli.example.com path=/cli/*", }, }, { @@ -389,11 +331,11 @@ func TestBuildSessionCorrelation_TargetMerge(t *testing.T) { _ = c.InjectSessionIDTarget.Set("domain=cli2.example.com") return c }, - wantTargets: []InjectTarget{ - {Domain: "yaml1.example.com"}, - {Domain: "yaml2.example.com"}, - {Domain: "cli1.example.com"}, - {Domain: "cli2.example.com"}, + wantTargets: []string{ + "domain=yaml1.example.com", + "domain=yaml2.example.com", + "domain=cli1.example.com", + "domain=cli2.example.com", }, }, } @@ -411,13 +353,9 @@ func TestBuildSessionCorrelation_TargetMerge(t *testing.T) { len(sc.InjectTargets), len(tc.wantTargets)) } for i := range sc.InjectTargets { - if sc.InjectTargets[i].Domain != tc.wantTargets[i].Domain { - t.Errorf("InjectTargets[%d].Domain: got %q, want %q", - i, sc.InjectTargets[i].Domain, tc.wantTargets[i].Domain) - } - if sc.InjectTargets[i].Path != tc.wantTargets[i].Path { - t.Errorf("InjectTargets[%d].Path: got %q, want %q", - i, sc.InjectTargets[i].Path, tc.wantTargets[i].Path) + if sc.InjectTargets[i] != tc.wantTargets[i] { + t.Errorf("InjectTargets[%d]: got %q, want %q", + i, sc.InjectTargets[i], tc.wantTargets[i]) } } }) @@ -431,23 +369,23 @@ func TestBuildSessionCorrelation_AgentURLFallback(t *testing.T) { name string cfg func() CliConfig environ []string - wantTargets []InjectTarget + wantTargets []string wantErr bool }{ { - name: "enabled, no explicit targets, CODER_AGENT_URL set → auto-derived", + name: "enabled, no explicit targets, CODER_AGENT_URL set -> auto-derived", cfg: func() CliConfig { c := baseCliConfig() _ = c.SessionCorrelationEnabled.Set("true") return c }, environ: []string{"CODER_AGENT_URL=https://dev.coder.com/"}, - wantTargets: []InjectTarget{ - {Domain: "dev.coder.com", Path: DefaultAIBridgePath}, + wantTargets: []string{ + "domain=dev.coder.com path=" + DefaultAIBridgePath, }, }, { - name: "enabled, no explicit targets, CODER_AGENT_URL absent → error", + name: "enabled, no explicit targets, CODER_AGENT_URL absent -> error", cfg: func() CliConfig { c := baseCliConfig() _ = c.SessionCorrelationEnabled.Set("true") @@ -465,12 +403,12 @@ func TestBuildSessionCorrelation_AgentURLFallback(t *testing.T) { return c }, environ: []string{"CODER_AGENT_URL=https://dev.coder.com/"}, - wantTargets: []InjectTarget{ - {Domain: "custom.example.com", Path: ""}, + wantTargets: []string{ + "domain=custom.example.com", }, }, { - name: "disabled, CODER_AGENT_URL absent → valid (no targets needed)", + name: "disabled, CODER_AGENT_URL absent -> valid (no targets needed)", cfg: func() CliConfig { return baseCliConfig() }, @@ -498,13 +436,9 @@ func TestBuildSessionCorrelation_AgentURLFallback(t *testing.T) { len(sc.InjectTargets), len(tc.wantTargets)) } for i := range sc.InjectTargets { - if sc.InjectTargets[i].Domain != tc.wantTargets[i].Domain { - t.Errorf("InjectTargets[%d].Domain: got %q, want %q", - i, sc.InjectTargets[i].Domain, tc.wantTargets[i].Domain) - } - if sc.InjectTargets[i].Path != tc.wantTargets[i].Path { - t.Errorf("InjectTargets[%d].Path: got %q, want %q", - i, sc.InjectTargets[i].Path, tc.wantTargets[i].Path) + if sc.InjectTargets[i] != tc.wantTargets[i] { + t.Errorf("InjectTargets[%d]: got %q, want %q", + i, sc.InjectTargets[i], tc.wantTargets[i]) } } }) diff --git a/proxy/proxy.go b/proxy/proxy.go index cc99f34..bb95f52 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -13,7 +13,6 @@ import ( "net/http" _ "net/http/pprof" "net/url" - "path" "strconv" "strings" "sync/atomic" @@ -32,10 +31,10 @@ type Server struct { tlsConfig *tls.Config httpPort int started atomic.Bool - sessionCorrelation config.SessionCorrelationConfig - sessionID string - seqCounter audit.SequenceCounter - forwardTransport http.RoundTripper // nil means use http.DefaultTransport + injectEngine *rulesengine.Engine // nil when session correlation is disabled + sessionID string + seqCounter audit.SequenceCounter + forwardTransport http.RoundTripper listener net.Listener pprofServer *http.Server @@ -55,6 +54,10 @@ type Config struct { // SessionCorrelation controls header injection for AI Bridge // correlation. See config.SessionCorrelationConfig for details. SessionCorrelation config.SessionCorrelationConfig + // InjectEngine, if non-nil, is used to evaluate whether outgoing + // requests match configured inject targets. Built from + // SessionCorrelation.InjectTargets using rulesengine.ParseAllowSpecs. + InjectEngine *rulesengine.Engine // SessionID is the boundary session UUID injected as a header // on matching requests. SessionID string @@ -74,8 +77,8 @@ func NewProxyServer(config Config) *Server { httpPort: config.HTTPPort, pprofEnabled: config.PprofEnabled, pprofPort: config.PprofPort, - sessionCorrelation: config.SessionCorrelation, - sessionID: config.SessionID, + injectEngine: config.InjectEngine, + sessionID: config.SessionID, forwardTransport: config.ForwardTransport, } } @@ -315,33 +318,18 @@ func (p *Server) processHTTPRequest(conn net.Conn, req *http.Request, https bool p.forwardRequest(conn, req, https, seqNum) } -// shouldInjectHeaders reports whether the request to the given host -// and path matches any configured inject target. When session -// correlation is disabled or no targets match it returns false. -func (p *Server) shouldInjectHeaders(host, reqPath string) bool { - if !p.sessionCorrelation.Enabled { +// shouldInjectHeaders reports whether the request URL matches any +// configured inject target. Inject targets are evaluated using the same +// rulesengine matching as --allow rules so that domain/path semantics +// are unified. When session correlation is disabled or no targets match +// it returns false. +func (p *Server) shouldInjectHeaders(fullURL string) bool { + if p.injectEngine == nil { return false } - h := host - if stripped, _, err := net.SplitHostPort(h); err == nil { - h = stripped - } - for _, target := range p.sessionCorrelation.InjectTargets { - targetHost := target.Domain - if stripped, _, err := net.SplitHostPort(targetHost); err == nil { - targetHost = stripped - } - if !strings.EqualFold(targetHost, h) { - continue - } - if target.Path == "" { - return true - } - if matched, _ := path.Match(target.Path, reqPath); matched { - return true - } - } - return false + // Inject targets do not restrict by HTTP method, so we pass an empty + // method. Rules without a method pattern match all methods. + return p.injectEngine.Evaluate("", fullURL).Allowed } func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool, seqNum int32) { @@ -390,7 +378,7 @@ func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool, se } } - if p.shouldInjectHeaders(req.Host, req.URL.Path) { + if p.shouldInjectHeaders(targetURL.String()) { newReq.Header.Set(config.SessionIDHeaderName, p.sessionID) newReq.Header.Set(config.SequenceNumberHeaderName, strconv.Itoa(int(seqNum))) } diff --git a/proxy/proxy_framework_test.go b/proxy/proxy_framework_test.go index dd94726..87edf26 100644 --- a/proxy/proxy_framework_test.go +++ b/proxy/proxy_framework_test.go @@ -180,6 +180,15 @@ func (pt *ProxyTest) Start() *ProxyTest { } } + // Build inject engine from session correlation targets. + var injectEngine *rulesengine.Engine + if pt.sessionCorrelation.Enabled && len(pt.sessionCorrelation.InjectTargets) > 0 { + injectRules, err := rulesengine.ParseAllowSpecs(pt.sessionCorrelation.InjectTargets) + require.NoError(pt.t, err, "Failed to parse inject target rules") + eng := rulesengine.NewRuleEngine(injectRules, logger) + injectEngine = &eng + } + pt.server = NewProxyServer(Config{ HTTPPort: pt.port, RuleEngine: ruleEngine, @@ -187,6 +196,7 @@ func (pt *ProxyTest) Start() *ProxyTest { Logger: logger, TLSConfig: tlsConfig, SessionCorrelation: pt.sessionCorrelation, + InjectEngine: injectEngine, SessionID: pt.sessionID, ForwardTransport: pt.forwardTransport, }) diff --git a/proxy/proxy_session_correlation_test.go b/proxy/proxy_session_correlation_test.go index 631774b..45f7ac6 100644 --- a/proxy/proxy_session_correlation_test.go +++ b/proxy/proxy_session_correlation_test.go @@ -52,7 +52,7 @@ func TestSessionCorrelation_MatchedDomain(t *testing.T) { WithAllowedDomain(backendURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, - InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, + InjectTargets: []string{"domain=" + backendURL.Hostname()}, }), WithSessionID("test-session-id-1234"), ).Start() @@ -84,14 +84,14 @@ func TestSessionCorrelation_MatchedDomainWithPort(t *testing.T) { require.Equal(t, backendURL.Hostname()+":"+backendURL.Port(), backendURL.Host, "backendURL.Host must be host:port") - // Configure the inject target with host:port instead of just the hostname. - // The port must be stripped during matching so the request still matches. + // Configure the inject target using just the hostname (without port). + // The rulesengine strips ports during matching, same as for --allow rules. pt := NewProxyTest(t, WithCertManager(t.TempDir()), WithAllowedDomain(backendURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, - InjectTargets: []config.InjectTarget{{Domain: backendURL.Host}}, + InjectTargets: []string{"domain=" + backendURL.Hostname()}, }), WithSessionID("test-session-id-1234"), ).Start() @@ -104,7 +104,7 @@ func TestSessionCorrelation_MatchedDomainWithPort(t *testing.T) { got := backend.receivedHeaders() assert.Equal(t, "test-session-id-1234", got.Get(config.SessionIDHeaderName), - "session ID header must be injected when inject target domain includes a port") + "session ID header must be injected when inject target domain matches request hostname") assert.Equal(t, "0", got.Get(config.SequenceNumberHeaderName), "sequence number header must start at 0") } @@ -121,7 +121,7 @@ func TestSessionCorrelation_UnmatchedDomain(t *testing.T) { WithAllowedDomain(backendURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, - InjectTargets: []config.InjectTarget{{Domain: "other-domain.example.com"}}, + InjectTargets: []string{"domain=other-domain.example.com"}, }), WithSessionID("test-session-id-1234"), ).Start() @@ -151,7 +151,7 @@ func TestSessionCorrelation_Disabled(t *testing.T) { WithAllowedDomain(backendURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: false, - InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, + InjectTargets: []string{"domain=" + backendURL.Hostname()}, }), WithSessionID("test-session-id-1234"), ).Start() @@ -181,7 +181,7 @@ func TestSessionCorrelation_OverwritesClientValue(t *testing.T) { WithAllowedDomain(backendURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, - InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, + InjectTargets: []string{"domain=" + backendURL.Hostname()}, }), WithSessionID("real-session-id"), ).Start() @@ -217,11 +217,8 @@ func TestSessionCorrelation_PathMatching(t *testing.T) { WithCertManager(t.TempDir()), WithAllowedDomain(backendURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ - Enabled: true, - InjectTargets: []config.InjectTarget{{ - Domain: backendURL.Hostname(), - Path: "/api/*", - }}, + Enabled: true, + InjectTargets: []string{"domain=" + backendURL.Hostname() + " path=/api/*"}, }), WithSessionID("test-session-id"), ).Start() @@ -262,7 +259,7 @@ func TestSessionCorrelation_SequenceNumberIncrements(t *testing.T) { WithAllowedDomain(backendURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, - InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, + InjectTargets: []string{"domain=" + backendURL.Hostname()}, }), WithSessionID("test-session-id"), ).Start()