From 1c83d7e6ca211a07193a960a688592e702b0c737 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 5 Jun 2026 13:49:58 +0000 Subject: [PATCH 1/5] Initial plan From f41a7f7a669325b2ee6d0741a8ab279e68bbd9af Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 5 Jun 2026 13:56:21 +0000 Subject: [PATCH 2/5] Refactor MCP text parsing and status code access --- internal/difc/path_labels.go | 18 ++-------- internal/difc/path_labels_test.go | 17 ++++++++++ internal/mcp/tool_result.go | 8 +++++ internal/mcp/tool_result_test.go | 39 +++++++++++++++++++++ internal/mcpresult/text_content.go | 45 +++++++++++++++++++++++++ internal/server/circuit_breaker_test.go | 7 ++++ internal/server/http_helpers.go | 2 +- internal/server/middleware.go | 8 ++--- internal/server/rate_limit.go | 32 +++++------------- internal/server/response_writer.go | 6 ---- internal/server/transport_test.go | 4 +-- 11 files changed, 135 insertions(+), 51 deletions(-) create mode 100644 internal/mcpresult/text_content.go diff --git a/internal/difc/path_labels.go b/internal/difc/path_labels.go index 1cfe77f71..fc5b98d1a 100644 --- a/internal/difc/path_labels.go +++ b/internal/difc/path_labels.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/github/gh-aw-mcpg/internal/logger" + "github.com/github/gh-aw-mcpg/internal/mcpresult" ) var logPathLabels = logger.New("difc:path_labels") @@ -105,21 +106,8 @@ func unwrapMCPResponse(data interface{}) (interface{}, bool) { return data, false } - content, ok := dataMap["content"].([]interface{}) - if !ok || len(content) == 0 { - return data, false - } - - first, ok := content[0].(map[string]interface{}) - if !ok { - return data, false - } - - // Check if this looks like an MCP text content item - textType, hasType := first["type"].(string) - text, hasText := first["text"].(string) - - if !hasType || textType != "text" || !hasText { + text := mcpresult.ExtractTextContent(dataMap) + if text == "" { return data, false } diff --git a/internal/difc/path_labels_test.go b/internal/difc/path_labels_test.go index 1b82a8a4b..39c340777 100644 --- a/internal/difc/path_labels_test.go +++ b/internal/difc/path_labels_test.go @@ -552,6 +552,23 @@ func TestUnwrapMCPResponse(t *testing.T) { assert.False(t, isMCPWrapped, "Should not unwrap non-JSON text") assert.Equal(t, mcpData, unwrapped) }) + + t.Run("unwraps BuildMCPTextResponse output", func(t *testing.T) { + mcpData := map[string]interface{}{ + "content": []map[string]interface{}{ + {"type": "text", "text": `{"items":[{"id":1}]}`}, + }, + } + + unwrapped, isMCPWrapped := unwrapMCPResponse(mcpData) + + assert.True(t, isMCPWrapped, "Should unwrap []map-based MCP responses") + unwrappedMap, ok := unwrapped.(map[string]interface{}) + require.True(t, ok, "Unwrapped should be a map") + items, ok := unwrappedMap["items"].([]interface{}) + require.True(t, ok, "Should have items array") + assert.Len(t, items, 1, "Should have 1 item") + }) } func TestPathLabeledData_MCPWrappedResponse(t *testing.T) { diff --git a/internal/mcp/tool_result.go b/internal/mcp/tool_result.go index 1b93b985a..b448d34fd 100644 --- a/internal/mcp/tool_result.go +++ b/internal/mcp/tool_result.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/github/gh-aw-mcpg/internal/logger" + "github.com/github/gh-aw-mcpg/internal/mcpresult" sdk "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -188,6 +189,13 @@ func NewErrorCallToolResult(err error) (*sdk.CallToolResult, interface{}, error) }, nil, err } +// ExtractTextContentFromResult returns the concatenated text from text content +// items in a raw MCP tool result map. Content items with a missing "type" are +// treated as text items for compatibility with older callers and tests. +func ExtractTextContentFromResult(result map[string]interface{}) string { + return mcpresult.ExtractTextContent(result) +} + // BuildMCPTextResponse returns a raw MCP response map with a single text content item. func BuildMCPTextResponse(text string) map[string]interface{} { return map[string]interface{}{ diff --git a/internal/mcp/tool_result_test.go b/internal/mcp/tool_result_test.go index 40b61a182..2427839d6 100644 --- a/internal/mcp/tool_result_test.go +++ b/internal/mcp/tool_result_test.go @@ -611,6 +611,45 @@ func TestBuildMCPTextResponse(t *testing.T) { assert.Equal(text, content[0]["text"]) } +func TestExtractTextContentFromResult(t *testing.T) { + t.Run("extracts concatenated text from interface slice", func(t *testing.T) { + result := map[string]interface{}{ + "content": []interface{}{ + map[string]interface{}{"type": "text", "text": "hello "}, + map[string]interface{}{"type": "text", "text": "world"}, + }, + } + + assert.Equal(t, "hello world", ExtractTextContentFromResult(result)) + }) + + t.Run("extracts text from BuildMCPTextResponse output", func(t *testing.T) { + assert.Equal(t, "hello", ExtractTextContentFromResult(BuildMCPTextResponse("hello"))) + }) + + t.Run("treats missing type as text for compatibility", func(t *testing.T) { + result := map[string]interface{}{ + "content": []interface{}{ + map[string]interface{}{"text": "rate limit exceeded"}, + }, + } + + assert.Equal(t, "rate limit exceeded", ExtractTextContentFromResult(result)) + }) + + t.Run("skips non-text items and malformed entries", func(t *testing.T) { + result := map[string]interface{}{ + "content": []interface{}{ + "not-a-map", + map[string]interface{}{"type": "image", "text": "ignored"}, + map[string]interface{}{"type": "text", "text": "kept"}, + }, + } + + assert.Equal(t, "kept", ExtractTextContentFromResult(result)) + }) +} + // BenchmarkConvertToCallToolResult_TextContent benchmarks the common case: // a map[string]interface{} with text content items (fast path). func BenchmarkConvertToCallToolResult_TextContent(b *testing.B) { diff --git a/internal/mcpresult/text_content.go b/internal/mcpresult/text_content.go new file mode 100644 index 000000000..9c97b5529 --- /dev/null +++ b/internal/mcpresult/text_content.go @@ -0,0 +1,45 @@ +package mcpresult + +import "strings" + +// ExtractTextContent returns the concatenated text from text content items in a +// raw MCP tool result map. Content items with a missing "type" are treated as +// text items for compatibility with older callers and tests. +func ExtractTextContent(result map[string]interface{}) string { + contentVal, hasContent := result["content"] + if !hasContent || contentVal == nil { + return "" + } + + var items []map[string]interface{} + switch v := contentVal.(type) { + case []interface{}: + items = make([]map[string]interface{}, 0, len(v)) + for _, item := range v { + ci, ok := item.(map[string]interface{}) + if !ok { + continue + } + items = append(items, ci) + } + case []map[string]interface{}: + items = v + default: + return "" + } + + var text strings.Builder + for _, item := range items { + itemType, _ := item["type"].(string) + if itemType != "" && itemType != "text" { + continue + } + itemText, _ := item["text"].(string) + if itemText == "" { + continue + } + text.WriteString(itemText) + } + + return text.String() +} diff --git a/internal/server/circuit_breaker_test.go b/internal/server/circuit_breaker_test.go index a0a222bd9..7d8a22a24 100644 --- a/internal/server/circuit_breaker_test.go +++ b/internal/server/circuit_breaker_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/github/gh-aw-mcpg/internal/mcp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -419,6 +420,12 @@ func TestExtractRateLimitErrorText(t *testing.T) { } assert.Equal(t, "rate limit exceeded", extractRateLimitErrorText(result)) }) + + t.Run("extracts text from BuildMCPTextResponse output", func(t *testing.T) { + t.Parallel() + result := mcp.BuildMCPTextResponse("API rate limit exceeded") + assert.Equal(t, "API rate limit exceeded", extractRateLimitErrorText(result)) + }) } // TestCircuitBreakerState_String verifies the string representation of each circuit breaker state. diff --git a/internal/server/http_helpers.go b/internal/server/http_helpers.go index 365c6b6af..3feb62f34 100644 --- a/internal/server/http_helpers.go +++ b/internal/server/http_helpers.go @@ -57,7 +57,7 @@ func withResponseLogging(handler http.Handler) http.Handler { handler.ServeHTTP(lw, r) if len(lw.Body()) > 0 { sanitizedBody := sanitize.SanitizeString(string(lw.Body())) - logHelpers.Printf("[%s] %s %s - Status: %d, Response: %s", r.RemoteAddr, r.Method, r.URL.Path, lw.StatusCode(), sanitizedBody) + logHelpers.Printf("[%s] %s %s - Status: %d, Response: %s", r.RemoteAddr, r.Method, r.URL.Path, lw.StatusCode, sanitizedBody) } }) } diff --git a/internal/server/middleware.go b/internal/server/middleware.go index cb33d0881..718ad50c7 100644 --- a/internal/server/middleware.go +++ b/internal/server/middleware.go @@ -184,7 +184,7 @@ func WithSDKLogging(handler http.Handler, mode string) http.Handler { if jsonrpcResp.Error != nil { // Error response - this is what we're particularly interested in logSDK.Printf("<<< SDK Response [%s] ERROR status=%d duration=%v", - mode, lw.StatusCode(), duration) + mode, lw.StatusCode, duration) logSDK.Printf(" JSON-RPC Error: code=%d message=%q", jsonrpcResp.Error.Code, jsonrpcResp.Error.Message) @@ -225,7 +225,7 @@ func WithSDKLogging(handler http.Handler, mode string) http.Handler { } else { // Success response logSDK.Printf("<<< SDK Response [%s] SUCCESS status=%d duration=%v", - mode, lw.StatusCode(), duration) + mode, lw.StatusCode, duration) logSDK.Printf(" JSON-RPC Response id=%v has result=%v", jsonrpcResp.ID, jsonrpcResp.Result != nil) @@ -236,7 +236,7 @@ func WithSDKLogging(handler http.Handler, mode string) http.Handler { } else { // Could be SSE stream or other format logSDK.Printf("<<< SDK Response [%s] status=%d duration=%v (non-JSON or stream)", - mode, lw.StatusCode(), duration) + mode, lw.StatusCode, duration) if len(responseBody) < 500 { logSDK.Printf(" Raw response: %s", string(responseBody)) } else { @@ -245,7 +245,7 @@ func WithSDKLogging(handler http.Handler, mode string) http.Handler { } } else { logSDK.Printf("<<< SDK Response [%s] status=%d duration=%v (empty body)", - mode, lw.StatusCode(), duration) + mode, lw.StatusCode, duration) } }) } diff --git a/internal/server/rate_limit.go b/internal/server/rate_limit.go index a7463d8eb..9ad5bf661 100644 --- a/internal/server/rate_limit.go +++ b/internal/server/rate_limit.go @@ -4,6 +4,8 @@ import ( "strconv" "strings" "time" + + "github.com/github/gh-aw-mcpg/internal/mcp" ) // extractRateLimitErrorText extracts the text content from a raw tool result @@ -15,16 +17,8 @@ func extractRateLimitErrorText(result interface{}) string { logCircuitBreaker.Print("extractRateLimitErrorText: result is not a map, using default message") return "rate limit exceeded" } - contents, _ := m["content"].([]interface{}) - logCircuitBreaker.Printf("extractRateLimitErrorText: scanning %d content items for error text", len(contents)) - for _, c := range contents { - cm, ok := c.(map[string]interface{}) - if !ok { - continue - } - if text, ok := cm["text"].(string); ok && text != "" { - return text - } + if text := mcp.ExtractTextContentFromResult(m); text != "" { + return text } logCircuitBreaker.Print("extractRateLimitErrorText: no text content found, using default message") return "rate limit exceeded" @@ -49,19 +43,11 @@ func isRateLimitToolResult(result interface{}) (bool, time.Time) { return false, time.Time{} } - contents, _ := m["content"].([]interface{}) - logCircuitBreaker.Printf("Inspecting error tool result for rate limit: contentItems=%d", len(contents)) - for _, c := range contents { - cm, ok := c.(map[string]interface{}) - if !ok { - continue - } - text, _ := cm["text"].(string) - if isRateLimitText(text) { - resetAt := parseRateLimitResetFromText(text) - logCircuitBreaker.Printf("Rate limit detected in tool result: hasResetAt=%v", !resetAt.IsZero()) - return true, resetAt - } + text := mcp.ExtractTextContentFromResult(m) + if isRateLimitText(text) { + resetAt := parseRateLimitResetFromText(text) + logCircuitBreaker.Printf("Rate limit detected in tool result: hasResetAt=%v", !resetAt.IsZero()) + return true, resetAt } return false, time.Time{} } diff --git a/internal/server/response_writer.go b/internal/server/response_writer.go index 7d7e7f88c..641c99857 100644 --- a/internal/server/response_writer.go +++ b/internal/server/response_writer.go @@ -46,9 +46,3 @@ func (w *responseWriter) Body() []byte { logResponseWriter.Printf("Retrieving captured body: %d bytes", len(bodyBytes)) return bodyBytes } - -// StatusCode returns the captured HTTP status code -func (w *responseWriter) StatusCode() int { - logResponseWriter.Printf("Retrieving captured status code: %d", w.BaseResponseWriter.StatusCode) - return w.BaseResponseWriter.StatusCode -} diff --git a/internal/server/transport_test.go b/internal/server/transport_test.go index 0b506bf3d..ab5fda154 100644 --- a/internal/server/transport_test.go +++ b/internal/server/transport_test.go @@ -53,7 +53,7 @@ func TestLoggingResponseWriter_WriteHeader(t *testing.T) { lw.WriteHeader(tt.statusCode) // Verify status code is captured - assert.Equal(t, tt.wantStatusCode, lw.StatusCode(), "Status code should be captured") + assert.Equal(t, tt.wantStatusCode, lw.StatusCode, "Status code should be captured") }) } } @@ -134,7 +134,7 @@ func TestLoggingResponseWriter_DefaultStatusCode(t *testing.T) { lw.Write([]byte("test")) // Default status code should be 200 - assert.Equal(t, http.StatusOK, lw.StatusCode(), "Default status code should be 200") + assert.Equal(t, http.StatusOK, lw.StatusCode, "Default status code should be 200") } // TestCreateHTTPServerForMCP_OAuth tests OAuth discovery endpoint From e87673a4ed385e05f92c5ba3c997a2256f8f6797 Mon Sep 17 00:00:00 2001 From: Landon Cox Date: Fri, 5 Jun 2026 08:01:59 -0700 Subject: [PATCH 3/5] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- internal/mcpresult/text_content.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/internal/mcpresult/text_content.go b/internal/mcpresult/text_content.go index 9c97b5529..94e97d0f2 100644 --- a/internal/mcpresult/text_content.go +++ b/internal/mcpresult/text_content.go @@ -31,8 +31,13 @@ func ExtractTextContent(result map[string]interface{}) string { var text strings.Builder for _, item := range items { itemType, _ := item["type"].(string) - if itemType != "" && itemType != "text" { + switch itemType { + case "", "text": + // keep + case "image", "audio", "resource": continue + default: + // Unknown types are treated as text for compatibility with ConvertToCallToolResult. } itemText, _ := item["text"].(string) if itemText == "" { From bb780f7a1495b5ec56e814a1bf4b5ff62734fffb Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 5 Jun 2026 15:12:34 +0000 Subject: [PATCH 4/5] Fix PR merge lint failure in strutil tests --- internal/strutil/util_test.go | 88 +++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/internal/strutil/util_test.go b/internal/strutil/util_test.go index 7b0721e00..47cbfa31a 100644 --- a/internal/strutil/util_test.go +++ b/internal/strutil/util_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestGetStringFromMap(t *testing.T) { @@ -33,3 +34,90 @@ func TestGetStringFromMap(t *testing.T) { assert.Equal(t, "", GetStringFromMap(m, "owner")) }) } + +func TestCopyTrimmedStringIntMap_UtilCompat(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input map[string]int + want map[string]int + }{ + { + name: "nil input returns nil", + input: nil, + want: nil, + }, + { + name: "empty map returns nil", + input: map[string]int{}, + want: nil, + }, + { + name: "single entry copied", + input: map[string]int{"get_file": 10}, + want: map[string]int{"get_file": 10}, + }, + { + name: "multiple entries copied", + input: map[string]int{"get_file": 10, "create_issue": 5, "search_code": 100}, + want: map[string]int{"get_file": 10, "create_issue": 5, "search_code": 100}, + }, + { + name: "leading and trailing spaces trimmed from keys", + input: map[string]int{" get_file ": 10}, + want: map[string]int{"get_file": 10}, + }, + { + name: "tab characters trimmed from keys", + input: map[string]int{"\tsearch_code\t": 5}, + want: map[string]int{"search_code": 5}, + }, + { + name: "zero limit values are preserved", + input: map[string]int{"get_file": 0}, + want: map[string]int{"get_file": 0}, + }, + { + name: "negative limit values are preserved", + input: map[string]int{"get_file": -1}, + want: map[string]int{"get_file": -1}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := CopyTrimmedStringIntMap(tt.input) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestCopyTrimmedStringIntMap_DefensiveCopy(t *testing.T) { + t.Parallel() + + original := map[string]int{"get_file": 10, "create_issue": 5} + copied := CopyTrimmedStringIntMap(original) + require.NotNil(t, copied) + + copied["get_file"] = 999 + copied["new_tool"] = 1 + + assert.Equal(t, 10, original["get_file"], "mutation of copy must not affect original") + assert.NotContains(t, original, "new_tool", "new keys in copy must not appear in original") +} + +func TestCopyTrimmedStringIntMap_OriginalMutationDoesNotAffectCopy(t *testing.T) { + t.Parallel() + + original := map[string]int{"get_file": 10} + result := CopyTrimmedStringIntMap(original) + require.NotNil(t, result) + + original["get_file"] = 999 + original["new_tool"] = 1 + + assert.Equal(t, 10, result["get_file"], "mutation of original must not affect copy") + assert.NotContains(t, result, "new_tool", "new keys in original must not appear in copy") +} From 474770e9d198b3d2d697e41aa045c0b0309f3ddd Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 5 Jun 2026 15:35:48 +0000 Subject: [PATCH 5/5] Fix merged rules test expectations --- internal/config/rules/rules_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/config/rules/rules_test.go b/internal/config/rules/rules_test.go index ebf870dbc..289286a5a 100644 --- a/internal/config/rules/rules_test.go +++ b/internal/config/rules/rules_test.go @@ -180,7 +180,7 @@ func TestPositiveInteger(t *testing.T) { fieldName: "payloadSizeThreshold", jsonPath: "gateway.payloadSizeThreshold", shouldErr: true, - errMsg: "payloadSizeThreshold must be a positive integer, got 0", + errMsg: "payloadSizeThreshold must be a positive integer (>= 1), got 0", }, { name: "negative value rejected", @@ -188,7 +188,7 @@ func TestPositiveInteger(t *testing.T) { fieldName: "payload_size_threshold", jsonPath: "gateway.payload_size_threshold", shouldErr: true, - errMsg: "payload_size_threshold must be a positive integer, got -1", + errMsg: "payload_size_threshold must be a positive integer (>= 1), got -1", }, }