Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions internal/config/rules/rules_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,15 +180,15 @@ func TestPositiveInteger(t *testing.T) {
fieldName: "payloadSizeThreshold",
jsonPath: "gateway.payloadSizeThreshold",
shouldErr: true,
errMsg: "payloadSizeThreshold must be a positive integer, got 0",
errMsg: "payloadSizeThreshold must be a positive integer (>= 1), got 0",
},
{
name: "negative value rejected",
value: -1,
fieldName: "payload_size_threshold",
jsonPath: "gateway.payload_size_threshold",
shouldErr: true,
errMsg: "payload_size_threshold must be a positive integer, got -1",
errMsg: "payload_size_threshold must be a positive integer (>= 1), got -1",
},
}

Expand Down
18 changes: 3 additions & 15 deletions internal/difc/path_labels.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"

"github.com/github/gh-aw-mcpg/internal/logger"
"github.com/github/gh-aw-mcpg/internal/mcpresult"
)

var logPathLabels = logger.New("difc:path_labels")
Expand Down Expand Up @@ -105,21 +106,8 @@ func unwrapMCPResponse(data interface{}) (interface{}, bool) {
return data, false
}

content, ok := dataMap["content"].([]interface{})
if !ok || len(content) == 0 {
return data, false
}

first, ok := content[0].(map[string]interface{})
if !ok {
return data, false
}

// Check if this looks like an MCP text content item
textType, hasType := first["type"].(string)
text, hasText := first["text"].(string)

if !hasType || textType != "text" || !hasText {
text := mcpresult.ExtractTextContent(dataMap)
if text == "" {
return data, false
}

Expand Down
17 changes: 17 additions & 0 deletions internal/difc/path_labels_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,23 @@ func TestUnwrapMCPResponse(t *testing.T) {
assert.False(t, isMCPWrapped, "Should not unwrap non-JSON text")
assert.Equal(t, mcpData, unwrapped)
})

t.Run("unwraps BuildMCPTextResponse output", func(t *testing.T) {
mcpData := map[string]interface{}{
"content": []map[string]interface{}{
{"type": "text", "text": `{"items":[{"id":1}]}`},
},
}

unwrapped, isMCPWrapped := unwrapMCPResponse(mcpData)

assert.True(t, isMCPWrapped, "Should unwrap []map-based MCP responses")
unwrappedMap, ok := unwrapped.(map[string]interface{})
require.True(t, ok, "Unwrapped should be a map")
items, ok := unwrappedMap["items"].([]interface{})
require.True(t, ok, "Should have items array")
assert.Len(t, items, 1, "Should have 1 item")
})
}

func TestPathLabeledData_MCPWrappedResponse(t *testing.T) {
Expand Down
8 changes: 8 additions & 0 deletions internal/mcp/tool_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"

"github.com/github/gh-aw-mcpg/internal/logger"
"github.com/github/gh-aw-mcpg/internal/mcpresult"
sdk "github.com/modelcontextprotocol/go-sdk/mcp"
)

Expand Down Expand Up @@ -188,6 +189,13 @@ func NewErrorCallToolResult(err error) (*sdk.CallToolResult, interface{}, error)
}, nil, err
}

// ExtractTextContentFromResult returns the concatenated text from text content
// items in a raw MCP tool result map. Content items with a missing "type" are
// treated as text items for compatibility with older callers and tests.
func ExtractTextContentFromResult(result map[string]interface{}) string {
return mcpresult.ExtractTextContent(result)
}

// BuildMCPTextResponse returns a raw MCP response map with a single text content item.
func BuildMCPTextResponse(text string) map[string]interface{} {
return map[string]interface{}{
Expand Down
39 changes: 39 additions & 0 deletions internal/mcp/tool_result_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,45 @@ func TestBuildMCPTextResponse(t *testing.T) {
assert.Equal(text, content[0]["text"])
}

func TestExtractTextContentFromResult(t *testing.T) {
t.Run("extracts concatenated text from interface slice", func(t *testing.T) {
result := map[string]interface{}{
"content": []interface{}{
map[string]interface{}{"type": "text", "text": "hello "},
map[string]interface{}{"type": "text", "text": "world"},
},
}

assert.Equal(t, "hello world", ExtractTextContentFromResult(result))
})

t.Run("extracts text from BuildMCPTextResponse output", func(t *testing.T) {
assert.Equal(t, "hello", ExtractTextContentFromResult(BuildMCPTextResponse("hello")))
})

t.Run("treats missing type as text for compatibility", func(t *testing.T) {
result := map[string]interface{}{
"content": []interface{}{
map[string]interface{}{"text": "rate limit exceeded"},
},
}

assert.Equal(t, "rate limit exceeded", ExtractTextContentFromResult(result))
})

t.Run("skips non-text items and malformed entries", func(t *testing.T) {
result := map[string]interface{}{
"content": []interface{}{
"not-a-map",
map[string]interface{}{"type": "image", "text": "ignored"},
map[string]interface{}{"type": "text", "text": "kept"},
},
}

assert.Equal(t, "kept", ExtractTextContentFromResult(result))
})
}

// BenchmarkConvertToCallToolResult_TextContent benchmarks the common case:
// a map[string]interface{} with text content items (fast path).
func BenchmarkConvertToCallToolResult_TextContent(b *testing.B) {
Expand Down
50 changes: 50 additions & 0 deletions internal/mcpresult/text_content.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package mcpresult

import "strings"

// ExtractTextContent returns the concatenated text from text content items in a
// raw MCP tool result map. Content items with a missing "type" are treated as
// text items for compatibility with older callers and tests.
func ExtractTextContent(result map[string]interface{}) string {
contentVal, hasContent := result["content"]
if !hasContent || contentVal == nil {
return ""
}

var items []map[string]interface{}
switch v := contentVal.(type) {
case []interface{}:
items = make([]map[string]interface{}, 0, len(v))
for _, item := range v {
ci, ok := item.(map[string]interface{})
if !ok {
continue
}
items = append(items, ci)
}
case []map[string]interface{}:
items = v
default:
return ""
}

var text strings.Builder
for _, item := range items {
itemType, _ := item["type"].(string)
switch itemType {
case "", "text":
// keep
case "image", "audio", "resource":
continue
default:
// Unknown types are treated as text for compatibility with ConvertToCallToolResult.
}
Comment thread
lpcox marked this conversation as resolved.
itemText, _ := item["text"].(string)
if itemText == "" {
continue
}
text.WriteString(itemText)
}

return text.String()
}
7 changes: 7 additions & 0 deletions internal/server/circuit_breaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"testing"
"time"

"github.com/github/gh-aw-mcpg/internal/mcp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -419,6 +420,12 @@ func TestExtractRateLimitErrorText(t *testing.T) {
}
assert.Equal(t, "rate limit exceeded", extractRateLimitErrorText(result))
})

t.Run("extracts text from BuildMCPTextResponse output", func(t *testing.T) {
t.Parallel()
result := mcp.BuildMCPTextResponse("API rate limit exceeded")
assert.Equal(t, "API rate limit exceeded", extractRateLimitErrorText(result))
})
}

// TestCircuitBreakerState_String verifies the string representation of each circuit breaker state.
Expand Down
2 changes: 1 addition & 1 deletion internal/server/http_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func withResponseLogging(handler http.Handler) http.Handler {
handler.ServeHTTP(lw, r)
if len(lw.Body()) > 0 {
sanitizedBody := sanitize.SanitizeString(string(lw.Body()))
logHelpers.Printf("[%s] %s %s - Status: %d, Response: %s", r.RemoteAddr, r.Method, r.URL.Path, lw.StatusCode(), sanitizedBody)
logHelpers.Printf("[%s] %s %s - Status: %d, Response: %s", r.RemoteAddr, r.Method, r.URL.Path, lw.StatusCode, sanitizedBody)
}
})
}
Expand Down
8 changes: 4 additions & 4 deletions internal/server/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func WithSDKLogging(handler http.Handler, mode string) http.Handler {
if jsonrpcResp.Error != nil {
// Error response - this is what we're particularly interested in
logSDK.Printf("<<< SDK Response [%s] ERROR status=%d duration=%v",
mode, lw.StatusCode(), duration)
mode, lw.StatusCode, duration)
logSDK.Printf(" JSON-RPC Error: code=%d message=%q",
jsonrpcResp.Error.Code, jsonrpcResp.Error.Message)

Expand Down Expand Up @@ -225,7 +225,7 @@ func WithSDKLogging(handler http.Handler, mode string) http.Handler {
} else {
// Success response
logSDK.Printf("<<< SDK Response [%s] SUCCESS status=%d duration=%v",
mode, lw.StatusCode(), duration)
mode, lw.StatusCode, duration)
logSDK.Printf(" JSON-RPC Response id=%v has result=%v",
jsonrpcResp.ID, jsonrpcResp.Result != nil)

Expand All @@ -236,7 +236,7 @@ func WithSDKLogging(handler http.Handler, mode string) http.Handler {
} else {
// Could be SSE stream or other format
logSDK.Printf("<<< SDK Response [%s] status=%d duration=%v (non-JSON or stream)",
mode, lw.StatusCode(), duration)
mode, lw.StatusCode, duration)
if len(responseBody) < 500 {
logSDK.Printf(" Raw response: %s", string(responseBody))
} else {
Expand All @@ -245,7 +245,7 @@ func WithSDKLogging(handler http.Handler, mode string) http.Handler {
}
} else {
logSDK.Printf("<<< SDK Response [%s] status=%d duration=%v (empty body)",
mode, lw.StatusCode(), duration)
mode, lw.StatusCode, duration)
}
})
}
32 changes: 9 additions & 23 deletions internal/server/rate_limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"strconv"
"strings"
"time"

"github.com/github/gh-aw-mcpg/internal/mcp"
)

// extractRateLimitErrorText extracts the text content from a raw tool result
Expand All @@ -15,16 +17,8 @@ func extractRateLimitErrorText(result interface{}) string {
logCircuitBreaker.Print("extractRateLimitErrorText: result is not a map, using default message")
return "rate limit exceeded"
}
contents, _ := m["content"].([]interface{})
logCircuitBreaker.Printf("extractRateLimitErrorText: scanning %d content items for error text", len(contents))
for _, c := range contents {
cm, ok := c.(map[string]interface{})
if !ok {
continue
}
if text, ok := cm["text"].(string); ok && text != "" {
return text
}
if text := mcp.ExtractTextContentFromResult(m); text != "" {
return text
}
logCircuitBreaker.Print("extractRateLimitErrorText: no text content found, using default message")
return "rate limit exceeded"
Expand All @@ -49,19 +43,11 @@ func isRateLimitToolResult(result interface{}) (bool, time.Time) {
return false, time.Time{}
}

contents, _ := m["content"].([]interface{})
logCircuitBreaker.Printf("Inspecting error tool result for rate limit: contentItems=%d", len(contents))
for _, c := range contents {
cm, ok := c.(map[string]interface{})
if !ok {
continue
}
text, _ := cm["text"].(string)
if isRateLimitText(text) {
resetAt := parseRateLimitResetFromText(text)
logCircuitBreaker.Printf("Rate limit detected in tool result: hasResetAt=%v", !resetAt.IsZero())
return true, resetAt
}
text := mcp.ExtractTextContentFromResult(m)
if isRateLimitText(text) {
resetAt := parseRateLimitResetFromText(text)
logCircuitBreaker.Printf("Rate limit detected in tool result: hasResetAt=%v", !resetAt.IsZero())
return true, resetAt
}
return false, time.Time{}
}
Expand Down
6 changes: 0 additions & 6 deletions internal/server/response_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,3 @@ func (w *responseWriter) Body() []byte {
logResponseWriter.Printf("Retrieving captured body: %d bytes", len(bodyBytes))
return bodyBytes
}

// StatusCode returns the captured HTTP status code
func (w *responseWriter) StatusCode() int {
logResponseWriter.Printf("Retrieving captured status code: %d", w.BaseResponseWriter.StatusCode)
return w.BaseResponseWriter.StatusCode
}
4 changes: 2 additions & 2 deletions internal/server/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestLoggingResponseWriter_WriteHeader(t *testing.T) {
lw.WriteHeader(tt.statusCode)

// Verify status code is captured
assert.Equal(t, tt.wantStatusCode, lw.StatusCode(), "Status code should be captured")
assert.Equal(t, tt.wantStatusCode, lw.StatusCode, "Status code should be captured")
})
}
}
Expand Down Expand Up @@ -134,7 +134,7 @@ func TestLoggingResponseWriter_DefaultStatusCode(t *testing.T) {
lw.Write([]byte("test"))

// Default status code should be 200
assert.Equal(t, http.StatusOK, lw.StatusCode(), "Default status code should be 200")
assert.Equal(t, http.StatusOK, lw.StatusCode, "Default status code should be 200")
}

// TestCreateHTTPServerForMCP_OAuth tests OAuth discovery endpoint
Expand Down
Loading