diff --git a/src/main/presenter/llmProviderPresenter/index.ts b/src/main/presenter/llmProviderPresenter/index.ts index a427a0e27..e2725000d 100644 --- a/src/main/presenter/llmProviderPresenter/index.ts +++ b/src/main/presenter/llmProviderPresenter/index.ts @@ -297,7 +297,8 @@ export class LLMProviderPresenter implements ILlmProviderPresenter { modelId: string, eventId: string, temperature: number = 0.6, - maxTokens: number = 4096 + maxTokens: number = 4096, + enabledMcpTools?: string[] ): AsyncGenerator { console.log(`[Agent Loop] Starting agent loop for event: ${eventId} with model: ${modelId}`) if (!this.canStartNewStream()) { @@ -371,7 +372,7 @@ export class LLMProviderPresenter implements ILlmProviderPresenter { try { console.log(`[Agent Loop] Iteration ${toolCallCount + 1} for event: ${eventId}`) - const mcpTools = await presenter.mcpPresenter.getAllToolDefinitions() + const mcpTools = await presenter.mcpPresenter.getAllToolDefinitions(enabledMcpTools) // Call the provider's core stream method, expecting LLMCoreStreamEvent const stream = provider.coreStream( conversationMessages, @@ -591,9 +592,9 @@ export class LLMProviderPresenter implements ILlmProviderPresenter { toolCallCount++ // Find the tool definition to get server info - const toolDef = (await presenter.mcpPresenter.getAllToolDefinitions()).find( - (t) => t.function.name === toolCall.name - ) + const toolDef = ( + await presenter.mcpPresenter.getAllToolDefinitions(enabledMcpTools) + ).find((t) => t.function.name === toolCall.name) if (!toolDef) { console.error(`Tool definition not found for ${toolCall.name}. Skipping execution.`) diff --git a/src/main/presenter/mcpPresenter/index.ts b/src/main/presenter/mcpPresenter/index.ts index e23fedd0a..704bcecbc 100644 --- a/src/main/presenter/mcpPresenter/index.ts +++ b/src/main/presenter/mcpPresenter/index.ts @@ -410,11 +410,10 @@ export class McpPresenter implements IMCPPresenter { // 通知渲染进程服务器已停止 eventBus.send(MCP_EVENTS.SERVER_STOPPED, SendTarget.ALL_WINDOWS, serverName) } - - async getAllToolDefinitions(): Promise { + async getAllToolDefinitions(enabledMcpTools?: string[]): Promise { const enabled = await this.configPresenter.getMcpEnabled() if (enabled) { - return this.toolManager.getAllToolDefinitions() + return await this.toolManager.getAllToolDefinitions(enabledMcpTools) } return [] } diff --git a/src/main/presenter/mcpPresenter/toolManager.ts b/src/main/presenter/mcpPresenter/toolManager.ts index 35abd4879..1f984e075 100644 --- a/src/main/presenter/mcpPresenter/toolManager.ts +++ b/src/main/presenter/mcpPresenter/toolManager.ts @@ -43,10 +43,17 @@ export class ToolManager { public async getRunningClients(): Promise { return this.serverManager.getRunningClients() } - // 获取所有工具定义 - public async getAllToolDefinitions(): Promise { + public async getAllToolDefinitions(enabledTools?: string[]): Promise { if (this.cachedToolDefinitions !== null && this.cachedToolDefinitions.length > 0) { + if (enabledTools) { + const enabledSet = new Set(enabledTools) + return this.cachedToolDefinitions.filter((toolDef) => { + const finalName = toolDef.function.name + const originalName = this.toolNameToTargetMap?.get(finalName)?.originalName || finalName + return enabledSet.has(finalName) || enabledSet.has(originalName) + }) + } return this.cachedToolDefinitions } @@ -200,6 +207,16 @@ export class ToolManager { // 缓存结果并返回 this.cachedToolDefinitions = results console.info(`Cached ${results.length} final tool definitions and populated target map.`) + + if (enabledTools && enabledTools.length > 0) { + const enabledSet = new Set(enabledTools) + return this.cachedToolDefinitions.filter((toolDef) => { + const finalName = toolDef.function.name + const originalName = this.toolNameToTargetMap?.get(finalName)?.originalName || finalName + return enabledSet.has(finalName) || enabledSet.has(originalName) + }) + } + return this.cachedToolDefinitions } diff --git a/src/main/presenter/sqlitePresenter/tables/conversations.ts b/src/main/presenter/sqlitePresenter/tables/conversations.ts index b037326be..e1ea7a337 100644 --- a/src/main/presenter/sqlitePresenter/tables/conversations.ts +++ b/src/main/presenter/sqlitePresenter/tables/conversations.ts @@ -17,6 +17,16 @@ type ConversationRow = { artifacts: number is_new: number is_pinned: number + enabled_mcp_tools: string | null +} + +// 解析 JSON 字段 +function getJsonField(val: string | null | undefined, fallback: T): T { + try { + return val ? JSON.parse(val) : fallback + } catch { + return fallback + } } export class ConversationsTable extends BaseTable { @@ -46,7 +56,6 @@ export class ConversationsTable extends BaseTable { CREATE INDEX idx_conversations_pinned ON conversations(is_pinned); ` } - getMigrationSQL(version: number): string | null { if (version === 1) { return ` @@ -67,11 +76,17 @@ export class ConversationsTable extends BaseTable { UPDATE conversations SET artifacts = 0; ` } + if (version === 3) { + return ` + ALTER TABLE conversations ADD COLUMN enabled_mcp_tools TEXT DEFAULT '[]'; + ` + } + return null } getLatestVersion(): number { - return 2 + return 3 } async create(title: string, settings: Partial = {}): Promise { @@ -89,9 +104,10 @@ export class ConversationsTable extends BaseTable { model_id, is_new, artifacts, - is_pinned + is_pinned, + enabled_mcp_tools ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? ,?) `) const conv_id = nanoid() const now = Date.now() @@ -108,7 +124,8 @@ export class ConversationsTable extends BaseTable { settings.modelId || 'gpt-4', 1, settings.artifacts || 0, - 0 // Default is_pinned to 0 + 0, // Default is_pinned to 0 + settings.enabledMcpTools ? JSON.stringify(settings.enabledMcpTools) : '[]' ) return conv_id } @@ -130,7 +147,8 @@ export class ConversationsTable extends BaseTable { model_id as modelId, is_new, artifacts, - is_pinned + is_pinned, + enabled_mcp_tools FROM conversations WHERE conv_id = ? ` @@ -155,7 +173,8 @@ export class ConversationsTable extends BaseTable { maxTokens: result.maxTokens, providerId: result.providerId, modelId: result.modelId, - artifacts: result.artifacts as 0 | 1 + artifacts: result.artifacts as 0 | 1, + enabledMcpTools: getJsonField(result.enabled_mcp_tools, []) } } } @@ -208,8 +227,11 @@ export class ConversationsTable extends BaseTable { updates.push('artifacts = ?') params.push(data.settings.artifacts) } + if (data.settings.enabledMcpTools !== undefined) { + updates.push('enabled_mcp_tools = ?') + params.push(JSON.stringify(data.settings.enabledMcpTools)) + } } - if (updates.length > 0 || data.updatedAt) { updates.push('updated_at = ?') params.push(data.updatedAt || Date.now()) @@ -252,7 +274,8 @@ export class ConversationsTable extends BaseTable { model_id as modelId, is_new, artifacts, - is_pinned + is_pinned, + enabled_mcp_tools FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ? @@ -276,7 +299,8 @@ export class ConversationsTable extends BaseTable { maxTokens: row.maxTokens, providerId: row.providerId, modelId: row.modelId, - artifacts: row.artifacts as 0 | 1 + artifacts: row.artifacts as 0 | 1, + enabledMcpTools: getJsonField(row.enabled_mcp_tools, []) } })) } diff --git a/src/main/presenter/threadPresenter/index.ts b/src/main/presenter/threadPresenter/index.ts index 15b2881e7..889559b84 100644 --- a/src/main/presenter/threadPresenter/index.ts +++ b/src/main/presenter/threadPresenter/index.ts @@ -706,7 +706,6 @@ export class ThreadPresenter implements IThreadPresenter { return conversation } - async createConversation( title: string, settings: Partial = {}, @@ -1467,16 +1466,17 @@ export class ThreadPresenter implements IThreadPresenter { providerId: currentProviderId, modelId: currentModelId, temperature: currentTemperature, - maxTokens: currentMaxTokens + maxTokens: currentMaxTokens, + enabledMcpTools: crrentEnabledMcpTools } = currentConversation.settings - const stream = this.llmProviderPresenter.startStreamCompletion( currentProviderId, // 使用最新的设置 finalContent, currentModelId, // 使用最新的设置 state.message.id, currentTemperature, // 使用最新的设置 - currentMaxTokens // 使用最新的设置 + currentMaxTokens, // 使用最新的设置 + crrentEnabledMcpTools ) for await (const event of stream) { const msg = event.data @@ -1574,7 +1574,7 @@ export class ThreadPresenter implements IThreadPresenter { this.throwIfCancelled(state.message.id) // 7. 准备提示内容 - const { providerId, modelId, temperature, maxTokens } = conversation.settings + const { providerId, modelId, temperature, maxTokens, enabledMcpTools } = conversation.settings const modelConfig = this.configPresenter.getModelConfig(modelId, providerId) const { finalContent, promptTokens } = await this.preparePromptContent( @@ -1641,7 +1641,8 @@ export class ThreadPresenter implements IThreadPresenter { modelId, state.message.id, temperature, - maxTokens + maxTokens, + enabledMcpTools ) for await (const event of stream) { const msg = event.data @@ -1789,7 +1790,7 @@ export class ThreadPresenter implements IThreadPresenter { finalContent: ChatMessage[] promptTokens: number }> { - const { systemPrompt, contextLength, artifacts } = conversation.settings + const { systemPrompt, contextLength, artifacts, enabledMcpTools } = conversation.settings const searchPrompt = searchResults ? generateSearchPrompt(userContent, searchResults) : '' const enrichedUserMessage = @@ -1801,7 +1802,7 @@ export class ThreadPresenter implements IThreadPresenter { const searchPromptTokens = searchPrompt ? approximateTokenSize(searchPrompt ?? '') : 0 const systemPromptTokens = systemPrompt ? approximateTokenSize(systemPrompt ?? '') : 0 const userMessageTokens = approximateTokenSize(userContent + enrichedUserMessage) - const mcpTools = await presenter.mcpPresenter.getAllToolDefinitions() + const mcpTools = await presenter.mcpPresenter.getAllToolDefinitions(enabledMcpTools) const mcpToolsTokens = mcpTools.reduce( (acc, tool) => acc + approximateTokenSize(JSON.stringify(tool)), 0 @@ -3049,7 +3050,7 @@ export class ThreadPresenter implements IThreadPresenter { throw new Error(errorMsg) } - const { providerId, modelId, temperature, maxTokens } = conversation.settings + const { providerId, modelId, temperature, maxTokens, enabledMcpTools } = conversation.settings const modelConfig = this.configPresenter.getModelConfig(modelId, providerId) if (!modelConfig) { @@ -3102,7 +3103,8 @@ export class ThreadPresenter implements IThreadPresenter { modelId, messageId, temperature, - maxTokens + maxTokens, + enabledMcpTools ) for await (const event of stream) { diff --git a/src/renderer/src/components/NewThread.vue b/src/renderer/src/components/NewThread.vue index 45383f030..645680f2c 100644 --- a/src/renderer/src/components/NewThread.vue +++ b/src/renderer/src/components/NewThread.vue @@ -366,7 +366,8 @@ const handleSend = async (content: UserMessageContent) => { temperature: temperature.value, contextLength: contextLength.value, maxTokens: maxTokens.value, - artifacts: artifacts.value as 0 | 1 + artifacts: artifacts.value as 0 | 1, + enabledMcpTools: chatStore.chatConfig.enabledMcpTools }) console.log('threadId', threadId, activeModel.value) chatStore.sendMessage(content) diff --git a/src/renderer/src/components/mcpToolsList.vue b/src/renderer/src/components/mcpToolsList.vue index fe0fbe717..120bb02f9 100644 --- a/src/renderer/src/components/mcpToolsList.vue +++ b/src/renderer/src/components/mcpToolsList.vue @@ -9,16 +9,17 @@ import { Button } from './ui/button' import { Switch } from './ui/switch' import { Badge } from './ui/badge' import { useLanguageStore } from '@/stores/language' +import { useChatStore } from '@/stores/chat' const { t } = useI18n() const mcpStore = useMcpStore() const langStore = useLanguageStore() +const chatStore = useChatStore() // 计算属性 const isLoading = computed(() => mcpStore.toolsLoading) const isError = computed(() => mcpStore.toolsError) const errorMessage = computed(() => mcpStore.toolsErrorMessage) -const toolCount = computed(() => mcpStore.toolCount) const hasTools = computed(() => mcpStore.hasTools) const mcpEnabled = computed(() => mcpStore.mcpEnabled) @@ -31,9 +32,39 @@ const getTools = (serverName: string) => { return mcpStore.tools.filter((tool) => tool.server.name === serverName) } +// 获取每个mcp服务的可用工具数量 +const getEnabledToolCountByServer = (serverName: string) => { + const enabledTools = chatStore.chatConfig.enabledMcpTools ?? [] + const serverTools = mcpStore.tools.filter((tool) => tool.server.name === serverName) + return serverTools.filter((tool) => enabledTools.includes(tool.function.name)).length +} + +// 获取可用工具总数 +const getTotalEnabledToolCount = () => { + const enabledMcpTools = chatStore.chatConfig.enabledMcpTools || [] + const filterList = mcpStore.tools.filter((item) => enabledMcpTools.includes(item.function.name)) + return filterList.length +} + +// 处理单个服务开关状态变化 const onServerToggle = (serverName: string) => { mcpStore.toggleServer(serverName) } + +// 处理单个工具开关状态变化 +const handleToolEnabledChange = (isEnabled: boolean, functionName: string) => { + const currentTools = chatStore.chatConfig.enabledMcpTools || [] + const updatedTools = isEnabled + ? Array.from(new Set([...currentTools, functionName])) + : currentTools.filter((name) => name !== functionName) + chatStore.updateChatConfig({ enabledMcpTools: updatedTools }) +} + +// 获取单个工具开关状态 +const isEnabled = (functionName: string): boolean => { + return chatStore.chatConfig.enabledMcpTools?.includes(functionName) ?? false +} + // 获取内置服务器的本地化名称和描述 const getLocalizedServerName = (serverName: string) => { return t(`mcp.inmemory.${serverName}.name`, serverName) @@ -76,7 +107,7 @@ onMounted(async () => { v-if="hasTools && !isLoading && !isError" :class="{ 'text-muted-foreground': !mcpEnabled, 'text-white': mcpEnabled }" class="text-sm" - >{{ toolCount }}{{ getTotalEnabledToolCount() }} @@ -84,7 +115,9 @@ onMounted(async () => {

{{ t('mcp.tools.disabled') }}

{{ t('mcp.tools.loading') }}

{{ t('mcp.tools.error') }}

-

{{ t('mcp.tools.available', { count: toolCount }) }}

+

+ {{ t('mcp.tools.available', { count: getTotalEnabledToolCount() }) }} +

{{ t('mcp.tools.none') }}

@@ -144,16 +177,28 @@ onMounted(async () => { variant="outline" class="flex items-center gap-1 mr-2 text-xs" > - {{ getTools(server.name).length }} + {{ getEnabledToolCountByServer(server.name) }}
{{ tool.function.name }}
+ +
+
+ {{ t('mcp.tools.empty') }}
diff --git a/src/renderer/src/stores/chat.ts b/src/renderer/src/stores/chat.ts index e098d475c..f4ede3950 100644 --- a/src/renderer/src/stores/chat.ts +++ b/src/renderer/src/stores/chat.ts @@ -56,7 +56,8 @@ export const useChatStore = defineStore('chat', () => { maxTokens: 8000, providerId: '', modelId: '', - artifacts: 0 + artifacts: 0, + enabledMcpTools: [] }) // Deeplink 消息缓存 diff --git a/src/renderer/src/stores/mcp.ts b/src/renderer/src/stores/mcp.ts index a7a4dd484..c544e3b21 100644 --- a/src/renderer/src/stores/mcp.ts +++ b/src/renderer/src/stores/mcp.ts @@ -4,6 +4,7 @@ import { usePresenter } from '@/composables/usePresenter' import { MCP_EVENTS } from '@/events' import { useI18n } from 'vue-i18n' import { useThrottleFn } from '@vueuse/core' +import { useChatStore } from './chat' import type { McpClient, MCPConfig, @@ -29,6 +30,7 @@ interface MCPToolCallResult { } export const useMcpStore = defineStore('mcp', () => { + const chatStore = useChatStore() const { t } = useI18n() // 获取MCP相关的presenter const mcpPresenter = usePresenter('mcpPresenter') @@ -163,8 +165,22 @@ export const useMcpStore = defineStore('mcp', () => { try { serverStatuses.value[serverName] = await mcpPresenter.isServerRunning(serverName) if (config.value.mcpEnabled && !noRefresh) { - loadTools() - loadClients() + await _loadTools() + await _loadClients() + } + // 根据服务器的状态,关闭或者开启该服务器的所有工具 + const isRunning = serverStatuses.value[serverName] || false + const currentTools = chatStore.chatConfig.enabledMcpTools || [] + if (isRunning) { + const serverTools = tools.value + .filter((tool) => tool.server.name === serverName) + .map((tool) => tool.function.name) + const mergedTools = Array.from(new Set([...currentTools, ...serverTools])) + chatStore.updateChatConfig({ enabledMcpTools: mergedTools }) + } else { + const allServerToolNames = tools.value.map((tool) => tool.function.name) + const filteredTools = currentTools.filter((name) => allServerToolNames.includes(name)) + chatStore.updateChatConfig({ enabledMcpTools: filteredTools }) } } catch (error) { console.error(t('mcp.errors.getServerStatusFailed', { serverName }), error) @@ -498,6 +514,11 @@ export const useMcpStore = defineStore('mcp', () => { await loadTools() await loadClients() } + + // 如果是新建会话页面,则缓存已激活工具名称 + if (!chatStore.getActiveThreadId()) { + chatStore.chatConfig.enabledMcpTools = tools.value.map((item) => item.function.name) + } } // 立即初始化 diff --git a/src/shared/presenter.d.ts b/src/shared/presenter.d.ts index d14d8c54a..784da13d8 100644 --- a/src/shared/presenter.d.ts +++ b/src/shared/presenter.d.ts @@ -502,7 +502,8 @@ export interface ILlmProviderPresenter { modelId: string, eventId: string, temperature?: number, - maxTokens?: number + maxTokens?: number, + enabledMcpTools?: string[] ): AsyncGenerator generateCompletion( providerId: string, @@ -534,6 +535,7 @@ export type CONVERSATION_SETTINGS = { providerId: string modelId: string artifacts: 0 | 1 + enabledMcpTools?: string[] } export type CONVERSATION = {