diff --git a/internal/guard/init.go b/internal/guard/init.go new file mode 100644 index 000000000..813ef7f11 --- /dev/null +++ b/internal/guard/init.go @@ -0,0 +1,51 @@ +package guard + +import ( + "context" + "fmt" + + "github.com/github/gh-aw-mcpg/internal/difc" + "github.com/github/gh-aw-mcpg/internal/logger" +) + +var logGuardInit = logger.New("guard:init") + +// RunLabelAgent executes the standard LabelAgent initialization pipeline: +// 1. Calls the guard's LabelAgent method with the provided pre-built payload. +// 2. Validates the result is non-nil. +// 3. Applies the returned agent labels to agentLabels. +// 4. Parses and returns the effective enforcement mode. +// +// It returns the effective enforcement mode, the raw LabelAgentResult (so +// callers can inspect NormalizedPolicy, DIFCMode, etc.), and any error. +// On error, defaultMode is returned unchanged so the caller's mode is unaffected. +func RunLabelAgent( + ctx context.Context, + g Guard, + payload interface{}, + backend BackendCaller, + caps *difc.Capabilities, + agentLabels *difc.AgentLabels, + defaultMode difc.EnforcementMode, +) (difc.EnforcementMode, *LabelAgentResult, error) { + logGuardInit.Printf("Calling LabelAgent: guard=%s", g.Name()) + + result, err := g.LabelAgent(ctx, payload, backend, caps) + if err != nil { + logGuardInit.Printf("LabelAgent failed: guard=%s, error=%v", g.Name(), err) + return defaultMode, nil, fmt.Errorf("LabelAgent failed: %w", err) + } + if result == nil { + logGuardInit.Printf("LabelAgent returned nil result: guard=%s", g.Name()) + return defaultMode, nil, fmt.Errorf("LabelAgent returned nil result") + } + + mode, err := ApplyLabelAgentResult(result, agentLabels, defaultMode) + if err != nil { + logGuardInit.Printf("LabelAgent result invalid: guard=%s, error=%v", g.Name(), err) + return defaultMode, nil, fmt.Errorf("LabelAgent result invalid: %w", err) + } + + logGuardInit.Printf("LabelAgent completed: guard=%s, mode=%s", g.Name(), mode) + return mode, result, nil +} diff --git a/internal/guard/init_test.go b/internal/guard/init_test.go new file mode 100644 index 000000000..c0d0e1ddf --- /dev/null +++ b/internal/guard/init_test.go @@ -0,0 +1,133 @@ +package guard + +import ( + "context" + "errors" + "testing" + + "github.com/github/gh-aw-mcpg/internal/difc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// runLabelAgentStubGuard is a minimal Guard implementation used to test RunLabelAgent. +type runLabelAgentStubGuard struct { + name string + labelAgentResult *LabelAgentResult + labelAgentErr error +} + +func (g *runLabelAgentStubGuard) Name() string { return g.name } + +func (g *runLabelAgentStubGuard) LabelAgent(_ context.Context, _ interface{}, _ BackendCaller, _ *difc.Capabilities) (*LabelAgentResult, error) { + return g.labelAgentResult, g.labelAgentErr +} + +func (g *runLabelAgentStubGuard) LabelResource(_ context.Context, _ string, _ interface{}, _ BackendCaller, _ *difc.Capabilities) (*difc.LabeledResource, difc.OperationType, error) { + return difc.NewLabeledResource("stub"), difc.OperationRead, nil +} + +func (g *runLabelAgentStubGuard) LabelResponse(_ context.Context, _ string, _ interface{}, _ BackendCaller, _ *difc.Capabilities) (difc.LabeledData, error) { + return nil, nil +} + +// noopRunBackendCaller is a BackendCaller that always returns nil. +type noopRunBackendCaller struct{} + +func (n *noopRunBackendCaller) CallTool(_ context.Context, _ string, _ interface{}) (interface{}, error) { + return nil, nil +} + +func TestRunLabelAgent_Success(t *testing.T) { + g := &runLabelAgentStubGuard{ + name: "test-guard", + labelAgentResult: &LabelAgentResult{ + Agent: AgentLabelsPayload{Secrecy: []string{"private:org/repo"}, Integrity: []string{"approved"}}, + DIFCMode: difc.ModeFilter, + }, + } + caps := difc.NewCapabilities() + agentLabels := difc.NewAgentRegistryWithDefaults(nil, nil).GetOrCreate("test-agent") + defaultMode := difc.EnforcementStrict + + mode, result, err := RunLabelAgent(context.Background(), g, map[string]interface{}{"policy": "test"}, &noopRunBackendCaller{}, caps, agentLabels, defaultMode) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, difc.EnforcementFilter, mode, "mode should be overridden by guard response") + assert.Equal(t, difc.ModeFilter, result.DIFCMode) +} + +func TestRunLabelAgent_GuardError(t *testing.T) { + g := &runLabelAgentStubGuard{ + name: "error-guard", + labelAgentErr: errors.New("wasm runtime error"), + } + caps := difc.NewCapabilities() + agentLabels := difc.NewAgentRegistryWithDefaults(nil, nil).GetOrCreate("test-agent") + defaultMode := difc.EnforcementFilter + + mode, result, err := RunLabelAgent(context.Background(), g, nil, &noopRunBackendCaller{}, caps, agentLabels, defaultMode) + + require.Error(t, err) + assert.Contains(t, err.Error(), "LabelAgent failed") + assert.Contains(t, err.Error(), "wasm runtime error") + assert.Nil(t, result) + assert.Equal(t, defaultMode, mode, "defaultMode should be returned on error") +} + +func TestRunLabelAgent_NilResult(t *testing.T) { + g := &runLabelAgentStubGuard{ + name: "nil-result-guard", + labelAgentResult: nil, + } + caps := difc.NewCapabilities() + agentLabels := difc.NewAgentRegistryWithDefaults(nil, nil).GetOrCreate("test-agent") + defaultMode := difc.EnforcementStrict + + mode, result, err := RunLabelAgent(context.Background(), g, nil, &noopRunBackendCaller{}, caps, agentLabels, defaultMode) + + require.Error(t, err) + assert.Contains(t, err.Error(), "LabelAgent returned nil result") + assert.Nil(t, result) + assert.Equal(t, defaultMode, mode) +} + +func TestRunLabelAgent_InvalidDIFCMode(t *testing.T) { + g := &runLabelAgentStubGuard{ + name: "bad-mode-guard", + labelAgentResult: &LabelAgentResult{ + Agent: AgentLabelsPayload{Secrecy: []string{}, Integrity: []string{}}, + DIFCMode: "not-a-real-mode", + }, + } + caps := difc.NewCapabilities() + agentLabels := difc.NewAgentRegistryWithDefaults(nil, nil).GetOrCreate("test-agent") + defaultMode := difc.EnforcementFilter + + mode, result, err := RunLabelAgent(context.Background(), g, nil, &noopRunBackendCaller{}, caps, agentLabels, defaultMode) + + require.Error(t, err) + assert.Contains(t, err.Error(), "LabelAgent result invalid") + assert.Nil(t, result) + assert.Equal(t, defaultMode, mode) +} + +func TestRunLabelAgent_EmptyDIFCModePreservesDefault(t *testing.T) { + g := &runLabelAgentStubGuard{ + name: "empty-mode-guard", + labelAgentResult: &LabelAgentResult{ + Agent: AgentLabelsPayload{Secrecy: []string{}, Integrity: []string{}}, + DIFCMode: "", // empty → preserve defaultMode + }, + } + caps := difc.NewCapabilities() + agentLabels := difc.NewAgentRegistryWithDefaults(nil, nil).GetOrCreate("test-agent") + defaultMode := difc.EnforcementStrict + + mode, result, err := RunLabelAgent(context.Background(), g, nil, &noopRunBackendCaller{}, caps, agentLabels, defaultMode) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, defaultMode, mode, "empty DIFCMode should preserve defaultMode") +} diff --git a/internal/httputil/httputil.go b/internal/httputil/httputil.go index 2c55c2a98..00ab405c0 100644 --- a/internal/httputil/httputil.go +++ b/internal/httputil/httputil.go @@ -33,6 +33,16 @@ func WriteJSONResponse(w http.ResponseWriter, statusCode int, body interface{}) } } +// WriteErrorResponse writes a JSON error response with a consistent +// {"error": code, "message": message} shape. Both the server and proxy packages +// should use this helper so that API consumers always receive the same error shape. +func WriteErrorResponse(w http.ResponseWriter, statusCode int, code, message string) { + WriteJSONResponse(w, statusCode, map[string]string{ + "error": code, + "message": message, + }) +} + // IsTransientHTTPError returns true for status codes that indicate a temporary // server-side condition (rate-limiting or transient failure) worth retrying. func IsTransientHTTPError(statusCode int) bool { diff --git a/internal/httputil/httputil_test.go b/internal/httputil/httputil_test.go index 0754bfdc9..3843c9003 100644 --- a/internal/httputil/httputil_test.go +++ b/internal/httputil/httputil_test.go @@ -310,3 +310,31 @@ func TestApplyGitHubAPIHeaders(t *testing.T) { assert.Equal(t, GitHubUserAgent, req.Header.Get("User-Agent")) }) } + +func TestWriteErrorResponse(t *testing.T) { + tests := []struct { + name string + statusCode int + code string + message string + }{ + {name: "400", statusCode: http.StatusBadRequest, code: "bad_request", message: "malformed input"}, + {name: "403", statusCode: http.StatusForbidden, code: "difc_forbidden", message: "DIFC policy violation"}, + {name: "500", statusCode: http.StatusInternalServerError, code: "internal_error", message: "unexpected error"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + WriteErrorResponse(w, tt.statusCode, tt.code, tt.message) + + assert.Equal(t, tt.statusCode, w.Code) + assert.Equal(t, "application/json", w.Header().Get("Content-Type")) + + var body map[string]string + require.NoError(t, json.NewDecoder(w.Body).Decode(&body)) + assert.Equal(t, tt.code, body["error"]) + assert.Equal(t, tt.message, body["message"]) + }) + } +} diff --git a/internal/proxy/handler.go b/internal/proxy/handler.go index eb0478bbc..a84bf617b 100644 --- a/internal/proxy/handler.go +++ b/internal/proxy/handler.go @@ -26,10 +26,10 @@ import ( var logHandler = logger.New("proxy:handler") // writeDIFCForbidden writes a 403 JSON response for DIFC policy violations. +// Uses the shared WriteErrorResponse helper so that the response shape is consistent +// with all other error responses in the gateway ({"error": ..., "message": ...}). func writeDIFCForbidden(w http.ResponseWriter, message string) { - httputil.WriteJSONResponse(w, http.StatusForbidden, map[string]string{ - "message": message, - }) + httputil.WriteErrorResponse(w, http.StatusForbidden, "difc_forbidden", message) } // proxyHandler implements http.Handler and runs the DIFC pipeline on proxied requests. @@ -157,12 +157,12 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa } // **Phase 0: Get agent labels** - agentLabels := s.agentRegistry.GetOrCreate("proxy") + agentLabels := s.AgentRegistry.GetOrCreate("proxy") logHandler.Printf("[DIFC] Phase 0: agent secrecy=%v integrity=%v", agentLabels.GetSecrecyTags(), agentLabels.GetIntegrityTags()) // **Phase 1: Guard labels the resource** - resource, operation, err := s.guard.LabelResource(ctx, toolName, args, backend, s.capabilities) + resource, operation, err := s.guard.LabelResource(ctx, toolName, args, backend, s.Capabilities) if err != nil { logHandler.Printf("[DIFC] Phase 1 failed: %v", err) // On labeling failure, fail closed to prevent enforcement bypass @@ -175,7 +175,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa resource.Secrecy.Label.GetTags(), resource.Integrity.Label.GetTags()) // **Phase 2: Coarse-grained access check** - evalResult := s.evaluator.Evaluate(agentLabels.Secrecy, agentLabels.Integrity, resource, operation) + evalResult := s.Evaluator.Evaluate(agentLabels.Secrecy, agentLabels.Integrity, resource, operation) if !evalResult.IsAllowed() { if difc.ShouldBypassCoarseDeny(operation) { @@ -235,7 +235,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa ctx = guard.SetRequestStateInContext(ctx, map[string]interface{}{ "tool_args": args, }) - labeledData, err := s.guard.LabelResponse(ctx, toolName, responseData, backend, s.capabilities) + labeledData, err := s.guard.LabelResponse(ctx, toolName, responseData, backend, s.Capabilities) if err != nil { logHandler.Printf("[DIFC] Phase 4 failed: %v", err) // On labeling failure, use coarse-grained result @@ -252,7 +252,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa var useOriginalBody bool // GraphQL responses need original format preserved if labeledData != nil { if collection, ok := labeledData.(*difc.CollectionLabeledData); ok { - filtered := s.evaluator.FilterCollection( + filtered := s.Evaluator.FilterCollection( agentLabels.Secrecy, agentLabels.Integrity, collection, operation) logHandler.Printf("[DIFC] Phase 5: %d/%d items accessible", @@ -266,7 +266,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa } // Strict mode: block entire response if any item filtered - if difc.ShouldBlockFilteredResponse(s.enforcementMode, filtered.GetFilteredCount()) { + if difc.ShouldBlockFilteredResponse(s.Mode, filtered.GetFilteredCount()) { logHandler.Printf("[DIFC] STRICT: blocking response — %d filtered items", filtered.GetFilteredCount()) writeDIFCForbidden(w, fmt.Sprintf("DIFC policy violation: %d of %d items not accessible", filtered.GetFilteredCount(), filtered.TotalCount)) @@ -318,7 +318,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa } // **Phase 6: Label accumulation (propagate mode)** - if labeledData != nil && difc.ShouldAccumulateReadLabels(operation, s.enforcementMode) { + if labeledData != nil && difc.ShouldAccumulateReadLabels(operation, s.Mode) { overall := labeledData.Overall() agentLabels.AccumulateFromRead(overall) logHandler.Printf("[DIFC] Phase 6: accumulated labels") diff --git a/internal/proxy/handler_difc_test.go b/internal/proxy/handler_difc_test.go index 8bc3263e7..fd3b22b40 100644 --- a/internal/proxy/handler_difc_test.go +++ b/internal/proxy/handler_difc_test.go @@ -61,14 +61,16 @@ func publicResource() *difc.LabeledResource { func newTestServerWithStub(t *testing.T, upstreamURL string, g *stubGuard, mode difc.EnforcementMode) *Server { t.Helper() return &Server{ - guard: g, - evaluator: difc.NewEvaluatorWithMode(mode), - agentRegistry: difc.NewAgentRegistryWithDefaults(nil, nil), - capabilities: difc.NewCapabilities(), + guard: g, + DIFCComponents: difc.DIFCComponents{ + Mode: mode, + Evaluator: difc.NewEvaluatorWithMode(mode), + AgentRegistry: difc.NewAgentRegistryWithDefaults(nil, nil), + Capabilities: difc.NewCapabilities(), + }, githubAPIURL: upstreamURL, httpClient: &http.Client{}, guardInitialized: true, - enforcementMode: mode, } } @@ -78,14 +80,16 @@ func newTestServerWithPrivateAgent(t *testing.T, upstreamURL string, g *stubGuar t.Helper() reg := difc.NewAgentRegistryWithDefaults([]difc.Tag{"private:test-org/test-repo"}, nil) return &Server{ - guard: g, - evaluator: difc.NewEvaluatorWithMode(mode), - agentRegistry: reg, - capabilities: difc.NewCapabilities(), + guard: g, + DIFCComponents: difc.DIFCComponents{ + Mode: mode, + Evaluator: difc.NewEvaluatorWithMode(mode), + AgentRegistry: reg, + Capabilities: difc.NewCapabilities(), + }, githubAPIURL: upstreamURL, httpClient: &http.Client{}, guardInitialized: true, - enforcementMode: mode, } } @@ -494,14 +498,16 @@ func TestHandleWithDIFC_PropagateMode_AccumulatesLabels(t *testing.T) { } reg := difc.NewAgentRegistryWithDefaults(nil, nil) s := &Server{ - guard: g, - evaluator: difc.NewEvaluatorWithMode(difc.EnforcementPropagate), - agentRegistry: reg, - capabilities: difc.NewCapabilities(), + guard: g, + DIFCComponents: difc.DIFCComponents{ + Mode: difc.EnforcementPropagate, + Evaluator: difc.NewEvaluatorWithMode(difc.EnforcementPropagate), + AgentRegistry: reg, + Capabilities: difc.NewCapabilities(), + }, githubAPIURL: upstream.URL, httpClient: &http.Client{}, guardInitialized: true, - enforcementMode: difc.EnforcementPropagate, } h := &proxyHandler{server: s} diff --git a/internal/proxy/handler_test.go b/internal/proxy/handler_test.go index f8d7484cf..6c031d7fb 100644 --- a/internal/proxy/handler_test.go +++ b/internal/proxy/handler_test.go @@ -21,14 +21,16 @@ import ( func newTestServer(t *testing.T, upstreamURL string) *Server { t.Helper() return &Server{ - guard: guard.NewNoopGuard(), - evaluator: difc.NewEvaluatorWithMode(difc.EnforcementFilter), - agentRegistry: difc.NewAgentRegistryWithDefaults(nil, nil), - capabilities: difc.NewCapabilities(), + guard: guard.NewNoopGuard(), + DIFCComponents: difc.DIFCComponents{ + Mode: difc.EnforcementFilter, + Evaluator: difc.NewEvaluatorWithMode(difc.EnforcementFilter), + AgentRegistry: difc.NewAgentRegistryWithDefaults(nil, nil), + Capabilities: difc.NewCapabilities(), + }, githubAPIURL: upstreamURL, httpClient: &http.Client{}, guardInitialized: true, - enforcementMode: difc.EnforcementFilter, } } @@ -510,8 +512,8 @@ func TestHandleWithDIFC_StrictModeBlocksFilteredItems(t *testing.T) { // strict mode with NoopGuard (no labels set) — evaluator allows, no items filtered s := newTestServer(t, upstream.URL) - s.enforcementMode = difc.EnforcementStrict - s.evaluator = difc.NewEvaluatorWithMode(difc.EnforcementStrict) + s.Mode = difc.EnforcementStrict + s.Evaluator = difc.NewEvaluatorWithMode(difc.EnforcementStrict) h := &proxyHandler{server: s} req := httptest.NewRequest(http.MethodGet, "/search/issues?q=test", nil) diff --git a/internal/proxy/init_guard_policy_test.go b/internal/proxy/init_guard_policy_test.go index 9838bc912..1176d69ec 100644 --- a/internal/proxy/init_guard_policy_test.go +++ b/internal/proxy/init_guard_policy_test.go @@ -49,13 +49,15 @@ func defaultLabelAgentStub(difcMode string, secrecy, integrity []string) *labelA // newTestServerForInitGuardPolicy creates a minimal proxy.Server for testing initGuardPolicy. func newTestServerForInitGuardPolicy(g guard.Guard, mode difc.EnforcementMode) *Server { return &Server{ - guard: g, - evaluator: difc.NewEvaluatorWithMode(mode), - agentRegistry: difc.NewAgentRegistryWithDefaults(nil, nil), - capabilities: difc.NewCapabilities(), - githubAPIURL: "https://api.github.com", - httpClient: &http.Client{}, - enforcementMode: mode, + guard: g, + DIFCComponents: difc.DIFCComponents{ + Mode: mode, + Evaluator: difc.NewEvaluatorWithMode(mode), + AgentRegistry: difc.NewAgentRegistryWithDefaults(nil, nil), + Capabilities: difc.NewCapabilities(), + }, + githubAPIURL: "https://api.github.com", + httpClient: &http.Client{}, } } @@ -143,7 +145,7 @@ func TestInitGuardPolicy_SuccessWithNoLabels(t *testing.T) { require.NoError(t, err) assert.True(t, s.guardInitialized) - assert.Equal(t, difc.EnforcementFilter, s.enforcementMode) + assert.Equal(t, difc.EnforcementFilter, s.Mode) } // TestInitGuardPolicy_SuccessAppliesAgentLabels verifies that secrecy and integrity tags @@ -157,7 +159,7 @@ func TestInitGuardPolicy_SuccessAppliesAgentLabels(t *testing.T) { require.NoError(t, err) assert.True(t, s.guardInitialized) - labels := s.agentRegistry.GetOrCreate("proxy") + labels := s.AgentRegistry.GetOrCreate("proxy") require.NotNil(t, labels) assert.Contains(t, labels.GetSecrecyTags(), difc.Tag("private:org/repo"), "secrecy tag must be applied") assert.Contains(t, labels.GetIntegrityTags(), difc.Tag("approved"), "integrity tag must be applied") @@ -174,7 +176,7 @@ func TestInitGuardPolicy_DIFCModeOverride(t *testing.T) { require.NoError(t, err) assert.True(t, s.guardInitialized) - assert.Equal(t, difc.EnforcementStrict, s.enforcementMode, + assert.Equal(t, difc.EnforcementStrict, s.Mode, "guard response DIFCMode must override the server's enforcement mode") } @@ -190,7 +192,7 @@ func TestInitGuardPolicy_InvalidDIFCModeError(t *testing.T) { assert.Contains(t, err.Error(), "invalid difc_mode") assert.False(t, s.guardInitialized, "guard must not be marked initialized when DIFCMode is invalid") - assert.Equal(t, difc.EnforcementFilter, s.enforcementMode, + assert.Equal(t, difc.EnforcementFilter, s.Mode, "enforcement mode must remain unchanged when DIFCMode is invalid") } @@ -206,7 +208,7 @@ func TestInitGuardPolicy_EmptyDIFCModePreservesMode(t *testing.T) { assert.True(t, s.guardInitialized) // Empty DIFCMode causes the mode-override block to be skipped entirely, so the // server's initial strict mode is preserved. - assert.Equal(t, difc.EnforcementStrict, s.enforcementMode) + assert.Equal(t, difc.EnforcementStrict, s.Mode) } // TestInitGuardPolicy_LegacyAllowOnlyKey verifies that a policy using the legacy @@ -271,7 +273,7 @@ func TestInitGuardPolicy_MultipleSecrecyTags(t *testing.T) { require.NoError(t, err) assert.True(t, s.guardInitialized) - labels := s.agentRegistry.GetOrCreate("proxy") + labels := s.AgentRegistry.GetOrCreate("proxy") require.NotNil(t, labels) for _, tag := range secrecy { assert.Contains(t, labels.GetSecrecyTags(), difc.Tag(tag), "secrecy tag %q must be applied", tag) diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index da8b4e0c0..0bb1b4472 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -38,10 +38,8 @@ const ( // It loads the same WASM guard used by the MCP gateway and runs the 6-phase // DIFC pipeline on every proxied response. type Server struct { - guard guard.Guard - evaluator *difc.Evaluator - agentRegistry *difc.AgentRegistry - capabilities *difc.Capabilities + guard guard.Guard + difc.DIFCComponents githubToken string githubAPIURL string // upstream base URL (no trailing slash) @@ -50,7 +48,6 @@ type Server struct { // guardInitialized tracks whether LabelAgent has been called guardInitialized bool - enforcementMode difc.EnforcementMode } // Config holds the configuration for creating a proxy Server. @@ -116,13 +113,10 @@ func New(ctx context.Context, cfg Config) (*Server, error) { logProxy.Printf("WASM guard loaded successfully") s := &Server{ - guard: g, - evaluator: difcComponents.Evaluator, - agentRegistry: difcComponents.AgentRegistry, - capabilities: difcComponents.Capabilities, - githubToken: cfg.GitHubToken, - githubAPIURL: apiURL, - enforcementMode: difcComponents.Mode, + guard: g, + DIFCComponents: difcComponents, + githubToken: cfg.GitHubToken, + githubAPIURL: apiURL, httpClient: &http.Client{ Timeout: 60 * time.Second, Transport: &http.Transport{ @@ -179,31 +173,22 @@ func (s *Server) initGuardPolicy(ctx context.Context, policyJSON string, trusted logProxy.Printf("Calling LabelAgent to initialize agent labels from guard") backend := &restBackendCaller{server: s} - result, err := s.guard.LabelAgent(ctx, payload, backend, s.capabilities) + agentLabels := s.AgentRegistry.GetOrCreate("proxy") + newMode, result, err := guard.RunLabelAgent(ctx, s.guard, payload, backend, s.Capabilities, agentLabels, s.Mode) if err != nil { - return fmt.Errorf("LabelAgent failed: %w", err) - } - if result == nil { - return fmt.Errorf("LabelAgent returned nil result") - } - - // Apply agent labels and parse enforcement mode from guard response - agentLabels := s.agentRegistry.GetOrCreate("proxy") - newMode, err := guard.ApplyLabelAgentResult(result, agentLabels, s.enforcementMode) - if err != nil { - return fmt.Errorf("LabelAgent result invalid: %w", err) + return err } logProxy.Printf("Agent labels applied: secrecy=%v, integrity=%v", result.Agent.Secrecy, result.Agent.Integrity) if result.DIFCMode != "" { - logProxy.Printf("Enforcement mode overridden by guard response: %s → %s", s.enforcementMode, newMode) - s.enforcementMode = newMode - s.evaluator.SetMode(newMode) + logProxy.Printf("Enforcement mode overridden by guard response: %s → %s", s.Mode, newMode) + s.Mode = newMode + s.Evaluator.SetMode(newMode) } s.guardInitialized = true logProxy.Printf("Guard initialized: mode=%s, secrecy=%v, integrity=%v", - s.enforcementMode, result.Agent.Secrecy, result.Agent.Integrity) + s.Mode, result.Agent.Secrecy, result.Agent.Integrity) return nil } diff --git a/internal/server/call_backend_tool_difc_test.go b/internal/server/call_backend_tool_difc_test.go index 5e924b5d2..5c7c80d7c 100644 --- a/internal/server/call_backend_tool_difc_test.go +++ b/internal/server/call_backend_tool_difc_test.go @@ -789,7 +789,7 @@ func TestCallBackendTool_Phase6_PropagateModeAccumulatesLabels(t *testing.T) { assert.False(result.IsError) // Agent labels should now contain the resource's secrecy tag (propagate mode). - agentLabels, ok := us.agentRegistry.Get(agentID) + agentLabels, ok := us.AgentRegistry.Get(agentID) require.True(ok, "agent should exist in registry after call") secrecyTags := agentLabels.GetSecrecyTags() assert.Contains(secrecyTags, difc.Tag("private:org/repo"), diff --git a/internal/server/ensure_guard_initialized_test.go b/internal/server/ensure_guard_initialized_test.go index 5ef4d4208..265439133 100644 --- a/internal/server/ensure_guard_initialized_test.go +++ b/internal/server/ensure_guard_initialized_test.go @@ -75,11 +75,14 @@ func newMinimalUnifiedServer(cfg *config.Config) *UnifiedServer { } } return &UnifiedServer{ - cfg: cfg, - sessions: make(map[string]*Session), - agentRegistry: difc.NewAgentRegistryWithDefaults(nil, nil), - capabilities: difc.NewCapabilities(), - evaluator: difc.NewEvaluatorWithMode(difcMode), + cfg: cfg, + sessions: make(map[string]*Session), + DIFCComponents: difc.DIFCComponents{ + Mode: difcMode, + AgentRegistry: difc.NewAgentRegistryWithDefaults(nil, nil), + Capabilities: difc.NewCapabilities(), + Evaluator: difc.NewEvaluatorWithMode(difcMode), + }, } } @@ -192,7 +195,7 @@ func TestEnsureGuardInitialized_LabelAgentError(t *testing.T) { _, err := us.ensureGuardInitialized(context.Background(), "session-err", "server1", g, &noopBackendCaller{}) require.Error(t, err) - assert.Contains(t, err.Error(), "label_agent failed") + assert.Contains(t, err.Error(), "LabelAgent failed") assert.Contains(t, err.Error(), "backend unreachable") } @@ -213,7 +216,7 @@ func TestEnsureGuardInitialized_LabelAgentNilResult(t *testing.T) { _, err := us.ensureGuardInitialized(context.Background(), "session-nil", "server1", g, &noopBackendCaller{}) require.Error(t, err) - assert.Contains(t, err.Error(), "label_agent returned nil result") + assert.Contains(t, err.Error(), "LabelAgent returned nil result") } // TestEnsureGuardInitialized_DIFCModeEmpty verifies that when LabelAgent returns an @@ -345,7 +348,7 @@ func TestEnsureGuardInitialized_LabelsAddedToRegistry(t *testing.T) { require.NoError(t, err) - agentLabels, ok := us.agentRegistry.Get("labeled-agent-id") + agentLabels, ok := us.AgentRegistry.Get("labeled-agent-id") require.True(t, ok, "agent labels should be registered") assert.Contains(t, agentLabels.GetSecrecyTags(), difc.Tag("private:org")) assert.Contains(t, agentLabels.GetIntegrityTags(), difc.Tag("merged")) @@ -542,7 +545,7 @@ func TestEnsureGuardInitialized_LabelsMergedAcrossGuards(t *testing.T) { _, err = us.ensureGuardInitialized(ctx, "session-union", "server-B", gB, &noopBackendCaller{}) require.NoError(t, err) - agentLabels, ok := us.agentRegistry.Get(agentID) + agentLabels, ok := us.AgentRegistry.Get(agentID) require.True(t, ok) tags := agentLabels.GetSecrecyTags() assert.Contains(t, tags, difc.Tag("tag-A"), "tag-A from guard-A should be present") diff --git a/internal/server/guard_init.go b/internal/server/guard_init.go index eafcebf90..76f8f7f3c 100644 --- a/internal/server/guard_init.go +++ b/internal/server/guard_init.go @@ -297,7 +297,7 @@ func (us *UnifiedServer) ensureGuardInitialized( g guard.Guard, backendCaller guard.BackendCaller, ) (difc.EnforcementMode, error) { - defaultMode := us.evaluator.GetMode() + defaultMode := us.Evaluator.GetMode() policy, source, err := us.resolveGuardPolicy(serverID) if err != nil { @@ -341,14 +341,18 @@ func (us *UnifiedServer) ensureGuardInitialized( logger.LogInfoWithServer(serverID, "difc", "Initializing guard session state: session=%s, policy_source=%s", sessionID, source) logger.LogInfoWithServer(serverID, "difc", "Calling label_agent: session=%s, guard=%s, policy=%s", sessionID, g.Name(), string(policyJSON)) - labelAgentResult, err := g.LabelAgent(ctx, labelAgentPayload, backendCaller, us.capabilities) + + agentID := guard.GetAgentIDFromContext(ctx) + + // Merge labels into existing agent (union semantics). + // Multiple guards may contribute labels for the same agent; each guard's + // label_agent output is additive so that later guards do not overwrite + // labels set by earlier ones. + agentLabels := us.AgentRegistry.GetOrCreate(agentID) + mode, labelAgentResult, err := guard.RunLabelAgent(ctx, g, labelAgentPayload, backendCaller, us.Capabilities, agentLabels, defaultMode) if err != nil { logger.LogErrorWithServer(serverID, "difc", "label_agent failed: session=%s, guard=%s, error=%v", sessionID, g.Name(), err) - return defaultMode, fmt.Errorf("label_agent failed: %w", err) - } - if labelAgentResult == nil { - logger.LogErrorWithServer(serverID, "difc", "label_agent returned nil result: session=%s, guard=%s", sessionID, g.Name()) - return defaultMode, fmt.Errorf("label_agent returned nil result") + return defaultMode, err } logger.LogMarshaledForDebug( labelAgentResult, @@ -360,18 +364,6 @@ func (us *UnifiedServer) ensureGuardInitialized( }, ) - agentID := guard.GetAgentIDFromContext(ctx) - - // Merge labels into existing agent (union semantics). - // Multiple guards may contribute labels for the same agent; each guard's - // label_agent output is additive so that later guards do not overwrite - // labels set by earlier ones. - agentLabels := us.agentRegistry.GetOrCreate(agentID) - mode, err := guard.ApplyLabelAgentResult(labelAgentResult, agentLabels, defaultMode) - if err != nil { - return defaultMode, fmt.Errorf("label_agent result invalid: %w", err) - } - us.sessionMu.Lock() session = us.sessions[sessionID] normalizedPolicy := config.NormalizeScopeKind(labelAgentResult.NormalizedPolicy) diff --git a/internal/server/http_helpers.go b/internal/server/http_helpers.go index f1336ebf0..d27d64b8e 100644 --- a/internal/server/http_helpers.go +++ b/internal/server/http_helpers.go @@ -49,10 +49,7 @@ func logRuntimeError(errorType, detail string, r *http.Request, serverName *stri // All HTTP error paths in the server package should use this helper to ensure // clients always receive application/json rather than text/plain. func writeErrorResponse(w http.ResponseWriter, statusCode int, code, message string) { - httputil.WriteJSONResponse(w, statusCode, map[string]string{ - "error": code, - "message": message, - }) + httputil.WriteErrorResponse(w, statusCode, code, message) } // rejectRequest logs a structured error, records a runtime error, and writes an diff --git a/internal/server/label_agent_test.go b/internal/server/label_agent_test.go index 527d6a605..537d4f40b 100644 --- a/internal/server/label_agent_test.go +++ b/internal/server/label_agent_test.go @@ -167,7 +167,7 @@ func TestCallBackendTool_LabelAgentInitializationCached(t *testing.T) { customGuard.mu.Unlock() assert.Equal(1, calls, "label_agent should run once per session/server policy") - agentLabels, ok := us.agentRegistry.Get("session-123") + agentLabels, ok := us.AgentRegistry.Get("session-123") require.True(ok) assert.Contains(agentLabels.GetSecrecyTags(), difc.Tag("policy-secret")) assert.Contains(agentLabels.GetIntegrityTags(), difc.Tag("policy-integrity")) @@ -293,7 +293,7 @@ func TestCallBackendTool_LabelAgentInitializationFromServerGuardPolicies(t *test customGuard.mu.Unlock() assert.Equal(1, calls, "label_agent should run once per session/server policy from guard-policies") - agentLabels, ok := us.agentRegistry.Get("session-456") + agentLabels, ok := us.AgentRegistry.Get("session-456") require.True(ok) assert.Contains(agentLabels.GetSecrecyTags(), difc.Tag("policy-secret")) assert.Contains(agentLabels.GetIntegrityTags(), difc.Tag("policy-integrity")) diff --git a/internal/server/unified.go b/internal/server/unified.go index fb5da8266..d09d25309 100644 --- a/internal/server/unified.go +++ b/internal/server/unified.go @@ -99,10 +99,8 @@ type UnifiedServer struct { // DIFC components guardRegistry *guard.Registry - agentRegistry *difc.AgentRegistry - capabilities *difc.Capabilities - evaluator *difc.Evaluator - enableDIFC bool // When true, DIFC enforcement and session requirement are enabled + difc.DIFCComponents + enableDIFC bool // When true, DIFC enforcement and session requirement are enabled // Configuration reference for guard loading cfg *config.Config @@ -162,11 +160,9 @@ func NewUnified(ctx context.Context, cfg *config.Config) (*UnifiedServer, error) circuitBreakers: buildCircuitBreakers(cfg), // Initialize DIFC components - guardRegistry: guard.NewRegistry(), - agentRegistry: difcComponents.AgentRegistry, - capabilities: difcComponents.Capabilities, - evaluator: difcComponents.Evaluator, - cfg: cfg, // Store config for guard loading + guardRegistry: guard.NewRegistry(), + DIFCComponents: difcComponents, + cfg: cfg, // Store config for guard loading // Cache tracer at construction to avoid calling otel.Tracer on every request. tracer: tracing.Tracer(), @@ -480,7 +476,7 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName // **Phase 0: Extract agent ID and get/create agent labels** agentID := guard.GetAgentIDFromContext(ctx) - agentLabels := us.agentRegistry.GetOrCreate(agentID) + agentLabels := us.AgentRegistry.GetOrCreate(agentID) logUnified.Printf("[DIFC] Agent %s | Secrecy: %v | Integrity: %v", agentID, agentLabels.GetSecrecyTags(), agentLabels.GetIntegrityTags()) @@ -496,7 +492,7 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName }) // **Phase 1: Guard labels the resource** - resource, operation, err := g.LabelResource(ctx, toolName, args, backendCaller, us.capabilities) + resource, operation, err := g.LabelResource(ctx, toolName, args, backendCaller, us.Capabilities) if err != nil { logger.LogWarn("difc", "Guard labeling failed: %v", err) httpStatusCode = 500 @@ -586,7 +582,7 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName var labeledData difc.LabeledData if shouldCallLabelResponse { - labeledData, err = g.LabelResponse(ctx, toolName, backendResult, backendCaller, us.capabilities) + labeledData, err = g.LabelResponse(ctx, toolName, backendResult, backendCaller, us.Capabilities) if err != nil { logger.LogWarn("difc", "Response labeling failed: %v", err) httpStatusCode = 500 diff --git a/internal/server/write_sink_guard_test.go b/internal/server/write_sink_guard_test.go index da2272e00..f706aa26a 100644 --- a/internal/server/write_sink_guard_test.go +++ b/internal/server/write_sink_guard_test.go @@ -235,7 +235,7 @@ func TestWriteSinkGuard_AllowsWriteAfterGitHubRead(t *testing.T) { assert.False(result1.IsError, "GitHub read should succeed") // Verify agent acquired tags - agentLabels, ok := us.agentRegistry.Get(sessionID) + agentLabels, ok := us.AgentRegistry.Get(sessionID) require.True(ok, "agent should be registered after GitHub read") assert.Contains(agentLabels.GetSecrecyTags(), difc.Tag("private:github/gh-aw*")) assert.Contains(agentLabels.GetIntegrityTags(), difc.Tag("approved:github/gh-aw*"))