diff --git a/internal/mcp/pagination.go b/internal/mcp/pagination.go index d040d2f51..72557c87c 100644 --- a/internal/mcp/pagination.go +++ b/internal/mcp/pagination.go @@ -72,6 +72,7 @@ const paginateAllMaxPages = 100 // paginateAll collects all items across paginated SDK list calls. // It returns an error if the backend returns more than paginateAllMaxPages pages, // protecting against runaway backends. +// Keep loop-protection invariants aligned with internal/testutil/mcptest/validator.go:paginate. func paginateAll[T any]( serverID string, itemKind string, diff --git a/internal/proxy/collaborator_permission.go b/internal/proxy/collaborator_permission.go index 9ae5b7ff3..3fdecb513 100644 --- a/internal/proxy/collaborator_permission.go +++ b/internal/proxy/collaborator_permission.go @@ -8,6 +8,7 @@ import ( "net/http" "github.com/github/gh-aw-mcpg/internal/logger" + "github.com/github/gh-aw-mcpg/internal/strutil" ) var logCollab = logger.New("proxy:collaborator_permission") @@ -17,9 +18,9 @@ var logCollab = logger.New("proxy:collaborator_permission") // It returns the (possibly partial) values even on error so that callers can // include them in diagnostic log messages. func ParseCollaboratorPermissionArgs(argsMap map[string]interface{}) (owner, repo, username string, err error) { - owner, _ = argsMap["owner"].(string) - repo, _ = argsMap["repo"].(string) - username, _ = argsMap["username"].(string) + owner = strutil.GetStringFromMap(argsMap, "owner") + repo = strutil.GetStringFromMap(argsMap, "repo") + username = strutil.GetStringFromMap(argsMap, "username") if owner == "" || repo == "" || username == "" { logCollab.Printf("ParseCollaboratorPermissionArgs: missing required fields: owner=%q, repo=%q, username=%q", owner, repo, username) err = fmt.Errorf("get_collaborator_permission: missing owner/repo/username") diff --git a/internal/proxy/graphql.go b/internal/proxy/graphql.go index c26ba70c7..afba6f137 100644 --- a/internal/proxy/graphql.go +++ b/internal/proxy/graphql.go @@ -147,14 +147,10 @@ func extractOwnerRepo(variables map[string]interface{}, query string) (string, s // Try variables first if variables != nil { - if v, ok := variables["owner"].(string); ok { - owner = v - } - if v, ok := variables["name"].(string); ok { - repo = v - } - if v, ok := variables["repo"].(string); ok && repo == "" { - repo = v + owner = strutil.GetStringFromMap(variables, "owner") + repo = strutil.GetStringFromMap(variables, "name") + if repo == "" { + repo = strutil.GetStringFromMap(variables, "repo") } logGraphQL.Printf("extractOwnerRepo: from variables: owner=%q repo=%q", owner, repo) } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 1a3efec7c..4aa566dfa 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -21,6 +21,7 @@ import ( "github.com/github/gh-aw-mcpg/internal/httputil" "github.com/github/gh-aw-mcpg/internal/logger" "github.com/github/gh-aw-mcpg/internal/mcp" + "github.com/github/gh-aw-mcpg/internal/strutil" "github.com/github/gh-aw-mcpg/internal/tracing" ) @@ -215,9 +216,9 @@ type restBackendCaller struct { // from tool arguments, accepting either string or float64 JSON number inputs for // the identifier. func extractOwnerRepoNumber(argsMap map[string]interface{}, ownerKey, repoKey, numberKey, toolName string) (owner, repo, number string, err error) { - owner, _ = argsMap[ownerKey].(string) - repo, _ = argsMap[repoKey].(string) - number, _ = argsMap[numberKey].(string) + owner = strutil.GetStringFromMap(argsMap, ownerKey) + repo = strutil.GetStringFromMap(argsMap, repoKey) + number = strutil.GetStringFromMap(argsMap, numberKey) if number == "" { if n, ok := argsMap[numberKey].(float64); ok { number = fmt.Sprintf("%d", int(n)) diff --git a/internal/server/circuit_breaker.go b/internal/server/circuit_breaker.go index 9b4116288..89b725402 100644 --- a/internal/server/circuit_breaker.go +++ b/internal/server/circuit_breaker.go @@ -2,8 +2,6 @@ package server import ( "fmt" - "strconv" - "strings" "sync" "time" @@ -246,95 +244,3 @@ func buildCircuitBreakers(cfg *config.Config) map[string]*circuitBreaker { logCircuitBreaker.Printf("buildCircuitBreakers: created %d circuit breakers", len(cbs)) return cbs } - -// extractRateLimitErrorText extracts the text content from a raw tool result -// that has been identified as a rate-limit error. Returns the original backend -// message so agents see the actual upstream error rather than a synthetic one. -func extractRateLimitErrorText(result interface{}) string { - m, ok := result.(map[string]interface{}) - if !ok { - return "rate limit exceeded" - } - contents, _ := m["content"].([]interface{}) - for _, c := range contents { - cm, ok := c.(map[string]interface{}) - if !ok { - continue - } - if text, ok := cm["text"].(string); ok && text != "" { - return text - } - } - return "rate limit exceeded" -} - -// isRateLimitToolResult reports whether a raw tool call result indicates -// a rate-limit error from the GitHub MCP server. It inspects the `isError` -// flag and the text content for well-known rate-limit phrases. -// -// The GitHub MCP server returns rate-limit errors as: -// -// {"content":[{"type":"text","text":"... 403 API rate limit exceeded ..."}],"isError":true} -func isRateLimitToolResult(result interface{}) (bool, time.Time) { - m, ok := result.(map[string]interface{}) - if !ok { - return false, time.Time{} - } - - // Only inspect error results. - isErr, _ := m["isError"].(bool) - if !isErr { - return false, time.Time{} - } - - contents, _ := m["content"].([]interface{}) - for _, c := range contents { - cm, ok := c.(map[string]interface{}) - if !ok { - continue - } - text, _ := cm["text"].(string) - if isRateLimitText(text) { - resetAt := parseRateLimitResetFromText(text) - logCircuitBreaker.Printf("Rate limit detected in tool result: hasResetAt=%v", !resetAt.IsZero()) - return true, resetAt - } - } - return false, time.Time{} -} - -// isRateLimitText returns true when the message indicates a GitHub rate-limit error. -func isRateLimitText(text string) bool { - lower := strings.ToLower(text) - return strings.Contains(lower, "rate limit exceeded") || - (strings.Contains(lower, "rate limit") && strings.Contains(lower, "403")) || - strings.Contains(lower, "api rate limit") || - strings.Contains(lower, "secondary rate limit") || - strings.Contains(lower, "too many requests") -} - -// parseRateLimitResetFromText attempts to extract a reset timestamp from the -// rate-limit error text. The GitHub MCP server includes messages like -// "API rate limit exceeded [rate reset in 42s]". -// Returns zero time when the value cannot be parsed or is 0 seconds. -func parseRateLimitResetFromText(text string) time.Time { - // Look for "[rate reset in Ns]" pattern. - lower := strings.ToLower(text) - idx := strings.Index(lower, "rate reset in ") - if idx < 0 { - return time.Time{} - } - rest := text[idx+len("rate reset in "):] - // Find the first non-digit character. - end := strings.IndexAny(rest, "s])") - if end < 0 { - return time.Time{} - } - secs, err := strconv.ParseInt(strings.TrimSpace(rest[:end]), 10, 64) - if err != nil || secs <= 0 { - return time.Time{} - } - resetAt := time.Now().Add(time.Duration(secs) * time.Second) - logCircuitBreaker.Printf("Parsed rate limit reset time from text: resetIn=%ds, resetAt=%s", secs, resetAt.UTC().Format(time.RFC3339)) - return resetAt -} diff --git a/internal/server/rate_limit.go b/internal/server/rate_limit.go new file mode 100644 index 000000000..354798707 --- /dev/null +++ b/internal/server/rate_limit.go @@ -0,0 +1,99 @@ +package server + +import ( + "strconv" + "strings" + "time" +) + +// extractRateLimitErrorText extracts the text content from a raw tool result +// that has been identified as a rate-limit error. Returns the original backend +// message so agents see the actual upstream error rather than a synthetic one. +func extractRateLimitErrorText(result interface{}) string { + m, ok := result.(map[string]interface{}) + if !ok { + return "rate limit exceeded" + } + contents, _ := m["content"].([]interface{}) + for _, c := range contents { + cm, ok := c.(map[string]interface{}) + if !ok { + continue + } + if text, ok := cm["text"].(string); ok && text != "" { + return text + } + } + return "rate limit exceeded" +} + +// isRateLimitToolResult reports whether a raw tool call result indicates +// a rate-limit error from the GitHub MCP server. It inspects the `isError` +// flag and the text content for well-known rate-limit phrases. +// +// The GitHub MCP server returns rate-limit errors as: +// +// {"content":[{"type":"text","text":"... 403 API rate limit exceeded ..."}],"isError":true} +func isRateLimitToolResult(result interface{}) (bool, time.Time) { + m, ok := result.(map[string]interface{}) + if !ok { + return false, time.Time{} + } + + // Only inspect error results. + isErr, _ := m["isError"].(bool) + if !isErr { + return false, time.Time{} + } + + contents, _ := m["content"].([]interface{}) + for _, c := range contents { + cm, ok := c.(map[string]interface{}) + if !ok { + continue + } + text, _ := cm["text"].(string) + if isRateLimitText(text) { + resetAt := parseRateLimitResetFromText(text) + logCircuitBreaker.Printf("Rate limit detected in tool result: hasResetAt=%v", !resetAt.IsZero()) + return true, resetAt + } + } + return false, time.Time{} +} + +// isRateLimitText returns true when the message indicates a GitHub rate-limit error. +func isRateLimitText(text string) bool { + lower := strings.ToLower(text) + return strings.Contains(lower, "rate limit exceeded") || + (strings.Contains(lower, "rate limit") && strings.Contains(lower, "403")) || + strings.Contains(lower, "api rate limit") || + strings.Contains(lower, "secondary rate limit") || + strings.Contains(lower, "too many requests") +} + +// parseRateLimitResetFromText attempts to extract a reset timestamp from the +// rate-limit error text. The GitHub MCP server includes messages like +// "API rate limit exceeded [rate reset in 42s]". +// Returns zero time when the value cannot be parsed or is 0 seconds. +func parseRateLimitResetFromText(text string) time.Time { + // Look for "[rate reset in Ns]" pattern. + lower := strings.ToLower(text) + idx := strings.Index(lower, "rate reset in ") + if idx < 0 { + return time.Time{} + } + rest := text[idx+len("rate reset in "):] + // Find the first non-digit character. + end := strings.IndexAny(rest, "s])") + if end < 0 { + return time.Time{} + } + secs, err := strconv.ParseInt(strings.TrimSpace(rest[:end]), 10, 64) + if err != nil || secs <= 0 { + return time.Time{} + } + resetAt := time.Now().Add(time.Duration(secs) * time.Second) + logCircuitBreaker.Printf("Parsed rate limit reset time from text: resetIn=%ds, resetAt=%s", secs, resetAt.UTC().Format(time.RFC3339)) + return resetAt +} diff --git a/internal/strutil/map.go b/internal/strutil/map.go new file mode 100644 index 000000000..3a5407fad --- /dev/null +++ b/internal/strutil/map.go @@ -0,0 +1,8 @@ +package strutil + +// GetStringFromMap returns the string value for key when it is present and typed as string. +// It returns an empty string for missing keys, nil maps, and non-string values. +func GetStringFromMap(m map[string]interface{}, key string) string { + v, _ := m[key].(string) + return v +} diff --git a/internal/strutil/map_test.go b/internal/strutil/map_test.go new file mode 100644 index 000000000..7b0721e00 --- /dev/null +++ b/internal/strutil/map_test.go @@ -0,0 +1,35 @@ +package strutil + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetStringFromMap(t *testing.T) { + t.Parallel() + + t.Run("returns string value when present", func(t *testing.T) { + t.Parallel() + m := map[string]interface{}{"owner": "octo"} + assert.Equal(t, "octo", GetStringFromMap(m, "owner")) + }) + + t.Run("returns empty string for missing key", func(t *testing.T) { + t.Parallel() + m := map[string]interface{}{"owner": "octo"} + assert.Equal(t, "", GetStringFromMap(m, "repo")) + }) + + t.Run("returns empty string for non-string value", func(t *testing.T) { + t.Parallel() + m := map[string]interface{}{"number": float64(1)} + assert.Equal(t, "", GetStringFromMap(m, "number")) + }) + + t.Run("returns empty string for nil map", func(t *testing.T) { + t.Parallel() + var m map[string]interface{} + assert.Equal(t, "", GetStringFromMap(m, "owner")) + }) +} diff --git a/internal/testutil/mcptest/validator.go b/internal/testutil/mcptest/validator.go index ed32b9687..a82e14d3d 100644 --- a/internal/testutil/mcptest/validator.go +++ b/internal/testutil/mcptest/validator.go @@ -43,6 +43,7 @@ func NewValidatorClient(ctx context.Context, transport sdk.Transport) (*Validato // paginate collects all pages from a paginated MCP list call. // fetch is called with a cursor (empty string for the first page) and returns the items, // the next cursor (empty when done), and any error. +// Keep loop-protection invariants aligned with internal/mcp/pagination.go:paginateAll. func paginate[T any](ctx context.Context, fetch func(ctx context.Context, cursor string) ([]T, string, error)) ([]T, error) { var all []T var cursor string