diff --git a/agents/semantickernel-agents-core/src/main/java/com/microsoft/semantickernel/agents/chatcompletion/ChatCompletionAgent.java b/agents/semantickernel-agents-core/src/main/java/com/microsoft/semantickernel/agents/chatcompletion/ChatCompletionAgent.java index 7985f5ad..776ed6a5 100644 --- a/agents/semantickernel-agents-core/src/main/java/com/microsoft/semantickernel/agents/chatcompletion/ChatCompletionAgent.java +++ b/agents/semantickernel-agents-core/src/main/java/com/microsoft/semantickernel/agents/chatcompletion/ChatCompletionAgent.java @@ -6,10 +6,10 @@ import com.microsoft.semantickernel.agents.AgentThread; import com.microsoft.semantickernel.agents.KernelAgent; import com.microsoft.semantickernel.builders.SemanticKernelBuilder; +import com.microsoft.semantickernel.functionchoice.AutoFunctionChoiceBehavior; import com.microsoft.semantickernel.orchestration.InvocationContext; import com.microsoft.semantickernel.orchestration.InvocationReturnMode; import com.microsoft.semantickernel.orchestration.PromptExecutionSettings; -import com.microsoft.semantickernel.orchestration.ToolCallBehavior; import com.microsoft.semantickernel.semanticfunctions.KernelArguments; import com.microsoft.semantickernel.semanticfunctions.PromptTemplate; import com.microsoft.semantickernel.semanticfunctions.PromptTemplateConfig; @@ -62,7 +62,7 @@ private ChatCompletionAgent( @Override public Mono>>> invokeAsync( List> messages, - AgentThread thread, + @Nullable AgentThread thread, @Nullable AgentInvokeOptions options ) { return ensureThreadExistsWithMessagesAsync(messages, thread, ChatHistoryAgentThread::new) @@ -76,22 +76,20 @@ public Mono>>> invokeAsync( // Invoke the agent with the chat history return internalInvokeAsync( history, + agentThread, options ) - .flatMapMany(Flux::fromIterable) - // notify on the new thread instance - .concatMap(agentMessage -> this.notifyThreadOfNewMessageAsync(agentThread, agentMessage).thenReturn(agentMessage)) - .collectList() .map(chatMessageContents -> chatMessageContents.stream() - .map(message -> new AgentResponseItem>(message, agentThread)) - .collect(Collectors.toList()) + .map(message -> new AgentResponseItem>(message, agentThread)) + .collect(Collectors.toList()) ); }); } private Mono>> internalInvokeAsync( ChatHistory history, + AgentThread thread, @Nullable AgentInvokeOptions options ) { if (options == null) { @@ -144,6 +142,20 @@ private Mono>> internalInvokeAsync( // Add the chat history to the new chat chat.addAll(history); + // Retrieve the chat message contents asynchronously and notify the thread + if (shouldNotifyFunctionCalls(agentInvocationContext)) { + // Notify all messages including function calls + return chatCompletionService.getChatMessageContentsAsync(chat, kernel, agentInvocationContext) + .flatMapMany(Flux::fromIterable) + .concatMap(message -> notifyThreadOfNewMessageAsync(thread, message).thenReturn(message)) + // Filter out function calls and their results + .filter(message -> message.getContent() != null && message.getAuthorRole() != AuthorRole.TOOL) + .collect(Collectors.toList()); + } + + // Return chat completion messages without notifying the thread + // We shouldn't add the function call content to the thread, since + // we don't know if the user will execute the call. They should add it themselves. return chatCompletionService.getChatMessageContentsAsync(chat, kernel, agentInvocationContext); } ); @@ -153,6 +165,22 @@ private Mono>> internalInvokeAsync( } } + boolean shouldNotifyFunctionCalls(InvocationContext invocationContext) { + if (invocationContext == null) { + return false; + } + + if (invocationContext.getFunctionChoiceBehavior() != null && invocationContext.getFunctionChoiceBehavior() instanceof AutoFunctionChoiceBehavior) { + return ((AutoFunctionChoiceBehavior) invocationContext.getFunctionChoiceBehavior()).isAutoInvoke(); + } + + if (invocationContext.getToolCallBehavior() != null) { + return invocationContext.getToolCallBehavior().isAutoInvokeAllowed(); + } + + return false; + } + @Override public Mono notifyThreadOfNewMessageAsync(AgentThread thread, ChatMessageContent message) { diff --git a/agents/semantickernel-agents-core/src/main/java/com/microsoft/semantickernel/agents/chatcompletion/ChatHistoryAgentThread.java b/agents/semantickernel-agents-core/src/main/java/com/microsoft/semantickernel/agents/chatcompletion/ChatHistoryAgentThread.java index 1a68f8c4..753759c1 100644 --- a/agents/semantickernel-agents-core/src/main/java/com/microsoft/semantickernel/agents/chatcompletion/ChatHistoryAgentThread.java +++ b/agents/semantickernel-agents-core/src/main/java/com/microsoft/semantickernel/agents/chatcompletion/ChatHistoryAgentThread.java @@ -16,12 +16,25 @@ public class ChatHistoryAgentThread extends BaseAgentThread { private ChatHistory chatHistory; + /** + * Constructor for ChatHistoryAgentThread. + * + */ public ChatHistoryAgentThread() { this(UUID.randomUUID().toString(), new ChatHistory()); } /** - * Constructor for com.microsoft.semantickernel.agents.chatcompletion.ChatHistoryAgentThread. + * Constructor for ChatHistoryAgentThread. + * + * @param chatHistory The chat history. + */ + public ChatHistoryAgentThread(@Nullable ChatHistory chatHistory) { + this(UUID.randomUUID().toString(), chatHistory); + } + + /** + * Constructor for ChatHistoryAgentThread. * * @param id The ID of the thread. * @param chatHistory The chat history. @@ -31,6 +44,8 @@ public ChatHistoryAgentThread(String id, @Nullable ChatHistory chatHistory) { this.chatHistory = chatHistory != null ? chatHistory : new ChatHistory(); } + + /** * Get the chat history. * diff --git a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/agents/CompletionAgent.java b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/agents/CompletionAgent.java index 16941ee4..4c39645c 100644 --- a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/agents/CompletionAgent.java +++ b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/agents/CompletionAgent.java @@ -6,14 +6,17 @@ import com.azure.core.credential.KeyCredential; import com.microsoft.semantickernel.Kernel; import com.microsoft.semantickernel.agents.AgentInvokeOptions; +import com.microsoft.semantickernel.agents.AgentThread; import com.microsoft.semantickernel.agents.chatcompletion.ChatCompletionAgent; import com.microsoft.semantickernel.agents.chatcompletion.ChatHistoryAgentThread; import com.microsoft.semantickernel.aiservices.openai.chatcompletion.OpenAIChatCompletion; import com.microsoft.semantickernel.contextvariables.ContextVariableTypeConverter; +import com.microsoft.semantickernel.contextvariables.ContextVariableTypes; import com.microsoft.semantickernel.functionchoice.FunctionChoiceBehavior; import com.microsoft.semantickernel.implementation.templateengine.tokenizer.DefaultPromptTemplate; import com.microsoft.semantickernel.orchestration.InvocationContext; import com.microsoft.semantickernel.orchestration.PromptExecutionSettings; +import com.microsoft.semantickernel.orchestration.ToolCallBehavior; import com.microsoft.semantickernel.plugin.KernelPluginFactory; import com.microsoft.semantickernel.samples.plugins.github.GitHubModel; import com.microsoft.semantickernel.samples.plugins.github.GitHubPlugin; @@ -105,7 +108,7 @@ public static void main(String[] args) { ) ).build(); - ChatHistoryAgentThread agentThread = new ChatHistoryAgentThread(); + AgentThread agentThread = new ChatHistoryAgentThread(); Scanner scanner = new Scanner(System.in); while (true) { @@ -118,22 +121,19 @@ public static void main(String[] args) { var message = new ChatMessageContent<>(AuthorRole.USER, input); KernelArguments arguments = KernelArguments.builder() - .withVariable("now", System.currentTimeMillis()) - .build(); + .withVariable("now", System.currentTimeMillis()) + .build(); var response = agent.invokeAsync( - List.of(message), - agentThread, - AgentInvokeOptions.builder() - .withKernel(kernel) - .withKernelArguments(arguments) - .build() - ).block(); - - var lastResponse = response.get(response.size() - 1); + message, + agentThread, + AgentInvokeOptions.builder() + .withKernelArguments(arguments) + .build() + ).block().get(0); - System.out.println("> " + lastResponse.getMessage()); - agentThread = (ChatHistoryAgentThread) lastResponse.getThread(); + System.out.println("> " + response.getMessage()); + agentThread = response.getThread(); } } } diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/Agent.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/Agent.java index 3a82550c..2277d522 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/Agent.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/Agent.java @@ -17,6 +17,8 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import javax.annotation.Nullable; + /** * Interface for a semantic kernel agent. */ @@ -43,6 +45,36 @@ public interface Agent { */ String getDescription(); + /** + * Invokes the agent with the given message. + * + * @param message The message to process + * @return A Mono containing the agent response + */ + Mono>>> invokeAsync(@Nullable ChatMessageContent message); + + /** + * Invokes the agent with the given message and thread. + * + * @param message The message to process + * @param thread The agent thread to use + * @return A Mono containing the agent response + */ + Mono>>> invokeAsync(@Nullable ChatMessageContent message, + @Nullable AgentThread thread); + + /** + * Invokes the agent with the given message, thread, and options. + * + * @param message The message to process + * @param thread The agent thread to use + * @param options The options for invoking the agent + * @return A Mono containing the agent response + */ + Mono>>> invokeAsync(@Nullable ChatMessageContent message, + @Nullable AgentThread thread, + @Nullable AgentInvokeOptions options); + /** * Invoke the agent with the given chat history. * @@ -51,7 +83,9 @@ public interface Agent { * @param options The options for invoking the agent * @return A Mono containing the agent response */ - Mono>>> invokeAsync(List> messages, AgentThread thread, AgentInvokeOptions options); + Mono>>> invokeAsync(List> messages, + @Nullable AgentThread thread, + @Nullable AgentInvokeOptions options); /** * Notifies the agent of a new message. diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/AgentInvokeOptions.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/AgentInvokeOptions.java index 3fb4cba1..559a062c 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/AgentInvokeOptions.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/AgentInvokeOptions.java @@ -13,9 +13,13 @@ */ public class AgentInvokeOptions { + @Nullable private final KernelArguments kernelArguments; + @Nullable private final Kernel kernel; + @Nullable private final String additionalInstructions; + @Nullable private final InvocationContext invocationContext; /** diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/KernelAgent.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/KernelAgent.java index 71e81951..4d30569f 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/KernelAgent.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/KernelAgent.java @@ -11,6 +11,8 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import javax.annotation.Nullable; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -114,7 +116,6 @@ public PromptTemplate getTemplate() { return template; } - /** * Merges the provided arguments with the current arguments. * Provided arguments will override the current arguments. @@ -167,4 +168,27 @@ protected Mono ensureThreadExistsWithMessagesAsync(Li .then(Mono.just((T) newThread)); }); } + + @Override + public Mono>>> invokeAsync(@Nullable ChatMessageContent message) { + return invokeAsync(message, null, null); + } + + @Override + public Mono>>> invokeAsync(@Nullable ChatMessageContent message, + @Nullable AgentThread thread) { + return invokeAsync(message, thread, null); + } + + @Override + public Mono>>> invokeAsync( + @Nullable ChatMessageContent message, + @Nullable AgentThread thread, + @Nullable AgentInvokeOptions options) { + ArrayList> messages = new ArrayList<>(); + if (message != null) { + messages.add(message); + } + return invokeAsync(messages, thread, options); + } }