Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions go/adk/pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ func CreateGoogleADKAgentWithSubagentSessionIDs(ctx context.Context, agentConfig
return nil, nil, fmt.Errorf("agent config is required")
}

toolsets := mcp.CreateToolsets(ctx, agentConfig.HttpTools, agentConfig.SseTools)
propagateToken := strings.ToLower(os.Getenv("KAGENT_PROPAGATE_TOKEN")) == "true"
toolsets := mcp.CreateToolsets(ctx, agentConfig.HttpTools, agentConfig.SseTools, propagateToken)
subagentSessionIDs := make(map[string]string)

var remoteAgentTools []tool.Tool
Expand All @@ -57,7 +58,7 @@ func CreateGoogleADKAgentWithSubagentSessionIDs(ctx context.Context, agentConfig
log.Info("Skipping remote agent with empty URL", "name", remoteAgent.Name)
continue
}
remoteTool, sessionID, err := tools.NewKAgentRemoteA2ATool(remoteAgent.Name, remoteAgent.Description, remoteAgent.Url, nil, remoteAgent.Headers)
remoteTool, sessionID, err := tools.NewKAgentRemoteA2ATool(remoteAgent.Name, remoteAgent.Description, remoteAgent.Url, nil, remoteAgent.Headers, propagateToken)
if err != nil {
return nil, nil, fmt.Errorf("failed to create remote A2A tool for %s: %w", remoteAgent.Name, err)
}
Expand Down
7 changes: 7 additions & 0 deletions go/adk/pkg/constants/const.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package constants

const (
// A2A call context's NewRequestMeta normalizes header names to lowercase.
// This is why we use "authorization" instead of "Authorization".
AuthorizationHeader = "authorization"
)
46 changes: 33 additions & 13 deletions go/adk/pkg/mcp/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/a2aproject/a2a-go/a2asrv"
"github.com/go-logr/logr"
"github.com/kagent-dev/kagent/go/adk/pkg/constants"
"github.com/kagent-dev/kagent/go/api/adk"
mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp"
"google.golang.org/adk/tool"
Expand Down Expand Up @@ -62,6 +63,7 @@ type mcpServerParams struct {
URL string
Headers map[string]string
AllowedHeaders []string // header names to forward from incoming request
PropagateToken bool // when true, Authorization is forwarded independently of AllowedHeaders
ServerType string // "http" or "sse"
Timeout *float64
SseReadTimeout *float64
Expand All @@ -73,7 +75,11 @@ type mcpServerParams struct {
// CreateToolsets creates toolsets from all configured HTTP and SSE MCP servers,
// returning the accumulated toolsets. Errors on individual servers are logged
// and skipped.
func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, sseTools []adk.SseMcpServerConfig) []tool.Toolset {
//
// When propagateToken is true, Authorization is forwarded to every MCP server
// independently of AllowedHeaders, mirroring the Python ADKTokenPropagationPlugin
// behaviour triggered by KAGENT_PROPAGATE_TOKEN.
func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, sseTools []adk.SseMcpServerConfig, propagateToken bool) []tool.Toolset {
log := logr.FromContextOrDiscard(ctx)
var toolsets []tool.Toolset

Expand All @@ -83,6 +89,7 @@ func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, ss
URL: httpTool.Params.Url,
Headers: httpTool.Params.Headers,
AllowedHeaders: httpTool.AllowedHeaders,
PropagateToken: propagateToken,
ServerType: "http",
Timeout: httpTool.Params.Timeout,
SseReadTimeout: httpTool.Params.SseReadTimeout,
Expand All @@ -103,6 +110,7 @@ func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, ss
URL: sseTool.Params.Url,
Headers: sseTool.Params.Headers,
AllowedHeaders: sseTool.AllowedHeaders,
PropagateToken: propagateToken,
ServerType: "sse",
Timeout: sseTool.Params.Timeout,
SseReadTimeout: sseTool.Params.SseReadTimeout,
Expand Down Expand Up @@ -200,11 +208,12 @@ func createTransport(ctx context.Context, params mcpServerParams) (mcpsdk.Transp
}

var httpTransport http.RoundTripper = baseTransport
if len(params.Headers) > 0 || len(params.AllowedHeaders) > 0 {
if len(params.Headers) > 0 || len(params.AllowedHeaders) > 0 || params.PropagateToken {
httpTransport = &headerRoundTripper{
base: baseTransport,
headers: params.Headers,
allowedHeaders: params.AllowedHeaders,
propagateToken: params.PropagateToken,
}
}

Expand All @@ -230,30 +239,41 @@ func createTransport(ctx context.Context, params mcpServerParams) (mcpsdk.Transp
}

// headerRoundTripper wraps an http.RoundTripper to add custom headers to all
// requests. It supports two sources of headers:
// - headers: static key/value pairs configured on the MCP server spec
// - allowedHeaders: header names to forward from the incoming A2A request;
// values are read on each call via allowedRequestHeaders directly from the
// A2A CallContext that is already present in the Go context.
//
// Static headers take precedence: if an allowed header has the same name as a
// static header, the static value wins.
// requests. It supports three sources of headers, applied in this order so that
// higher-priority sources win on collision:
// 1. propagateToken: when true, Authorization is read from the incoming A2A
// CallContext and forwarded unconditionally (independent of allowedHeaders).
// 2. allowedHeaders: explicit per-header forwarding from the A2A CallContext.
// 3. headers: static key/value pairs configured on the MCP server spec (highest
// priority — always wins).
type headerRoundTripper struct {
base http.RoundTripper
headers map[string]string
allowedHeaders []string // header names (case-insensitive) to forward from A2A context
propagateToken bool // when true, Authorization is forwarded independently
}

func (rt *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
req = req.Clone(req.Context())

// Forward allowed headers from the incoming A2A request first so that
// static headers can override them if there is a name collision.
// When KAGENT_PROPAGATE_TOKEN is set, forward Authorization from the incoming
// A2A request independently of allowedHeaders.
if rt.propagateToken {
if callCtx, ok := a2asrv.CallContextFrom(req.Context()); ok {
if meta := callCtx.RequestMeta(); meta != nil {
if vals, ok := meta.Get(constants.AuthorizationHeader); ok && len(vals) > 0 && vals[0] != "" {
req.Header.Set(constants.AuthorizationHeader, vals[0])
}
}
}
}
Comment thread
jmhbh marked this conversation as resolved.

// Forward explicitly allowed headers from the incoming A2A request.
for k, v := range allowedRequestHeaders(req.Context(), rt.allowedHeaders) {
req.Header.Set(k, v)
}

// Apply static headers (override any dynamic ones with the same name).
// Apply static headers last — they take precedence over all dynamic sources.
for key, value := range rt.headers {
req.Header.Set(key, value)
}
Expand Down
67 changes: 67 additions & 0 deletions go/adk/pkg/mcp/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,73 @@ func TestAllowedRequestHeaders_MultiValueFirstWins(t *testing.T) {
}
}

// TestPropagateToken_ForwardsAuthorizationToMCP verifies that when propagateToken
// is set on headerRoundTripper, the Authorization header from the incoming A2A
// CallContext is forwarded to the outbound MCP request independently of allowedHeaders.
func TestPropagateToken_ForwardsAuthorizationToMCP(t *testing.T) {
t.Parallel()
var capturedAuth string

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAuth = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()

ctx := a2aCtx(map[string][]string{
"Authorization": {"Bearer propagated-token"},
})

rt := &headerRoundTripper{
base: http.DefaultTransport,
propagateToken: true,
}

req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil)
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("RoundTrip failed: %v", err)
}
resp.Body.Close()

if capturedAuth != "Bearer propagated-token" {
t.Errorf("Authorization: got %q, want %q", capturedAuth, "Bearer propagated-token")
}
}

// TestPropagateToken_DoesNotForwardWhenDisabled verifies that when propagateToken
// is false, the Authorization header is not forwarded unless listed in allowedHeaders.
func TestPropagateToken_DoesNotForwardWhenDisabled(t *testing.T) {
t.Parallel()
var capturedAuth string

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAuth = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()

ctx := a2aCtx(map[string][]string{
"Authorization": {"Bearer propagated-token"},
})

rt := &headerRoundTripper{
base: http.DefaultTransport,
propagateToken: false,
}

req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil)
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("RoundTrip failed: %v", err)
}
resp.Body.Close()

if capturedAuth != "" {
t.Errorf("Authorization should not be forwarded when propagateToken=false, got %q", capturedAuth)
}
}

// TestAllowedRequestHeaders_ReturnsNilWhenNoMatches verifies that the helper returns
// nil rather than an empty map when the allowed list has entries but none of them
// appear in the request metadata.
Expand Down
60 changes: 46 additions & 14 deletions go/adk/pkg/tools/remote_a2a_tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ import (
a2atype "github.com/a2aproject/a2a-go/a2a"
"github.com/a2aproject/a2a-go/a2aclient"
"github.com/a2aproject/a2a-go/a2aclient/agentcard"
"github.com/a2aproject/a2a-go/a2asrv"
"github.com/kagent-dev/kagent/go/adk/pkg/a2a"
"github.com/kagent-dev/kagent/go/adk/pkg/constants"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"google.golang.org/adk/tool"
"google.golang.org/adk/tool/functiontool"
Expand All @@ -32,6 +34,30 @@ func (u *userIDForwardingInterceptor) Before(ctx context.Context, req *a2aclient
return ctx, nil
}

// authzForwardingInterceptor forwards the Authorization header from the
// incoming A2A request context to outbound sub-agent A2A calls.
type authzForwardingInterceptor struct {
a2aclient.PassthroughInterceptor
}

func (a *authzForwardingInterceptor) Before(ctx context.Context, req *a2aclient.Request) (context.Context, error) {
callCtx, ok := a2asrv.CallContextFrom(ctx)
if !ok {
return ctx, nil
}
meta := callCtx.RequestMeta()
if meta == nil {
return ctx, nil
}
if len(req.Meta.Get(constants.AuthorizationHeader)) > 0 {
return ctx, nil
}
if vals, ok := meta.Get(constants.AuthorizationHeader); ok && len(vals) > 0 && vals[0] != "" {
req.Meta.Append(constants.AuthorizationHeader, vals[0])
Comment thread
jmhbh marked this conversation as resolved.
}
return ctx, nil
}

// remoteA2AInput is the typed argument for the remote A2A function tool.
type remoteA2AInput struct {
Request string `json:"request"`
Expand All @@ -40,11 +66,12 @@ type remoteA2AInput struct {
// remoteA2AState holds the mutable state for one remote A2A agent connection.
// All external interaction goes through the tool.Tool returned by NewKAgentRemoteA2ATool.
type remoteA2AState struct {
name string
description string
baseURL string
httpClient *http.Client
extraHeaders map[string]string
name string
description string
baseURL string
httpClient *http.Client
extraHeaders map[string]string
propagateToken bool

a2aClient *a2aclient.Client
agentCard *a2atype.AgentCard
Expand All @@ -62,18 +89,19 @@ type remoteA2AState struct {
// The agent card is fetched lazily from baseURL/.well-known/agent.json.
// If httpClient is nil, a default client is created. The client's transport is
// wrapped with otelhttp to propagate W3C trace context to subagents.
func NewKAgentRemoteA2ATool(name, description, baseURL string, httpClient *http.Client, extraHeaders map[string]string) (tool.Tool, string, error) {
func NewKAgentRemoteA2ATool(name, description, baseURL string, httpClient *http.Client, extraHeaders map[string]string, propagateToken bool) (tool.Tool, string, error) {
if httpClient == nil {
httpClient = &http.Client{}
}
httpClient = withOTelTransport(httpClient)
state := &remoteA2AState{
name: name,
description: description,
baseURL: baseURL,
httpClient: httpClient,
extraHeaders: extraHeaders,
lastContextID: a2atype.NewContextID(),
name: name,
description: description,
baseURL: baseURL,
httpClient: httpClient,
extraHeaders: extraHeaders,
propagateToken: propagateToken,
lastContextID: a2atype.NewContextID(),
}
ft, err := functiontool.New(functiontool.Config{
Name: name,
Expand Down Expand Up @@ -119,10 +147,14 @@ func (s *remoteA2AState) ensureClient(ctx context.Context) (*a2aclient.Client, e
for k, v := range s.extraHeaders {
meta.Append(k, v)
}
opts = append(opts, a2aclient.WithInterceptors(
interceptors := []a2aclient.CallInterceptor{
a2aclient.NewStaticCallMetaInjector(meta),
&userIDForwardingInterceptor{},
))
}
if s.propagateToken {
interceptors = append(interceptors, &authzForwardingInterceptor{})
}
opts = append(opts, a2aclient.WithInterceptors(interceptors...))

client, err := a2aclient.NewFromCard(ctx, card, opts...)
if err != nil {
Expand Down
Loading
Loading