Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions audit/multi_auditor.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ func SetupAuditor(ctx context.Context, logger *slog.Logger, disableAuditLogs boo
}
agentWillProxy := !os.IsNotExist(err)
if agentWillProxy {
seq := &SequenceCounter{}
socketAuditor := NewSocketAuditor(logger, logProxySocketPath, sessionID, seq)
socketAuditor := NewSocketAuditor(logger, logProxySocketPath, sessionID)
go socketAuditor.Loop(ctx)
auditors = append(auditors, socketAuditor)
} else {
Expand Down
5 changes: 5 additions & 0 deletions audit/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,9 @@ type Request struct {
Host string
Allowed bool
Rule string // The rule that matched (if any)

// SequenceNumber is the sequence number assigned to this audit event
// by the proxy. It is monotonically increasing within a session and
// is shared with any injected HTTP header so both carry the same value.
SequenceNumber int32
}
6 changes: 2 additions & 4 deletions audit/socket_auditor.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ type SocketAuditor struct {
batchTimerDuration time.Duration
socketPath string
sessionID uuid.UUID
seq *SequenceCounter

droppedChannelFull atomic.Int64
droppedBatchFull atomic.Int64
Expand All @@ -48,7 +47,7 @@ type SocketAuditor struct {
// NewSocketAuditor creates a new SocketAuditor that sends logs to the agent's
// boundary log proxy socket after SocketAuditor.Loop is called. The socket path
// is read from EnvAuditSocketPath, falling back to defaultAuditSocketPath.
func NewSocketAuditor(logger *slog.Logger, socketPath string, sessionID uuid.UUID, seq *SequenceCounter) *SocketAuditor {
func NewSocketAuditor(logger *slog.Logger, socketPath string, sessionID uuid.UUID) *SocketAuditor {
// This channel buffer size intends to allow enough buffering for bursty
// AI agent network requests while a batch is being sent to the workspace
// agent.
Expand All @@ -64,7 +63,6 @@ func NewSocketAuditor(logger *slog.Logger, socketPath string, sessionID uuid.UUI
batchTimerDuration: defaultBatchTimerDuration,
socketPath: socketPath,
sessionID: sessionID,
seq: seq,
}
}

Expand All @@ -84,7 +82,7 @@ func (s *SocketAuditor) AuditRequest(req Request) {
log := &agentproto.BoundaryLog{
Allowed: req.Allowed,
Time: timestamppb.Now(),
SequenceNumber: s.seq.Next(),
SequenceNumber: req.SequenceNumber,
Resource: &agentproto.BoundaryLog_HttpRequest_{HttpRequest: httpReq},
}

Expand Down
8 changes: 2 additions & 6 deletions audit/socket_auditor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ func TestSocketAuditor_Loop_RetriesOnConnectionFailure(t *testing.T) {
batchSize: defaultBatchSize,
batchTimerDuration: time.Hour, // Ensure timer doesn't interfere with the test
sessionID: uuid.MustParse("00000000-0000-4000-8000-000000000001"),
seq: &SequenceCounter{},
}

// Set up hook to detect flush attempts
Expand Down Expand Up @@ -363,7 +362,6 @@ func TestSocketAuditor_Loop_ReportsBatchFullDrops(t *testing.T) {
batchSize: defaultBatchSize,
batchTimerDuration: time.Hour,
sessionID: uuid.MustParse("00000000-0000-4000-8000-000000000001"),
seq: &SequenceCounter{},
}

flushed := make(chan struct{}, 4)
Expand Down Expand Up @@ -483,7 +481,7 @@ func TestSocketAuditor_AuditRequest_SequenceNumberIncrements(t *testing.T) {
auditor := setupSocketAuditor(t)

for i := range 5 {
auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true})
auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true, SequenceNumber: int32(i)})

select {
case log := <-auditor.logCh:
Expand All @@ -509,7 +507,7 @@ func TestSocketAuditor_Loop_FlushIncludesSessionID(t *testing.T) {

// Fill a batch to trigger a flush.
for i := 0; i < auditor.batchSize; i++ {
auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true})
auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true, SequenceNumber: int32(i)})
}

select {
Expand Down Expand Up @@ -549,7 +547,6 @@ func setupSocketAuditor(t *testing.T) *SocketAuditor {
batchSize: defaultBatchSize,
batchTimerDuration: defaultBatchTimerDuration,
sessionID: uuid.MustParse("00000000-0000-4000-8000-000000000001"),
seq: &SequenceCounter{},
}
}

Expand Down Expand Up @@ -580,7 +577,6 @@ func setupTestAuditor(t *testing.T) (*SocketAuditor, net.Conn) {
batchSize: defaultBatchSize,
batchTimerDuration: defaultBatchTimerDuration,
sessionID: uuid.MustParse("00000000-0000-4000-8000-000000000001"),
seq: &SequenceCounter{},
}

return auditor, serverConn
Expand Down
2 changes: 1 addition & 1 deletion cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func BaseCommand(version string) *serpent.Command {
{
Flag: "session-id-inject-target",
Env: "BOUNDARY_SESSION_ID_INJECT_TARGET",
Description: `Inject target for session correlation headers. Repeat the flag once per target; each value describes exactly one target. Format: "domain=<host> [path=<glob>]". Example: --session-id-inject-target "domain=prod.coder.com path=/api/v2/aibridge/*"`,
Description: `Inject target for session correlation headers. Repeat the flag once per target; each value describes exactly one target. Format: "domain=<host> [path=<glob>]". Example: --session-id-inject-target "domain=prod.coder.com path=/api/v2/aibridge/*".`,
Value: &cliConfig.InjectSessionIDTarget,
YAML: "", // CLI only, YAML uses session_id_inject_targets.
},
Expand Down
25 changes: 9 additions & 16 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,26 +141,19 @@ func NewAppConfigFromCliConfig(cfg CliConfig, targetCMD []string, environ []stri
}, nil
}

// buildSessionCorrelation merges CLI and YAML inject target sources,
// parses each target string, and validates the resulting configuration.
// environ is passed explicitly (rather than reading os.Environ inside)
// so that callers and tests can supply a controlled environment.
// buildSessionCorrelation merges CLI and YAML inject target sources
// and validates the resulting configuration. Inject targets use the same
// "domain=... path=..." syntax as --allow rules so that matching semantics
// are identical. environ is passed explicitly (rather than reading
// os.Environ inside) so that callers and tests can supply a controlled
// environment.
func buildSessionCorrelation(cfg CliConfig, environ []string) (SessionCorrelationConfig, error) {
// Merge YAML targets with CLI targets.
rawTargets := append(cfg.InjectSessionIDTargets.Value(), cfg.InjectSessionIDTarget.Value()...)

var targets []InjectTarget
for _, raw := range rawTargets {
t, err := ParseInjectTarget(raw)
if err != nil {
return SessionCorrelationConfig{}, err
}
targets = append(targets, t)
}
targets := append(cfg.InjectSessionIDTargets.Value(), cfg.InjectSessionIDTarget.Value()...)

if len(targets) == 0 && cfg.SessionCorrelationEnabled.Value() {
if t := DefaultInjectTargetFromEnv(environ); t != nil {
targets = []InjectTarget{*t}
if t := DefaultInjectTargetFromEnv(environ); t != "" {
targets = []string{t}
}
}

Expand Down
109 changes: 41 additions & 68 deletions config/session_correlation.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"fmt"
"net/url"
"strings"

"github.com/coder/boundary/rulesengine"
)

// Header names and paths for session correlation.
Expand All @@ -27,14 +29,6 @@ const (
CoderAgentURLEnv = "CODER_AGENT_URL"
)

// InjectTarget represents a parsed target for session correlation header
// injection. Requests matching the domain (and optional path glob) will
// receive the session ID and sequence number headers.
type InjectTarget struct {
Domain string
Path string
}

// SessionCorrelationConfig holds configuration for session correlation
// header injection. When enabled, boundary injects its session ID and
// sequence number as custom headers on matching outbound requests so
Expand All @@ -45,63 +39,22 @@ type SessionCorrelationConfig struct {
// Deployments without AI Bridge in front should set this to false.
Enabled bool

// InjectTargets is the list of domain/path patterns that should
// receive session correlation headers.
InjectTargets []InjectTarget
}

// ParseInjectTarget parses a string of the form "domain=... path=..."
// into an InjectTarget. The domain key is required; path is optional.
func ParseInjectTarget(raw string) (InjectTarget, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return InjectTarget{}, fmt.Errorf("inject target must not be empty")
}

var target InjectTarget
seen := make(map[string]bool)
for _, part := range strings.Fields(raw) {
key, value, ok := strings.Cut(part, "=")
if !ok {
return InjectTarget{}, fmt.Errorf(
"inject target: malformed key-value pair %q, expected key=value", part,
)
}
if seen[key] {
return InjectTarget{}, fmt.Errorf(
"inject target: duplicate key %q (use separate flags for multiple targets)", key,
)
}
seen[key] = true
switch key {
case "domain":
if value == "" {
return InjectTarget{}, fmt.Errorf("inject target: domain must not be empty")
}
target.Domain = value
case "path":
target.Path = value
default:
return InjectTarget{}, fmt.Errorf("inject target: unknown key %q", key)
}
}

if target.Domain == "" {
return InjectTarget{}, fmt.Errorf("inject target: domain is required")
}

return target, nil
// InjectTargets is the list of raw rule specs (same syntax as --allow)
// that should receive session correlation headers. Each string uses the
// rulesengine "domain=... path=..." format so that inject target
// matching is identical to allow-rule matching.
InjectTargets []string
}

// DefaultInjectTargetFromEnv derives an InjectTarget from the CODER_AGENT_URL
// variable in the provided environment slice. It returns nil if the variable is
// absent, empty, or not a valid URL with a host. The derived target uses
// DefaultAIBridgePath as the path glob so that all AI Bridge traffic on the
// control-plane host is matched.
// DefaultInjectTargetFromEnv derives an inject target rule string from the
// CODER_AGENT_URL variable in the provided environment slice. It returns ""
// if the variable is absent, empty, or not a valid URL with a host. The
// derived target uses DefaultAIBridgePath as the path glob so that all AI
// Bridge traffic on the control-plane host is matched.
//
// The environ parameter is accepted rather than reading os.Environ directly so
// that callers (and tests) can supply an arbitrary environment.
func DefaultInjectTargetFromEnv(environ []string) *InjectTarget {
func DefaultInjectTargetFromEnv(environ []string) string {
var raw string
for _, e := range environ {
k, v, ok := strings.Cut(e, "=")
Expand All @@ -111,23 +64,22 @@ func DefaultInjectTargetFromEnv(environ []string) *InjectTarget {
}
}
if raw == "" {
return nil
return ""
}

u, err := url.Parse(raw)
if err != nil || u.Host == "" {
return nil
return ""
}

return &InjectTarget{
Domain: u.Hostname(),
Path: DefaultAIBridgePath,
}
return fmt.Sprintf("domain=%s path=%s", u.Hostname(), DefaultAIBridgePath)
}

// ValidateSessionCorrelation checks that the session correlation config
// is internally consistent. It returns an error describing the first
// problem found, or nil if the config is valid.
// is internally consistent. When enabled it verifies that at least one
// inject target is configured and that every target string is a valid
// rulesengine rule. It returns an error describing the first problem
// found, or nil if the config is valid.
func ValidateSessionCorrelation(cfg SessionCorrelationConfig) error {
if !cfg.Enabled {
return nil
Expand All @@ -139,5 +91,26 @@ func ValidateSessionCorrelation(cfg SessionCorrelationConfig) error {
)
}

// Reject empty target strings before passing to the parser.
for _, t := range cfg.InjectTargets {
if strings.TrimSpace(t) == "" {
return fmt.Errorf("inject target: must not be empty")
}
}

// Validate each target parses as a rulesengine rule.
rules, err := rulesengine.ParseAllowSpecs(cfg.InjectTargets)
if err != nil {
return fmt.Errorf("inject target: %w", err)
}

// Inject targets must specify a domain; path-only rules are not
// meaningful for header injection.
for i, r := range rules {
if r.HostPattern == nil {
return fmt.Errorf("inject target %q: domain is required", cfg.InjectTargets[i])
}
}

return nil
}
Loading
Loading