From e45c7b194ff4fcd8080db131947e86ac773e30d8 Mon Sep 17 00:00:00 2001 From: JM Huibonhoa Date: Tue, 12 May 2026 17:38:29 -0400 Subject: [PATCH 01/10] fix: KAGENT_PROPAGATE_TOKEN now propagates access token subagents and mcp tools for both the go and python adk runtimes Signed-off-by: JM Huibonhoa --- go/adk/pkg/agent/agent.go | 5 +- go/adk/pkg/mcp/registry.go | 48 +++++++++---- go/adk/pkg/mcp/registry_test.go | 67 +++++++++++++++++++ go/adk/pkg/tools/remote_a2a_tool.go | 58 ++++++++++++---- .../src/kagent/adk/_remote_a2a_tool.py | 43 ++++++++++-- .../tests/unittests/test_remote_a2a_tool.py | 39 +++++++++++ 6 files changed, 226 insertions(+), 34 deletions(-) diff --git a/go/adk/pkg/agent/agent.go b/go/adk/pkg/agent/agent.go index 7b07e903e..f779648e6 100644 --- a/go/adk/pkg/agent/agent.go +++ b/go/adk/pkg/agent/agent.go @@ -48,7 +48,8 @@ func CreateGoogleADKAgentWithSubagentSessionIDs(ctx context.Context, agentConfig return nil, nil, fmt.Errorf("agent config is required") } - toolsets := mcp.CreateToolsets(ctx, agentConfig.HttpTools, agentConfig.SseTools) + propagateToken := os.Getenv("KAGENT_PROPAGATE_TOKEN") != "" + toolsets := mcp.CreateToolsets(ctx, agentConfig.HttpTools, agentConfig.SseTools, propagateToken) subagentSessionIDs := make(map[string]string) var remoteAgentTools []tool.Tool @@ -57,7 +58,7 @@ func CreateGoogleADKAgentWithSubagentSessionIDs(ctx context.Context, agentConfig log.Info("Skipping remote agent with empty URL", "name", remoteAgent.Name) continue } - remoteTool, sessionID, err := tools.NewKAgentRemoteA2ATool(remoteAgent.Name, remoteAgent.Description, remoteAgent.Url, nil, remoteAgent.Headers) + remoteTool, sessionID, err := tools.NewKAgentRemoteA2ATool(remoteAgent.Name, remoteAgent.Description, remoteAgent.Url, nil, remoteAgent.Headers, propagateToken) if err != nil { return nil, nil, fmt.Errorf("failed to create remote A2A tool for %s: %w", remoteAgent.Name, err) } diff --git a/go/adk/pkg/mcp/registry.go b/go/adk/pkg/mcp/registry.go index 5f31ed58a..0aa140922 100644 --- a/go/adk/pkg/mcp/registry.go +++ b/go/adk/pkg/mcp/registry.go @@ -62,6 +62,7 @@ type mcpServerParams struct { URL string Headers map[string]string AllowedHeaders []string // header names to forward from incoming request + PropagateToken bool // when true, Authorization is forwarded independently of AllowedHeaders ServerType string // "http" or "sse" Timeout *float64 SseReadTimeout *float64 @@ -73,7 +74,11 @@ type mcpServerParams struct { // CreateToolsets creates toolsets from all configured HTTP and SSE MCP servers, // returning the accumulated toolsets. Errors on individual servers are logged // and skipped. -func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, sseTools []adk.SseMcpServerConfig) []tool.Toolset { +// +// When propagateToken is true, Authorization is forwarded to every MCP server +// independently of AllowedHeaders, mirroring the Python ADKTokenPropagationPlugin +// behaviour triggered by KAGENT_PROPAGATE_TOKEN. +func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, sseTools []adk.SseMcpServerConfig, propagateToken bool) []tool.Toolset { log := logr.FromContextOrDiscard(ctx) var toolsets []tool.Toolset @@ -83,6 +88,7 @@ func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, ss URL: httpTool.Params.Url, Headers: httpTool.Params.Headers, AllowedHeaders: httpTool.AllowedHeaders, + PropagateToken: propagateToken, ServerType: "http", Timeout: httpTool.Params.Timeout, SseReadTimeout: httpTool.Params.SseReadTimeout, @@ -103,6 +109,7 @@ func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, ss URL: sseTool.Params.Url, Headers: sseTool.Params.Headers, AllowedHeaders: sseTool.AllowedHeaders, + PropagateToken: propagateToken, ServerType: "sse", Timeout: sseTool.Params.Timeout, SseReadTimeout: sseTool.Params.SseReadTimeout, @@ -200,11 +207,13 @@ func createTransport(ctx context.Context, params mcpServerParams) (mcpsdk.Transp } var httpTransport http.RoundTripper = baseTransport - if len(params.Headers) > 0 || len(params.AllowedHeaders) > 0 { + if len(params.Headers) > 0 || len(params.AllowedHeaders) > 0 || params.PropagateToken { httpTransport = &headerRoundTripper{ base: baseTransport, headers: params.Headers, allowedHeaders: params.AllowedHeaders, + propagateToken: params.PropagateToken, + log: log, } } @@ -230,30 +239,43 @@ func createTransport(ctx context.Context, params mcpServerParams) (mcpsdk.Transp } // headerRoundTripper wraps an http.RoundTripper to add custom headers to all -// requests. It supports two sources of headers: -// - headers: static key/value pairs configured on the MCP server spec -// - allowedHeaders: header names to forward from the incoming A2A request; -// values are read on each call via allowedRequestHeaders directly from the -// A2A CallContext that is already present in the Go context. -// -// Static headers take precedence: if an allowed header has the same name as a -// static header, the static value wins. +// requests. It supports three sources of headers, applied in this order so that +// higher-priority sources win on collision: +// 1. propagateToken: when true, Authorization is read from the incoming A2A +// CallContext and forwarded unconditionally (independent of allowedHeaders). +// 2. allowedHeaders: explicit per-header forwarding from the A2A CallContext. +// 3. headers: static key/value pairs configured on the MCP server spec (highest +// priority — always wins). type headerRoundTripper struct { base http.RoundTripper headers map[string]string allowedHeaders []string // header names (case-insensitive) to forward from A2A context + propagateToken bool // when true, Authorization is forwarded independently + log logr.Logger } func (rt *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { req = req.Clone(req.Context()) - // Forward allowed headers from the incoming A2A request first so that - // static headers can override them if there is a name collision. + // When KAGENT_PROPAGATE_TOKEN is set, forward Authorization from the incoming + // A2A request independently of allowedHeaders. + if rt.propagateToken { + if callCtx, ok := a2asrv.CallContextFrom(req.Context()); ok { + if meta := callCtx.RequestMeta(); meta != nil { + if vals, ok := meta.Get("authorization"); ok && len(vals) > 0 && vals[0] != "" { + req.Header.Set("authorization", vals[0]) + rt.log.Info("forwarding authorization header to MCP server", "url", req.URL.String()) + } + } + } + } + + // Forward explicitly allowed headers from the incoming A2A request. for k, v := range allowedRequestHeaders(req.Context(), rt.allowedHeaders) { req.Header.Set(k, v) } - // Apply static headers (override any dynamic ones with the same name). + // Apply static headers last — they take precedence over all dynamic sources. for key, value := range rt.headers { req.Header.Set(key, value) } diff --git a/go/adk/pkg/mcp/registry_test.go b/go/adk/pkg/mcp/registry_test.go index f871dd149..7a0cdc0d3 100644 --- a/go/adk/pkg/mcp/registry_test.go +++ b/go/adk/pkg/mcp/registry_test.go @@ -239,6 +239,73 @@ func TestAllowedRequestHeaders_MultiValueFirstWins(t *testing.T) { } } +// TestPropagateToken_ForwardsAuthorizationToMCP verifies that when propagateToken +// is set on headerRoundTripper, the Authorization header from the incoming A2A +// CallContext is forwarded to the outbound MCP request independently of allowedHeaders. +func TestPropagateToken_ForwardsAuthorizationToMCP(t *testing.T) { + t.Parallel() + var capturedAuth string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + ctx := a2aCtx(map[string][]string{ + "Authorization": {"Bearer propagated-token"}, + }) + + rt := &headerRoundTripper{ + base: http.DefaultTransport, + propagateToken: true, + } + + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip failed: %v", err) + } + resp.Body.Close() + + if capturedAuth != "Bearer propagated-token" { + t.Errorf("Authorization: got %q, want %q", capturedAuth, "Bearer propagated-token") + } +} + +// TestPropagateToken_DoesNotForwardWhenDisabled verifies that when propagateToken +// is false, the Authorization header is not forwarded unless listed in allowedHeaders. +func TestPropagateToken_DoesNotForwardWhenDisabled(t *testing.T) { + t.Parallel() + var capturedAuth string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + ctx := a2aCtx(map[string][]string{ + "Authorization": {"Bearer propagated-token"}, + }) + + rt := &headerRoundTripper{ + base: http.DefaultTransport, + propagateToken: false, + } + + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip failed: %v", err) + } + resp.Body.Close() + + if capturedAuth != "" { + t.Errorf("Authorization should not be forwarded when propagateToken=false, got %q", capturedAuth) + } +} + // TestAllowedRequestHeaders_ReturnsNilWhenNoMatches verifies that the helper returns // nil rather than an empty map when the allowed list has entries but none of them // appear in the request metadata. diff --git a/go/adk/pkg/tools/remote_a2a_tool.go b/go/adk/pkg/tools/remote_a2a_tool.go index 9fc79b641..f0c448a1a 100644 --- a/go/adk/pkg/tools/remote_a2a_tool.go +++ b/go/adk/pkg/tools/remote_a2a_tool.go @@ -11,6 +11,7 @@ import ( a2atype "github.com/a2aproject/a2a-go/a2a" "github.com/a2aproject/a2a-go/a2aclient" "github.com/a2aproject/a2a-go/a2aclient/agentcard" + "github.com/a2aproject/a2a-go/a2asrv" "github.com/kagent-dev/kagent/go/adk/pkg/a2a" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "google.golang.org/adk/tool" @@ -32,6 +33,29 @@ func (u *userIDForwardingInterceptor) Before(ctx context.Context, req *a2aclient return ctx, nil } +// authzForwardingInterceptor forwards the Authorization header from the +// incoming A2A request context to outbound sub-agent A2A calls. +type authzForwardingInterceptor struct { + a2aclient.PassthroughInterceptor + name string +} + +func (a *authzForwardingInterceptor) Before(ctx context.Context, req *a2aclient.Request) (context.Context, error) { + callCtx, ok := a2asrv.CallContextFrom(ctx) + if !ok { + return ctx, nil + } + meta := callCtx.RequestMeta() + if meta == nil { + return ctx, nil + } + if vals, ok := meta.Get("authorization"); ok && len(vals) > 0 && vals[0] != "" { + req.Meta.Append("authorization", vals[0]) + slog.Info("forwarding authorization header to sub-agent A2A call", "tool", a.name) + } + return ctx, nil +} + // remoteA2AInput is the typed argument for the remote A2A function tool. type remoteA2AInput struct { Request string `json:"request"` @@ -40,11 +64,12 @@ type remoteA2AInput struct { // remoteA2AState holds the mutable state for one remote A2A agent connection. // All external interaction goes through the tool.Tool returned by NewKAgentRemoteA2ATool. type remoteA2AState struct { - name string - description string - baseURL string - httpClient *http.Client - extraHeaders map[string]string + name string + description string + baseURL string + httpClient *http.Client + extraHeaders map[string]string + propagateToken bool a2aClient *a2aclient.Client agentCard *a2atype.AgentCard @@ -62,18 +87,19 @@ type remoteA2AState struct { // The agent card is fetched lazily from baseURL/.well-known/agent.json. // If httpClient is nil, a default client is created. The client's transport is // wrapped with otelhttp to propagate W3C trace context to subagents. -func NewKAgentRemoteA2ATool(name, description, baseURL string, httpClient *http.Client, extraHeaders map[string]string) (tool.Tool, string, error) { +func NewKAgentRemoteA2ATool(name, description, baseURL string, httpClient *http.Client, extraHeaders map[string]string, propagateToken bool) (tool.Tool, string, error) { if httpClient == nil { httpClient = &http.Client{} } httpClient = withOTelTransport(httpClient) state := &remoteA2AState{ - name: name, - description: description, - baseURL: baseURL, - httpClient: httpClient, - extraHeaders: extraHeaders, - lastContextID: a2atype.NewContextID(), + name: name, + description: description, + baseURL: baseURL, + httpClient: httpClient, + extraHeaders: extraHeaders, + propagateToken: propagateToken, + lastContextID: a2atype.NewContextID(), } ft, err := functiontool.New(functiontool.Config{ Name: name, @@ -119,10 +145,14 @@ func (s *remoteA2AState) ensureClient(ctx context.Context) (*a2aclient.Client, e for k, v := range s.extraHeaders { meta.Append(k, v) } - opts = append(opts, a2aclient.WithInterceptors( + interceptors := []a2aclient.CallInterceptor{ a2aclient.NewStaticCallMetaInjector(meta), &userIDForwardingInterceptor{}, - )) + } + if s.propagateToken { + interceptors = append(interceptors, &authzForwardingInterceptor{name: s.name}) + } + opts = append(opts, a2aclient.WithInterceptors(interceptors...)) client, err := a2aclient.NewFromCard(ctx, card, opts...) if err != nil { diff --git a/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py b/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py index ff73d0c6d..48becb8e6 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py +++ b/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py @@ -12,6 +12,7 @@ """ import logging +import os import uuid from typing import Any, Optional, Protocol, runtime_checkable from urllib.parse import urlparse @@ -58,6 +59,9 @@ _USER_ID_CONTEXT_KEY = "x-user-id" _SOURCE_HEADER = "x-kagent-source" _SOURCE_SUBAGENT = "agent" +_HEADERS_STATE_KEY = "headers" +_AUTHORIZATION_CONTEXT_KEY = "authorization" +_PROPAGATE_TOKEN = bool(os.getenv("KAGENT_PROPAGATE_TOKEN")) class _SubagentInterceptor(ClientCallInterceptor): @@ -65,14 +69,25 @@ class _SubagentInterceptor(ClientCallInterceptor): Injects the authenticated user's ID as an ``x-user-id`` HTTP header and marks the request as originating from an agent call via ``x-kagent-source: agent`` on every outgoing A2A request. + + When ``propagate_token`` is True, also forwards the Authorization header + from the call context state to the sub-agent. """ + def __init__(self, propagate_token: bool = False) -> None: + self._propagate_token = propagate_token + async def intercept(self, method_name, request_payload, http_kwargs, agent_card, context): headers = dict(http_kwargs.get("headers", {})) # Always mark requests from a parent agent tool as subagent-originated headers[_SOURCE_HEADER] = _SOURCE_SUBAGENT - if context and _USER_ID_CONTEXT_KEY in context.state: - headers["x-user-id"] = context.state[_USER_ID_CONTEXT_KEY] + + if context: + if _USER_ID_CONTEXT_KEY in context.state: + headers["x-user-id"] = context.state[_USER_ID_CONTEXT_KEY] + if self._propagate_token and _AUTHORIZATION_CONTEXT_KEY in context.state: + headers["authorization"] = context.state["authorization"] + logger.info("forwarding authorization header to sub-agent A2A call") http_kwargs["headers"] = headers return request_payload, http_kwargs @@ -140,10 +155,12 @@ def __init__( description: str, agent_card_url: str, httpx_client: Optional[httpx.AsyncClient] = None, + propagate_token: bool = False, ) -> None: super().__init__(name=name, description=description) self._agent_card_url = agent_card_url self._httpx_client = httpx_client + self._propagate_token = propagate_token self._a2a_client: Optional[A2AClient] = None self._agent_card: Optional[AgentCard] = None # Pre-generate context_id for UI session polling @@ -188,7 +205,7 @@ async def _ensure_client(self) -> A2AClient: factory = A2AClientFactory(config=config) self._a2a_client = factory.create( self._agent_card, - interceptors=[_SubagentInterceptor()], + interceptors=[_SubagentInterceptor(propagate_token=self._propagate_token)], ) return self._a2a_client @@ -239,7 +256,14 @@ async def _handle_first_call(self, args: dict[str, Any], tool_context: ToolConte # Forward the authenticated user ID so the subagent session is scoped # to the same user as the parent agent session. - call_context = ClientCallContext(state={_USER_ID_CONTEXT_KEY: tool_context.session.user_id}) + call_context_state: dict[str, Any] = {_USER_ID_CONTEXT_KEY: tool_context.session.user_id} + if self._propagate_token: + incoming = tool_context.session.state.get(_HEADERS_STATE_KEY) or {} + if isinstance(incoming, dict): + auth = incoming.get("authorization") or incoming.get("Authorization") + if auth: + call_context_state[_AUTHORIZATION_CONTEXT_KEY] = auth + call_context = ClientCallContext(state=call_context_state) task: Optional[Task] = None try: @@ -381,7 +405,14 @@ async def _handle_resume(self, tool_context: ToolContext) -> Any: ) client = await self._ensure_client() - call_context = ClientCallContext(state={_USER_ID_CONTEXT_KEY: tool_context.session.user_id}) + call_context_state: dict[str, Any] = {_USER_ID_CONTEXT_KEY: tool_context.session.user_id} + if self._propagate_token: + incoming = tool_context.session.state.get(_HEADERS_STATE_KEY) or {} + if isinstance(incoming, dict): + auth = incoming.get("authorization") or incoming.get("Authorization") + if auth: + call_context_state["authorization"] = auth + call_context = ClientCallContext(state=call_context_state) task: Optional[Task] = None try: async for response in client.send_message(request=decision_message, context=call_context): @@ -449,6 +480,7 @@ def __init__( description: str, agent_card_url: str, httpx_client: httpx.AsyncClient, + propagate_token: bool = _PROPAGATE_TOKEN, ) -> None: super().__init__() self._httpx_client = httpx_client @@ -457,6 +489,7 @@ def __init__( description=description, agent_card_url=agent_card_url, httpx_client=httpx_client, + propagate_token=propagate_token, ) @property diff --git a/python/packages/kagent-adk/tests/unittests/test_remote_a2a_tool.py b/python/packages/kagent-adk/tests/unittests/test_remote_a2a_tool.py index 3185ad8e2..ecf5cd5d2 100644 --- a/python/packages/kagent-adk/tests/unittests/test_remote_a2a_tool.py +++ b/python/packages/kagent-adk/tests/unittests/test_remote_a2a_tool.py @@ -26,6 +26,7 @@ KAgentRemoteA2ATool, KAgentRemoteA2AToolset, SubagentSessionProvider, + _SubagentInterceptor, ) # --------------------------------------------------------------------------- @@ -139,6 +140,44 @@ def _approval_ctx(confirmed: bool, payload: dict | None = None, **kwargs) -> Moc return MockToolContext(tool_confirmation=confirmation, **kwargs) +# --------------------------------------------------------------------------- +# _SubagentInterceptor propagate_token tests +# --------------------------------------------------------------------------- + + +class TestSubagentInterceptorPropagateToken: + """Tests for Authorization header propagation in _SubagentInterceptor.""" + + async def _call_intercept(self, interceptor, state: dict) -> dict: + from a2a.client.middleware import ClientCallContext + + ctx = ClientCallContext(state=state) + _, http_kwargs = await interceptor.intercept( + method_name="message/send", + request_payload={}, + http_kwargs={}, + agent_card=None, + context=ctx, + ) + return http_kwargs.get("headers", {}) + + async def test_forwards_auth_when_propagate_token_enabled(self): + interceptor = _SubagentInterceptor(propagate_token=True) + headers = await self._call_intercept( + interceptor, + state={"x-user-id": "user1", "authorization": "Bearer test-jwt"}, + ) + assert headers.get("authorization") == "Bearer test-jwt" + + async def test_does_not_forward_auth_when_propagate_token_disabled(self): + interceptor = _SubagentInterceptor(propagate_token=False) + headers = await self._call_intercept( + interceptor, + state={"x-user-id": "user1", "authorization": "Bearer test-jwt"}, + ) + assert "authorization" not in headers + + # --------------------------------------------------------------------------- # First-call tests # --------------------------------------------------------------------------- From 0372da875a36632de009be76ded7d432861b42f9 Mon Sep 17 00:00:00 2001 From: JM Huibonhoa Date: Tue, 12 May 2026 18:43:54 -0400 Subject: [PATCH 02/10] style: remove logs and use constants Signed-off-by: JM Huibonhoa --- go/adk/pkg/constants/const.go | 7 +++++++ go/adk/pkg/mcp/registry.go | 6 +++--- go/adk/pkg/tools/remote_a2a_tool.go | 6 +++--- .../packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py | 1 - 4 files changed, 13 insertions(+), 7 deletions(-) create mode 100644 go/adk/pkg/constants/const.go diff --git a/go/adk/pkg/constants/const.go b/go/adk/pkg/constants/const.go new file mode 100644 index 000000000..2926e96e4 --- /dev/null +++ b/go/adk/pkg/constants/const.go @@ -0,0 +1,7 @@ +package constants + +const ( + // A2A call context's NewRequestMeta normalizes header names to lowercase. + // This is why we use "authorization" instead of "Authorization". + AuthorizationHeader = "authorization" +) diff --git a/go/adk/pkg/mcp/registry.go b/go/adk/pkg/mcp/registry.go index 0aa140922..701776667 100644 --- a/go/adk/pkg/mcp/registry.go +++ b/go/adk/pkg/mcp/registry.go @@ -11,6 +11,7 @@ import ( "github.com/a2aproject/a2a-go/a2asrv" "github.com/go-logr/logr" + "github.com/kagent-dev/kagent/go/adk/pkg/constants" "github.com/kagent-dev/kagent/go/api/adk" mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp" "google.golang.org/adk/tool" @@ -262,9 +263,8 @@ func (rt *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro if rt.propagateToken { if callCtx, ok := a2asrv.CallContextFrom(req.Context()); ok { if meta := callCtx.RequestMeta(); meta != nil { - if vals, ok := meta.Get("authorization"); ok && len(vals) > 0 && vals[0] != "" { - req.Header.Set("authorization", vals[0]) - rt.log.Info("forwarding authorization header to MCP server", "url", req.URL.String()) + if vals, ok := meta.Get(constants.AuthorizationHeader); ok && len(vals) > 0 && vals[0] != "" { + req.Header.Set(constants.AuthorizationHeader, vals[0]) } } } diff --git a/go/adk/pkg/tools/remote_a2a_tool.go b/go/adk/pkg/tools/remote_a2a_tool.go index f0c448a1a..210705307 100644 --- a/go/adk/pkg/tools/remote_a2a_tool.go +++ b/go/adk/pkg/tools/remote_a2a_tool.go @@ -13,6 +13,7 @@ import ( "github.com/a2aproject/a2a-go/a2aclient/agentcard" "github.com/a2aproject/a2a-go/a2asrv" "github.com/kagent-dev/kagent/go/adk/pkg/a2a" + "github.com/kagent-dev/kagent/go/adk/pkg/constants" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "google.golang.org/adk/tool" "google.golang.org/adk/tool/functiontool" @@ -49,9 +50,8 @@ func (a *authzForwardingInterceptor) Before(ctx context.Context, req *a2aclient. if meta == nil { return ctx, nil } - if vals, ok := meta.Get("authorization"); ok && len(vals) > 0 && vals[0] != "" { - req.Meta.Append("authorization", vals[0]) - slog.Info("forwarding authorization header to sub-agent A2A call", "tool", a.name) + if vals, ok := meta.Get(constants.AuthorizationHeader); ok && len(vals) > 0 && vals[0] != "" { + req.Meta.Append(constants.AuthorizationHeader, vals[0]) } return ctx, nil } diff --git a/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py b/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py index 48becb8e6..d5c532d42 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py +++ b/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py @@ -87,7 +87,6 @@ async def intercept(self, method_name, request_payload, http_kwargs, agent_card, headers["x-user-id"] = context.state[_USER_ID_CONTEXT_KEY] if self._propagate_token and _AUTHORIZATION_CONTEXT_KEY in context.state: headers["authorization"] = context.state["authorization"] - logger.info("forwarding authorization header to sub-agent A2A call") http_kwargs["headers"] = headers return request_payload, http_kwargs From 1870b95d2537773f2049d875977138b127aa9877 Mon Sep 17 00:00:00 2001 From: JM Huibonhoa Date: Wed, 13 May 2026 14:34:07 -0400 Subject: [PATCH 03/10] fix: address pr feedback Signed-off-by: JM Huibonhoa --- go/adk/pkg/mcp/registry.go | 2 -- go/adk/pkg/tools/remote_a2a_tool.go | 3 +++ .../kagent-adk/src/kagent/adk/_remote_a2a_tool.py | 8 ++++---- python/packages/kagent-adk/src/kagent/adk/cli.py | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/go/adk/pkg/mcp/registry.go b/go/adk/pkg/mcp/registry.go index 701776667..1dd2a4d04 100644 --- a/go/adk/pkg/mcp/registry.go +++ b/go/adk/pkg/mcp/registry.go @@ -214,7 +214,6 @@ func createTransport(ctx context.Context, params mcpServerParams) (mcpsdk.Transp headers: params.Headers, allowedHeaders: params.AllowedHeaders, propagateToken: params.PropagateToken, - log: log, } } @@ -252,7 +251,6 @@ type headerRoundTripper struct { headers map[string]string allowedHeaders []string // header names (case-insensitive) to forward from A2A context propagateToken bool // when true, Authorization is forwarded independently - log logr.Logger } func (rt *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { diff --git a/go/adk/pkg/tools/remote_a2a_tool.go b/go/adk/pkg/tools/remote_a2a_tool.go index 210705307..aefec7746 100644 --- a/go/adk/pkg/tools/remote_a2a_tool.go +++ b/go/adk/pkg/tools/remote_a2a_tool.go @@ -50,6 +50,9 @@ func (a *authzForwardingInterceptor) Before(ctx context.Context, req *a2aclient. if meta == nil { return ctx, nil } + if len(req.Meta.Get(constants.AuthorizationHeader)) > 0 { + return ctx, nil + } if vals, ok := meta.Get(constants.AuthorizationHeader); ok && len(vals) > 0 && vals[0] != "" { req.Meta.Append(constants.AuthorizationHeader, vals[0]) } diff --git a/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py b/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py index d5c532d42..0e63444e8 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py +++ b/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py @@ -61,7 +61,7 @@ _SOURCE_SUBAGENT = "agent" _HEADERS_STATE_KEY = "headers" _AUTHORIZATION_CONTEXT_KEY = "authorization" -_PROPAGATE_TOKEN = bool(os.getenv("KAGENT_PROPAGATE_TOKEN")) +_PROPAGATE_TOKEN = os.getenv("KAGENT_PROPAGATE_TOKEN", "").lower() == "true" class _SubagentInterceptor(ClientCallInterceptor): @@ -86,7 +86,7 @@ async def intercept(self, method_name, request_payload, http_kwargs, agent_card, if _USER_ID_CONTEXT_KEY in context.state: headers["x-user-id"] = context.state[_USER_ID_CONTEXT_KEY] if self._propagate_token and _AUTHORIZATION_CONTEXT_KEY in context.state: - headers["authorization"] = context.state["authorization"] + headers[_AUTHORIZATION_CONTEXT_KEY] = context.state[_AUTHORIZATION_CONTEXT_KEY] http_kwargs["headers"] = headers return request_payload, http_kwargs @@ -154,7 +154,7 @@ def __init__( description: str, agent_card_url: str, httpx_client: Optional[httpx.AsyncClient] = None, - propagate_token: bool = False, + propagate_token: bool = _PROPAGATE_TOKEN, ) -> None: super().__init__(name=name, description=description) self._agent_card_url = agent_card_url @@ -410,7 +410,7 @@ async def _handle_resume(self, tool_context: ToolContext) -> Any: if isinstance(incoming, dict): auth = incoming.get("authorization") or incoming.get("Authorization") if auth: - call_context_state["authorization"] = auth + call_context_state[_AUTHORIZATION_CONTEXT_KEY] = auth call_context = ClientCallContext(state=call_context_state) task: Optional[Task] = None try: diff --git a/python/packages/kagent-adk/src/kagent/adk/cli.py b/python/packages/kagent-adk/src/kagent/adk/cli.py index b73f4ce32..67b1441ca 100644 --- a/python/packages/kagent-adk/src/kagent/adk/cli.py +++ b/python/packages/kagent-adk/src/kagent/adk/cli.py @@ -24,7 +24,7 @@ kagent_url_override = os.getenv("KAGENT_URL") sts_well_known_uri = os.getenv("STS_WELL_KNOWN_URI") -propagate_token = os.getenv("KAGENT_PROPAGATE_TOKEN") +propagate_token = os.getenv("KAGENT_PROPAGATE_TOKEN", "").lower() == "true" uvicorn_log_level = os.getenv("UVICORN_LOG_LEVEL", os.getenv("LOG_LEVEL", "info")).lower() From 8c8122d9cbc6d1152a3c1b5d11a38128ff8fc484 Mon Sep 17 00:00:00 2001 From: JM Huibonhoa Date: Wed, 13 May 2026 16:17:51 -0400 Subject: [PATCH 04/10] style: refactor call context building into helper func Signed-off-by: JM Huibonhoa --- .../src/kagent/adk/_remote_a2a_tool.py | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py b/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py index 0e63444e8..4f66924fe 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py +++ b/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py @@ -222,6 +222,16 @@ def _get_declaration(self) -> genai_types.FunctionDeclaration: ), ) + def _build_call_context(self, tool_context: ToolContext) -> ClientCallContext: + state: dict[str, Any] = {_USER_ID_CONTEXT_KEY: tool_context.session.user_id} + if self._propagate_token: + incoming = tool_context.session.state.get(_HEADERS_STATE_KEY) or {} + if isinstance(incoming, dict): + auth = incoming.get("authorization") or incoming.get("Authorization") + if auth: + state[_AUTHORIZATION_CONTEXT_KEY] = auth + return ClientCallContext(state=state) + async def run_async(self, *, args: dict[str, Any], tool_context: ToolContext) -> Any: """Execute the remote agent tool. @@ -255,14 +265,7 @@ async def _handle_first_call(self, args: dict[str, Any], tool_context: ToolConte # Forward the authenticated user ID so the subagent session is scoped # to the same user as the parent agent session. - call_context_state: dict[str, Any] = {_USER_ID_CONTEXT_KEY: tool_context.session.user_id} - if self._propagate_token: - incoming = tool_context.session.state.get(_HEADERS_STATE_KEY) or {} - if isinstance(incoming, dict): - auth = incoming.get("authorization") or incoming.get("Authorization") - if auth: - call_context_state[_AUTHORIZATION_CONTEXT_KEY] = auth - call_context = ClientCallContext(state=call_context_state) + call_context = self._build_call_context(tool_context) task: Optional[Task] = None try: @@ -404,14 +407,7 @@ async def _handle_resume(self, tool_context: ToolContext) -> Any: ) client = await self._ensure_client() - call_context_state: dict[str, Any] = {_USER_ID_CONTEXT_KEY: tool_context.session.user_id} - if self._propagate_token: - incoming = tool_context.session.state.get(_HEADERS_STATE_KEY) or {} - if isinstance(incoming, dict): - auth = incoming.get("authorization") or incoming.get("Authorization") - if auth: - call_context_state[_AUTHORIZATION_CONTEXT_KEY] = auth - call_context = ClientCallContext(state=call_context_state) + call_context = self._build_call_context(tool_context) task: Optional[Task] = None try: async for response in client.send_message(request=decision_message, context=call_context): From c6fe25a83ef14b89809dc163178c1276114e357e Mon Sep 17 00:00:00 2001 From: JM Huibonhoa Date: Wed, 13 May 2026 16:22:51 -0400 Subject: [PATCH 05/10] fix: address pr feedback Signed-off-by: JM Huibonhoa --- go/adk/pkg/agent/agent.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/adk/pkg/agent/agent.go b/go/adk/pkg/agent/agent.go index f779648e6..452eaad3e 100644 --- a/go/adk/pkg/agent/agent.go +++ b/go/adk/pkg/agent/agent.go @@ -48,7 +48,7 @@ func CreateGoogleADKAgentWithSubagentSessionIDs(ctx context.Context, agentConfig return nil, nil, fmt.Errorf("agent config is required") } - propagateToken := os.Getenv("KAGENT_PROPAGATE_TOKEN") != "" + propagateToken := strings.ToLower(os.Getenv("KAGENT_PROPAGATE_TOKEN")) != "true" toolsets := mcp.CreateToolsets(ctx, agentConfig.HttpTools, agentConfig.SseTools, propagateToken) subagentSessionIDs := make(map[string]string) From 022ee091eb48582cc2f9b235c9692977a1268be3 Mon Sep 17 00:00:00 2001 From: JM Huibonhoa Date: Wed, 13 May 2026 16:26:44 -0400 Subject: [PATCH 06/10] fix: address pr feedback Signed-off-by: JM Huibonhoa --- go/adk/pkg/agent/agent.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/adk/pkg/agent/agent.go b/go/adk/pkg/agent/agent.go index 452eaad3e..1aaa21af4 100644 --- a/go/adk/pkg/agent/agent.go +++ b/go/adk/pkg/agent/agent.go @@ -48,7 +48,7 @@ func CreateGoogleADKAgentWithSubagentSessionIDs(ctx context.Context, agentConfig return nil, nil, fmt.Errorf("agent config is required") } - propagateToken := strings.ToLower(os.Getenv("KAGENT_PROPAGATE_TOKEN")) != "true" + propagateToken := strings.ToLower(os.Getenv("KAGENT_PROPAGATE_TOKEN")) == "true" toolsets := mcp.CreateToolsets(ctx, agentConfig.HttpTools, agentConfig.SseTools, propagateToken) subagentSessionIDs := make(map[string]string) From a2033957af07005aee303801563d5bf62ec35d2e Mon Sep 17 00:00:00 2001 From: JM Huibonhoa Date: Wed, 13 May 2026 16:29:08 -0400 Subject: [PATCH 07/10] refactor: clean up unused struct fields Signed-off-by: JM Huibonhoa --- go/adk/pkg/tools/remote_a2a_tool.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/go/adk/pkg/tools/remote_a2a_tool.go b/go/adk/pkg/tools/remote_a2a_tool.go index aefec7746..87afe88a8 100644 --- a/go/adk/pkg/tools/remote_a2a_tool.go +++ b/go/adk/pkg/tools/remote_a2a_tool.go @@ -38,7 +38,6 @@ func (u *userIDForwardingInterceptor) Before(ctx context.Context, req *a2aclient // incoming A2A request context to outbound sub-agent A2A calls. type authzForwardingInterceptor struct { a2aclient.PassthroughInterceptor - name string } func (a *authzForwardingInterceptor) Before(ctx context.Context, req *a2aclient.Request) (context.Context, error) { @@ -153,7 +152,7 @@ func (s *remoteA2AState) ensureClient(ctx context.Context) (*a2aclient.Client, e &userIDForwardingInterceptor{}, } if s.propagateToken { - interceptors = append(interceptors, &authzForwardingInterceptor{name: s.name}) + interceptors = append(interceptors, &authzForwardingInterceptor{}) } opts = append(opts, a2aclient.WithInterceptors(interceptors...)) From d0948bcbb242fe22999b25e39093851ef7b8d1d7 Mon Sep 17 00:00:00 2001 From: JM Huibonhoa Date: Wed, 13 May 2026 16:50:29 -0400 Subject: [PATCH 08/10] fix: address pr feedback Signed-off-by: JM Huibonhoa --- .../packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py | 6 ++---- python/packages/kagent-adk/src/kagent/adk/cli.py | 4 ++-- python/packages/kagent-adk/src/kagent/adk/types.py | 3 ++- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py b/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py index 4f66924fe..f6f2d3490 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py +++ b/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py @@ -12,7 +12,6 @@ """ import logging -import os import uuid from typing import Any, Optional, Protocol, runtime_checkable from urllib.parse import urlparse @@ -61,7 +60,6 @@ _SOURCE_SUBAGENT = "agent" _HEADERS_STATE_KEY = "headers" _AUTHORIZATION_CONTEXT_KEY = "authorization" -_PROPAGATE_TOKEN = os.getenv("KAGENT_PROPAGATE_TOKEN", "").lower() == "true" class _SubagentInterceptor(ClientCallInterceptor): @@ -154,7 +152,7 @@ def __init__( description: str, agent_card_url: str, httpx_client: Optional[httpx.AsyncClient] = None, - propagate_token: bool = _PROPAGATE_TOKEN, + propagate_token: bool = False, ) -> None: super().__init__(name=name, description=description) self._agent_card_url = agent_card_url @@ -475,7 +473,7 @@ def __init__( description: str, agent_card_url: str, httpx_client: httpx.AsyncClient, - propagate_token: bool = _PROPAGATE_TOKEN, + propagate_token: bool = False, ) -> None: super().__init__() self._httpx_client = httpx_client diff --git a/python/packages/kagent-adk/src/kagent/adk/cli.py b/python/packages/kagent-adk/src/kagent/adk/cli.py index 67b1441ca..e32d0aacb 100644 --- a/python/packages/kagent-adk/src/kagent/adk/cli.py +++ b/python/packages/kagent-adk/src/kagent/adk/cli.py @@ -79,7 +79,7 @@ def static( plugins.append(LLMPassthroughPlugin()) def root_agent_factory() -> BaseAgent: - root_agent = agent_config.to_agent(app_cfg.name, sts_integration) + root_agent = agent_config.to_agent(app_cfg.name, sts_integration, propagate_token) maybe_add_skills_with_config(root_agent, agent_config) @@ -218,7 +218,7 @@ async def test_agent(agent_config: AgentConfig, agent_card: AgentCard, task: str plugins = [sts_integration] def root_agent_factory() -> BaseAgent: - root_agent = agent_config.to_agent(app_cfg.name, sts_integration) + root_agent = agent_config.to_agent(app_cfg.name, sts_integration, propagate_token) maybe_add_skills_with_config(root_agent, agent_config) return root_agent diff --git a/python/packages/kagent-adk/src/kagent/adk/types.py b/python/packages/kagent-adk/src/kagent/adk/types.py index 96c9ee6f0..df0176f0c 100644 --- a/python/packages/kagent-adk/src/kagent/adk/types.py +++ b/python/packages/kagent-adk/src/kagent/adk/types.py @@ -298,7 +298,7 @@ class AgentConfig(BaseModel): network: NetworkConfig | None = None context_config: ContextConfig | None = None - def to_agent(self, name: str, sts_integration: Optional[ADKTokenPropagationPlugin] = None) -> Agent: + def to_agent(self, name: str, sts_integration: Optional[ADKTokenPropagationPlugin] = None, propagate_token: bool = False) -> Agent: if name is None or not str(name).strip(): raise ValueError("Agent name must be a non-empty string.") tools: list[ToolUnion] = [] @@ -406,6 +406,7 @@ async def rewrite_url_to_proxy(request: httpx.Request) -> None: description=remote_agent.description, agent_card_url=f"{remote_agent.url}{AGENT_CARD_WELL_KNOWN_PATH}", httpx_client=client, + propagate_token=propagate_token, ) ) From d483e2156a0afe42d16968ec0dd273b3a6ad412a Mon Sep 17 00:00:00 2001 From: JM Huibonhoa Date: Wed, 13 May 2026 16:52:10 -0400 Subject: [PATCH 09/10] style: lint issues Signed-off-by: JM Huibonhoa --- python/packages/kagent-adk/src/kagent/adk/types.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/packages/kagent-adk/src/kagent/adk/types.py b/python/packages/kagent-adk/src/kagent/adk/types.py index df0176f0c..c9a4c074f 100644 --- a/python/packages/kagent-adk/src/kagent/adk/types.py +++ b/python/packages/kagent-adk/src/kagent/adk/types.py @@ -298,7 +298,9 @@ class AgentConfig(BaseModel): network: NetworkConfig | None = None context_config: ContextConfig | None = None - def to_agent(self, name: str, sts_integration: Optional[ADKTokenPropagationPlugin] = None, propagate_token: bool = False) -> Agent: + def to_agent( + self, name: str, sts_integration: Optional[ADKTokenPropagationPlugin] = None, propagate_token: bool = False + ) -> Agent: if name is None or not str(name).strip(): raise ValueError("Agent name must be a non-empty string.") tools: list[ToolUnion] = [] From b7732d944fe450526d9655f685d521783b894926 Mon Sep 17 00:00:00 2001 From: JM Huibonhoa Date: Wed, 13 May 2026 17:37:32 -0400 Subject: [PATCH 10/10] refactor: pass HeaderProvider to to KAgentRemoteA2AToolset instead of propagate_token bool Signed-off-by: JM Huibonhoa --- .../src/kagent/adk/_remote_a2a_tool.py | 41 ++++++++----------- .../kagent-adk/src/kagent/adk/types.py | 5 ++- .../tests/unittests/test_remote_a2a_tool.py | 16 ++++---- 3 files changed, 29 insertions(+), 33 deletions(-) diff --git a/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py b/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py index f6f2d3490..3cbaabcb2 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py +++ b/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py @@ -13,7 +13,7 @@ import logging import uuid -from typing import Any, Optional, Protocol, runtime_checkable +from typing import Any, Callable, Optional, Protocol, runtime_checkable from urllib.parse import urlparse import httpx @@ -59,32 +59,27 @@ _SOURCE_HEADER = "x-kagent-source" _SOURCE_SUBAGENT = "agent" _HEADERS_STATE_KEY = "headers" -_AUTHORIZATION_CONTEXT_KEY = "authorization" +_EXTRA_HEADERS_CONTEXT_KEY = "_a2a_extra_headers" class _SubagentInterceptor(ClientCallInterceptor): """ - Injects the authenticated user's ID as an ``x-user-id`` HTTP header and + Injects the authenticated user's ID as an ``x-user-id`` HTTP header, marks the request as originating from an agent call via - ``x-kagent-source: agent`` on every outgoing A2A request. - - When ``propagate_token`` is True, also forwards the Authorization header - from the call context state to the sub-agent. + ``x-kagent-source: agent``, and forwards any pre-computed propagation + headers stored in the call context state under ``_EXTRA_HEADERS_CONTEXT_KEY``. """ - def __init__(self, propagate_token: bool = False) -> None: - self._propagate_token = propagate_token - async def intercept(self, method_name, request_payload, http_kwargs, agent_card, context): headers = dict(http_kwargs.get("headers", {})) - # Always mark requests from a parent agent tool as subagent-originated headers[_SOURCE_HEADER] = _SOURCE_SUBAGENT if context: if _USER_ID_CONTEXT_KEY in context.state: headers["x-user-id"] = context.state[_USER_ID_CONTEXT_KEY] - if self._propagate_token and _AUTHORIZATION_CONTEXT_KEY in context.state: - headers[_AUTHORIZATION_CONTEXT_KEY] = context.state[_AUTHORIZATION_CONTEXT_KEY] + extra = context.state.get(_EXTRA_HEADERS_CONTEXT_KEY) + if extra: + headers.update(extra) http_kwargs["headers"] = headers return request_payload, http_kwargs @@ -152,12 +147,12 @@ def __init__( description: str, agent_card_url: str, httpx_client: Optional[httpx.AsyncClient] = None, - propagate_token: bool = False, + header_provider: Optional[Callable[[Optional[ReadonlyContext]], dict[str, str]]] = None, ) -> None: super().__init__(name=name, description=description) self._agent_card_url = agent_card_url self._httpx_client = httpx_client - self._propagate_token = propagate_token + self._header_provider = header_provider self._a2a_client: Optional[A2AClient] = None self._agent_card: Optional[AgentCard] = None # Pre-generate context_id for UI session polling @@ -202,7 +197,7 @@ async def _ensure_client(self) -> A2AClient: factory = A2AClientFactory(config=config) self._a2a_client = factory.create( self._agent_card, - interceptors=[_SubagentInterceptor(propagate_token=self._propagate_token)], + interceptors=[_SubagentInterceptor()], ) return self._a2a_client @@ -222,12 +217,10 @@ def _get_declaration(self) -> genai_types.FunctionDeclaration: def _build_call_context(self, tool_context: ToolContext) -> ClientCallContext: state: dict[str, Any] = {_USER_ID_CONTEXT_KEY: tool_context.session.user_id} - if self._propagate_token: - incoming = tool_context.session.state.get(_HEADERS_STATE_KEY) or {} - if isinstance(incoming, dict): - auth = incoming.get("authorization") or incoming.get("Authorization") - if auth: - state[_AUTHORIZATION_CONTEXT_KEY] = auth + if self._header_provider: + extra_headers = self._header_provider(tool_context) + if extra_headers: + state[_EXTRA_HEADERS_CONTEXT_KEY] = extra_headers return ClientCallContext(state=state) async def run_async(self, *, args: dict[str, Any], tool_context: ToolContext) -> Any: @@ -473,7 +466,7 @@ def __init__( description: str, agent_card_url: str, httpx_client: httpx.AsyncClient, - propagate_token: bool = False, + header_provider: Optional[Callable[[Optional[ReadonlyContext]], dict[str, str]]] = None, ) -> None: super().__init__() self._httpx_client = httpx_client @@ -482,7 +475,7 @@ def __init__( description=description, agent_card_url=agent_card_url, httpx_client=httpx_client, - propagate_token=propagate_token, + header_provider=header_provider, ) @property diff --git a/python/packages/kagent-adk/src/kagent/adk/types.py b/python/packages/kagent-adk/src/kagent/adk/types.py index c9a4c074f..5e2f4a97a 100644 --- a/python/packages/kagent-adk/src/kagent/adk/types.py +++ b/python/packages/kagent-adk/src/kagent/adk/types.py @@ -402,13 +402,16 @@ async def rewrite_url_to_proxy(request: httpx.Request) -> None: timeout=timeout, ) + a2a_header_provider = None + if propagate_token: + a2a_header_provider = create_header_provider(allowed_headers=["authorization"]) tools.append( KAgentRemoteA2AToolset( name=remote_agent.name, description=remote_agent.description, agent_card_url=f"{remote_agent.url}{AGENT_CARD_WELL_KNOWN_PATH}", httpx_client=client, - propagate_token=propagate_token, + header_provider=a2a_header_provider, ) ) diff --git a/python/packages/kagent-adk/tests/unittests/test_remote_a2a_tool.py b/python/packages/kagent-adk/tests/unittests/test_remote_a2a_tool.py index ecf5cd5d2..8682741bb 100644 --- a/python/packages/kagent-adk/tests/unittests/test_remote_a2a_tool.py +++ b/python/packages/kagent-adk/tests/unittests/test_remote_a2a_tool.py @@ -141,12 +141,12 @@ def _approval_ctx(confirmed: bool, payload: dict | None = None, **kwargs) -> Moc # --------------------------------------------------------------------------- -# _SubagentInterceptor propagate_token tests +# _SubagentInterceptor header propagation tests # --------------------------------------------------------------------------- -class TestSubagentInterceptorPropagateToken: - """Tests for Authorization header propagation in _SubagentInterceptor.""" +class TestSubagentInterceptorHeaderPropagation: + """Tests for header propagation in _SubagentInterceptor via context state.""" async def _call_intercept(self, interceptor, state: dict) -> dict: from a2a.client.middleware import ClientCallContext @@ -161,16 +161,16 @@ async def _call_intercept(self, interceptor, state: dict) -> dict: ) return http_kwargs.get("headers", {}) - async def test_forwards_auth_when_propagate_token_enabled(self): - interceptor = _SubagentInterceptor(propagate_token=True) + async def test_forwards_extra_headers_from_context_state(self): + interceptor = _SubagentInterceptor() headers = await self._call_intercept( interceptor, - state={"x-user-id": "user1", "authorization": "Bearer test-jwt"}, + state={"x-user-id": "user1", "_a2a_extra_headers": {"authorization": "Bearer test-jwt"}}, ) assert headers.get("authorization") == "Bearer test-jwt" - async def test_does_not_forward_auth_when_propagate_token_disabled(self): - interceptor = _SubagentInterceptor(propagate_token=False) + async def test_no_extra_headers_without_state_key(self): + interceptor = _SubagentInterceptor() headers = await self._call_intercept( interceptor, state={"x-user-id": "user1", "authorization": "Bearer test-jwt"},