diff --git a/internal/config/rules/rules_test.go b/internal/config/rules/rules_test.go index ebf870dbc..289286a5a 100644 --- a/internal/config/rules/rules_test.go +++ b/internal/config/rules/rules_test.go @@ -180,7 +180,7 @@ func TestPositiveInteger(t *testing.T) { fieldName: "payloadSizeThreshold", jsonPath: "gateway.payloadSizeThreshold", shouldErr: true, - errMsg: "payloadSizeThreshold must be a positive integer, got 0", + errMsg: "payloadSizeThreshold must be a positive integer (>= 1), got 0", }, { name: "negative value rejected", @@ -188,7 +188,7 @@ func TestPositiveInteger(t *testing.T) { fieldName: "payload_size_threshold", jsonPath: "gateway.payload_size_threshold", shouldErr: true, - errMsg: "payload_size_threshold must be a positive integer, got -1", + errMsg: "payload_size_threshold must be a positive integer (>= 1), got -1", }, } diff --git a/internal/difc/path_labels.go b/internal/difc/path_labels.go index 1cfe77f71..fc5b98d1a 100644 --- a/internal/difc/path_labels.go +++ b/internal/difc/path_labels.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/github/gh-aw-mcpg/internal/logger" + "github.com/github/gh-aw-mcpg/internal/mcpresult" ) var logPathLabels = logger.New("difc:path_labels") @@ -105,21 +106,8 @@ func unwrapMCPResponse(data interface{}) (interface{}, bool) { return data, false } - content, ok := dataMap["content"].([]interface{}) - if !ok || len(content) == 0 { - return data, false - } - - first, ok := content[0].(map[string]interface{}) - if !ok { - return data, false - } - - // Check if this looks like an MCP text content item - textType, hasType := first["type"].(string) - text, hasText := first["text"].(string) - - if !hasType || textType != "text" || !hasText { + text := mcpresult.ExtractTextContent(dataMap) + if text == "" { return data, false } diff --git a/internal/difc/path_labels_test.go b/internal/difc/path_labels_test.go index 1b82a8a4b..39c340777 100644 --- a/internal/difc/path_labels_test.go +++ b/internal/difc/path_labels_test.go @@ -552,6 +552,23 @@ func TestUnwrapMCPResponse(t *testing.T) { assert.False(t, isMCPWrapped, "Should not unwrap non-JSON text") assert.Equal(t, mcpData, unwrapped) }) + + t.Run("unwraps BuildMCPTextResponse output", func(t *testing.T) { + mcpData := map[string]interface{}{ + "content": []map[string]interface{}{ + {"type": "text", "text": `{"items":[{"id":1}]}`}, + }, + } + + unwrapped, isMCPWrapped := unwrapMCPResponse(mcpData) + + assert.True(t, isMCPWrapped, "Should unwrap []map-based MCP responses") + unwrappedMap, ok := unwrapped.(map[string]interface{}) + require.True(t, ok, "Unwrapped should be a map") + items, ok := unwrappedMap["items"].([]interface{}) + require.True(t, ok, "Should have items array") + assert.Len(t, items, 1, "Should have 1 item") + }) } func TestPathLabeledData_MCPWrappedResponse(t *testing.T) { diff --git a/internal/mcp/tool_result.go b/internal/mcp/tool_result.go index 1b93b985a..b448d34fd 100644 --- a/internal/mcp/tool_result.go +++ b/internal/mcp/tool_result.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/github/gh-aw-mcpg/internal/logger" + "github.com/github/gh-aw-mcpg/internal/mcpresult" sdk "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -188,6 +189,13 @@ func NewErrorCallToolResult(err error) (*sdk.CallToolResult, interface{}, error) }, nil, err } +// ExtractTextContentFromResult returns the concatenated text from text content +// items in a raw MCP tool result map. Content items with a missing "type" are +// treated as text items for compatibility with older callers and tests. +func ExtractTextContentFromResult(result map[string]interface{}) string { + return mcpresult.ExtractTextContent(result) +} + // BuildMCPTextResponse returns a raw MCP response map with a single text content item. func BuildMCPTextResponse(text string) map[string]interface{} { return map[string]interface{}{ diff --git a/internal/mcp/tool_result_test.go b/internal/mcp/tool_result_test.go index 40b61a182..2427839d6 100644 --- a/internal/mcp/tool_result_test.go +++ b/internal/mcp/tool_result_test.go @@ -611,6 +611,45 @@ func TestBuildMCPTextResponse(t *testing.T) { assert.Equal(text, content[0]["text"]) } +func TestExtractTextContentFromResult(t *testing.T) { + t.Run("extracts concatenated text from interface slice", func(t *testing.T) { + result := map[string]interface{}{ + "content": []interface{}{ + map[string]interface{}{"type": "text", "text": "hello "}, + map[string]interface{}{"type": "text", "text": "world"}, + }, + } + + assert.Equal(t, "hello world", ExtractTextContentFromResult(result)) + }) + + t.Run("extracts text from BuildMCPTextResponse output", func(t *testing.T) { + assert.Equal(t, "hello", ExtractTextContentFromResult(BuildMCPTextResponse("hello"))) + }) + + t.Run("treats missing type as text for compatibility", func(t *testing.T) { + result := map[string]interface{}{ + "content": []interface{}{ + map[string]interface{}{"text": "rate limit exceeded"}, + }, + } + + assert.Equal(t, "rate limit exceeded", ExtractTextContentFromResult(result)) + }) + + t.Run("skips non-text items and malformed entries", func(t *testing.T) { + result := map[string]interface{}{ + "content": []interface{}{ + "not-a-map", + map[string]interface{}{"type": "image", "text": "ignored"}, + map[string]interface{}{"type": "text", "text": "kept"}, + }, + } + + assert.Equal(t, "kept", ExtractTextContentFromResult(result)) + }) +} + // BenchmarkConvertToCallToolResult_TextContent benchmarks the common case: // a map[string]interface{} with text content items (fast path). func BenchmarkConvertToCallToolResult_TextContent(b *testing.B) { diff --git a/internal/mcpresult/text_content.go b/internal/mcpresult/text_content.go new file mode 100644 index 000000000..94e97d0f2 --- /dev/null +++ b/internal/mcpresult/text_content.go @@ -0,0 +1,50 @@ +package mcpresult + +import "strings" + +// ExtractTextContent returns the concatenated text from text content items in a +// raw MCP tool result map. Content items with a missing "type" are treated as +// text items for compatibility with older callers and tests. +func ExtractTextContent(result map[string]interface{}) string { + contentVal, hasContent := result["content"] + if !hasContent || contentVal == nil { + return "" + } + + var items []map[string]interface{} + switch v := contentVal.(type) { + case []interface{}: + items = make([]map[string]interface{}, 0, len(v)) + for _, item := range v { + ci, ok := item.(map[string]interface{}) + if !ok { + continue + } + items = append(items, ci) + } + case []map[string]interface{}: + items = v + default: + return "" + } + + var text strings.Builder + for _, item := range items { + itemType, _ := item["type"].(string) + switch itemType { + case "", "text": + // keep + case "image", "audio", "resource": + continue + default: + // Unknown types are treated as text for compatibility with ConvertToCallToolResult. + } + itemText, _ := item["text"].(string) + if itemText == "" { + continue + } + text.WriteString(itemText) + } + + return text.String() +} diff --git a/internal/server/circuit_breaker_test.go b/internal/server/circuit_breaker_test.go index a0a222bd9..7d8a22a24 100644 --- a/internal/server/circuit_breaker_test.go +++ b/internal/server/circuit_breaker_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/github/gh-aw-mcpg/internal/mcp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -419,6 +420,12 @@ func TestExtractRateLimitErrorText(t *testing.T) { } assert.Equal(t, "rate limit exceeded", extractRateLimitErrorText(result)) }) + + t.Run("extracts text from BuildMCPTextResponse output", func(t *testing.T) { + t.Parallel() + result := mcp.BuildMCPTextResponse("API rate limit exceeded") + assert.Equal(t, "API rate limit exceeded", extractRateLimitErrorText(result)) + }) } // TestCircuitBreakerState_String verifies the string representation of each circuit breaker state. diff --git a/internal/server/http_helpers.go b/internal/server/http_helpers.go index 365c6b6af..3feb62f34 100644 --- a/internal/server/http_helpers.go +++ b/internal/server/http_helpers.go @@ -57,7 +57,7 @@ func withResponseLogging(handler http.Handler) http.Handler { handler.ServeHTTP(lw, r) if len(lw.Body()) > 0 { sanitizedBody := sanitize.SanitizeString(string(lw.Body())) - logHelpers.Printf("[%s] %s %s - Status: %d, Response: %s", r.RemoteAddr, r.Method, r.URL.Path, lw.StatusCode(), sanitizedBody) + logHelpers.Printf("[%s] %s %s - Status: %d, Response: %s", r.RemoteAddr, r.Method, r.URL.Path, lw.StatusCode, sanitizedBody) } }) } diff --git a/internal/server/middleware.go b/internal/server/middleware.go index cb33d0881..718ad50c7 100644 --- a/internal/server/middleware.go +++ b/internal/server/middleware.go @@ -184,7 +184,7 @@ func WithSDKLogging(handler http.Handler, mode string) http.Handler { if jsonrpcResp.Error != nil { // Error response - this is what we're particularly interested in logSDK.Printf("<<< SDK Response [%s] ERROR status=%d duration=%v", - mode, lw.StatusCode(), duration) + mode, lw.StatusCode, duration) logSDK.Printf(" JSON-RPC Error: code=%d message=%q", jsonrpcResp.Error.Code, jsonrpcResp.Error.Message) @@ -225,7 +225,7 @@ func WithSDKLogging(handler http.Handler, mode string) http.Handler { } else { // Success response logSDK.Printf("<<< SDK Response [%s] SUCCESS status=%d duration=%v", - mode, lw.StatusCode(), duration) + mode, lw.StatusCode, duration) logSDK.Printf(" JSON-RPC Response id=%v has result=%v", jsonrpcResp.ID, jsonrpcResp.Result != nil) @@ -236,7 +236,7 @@ func WithSDKLogging(handler http.Handler, mode string) http.Handler { } else { // Could be SSE stream or other format logSDK.Printf("<<< SDK Response [%s] status=%d duration=%v (non-JSON or stream)", - mode, lw.StatusCode(), duration) + mode, lw.StatusCode, duration) if len(responseBody) < 500 { logSDK.Printf(" Raw response: %s", string(responseBody)) } else { @@ -245,7 +245,7 @@ func WithSDKLogging(handler http.Handler, mode string) http.Handler { } } else { logSDK.Printf("<<< SDK Response [%s] status=%d duration=%v (empty body)", - mode, lw.StatusCode(), duration) + mode, lw.StatusCode, duration) } }) } diff --git a/internal/server/rate_limit.go b/internal/server/rate_limit.go index a7463d8eb..9ad5bf661 100644 --- a/internal/server/rate_limit.go +++ b/internal/server/rate_limit.go @@ -4,6 +4,8 @@ import ( "strconv" "strings" "time" + + "github.com/github/gh-aw-mcpg/internal/mcp" ) // extractRateLimitErrorText extracts the text content from a raw tool result @@ -15,16 +17,8 @@ func extractRateLimitErrorText(result interface{}) string { logCircuitBreaker.Print("extractRateLimitErrorText: result is not a map, using default message") return "rate limit exceeded" } - contents, _ := m["content"].([]interface{}) - logCircuitBreaker.Printf("extractRateLimitErrorText: scanning %d content items for error text", len(contents)) - for _, c := range contents { - cm, ok := c.(map[string]interface{}) - if !ok { - continue - } - if text, ok := cm["text"].(string); ok && text != "" { - return text - } + if text := mcp.ExtractTextContentFromResult(m); text != "" { + return text } logCircuitBreaker.Print("extractRateLimitErrorText: no text content found, using default message") return "rate limit exceeded" @@ -49,19 +43,11 @@ func isRateLimitToolResult(result interface{}) (bool, time.Time) { return false, time.Time{} } - contents, _ := m["content"].([]interface{}) - logCircuitBreaker.Printf("Inspecting error tool result for rate limit: contentItems=%d", len(contents)) - for _, c := range contents { - cm, ok := c.(map[string]interface{}) - if !ok { - continue - } - text, _ := cm["text"].(string) - if isRateLimitText(text) { - resetAt := parseRateLimitResetFromText(text) - logCircuitBreaker.Printf("Rate limit detected in tool result: hasResetAt=%v", !resetAt.IsZero()) - return true, resetAt - } + text := mcp.ExtractTextContentFromResult(m) + if isRateLimitText(text) { + resetAt := parseRateLimitResetFromText(text) + logCircuitBreaker.Printf("Rate limit detected in tool result: hasResetAt=%v", !resetAt.IsZero()) + return true, resetAt } return false, time.Time{} } diff --git a/internal/server/response_writer.go b/internal/server/response_writer.go index 7d7e7f88c..641c99857 100644 --- a/internal/server/response_writer.go +++ b/internal/server/response_writer.go @@ -46,9 +46,3 @@ func (w *responseWriter) Body() []byte { logResponseWriter.Printf("Retrieving captured body: %d bytes", len(bodyBytes)) return bodyBytes } - -// StatusCode returns the captured HTTP status code -func (w *responseWriter) StatusCode() int { - logResponseWriter.Printf("Retrieving captured status code: %d", w.BaseResponseWriter.StatusCode) - return w.BaseResponseWriter.StatusCode -} diff --git a/internal/server/transport_test.go b/internal/server/transport_test.go index 0b506bf3d..ab5fda154 100644 --- a/internal/server/transport_test.go +++ b/internal/server/transport_test.go @@ -53,7 +53,7 @@ func TestLoggingResponseWriter_WriteHeader(t *testing.T) { lw.WriteHeader(tt.statusCode) // Verify status code is captured - assert.Equal(t, tt.wantStatusCode, lw.StatusCode(), "Status code should be captured") + assert.Equal(t, tt.wantStatusCode, lw.StatusCode, "Status code should be captured") }) } } @@ -134,7 +134,7 @@ func TestLoggingResponseWriter_DefaultStatusCode(t *testing.T) { lw.Write([]byte("test")) // Default status code should be 200 - assert.Equal(t, http.StatusOK, lw.StatusCode(), "Default status code should be 200") + assert.Equal(t, http.StatusOK, lw.StatusCode, "Default status code should be 200") } // TestCreateHTTPServerForMCP_OAuth tests OAuth discovery endpoint