Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/mcp/pagination.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ const paginateAllMaxPages = 100
// paginateAll collects all items across paginated SDK list calls.
// It returns an error if the backend returns more than paginateAllMaxPages pages,
// protecting against runaway backends.
// Keep loop-protection invariants aligned with internal/testutil/mcptest/validator.go:paginate.
func paginateAll[T any](
serverID string,
itemKind string,
Expand Down
7 changes: 4 additions & 3 deletions internal/proxy/collaborator_permission.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"

"github.com/github/gh-aw-mcpg/internal/logger"
"github.com/github/gh-aw-mcpg/internal/strutil"
)

var logCollab = logger.New("proxy:collaborator_permission")
Expand All @@ -17,9 +18,9 @@ var logCollab = logger.New("proxy:collaborator_permission")
// It returns the (possibly partial) values even on error so that callers can
// include them in diagnostic log messages.
func ParseCollaboratorPermissionArgs(argsMap map[string]interface{}) (owner, repo, username string, err error) {
owner, _ = argsMap["owner"].(string)
repo, _ = argsMap["repo"].(string)
username, _ = argsMap["username"].(string)
owner = strutil.GetStringFromMap(argsMap, "owner")
repo = strutil.GetStringFromMap(argsMap, "repo")
username = strutil.GetStringFromMap(argsMap, "username")
if owner == "" || repo == "" || username == "" {
logCollab.Printf("ParseCollaboratorPermissionArgs: missing required fields: owner=%q, repo=%q, username=%q", owner, repo, username)
err = fmt.Errorf("get_collaborator_permission: missing owner/repo/username")
Expand Down
12 changes: 4 additions & 8 deletions internal/proxy/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,10 @@ func extractOwnerRepo(variables map[string]interface{}, query string) (string, s

// Try variables first
if variables != nil {
if v, ok := variables["owner"].(string); ok {
owner = v
}
if v, ok := variables["name"].(string); ok {
repo = v
}
if v, ok := variables["repo"].(string); ok && repo == "" {
repo = v
owner = strutil.GetStringFromMap(variables, "owner")
repo = strutil.GetStringFromMap(variables, "name")
if repo == "" {
repo = strutil.GetStringFromMap(variables, "repo")
}
logGraphQL.Printf("extractOwnerRepo: from variables: owner=%q repo=%q", owner, repo)
}
Expand Down
7 changes: 4 additions & 3 deletions internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/github/gh-aw-mcpg/internal/httputil"
"github.com/github/gh-aw-mcpg/internal/logger"
"github.com/github/gh-aw-mcpg/internal/mcp"
"github.com/github/gh-aw-mcpg/internal/strutil"
"github.com/github/gh-aw-mcpg/internal/tracing"
)

Expand Down Expand Up @@ -215,9 +216,9 @@ type restBackendCaller struct {
// from tool arguments, accepting either string or float64 JSON number inputs for
// the identifier.
func extractOwnerRepoNumber(argsMap map[string]interface{}, ownerKey, repoKey, numberKey, toolName string) (owner, repo, number string, err error) {
owner, _ = argsMap[ownerKey].(string)
repo, _ = argsMap[repoKey].(string)
number, _ = argsMap[numberKey].(string)
owner = strutil.GetStringFromMap(argsMap, ownerKey)
repo = strutil.GetStringFromMap(argsMap, repoKey)
number = strutil.GetStringFromMap(argsMap, numberKey)
if number == "" {
if n, ok := argsMap[numberKey].(float64); ok {
number = fmt.Sprintf("%d", int(n))
Expand Down
94 changes: 0 additions & 94 deletions internal/server/circuit_breaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package server

import (
"fmt"
"strconv"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -246,95 +244,3 @@ func buildCircuitBreakers(cfg *config.Config) map[string]*circuitBreaker {
logCircuitBreaker.Printf("buildCircuitBreakers: created %d circuit breakers", len(cbs))
return cbs
}

// extractRateLimitErrorText extracts the text content from a raw tool result
// that has been identified as a rate-limit error. Returns the original backend
// message so agents see the actual upstream error rather than a synthetic one.
func extractRateLimitErrorText(result interface{}) string {
m, ok := result.(map[string]interface{})
if !ok {
return "rate limit exceeded"
}
contents, _ := m["content"].([]interface{})
for _, c := range contents {
cm, ok := c.(map[string]interface{})
if !ok {
continue
}
if text, ok := cm["text"].(string); ok && text != "" {
return text
}
}
return "rate limit exceeded"
}

// isRateLimitToolResult reports whether a raw tool call result indicates
// a rate-limit error from the GitHub MCP server. It inspects the `isError`
// flag and the text content for well-known rate-limit phrases.
//
// The GitHub MCP server returns rate-limit errors as:
//
// {"content":[{"type":"text","text":"... 403 API rate limit exceeded ..."}],"isError":true}
func isRateLimitToolResult(result interface{}) (bool, time.Time) {
m, ok := result.(map[string]interface{})
if !ok {
return false, time.Time{}
}

// Only inspect error results.
isErr, _ := m["isError"].(bool)
if !isErr {
return false, time.Time{}
}

contents, _ := m["content"].([]interface{})
for _, c := range contents {
cm, ok := c.(map[string]interface{})
if !ok {
continue
}
text, _ := cm["text"].(string)
if isRateLimitText(text) {
resetAt := parseRateLimitResetFromText(text)
logCircuitBreaker.Printf("Rate limit detected in tool result: hasResetAt=%v", !resetAt.IsZero())
return true, resetAt
}
}
return false, time.Time{}
}

// isRateLimitText returns true when the message indicates a GitHub rate-limit error.
func isRateLimitText(text string) bool {
lower := strings.ToLower(text)
return strings.Contains(lower, "rate limit exceeded") ||
(strings.Contains(lower, "rate limit") && strings.Contains(lower, "403")) ||
strings.Contains(lower, "api rate limit") ||
strings.Contains(lower, "secondary rate limit") ||
strings.Contains(lower, "too many requests")
}

// parseRateLimitResetFromText attempts to extract a reset timestamp from the
// rate-limit error text. The GitHub MCP server includes messages like
// "API rate limit exceeded [rate reset in 42s]".
// Returns zero time when the value cannot be parsed or is 0 seconds.
func parseRateLimitResetFromText(text string) time.Time {
// Look for "[rate reset in Ns]" pattern.
lower := strings.ToLower(text)
idx := strings.Index(lower, "rate reset in ")
if idx < 0 {
return time.Time{}
}
rest := text[idx+len("rate reset in "):]
// Find the first non-digit character.
end := strings.IndexAny(rest, "s])")
if end < 0 {
return time.Time{}
}
secs, err := strconv.ParseInt(strings.TrimSpace(rest[:end]), 10, 64)
if err != nil || secs <= 0 {
return time.Time{}
}
resetAt := time.Now().Add(time.Duration(secs) * time.Second)
logCircuitBreaker.Printf("Parsed rate limit reset time from text: resetIn=%ds, resetAt=%s", secs, resetAt.UTC().Format(time.RFC3339))
return resetAt
}
99 changes: 99 additions & 0 deletions internal/server/rate_limit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package server

import (
"strconv"
"strings"
"time"
)

// extractRateLimitErrorText extracts the text content from a raw tool result
// that has been identified as a rate-limit error. Returns the original backend
// message so agents see the actual upstream error rather than a synthetic one.
func extractRateLimitErrorText(result interface{}) string {
m, ok := result.(map[string]interface{})
if !ok {
return "rate limit exceeded"
}
contents, _ := m["content"].([]interface{})
for _, c := range contents {
cm, ok := c.(map[string]interface{})
if !ok {
continue
}
if text, ok := cm["text"].(string); ok && text != "" {
return text
}
}
return "rate limit exceeded"
}

// isRateLimitToolResult reports whether a raw tool call result indicates
// a rate-limit error from the GitHub MCP server. It inspects the `isError`
// flag and the text content for well-known rate-limit phrases.
//
// The GitHub MCP server returns rate-limit errors as:
//
// {"content":[{"type":"text","text":"... 403 API rate limit exceeded ..."}],"isError":true}
func isRateLimitToolResult(result interface{}) (bool, time.Time) {
m, ok := result.(map[string]interface{})
if !ok {
return false, time.Time{}
}

// Only inspect error results.
isErr, _ := m["isError"].(bool)
if !isErr {
return false, time.Time{}
}

contents, _ := m["content"].([]interface{})
for _, c := range contents {
cm, ok := c.(map[string]interface{})
if !ok {
continue
}
text, _ := cm["text"].(string)
if isRateLimitText(text) {
resetAt := parseRateLimitResetFromText(text)
logCircuitBreaker.Printf("Rate limit detected in tool result: hasResetAt=%v", !resetAt.IsZero())
return true, resetAt
}
}
return false, time.Time{}
}

// isRateLimitText returns true when the message indicates a GitHub rate-limit error.
func isRateLimitText(text string) bool {
lower := strings.ToLower(text)
return strings.Contains(lower, "rate limit exceeded") ||
(strings.Contains(lower, "rate limit") && strings.Contains(lower, "403")) ||
strings.Contains(lower, "api rate limit") ||
strings.Contains(lower, "secondary rate limit") ||
strings.Contains(lower, "too many requests")
}

// parseRateLimitResetFromText attempts to extract a reset timestamp from the
// rate-limit error text. The GitHub MCP server includes messages like
// "API rate limit exceeded [rate reset in 42s]".
// Returns zero time when the value cannot be parsed or is 0 seconds.
func parseRateLimitResetFromText(text string) time.Time {
// Look for "[rate reset in Ns]" pattern.
lower := strings.ToLower(text)
idx := strings.Index(lower, "rate reset in ")
if idx < 0 {
return time.Time{}
}
rest := text[idx+len("rate reset in "):]
// Find the first non-digit character.
end := strings.IndexAny(rest, "s])")
if end < 0 {
return time.Time{}
}
secs, err := strconv.ParseInt(strings.TrimSpace(rest[:end]), 10, 64)
if err != nil || secs <= 0 {
return time.Time{}
}
resetAt := time.Now().Add(time.Duration(secs) * time.Second)
logCircuitBreaker.Printf("Parsed rate limit reset time from text: resetIn=%ds, resetAt=%s", secs, resetAt.UTC().Format(time.RFC3339))
return resetAt
}
8 changes: 8 additions & 0 deletions internal/strutil/map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package strutil

// GetStringFromMap returns the string value for key when it is present and typed as string.
// It returns an empty string for missing keys, nil maps, and non-string values.
func GetStringFromMap(m map[string]interface{}, key string) string {
v, _ := m[key].(string)
return v
}
35 changes: 35 additions & 0 deletions internal/strutil/map_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package strutil

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestGetStringFromMap(t *testing.T) {
t.Parallel()

t.Run("returns string value when present", func(t *testing.T) {
t.Parallel()
m := map[string]interface{}{"owner": "octo"}
assert.Equal(t, "octo", GetStringFromMap(m, "owner"))
})

t.Run("returns empty string for missing key", func(t *testing.T) {
t.Parallel()
m := map[string]interface{}{"owner": "octo"}
assert.Equal(t, "", GetStringFromMap(m, "repo"))
})

t.Run("returns empty string for non-string value", func(t *testing.T) {
t.Parallel()
m := map[string]interface{}{"number": float64(1)}
assert.Equal(t, "", GetStringFromMap(m, "number"))
})

t.Run("returns empty string for nil map", func(t *testing.T) {
t.Parallel()
var m map[string]interface{}
assert.Equal(t, "", GetStringFromMap(m, "owner"))
})
}
1 change: 1 addition & 0 deletions internal/testutil/mcptest/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func NewValidatorClient(ctx context.Context, transport sdk.Transport) (*Validato
// paginate collects all pages from a paginated MCP list call.
// fetch is called with a cursor (empty string for the first page) and returns the items,
// the next cursor (empty when done), and any error.
// Keep loop-protection invariants aligned with internal/mcp/pagination.go:paginateAll.
func paginate[T any](ctx context.Context, fetch func(ctx context.Context, cursor string) ([]T, string, error)) ([]T, error) {
var all []T
var cursor string
Expand Down
Loading