Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import com.microsoft.semantickernel.agents.AgentThread;
import com.microsoft.semantickernel.agents.KernelAgent;
import com.microsoft.semantickernel.builders.SemanticKernelBuilder;
import com.microsoft.semantickernel.functionchoice.AutoFunctionChoiceBehavior;
import com.microsoft.semantickernel.orchestration.InvocationContext;
import com.microsoft.semantickernel.orchestration.InvocationReturnMode;
import com.microsoft.semantickernel.orchestration.PromptExecutionSettings;
import com.microsoft.semantickernel.orchestration.ToolCallBehavior;
import com.microsoft.semantickernel.semanticfunctions.KernelArguments;
import com.microsoft.semantickernel.semanticfunctions.PromptTemplate;
import com.microsoft.semantickernel.semanticfunctions.PromptTemplateConfig;
Expand Down Expand Up @@ -62,7 +62,7 @@ private ChatCompletionAgent(
@Override
public Mono<List<AgentResponseItem<ChatMessageContent<?>>>> invokeAsync(
List<ChatMessageContent<?>> messages,
AgentThread thread,
@Nullable AgentThread thread,
@Nullable AgentInvokeOptions options
) {
return ensureThreadExistsWithMessagesAsync(messages, thread, ChatHistoryAgentThread::new)
Expand All @@ -76,22 +76,20 @@ public Mono<List<AgentResponseItem<ChatMessageContent<?>>>> invokeAsync(
// Invoke the agent with the chat history
return internalInvokeAsync(
history,
agentThread,
options
)
.flatMapMany(Flux::fromIterable)
// notify on the new thread instance
.concatMap(agentMessage -> this.notifyThreadOfNewMessageAsync(agentThread, agentMessage).thenReturn(agentMessage))
.collectList()
.map(chatMessageContents ->
chatMessageContents.stream()
.map(message -> new AgentResponseItem<ChatMessageContent<?>>(message, agentThread))
.collect(Collectors.toList())
.map(message -> new AgentResponseItem<ChatMessageContent<?>>(message, agentThread))
.collect(Collectors.toList())
);
});
}

private Mono<List<ChatMessageContent<?>>> internalInvokeAsync(
ChatHistory history,
AgentThread thread,
@Nullable AgentInvokeOptions options
) {
if (options == null) {
Expand Down Expand Up @@ -144,6 +142,20 @@ private Mono<List<ChatMessageContent<?>>> internalInvokeAsync(
// Add the chat history to the new chat
chat.addAll(history);

// Retrieve the chat message contents asynchronously and notify the thread
if (shouldNotifyFunctionCalls(agentInvocationContext)) {
// Notify all messages including function calls
return chatCompletionService.getChatMessageContentsAsync(chat, kernel, agentInvocationContext)
.flatMapMany(Flux::fromIterable)
.concatMap(message -> notifyThreadOfNewMessageAsync(thread, message).thenReturn(message))
// Filter out function calls and their results
.filter(message -> message.getContent() != null && message.getAuthorRole() != AuthorRole.TOOL)
.collect(Collectors.toList());
}

// Return chat completion messages without notifying the thread
// We shouldn't add the function call content to the thread, since
// we don't know if the user will execute the call. They should add it themselves.
return chatCompletionService.getChatMessageContentsAsync(chat, kernel, agentInvocationContext);
}
);
Expand All @@ -153,6 +165,22 @@ private Mono<List<ChatMessageContent<?>>> internalInvokeAsync(
}
}

boolean shouldNotifyFunctionCalls(InvocationContext invocationContext) {
if (invocationContext == null) {
return false;
}

if (invocationContext.getFunctionChoiceBehavior() != null && invocationContext.getFunctionChoiceBehavior() instanceof AutoFunctionChoiceBehavior) {
return ((AutoFunctionChoiceBehavior) invocationContext.getFunctionChoiceBehavior()).isAutoInvoke();
}

if (invocationContext.getToolCallBehavior() != null) {
return invocationContext.getToolCallBehavior().isAutoInvokeAllowed();
}

return false;
}


@Override
public Mono<Void> notifyThreadOfNewMessageAsync(AgentThread thread, ChatMessageContent<?> message) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,25 @@
public class ChatHistoryAgentThread extends BaseAgentThread {
private ChatHistory chatHistory;

/**
* Constructor for ChatHistoryAgentThread.
*
*/
public ChatHistoryAgentThread() {
this(UUID.randomUUID().toString(), new ChatHistory());
}

/**
* Constructor for com.microsoft.semantickernel.agents.chatcompletion.ChatHistoryAgentThread.
* Constructor for ChatHistoryAgentThread.
*
* @param chatHistory The chat history.
*/
public ChatHistoryAgentThread(@Nullable ChatHistory chatHistory) {
this(UUID.randomUUID().toString(), chatHistory);
}

/**
* Constructor for ChatHistoryAgentThread.
*
* @param id The ID of the thread.
* @param chatHistory The chat history.
Expand All @@ -31,6 +44,8 @@ public ChatHistoryAgentThread(String id, @Nullable ChatHistory chatHistory) {
this.chatHistory = chatHistory != null ? chatHistory : new ChatHistory();
}



/**
* Get the chat history.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
import com.azure.core.credential.KeyCredential;
import com.microsoft.semantickernel.Kernel;
import com.microsoft.semantickernel.agents.AgentInvokeOptions;
import com.microsoft.semantickernel.agents.AgentThread;
import com.microsoft.semantickernel.agents.chatcompletion.ChatCompletionAgent;
import com.microsoft.semantickernel.agents.chatcompletion.ChatHistoryAgentThread;
import com.microsoft.semantickernel.aiservices.openai.chatcompletion.OpenAIChatCompletion;
import com.microsoft.semantickernel.contextvariables.ContextVariableTypeConverter;
import com.microsoft.semantickernel.contextvariables.ContextVariableTypes;
import com.microsoft.semantickernel.functionchoice.FunctionChoiceBehavior;
import com.microsoft.semantickernel.implementation.templateengine.tokenizer.DefaultPromptTemplate;
import com.microsoft.semantickernel.orchestration.InvocationContext;
import com.microsoft.semantickernel.orchestration.PromptExecutionSettings;
import com.microsoft.semantickernel.orchestration.ToolCallBehavior;
import com.microsoft.semantickernel.plugin.KernelPluginFactory;
import com.microsoft.semantickernel.samples.plugins.github.GitHubModel;
import com.microsoft.semantickernel.samples.plugins.github.GitHubPlugin;
Expand Down Expand Up @@ -105,7 +108,7 @@ public static void main(String[] args) {
)
).build();

ChatHistoryAgentThread agentThread = new ChatHistoryAgentThread();
AgentThread agentThread = new ChatHistoryAgentThread();
Scanner scanner = new Scanner(System.in);

while (true) {
Expand All @@ -118,22 +121,19 @@ public static void main(String[] args) {

var message = new ChatMessageContent<>(AuthorRole.USER, input);
KernelArguments arguments = KernelArguments.builder()
.withVariable("now", System.currentTimeMillis())
.build();
.withVariable("now", System.currentTimeMillis())
.build();

var response = agent.invokeAsync(
List.of(message),
agentThread,
AgentInvokeOptions.builder()
.withKernel(kernel)
.withKernelArguments(arguments)
.build()
).block();

var lastResponse = response.get(response.size() - 1);
message,
agentThread,
AgentInvokeOptions.builder()
.withKernelArguments(arguments)
.build()
).block().get(0);

System.out.println("> " + lastResponse.getMessage());
agentThread = (ChatHistoryAgentThread) lastResponse.getThread();
System.out.println("> " + response.getMessage());
agentThread = response.getThread();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import javax.annotation.Nullable;

/**
* Interface for a semantic kernel agent.
*/
Expand All @@ -43,6 +45,36 @@ public interface Agent {
*/
String getDescription();

/**
* Invokes the agent with the given message.
*
* @param message The message to process
* @return A Mono containing the agent response
*/
Mono<List<AgentResponseItem<ChatMessageContent<?>>>> invokeAsync(@Nullable ChatMessageContent<?> message);

/**
* Invokes the agent with the given message and thread.
*
* @param message The message to process
* @param thread The agent thread to use
* @return A Mono containing the agent response
*/
Mono<List<AgentResponseItem<ChatMessageContent<?>>>> invokeAsync(@Nullable ChatMessageContent<?> message,
@Nullable AgentThread thread);

/**
* Invokes the agent with the given message, thread, and options.
*
* @param message The message to process
* @param thread The agent thread to use
* @param options The options for invoking the agent
* @return A Mono containing the agent response
*/
Mono<List<AgentResponseItem<ChatMessageContent<?>>>> invokeAsync(@Nullable ChatMessageContent<?> message,
@Nullable AgentThread thread,
@Nullable AgentInvokeOptions options);

/**
* Invoke the agent with the given chat history.
*
Expand All @@ -51,7 +83,9 @@ public interface Agent {
* @param options The options for invoking the agent
* @return A Mono containing the agent response
*/
Mono<List<AgentResponseItem<ChatMessageContent<?>>>> invokeAsync(List<ChatMessageContent<?>> messages, AgentThread thread, AgentInvokeOptions options);
Mono<List<AgentResponseItem<ChatMessageContent<?>>>> invokeAsync(List<ChatMessageContent<?>> messages,
@Nullable AgentThread thread,
@Nullable AgentInvokeOptions options);

/**
* Notifies the agent of a new message.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
*/
public class AgentInvokeOptions {

@Nullable
private final KernelArguments kernelArguments;
@Nullable
private final Kernel kernel;
@Nullable
private final String additionalInstructions;
@Nullable
private final InvocationContext invocationContext;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -114,7 +116,6 @@ public PromptTemplate getTemplate() {
return template;
}


/**
* Merges the provided arguments with the current arguments.
* Provided arguments will override the current arguments.
Expand Down Expand Up @@ -167,4 +168,27 @@ protected <T extends AgentThread> Mono<T> ensureThreadExistsWithMessagesAsync(Li
.then(Mono.just((T) newThread));
});
}

@Override
public Mono<List<AgentResponseItem<ChatMessageContent<?>>>> invokeAsync(@Nullable ChatMessageContent<?> message) {
return invokeAsync(message, null, null);
}

@Override
public Mono<List<AgentResponseItem<ChatMessageContent<?>>>> invokeAsync(@Nullable ChatMessageContent<?> message,
@Nullable AgentThread thread) {
return invokeAsync(message, thread, null);
}

@Override
public Mono<List<AgentResponseItem<ChatMessageContent<?>>>> invokeAsync(
@Nullable ChatMessageContent<?> message,
@Nullable AgentThread thread,
@Nullable AgentInvokeOptions options) {
ArrayList<ChatMessageContent<?>> messages = new ArrayList<>();
if (message != null) {
messages.add(message);
}
return invokeAsync(messages, thread, options);
}
}