From 9334986bec857f9a9600a9a22135e0f9c314d77f Mon Sep 17 00:00:00 2001 From: Ivan Pegashev Date: Fri, 19 Apr 2024 00:50:25 +0300 Subject: [PATCH] feat(chat): add cancel button in chat to stop generation --- src/common/chat/cloudChat.ts | 4 + src/common/chat/index.ts | 2 + src/common/chat/localChat.ts | 5 +- src/common/panel/chat.ts | 72 +++++++++++---- webviews/src/App.tsx | 8 +- webviews/src/components/TextArea/index.tsx | 3 +- webviews/src/hooks/useChat.ts | 87 ++++++++++--------- .../transformCallback2AsyncGenerator.ts | 30 +++++++ webviews/src/utilities/vscode.ts | 53 +++++++++-- 9 files changed, 191 insertions(+), 73 deletions(-) create mode 100644 webviews/src/utilities/transformCallback2AsyncGenerator.ts diff --git a/src/common/chat/cloudChat.ts b/src/common/chat/cloudChat.ts index 26cce7f..7088a9f 100644 --- a/src/common/chat/cloudChat.ts +++ b/src/common/chat/cloudChat.ts @@ -11,6 +11,7 @@ import { configuration } from "../utils/configuration"; type Parameters = { temperature: number; n_predict: number; + controller?: AbortController; }; export const sendChatRequestCloud = async ( @@ -45,6 +46,9 @@ export const sendChatRequestCloud = async ( }); const stream = await model.pipe(parser).stream(messages, { + configurable: { + signal: parameters.controller, + }, maxConcurrency: 1, }); diff --git a/src/common/chat/index.ts b/src/common/chat/index.ts index f5f73ca..120c768 100644 --- a/src/common/chat/index.ts +++ b/src/common/chat/index.ts @@ -26,6 +26,7 @@ export async function* chat( history: ChatMessage[], config?: { provideHighlightedText?: boolean; + abortController: AbortController; } ) { const loggerCompletion = logCompletion(); @@ -36,6 +37,7 @@ export async function* chat( n_predict: 4096, stop: [], temperature: 0.7, + controller: config?.abortController, }; const { stopTask } = statusBar.startTask(); diff --git a/src/common/chat/localChat.ts b/src/common/chat/localChat.ts index 00bed2a..5306633 100644 --- a/src/common/chat/localChat.ts +++ b/src/common/chat/localChat.ts @@ -53,7 +53,10 @@ export async function* sendChatRequestLocal( const startTime = performance.now(); let timings; - for await (const chunk of llama(prompt, parametersForCompletion, { url })) { + for await (const chunk of llama(prompt, parametersForCompletion, { + url, + controller: parameters.controller, + })) { // @ts-ignore if (chunk.data) { // @ts-ignore diff --git a/src/common/panel/chat.ts b/src/common/panel/chat.ts index 85c400c..6f6ccf8 100644 --- a/src/common/panel/chat.ts +++ b/src/common/panel/chat.ts @@ -1,8 +1,9 @@ -import { Disposable, Webview, window, Uri } from "vscode"; +import { Disposable, Webview, Uri } from "vscode"; import * as vscode from "vscode"; import { getUri } from "../utils/getUri"; import { getNonce } from "../utils/getNonce"; import { chat } from "../chat"; +import { ChatMessage } from "../prompt/promptChat"; export type MessageType = | { @@ -13,8 +14,8 @@ export type MessageType = } | { type: "e2w-response"; + id: string; command: string; - messageId: string; done: boolean; data: any; }; @@ -22,6 +23,7 @@ export type MessageType = export class ChatPanel implements vscode.WebviewViewProvider { private disposables: Disposable[] = []; private webview: Webview | undefined; + private messageCallback: Record = {}; constructor(private readonly extensionUri: vscode.Uri) {} @@ -94,26 +96,19 @@ export class ChatPanel implements vscode.WebviewViewProvider { private setWebviewMessageListener(webview: Webview) { webview.onDidReceiveMessage( async (message: any) => { - const sendResponse = (messageToResponse: any, done: boolean) => { - this.postMessage({ - type: "e2w-response", - command: message.type, - messageId: message.messageId, - data: messageToResponse, - done: done, - }); - }; + if (message.type in this.messageCallback) { + this.messageCallback[message.type](); + return; + } const type = message.type; - const data = message.data; switch (type) { case "sendMessage": - for await (const message of chat(data, { - provideHighlightedText: true, - })) { - sendResponse(message, false); - } - sendResponse("", true); + await this.handleStartGeneration({ + chatMessage: message.data, + messageId: message.messageId, + messageType: message.type, + }); return; } }, @@ -122,6 +117,47 @@ export class ChatPanel implements vscode.WebviewViewProvider { ); } + private addMessageListener( + commandOrMessageId: string, + callback: (message: any) => void + ) { + this.messageCallback[commandOrMessageId] = callback; + } + + private async handleStartGeneration({ + messageId, + messageType, + chatMessage, + }: { + messageId: string; + messageType: string; + chatMessage: ChatMessage[]; + }) { + const sendResponse = (messageToResponse: any, done: boolean) => { + this.postMessage({ + type: "e2w-response", + id: messageId, + command: messageType, + data: messageToResponse, + done: done, + }); + }; + const abortController = new AbortController(); + + this.addMessageListener("abort-generate", () => { + abortController.abort(); + }); + + for await (const message of chat(chatMessage, { + provideHighlightedText: true, + abortController, + })) { + sendResponse(message, false); + } + + sendResponse("", true); + } + public async sendMessageToWebview( command: MessageType["command"], data: MessageType["data"] diff --git a/webviews/src/App.tsx b/webviews/src/App.tsx index cb4a71e..d2e9d0e 100644 --- a/webviews/src/App.tsx +++ b/webviews/src/App.tsx @@ -15,6 +15,7 @@ export const App = () => { input, setInput, startNewChat, + stop, } = useChat(); useMessageListener("startNewChat", () => { @@ -50,14 +51,11 @@ export const App = () => { buttonEnd={ diff --git a/webviews/src/components/TextArea/index.tsx b/webviews/src/components/TextArea/index.tsx index 004ba0f..e810aad 100644 --- a/webviews/src/components/TextArea/index.tsx +++ b/webviews/src/components/TextArea/index.tsx @@ -54,7 +54,7 @@ const TextArea = ({ onSubmit(); } - event.preventDefault(); // Prevents the addition of a new line in the text field + event.preventDefault(); } }} > @@ -63,5 +63,4 @@ const TextArea = ({ ); }; -// 24 42 61 export default TextArea; diff --git a/webviews/src/hooks/useChat.ts b/webviews/src/hooks/useChat.ts index e53be7e..d5a3f95 100644 --- a/webviews/src/hooks/useChat.ts +++ b/webviews/src/hooks/useChat.ts @@ -1,19 +1,48 @@ -import { useCallback, useState } from "react"; +import { useCallback, useRef, useState } from "react"; import { randomMessageId } from "../utilities/messageId"; import { vscode } from "../utilities/vscode"; +export type ChatMessage = { + role: string; + content: string; + chatMessageId: string; +}; + export const useChat = () => { - const [chatMessages, setChatMessages] = useState< - { - role: string; - content: string; - chatMessageId: string; - }[] - >([]); + const [chatMessages, setChatMessages] = useState([]); const [input, setInput] = useState(""); const [isLoading, setIsLoading] = useState(false); + const abortController = useRef(new AbortController()); + + const sendMessage = async (chatHistoryLocal: ChatMessage[]) => { + const messageId = randomMessageId(); + for await (const newMessage of vscode.startGeneration(chatHistoryLocal, { + signal: abortController.current.signal, + })) { + setChatMessages((chatHistoryLocal) => { + const messages = chatHistoryLocal.filter( + (message) => message.chatMessageId !== messageId + ); + + const currentChatMessage = chatHistoryLocal.find( + (message) => message.chatMessageId === messageId + ); + + return [ + ...messages, + { + role: "ai", + content: (currentChatMessage?.content || "") + newMessage, + chatMessageId: messageId, + }, + ]; + }); + } + setIsLoading(false); + }; + const handleSubmit = () => { if (isLoading) { return; @@ -21,6 +50,9 @@ export const useChat = () => { if (input === "") { return; } + if (abortController.current.signal.aborted) { + abortController.current = new AbortController(); + } setChatMessages((value) => { const messageId = randomMessageId(); @@ -41,40 +73,10 @@ export const useChat = () => { setInput(""); }; - const sendMessage = async (chatHistoryLocal: any) => { - const messageId = randomMessageId(); - await vscode.postMessageCallback( - { - type: "sendMessage", - data: chatHistoryLocal, - }, - (newMessage) => { - setChatMessages((chatHistoryLocal) => { - const messages = chatHistoryLocal.filter( - (message) => message.chatMessageId !== messageId - ); - - const currentChatMessage = chatHistoryLocal.find( - (message) => message.chatMessageId === messageId - ); - - if (newMessage.done) { - setIsLoading(false); - return chatHistoryLocal; - } - - return [ - ...messages, - { - role: "ai", - content: (currentChatMessage?.content || "") + newMessage.data, - chatMessageId: messageId, - }, - ]; - }); - } - ); - }; + const stop = useCallback(() => { + abortController.current.abort(); + setIsLoading(false); + }, [abortController]); const startNewChat = useCallback(() => { setChatMessages([]); @@ -87,5 +89,6 @@ export const useChat = () => { setInput, handleSubmit, startNewChat, + stop, }; }; diff --git a/webviews/src/utilities/transformCallback2AsyncGenerator.ts b/webviews/src/utilities/transformCallback2AsyncGenerator.ts new file mode 100644 index 0000000..a4c8fbd --- /dev/null +++ b/webviews/src/utilities/transformCallback2AsyncGenerator.ts @@ -0,0 +1,30 @@ +export class Transform { + private open = true; + private queue: T[] = []; + private resolve: (() => void) | undefined; + + async *stream(): AsyncGenerator { + this.open = true; + + while (this.open) { + if (this.queue.length) { + yield this.queue.shift()!; + continue; + } + + await new Promise((resolveLocal) => { + this.resolve = resolveLocal; + }); + } + } + + push(data: T): void { + this.queue.push(data); + this.resolve?.(); + } + + close(): void { + this.open = false; + this.resolve?.(); + } +} diff --git a/webviews/src/utilities/vscode.ts b/webviews/src/utilities/vscode.ts index dd06c15..83d35e5 100644 --- a/webviews/src/utilities/vscode.ts +++ b/webviews/src/utilities/vscode.ts @@ -1,5 +1,7 @@ import type { WebviewApi } from "vscode-webview"; import { randomMessageId } from "./messageId"; +import { ChatMessage } from "../hooks/useChat"; +import { Transform } from "./transformCallback2AsyncGenerator"; export type MessageType = | { @@ -11,7 +13,7 @@ export type MessageType = | { type: "e2w-response"; command: string; - messageId: string; + id: string; done: boolean; data: any; }; @@ -29,9 +31,9 @@ class VSCodeAPIWrapper { if ( newMessage.type === "e2w-response" && - newMessage.messageId in this.messageCallback + newMessage.id in this.messageCallback ) { - this.messageCallback[newMessage.messageId](newMessage); + this.messageCallback[newMessage.id](newMessage); return; } @@ -45,15 +47,21 @@ class VSCodeAPIWrapper { }); } - public async postMessageCallback( + public postMessageCallback( message: { type: string; data: any }, - messageCallback?: (message: any) => void + messageCallback?: (message: any) => void, + config?: { signal?: AbortSignal } ) { if (this.vsCodeApi) { const messageId = randomMessageId(); if (messageCallback) { this.addMessageListener(messageId, messageCallback); } + + config?.signal?.addEventListener("abort", () => { + this.abortOperation(messageId); + }); + this.vsCodeApi.postMessage({ ...message, messageId, @@ -62,12 +70,47 @@ class VSCodeAPIWrapper { console.log(message); } } + + public startGeneration( + chatHistory: ChatMessage[], + config?: { + signal: AbortSignal; + } + ) { + const transform = new Transform(); + this.postMessageCallback( + { + data: chatHistory, + type: "sendMessage", + }, + (message) => { + if (message.done) { + transform.close(); + } else { + transform.push(message.data); + } + }, + { + signal: config?.signal, + } + ); + + return transform.stream(); + } + public addMessageListener( commandOrMessageId: string, callback: (message: any) => void ) { this.messageCallback[commandOrMessageId] = callback; } + + private abortOperation(messageId: string) { + this.vsCodeApi?.postMessage({ + type: "abort-generate", + id: messageId, + }); + } } export const vscode = new VSCodeAPIWrapper();