Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions internal/cmd/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
//
// This keeps the env-var name co-located with the flag declaration.
//
// Exception: getDefaultDIFCMode() in flags_difc.go is kept as a named helper
// because it contains validation logic beyond a simple env lookup.
// Exception: difc.DefaultEnforcementMode() is kept as a named helper because
// it contains validation logic beyond a simple env lookup.
//
// When adding a new flag with an environment variable override:
// 1. Use envutil.GetEnv* directly in the RegisterFlag call.
Expand Down
52 changes: 1 addition & 51 deletions internal/cmd/flags_difc.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@ package cmd
// DIFC (Decentralized Information Flow Control) related flags

import (
"fmt"
"os"
"strings"

"github.com/github/gh-aw-mcpg/internal/config"
"github.com/github/gh-aw-mcpg/internal/difc"
"github.com/github/gh-aw-mcpg/internal/envutil"
"github.com/github/gh-aw-mcpg/internal/strutil"
"github.com/spf13/cobra"
)

Expand All @@ -30,7 +27,7 @@ const containerGuardWasmPath = "/guards/github/00-github-guard.wasm"

func init() {
RegisterFlag(func(cmd *cobra.Command) {
cmd.Flags().StringVar(&difcMode, "guards-mode", getDefaultDIFCMode(), "Guards enforcement mode: strict (deny violations), filter (remove denied tools), or propagate (auto-adjust agent labels on reads)")
cmd.Flags().StringVar(&difcMode, "guards-mode", difc.DefaultEnforcementMode(), "Guards enforcement mode: strict (deny violations), filter (remove denied tools), or propagate (auto-adjust agent labels on reads)")
cmd.Flags().StringVar(&difcSinkServerIDs, "guards-sink-server-ids", envutil.GetEnvString("MCP_GATEWAY_GUARDS_SINK_SERVER_IDS", ""), "Comma-separated server IDs whose RPC JSONL logs should include agent secrecy/integrity tag snapshots")
cmd.Flags().StringVar(&guardPolicyJSON, "guard-policy-json", envutil.GetEnvString(config.EnvGuardPolicyJSON, ""), "Guard policy JSON (e.g. {\"allow-only\":{\"repos\":\"public\",\"min-integrity\":\"none\"}})")
cmd.Flags().BoolVar(&allowOnlyPublic, "allowonly-scope-public", envutil.GetEnvBool(config.EnvAllowOnlyScopePublic, false), "Use public AllowOnly scope")
Expand Down Expand Up @@ -74,50 +71,3 @@ func resolveGuardPolicyOverride(cmd *cobra.Command) (*config.GuardPolicy, string
allowOnlyMinInt,
)
}

// getDefaultDIFCMode returns the default guards mode, checking MCP_GATEWAY_GUARDS_MODE
// environment variable first, then falling back to the hardcoded default (strict)
func getDefaultDIFCMode() string {
if envMode := os.Getenv("MCP_GATEWAY_GUARDS_MODE"); envMode != "" {
mode := strings.ToLower(envMode)
if _, err := difc.ParseEnforcementMode(mode); err == nil {
debugLog.Printf("Guards mode set from MCP_GATEWAY_GUARDS_MODE: %s", mode)
return mode
}
debugLog.Printf("MCP_GATEWAY_GUARDS_MODE value %q is invalid, falling back to default: %s", envMode, difc.ModeStrict)
}
return difc.ModeStrict
}

// validateDIFCModeFlag validates the value of the --guards-mode CLI flag.
func validateDIFCModeFlag(mode string) error {
if _, err := difc.ParseEnforcementMode(mode); err != nil {
return fmt.Errorf("invalid --guards-mode flag: %w", err)
}
return nil
}

func parseDIFCSinkServerIDs(input string) ([]string, error) {
if strings.TrimSpace(input) == "" {
return nil, nil
}

debugLog.Printf("Parsing DIFC sink server IDs: input=%q", input)

parts := strings.Split(input, ",")
validated := make([]string, 0, len(parts))
for _, part := range parts {
value := strings.TrimSpace(part)
if value == "" {
continue
}
if strings.ContainsAny(value, " \t\n\r") {
return nil, fmt.Errorf("invalid guards sink server ID %q: whitespace is not allowed", value)
}
validated = append(validated, value)
}

result := strutil.DeduplicateStrings(validated, false)
debugLog.Printf("Parsed %d unique DIFC sink server IDs: %v", len(result), result)
return result, nil
}
80 changes: 1 addition & 79 deletions internal/cmd/flags_difc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,58 +69,6 @@ func TestValidateDIFCMode(t *testing.T) {
}
}

func TestGetDefaultDIFCMode(t *testing.T) {
tests := []struct {
name string
envValue string
want string
}{
{
name: "no env var returns strict",
envValue: "",
want: "strict",
},
{
name: "env var strict",
envValue: "strict",
want: "strict",
},
{
name: "env var filter",
envValue: "filter",
want: "filter",
},
{
name: "env var propagate",
envValue: "propagate",
want: "propagate",
},
{
name: "env var FILTER uppercase",
envValue: "FILTER",
want: "filter",
},
{
name: "env var invalid falls back to strict",
envValue: "invalid",
want: "strict",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.envValue != "" {
t.Setenv("MCP_GATEWAY_GUARDS_MODE", tt.envValue)
} else {
t.Setenv("MCP_GATEWAY_GUARDS_MODE", "")
}

got := getDefaultDIFCMode()
assert.Equal(t, tt.want, got)
})
}
}

func TestValidDIFCModes(t *testing.T) {
require := require.New(t)

Expand All @@ -136,32 +84,6 @@ func TestValidDIFCModes(t *testing.T) {
require.Len(difc.ValidModes, 3, "should only have 3 valid modes")
}

func TestValidateDIFCModeFlag(t *testing.T) {
tests := []struct {
name string
mode string
wantErr bool
}{
{name: "strict valid", mode: "strict", wantErr: false},
{name: "filter valid", mode: "filter", wantErr: false},
{name: "propagate valid", mode: "propagate", wantErr: false},
{name: "empty defaults to strict", mode: "", wantErr: false},
{name: "invalid mode", mode: "bogus", wantErr: true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateDIFCModeFlag(tt.mode)
if tt.wantErr {
require.Error(t, err, "expected error for mode %q", tt.mode)
assert.ErrorContains(t, err, "invalid --guards-mode flag")
} else {
assert.NoError(t, err, "unexpected error for mode %q", tt.mode)
}
})
}
}

func TestParseDIFCSinkServerIDs(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -218,7 +140,7 @@ func TestParseDIFCSinkServerIDs(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := parseDIFCSinkServerIDs(tt.input)
result, err := difc.ParseSinkServerIDs(tt.input)
if tt.wantErr {
require.Error(t, err)
return
Expand Down
4 changes: 2 additions & 2 deletions internal/cmd/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ func runProxy(cmd *cobra.Command, args []string) error {

logProxyCmd.Printf("Starting proxy: listen=%s, guard=%s, mode=%s, tls=%v", proxyListen, proxyGuardWasm, proxyDIFCMode, proxyTLS)

if err := validateDIFCModeFlag(proxyDIFCMode); err != nil {
return err
if _, err := difc.ParseEnforcementMode(proxyDIFCMode); err != nil {
return fmt.Errorf("invalid --guards-mode flag: %w", err)
}

// Initialize loggers
Expand Down
6 changes: 3 additions & 3 deletions internal/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ func run(cmd *cobra.Command, args []string) error {
}

// Validate guards mode before applying
if err := validateDIFCModeFlag(difcMode); err != nil {
return err
if _, err := difc.ParseEnforcementMode(difcMode); err != nil {
return fmt.Errorf("invalid --guards-mode flag: %w", err)
}

// Apply command-line flags to config
Expand All @@ -237,7 +237,7 @@ func run(cmd *cobra.Command, args []string) error {
logger.StartupInfo("MCP_GATEWAY_GUARDS_SINK_SERVER_IDS=%q", envSinkServerIDs)
}

resolvedSinkServerIDs, err := parseDIFCSinkServerIDs(difcSinkServerIDs)
resolvedSinkServerIDs, err := difc.ParseSinkServerIDs(difcSinkServerIDs)
if err != nil {
return fmt.Errorf("invalid --guards-sink-server-ids value: %w", err)
}
Expand Down
8 changes: 8 additions & 0 deletions internal/config/guard_policy_integrity_levels.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package config

var allowedGuardPolicyIntegrityLevels = []string{
IntegrityNone,
IntegrityUnapproved,
IntegrityApproved,
IntegrityMerged,
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package guard
package config

import (
"encoding/json"
"fmt"
)

// PolicyToMap converts a policy value to a generic map through a JSON roundtrip.
// It returns an error if the value cannot be serialized or does not decode to a
// JSON object.
func PolicyToMap(policy interface{}) (map[string]interface{}, error) {
// GuardPolicyToMap converts a policy value to a generic map through a JSON
// roundtrip. It returns an error if the value cannot be serialized or does not
// decode to a JSON object.
func GuardPolicyToMap(policy interface{}) (map[string]interface{}, error) {
if policy == nil {
return nil, fmt.Errorf("policy is required")
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package guard
package config

import (
"math"
Expand All @@ -8,7 +8,7 @@ import (
"github.com/stretchr/testify/require"
)

func TestPolicyToMap(t *testing.T) {
func TestGuardPolicyToMap(t *testing.T) {
t.Run("returns deep copy for map input", func(t *testing.T) {
policy := map[string]interface{}{
"allow-only": map[string]interface{}{
Expand All @@ -17,7 +17,7 @@ func TestPolicyToMap(t *testing.T) {
},
}

payload, err := PolicyToMap(policy)
payload, err := GuardPolicyToMap(policy)
require.NoError(t, err)
require.NotNil(t, payload)

Expand All @@ -31,19 +31,19 @@ func TestPolicyToMap(t *testing.T) {
})

t.Run("nil policy returns error", func(t *testing.T) {
_, err := PolicyToMap(nil)
_, err := GuardPolicyToMap(nil)
require.Error(t, err)
assert.ErrorContains(t, err, "policy is required")
})

t.Run("non-object policy returns error", func(t *testing.T) {
_, err := PolicyToMap([]string{"not-an-object"})
_, err := GuardPolicyToMap([]string{"not-an-object"})
require.Error(t, err)
assert.ErrorContains(t, err, "policy must decode to a JSON object")
})

t.Run("unmarshalable policy returns error", func(t *testing.T) {
_, err := PolicyToMap(math.NaN())
_, err := GuardPolicyToMap(math.NaN())
require.Error(t, err)
assert.ErrorContains(t, err, "failed to serialize policy")
})
Expand Down
3 changes: 1 addition & 2 deletions internal/config/guard_policy_parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"strings"

"github.com/github/gh-aw-mcpg/internal/envutil"
"github.com/github/gh-aw-mcpg/internal/guard"
)

// Environment variable names for guard policy configuration.
Expand Down Expand Up @@ -153,7 +152,7 @@ func BuildAllowOnlyPolicy(public bool, owner, repo, minIntegrity string) (*Guard
return nil, fmt.Errorf("min-integrity is required")
}
if !hasIntegrity {
return nil, fmt.Errorf("min-integrity must be one of: %s", strings.Join(guard.AllowedIntegrityLevels, ", "))
return nil, fmt.Errorf("min-integrity must be one of: %s", strings.Join(allowedGuardPolicyIntegrityLevels, ", "))
}

var repos interface{}
Expand Down
3 changes: 1 addition & 2 deletions internal/config/guard_policy_parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"strings"
"testing"

"github.com/github/gh-aw-mcpg/internal/guard"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -560,7 +559,7 @@ func TestBuildAllowOnlyPolicy_InvalidIntegrityErrorListsCanonicalValues(t *testi
got, err := BuildAllowOnlyPolicy(true, "", "", "superstrict")
require.Nil(t, got)
require.EqualError(t, err,
fmt.Sprintf("min-integrity must be one of: %s", strings.Join(guard.AllowedIntegrityLevels, ", ")))
fmt.Sprintf("min-integrity must be one of: %s", strings.Join(allowedGuardPolicyIntegrityLevels, ", ")))
}

// TestParsePolicyMap_LegacyMinIntegrityTakesPrecedence verifies that
Expand Down
8 changes: 3 additions & 5 deletions internal/config/guard_policy_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"fmt"
"sort"
"strings"

"github.com/github/gh-aw-mcpg/internal/guard"
)

// ValidateGuardPolicy validates AllowOnly or WriteSink policy input.
Expand Down Expand Up @@ -99,7 +97,7 @@ func NormalizeGuardPolicy(policy *GuardPolicy) (*NormalizedGuardPolicy, error) {

integrity := strings.ToLower(strings.TrimSpace(policy.AllowOnly.MinIntegrity))
if _, ok := validMinIntegrityValues[integrity]; !ok {
return nil, fmt.Errorf("allow-only.min-integrity must be one of: %s", strings.Join(guard.AllowedIntegrityLevels, ", "))
return nil, fmt.Errorf("allow-only.min-integrity must be one of: %s", strings.Join(allowedGuardPolicyIntegrityLevels, ", "))
}

normalized := &NormalizedGuardPolicy{MinIntegrity: integrity}
Expand Down Expand Up @@ -143,7 +141,7 @@ func NormalizeGuardPolicy(policy *GuardPolicy) (*NormalizedGuardPolicy, error) {
// uses Rust-side default of "none" when endorsement/disapproval is evaluated).
if v := strings.ToLower(strings.TrimSpace(policy.AllowOnly.DisapprovalIntegrity)); v != "" {
if _, ok := validMinIntegrityValues[v]; !ok {
return nil, fmt.Errorf("allow-only.disapproval-integrity must be one of: %s", strings.Join(guard.AllowedIntegrityLevels, ", "))
return nil, fmt.Errorf("allow-only.disapproval-integrity must be one of: %s", strings.Join(allowedGuardPolicyIntegrityLevels, ", "))
}
normalized.DisapprovalIntegrity = v
}
Expand All @@ -152,7 +150,7 @@ func NormalizeGuardPolicy(policy *GuardPolicy) (*NormalizedGuardPolicy, error) {
// uses Rust-side default of "approved" when evaluating reactor eligibility).
if v := strings.ToLower(strings.TrimSpace(policy.AllowOnly.EndorserMinIntegrity)); v != "" {
if _, ok := validMinIntegrityValues[v]; !ok {
return nil, fmt.Errorf("allow-only.endorser-min-integrity must be one of: %s", strings.Join(guard.AllowedIntegrityLevels, ", "))
return nil, fmt.Errorf("allow-only.endorser-min-integrity must be one of: %s", strings.Join(allowedGuardPolicyIntegrityLevels, ", "))
}
normalized.EndorserMinIntegrity = v
}
Expand Down
15 changes: 15 additions & 0 deletions internal/difc/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package difc

import (
"fmt"
"os"
"strings"

"github.com/github/gh-aw-mcpg/internal/logger"
Expand Down Expand Up @@ -89,6 +90,20 @@ func ParseEnforcementMode(s string) (EnforcementMode, error) {
}
}

// DefaultEnforcementMode returns the default guards mode, checking
// MCP_GATEWAY_GUARDS_MODE first and falling back to strict.
func DefaultEnforcementMode() string {
if envMode := os.Getenv("MCP_GATEWAY_GUARDS_MODE"); envMode != "" {
mode := strings.ToLower(envMode)
if _, err := ParseEnforcementMode(mode); err == nil {
logEvaluator.Printf("Guards mode set from MCP_GATEWAY_GUARDS_MODE: %s", mode)
return mode
}
logEvaluator.Printf("MCP_GATEWAY_GUARDS_MODE value %q is invalid, falling back to default: %s", envMode, ModeStrict)
}
return ModeStrict
}

// DIFCComponents holds the set of DIFC objects needed by a server or proxy.
type DIFCComponents struct {
Mode EnforcementMode
Expand Down
Loading
Loading