diff --git a/go/adk/pkg/agent/agent.go b/go/adk/pkg/agent/agent.go index 7b07e903e..1aaa21af4 100644 --- a/go/adk/pkg/agent/agent.go +++ b/go/adk/pkg/agent/agent.go @@ -48,7 +48,8 @@ func CreateGoogleADKAgentWithSubagentSessionIDs(ctx context.Context, agentConfig return nil, nil, fmt.Errorf("agent config is required") } - toolsets := mcp.CreateToolsets(ctx, agentConfig.HttpTools, agentConfig.SseTools) + propagateToken := strings.ToLower(os.Getenv("KAGENT_PROPAGATE_TOKEN")) == "true" + toolsets := mcp.CreateToolsets(ctx, agentConfig.HttpTools, agentConfig.SseTools, propagateToken) subagentSessionIDs := make(map[string]string) var remoteAgentTools []tool.Tool @@ -57,7 +58,7 @@ func CreateGoogleADKAgentWithSubagentSessionIDs(ctx context.Context, agentConfig log.Info("Skipping remote agent with empty URL", "name", remoteAgent.Name) continue } - remoteTool, sessionID, err := tools.NewKAgentRemoteA2ATool(remoteAgent.Name, remoteAgent.Description, remoteAgent.Url, nil, remoteAgent.Headers) + remoteTool, sessionID, err := tools.NewKAgentRemoteA2ATool(remoteAgent.Name, remoteAgent.Description, remoteAgent.Url, nil, remoteAgent.Headers, propagateToken) if err != nil { return nil, nil, fmt.Errorf("failed to create remote A2A tool for %s: %w", remoteAgent.Name, err) } diff --git a/go/adk/pkg/constants/const.go b/go/adk/pkg/constants/const.go new file mode 100644 index 000000000..2926e96e4 --- /dev/null +++ b/go/adk/pkg/constants/const.go @@ -0,0 +1,7 @@ +package constants + +const ( + // A2A call context's NewRequestMeta normalizes header names to lowercase. + // This is why we use "authorization" instead of "Authorization". + AuthorizationHeader = "authorization" +) diff --git a/go/adk/pkg/mcp/registry.go b/go/adk/pkg/mcp/registry.go index 5f31ed58a..1dd2a4d04 100644 --- a/go/adk/pkg/mcp/registry.go +++ b/go/adk/pkg/mcp/registry.go @@ -11,6 +11,7 @@ import ( "github.com/a2aproject/a2a-go/a2asrv" "github.com/go-logr/logr" + "github.com/kagent-dev/kagent/go/adk/pkg/constants" "github.com/kagent-dev/kagent/go/api/adk" mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp" "google.golang.org/adk/tool" @@ -62,6 +63,7 @@ type mcpServerParams struct { URL string Headers map[string]string AllowedHeaders []string // header names to forward from incoming request + PropagateToken bool // when true, Authorization is forwarded independently of AllowedHeaders ServerType string // "http" or "sse" Timeout *float64 SseReadTimeout *float64 @@ -73,7 +75,11 @@ type mcpServerParams struct { // CreateToolsets creates toolsets from all configured HTTP and SSE MCP servers, // returning the accumulated toolsets. Errors on individual servers are logged // and skipped. -func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, sseTools []adk.SseMcpServerConfig) []tool.Toolset { +// +// When propagateToken is true, Authorization is forwarded to every MCP server +// independently of AllowedHeaders, mirroring the Python ADKTokenPropagationPlugin +// behaviour triggered by KAGENT_PROPAGATE_TOKEN. +func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, sseTools []adk.SseMcpServerConfig, propagateToken bool) []tool.Toolset { log := logr.FromContextOrDiscard(ctx) var toolsets []tool.Toolset @@ -83,6 +89,7 @@ func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, ss URL: httpTool.Params.Url, Headers: httpTool.Params.Headers, AllowedHeaders: httpTool.AllowedHeaders, + PropagateToken: propagateToken, ServerType: "http", Timeout: httpTool.Params.Timeout, SseReadTimeout: httpTool.Params.SseReadTimeout, @@ -103,6 +110,7 @@ func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, ss URL: sseTool.Params.Url, Headers: sseTool.Params.Headers, AllowedHeaders: sseTool.AllowedHeaders, + PropagateToken: propagateToken, ServerType: "sse", Timeout: sseTool.Params.Timeout, SseReadTimeout: sseTool.Params.SseReadTimeout, @@ -200,11 +208,12 @@ func createTransport(ctx context.Context, params mcpServerParams) (mcpsdk.Transp } var httpTransport http.RoundTripper = baseTransport - if len(params.Headers) > 0 || len(params.AllowedHeaders) > 0 { + if len(params.Headers) > 0 || len(params.AllowedHeaders) > 0 || params.PropagateToken { httpTransport = &headerRoundTripper{ base: baseTransport, headers: params.Headers, allowedHeaders: params.AllowedHeaders, + propagateToken: params.PropagateToken, } } @@ -230,30 +239,41 @@ func createTransport(ctx context.Context, params mcpServerParams) (mcpsdk.Transp } // headerRoundTripper wraps an http.RoundTripper to add custom headers to all -// requests. It supports two sources of headers: -// - headers: static key/value pairs configured on the MCP server spec -// - allowedHeaders: header names to forward from the incoming A2A request; -// values are read on each call via allowedRequestHeaders directly from the -// A2A CallContext that is already present in the Go context. -// -// Static headers take precedence: if an allowed header has the same name as a -// static header, the static value wins. +// requests. It supports three sources of headers, applied in this order so that +// higher-priority sources win on collision: +// 1. propagateToken: when true, Authorization is read from the incoming A2A +// CallContext and forwarded unconditionally (independent of allowedHeaders). +// 2. allowedHeaders: explicit per-header forwarding from the A2A CallContext. +// 3. headers: static key/value pairs configured on the MCP server spec (highest +// priority — always wins). type headerRoundTripper struct { base http.RoundTripper headers map[string]string allowedHeaders []string // header names (case-insensitive) to forward from A2A context + propagateToken bool // when true, Authorization is forwarded independently } func (rt *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { req = req.Clone(req.Context()) - // Forward allowed headers from the incoming A2A request first so that - // static headers can override them if there is a name collision. + // When KAGENT_PROPAGATE_TOKEN is set, forward Authorization from the incoming + // A2A request independently of allowedHeaders. + if rt.propagateToken { + if callCtx, ok := a2asrv.CallContextFrom(req.Context()); ok { + if meta := callCtx.RequestMeta(); meta != nil { + if vals, ok := meta.Get(constants.AuthorizationHeader); ok && len(vals) > 0 && vals[0] != "" { + req.Header.Set(constants.AuthorizationHeader, vals[0]) + } + } + } + } + + // Forward explicitly allowed headers from the incoming A2A request. for k, v := range allowedRequestHeaders(req.Context(), rt.allowedHeaders) { req.Header.Set(k, v) } - // Apply static headers (override any dynamic ones with the same name). + // Apply static headers last — they take precedence over all dynamic sources. for key, value := range rt.headers { req.Header.Set(key, value) } diff --git a/go/adk/pkg/mcp/registry_test.go b/go/adk/pkg/mcp/registry_test.go index f871dd149..7a0cdc0d3 100644 --- a/go/adk/pkg/mcp/registry_test.go +++ b/go/adk/pkg/mcp/registry_test.go @@ -239,6 +239,73 @@ func TestAllowedRequestHeaders_MultiValueFirstWins(t *testing.T) { } } +// TestPropagateToken_ForwardsAuthorizationToMCP verifies that when propagateToken +// is set on headerRoundTripper, the Authorization header from the incoming A2A +// CallContext is forwarded to the outbound MCP request independently of allowedHeaders. +func TestPropagateToken_ForwardsAuthorizationToMCP(t *testing.T) { + t.Parallel() + var capturedAuth string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + ctx := a2aCtx(map[string][]string{ + "Authorization": {"Bearer propagated-token"}, + }) + + rt := &headerRoundTripper{ + base: http.DefaultTransport, + propagateToken: true, + } + + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip failed: %v", err) + } + resp.Body.Close() + + if capturedAuth != "Bearer propagated-token" { + t.Errorf("Authorization: got %q, want %q", capturedAuth, "Bearer propagated-token") + } +} + +// TestPropagateToken_DoesNotForwardWhenDisabled verifies that when propagateToken +// is false, the Authorization header is not forwarded unless listed in allowedHeaders. +func TestPropagateToken_DoesNotForwardWhenDisabled(t *testing.T) { + t.Parallel() + var capturedAuth string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + ctx := a2aCtx(map[string][]string{ + "Authorization": {"Bearer propagated-token"}, + }) + + rt := &headerRoundTripper{ + base: http.DefaultTransport, + propagateToken: false, + } + + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip failed: %v", err) + } + resp.Body.Close() + + if capturedAuth != "" { + t.Errorf("Authorization should not be forwarded when propagateToken=false, got %q", capturedAuth) + } +} + // TestAllowedRequestHeaders_ReturnsNilWhenNoMatches verifies that the helper returns // nil rather than an empty map when the allowed list has entries but none of them // appear in the request metadata. diff --git a/go/adk/pkg/tools/remote_a2a_tool.go b/go/adk/pkg/tools/remote_a2a_tool.go index 9fc79b641..87afe88a8 100644 --- a/go/adk/pkg/tools/remote_a2a_tool.go +++ b/go/adk/pkg/tools/remote_a2a_tool.go @@ -11,7 +11,9 @@ import ( a2atype "github.com/a2aproject/a2a-go/a2a" "github.com/a2aproject/a2a-go/a2aclient" "github.com/a2aproject/a2a-go/a2aclient/agentcard" + "github.com/a2aproject/a2a-go/a2asrv" "github.com/kagent-dev/kagent/go/adk/pkg/a2a" + "github.com/kagent-dev/kagent/go/adk/pkg/constants" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "google.golang.org/adk/tool" "google.golang.org/adk/tool/functiontool" @@ -32,6 +34,30 @@ func (u *userIDForwardingInterceptor) Before(ctx context.Context, req *a2aclient return ctx, nil } +// authzForwardingInterceptor forwards the Authorization header from the +// incoming A2A request context to outbound sub-agent A2A calls. +type authzForwardingInterceptor struct { + a2aclient.PassthroughInterceptor +} + +func (a *authzForwardingInterceptor) Before(ctx context.Context, req *a2aclient.Request) (context.Context, error) { + callCtx, ok := a2asrv.CallContextFrom(ctx) + if !ok { + return ctx, nil + } + meta := callCtx.RequestMeta() + if meta == nil { + return ctx, nil + } + if len(req.Meta.Get(constants.AuthorizationHeader)) > 0 { + return ctx, nil + } + if vals, ok := meta.Get(constants.AuthorizationHeader); ok && len(vals) > 0 && vals[0] != "" { + req.Meta.Append(constants.AuthorizationHeader, vals[0]) + } + return ctx, nil +} + // remoteA2AInput is the typed argument for the remote A2A function tool. type remoteA2AInput struct { Request string `json:"request"` @@ -40,11 +66,12 @@ type remoteA2AInput struct { // remoteA2AState holds the mutable state for one remote A2A agent connection. // All external interaction goes through the tool.Tool returned by NewKAgentRemoteA2ATool. type remoteA2AState struct { - name string - description string - baseURL string - httpClient *http.Client - extraHeaders map[string]string + name string + description string + baseURL string + httpClient *http.Client + extraHeaders map[string]string + propagateToken bool a2aClient *a2aclient.Client agentCard *a2atype.AgentCard @@ -62,18 +89,19 @@ type remoteA2AState struct { // The agent card is fetched lazily from baseURL/.well-known/agent.json. // If httpClient is nil, a default client is created. The client's transport is // wrapped with otelhttp to propagate W3C trace context to subagents. -func NewKAgentRemoteA2ATool(name, description, baseURL string, httpClient *http.Client, extraHeaders map[string]string) (tool.Tool, string, error) { +func NewKAgentRemoteA2ATool(name, description, baseURL string, httpClient *http.Client, extraHeaders map[string]string, propagateToken bool) (tool.Tool, string, error) { if httpClient == nil { httpClient = &http.Client{} } httpClient = withOTelTransport(httpClient) state := &remoteA2AState{ - name: name, - description: description, - baseURL: baseURL, - httpClient: httpClient, - extraHeaders: extraHeaders, - lastContextID: a2atype.NewContextID(), + name: name, + description: description, + baseURL: baseURL, + httpClient: httpClient, + extraHeaders: extraHeaders, + propagateToken: propagateToken, + lastContextID: a2atype.NewContextID(), } ft, err := functiontool.New(functiontool.Config{ Name: name, @@ -119,10 +147,14 @@ func (s *remoteA2AState) ensureClient(ctx context.Context) (*a2aclient.Client, e for k, v := range s.extraHeaders { meta.Append(k, v) } - opts = append(opts, a2aclient.WithInterceptors( + interceptors := []a2aclient.CallInterceptor{ a2aclient.NewStaticCallMetaInjector(meta), &userIDForwardingInterceptor{}, - )) + } + if s.propagateToken { + interceptors = append(interceptors, &authzForwardingInterceptor{}) + } + opts = append(opts, a2aclient.WithInterceptors(interceptors...)) client, err := a2aclient.NewFromCard(ctx, card, opts...) if err != nil { diff --git a/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py b/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py index ff73d0c6d..3cbaabcb2 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py +++ b/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py @@ -13,7 +13,7 @@ import logging import uuid -from typing import Any, Optional, Protocol, runtime_checkable +from typing import Any, Callable, Optional, Protocol, runtime_checkable from urllib.parse import urlparse import httpx @@ -58,21 +58,28 @@ _USER_ID_CONTEXT_KEY = "x-user-id" _SOURCE_HEADER = "x-kagent-source" _SOURCE_SUBAGENT = "agent" +_HEADERS_STATE_KEY = "headers" +_EXTRA_HEADERS_CONTEXT_KEY = "_a2a_extra_headers" class _SubagentInterceptor(ClientCallInterceptor): """ - Injects the authenticated user's ID as an ``x-user-id`` HTTP header and + Injects the authenticated user's ID as an ``x-user-id`` HTTP header, marks the request as originating from an agent call via - ``x-kagent-source: agent`` on every outgoing A2A request. + ``x-kagent-source: agent``, and forwards any pre-computed propagation + headers stored in the call context state under ``_EXTRA_HEADERS_CONTEXT_KEY``. """ async def intercept(self, method_name, request_payload, http_kwargs, agent_card, context): headers = dict(http_kwargs.get("headers", {})) - # Always mark requests from a parent agent tool as subagent-originated headers[_SOURCE_HEADER] = _SOURCE_SUBAGENT - if context and _USER_ID_CONTEXT_KEY in context.state: - headers["x-user-id"] = context.state[_USER_ID_CONTEXT_KEY] + + if context: + if _USER_ID_CONTEXT_KEY in context.state: + headers["x-user-id"] = context.state[_USER_ID_CONTEXT_KEY] + extra = context.state.get(_EXTRA_HEADERS_CONTEXT_KEY) + if extra: + headers.update(extra) http_kwargs["headers"] = headers return request_payload, http_kwargs @@ -140,10 +147,12 @@ def __init__( description: str, agent_card_url: str, httpx_client: Optional[httpx.AsyncClient] = None, + header_provider: Optional[Callable[[Optional[ReadonlyContext]], dict[str, str]]] = None, ) -> None: super().__init__(name=name, description=description) self._agent_card_url = agent_card_url self._httpx_client = httpx_client + self._header_provider = header_provider self._a2a_client: Optional[A2AClient] = None self._agent_card: Optional[AgentCard] = None # Pre-generate context_id for UI session polling @@ -206,6 +215,14 @@ def _get_declaration(self) -> genai_types.FunctionDeclaration: ), ) + def _build_call_context(self, tool_context: ToolContext) -> ClientCallContext: + state: dict[str, Any] = {_USER_ID_CONTEXT_KEY: tool_context.session.user_id} + if self._header_provider: + extra_headers = self._header_provider(tool_context) + if extra_headers: + state[_EXTRA_HEADERS_CONTEXT_KEY] = extra_headers + return ClientCallContext(state=state) + async def run_async(self, *, args: dict[str, Any], tool_context: ToolContext) -> Any: """Execute the remote agent tool. @@ -239,7 +256,7 @@ async def _handle_first_call(self, args: dict[str, Any], tool_context: ToolConte # Forward the authenticated user ID so the subagent session is scoped # to the same user as the parent agent session. - call_context = ClientCallContext(state={_USER_ID_CONTEXT_KEY: tool_context.session.user_id}) + call_context = self._build_call_context(tool_context) task: Optional[Task] = None try: @@ -381,7 +398,7 @@ async def _handle_resume(self, tool_context: ToolContext) -> Any: ) client = await self._ensure_client() - call_context = ClientCallContext(state={_USER_ID_CONTEXT_KEY: tool_context.session.user_id}) + call_context = self._build_call_context(tool_context) task: Optional[Task] = None try: async for response in client.send_message(request=decision_message, context=call_context): @@ -449,6 +466,7 @@ def __init__( description: str, agent_card_url: str, httpx_client: httpx.AsyncClient, + header_provider: Optional[Callable[[Optional[ReadonlyContext]], dict[str, str]]] = None, ) -> None: super().__init__() self._httpx_client = httpx_client @@ -457,6 +475,7 @@ def __init__( description=description, agent_card_url=agent_card_url, httpx_client=httpx_client, + header_provider=header_provider, ) @property diff --git a/python/packages/kagent-adk/src/kagent/adk/cli.py b/python/packages/kagent-adk/src/kagent/adk/cli.py index b73f4ce32..e32d0aacb 100644 --- a/python/packages/kagent-adk/src/kagent/adk/cli.py +++ b/python/packages/kagent-adk/src/kagent/adk/cli.py @@ -24,7 +24,7 @@ kagent_url_override = os.getenv("KAGENT_URL") sts_well_known_uri = os.getenv("STS_WELL_KNOWN_URI") -propagate_token = os.getenv("KAGENT_PROPAGATE_TOKEN") +propagate_token = os.getenv("KAGENT_PROPAGATE_TOKEN", "").lower() == "true" uvicorn_log_level = os.getenv("UVICORN_LOG_LEVEL", os.getenv("LOG_LEVEL", "info")).lower() @@ -79,7 +79,7 @@ def static( plugins.append(LLMPassthroughPlugin()) def root_agent_factory() -> BaseAgent: - root_agent = agent_config.to_agent(app_cfg.name, sts_integration) + root_agent = agent_config.to_agent(app_cfg.name, sts_integration, propagate_token) maybe_add_skills_with_config(root_agent, agent_config) @@ -218,7 +218,7 @@ async def test_agent(agent_config: AgentConfig, agent_card: AgentCard, task: str plugins = [sts_integration] def root_agent_factory() -> BaseAgent: - root_agent = agent_config.to_agent(app_cfg.name, sts_integration) + root_agent = agent_config.to_agent(app_cfg.name, sts_integration, propagate_token) maybe_add_skills_with_config(root_agent, agent_config) return root_agent diff --git a/python/packages/kagent-adk/src/kagent/adk/types.py b/python/packages/kagent-adk/src/kagent/adk/types.py index 96c9ee6f0..5e2f4a97a 100644 --- a/python/packages/kagent-adk/src/kagent/adk/types.py +++ b/python/packages/kagent-adk/src/kagent/adk/types.py @@ -298,7 +298,9 @@ class AgentConfig(BaseModel): network: NetworkConfig | None = None context_config: ContextConfig | None = None - def to_agent(self, name: str, sts_integration: Optional[ADKTokenPropagationPlugin] = None) -> Agent: + def to_agent( + self, name: str, sts_integration: Optional[ADKTokenPropagationPlugin] = None, propagate_token: bool = False + ) -> Agent: if name is None or not str(name).strip(): raise ValueError("Agent name must be a non-empty string.") tools: list[ToolUnion] = [] @@ -400,12 +402,16 @@ async def rewrite_url_to_proxy(request: httpx.Request) -> None: timeout=timeout, ) + a2a_header_provider = None + if propagate_token: + a2a_header_provider = create_header_provider(allowed_headers=["authorization"]) tools.append( KAgentRemoteA2AToolset( name=remote_agent.name, description=remote_agent.description, agent_card_url=f"{remote_agent.url}{AGENT_CARD_WELL_KNOWN_PATH}", httpx_client=client, + header_provider=a2a_header_provider, ) ) diff --git a/python/packages/kagent-adk/tests/unittests/test_remote_a2a_tool.py b/python/packages/kagent-adk/tests/unittests/test_remote_a2a_tool.py index 3185ad8e2..8682741bb 100644 --- a/python/packages/kagent-adk/tests/unittests/test_remote_a2a_tool.py +++ b/python/packages/kagent-adk/tests/unittests/test_remote_a2a_tool.py @@ -26,6 +26,7 @@ KAgentRemoteA2ATool, KAgentRemoteA2AToolset, SubagentSessionProvider, + _SubagentInterceptor, ) # --------------------------------------------------------------------------- @@ -139,6 +140,44 @@ def _approval_ctx(confirmed: bool, payload: dict | None = None, **kwargs) -> Moc return MockToolContext(tool_confirmation=confirmation, **kwargs) +# --------------------------------------------------------------------------- +# _SubagentInterceptor header propagation tests +# --------------------------------------------------------------------------- + + +class TestSubagentInterceptorHeaderPropagation: + """Tests for header propagation in _SubagentInterceptor via context state.""" + + async def _call_intercept(self, interceptor, state: dict) -> dict: + from a2a.client.middleware import ClientCallContext + + ctx = ClientCallContext(state=state) + _, http_kwargs = await interceptor.intercept( + method_name="message/send", + request_payload={}, + http_kwargs={}, + agent_card=None, + context=ctx, + ) + return http_kwargs.get("headers", {}) + + async def test_forwards_extra_headers_from_context_state(self): + interceptor = _SubagentInterceptor() + headers = await self._call_intercept( + interceptor, + state={"x-user-id": "user1", "_a2a_extra_headers": {"authorization": "Bearer test-jwt"}}, + ) + assert headers.get("authorization") == "Bearer test-jwt" + + async def test_no_extra_headers_without_state_key(self): + interceptor = _SubagentInterceptor() + headers = await self._call_intercept( + interceptor, + state={"x-user-id": "user1", "authorization": "Bearer test-jwt"}, + ) + assert "authorization" not in headers + + # --------------------------------------------------------------------------- # First-call tests # ---------------------------------------------------------------------------