Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions internal/guard/init.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package guard

import (
"context"
"fmt"

"github.com/github/gh-aw-mcpg/internal/difc"
"github.com/github/gh-aw-mcpg/internal/logger"
)

var logGuardInit = logger.New("guard:init")

// RunLabelAgent executes the standard LabelAgent initialization pipeline:
// 1. Calls the guard's LabelAgent method with the provided pre-built payload.
// 2. Validates the result is non-nil.
// 3. Applies the returned agent labels to agentLabels.
// 4. Parses and returns the effective enforcement mode.
//
// It returns the effective enforcement mode, the raw LabelAgentResult (so
// callers can inspect NormalizedPolicy, DIFCMode, etc.), and any error.
// On error, defaultMode is returned unchanged so the caller's mode is unaffected.
func RunLabelAgent(
ctx context.Context,
g Guard,
payload interface{},
backend BackendCaller,
caps *difc.Capabilities,
agentLabels *difc.AgentLabels,
defaultMode difc.EnforcementMode,
) (difc.EnforcementMode, *LabelAgentResult, error) {
logGuardInit.Printf("Calling LabelAgent: guard=%s", g.Name())

result, err := g.LabelAgent(ctx, payload, backend, caps)
if err != nil {
logGuardInit.Printf("LabelAgent failed: guard=%s, error=%v", g.Name(), err)
return defaultMode, nil, fmt.Errorf("LabelAgent failed: %w", err)
}
if result == nil {
logGuardInit.Printf("LabelAgent returned nil result: guard=%s", g.Name())
return defaultMode, nil, fmt.Errorf("LabelAgent returned nil result")
}

mode, err := ApplyLabelAgentResult(result, agentLabels, defaultMode)
if err != nil {
logGuardInit.Printf("LabelAgent result invalid: guard=%s, error=%v", g.Name(), err)
return defaultMode, nil, fmt.Errorf("LabelAgent result invalid: %w", err)
}

logGuardInit.Printf("LabelAgent completed: guard=%s, mode=%s", g.Name(), mode)
return mode, result, nil
}
133 changes: 133 additions & 0 deletions internal/guard/init_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package guard

import (
"context"
"errors"
"testing"

"github.com/github/gh-aw-mcpg/internal/difc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// runLabelAgentStubGuard is a minimal Guard implementation used to test RunLabelAgent.
type runLabelAgentStubGuard struct {
name string
labelAgentResult *LabelAgentResult
labelAgentErr error
}

func (g *runLabelAgentStubGuard) Name() string { return g.name }

func (g *runLabelAgentStubGuard) LabelAgent(_ context.Context, _ interface{}, _ BackendCaller, _ *difc.Capabilities) (*LabelAgentResult, error) {
return g.labelAgentResult, g.labelAgentErr
}

func (g *runLabelAgentStubGuard) LabelResource(_ context.Context, _ string, _ interface{}, _ BackendCaller, _ *difc.Capabilities) (*difc.LabeledResource, difc.OperationType, error) {
return difc.NewLabeledResource("stub"), difc.OperationRead, nil
}

func (g *runLabelAgentStubGuard) LabelResponse(_ context.Context, _ string, _ interface{}, _ BackendCaller, _ *difc.Capabilities) (difc.LabeledData, error) {
return nil, nil
}

// noopRunBackendCaller is a BackendCaller that always returns nil.
type noopRunBackendCaller struct{}

func (n *noopRunBackendCaller) CallTool(_ context.Context, _ string, _ interface{}) (interface{}, error) {
return nil, nil
}

func TestRunLabelAgent_Success(t *testing.T) {
g := &runLabelAgentStubGuard{
name: "test-guard",
labelAgentResult: &LabelAgentResult{
Agent: AgentLabelsPayload{Secrecy: []string{"private:org/repo"}, Integrity: []string{"approved"}},
DIFCMode: difc.ModeFilter,
},
}
caps := difc.NewCapabilities()
agentLabels := difc.NewAgentRegistryWithDefaults(nil, nil).GetOrCreate("test-agent")
defaultMode := difc.EnforcementStrict

mode, result, err := RunLabelAgent(context.Background(), g, map[string]interface{}{"policy": "test"}, &noopRunBackendCaller{}, caps, agentLabels, defaultMode)

require.NoError(t, err)
require.NotNil(t, result)
assert.Equal(t, difc.EnforcementFilter, mode, "mode should be overridden by guard response")
assert.Equal(t, difc.ModeFilter, result.DIFCMode)
}

func TestRunLabelAgent_GuardError(t *testing.T) {
g := &runLabelAgentStubGuard{
name: "error-guard",
labelAgentErr: errors.New("wasm runtime error"),
}
caps := difc.NewCapabilities()
agentLabels := difc.NewAgentRegistryWithDefaults(nil, nil).GetOrCreate("test-agent")
defaultMode := difc.EnforcementFilter

mode, result, err := RunLabelAgent(context.Background(), g, nil, &noopRunBackendCaller{}, caps, agentLabels, defaultMode)

require.Error(t, err)
assert.Contains(t, err.Error(), "LabelAgent failed")
assert.Contains(t, err.Error(), "wasm runtime error")
assert.Nil(t, result)
assert.Equal(t, defaultMode, mode, "defaultMode should be returned on error")
}

func TestRunLabelAgent_NilResult(t *testing.T) {
g := &runLabelAgentStubGuard{
name: "nil-result-guard",
labelAgentResult: nil,
}
caps := difc.NewCapabilities()
agentLabels := difc.NewAgentRegistryWithDefaults(nil, nil).GetOrCreate("test-agent")
defaultMode := difc.EnforcementStrict

mode, result, err := RunLabelAgent(context.Background(), g, nil, &noopRunBackendCaller{}, caps, agentLabels, defaultMode)

require.Error(t, err)
assert.Contains(t, err.Error(), "LabelAgent returned nil result")
assert.Nil(t, result)
assert.Equal(t, defaultMode, mode)
}

func TestRunLabelAgent_InvalidDIFCMode(t *testing.T) {
g := &runLabelAgentStubGuard{
name: "bad-mode-guard",
labelAgentResult: &LabelAgentResult{
Agent: AgentLabelsPayload{Secrecy: []string{}, Integrity: []string{}},
DIFCMode: "not-a-real-mode",
},
}
caps := difc.NewCapabilities()
agentLabels := difc.NewAgentRegistryWithDefaults(nil, nil).GetOrCreate("test-agent")
defaultMode := difc.EnforcementFilter

mode, result, err := RunLabelAgent(context.Background(), g, nil, &noopRunBackendCaller{}, caps, agentLabels, defaultMode)

require.Error(t, err)
assert.Contains(t, err.Error(), "LabelAgent result invalid")
assert.Nil(t, result)
assert.Equal(t, defaultMode, mode)
}

func TestRunLabelAgent_EmptyDIFCModePreservesDefault(t *testing.T) {
g := &runLabelAgentStubGuard{
name: "empty-mode-guard",
labelAgentResult: &LabelAgentResult{
Agent: AgentLabelsPayload{Secrecy: []string{}, Integrity: []string{}},
DIFCMode: "", // empty → preserve defaultMode
},
}
caps := difc.NewCapabilities()
agentLabels := difc.NewAgentRegistryWithDefaults(nil, nil).GetOrCreate("test-agent")
defaultMode := difc.EnforcementStrict

mode, result, err := RunLabelAgent(context.Background(), g, nil, &noopRunBackendCaller{}, caps, agentLabels, defaultMode)

require.NoError(t, err)
require.NotNil(t, result)
assert.Equal(t, defaultMode, mode, "empty DIFCMode should preserve defaultMode")
}
10 changes: 10 additions & 0 deletions internal/httputil/httputil.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ func WriteJSONResponse(w http.ResponseWriter, statusCode int, body interface{})
}
}

// WriteErrorResponse writes a JSON error response with a consistent
// {"error": code, "message": message} shape. Both the server and proxy packages
// should use this helper so that API consumers always receive the same error shape.
func WriteErrorResponse(w http.ResponseWriter, statusCode int, code, message string) {
WriteJSONResponse(w, statusCode, map[string]string{
"error": code,
"message": message,
})
}

// IsTransientHTTPError returns true for status codes that indicate a temporary
// server-side condition (rate-limiting or transient failure) worth retrying.
func IsTransientHTTPError(statusCode int) bool {
Expand Down
28 changes: 28 additions & 0 deletions internal/httputil/httputil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,31 @@ func TestApplyGitHubAPIHeaders(t *testing.T) {
assert.Equal(t, GitHubUserAgent, req.Header.Get("User-Agent"))
})
}

func TestWriteErrorResponse(t *testing.T) {
tests := []struct {
name string
statusCode int
code string
message string
}{
{name: "400", statusCode: http.StatusBadRequest, code: "bad_request", message: "malformed input"},
{name: "403", statusCode: http.StatusForbidden, code: "difc_forbidden", message: "DIFC policy violation"},
{name: "500", statusCode: http.StatusInternalServerError, code: "internal_error", message: "unexpected error"},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
WriteErrorResponse(w, tt.statusCode, tt.code, tt.message)

assert.Equal(t, tt.statusCode, w.Code)
assert.Equal(t, "application/json", w.Header().Get("Content-Type"))

var body map[string]string
require.NoError(t, json.NewDecoder(w.Body).Decode(&body))
assert.Equal(t, tt.code, body["error"])
assert.Equal(t, tt.message, body["message"])
})
}
}
20 changes: 10 additions & 10 deletions internal/proxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ import (
var logHandler = logger.New("proxy:handler")

// writeDIFCForbidden writes a 403 JSON response for DIFC policy violations.
// Uses the shared WriteErrorResponse helper so that the response shape is consistent
// with all other error responses in the gateway ({"error": ..., "message": ...}).
func writeDIFCForbidden(w http.ResponseWriter, message string) {
httputil.WriteJSONResponse(w, http.StatusForbidden, map[string]string{
"message": message,
})
httputil.WriteErrorResponse(w, http.StatusForbidden, "difc_forbidden", message)
}

// proxyHandler implements http.Handler and runs the DIFC pipeline on proxied requests.
Expand Down Expand Up @@ -157,12 +157,12 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa
}

// **Phase 0: Get agent labels**
agentLabels := s.agentRegistry.GetOrCreate("proxy")
agentLabels := s.AgentRegistry.GetOrCreate("proxy")
logHandler.Printf("[DIFC] Phase 0: agent secrecy=%v integrity=%v",
agentLabels.GetSecrecyTags(), agentLabels.GetIntegrityTags())

// **Phase 1: Guard labels the resource**
resource, operation, err := s.guard.LabelResource(ctx, toolName, args, backend, s.capabilities)
resource, operation, err := s.guard.LabelResource(ctx, toolName, args, backend, s.Capabilities)
if err != nil {
logHandler.Printf("[DIFC] Phase 1 failed: %v", err)
// On labeling failure, fail closed to prevent enforcement bypass
Expand All @@ -175,7 +175,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa
resource.Secrecy.Label.GetTags(), resource.Integrity.Label.GetTags())

// **Phase 2: Coarse-grained access check**
evalResult := s.evaluator.Evaluate(agentLabels.Secrecy, agentLabels.Integrity, resource, operation)
evalResult := s.Evaluator.Evaluate(agentLabels.Secrecy, agentLabels.Integrity, resource, operation)

if !evalResult.IsAllowed() {
if difc.ShouldBypassCoarseDeny(operation) {
Expand Down Expand Up @@ -235,7 +235,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa
ctx = guard.SetRequestStateInContext(ctx, map[string]interface{}{
"tool_args": args,
})
labeledData, err := s.guard.LabelResponse(ctx, toolName, responseData, backend, s.capabilities)
labeledData, err := s.guard.LabelResponse(ctx, toolName, responseData, backend, s.Capabilities)
if err != nil {
logHandler.Printf("[DIFC] Phase 4 failed: %v", err)
// On labeling failure, use coarse-grained result
Expand All @@ -252,7 +252,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa
var useOriginalBody bool // GraphQL responses need original format preserved
if labeledData != nil {
if collection, ok := labeledData.(*difc.CollectionLabeledData); ok {
filtered := s.evaluator.FilterCollection(
filtered := s.Evaluator.FilterCollection(
agentLabels.Secrecy, agentLabels.Integrity, collection, operation)

logHandler.Printf("[DIFC] Phase 5: %d/%d items accessible",
Expand All @@ -266,7 +266,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa
}

// Strict mode: block entire response if any item filtered
if difc.ShouldBlockFilteredResponse(s.enforcementMode, filtered.GetFilteredCount()) {
if difc.ShouldBlockFilteredResponse(s.Mode, filtered.GetFilteredCount()) {
logHandler.Printf("[DIFC] STRICT: blocking response — %d filtered items", filtered.GetFilteredCount())
writeDIFCForbidden(w, fmt.Sprintf("DIFC policy violation: %d of %d items not accessible",
filtered.GetFilteredCount(), filtered.TotalCount))
Expand Down Expand Up @@ -318,7 +318,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa
}

// **Phase 6: Label accumulation (propagate mode)**
if labeledData != nil && difc.ShouldAccumulateReadLabels(operation, s.enforcementMode) {
if labeledData != nil && difc.ShouldAccumulateReadLabels(operation, s.Mode) {
overall := labeledData.Overall()
agentLabels.AccumulateFromRead(overall)
logHandler.Printf("[DIFC] Phase 6: accumulated labels")
Expand Down
36 changes: 21 additions & 15 deletions internal/proxy/handler_difc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,16 @@ func publicResource() *difc.LabeledResource {
func newTestServerWithStub(t *testing.T, upstreamURL string, g *stubGuard, mode difc.EnforcementMode) *Server {
t.Helper()
return &Server{
guard: g,
evaluator: difc.NewEvaluatorWithMode(mode),
agentRegistry: difc.NewAgentRegistryWithDefaults(nil, nil),
capabilities: difc.NewCapabilities(),
guard: g,
DIFCComponents: difc.DIFCComponents{
Mode: mode,
Evaluator: difc.NewEvaluatorWithMode(mode),
AgentRegistry: difc.NewAgentRegistryWithDefaults(nil, nil),
Capabilities: difc.NewCapabilities(),
},
githubAPIURL: upstreamURL,
httpClient: &http.Client{},
guardInitialized: true,
enforcementMode: mode,
}
}

Expand All @@ -78,14 +80,16 @@ func newTestServerWithPrivateAgent(t *testing.T, upstreamURL string, g *stubGuar
t.Helper()
reg := difc.NewAgentRegistryWithDefaults([]difc.Tag{"private:test-org/test-repo"}, nil)
return &Server{
guard: g,
evaluator: difc.NewEvaluatorWithMode(mode),
agentRegistry: reg,
capabilities: difc.NewCapabilities(),
guard: g,
DIFCComponents: difc.DIFCComponents{
Mode: mode,
Evaluator: difc.NewEvaluatorWithMode(mode),
AgentRegistry: reg,
Capabilities: difc.NewCapabilities(),
},
githubAPIURL: upstreamURL,
httpClient: &http.Client{},
guardInitialized: true,
enforcementMode: mode,
}
}

Expand Down Expand Up @@ -494,14 +498,16 @@ func TestHandleWithDIFC_PropagateMode_AccumulatesLabels(t *testing.T) {
}
reg := difc.NewAgentRegistryWithDefaults(nil, nil)
s := &Server{
guard: g,
evaluator: difc.NewEvaluatorWithMode(difc.EnforcementPropagate),
agentRegistry: reg,
capabilities: difc.NewCapabilities(),
guard: g,
DIFCComponents: difc.DIFCComponents{
Mode: difc.EnforcementPropagate,
Evaluator: difc.NewEvaluatorWithMode(difc.EnforcementPropagate),
AgentRegistry: reg,
Capabilities: difc.NewCapabilities(),
},
githubAPIURL: upstream.URL,
httpClient: &http.Client{},
guardInitialized: true,
enforcementMode: difc.EnforcementPropagate,
}
h := &proxyHandler{server: s}

Expand Down
Loading
Loading