Skip to content

.NET: [Bug]: Approval middleware surfaces non-approval functions (GetDateTime) as approval requests #6264

@helloxubo

Description

@helloxubo

Description

We register GetDateTime as a normal tool on the agent and register GetWeather wrapped in ApprovalRequiredAIFunction for approval. However, when ConsolePromptingApprovalMiddleware inspects the run result it receives ToolApprovalRequestContent entries for both GetDateTime and GetWeather. Only GetWeather should require approval.

Code Sample

csharp
chatClient = chatClient.AsBuilder()
    .Use(getResponseFunc: ChatClientMiddlewareA, getStreamingResponseFunc: null)
    .Use(getResponseFunc: ChatClientMiddlewareB, getStreamingResponseFunc: null)
    .ConfigureOptions(options =>
    {
        options.RawRepresentationFactory = (response) =>
        {
            var opt = new OpenAI.Chat.ChatCompletionOptions();
#pragma warning disable SCME0001 // 类型仅用于评估,在将来的更新中可能会被更改或删除。取消此诊断以继续。
            opt.Patch.Set("$thinking.type"u8, "disabled");
#pragma warning restore SCME0001 // 类型仅用于评估,在将来的更新中可能会被更改或删除。取消此诊断以继续。
            return opt;
        };
    })
    .Build();

var originalAgent = chatClient
    .AsBuilder()
    .BuildAIAgent(
        instructions: "You are an AI assistant that helps people find information.",
    tools: [AIFunctionFactory.Create(GetDateTime, name: nameof(GetDateTime))]);


var middlewareAgent = originalAgent
    .AsBuilder()
    .Use(FunctionCallMiddlewareA)
    .Use(FunctionCallMiddlewareB)
    .Use(PIIMiddleware,null)
    .Use(GuardrailMiddleware,null)
    .Use(PerRequestFunctionCallingMiddleware)
    .Use(ConsolePromptingApprovalMiddleware, null)
    .Build();

var optionsWithApproval = new ChatClientAgentRunOptions(new()
{
    // Adding a function with approval required
    Tools = [new ApprovalRequiredAIFunction(AIFunctionFactory.Create(GetWeather, name: nameof(GetWeather)))],
})
{
    ChatClientFactory = (chatClient) => chatClient
        .AsBuilder()
        .Use(PerRequestChatClientMiddleware, null) // Using the non-streaming for handling streaming as well
        .Build()
};

var response = await middlewareAgent.RunAsync("What's the current time and the weather in Seattle?", options: optionsWithApproval, cancellationToken: default);


async Task<ChatResponse> ChatClientMiddlewareA(IEnumerable<ChatMessage> message, ChatOptions? options, IChatClient innerChatClient, CancellationToken cancellationToken)
{
    Console.WriteLine("Chat Client Middleware A - Pre-Chat");
    var response = await innerChatClient.GetResponseAsync(message, options, cancellationToken);
    Console.WriteLine("Chat Client Middleware A - Post-Chat");

    return response;
}

async Task<ChatResponse> ChatClientMiddlewareB(IEnumerable<ChatMessage> message, ChatOptions? options, IChatClient innerChatClient, CancellationToken cancellationToken)
{
    Console.WriteLine("Chat Client Middleware B - Pre-Chat");
    var response = await innerChatClient.GetResponseAsync(message, options, cancellationToken);
    Console.WriteLine("Chat Client Middleware B - Post-Chat");

    return response;
}

async ValueTask<object?> FunctionCallMiddlewareA(AIAgent agent, FunctionInvocationContext context, Func<FunctionInvocationContext, CancellationToken, ValueTask<object?>> next, CancellationToken cancellationToken)
{
    Console.WriteLine($"Function Name: {context!.Function.Name} - Middleware A Pre-Invoke");
    var result = await next(context, cancellationToken);
    Console.WriteLine($"Function Name: {context!.Function.Name} - Middleware A Post-Invoke");

    return result;
}

async ValueTask<object?> FunctionCallMiddlewareB(AIAgent agent, FunctionInvocationContext context, Func<FunctionInvocationContext, CancellationToken, ValueTask<object?>> next, CancellationToken cancellationToken)
{
    Console.WriteLine($"Function Name: {context!.Function.Name} - Middleware B Pre-Invoke");
    var result = await next(context, cancellationToken);
    Console.WriteLine($"Function Name: {context!.Function.Name} - Middleware B Post-Invoke");

    return result;
}

async Task<AgentResponse> PIIMiddleware(IEnumerable<ChatMessage> messages, AgentSession? session, AgentRunOptions? options, AIAgent innerAgent, CancellationToken cancellationToken)
{
    // Redact PII information from input messages
    var filteredMessages = FilterMessages(messages);
    Console.WriteLine("Pii Middleware - Filtered Messages Pre-Run");

    var response = await innerAgent.RunAsync(filteredMessages, session, options, cancellationToken).ConfigureAwait(false);

    // Redact PII information from output messages
    response.Messages = FilterMessages(response.Messages);

    Console.WriteLine("Pii Middleware - Filtered Messages Post-Run");

    return response;

    static IList<ChatMessage> FilterMessages(IEnumerable<ChatMessage> messages)
    {
        return messages.Select(m => new ChatMessage(m.Role, FilterPii(m.Text))).ToList();
    }

    static string FilterPii(string content)
    {
        return content;
    }
}

async Task<AgentResponse> GuardrailMiddleware(IEnumerable<ChatMessage> messages, AgentSession? session, AgentRunOptions? options, AIAgent innerAgent, CancellationToken cancellationToken)
{
    // Redact keywords from input messages
    var filteredMessages = FilterMessages(messages);

    Console.WriteLine("Guardrail Middleware - Filtered messages Pre-Run");

    // Proceed with the agent run
    var response = await innerAgent.RunAsync(filteredMessages, session, options, cancellationToken);

    // Redact keywords from output messages
    response.Messages = FilterMessages(response.Messages);

    Console.WriteLine("Guardrail Middleware - Filtered messages Post-Run");

    return response;

    List<ChatMessage> FilterMessages(IEnumerable<ChatMessage> messages)
    {
        return messages.Select(m => new ChatMessage(m.Role, FilterContent(m.Text))).ToList();
    }

    static string FilterContent(string content)
    {
        foreach (var keyword in new[] { "harmful", "illegal", "violence" })
        {
            if (content.Contains(keyword, StringComparison.OrdinalIgnoreCase))
            {
                return "[REDACTED: Forbidden content]";
            }
        }

        return content;
    }
}

async Task<ChatResponse> PerRequestChatClientMiddleware(IEnumerable<ChatMessage> message, ChatOptions? options, IChatClient innerChatClient, CancellationToken cancellationToken)
{
    Console.WriteLine("Per-Request Chat Client Middleware - Pre-Chat");
    var response = await innerChatClient.GetResponseAsync(message, options, cancellationToken);
    Console.WriteLine("Per-Request Chat Client Middleware - Post-Chat");

    return response;
}

async ValueTask<object?> PerRequestFunctionCallingMiddleware(AIAgent agent, FunctionInvocationContext context, Func<FunctionInvocationContext, CancellationToken, ValueTask<object?>> next, CancellationToken cancellationToken)
{
    Console.WriteLine($"Agent Id: {agent.Id}");
    Console.WriteLine($"Function Name: {context!.Function.Name} - Per-Request Pre-Invoke");
    var result = await next(context, cancellationToken);
    Console.WriteLine($"Function Name: {context!.Function.Name} - Per-Request Post-Invoke");
    return result;
}

async Task<AgentResponse> ConsolePromptingApprovalMiddleware(IEnumerable<ChatMessage> messages, AgentSession? session, AgentRunOptions? options, AIAgent innerAgent, CancellationToken cancellationToken)
{
    AgentResponse response = await innerAgent.RunAsync(messages, session, options, cancellationToken);

    // For simplicity, we are assuming here that only function approvals are pending.
    List<ToolApprovalRequestContent> approvalRequests = response.Messages.SelectMany(m => m.Contents).OfType<ToolApprovalRequestContent>().ToList();

    while (approvalRequests.Count > 0)
    {
        // Ask the user to approve each function call request.
        // Pass the user input responses back to the agent for further processing.
        response.Messages = approvalRequests
            .ConvertAll(functionApprovalRequest =>
            {
                Console.WriteLine($"The agent would like to invoke the following function, please reply Y to approve: Name {((FunctionCallContent)functionApprovalRequest.ToolCall).Name}");
                return new ChatMessage(ChatRole.User, [functionApprovalRequest.CreateResponse(Console.ReadLine()?.Equals("Y", StringComparison.OrdinalIgnoreCase) ?? false)]);
            });

        response = await innerAgent.RunAsync(response.Messages, session, options, cancellationToken);

        approvalRequests = response.Messages.SelectMany(m => m.Contents).OfType<ToolApprovalRequestContent>().ToList();
    }

    return response;
}

[Description("The current datetime offset.")]
static string GetDateTime()
    => DateTimeOffset.Now.ToString();

static string GetWeather([Description("The location to get the weather for.")] string location)
    => $"The weather in {location} is cloudy with a high of 15°C.";

Error Messages / Stack Traces

Package Versions

1.8.0

.NET Version

.net 10

Additional Context

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No fields configured for Bug.

    Projects

    Status
    No status

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions