From d0f206f2c670646fd638e7f64232d5c88d3de236 Mon Sep 17 00:00:00 2001 From: Peter Ibekwe Date: Tue, 9 Jun 2026 15:01:44 -0700 Subject: [PATCH 1/5] declarative workflow approval flow fix --- .../ObjectModel/InvokeFunctionToolExecutor.cs | 66 +++- .../InvokeFunctionToolExecutorTest.cs | 355 ++++++++++++++++++ 2 files changed, 419 insertions(+), 2 deletions(-) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeFunctionToolExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeFunctionToolExecutor.cs index 6ca429c648a..7de174404ac 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeFunctionToolExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeFunctionToolExecutor.cs @@ -13,6 +13,7 @@ using Microsoft.Agents.AI.Workflows.Declarative.PowerFx; using Microsoft.Agents.ObjectModel; using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.AI.Workflows.Declarative.ObjectModel; @@ -27,6 +28,13 @@ internal sealed class InvokeFunctionToolExecutor( WorkflowFormulaState state) : DeclarativeActionExecutor(model, state) { + private const string ApprovalSnapshotStateKey = nameof(_approvalSnapshot); + + /// + /// Snapshot of evaluated parameters at approval-request time. + /// + private ApprovalSnapshot? _approvalSnapshot; + /// /// Step identifiers for the function tool invocation workflow. /// @@ -69,6 +77,10 @@ public static class Steps // If approval is required, add user input request content if (requireApproval) { + // Snapshot the evaluated parameters. + // If state mutates during the approval window, the approved values are used on resume. + this._approvalSnapshot = new ApprovalSnapshot(functionName, arguments); + requestMessage.Contents.Add(new ToolApprovalRequestContent(this.Id, functionCall)); } @@ -155,6 +167,31 @@ public async ValueTask CaptureResponseAsync( // Completes the action after processing the function result. await context.RaiseCompletionEventAsync(this.Model, cancellationToken).ConfigureAwait(false); + + // Clear the approval snapshot after the action completes so a subsequent + // execution of the same executor instance doesn't reuse stale data. + this._approvalSnapshot = null; + await context.QueueStateUpdateAsync(ApprovalSnapshotStateKey, null, null, cancellationToken).ConfigureAwait(false); + } + + /// + /// + /// Persists the approval snapshot to workflow state so it survives checkpoint/restore cycles. + /// + protected override async ValueTask OnCheckpointingAsync(IWorkflowContext context, CancellationToken cancellationToken = default) + { + await context.QueueStateUpdateAsync(ApprovalSnapshotStateKey, this._approvalSnapshot, null, cancellationToken).ConfigureAwait(false); + await base.OnCheckpointingAsync(context, cancellationToken).ConfigureAwait(false); + } + + /// + /// + /// Restores the approval snapshot from workflow state after a checkpoint restore. + /// + protected override async ValueTask OnCheckpointRestoredAsync(IWorkflowContext context, CancellationToken cancellationToken = default) + { + await base.OnCheckpointRestoredAsync(context, cancellationToken).ConfigureAwait(false); + this._approvalSnapshot = await context.ReadStateAsync(ApprovalSnapshotStateKey, null, cancellationToken).ConfigureAwait(false); } /// @@ -262,7 +299,24 @@ private string GetFunctionName() => private async ValueTask InvokeRegisteredFunctionAsync(CancellationToken cancellationToken) { - string functionName = this.GetFunctionName(); + string functionName; + Dictionary? arguments; + + if (this._approvalSnapshot is { } snapshot) + { + // Use the snapshot captured at approval-request time so we invoke exactly what + // the user approved, even if Power Fx state has mutated during the approval window. + functionName = snapshot.FunctionName; + arguments = snapshot.Arguments; + } + else + { + // Fallback for checkpoints created before approval snapshots were introduced. + this.Logger.LogWarning("Approval snapshot missing for '{ActionId}'; falling back to expression re-evaluation.", this.Id); + functionName = this.GetFunctionName(); + arguments = this.GetArguments(); + } + AIFunction? function = agentProvider.Functions?.FirstOrDefault( f => string.Equals(f.Name, functionName, StringComparison.Ordinal)); @@ -275,7 +329,6 @@ private string GetFunctionName() => }; } - Dictionary? arguments = this.GetArguments(); AIFunctionArguments? functionArguments = arguments is null ? null : new AIFunctionArguments(arguments); object? result; @@ -341,4 +394,13 @@ private bool GetAutoSendValue() return result; } + + /// + /// Stores the evaluated parameters at approval-request time so that + /// uses the values the user reviewed, + /// even if mutates during the approval window. + /// + internal sealed record ApprovalSnapshot( + string FunctionName, + Dictionary? Arguments); } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/ObjectModel/InvokeFunctionToolExecutorTest.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/ObjectModel/InvokeFunctionToolExecutorTest.cs index b00339ea3b5..845f2a18718 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/ObjectModel/InvokeFunctionToolExecutorTest.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/ObjectModel/InvokeFunctionToolExecutorTest.cs @@ -1,11 +1,21 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text.Json; +using System.Threading; using System.Threading.Tasks; using Microsoft.Agents.AI.Workflows.Declarative.Events; +using Microsoft.Agents.AI.Workflows.Declarative.Kit; using Microsoft.Agents.AI.Workflows.Declarative.ObjectModel; using Microsoft.Agents.AI.Workflows.Declarative.PowerFx; using Microsoft.Agents.ObjectModel; using Microsoft.Extensions.AI; +using Microsoft.PowerFx.Types; +using Moq; +using ApprovalSnapshot = Microsoft.Agents.AI.Workflows.Declarative.ObjectModel.InvokeFunctionToolExecutor.ApprovalSnapshot; namespace Microsoft.Agents.AI.Workflows.Declarative.UnitTests.ObjectModel; @@ -261,6 +271,323 @@ public async Task InvokeFunctionToolCaptureResponseWithMultipleFunctionResultsAs #endregion + #region Approval Snapshot Security Tests + + /// + /// Verifies that mutating the function-name variable after approval does not change + /// which function is actually invoked. The originally-approved name must be used. + /// + [Fact] + public async Task InvokeFunctionToolCaptureResponseUsesApprovedFunctionNameNotMutatedAsync() + { + // Arrange + const string ApprovedFunctionName = "safe_readonly_query"; + const string MutatedFunctionName = "dangerous_admin_tool"; + + this.State.Set("TargetFunction", FormulaValue.New(ApprovedFunctionName)); + this.State.InitializeSystem(); + this.State.Bind(); + + InvokeFunctionTool model = this.CreateModelWithVariableFunctionName( + displayName: nameof(InvokeFunctionToolCaptureResponseUsesApprovedFunctionNameNotMutatedAsync), + variableName: "TargetFunction"); + + string? capturedFunctionName = null; + TestFunctionAgentProvider testAgentProvider = new( + [ + AIFunctionFactory.Create(() => "safe-result", name: ApprovedFunctionName), + AIFunctionFactory.Create(() => "dangerous-result", name: MutatedFunctionName), + ], + onInvoke: name => capturedFunctionName = name); + InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State); + + // Act - trigger ExecuteAsync to store the approval snapshot + Mock mockContext = CreateMockWorkflowContext(); + await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None); + + // Simulate parallel branch mutating state during the approval window + this.State.Set("TargetFunction", FormulaValue.New(MutatedFunctionName)); + this.State.Bind(); + + // User clicks approve (they saw "safe_readonly_query" in the approval UI) + ExternalInputResponse response = CreateApprovalResponse(action.Id, approved: true); + + // Resume after approval + await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None); + + // Assert - the originally-approved function must be invoked, not the mutated one + Assert.NotNull(capturedFunctionName); + Assert.Equal(ApprovedFunctionName, capturedFunctionName); + } + + /// + /// Verifies that mutating an argument variable after approval does not change + /// the arguments actually passed to the invoked function. + /// + [Fact] + public async Task InvokeFunctionToolCaptureResponseUsesApprovedArgumentsNotMutatedAsync() + { + // Arrange + const string FunctionName = "process_query"; + const string ArgumentKey = "query"; + const string ApprovedQuery = "SELECT * FROM users LIMIT 10"; + const string MutatedQuery = "DROP TABLE users CASCADE; --"; + + this.State.Set("SqlQuery", FormulaValue.New(ApprovedQuery)); + this.State.InitializeSystem(); + this.State.Bind(); + + InvokeFunctionTool model = this.CreateModelWithVariableArgument( + displayName: nameof(InvokeFunctionToolCaptureResponseUsesApprovedArgumentsNotMutatedAsync), + functionName: FunctionName, + argumentKey: ArgumentKey, + variableName: "SqlQuery"); + + AIFunctionArguments? capturedArguments = null; + TestFunctionAgentProvider testAgentProvider = new( + [AIFunctionFactory.Create((string query) => $"executed:{query}", name: FunctionName)], + onInvokeArguments: args => capturedArguments = args); + InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State); + + // Act - trigger ExecuteAsync to store the approval snapshot + Mock mockContext = CreateMockWorkflowContext(); + await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None); + + // Simulate parallel branch mutating state during the approval window + this.State.Set("SqlQuery", FormulaValue.New(MutatedQuery)); + this.State.Bind(); + + // User clicks approve + ExternalInputResponse response = CreateApprovalResponse(action.Id, approved: true); + + // Resume after approval + await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None); + + // Assert - the originally-approved argument must be used, not the mutated one + Assert.NotNull(capturedArguments); + Assert.Equal(ApprovedQuery, capturedArguments[ArgumentKey]?.ToString()); + } + + /// + /// Verifies that the approval snapshot survives a checkpoint/restore cycle. + /// After restore, the originally-approved function must still be used even if state was mutated. + /// + [Fact] + public async Task InvokeFunctionToolCaptureResponseUsesSnapshotAfterCheckpointRestoreAsync() + { + // Arrange + const string ApprovedFunctionName = "safe_readonly_query"; + const string MutatedFunctionName = "dangerous_admin_tool"; + + this.State.Set("TargetFunction", FormulaValue.New(ApprovedFunctionName)); + this.State.InitializeSystem(); + this.State.Bind(); + + InvokeFunctionTool model = this.CreateModelWithVariableFunctionName( + displayName: nameof(InvokeFunctionToolCaptureResponseUsesSnapshotAfterCheckpointRestoreAsync), + variableName: "TargetFunction"); + + string? capturedFunctionName = null; + TestFunctionAgentProvider testAgentProvider = new( + [ + AIFunctionFactory.Create(() => "safe-result", name: ApprovedFunctionName), + AIFunctionFactory.Create(() => "dangerous-result", name: MutatedFunctionName), + ], + onInvoke: name => capturedFunctionName = name); + InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State); + + // Act - trigger ExecuteAsync to store the approval snapshot + Mock mockContext = CreateMockWorkflowContextWithStateStore(); + await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None); + + // Simulate checkpoint: persist to state store + await InvokeProtectedMethodAsync(action, "OnCheckpointingAsync", mockContext.Object, CancellationToken.None); + + // Simulate restore on a "new" executor instance by clearing the in-memory field via reflection + // (In production, a new executor instance would be created with _approvalSnapshot == null) + typeof(InvokeFunctionToolExecutor) + .GetField("_approvalSnapshot", BindingFlags.NonPublic | BindingFlags.Instance)! + .SetValue(action, null); + + // Restore from state store + await InvokeProtectedMethodAsync(action, "OnCheckpointRestoredAsync", mockContext.Object, CancellationToken.None); + + // Mutate state after restore (simulating parallel branch) + this.State.Set("TargetFunction", FormulaValue.New(MutatedFunctionName)); + this.State.Bind(); + + // User clicks approve + ExternalInputResponse response = CreateApprovalResponse(action.Id, approved: true); + + // Resume after approval + await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None); + + // Assert - the originally-approved function must be invoked, not the mutated one + Assert.NotNull(capturedFunctionName); + Assert.Equal(ApprovedFunctionName, capturedFunctionName); + } + + /// + /// Verifies that the approval snapshot is cleared after a completed approval cycle, + /// both in-memory and in the persisted state store. This prevents stale data from + /// influencing a subsequent execution of the same executor instance. + /// + [Fact] + public async Task InvokeFunctionToolCaptureResponseClearsSnapshotAfterCompletionAsync() + { + // Arrange + const string FunctionName = "any_function"; + + this.State.InitializeSystem(); + this.State.Bind(); + + InvokeFunctionTool model = this.CreateModel( + displayName: nameof(InvokeFunctionToolCaptureResponseClearsSnapshotAfterCompletionAsync), + functionName: FunctionName, + requireApproval: true); + + TestFunctionAgentProvider testAgentProvider = new( + [AIFunctionFactory.Create(() => "result", name: FunctionName)]); + InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State); + + // Act - run the full approval cycle + Dictionary stateStore = []; + Mock mockContext = CreateMockWorkflowContextWithStateStore(stateStore); + await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None); + + // Sanity: snapshot was captured + FieldInfo snapshotField = typeof(InvokeFunctionToolExecutor) + .GetField("_approvalSnapshot", BindingFlags.NonPublic | BindingFlags.Instance)!; + Assert.NotNull(snapshotField.GetValue(action)); + + ExternalInputResponse response = CreateApprovalResponse(action.Id, approved: true); + await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None); + + // Assert - both in-memory field and persisted state are cleared + Assert.Null(snapshotField.GetValue(action)); + Assert.True(stateStore.ContainsKey("_approvalSnapshot")); + Assert.Null(stateStore["_approvalSnapshot"]); + } + + private static ExternalInputResponse CreateApprovalResponse(string actionId, bool approved) + { + FunctionCallContent functionCall = new(callId: actionId, name: "ignored"); + ToolApprovalRequestContent approvalRequest = new(actionId, functionCall); + ToolApprovalResponseContent approvalResponse = approvalRequest.CreateResponse(approved); + return new ExternalInputResponse(new ChatMessage(ChatRole.User, [approvalResponse])); + } + + private static Mock CreateMockWorkflowContext() + { + Mock mockContext = new(); + mockContext.Setup(c => c.AddEventAsync(It.IsAny(), It.IsAny())) + .Returns(default(ValueTask)); + mockContext.Setup(c => c.QueueStateUpdateAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(default(ValueTask)); + mockContext.Setup(c => c.SendMessageAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(default(ValueTask)); + return mockContext; + } + + /// + /// Creates a mock workflow context that actually stores state values (for checkpoint/restore tests). + /// Optionally accepts an externally-owned dictionary so callers can inspect the persisted state. + /// + private static Mock CreateMockWorkflowContextWithStateStore(Dictionary? stateStore = null) + { + stateStore ??= []; + Mock mockContext = new(); + mockContext.Setup(c => c.AddEventAsync(It.IsAny(), It.IsAny())) + .Returns(default(ValueTask)); + mockContext.Setup(c => c.QueueStateUpdateAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((key, value, _, _) => stateStore[key] = value) + .Returns(default(ValueTask)); + mockContext.Setup(c => c.SendMessageAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(default(ValueTask)); + mockContext.Setup(c => c.ReadStateAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((key, _, _) => + new ValueTask(stateStore.TryGetValue(key, out object? val) ? val as ApprovalSnapshot : null)); + mockContext.Setup(c => c.ReadStateKeysAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new HashSet()); + return mockContext; + } + + /// + /// Invokes a protected method on the executor via reflection (for testing checkpoint hooks). + /// + private static async ValueTask InvokeProtectedMethodAsync(InvokeFunctionToolExecutor action, string methodName, IWorkflowContext context, CancellationToken cancellationToken) + { + MethodInfo method = typeof(InvokeFunctionToolExecutor) + .GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Instance)!; + ValueTask result = (ValueTask)method.Invoke(action, [context, cancellationToken])!; + await result.ConfigureAwait(false); + } + + /// + /// Minimal concrete that exposes an injected + /// registry and records which function got invoked. + /// Used by the framework-invoke approval branch (InvokeRegisteredFunctionAsync). + /// + private sealed class TestFunctionAgentProvider : ResponseAgentProvider + { + private readonly Action? _onInvoke; + private readonly Action? _onInvokeArguments; + + public TestFunctionAgentProvider( + IEnumerable functions, + Action? onInvoke = null, + Action? onInvokeArguments = null) + { + this._onInvoke = onInvoke; + this._onInvokeArguments = onInvokeArguments; + this.Functions = functions.Select(f => (AIFunction)new RecordingAIFunction(f, this)).ToList(); + } + + internal void RecordInvocation(string name, AIFunctionArguments? arguments) + { + this._onInvoke?.Invoke(name); + if (arguments is not null) + { + this._onInvokeArguments?.Invoke(arguments); + } + } + + public override Task CreateConversationAsync(CancellationToken cancellationToken = default) => + throw new NotSupportedException(); + + public override Task CreateMessageAsync(string conversationId, ChatMessage conversationMessage, CancellationToken cancellationToken = default) => + throw new NotSupportedException(); + + public override Task GetMessageAsync(string conversationId, string messageId, CancellationToken cancellationToken = default) => + throw new NotSupportedException(); + + public override IAsyncEnumerable InvokeAgentAsync( + string agentId, string? agentVersion, string? conversationId, + IEnumerable? messages, IDictionary? inputArguments, + CancellationToken cancellationToken = default) => + throw new NotSupportedException(); + + public override IAsyncEnumerable GetMessagesAsync( + string conversationId, int? limit = null, string? after = null, string? before = null, + bool newestFirst = false, CancellationToken cancellationToken = default) => + throw new NotSupportedException(); + + private sealed class RecordingAIFunction(AIFunction inner, TestFunctionAgentProvider owner) : AIFunction + { + public override string Name => inner.Name; + public override string Description => inner.Description; + public override JsonElement JsonSchema => inner.JsonSchema; + + protected override ValueTask InvokeCoreAsync(AIFunctionArguments arguments, CancellationToken cancellationToken) + { + owner.RecordInvocation(inner.Name, arguments); + return inner.InvokeAsync(arguments, cancellationToken); + } + } + } + + #endregion + #region Helper Methods private async Task ExecuteTestAsync(InvokeFunctionTool model) @@ -318,5 +645,33 @@ private InvokeFunctionTool CreateModel( return AssignParent(builder); } + private InvokeFunctionTool CreateModelWithVariableFunctionName(string displayName, string variableName) + { + InvokeFunctionTool.Builder builder = new() + { + Id = this.CreateActionId(), + DisplayName = this.FormatDisplayName(displayName), + FunctionName = new StringExpression.Builder( + StringExpression.Variable(PropertyPath.TopicVariable(variableName))), + RequireApproval = new BoolExpression.Builder(BoolExpression.Literal(true)), + }; + return AssignParent(builder); + } + + private InvokeFunctionTool CreateModelWithVariableArgument( + string displayName, string functionName, string argumentKey, string variableName) + { + InvokeFunctionTool.Builder builder = new() + { + Id = this.CreateActionId(), + DisplayName = this.FormatDisplayName(displayName), + FunctionName = new StringExpression.Builder(StringExpression.Literal(functionName)), + RequireApproval = new BoolExpression.Builder(BoolExpression.Literal(true)), + }; + builder.Arguments.Add(argumentKey, + ValueExpression.Variable(PropertyPath.TopicVariable(variableName))); + return AssignParent(builder); + } + #endregion } From 7cf215b8e9e72a472595995f7c6c9cb9482e0f59 Mon Sep 17 00:00:00 2001 From: Peter Ibekwe Date: Tue, 9 Jun 2026 17:37:48 -0700 Subject: [PATCH 2/5] Update mcp handler cache construction --- .../DefaultMcpToolHandler.cs | 78 ++++++-- .../DefaultMcpToolHandlerTests.cs | 183 ++++++++++++++++++ 2 files changed, 243 insertions(+), 18 deletions(-) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative.Mcp/DefaultMcpToolHandler.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative.Mcp/DefaultMcpToolHandler.cs index c133b38bbdd..46e405244a9 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative.Mcp/DefaultMcpToolHandler.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative.Mcp/DefaultMcpToolHandler.cs @@ -2,10 +2,10 @@ using System; using System.Collections.Generic; -using System.Globalization; using System.IO; using System.Linq; using System.Net.Http; +using System.Security.Cryptography; using System.Text; using System.Text.Json; using System.Threading; @@ -39,7 +39,7 @@ public sealed class DefaultMcpToolHandler : IMcpToolHandler, IAsyncDisposable private static readonly JsonWriterOptions s_toolListJsonWriterOptions = new() { Indented = true }; private readonly Func>? _httpClientProvider; - private readonly Dictionary _clients = []; + private readonly Dictionary<(string Url, string Label, string Connection, string HeadersHash), McpClient> _clients = []; private readonly Dictionary _ownedHttpClients = []; private readonly SemaphoreSlim _clientLock = new(1, 1); @@ -66,16 +66,15 @@ public async Task InvokeToolAsync( string? connectionName, CancellationToken cancellationToken = default) { - // TODO: Handle connectionName and server label appropriately when Hosted scenario supports them. For now, ignore if (IsListToolsToolName(toolName)) { ThrowIfListToolsArgumentsSpecified(arguments); - McpClient listToolsClient = await this.GetOrCreateClientAsync(serverUrl, serverLabel, headers, cancellationToken).ConfigureAwait(false); + McpClient listToolsClient = await this.GetOrCreateClientAsync(serverUrl, serverLabel, headers, connectionName, cancellationToken).ConfigureAwait(false); IList tools = await listToolsClient.ListToolsAsync(cancellationToken: cancellationToken).ConfigureAwait(false); return CreateListToolsResultContent(tools.Select(tool => tool.ProtocolTool)); } - McpClient client = await this.GetOrCreateClientAsync(serverUrl, serverLabel, headers, cancellationToken).ConfigureAwait(false); + McpClient client = await this.GetOrCreateClientAsync(serverUrl, serverLabel, headers, connectionName, cancellationToken).ConfigureAwait(false); McpServerToolResultContent resultContent = new(Guid.NewGuid().ToString()); @@ -145,10 +144,11 @@ private async Task GetOrCreateClientAsync( string serverUrl, string? serverLabel, IDictionary? headers, + string? connectionName, CancellationToken cancellationToken) { - string normalizedUrl = serverUrl.Trim().ToUpperInvariant(); - string clientCacheKey = $"{normalizedUrl}|{ComputeHeadersHash(headers)}"; + string trimmedUrl = serverUrl.Trim(); + var clientCacheKey = BuildCacheKey(trimmedUrl, serverLabel, connectionName, headers); await this._clientLock.WaitAsync(cancellationToken).ConfigureAwait(false); try @@ -158,7 +158,7 @@ private async Task GetOrCreateClientAsync( return existingClient; } - McpClient newClient = await this.CreateClientAsync(serverUrl, serverLabel, headers, normalizedUrl, cancellationToken).ConfigureAwait(false); + McpClient newClient = await this.CreateClientAsync(serverUrl, serverLabel, headers, trimmedUrl, cancellationToken).ConfigureAwait(false); this._clients[clientCacheKey] = newClient; return newClient; } @@ -168,6 +168,19 @@ private async Task GetOrCreateClientAsync( } } + /// + /// Builds the per-client cache key as a 4-tuple of + /// (trimmed serverUrl, serverLabel, connectionName, headers hash). All four components + /// participate so that callers using different labels/connections/headers receive + /// distinct instances even when targeting the same URL. + /// + internal static (string Url, string Label, string Connection, string HeadersHash) BuildCacheKey( + string trimmedUrl, + string? serverLabel, + string? connectionName, + IDictionary? headers) => + (trimmedUrl, serverLabel ?? string.Empty, connectionName ?? string.Empty, ComputeHeadersHash(headers)); + private async Task CreateClientAsync( string serverUrl, string? serverLabel, @@ -185,7 +198,12 @@ private async Task CreateClientAsync( if (httpClient is null && !this._ownedHttpClients.TryGetValue(httpClientCacheKey, out httpClient)) { - httpClient = new HttpClient(); + // Disable cookies so handler-level state (cookie jar) cannot cross the cache-key + // isolation boundary established by GetOrCreateClientAsync. The actual MCP auth + // travels via AdditionalHeaders (set per-transport below), not session cookies. + // CheckCertificateRevocationList satisfies CA5399 since we're explicitly constructing the handler. + HttpClientHandler handler = new() { UseCookies = false, CheckCertificateRevocationList = true }; + httpClient = new HttpClient(handler); this._ownedHttpClients[httpClientCacheKey] = httpClient; } @@ -202,26 +220,50 @@ private async Task CreateClientAsync( return await McpClient.CreateAsync(transport, cancellationToken: cancellationToken).ConfigureAwait(false); } - private static string ComputeHeadersHash(IDictionary? headers) + /// + /// Computes a deterministic, order-independent hash of the header set. + /// Header names are lower-cased for case-insensitive matching (RFC 7230 §3.2). + /// Header values remain case-sensitive (RFC 7235 — credentials are case-sensitive). + /// +#pragma warning disable CA1308 // RFC 7230 §3.2 requires lower-cased header names for case-insensitive comparison; CA1308's uppercase preference does not apply here + internal static string ComputeHeadersHash(IDictionary? headers) { if (headers is null || headers.Count == 0) { return string.Empty; } - // Build a deterministic, sorted representation of the headers - // Within a single process lifetime, the hashcodes are consistent. - // This will ensure that the same set of headers always produces the same hash, regardless of order. - SortedDictionary sorted = new(headers.ToDictionary(h => h.Key.ToUpperInvariant(), h => h.Value.ToUpperInvariant())); - int hashCode = 17; + // Sort by lower-cased key for deterministic ordering, preserving value case. + SortedDictionary sorted = new(StringComparer.Ordinal); + foreach (KeyValuePair header in headers) + { + sorted[header.Key.ToLowerInvariant()] = header.Value; + } + + StringBuilder payload = new(); foreach (KeyValuePair kvp in sorted) { - hashCode = (hashCode * 31) + StringComparer.OrdinalIgnoreCase.GetHashCode(kvp.Key); - hashCode = (hashCode * 31) + StringComparer.OrdinalIgnoreCase.GetHashCode(kvp.Value); + payload.Append(kvp.Key).Append(':').Append(kvp.Value).Append('\n'); + } + + byte[] inputBytes = Encoding.UTF8.GetBytes(payload.ToString()); +#if NET5_0_OR_GREATER + byte[] hashBytes = SHA256.HashData(inputBytes); +#else + using SHA256 sha256 = SHA256.Create(); + byte[] hashBytes = sha256.ComputeHash(inputBytes); +#endif + + // Convert to hex string (compatible with net472/netstandard2.0) + StringBuilder hex = new(hashBytes.Length * 2); + foreach (byte b in hashBytes) + { + hex.Append(b.ToString("X2", System.Globalization.CultureInfo.InvariantCulture)); } - return hashCode.ToString(CultureInfo.InvariantCulture); + return hex.ToString(); } +#pragma warning restore CA1308 private static void ThrowIfListToolsArgumentsSpecified(IDictionary? arguments) { diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.Mcp.UnitTests/DefaultMcpToolHandlerTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.Mcp.UnitTests/DefaultMcpToolHandlerTests.cs index 1327c3df48f..1470e4f0f56 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.Mcp.UnitTests/DefaultMcpToolHandlerTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.Mcp.UnitTests/DefaultMcpToolHandlerTests.cs @@ -321,6 +321,189 @@ await handler.InvokeToolAsync( #endregion + #region ComputeHeadersHash Tests + + [Fact] + public void ComputeHeadersHash_WithNullHeaders_ReturnsEmptyString() + { + // Act + string result = DefaultMcpToolHandler.ComputeHeadersHash(null); + + // Assert + result.Should().BeEmpty(); + } + + [Fact] + public void ComputeHeadersHash_WithEmptyHeaders_ReturnsEmptyString() + { + // Act + string result = DefaultMcpToolHandler.ComputeHeadersHash(new Dictionary()); + + // Assert + result.Should().BeEmpty(); + } + + [Fact] + public void ComputeHeadersHash_SameHeadersDifferentOrder_ReturnsSameHash() + { + // Arrange + Dictionary headers1 = new() + { + ["Authorization"] = "Bearer token123", + ["X-Custom"] = "value1" + }; + Dictionary headers2 = new() + { + ["X-Custom"] = "value1", + ["Authorization"] = "Bearer token123" + }; + + // Act + string hash1 = DefaultMcpToolHandler.ComputeHeadersHash(headers1); + string hash2 = DefaultMcpToolHandler.ComputeHeadersHash(headers2); + + // Assert + hash1.Should().Be(hash2); + } + + [Fact] + public void ComputeHeadersHash_SameKeysDifferentCaseKeys_ReturnsSameHash() + { + // Arrange — RFC 7230: header names are case-insensitive + Dictionary headers1 = new() { ["Authorization"] = "Bearer token" }; + Dictionary headers2 = new() { ["authorization"] = "Bearer token" }; + + // Act + string hash1 = DefaultMcpToolHandler.ComputeHeadersHash(headers1); + string hash2 = DefaultMcpToolHandler.ComputeHeadersHash(headers2); + + // Assert + hash1.Should().Be(hash2); + } + + [Fact] + public void ComputeHeadersHash_SameKeysDifferentCaseValues_ReturnsDifferentHash() + { + // Arrange — RFC 7235: credentials are case-sensitive + Dictionary headers1 = new() { ["Authorization"] = "Bearer ABC" }; + Dictionary headers2 = new() { ["Authorization"] = "Bearer abc" }; + + // Act + string hash1 = DefaultMcpToolHandler.ComputeHeadersHash(headers1); + string hash2 = DefaultMcpToolHandler.ComputeHeadersHash(headers2); + + // Assert + hash1.Should().NotBe(hash2); + } + + [Fact] + public void ComputeHeadersHash_DifferentHeaders_ReturnsDifferentHash() + { + // Arrange + Dictionary headers1 = new() { ["Authorization"] = "Bearer token1" }; + Dictionary headers2 = new() { ["Authorization"] = "Bearer token2" }; + + // Act + string hash1 = DefaultMcpToolHandler.ComputeHeadersHash(headers1); + string hash2 = DefaultMcpToolHandler.ComputeHeadersHash(headers2); + + // Assert + hash1.Should().NotBe(hash2); + } + + #endregion + + #region Cache Key Discrimination Tests + + // These tests exercise BuildCacheKey directly because the integration path + // (InvokeToolAsync against a fake server) doesn't surface cache-hit behavior + // without standing up a real MCP server — McpClient.CreateAsync fails before + // _clients[key] = newClient runs, so nothing ever gets cached. + // Tuple equality on the returned 4-tuple verifies that the dimensions + // collectively discriminate cache entries. + + [Fact] + public void BuildCacheKey_SameInputs_ReturnsEqualKeys() + { + // Arrange + Dictionary headers = new() { ["Authorization"] = "Bearer token" }; + + // Act + var key1 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", "label", "conn", headers); + var key2 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", "label", "conn", headers); + + // Assert + key1.Should().Be(key2); + } + + [Fact] + public void BuildCacheKey_DifferentConnectionName_ReturnsDifferentKeys() + { + // Act + var key1 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", "label", "connection-a", null); + var key2 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", "label", "connection-b", null); + + // Assert + key1.Should().NotBe(key2); + key1.Connection.Should().Be("connection-a"); + key2.Connection.Should().Be("connection-b"); + } + + [Fact] + public void BuildCacheKey_DifferentServerLabel_ReturnsDifferentKeys() + { + // Act + var key1 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", "label-a", null, null); + var key2 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", "label-b", null, null); + + // Assert + key1.Should().NotBe(key2); + key1.Label.Should().Be("label-a"); + key2.Label.Should().Be("label-b"); + } + + [Fact] + public void BuildCacheKey_CaseSensitiveUrlPath_ReturnsDifferentKeys() + { + // Arrange — RFC 3986: URL path is case-sensitive + // Act + var key1 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/Tools", null, null, null); + var key2 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/tools", null, null, null); + + // Assert + key1.Should().NotBe(key2); + } + + [Fact] + public void BuildCacheKey_HeaderValuesCaseSensitive_ReturnsDifferentKeys() + { + // Arrange — RFC 7235: credentials are case-sensitive + Dictionary headers1 = new() { ["Authorization"] = "Bearer ABC" }; + Dictionary headers2 = new() { ["Authorization"] = "Bearer abc" }; + + // Act + var key1 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", null, null, headers1); + var key2 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", null, null, headers2); + + // Assert — header value case must propagate into the cache key + key1.Should().NotBe(key2); + key1.HeadersHash.Should().NotBe(key2.HeadersHash); + } + + [Fact] + public void BuildCacheKey_NullLabelAndConnection_NormalizesToEmptyString() + { + // Act + var key = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", null, null, null); + + // Assert — verifies null-safety contract callers rely on + key.Label.Should().BeEmpty(); + key.Connection.Should().BeEmpty(); + key.HeadersHash.Should().BeEmpty(); + } + + #endregion + #region Reserved Tools/List Tests [Fact] From ea8c35296caaef55f1c69b2b969240beed0db38b Mon Sep 17 00:00:00 2001 From: Peter Ibekwe Date: Tue, 9 Jun 2026 18:03:29 -0700 Subject: [PATCH 3/5] fix method argument. --- .../DefaultMcpToolHandler.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative.Mcp/DefaultMcpToolHandler.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative.Mcp/DefaultMcpToolHandler.cs index 46e405244a9..8b0a410d8bd 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative.Mcp/DefaultMcpToolHandler.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative.Mcp/DefaultMcpToolHandler.cs @@ -158,7 +158,7 @@ private async Task GetOrCreateClientAsync( return existingClient; } - McpClient newClient = await this.CreateClientAsync(serverUrl, serverLabel, headers, trimmedUrl, cancellationToken).ConfigureAwait(false); + McpClient newClient = await this.CreateClientAsync(trimmedUrl, serverLabel, headers, trimmedUrl, cancellationToken).ConfigureAwait(false); this._clients[clientCacheKey] = newClient; return newClient; } From fe04a3ac769d287b72483c6deedcc99037f03128 Mon Sep 17 00:00:00 2001 From: Peter Ibekwe <109177538+peibekwe@users.noreply.github.com> Date: Wed, 10 Jun 2026 08:49:32 -0700 Subject: [PATCH 4/5] Update dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeFunctionToolExecutor.cs Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../ObjectModel/InvokeFunctionToolExecutor.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeFunctionToolExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeFunctionToolExecutor.cs index 7de174404ac..8554b151821 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeFunctionToolExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeFunctionToolExecutor.cs @@ -329,7 +329,7 @@ private string GetFunctionName() => }; } - AIFunctionArguments? functionArguments = arguments is null ? null : new AIFunctionArguments(arguments); +AIFunctionArguments? functionArguments = arguments is null ? null : new AIFunctionArguments(arguments.NormalizePortableValues()); object? result; try From b4db4c8f02078263a6b403e32d6d769a2a9dc3c9 Mon Sep 17 00:00:00 2001 From: Peter Ibekwe Date: Wed, 10 Jun 2026 09:09:03 -0700 Subject: [PATCH 5/5] Fix identation --- .../ObjectModel/InvokeFunctionToolExecutor.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeFunctionToolExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeFunctionToolExecutor.cs index 8554b151821..08d57a6b6e6 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeFunctionToolExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeFunctionToolExecutor.cs @@ -329,7 +329,7 @@ private string GetFunctionName() => }; } -AIFunctionArguments? functionArguments = arguments is null ? null : new AIFunctionArguments(arguments.NormalizePortableValues()); + AIFunctionArguments? functionArguments = arguments is null ? null : new AIFunctionArguments(arguments.NormalizePortableValues()); object? result; try