diff --git a/README.md b/README.md index 716e88297..5db8cf49d 100644 --- a/README.md +++ b/README.md @@ -103,11 +103,14 @@ Restricts which repositories a guard allows and at what integrity level: **`trusted-users`** *(optional)* — Array of GitHub usernames whose content is unconditionally elevated to `approved` integrity. Useful for granting specific external contributors (e.g., trusted open-source maintainers) the same treatment as repository members, without lowering `min-integrity` globally. Uses `max(base, approved)` so it never lowers integrity. Does not override `blocked-users`. +**`tool-call-limits`** *(optional)* — Map of tool names to per-session call limits enforced by the gateway before the backend is invoked. Positive values hard-limit that tool for the session, while `0` or an omitted entry leaves the tool unlimited. + ```json "guard-policies": { "allow-only": { "repos": ["myorg/*"], "min-integrity": "approved", + "tool-call-limits": {"issue_read": 1}, "blocked-users": ["spam-bot", "compromised-user"], "approval-labels": ["human-reviewed", "safe-for-agent"], "trusted-users": ["alice", "trusted-contributor"] diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index de8edb10a..834525c74 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -251,6 +251,8 @@ min-integrity = "unapproved" - **`trusted-users`** *(optional)*: Array of GitHub usernames whose content is unconditionally elevated to `approved` integrity. Useful for granting specific external contributors the same treatment as repository members without lowering `min-integrity` globally. Uses `max(base, approved)` so it never lowers integrity. Does not override `blocked-users`. +- **`tool-call-limits`** *(optional)*: Map of tool names to per-session call limits enforced by the gateway. Positive values cap how many times that tool may be called in one session; `0` or an omitted entry leaves the tool unlimited. + - **Meaning**: Restricts the GitHub MCP server to only access specified repositories. Tools like `get_file_contents`, `search_code`, etc. will only work on allowed repositories. Attempts to access other repositories will be denied by the guard policy. ### write-sink (output servers) diff --git a/internal/config/guard_policy.go b/internal/config/guard_policy.go index 1d94cb877..9968c97c7 100644 --- a/internal/config/guard_policy.go +++ b/internal/config/guard_policy.go @@ -38,33 +38,35 @@ type WriteSinkPolicy struct { // AllowOnlyPolicy configures scope and minimum required integrity. type AllowOnlyPolicy struct { - Repos interface{} `toml:"repos" json:"repos"` - MinIntegrity string `toml:"min-integrity" json:"min-integrity"` - BlockedUsers []string `toml:"blocked-users" json:"blocked-users,omitempty"` - ApprovalLabels []string `toml:"approval-labels" json:"approval-labels,omitempty"` - TrustedUsers []string `toml:"trusted-users" json:"trusted-users,omitempty"` - EndorsementReactions []string `toml:"endorsement-reactions" json:"endorsement-reactions,omitempty"` - DisapprovalReactions []string `toml:"disapproval-reactions" json:"disapproval-reactions,omitempty"` - DisapprovalIntegrity string `toml:"disapproval-integrity" json:"disapproval-integrity,omitempty"` - EndorserMinIntegrity string `toml:"endorser-min-integrity" json:"endorser-min-integrity,omitempty"` - PromotionLabel string `toml:"promotion-label" json:"promotion-label,omitempty"` - DemotionLabel string `toml:"demotion-label" json:"demotion-label,omitempty"` + Repos interface{} `toml:"repos" json:"repos"` + MinIntegrity string `toml:"min-integrity" json:"min-integrity"` + ToolCallLimits map[string]int `toml:"tool-call-limits" json:"tool-call-limits,omitempty"` + BlockedUsers []string `toml:"blocked-users" json:"blocked-users,omitempty"` + ApprovalLabels []string `toml:"approval-labels" json:"approval-labels,omitempty"` + TrustedUsers []string `toml:"trusted-users" json:"trusted-users,omitempty"` + EndorsementReactions []string `toml:"endorsement-reactions" json:"endorsement-reactions,omitempty"` + DisapprovalReactions []string `toml:"disapproval-reactions" json:"disapproval-reactions,omitempty"` + DisapprovalIntegrity string `toml:"disapproval-integrity" json:"disapproval-integrity,omitempty"` + EndorserMinIntegrity string `toml:"endorser-min-integrity" json:"endorser-min-integrity,omitempty"` + PromotionLabel string `toml:"promotion-label" json:"promotion-label,omitempty"` + DemotionLabel string `toml:"demotion-label" json:"demotion-label,omitempty"` } // NormalizedGuardPolicy is a canonical policy representation for caching and observability. type NormalizedGuardPolicy struct { - ScopeKind string `json:"scope_kind"` - ScopeValues []string `json:"scope_values,omitempty"` - MinIntegrity string `json:"min-integrity"` - BlockedUsers []string `json:"blocked-users,omitempty"` - ApprovalLabels []string `json:"approval-labels,omitempty"` - TrustedUsers []string `json:"trusted-users,omitempty"` - EndorsementReactions []string `json:"endorsement-reactions,omitempty"` - DisapprovalReactions []string `json:"disapproval-reactions,omitempty"` - DisapprovalIntegrity string `json:"disapproval-integrity,omitempty"` - EndorserMinIntegrity string `json:"endorser-min-integrity,omitempty"` - PromotionLabel string `json:"promotion-label,omitempty"` - DemotionLabel string `json:"demotion-label,omitempty"` + ScopeKind string `json:"scope_kind"` + ScopeValues []string `json:"scope_values,omitempty"` + MinIntegrity string `json:"min-integrity"` + ToolCallLimits map[string]int `json:"tool-call-limits,omitempty"` + BlockedUsers []string `json:"blocked-users,omitempty"` + ApprovalLabels []string `json:"approval-labels,omitempty"` + TrustedUsers []string `json:"trusted-users,omitempty"` + EndorsementReactions []string `json:"endorsement-reactions,omitempty"` + DisapprovalReactions []string `json:"disapproval-reactions,omitempty"` + DisapprovalIntegrity string `json:"disapproval-integrity,omitempty"` + EndorserMinIntegrity string `json:"endorser-min-integrity,omitempty"` + PromotionLabel string `json:"promotion-label,omitempty"` + DemotionLabel string `json:"demotion-label,omitempty"` } func (p *GuardPolicy) UnmarshalJSON(data []byte) error { @@ -144,6 +146,10 @@ func (p *AllowOnlyPolicy) UnmarshalJSON(data []byte) error { if err := json.Unmarshal(value, &p.MinIntegrity); err != nil { return fmt.Errorf("invalid allow-only.min-integrity: %w", err) } + case "tool-call-limits": + if err := json.Unmarshal(value, &p.ToolCallLimits); err != nil { + return fmt.Errorf("invalid allow-only.tool-call-limits: %w", err) + } case "blocked-users": if err := json.Unmarshal(value, &p.BlockedUsers); err != nil { return fmt.Errorf("invalid allow-only.blocked-users: %w", err) @@ -198,17 +204,18 @@ func (p *AllowOnlyPolicy) UnmarshalJSON(data []byte) error { func (p AllowOnlyPolicy) MarshalJSON() ([]byte, error) { type serializedAllowOnly struct { - Repos interface{} `json:"repos"` - MinIntegrity string `json:"min-integrity"` - BlockedUsers []string `json:"blocked-users,omitempty"` - ApprovalLabels []string `json:"approval-labels,omitempty"` - TrustedUsers []string `json:"trusted-users,omitempty"` - EndorsementReactions []string `json:"endorsement-reactions,omitempty"` - DisapprovalReactions []string `json:"disapproval-reactions,omitempty"` - DisapprovalIntegrity string `json:"disapproval-integrity,omitempty"` - EndorserMinIntegrity string `json:"endorser-min-integrity,omitempty"` - PromotionLabel string `json:"promotion-label,omitempty"` - DemotionLabel string `json:"demotion-label,omitempty"` + Repos interface{} `json:"repos"` + MinIntegrity string `json:"min-integrity"` + ToolCallLimits map[string]int `json:"tool-call-limits,omitempty"` + BlockedUsers []string `json:"blocked-users,omitempty"` + ApprovalLabels []string `json:"approval-labels,omitempty"` + TrustedUsers []string `json:"trusted-users,omitempty"` + EndorsementReactions []string `json:"endorsement-reactions,omitempty"` + DisapprovalReactions []string `json:"disapproval-reactions,omitempty"` + DisapprovalIntegrity string `json:"disapproval-integrity,omitempty"` + EndorserMinIntegrity string `json:"endorser-min-integrity,omitempty"` + PromotionLabel string `json:"promotion-label,omitempty"` + DemotionLabel string `json:"demotion-label,omitempty"` } return json.Marshal(serializedAllowOnly(p)) diff --git a/internal/config/guard_policy_test.go b/internal/config/guard_policy_test.go index 648ffd1af..db84a3e45 100644 --- a/internal/config/guard_policy_test.go +++ b/internal/config/guard_policy_test.go @@ -680,6 +680,13 @@ func TestAllowOnlyPolicyUnmarshalJSON(t *testing.T) { assert.Equal(t, []string{"evil-bot", "bad-actor"}, p.BlockedUsers) }, }, + { + name: "tool-call-limits parsed correctly", + json: `{"repos":"public","min-integrity":"none","tool-call-limits":{"issue_read":1,"list_issues":2}}`, + check: func(t *testing.T, p *AllowOnlyPolicy) { + assert.Equal(t, map[string]int{"issue_read": 1, "list_issues": 2}, p.ToolCallLimits) + }, + }, { name: "approval-labels parsed correctly", json: `{"repos":"public","min-integrity":"none","approval-labels":["approved","human-reviewed"]}`, @@ -829,6 +836,21 @@ func TestAllowOnlyPolicyMarshalJSON(t *testing.T) { assert.Contains(t, jsonStr, `"human-reviewed"`) }) + t.Run("tool-call-limits is included when set", func(t *testing.T) { + policy := AllowOnlyPolicy{ + Repos: "public", + MinIntegrity: "none", + ToolCallLimits: map[string]int{"issue_read": 1}, + } + + data, err := json.Marshal(policy) + require.NoError(t, err) + + jsonStr := string(data) + assert.Contains(t, jsonStr, `"tool-call-limits"`) + assert.Contains(t, jsonStr, `"issue_read"`) + }) + t.Run("nil blocked-users and approval-labels are omitted", func(t *testing.T) { policy := AllowOnlyPolicy{ Repos: "public", @@ -966,6 +988,16 @@ func TestValidateGuardPolicy(t *testing.T) { require.NoError(t, err) }) + t.Run("zero tool-call-limit is treated as unlimited", func(t *testing.T) { + policy := &GuardPolicy{AllowOnly: &AllowOnlyPolicy{ + Repos: "all", + MinIntegrity: "none", + ToolCallLimits: map[string]int{"issue_read": 0}, + }} + err := ValidateGuardPolicy(policy) + require.NoError(t, err) + }) + t.Run("invalid policy returns error", func(t *testing.T) { policy := &GuardPolicy{AllowOnly: &AllowOnlyPolicy{ Repos: "all", @@ -974,6 +1006,17 @@ func TestValidateGuardPolicy(t *testing.T) { err := ValidateGuardPolicy(policy) require.Error(t, err) }) + + t.Run("negative tool-call-limit returns error", func(t *testing.T) { + policy := &GuardPolicy{AllowOnly: &AllowOnlyPolicy{ + Repos: "all", + MinIntegrity: "none", + ToolCallLimits: map[string]int{"issue_read": -1}, + }} + err := ValidateGuardPolicy(policy) + require.Error(t, err) + assert.ErrorContains(t, err, `allow-only.tool-call-limits["issue_read"] must be >= 0`) + }) } // TestIsScopeTokenChar tests valid and invalid characters for scope tokens. @@ -1004,6 +1047,17 @@ func TestNormalizeGuardPolicyReactionEndorsement(t *testing.T) { assert.Equal(t, []string{"THUMBS_UP", "HEART"}, got.EndorsementReactions) }) + t.Run("tool-call-limits propagated to normalized policy", func(t *testing.T) { + policy := &GuardPolicy{AllowOnly: &AllowOnlyPolicy{ + Repos: "public", + MinIntegrity: "approved", + ToolCallLimits: map[string]int{"issue_read": 1, "list_issues": 0}, + }} + got, err := NormalizeGuardPolicy(policy) + require.NoError(t, err) + assert.Equal(t, map[string]int{"issue_read": 1, "list_issues": 0}, got.ToolCallLimits) + }) + t.Run("disapproval-reactions propagated and normalized to uppercase", func(t *testing.T) { policy := &GuardPolicy{AllowOnly: &AllowOnlyPolicy{ Repos: "public", diff --git a/internal/config/guard_policy_unmarshal_coverage_test.go b/internal/config/guard_policy_unmarshal_coverage_test.go index 593ca2554..e3dbb9817 100644 --- a/internal/config/guard_policy_unmarshal_coverage_test.go +++ b/internal/config/guard_policy_unmarshal_coverage_test.go @@ -157,6 +157,11 @@ func TestAllowOnlyPolicyUnmarshalJSON_FieldErrorPaths(t *testing.T) { json: `{"repos": "all", "min-integrity": "none", "blocked-users": "notanarray"}`, wantErr: "invalid allow-only.blocked-users", }, + { + name: "tool-call-limits field invalid JSON type", + json: `{"repos": "all", "min-integrity": "none", "tool-call-limits": "notamap"}`, + wantErr: "invalid allow-only.tool-call-limits", + }, { name: "approval-labels field invalid JSON type", json: `{"repos": "all", "min-integrity": "none", "approval-labels": 42}`, @@ -416,6 +421,7 @@ func TestAllowOnlyPolicyUnmarshalJSON_FullRoundTrip(t *testing.T) { BlockedUsers: []string{"bad-actor"}, ApprovalLabels: []string{"approved"}, TrustedUsers: []string{"contractor"}, + ToolCallLimits: map[string]int{"issue_read": 1}, EndorsementReactions: []string{"THUMBS_UP"}, DisapprovalReactions: []string{"THUMBS_DOWN"}, DisapprovalIntegrity: "none", @@ -434,6 +440,7 @@ func TestAllowOnlyPolicyUnmarshalJSON_FullRoundTrip(t *testing.T) { assert.Equal(t, original.BlockedUsers, parsed.BlockedUsers) assert.Equal(t, original.ApprovalLabels, parsed.ApprovalLabels) assert.Equal(t, original.TrustedUsers, parsed.TrustedUsers) + assert.Equal(t, original.ToolCallLimits, parsed.ToolCallLimits) assert.Equal(t, original.EndorsementReactions, parsed.EndorsementReactions) assert.Equal(t, original.DisapprovalReactions, parsed.DisapprovalReactions) assert.Equal(t, original.DisapprovalIntegrity, parsed.DisapprovalIntegrity) diff --git a/internal/config/guard_policy_validation.go b/internal/config/guard_policy_validation.go index 7c4f916c8..525cef236 100644 --- a/internal/config/guard_policy_validation.go +++ b/internal/config/guard_policy_validation.go @@ -108,6 +108,11 @@ func NormalizeGuardPolicy(policy *GuardPolicy) (*NormalizedGuardPolicy, error) { var err error + normalized.ToolCallLimits, err = normalizeToolCallLimits(policy.AllowOnly.ToolCallLimits) + if err != nil { + return nil, err + } + // Validate and normalize blocked-users, approval-labels, trusted-users. // Dedup uses lowercased keys; original trimmed values are stored. normalized.BlockedUsers, err = normalizeStringSlice("blocked-users", policy.AllowOnly.BlockedUsers, strings.ToLower, false) @@ -332,3 +337,22 @@ func normalizeStringSlice(field string, input []string, caseNorm func(string) st } return out, nil } + +func normalizeToolCallLimits(input map[string]int) (map[string]int, error) { + if len(input) == 0 { + return nil, nil + } + + out := make(map[string]int, len(input)) + for toolName, limit := range input { + toolName = strings.TrimSpace(toolName) + if toolName == "" { + return nil, fmt.Errorf("allow-only.tool-call-limits keys must not be empty") + } + if limit < 0 { + return nil, fmt.Errorf("allow-only.tool-call-limits[%q] must be >= 0", toolName) + } + out[toolName] = limit + } + return out, nil +} diff --git a/internal/server/call_backend_tool_difc_test.go b/internal/server/call_backend_tool_difc_test.go index 1e88c38c3..791bf815b 100644 --- a/internal/server/call_backend_tool_difc_test.go +++ b/internal/server/call_backend_tool_difc_test.go @@ -820,3 +820,156 @@ func TestCallBackendTool_GuardInitError(t *testing.T) { require.Error(err) assert.ErrorContains(err, "guard session initialization failed") } + +func TestCallBackendTool_ToolCallLimitEnforcedPerSession(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + + backendCalls := 0 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + method, _ := req["method"].(string) + w.Header().Set("Content-Type", "application/json") + switch method { + case "initialize": + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", "id": req["id"], + "result": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{}, + "serverInfo": map[string]interface{}{"name": "test-backend", "version": "1.0"}, + }, + }) + case "tools/list": + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", "id": req["id"], + "result": map[string]interface{}{ + "tools": []map[string]interface{}{ + { + "name": "issue_read", + "description": "test tool", + "inputSchema": map[string]interface{}{"type": "object", "properties": map[string]interface{}{}}, + }, + }, + }, + }) + case "tools/call": + backendCalls++ + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", "id": req["id"], + "result": map[string]interface{}{ + "content": []map[string]interface{}{{"type": "text", "text": "tool result"}}, + "isError": false, + }, + }) + } + })) + defer backend.Close() + + g := &difcTestGuard{name: "difc-tool-call-limit-guard"} + us := makeUnifiedWithGuard(t, "difc-tool-call-limit-type", g, backend, "strict") + us.cfg.GuardPolicy.AllowOnly.ToolCallLimits = map[string]int{"issue_read": 2} + defer us.Close() + + result, _, err := us.callBackendTool(callCtx("session-limit-a"), "test-server", "issue_read", nil) + require.NotNil(result) + require.NoError(err) + assert.False(result.IsError) + + result, _, err = us.callBackendTool(callCtx("session-limit-a"), "test-server", "issue_read", nil) + require.NotNil(result) + require.NoError(err) + assert.False(result.IsError) + + result, _, err = us.callBackendTool(callCtx("session-limit-a"), "test-server", "issue_read", nil) + require.NotNil(result) + require.Error(err) + assert.True(result.IsError) + assert.Contains(result.Content[0].(*sdk.TextContent).Text, `tool call limit reached for "issue_read" (max: 2)`) + assert.Equal(2, backendCalls, "over-limit call must not reach the backend") + + result, _, err = us.callBackendTool(callCtx("session-limit-b"), "test-server", "issue_read", nil) + require.NotNil(result) + require.NoError(err) + assert.False(result.IsError) + assert.Equal(3, backendCalls, "a new session must get a fresh per-tool budget") +} + +func TestCallBackendTool_ToolCallLimitZeroOrAbsentIsUnlimited(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + + backendCalls := 0 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + method, _ := req["method"].(string) + w.Header().Set("Content-Type", "application/json") + switch method { + case "initialize": + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", "id": req["id"], + "result": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{}, + "serverInfo": map[string]interface{}{"name": "test-backend", "version": "1.0"}, + }, + }) + case "tools/list": + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", "id": req["id"], + "result": map[string]interface{}{ + "tools": []map[string]interface{}{ + { + "name": "zero_limit_tool", + "description": "zero limit tool", + "inputSchema": map[string]interface{}{"type": "object", "properties": map[string]interface{}{}}, + }, + { + "name": "unlisted_tool", + "description": "unlisted tool", + "inputSchema": map[string]interface{}{"type": "object", "properties": map[string]interface{}{}}, + }, + }, + }, + }) + case "tools/call": + backendCalls++ + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", "id": req["id"], + "result": map[string]interface{}{ + "content": []map[string]interface{}{{"type": "text", "text": "tool result"}}, + "isError": false, + }, + }) + } + })) + defer backend.Close() + + g := &difcTestGuard{name: "difc-zero-limit-guard"} + us := makeUnifiedWithGuard(t, "difc-zero-limit-type", g, backend, "strict") + us.cfg.GuardPolicy.AllowOnly.ToolCallLimits = map[string]int{"zero_limit_tool": 0} + defer us.Close() + + for i := 0; i < 3; i++ { + result, _, err := us.callBackendTool(callCtx("session-unlimited"), "test-server", "zero_limit_tool", nil) + require.NotNil(result) + require.NoError(err) + assert.False(result.IsError) + } + for i := 0; i < 2; i++ { + result, _, err := us.callBackendTool(callCtx("session-unlimited"), "test-server", "unlisted_tool", nil) + require.NotNil(result) + require.NoError(err) + assert.False(result.IsError) + } + + assert.Equal(5, backendCalls, "zero or absent limits must not block tool calls") +} diff --git a/internal/server/guard_init.go b/internal/server/guard_init.go index 996470c10..9b0afe907 100644 --- a/internal/server/guard_init.go +++ b/internal/server/guard_init.go @@ -382,12 +382,17 @@ func (us *UnifiedServer) ensureGuardInitialized( if session.GuardInit == nil { session.GuardInit = make(map[string]*GuardSessionState) } + var toolCallLimits map[string]int + if policy.AllowOnly != nil { + toolCallLimits = copyToolCallLimits(policy.AllowOnly.ToolCallLimits) + } session.GuardInit[serverID] = &GuardSessionState{ Initialized: true, PolicyHash: policyHash, PolicySource: source, DIFCMode: mode, NormalizedPolicy: normalizedPolicy, + ToolCallLimits: toolCallLimits, } us.sessionMu.Unlock() @@ -397,6 +402,20 @@ func (us *UnifiedServer) ensureGuardInitialized( return mode, nil } +// copyToolCallLimits returns a defensive copy of tool-call-limits so per-session +// counters cannot be affected by later config mutations. Keys are trimmed of +// surrounding whitespace to match the normalization applied during validation. +func copyToolCallLimits(input map[string]int) map[string]int { + if len(input) == 0 { + return nil + } + out := make(map[string]int, len(input)) + for toolName, limit := range input { + out[strings.TrimSpace(toolName)] = limit + } + return out +} + // getTrustedBots returns the configured list of additional trusted bot usernames, // or nil if none are configured. func (us *UnifiedServer) getTrustedBots() []string { diff --git a/internal/server/unified.go b/internal/server/unified.go index 5ff9b6278..5d6fe2242 100644 --- a/internal/server/unified.go +++ b/internal/server/unified.go @@ -50,6 +50,9 @@ type GuardSessionState struct { PolicySource string DIFCMode difc.EnforcementMode NormalizedPolicy map[string]interface{} + ToolCallLimits map[string]int + ToolCallCounts map[string]int + CallCountMu sync.Mutex } // ServerStatus represents the health status of a backend server @@ -377,7 +380,8 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName oteltrace.WithSpanKind(oteltrace.SpanKindInternal), ) // httpStatusCode tracks the conceptual HTTP status of the proxied response (spec §4.1.3.6). - // It starts at 200 and is updated to 500 (error) or 403 (access denied) before each exit. + // It starts at 200 and is updated to 500 (error), 403 (access denied), or 429 (budget + // exhaustion) before each exit. httpStatusCode := 200 defer func() { toolSpan.SetAttributes(semconv.HTTPResponseStatusCodeKey.Int(httpStatusCode)) @@ -414,6 +418,12 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName httpStatusCode = 500 return mcp.NewErrorCallToolResult(fmt.Errorf("guard session initialization failed: %w", err)) } + if err := us.enforceToolCallLimit(sessionID, serverID, toolName); err != nil { + httpStatusCode = 429 + toolSpan.RecordError(err) + toolSpan.SetStatus(codes.Error, "tool call limit reached") + return mcp.NewErrorCallToolResult(err) + } requestEvaluator := difc.NewEvaluatorWithMode(enforcementMode) @@ -635,6 +645,40 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName return callResult, finalResult, nil } +// enforceToolCallLimit applies the configured per-session budget for toolName on +// the given server, incrementing the call counter for in-budget attempts and +// returning an error without incrementing when the session has exhausted its limit. +func (us *UnifiedServer) enforceToolCallLimit(sessionID, serverID, toolName string) error { + us.sessionMu.RLock() + session := us.sessions[sessionID] + var state *GuardSessionState + if session != nil { + state = session.GuardInit[serverID] + } + us.sessionMu.RUnlock() + + if state == nil || len(state.ToolCallLimits) == 0 { + return nil + } + + state.CallCountMu.Lock() + defer state.CallCountMu.Unlock() + + limit, ok := state.ToolCallLimits[toolName] + if !ok || limit == 0 { + return nil + } + if state.ToolCallCounts == nil { + state.ToolCallCounts = make(map[string]int) + } + + if state.ToolCallCounts[toolName] >= limit { + return fmt.Errorf("tool call limit reached for %q (max: %d)", toolName, limit) + } + state.ToolCallCounts[toolName]++ + return nil +} + // Run starts the unified MCP server on the specified transport func (us *UnifiedServer) Run(transport sdk.Transport) error { logger.LogInfo("startup", "Starting unified MCP server...")