Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

using System;
using System.Collections.Generic;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Security.Cryptography;
using System.Text;
using System.Text.Json;
using System.Threading;
Expand Down Expand Up @@ -39,7 +39,7 @@ public sealed class DefaultMcpToolHandler : IMcpToolHandler, IAsyncDisposable
private static readonly JsonWriterOptions s_toolListJsonWriterOptions = new() { Indented = true };

private readonly Func<string, CancellationToken, Task<HttpClient?>>? _httpClientProvider;
private readonly Dictionary<string, McpClient> _clients = [];
private readonly Dictionary<(string Url, string Label, string Connection, string HeadersHash), McpClient> _clients = [];
private readonly Dictionary<string, HttpClient> _ownedHttpClients = [];
private readonly SemaphoreSlim _clientLock = new(1, 1);

Expand All @@ -66,16 +66,15 @@ public async Task<McpServerToolResultContent> InvokeToolAsync(
string? connectionName,
CancellationToken cancellationToken = default)
{
// TODO: Handle connectionName and server label appropriately when Hosted scenario supports them. For now, ignore
if (IsListToolsToolName(toolName))
{
ThrowIfListToolsArgumentsSpecified(arguments);
McpClient listToolsClient = await this.GetOrCreateClientAsync(serverUrl, serverLabel, headers, cancellationToken).ConfigureAwait(false);
McpClient listToolsClient = await this.GetOrCreateClientAsync(serverUrl, serverLabel, headers, connectionName, cancellationToken).ConfigureAwait(false);
IList<McpClientTool> tools = await listToolsClient.ListToolsAsync(cancellationToken: cancellationToken).ConfigureAwait(false);
return CreateListToolsResultContent(tools.Select(tool => tool.ProtocolTool));
}

McpClient client = await this.GetOrCreateClientAsync(serverUrl, serverLabel, headers, cancellationToken).ConfigureAwait(false);
McpClient client = await this.GetOrCreateClientAsync(serverUrl, serverLabel, headers, connectionName, cancellationToken).ConfigureAwait(false);

McpServerToolResultContent resultContent = new(Guid.NewGuid().ToString());

Expand Down Expand Up @@ -145,10 +144,11 @@ private async Task<McpClient> GetOrCreateClientAsync(
string serverUrl,
string? serverLabel,
IDictionary<string, string>? headers,
string? connectionName,
CancellationToken cancellationToken)
{
string normalizedUrl = serverUrl.Trim().ToUpperInvariant();
string clientCacheKey = $"{normalizedUrl}|{ComputeHeadersHash(headers)}";
string trimmedUrl = serverUrl.Trim();
var clientCacheKey = BuildCacheKey(trimmedUrl, serverLabel, connectionName, headers);

await this._clientLock.WaitAsync(cancellationToken).ConfigureAwait(false);
try
Expand All @@ -158,7 +158,7 @@ private async Task<McpClient> GetOrCreateClientAsync(
return existingClient;
}

McpClient newClient = await this.CreateClientAsync(serverUrl, serverLabel, headers, normalizedUrl, cancellationToken).ConfigureAwait(false);
McpClient newClient = await this.CreateClientAsync(trimmedUrl, serverLabel, headers, trimmedUrl, cancellationToken).ConfigureAwait(false);
this._clients[clientCacheKey] = newClient;
return newClient;
}
Expand All @@ -168,6 +168,19 @@ private async Task<McpClient> GetOrCreateClientAsync(
}
}

/// <summary>
/// Builds the per-client cache key as a 4-tuple of
/// (trimmed serverUrl, serverLabel, connectionName, headers hash). All four components
/// participate so that callers using different labels/connections/headers receive
/// distinct <see cref="McpClient"/> instances even when targeting the same URL.
/// </summary>
internal static (string Url, string Label, string Connection, string HeadersHash) BuildCacheKey(
string trimmedUrl,
string? serverLabel,
string? connectionName,
IDictionary<string, string>? headers) =>
(trimmedUrl, serverLabel ?? string.Empty, connectionName ?? string.Empty, ComputeHeadersHash(headers));

private async Task<McpClient> CreateClientAsync(
string serverUrl,
string? serverLabel,
Expand All @@ -185,7 +198,12 @@ private async Task<McpClient> CreateClientAsync(

if (httpClient is null && !this._ownedHttpClients.TryGetValue(httpClientCacheKey, out httpClient))
{
httpClient = new HttpClient();
// Disable cookies so handler-level state (cookie jar) cannot cross the cache-key
// isolation boundary established by GetOrCreateClientAsync. The actual MCP auth
// travels via AdditionalHeaders (set per-transport below), not session cookies.
// CheckCertificateRevocationList satisfies CA5399 since we're explicitly constructing the handler.
HttpClientHandler handler = new() { UseCookies = false, CheckCertificateRevocationList = true };
httpClient = new HttpClient(handler);
this._ownedHttpClients[httpClientCacheKey] = httpClient;
}

Expand All @@ -202,26 +220,50 @@ private async Task<McpClient> CreateClientAsync(
return await McpClient.CreateAsync(transport, cancellationToken: cancellationToken).ConfigureAwait(false);
}

private static string ComputeHeadersHash(IDictionary<string, string>? headers)
/// <summary>
/// Computes a deterministic, order-independent hash of the header set.
/// Header names are lower-cased for case-insensitive matching (RFC 7230 §3.2).
/// Header values remain case-sensitive (RFC 7235 — credentials are case-sensitive).
/// </summary>
#pragma warning disable CA1308 // RFC 7230 §3.2 requires lower-cased header names for case-insensitive comparison; CA1308's uppercase preference does not apply here
internal static string ComputeHeadersHash(IDictionary<string, string>? headers)
{
if (headers is null || headers.Count == 0)
{
return string.Empty;
}

// Build a deterministic, sorted representation of the headers
// Within a single process lifetime, the hashcodes are consistent.
// This will ensure that the same set of headers always produces the same hash, regardless of order.
SortedDictionary<string, string> sorted = new(headers.ToDictionary(h => h.Key.ToUpperInvariant(), h => h.Value.ToUpperInvariant()));
int hashCode = 17;
// Sort by lower-cased key for deterministic ordering, preserving value case.
SortedDictionary<string, string> sorted = new(StringComparer.Ordinal);
foreach (KeyValuePair<string, string> header in headers)
{
sorted[header.Key.ToLowerInvariant()] = header.Value;
}

StringBuilder payload = new();
foreach (KeyValuePair<string, string> kvp in sorted)
{
hashCode = (hashCode * 31) + StringComparer.OrdinalIgnoreCase.GetHashCode(kvp.Key);
hashCode = (hashCode * 31) + StringComparer.OrdinalIgnoreCase.GetHashCode(kvp.Value);
payload.Append(kvp.Key).Append(':').Append(kvp.Value).Append('\n');
}

byte[] inputBytes = Encoding.UTF8.GetBytes(payload.ToString());
#if NET5_0_OR_GREATER
byte[] hashBytes = SHA256.HashData(inputBytes);
#else
using SHA256 sha256 = SHA256.Create();
byte[] hashBytes = sha256.ComputeHash(inputBytes);
#endif

// Convert to hex string (compatible with net472/netstandard2.0)
StringBuilder hex = new(hashBytes.Length * 2);
foreach (byte b in hashBytes)
{
hex.Append(b.ToString("X2", System.Globalization.CultureInfo.InvariantCulture));
}

return hashCode.ToString(CultureInfo.InvariantCulture);
return hex.ToString();
}
#pragma warning restore CA1308

private static void ThrowIfListToolsArgumentsSpecified(IDictionary<string, object?>? arguments)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using Microsoft.Agents.AI.Workflows.Declarative.PowerFx;
using Microsoft.Agents.ObjectModel;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Logging;
using Microsoft.Shared.Diagnostics;

namespace Microsoft.Agents.AI.Workflows.Declarative.ObjectModel;
Expand All @@ -27,6 +28,13 @@ internal sealed class InvokeFunctionToolExecutor(
WorkflowFormulaState state) :
DeclarativeActionExecutor<InvokeFunctionTool>(model, state)
{
private const string ApprovalSnapshotStateKey = nameof(_approvalSnapshot);

/// <summary>
/// Snapshot of evaluated parameters at approval-request time.
/// </summary>
private ApprovalSnapshot? _approvalSnapshot;

/// <summary>
/// Step identifiers for the function tool invocation workflow.
/// </summary>
Expand Down Expand Up @@ -69,6 +77,10 @@ public static class Steps
// If approval is required, add user input request content
if (requireApproval)
{
// Snapshot the evaluated parameters.
// If state mutates during the approval window, the approved values are used on resume.
this._approvalSnapshot = new ApprovalSnapshot(functionName, arguments);

requestMessage.Contents.Add(new ToolApprovalRequestContent(this.Id, functionCall));
}

Expand Down Expand Up @@ -155,6 +167,31 @@ public async ValueTask CaptureResponseAsync(

// Completes the action after processing the function result.
await context.RaiseCompletionEventAsync(this.Model, cancellationToken).ConfigureAwait(false);

// Clear the approval snapshot after the action completes so a subsequent
// execution of the same executor instance doesn't reuse stale data.
this._approvalSnapshot = null;
await context.QueueStateUpdateAsync<ApprovalSnapshot?>(ApprovalSnapshotStateKey, null, null, cancellationToken).ConfigureAwait(false);
}

/// <inheritdoc/>
/// <remarks>
/// Persists the approval snapshot to workflow state so it survives checkpoint/restore cycles.
/// </remarks>
protected override async ValueTask OnCheckpointingAsync(IWorkflowContext context, CancellationToken cancellationToken = default)
{
await context.QueueStateUpdateAsync(ApprovalSnapshotStateKey, this._approvalSnapshot, null, cancellationToken).ConfigureAwait(false);
await base.OnCheckpointingAsync(context, cancellationToken).ConfigureAwait(false);
}

/// <inheritdoc/>
/// <remarks>
/// Restores the approval snapshot from workflow state after a checkpoint restore.
/// </remarks>
protected override async ValueTask OnCheckpointRestoredAsync(IWorkflowContext context, CancellationToken cancellationToken = default)
{
await base.OnCheckpointRestoredAsync(context, cancellationToken).ConfigureAwait(false);
this._approvalSnapshot = await context.ReadStateAsync<ApprovalSnapshot>(ApprovalSnapshotStateKey, null, cancellationToken).ConfigureAwait(false);
}

/// <summary>
Expand Down Expand Up @@ -262,7 +299,24 @@ private string GetFunctionName() =>

private async ValueTask<FunctionResultContent?> InvokeRegisteredFunctionAsync(CancellationToken cancellationToken)
{
string functionName = this.GetFunctionName();
string functionName;
Dictionary<string, object?>? arguments;

if (this._approvalSnapshot is { } snapshot)
{
// Use the snapshot captured at approval-request time so we invoke exactly what
// the user approved, even if Power Fx state has mutated during the approval window.
functionName = snapshot.FunctionName;
arguments = snapshot.Arguments;
}
else
{
// Fallback for checkpoints created before approval snapshots were introduced.
this.Logger.LogWarning("Approval snapshot missing for '{ActionId}'; falling back to expression re-evaluation.", this.Id);
functionName = this.GetFunctionName();
arguments = this.GetArguments();
}

AIFunction? function = agentProvider.Functions?.FirstOrDefault(
f => string.Equals(f.Name, functionName, StringComparison.Ordinal));

Expand All @@ -275,8 +329,7 @@ private string GetFunctionName() =>
};
}

Dictionary<string, object?>? arguments = this.GetArguments();
AIFunctionArguments? functionArguments = arguments is null ? null : new AIFunctionArguments(arguments);
AIFunctionArguments? functionArguments = arguments is null ? null : new AIFunctionArguments(arguments.NormalizePortableValues());

object? result;
try
Expand Down Expand Up @@ -341,4 +394,13 @@ private bool GetAutoSendValue()

return result;
}

/// <summary>
/// Stores the evaluated parameters at approval-request time so that
/// <see cref="CaptureResponseAsync"/> uses the values the user reviewed,
/// even if <see cref="WorkflowFormulaState"/> mutates during the approval window.
/// </summary>
internal sealed record ApprovalSnapshot(
string FunctionName,
Dictionary<string, object?>? Arguments);
}
Loading
Loading