From 55ebb03faf6f75ff433b99b7ad3d7e644da5d63b Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Sat, 4 Apr 2026 20:35:30 +0000 Subject: [PATCH 1/2] Polish local model chat UI and share provider helpers Export LOCAL_PROVIDER_LABELS and local model id parsers from modelRegistry; dedupe labels in agentChatService and aiIntegrationService. Merge CLI and local runtime banners in AgentChatPane, reset unified permission mode when harness recommendations change, and fix local placeholder sdkModelId. Settings and model options use the shared helpers. Co-authored-by: Arul Sharma --- .../main/services/ai/aiIntegrationService.ts | 7 +- .../main/services/chat/agentChatService.ts | 7 +- .../components/chat/AgentChatPane.tsx | 215 +++++++++++++----- .../components/settings/ProvidersSection.tsx | 26 ++- .../shared/UnifiedModelSelector.tsx | 33 +-- apps/desktop/src/renderer/lib/modelOptions.ts | 45 ++-- apps/desktop/src/shared/modelRegistry.ts | 16 +- 7 files changed, 214 insertions(+), 135 deletions(-) diff --git a/apps/desktop/src/main/services/ai/aiIntegrationService.ts b/apps/desktop/src/main/services/ai/aiIntegrationService.ts index f1efbfb8e..7836c2188 100644 --- a/apps/desktop/src/main/services/ai/aiIntegrationService.ts +++ b/apps/desktop/src/main/services/ai/aiIntegrationService.ts @@ -17,6 +17,7 @@ import { getAvailableModels, getLocalProviderDefaultEndpoint, listModelDescriptorsForProvider, + LOCAL_PROVIDER_LABELS, replaceDynamicLocalModelDescriptors, resolveModelAlias, enrichModelRegistry, @@ -380,12 +381,6 @@ function redactDetectedAuth( return redacted; } -const LOCAL_PROVIDER_LABELS: Record = { - ollama: "Ollama", - lmstudio: "LM Studio", - vllm: "vLLM", -}; - function apiProviderLabel(provider: string): string { const labels: Record = { anthropic: "Anthropic", diff --git a/apps/desktop/src/main/services/chat/agentChatService.ts b/apps/desktop/src/main/services/chat/agentChatService.ts index a37fd2baa..d4e407ffc 100644 --- a/apps/desktop/src/main/services/chat/agentChatService.ts +++ b/apps/desktop/src/main/services/chat/agentChatService.ts @@ -121,6 +121,7 @@ import { getModelById, getAvailableModels as getRegistryModels, listModelDescriptorsForProvider, + LOCAL_PROVIDER_LABELS, MODEL_REGISTRY, pickDefaultCursorDescriptorFromCliList, replaceDynamicLocalModelDescriptors, @@ -4294,12 +4295,6 @@ export function createAgentChatService(args: { // Unified session support — for API-key / local models using streamText + universal tools. // CLI-wrapped models fall through to the existing Claude/Codex runtimes. - const LOCAL_PROVIDER_LABELS: Record = { - ollama: "Ollama", - lmstudio: "LM Studio", - vllm: "vLLM", - }; - const discoveredLocalModelToDescriptor = (model: DiscoveredLocalModel): ModelDescriptor => createDynamicLocalModelDescriptor(model.provider, model.modelId, { ...(model.displayName ? { displayName: model.displayName } : {}), diff --git a/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx b/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx index 1ff5bb6ad..c96fd34c4 100644 --- a/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx +++ b/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx @@ -30,9 +30,12 @@ import { } from "../../../shared/types"; import { parseAgentChatTranscript } from "../../../shared/chatTranscript"; import { + LOCAL_PROVIDER_LABELS, MODEL_REGISTRY, + getLocalModelIdTail, getLocalProviderDefaultEndpoint, getModelById, + parseLocalProviderFromModelId, resolveModelDescriptorForProvider, type LocalProviderFamily, type ModelDescriptor, @@ -67,30 +70,17 @@ const LEGACY_PROVIDER_KEY = "ade.chat.lastProvider"; const LEGACY_MODEL_KEY_PREFIX = "ade.chat.lastModel"; const COMPUTER_USE_SNAPSHOT_COOLDOWN_MS = 750; -const LOCAL_PROVIDER_LABELS: Record = { - ollama: "Ollama", - lmstudio: "LM Studio", - vllm: "vLLM", -}; type AiStatusSnapshot = AiSettingsStatus & { runtimeConnections?: Record; }; -function getLocalProviderFromModelId(modelId: string): LocalProviderFamily | null { - const provider = String(modelId ?? "").trim().split("/", 1)[0]?.toLowerCase(); - if (provider === "ollama" || provider === "lmstudio" || provider === "vllm") { - return provider; - } - return null; -} - function formatLocalModelLabel(modelId: string): string { - const provider = getLocalProviderFromModelId(modelId); + const provider = parseLocalProviderFromModelId(modelId); if (!provider) { return getModelById(modelId)?.displayName ?? modelId; } - const tail = String(modelId ?? "").trim().slice(provider.length + 1).trim(); + const tail = getLocalModelIdTail(modelId, provider); return tail.length ? tail : modelId; } @@ -103,6 +93,60 @@ function recommendedUnifiedPermissionModeForModel( : null; } +function shouldResetUnifiedPermissionForModelSwitch( + previous: ModelDescriptor | null | undefined, + next: ModelDescriptor | null | undefined, +): boolean { + const prevRec = recommendedUnifiedPermissionModeForModel(previous); + const nextRec = recommendedUnifiedPermissionModeForModel(next); + if (prevRec == null && nextRec == null) return false; + return prevRec !== nextRec; +} + +type LocalRuntimeNoticeShape = { + tone: "success" | "warning"; + title: string; + message: string; +}; + +function LocalRuntimeNoticeBlock(props: { + notice: LocalRuntimeNoticeShape; + endpoint?: string | null; + /** `inline` = text only (inside a parent runtime card). */ + variant?: "card" | "inline"; +}) { + const { notice, endpoint, variant = "card" } = props; + const isCard = variant === "card"; + return ( +
+
+ {notice.title} +
+
+ {notice.message} +
+ {endpoint ? ( + + {endpoint} + + ) : null} +
+ ); +} + export function resolveChatSessionProfile(_computerUsePolicy: ComputerUsePolicy): AgentChatSessionProfile { return "workflow"; } @@ -629,6 +673,7 @@ export function AgentChatPane({ const [codexSandbox, setCodexSandbox] = useState(initialNativeControls.codexSandbox); const [codexConfigSource, setCodexConfigSource] = useState(initialNativeControls.codexConfigSource); const [unifiedPermissionMode, setUnifiedPermissionMode] = useState(initialNativeControls.unifiedPermissionMode); + const prevModelDescRef = useRef(undefined); const [cursorModeId, setCursorModeId] = useState(initialNativeControls.cursorModeId); const [cursorConfigValues, setCursorConfigValues] = useState>(initialNativeControls.cursorConfigValues); const [computerUsePolicy, setComputerUsePolicy] = useState(createDefaultComputerUsePolicy()); @@ -728,7 +773,7 @@ export function AgentChatPane({ const localRuntimeState = useMemo(() => { const provider = selectedModelDesc?.authTypes.includes("local") ? (selectedModelDesc.family as LocalProviderFamily) - : getLocalProviderFromModelId(modelId); + : parseLocalProviderFromModelId(modelId); if (!provider) return null; const runtimeConnection = aiStatus?.runtimeConnections?.[provider] ?? null; const detectedEntry = aiStatus?.detectedAuth?.find( @@ -790,9 +835,61 @@ export function AgentChatPane({ return { tone: "success" as const, title: `${localRuntimeState.label} runtime`, - message: `${localRuntimeState.label} is connected at ${localRuntimeState.endpoint} with ${localRuntimeState.modelIds.length} loaded model${localRuntimeState.modelIds.length === 1 ? "" : "s"}${localRuntimeState.health ? ` (${localRuntimeState.health})` : ""}.`, + message: `${localRuntimeState.label} is connected with ${localRuntimeState.modelIds.length} loaded model${localRuntimeState.modelIds.length === 1 ? "" : "s"}${localRuntimeState.health ? ` (${localRuntimeState.health})` : ""}.`, }; }, [localRuntimeState, modelId, selectedModelDesc?.displayName]); + + const cliRuntimeBlocked = Boolean( + selectedSessionId + && activeProviderConnection + && !activeProviderConnection.runtimeAvailable + && (activeProviderConnection.blocker || activeProviderConnection.provider === "cursor"), + ); + const cliRuntimeTitle = activeProviderConnection?.provider === "claude" + ? "Claude runtime" + : activeProviderConnection?.provider === "cursor" + ? "Cursor runtime" + : "Codex runtime"; + const cliRuntimeBody = activeProviderConnection?.blocker + ?? (activeProviderConnection?.provider === "cursor" + ? "Cursor agent is not available. Ensure Cursor is installed and the agent is enabled." + : null); + + const mergedRuntimeBanner = useMemo(() => { + if (!cliRuntimeBlocked && !localRuntimeNotice) return null; + if (cliRuntimeBlocked && localRuntimeNotice) { + return { + kind: "merged" as const, + cliTitle: cliRuntimeTitle, + cliBody: cliRuntimeBody ?? "", + localNotice: localRuntimeNotice, + localEndpoint: localRuntimeState?.endpoint, + }; + } + if (cliRuntimeBlocked) { + return { + kind: "cli-only" as const, + cliTitle: cliRuntimeTitle, + cliBody: cliRuntimeBody ?? "", + }; + } + return { + kind: "local-only" as const, + localNotice: localRuntimeNotice!, + localEndpoint: localRuntimeState?.endpoint, + }; + }, [ + cliRuntimeBlocked, + cliRuntimeBody, + cliRuntimeTitle, + localRuntimeNotice, + localRuntimeState?.endpoint, + ]); + + useEffect(() => { + prevModelDescRef.current = getModelById(modelId); + }, [modelId]); + const surfaceMode = presentation?.mode ?? "standard"; const identitySessionSettingsBusy = isPersistentIdentitySurface && sessionMutationKind !== null; @@ -1698,32 +1795,39 @@ export function AgentChatPane({ }; }, [currentNativeControls]); const buildModelSelectionSnapshot = useCallback((nextModelId: string) => { + const previousDesc = prevModelDescRef.current; const nextDesc = getModelById(nextModelId); const nextProvider = resolveChatRuntimeProvider(nextDesc); const nextModel = nextProvider === "unified" ? nextModelId : runtimeFacingModelId(nextDesc, nextModelId); const tiers = nextDesc?.reasoningTiers ?? []; const preferred = readLastUsedReasoningEffort({ laneId, modelId: nextModelId }); const nextReasoningEffort = selectReasoningEffort({ tiers, preferred }); + const nextRec = recommendedUnifiedPermissionModeForModel(nextDesc); return { nextDesc, nextModelId, nextModel, nextProvider, nextReasoningEffort, - nextUnifiedPermissionMode: recommendedUnifiedPermissionModeForModel(nextDesc), + nextUnifiedPermissionMode: nextRec, + resetUnifiedPermissionToDefault: shouldResetUnifiedPermissionForModelSwitch(previousDesc, nextDesc), }; }, [laneId]); const applyModelSelectionSnapshot = useCallback((snapshot: { nextModelId: string; nextReasoningEffort: string | null; nextUnifiedPermissionMode?: AgentChatUnifiedPermissionMode | null; + resetUnifiedPermissionToDefault?: boolean; }) => { setModelId(snapshot.nextModelId); setReasoningEffort(snapshot.nextReasoningEffort); + if (snapshot.resetUnifiedPermissionToDefault) { + setUnifiedPermissionMode(initialNativeControls.unifiedPermissionMode); + } if (snapshot.nextUnifiedPermissionMode) { setUnifiedPermissionMode(snapshot.nextUnifiedPermissionMode); } - }, []); + }, [initialNativeControls.unifiedPermissionMode]); const notifySessionCreated = useCallback((session: AgentChatSession) => { if (!onSessionCreated) return; void Promise.resolve(onSessionCreated(session)).catch((err) => { console.error("notifySessionCreated failed:", err); }); @@ -2433,15 +2537,15 @@ export function AgentChatPane({ } setSessionMutationKind("model"); - const nextNativeControlPayload = snapshot.nextUnifiedPermissionMode + const nextUnifiedForPayload = snapshot.resetUnifiedPermissionToDefault + ? (snapshot.nextUnifiedPermissionMode ?? initialNativeControls.unifiedPermissionMode) + : snapshot.nextUnifiedPermissionMode; + const nextNativeControlPayload = snapshot.nextProvider === "unified" && nextUnifiedForPayload != null ? { - ...summarizeNativeControls(snapshot.nextProvider, { + ...summarizeNativeControls("unified", { ...currentNativeControls, - unifiedPermissionMode: snapshot.nextUnifiedPermissionMode, + unifiedPermissionMode: nextUnifiedForPayload, }), - ...(snapshot.nextProvider === "cursor" - ? { cursorConfigValues: currentNativeControls.cursorConfigValues } - : {}), } : buildNativeControlPayload(snapshot.nextProvider); void window.ade.agentChat.updateSession({ @@ -2550,47 +2654,44 @@ export function AgentChatPane({ {error} ) : null} - {selectedSessionId && !activeProviderConnection?.runtimeAvailable && (activeProviderConnection?.blocker || activeProviderConnection?.provider === "cursor") ? ( + {mergedRuntimeBanner?.kind === "cli-only" ? (
- {activeProviderConnection.provider === "claude" - ? "Claude runtime" - : activeProviderConnection.provider === "cursor" - ? "Cursor runtime" - : "Codex runtime"} + {mergedRuntimeBanner.cliTitle}
- {activeProviderConnection.blocker || "Cursor agent is not available. Ensure Cursor is installed and the agent is enabled."} + {mergedRuntimeBanner.cliBody}
) : null} - - {localRuntimeNotice ? ( -
-
- {localRuntimeNotice.title} + {mergedRuntimeBanner?.kind === "local-only" ? ( + + ) : null} + {mergedRuntimeBanner?.kind === "merged" ? ( +
+
+ Runtime status
-
- {localRuntimeNotice.message} +
+
+
+ {mergedRuntimeBanner.cliTitle} +
+
+ {mergedRuntimeBanner.cliBody} +
+
+
+ +
- {localRuntimeState ? ( - - {localRuntimeState.endpoint} - - ) : null}
) : null} diff --git a/apps/desktop/src/renderer/components/settings/ProvidersSection.tsx b/apps/desktop/src/renderer/components/settings/ProvidersSection.tsx index 41734d982..7b934afec 100644 --- a/apps/desktop/src/renderer/components/settings/ProvidersSection.tsx +++ b/apps/desktop/src/renderer/components/settings/ProvidersSection.tsx @@ -8,7 +8,14 @@ import type { AiSettingsStatus, ProjectConfigSnapshot, } from "../../../shared/types"; -import { getLocalProviderDefaultEndpoint, getModelById, type LocalProviderFamily } from "../../../shared/modelRegistry"; +import { + getLocalModelIdTail, + getLocalProviderDefaultEndpoint, + getModelById, + LOCAL_PROVIDER_LABELS, + parseLocalProviderFromModelId, + type LocalProviderFamily, +} from "../../../shared/modelRegistry"; import { ArrowsClockwise, CheckCircle, @@ -72,10 +79,6 @@ const LOCAL_PROVIDER_SPECS: Array<{ { provider: "vllm", label: "vLLM", description: "OpenAI-compatible local server" }, ]; -const LOCAL_PROVIDER_LABELS: Record = Object.fromEntries( - LOCAL_PROVIDER_SPECS.map((entry) => [entry.provider, entry.label]), -) as Record; - const API_KEY_PROVIDERS: Array<{ provider: string; label: string; @@ -184,14 +187,13 @@ function buildCliMessage(tool: (typeof CLI_TOOLS)[number], connection: AiProvide function formatLocalModelLabel(modelId: string): string { const descriptor = getModelById(modelId); if (descriptor) return descriptor.displayName; - const raw = String(modelId ?? "").trim(); - const prefix = raw.split("/", 1)[0]?.toLowerCase(); - if (prefix === "ollama" || prefix === "lmstudio" || prefix === "vllm") { - const tail = raw.slice(prefix.length + 1).trim(); - const providerLabel = LOCAL_PROVIDER_SPECS.find((entry) => entry.provider === prefix)?.label ?? prefix; - return tail.length ? `${tail} (${providerLabel})` : raw; + const provider = parseLocalProviderFromModelId(modelId); + if (provider) { + const tail = getLocalModelIdTail(modelId, provider); + const brand = LOCAL_PROVIDER_LABELS[provider]; + return tail.length ? `${tail} (${brand})` : String(modelId ?? "").trim(); } - return raw; + return String(modelId ?? "").trim(); } const AUTH_ERROR_SIGNALS = [ diff --git a/apps/desktop/src/renderer/components/shared/UnifiedModelSelector.tsx b/apps/desktop/src/renderer/components/shared/UnifiedModelSelector.tsx index 8b289ae13..062ee112a 100644 --- a/apps/desktop/src/renderer/components/shared/UnifiedModelSelector.tsx +++ b/apps/desktop/src/renderer/components/shared/UnifiedModelSelector.tsx @@ -2,7 +2,10 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { createPortal } from "react-dom"; import { AnimatePresence, motion } from "motion/react"; import { + LOCAL_PROVIDER_LABELS, MODEL_REGISTRY, + getLocalModelIdTail, + parseLocalProviderFromModelId, resolveModelDescriptor, type ModelDescriptor, } from "../../../shared/modelRegistry"; @@ -43,11 +46,6 @@ type UnifiedModelSelectorProps = { }; const SOURCE_KEYS: SourceSectionKey[] = ["subscription", "api", "local"]; -const LOCAL_PROVIDER_LABELS: Record = { - ollama: "Ollama", - lmstudio: "LM Studio", - vllm: "vLLM", -}; const selectCls = cn( "h-8 rounded-lg border border-white/[0.08] bg-white/[0.04] px-2 font-sans text-[11px] text-fg/70", @@ -67,21 +65,6 @@ function providerAccent(family: string, fallback?: string): string { return PROVIDER_BADGE_COLORS[family] ?? fallback ?? "#A78BFA"; } -function getLocalProviderFromModelId(modelId: string): "ollama" | "lmstudio" | "vllm" | null { - const provider = String(modelId ?? "").trim().split("/", 1)[0]?.toLowerCase(); - if (provider === "ollama" || provider === "lmstudio" || provider === "vllm") { - return provider; - } - return null; -} - -function getLocalModelShortLabel(modelId: string): string { - const provider = getLocalProviderFromModelId(modelId); - if (!provider) return modelId; - const tail = String(modelId ?? "").trim().slice(provider.length + 1).trim(); - return tail.length ? tail : modelId; -} - function subsectionTabTitle(sub: ModelSubsection): string { return sub.label.trim() || "Models"; } @@ -133,10 +116,10 @@ function createUnknownModelPlaceholder(modelId: string): ModelDescriptor { isCliWrapped: true, }; } - const localProvider = getLocalProviderFromModelId(modelId); + const localProvider = parseLocalProviderFromModelId(modelId); if (localProvider) { - const providerLabel = LOCAL_PROVIDER_LABELS[localProvider]; - const shortId = getLocalModelShortLabel(modelId); + const shortId = getLocalModelIdTail(modelId, localProvider) || modelId; + const brand = LOCAL_PROVIDER_LABELS[localProvider]; return { id: modelId, shortId, @@ -148,11 +131,11 @@ function createUnknownModelPlaceholder(modelId: string): ModelDescriptor { capabilities: { tools: false, vision: false, reasoning: false, streaming: true }, color: PROVIDER_BADGE_COLORS[localProvider] ?? "#64748B", sdkProvider: "@ai-sdk/openai-compatible", - sdkModelId: modelId, + sdkModelId: shortId, isCliWrapped: false, discoverySource: localProvider === "lmstudio" ? "lmstudio-openai" : localProvider, harnessProfile: "guarded", - aliases: [providerLabel], + aliases: [brand], }; } return { diff --git a/apps/desktop/src/renderer/lib/modelOptions.ts b/apps/desktop/src/renderer/lib/modelOptions.ts index d52e08681..c6e5fc1f3 100644 --- a/apps/desktop/src/renderer/lib/modelOptions.ts +++ b/apps/desktop/src/renderer/lib/modelOptions.ts @@ -1,42 +1,33 @@ import type { AiModelDescriptor, AiRuntimeConnectionStatus, AiSettingsStatus, ModelId } from "../../shared/types"; -import { MODEL_REGISTRY, getModelById, type LocalProviderFamily, type ModelDescriptor } from "../../shared/modelRegistry"; +import { + LOCAL_PROVIDER_LABELS, + MODEL_REGISTRY, + getLocalModelIdTail, + getModelById, + isLocalProviderFamily, + parseLocalProviderFromModelId, + type ModelDescriptor, +} from "../../shared/modelRegistry"; function normalizeAuthProvider(provider: string | undefined): string { return String(provider ?? "").trim().toLowerCase(); } -const LOCAL_PROVIDER_LABELS: Record = { - ollama: "Ollama", - lmstudio: "LM Studio", - vllm: "vLLM", -}; - -function getLocalProviderFromModelId(modelId: string): LocalProviderFamily | null { - const provider = String(modelId ?? "") - .trim() - .split("/", 1)[0] - ?.toLowerCase(); - if (provider === "ollama" || provider === "lmstudio" || provider === "vllm") { - return provider; - } - return null; -} - function getLocalModelLabel(modelId: string): string { - const provider = getLocalProviderFromModelId(modelId); + const provider = parseLocalProviderFromModelId(modelId); if (!provider) return modelId; - const tail = String(modelId ?? "").trim().slice(provider.length + 1).trim(); + const tail = getLocalModelIdTail(modelId, provider); return tail.length ? tail : modelId; } function buildFallbackModelOption(modelId: string): AiModelDescriptor { - const provider = getLocalProviderFromModelId(modelId); + const provider = parseLocalProviderFromModelId(modelId); if (provider) { - const providerLabel = LOCAL_PROVIDER_LABELS[provider]; + const pLabel = LOCAL_PROVIDER_LABELS[provider]; return { id: modelId, label: getLocalModelLabel(modelId), - description: `${providerLabel} local model`, + description: `${pLabel} local model`, }; } return { @@ -104,7 +95,7 @@ function hasDynamicLocalModelIdsForProvider( runtimeConnections: Record | undefined, ): boolean { const normalizedProvider = normalizeAuthProvider(provider); - if (normalizedProvider !== "ollama" && normalizedProvider !== "lmstudio" && normalizedProvider !== "vllm") { + if (!isLocalProviderFamily(normalizedProvider)) { return false; } const prefix = `${normalizedProvider}/`; @@ -168,7 +159,7 @@ export function deriveConfiguredModelIds( for (const [provider, connection] of Object.entries(runtimeConnections ?? {})) { const normalizedProvider = normalizeAuthProvider(provider); - if (normalizedProvider !== "ollama" && normalizedProvider !== "lmstudio" && normalizedProvider !== "vllm") { + if (!isLocalProviderFamily(normalizedProvider)) { continue; } if (!hasDynamicLocalModelIdsForProvider(normalizedProvider, status.availableModelIds, runtimeConnections)) { @@ -210,7 +201,7 @@ export function deriveConfiguredModelOptions( const descriptor = getModelById(modelId); return descriptor ? [descriptorToModelOption(descriptor)] - : getLocalProviderFromModelId(modelId) + : parseLocalProviderFromModelId(modelId) ? [buildFallbackModelOption(modelId)] : []; }); @@ -224,6 +215,6 @@ export function includeSelectedModelOption( if (!modelId.length || options.some((option) => option.id === modelId)) return options; const descriptor = getModelById(modelId); if (descriptor) return [descriptorToModelOption(descriptor), ...options]; - if (getLocalProviderFromModelId(modelId)) return [buildFallbackModelOption(modelId), ...options]; + if (parseLocalProviderFromModelId(modelId)) return [buildFallbackModelOption(modelId), ...options]; return options; } diff --git a/apps/desktop/src/shared/modelRegistry.ts b/apps/desktop/src/shared/modelRegistry.ts index b7c407cdd..cdf6840ec 100644 --- a/apps/desktop/src/shared/modelRegistry.ts +++ b/apps/desktop/src/shared/modelRegistry.ts @@ -82,7 +82,8 @@ export function isModelProviderGroup(value: string | null | undefined): value is const ALL_CAPS: ModelCapabilities = { tools: true, vision: true, reasoning: true, streaming: true }; const NO_REASONING: ModelCapabilities = { tools: true, vision: true, reasoning: false, streaming: true }; const BASIC_CAPS: ModelCapabilities = { tools: true, vision: false, reasoning: false, streaming: true }; -const LOCAL_PROVIDER_LABELS: Record = { +/** Human-readable names for Ollama / LM Studio / vLLM (shared across main, renderer, and MCP). */ +export const LOCAL_PROVIDER_LABELS: Record = { ollama: "Ollama", lmstudio: "LM Studio", vllm: "vLLM", @@ -753,10 +754,21 @@ export function validateModelRegistry(models: ModelDescriptor[] = MODEL_REGISTRY validateModelRegistry(); rebuildIndexes(); -function isLocalProviderFamily(value: string): value is LocalProviderFamily { +export function isLocalProviderFamily(value: string): value is LocalProviderFamily { return value === "ollama" || value === "lmstudio" || value === "vllm"; } +/** First path segment of `provider/modelId` when it is a known local provider. */ +export function parseLocalProviderFromModelId(modelId: string): LocalProviderFamily | null { + const provider = String(modelId ?? "").trim().split("/", 1)[0]?.toLowerCase() ?? ""; + return isLocalProviderFamily(provider) ? provider : null; +} + +/** Model name segment after `provider/` for local refs; empty string if missing. */ +export function getLocalModelIdTail(modelId: string, provider: LocalProviderFamily): string { + return String(modelId ?? "").trim().slice(provider.length + 1).trim(); +} + function parseDynamicLocalModelRef(modelRef: string): { provider: LocalProviderFamily; modelId: string } | null { const normalized = modelRef.trim(); if (!normalized.length) return null; From be3a62e656f910029ab4a120df533e7bfc1ae3f8 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Sun, 5 Apr 2026 02:30:54 +0000 Subject: [PATCH 2/2] Address PR review: local permission descriptors and unified mode setter Add getModelDescriptorForPermissionMode for harness decisions when getModelById is undefined (e.g. ollama/auto). Use it in AgentChatPane for prevModelRef, model switch snapshots, and session create. Consolidate applyModelSelectionSnapshot to a single setUnifiedPermissionMode call. Co-authored-by: Arul Sharma --- .../components/chat/AgentChatPane.tsx | 22 +++++++++++-------- apps/desktop/src/shared/modelRegistry.test.ts | 14 ++++++++++++ apps/desktop/src/shared/modelRegistry.ts | 17 ++++++++++++++ 3 files changed, 44 insertions(+), 9 deletions(-) diff --git a/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx b/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx index c96fd34c4..86eb1d5cc 100644 --- a/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx +++ b/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx @@ -35,6 +35,7 @@ import { getLocalModelIdTail, getLocalProviderDefaultEndpoint, getModelById, + getModelDescriptorForPermissionMode, parseLocalProviderFromModelId, resolveModelDescriptorForProvider, type LocalProviderFamily, @@ -887,7 +888,7 @@ export function AgentChatPane({ ]); useEffect(() => { - prevModelDescRef.current = getModelById(modelId); + prevModelDescRef.current = getModelDescriptorForPermissionMode(modelId); }, [modelId]); const surfaceMode = presentation?.mode ?? "standard"; @@ -1797,12 +1798,13 @@ export function AgentChatPane({ const buildModelSelectionSnapshot = useCallback((nextModelId: string) => { const previousDesc = prevModelDescRef.current; const nextDesc = getModelById(nextModelId); + const nextPermissionDesc = getModelDescriptorForPermissionMode(nextModelId); const nextProvider = resolveChatRuntimeProvider(nextDesc); const nextModel = nextProvider === "unified" ? nextModelId : runtimeFacingModelId(nextDesc, nextModelId); const tiers = nextDesc?.reasoningTiers ?? []; const preferred = readLastUsedReasoningEffort({ laneId, modelId: nextModelId }); const nextReasoningEffort = selectReasoningEffort({ tiers, preferred }); - const nextRec = recommendedUnifiedPermissionModeForModel(nextDesc); + const nextRec = recommendedUnifiedPermissionModeForModel(nextPermissionDesc); return { nextDesc, nextModelId, @@ -1810,7 +1812,7 @@ export function AgentChatPane({ nextProvider, nextReasoningEffort, nextUnifiedPermissionMode: nextRec, - resetUnifiedPermissionToDefault: shouldResetUnifiedPermissionForModelSwitch(previousDesc, nextDesc), + resetUnifiedPermissionToDefault: shouldResetUnifiedPermissionForModelSwitch(previousDesc, nextPermissionDesc), }; }, [laneId]); const applyModelSelectionSnapshot = useCallback((snapshot: { @@ -1821,11 +1823,12 @@ export function AgentChatPane({ }) => { setModelId(snapshot.nextModelId); setReasoningEffort(snapshot.nextReasoningEffort); - if (snapshot.resetUnifiedPermissionToDefault) { - setUnifiedPermissionMode(initialNativeControls.unifiedPermissionMode); - } - if (snapshot.nextUnifiedPermissionMode) { - setUnifiedPermissionMode(snapshot.nextUnifiedPermissionMode); + const nextUnified = snapshot.nextUnifiedPermissionMode ?? null; + const targetUnified = snapshot.resetUnifiedPermissionToDefault + ? (nextUnified ?? initialNativeControls.unifiedPermissionMode) + : nextUnified; + if (targetUnified != null) { + setUnifiedPermissionMode(targetUnified); } }, [initialNativeControls.unifiedPermissionMode]); const notifySessionCreated = useCallback((session: AgentChatSession) => { @@ -1840,11 +1843,12 @@ export function AgentChatPane({ if (!laneId) return null; const createPromise = (async () => { const desc = getModelById(modelId); + const permissionDesc = getModelDescriptorForPermissionMode(modelId); const provider = resolveChatRuntimeProvider(desc); const model = provider === "unified" ? modelId : runtimeFacingModelId(desc, modelId); const sessionProfile = resolveChatSessionProfile(computerUsePolicy); const harnessPermissionMode = provider === "unified" - ? recommendedUnifiedPermissionModeForModel(desc) + ? recommendedUnifiedPermissionModeForModel(permissionDesc) : null; const nativeControlPayload = harnessPermissionMode ? { diff --git a/apps/desktop/src/shared/modelRegistry.test.ts b/apps/desktop/src/shared/modelRegistry.test.ts index 4fdbf920a..fa446ba97 100644 --- a/apps/desktop/src/shared/modelRegistry.test.ts +++ b/apps/desktop/src/shared/modelRegistry.test.ts @@ -4,6 +4,7 @@ import { getAvailableModels, getDefaultModelDescriptor, getModelById, + getModelDescriptorForPermissionMode, getRuntimeModelRefForDescriptor, listModelDescriptorsForProvider, MODEL_REGISTRY, @@ -76,6 +77,19 @@ describe("modelRegistry", () => { expect(resolveModelDescriptor("nonexistent/model-id")).toBeUndefined(); }); + it("getModelDescriptorForPermissionMode matches getModelById for known locals", () => { + const id = "ollama/qwen2.5-coder:32b"; + expect(getModelDescriptorForPermissionMode(id)).toEqual(getModelById(id)); + }); + + it("getModelDescriptorForPermissionMode yields guarded local for ollama/auto when getModelById is undefined", () => { + expect(getModelById("ollama/auto")).toBeUndefined(); + const perm = getModelDescriptorForPermissionMode("ollama/auto"); + expect(perm?.family).toBe("ollama"); + expect(perm?.harnessProfile).toBe("guarded"); + expect(perm?.authTypes).toContain("local"); + }); + it("resolves gpt-5.4 shortId to the API-key variant, not the codex variant", () => { const resolved = resolveModelAlias("gpt-5.4"); expect(resolved).toBeTruthy(); diff --git a/apps/desktop/src/shared/modelRegistry.ts b/apps/desktop/src/shared/modelRegistry.ts index cdf6840ec..508d4bf44 100644 --- a/apps/desktop/src/shared/modelRegistry.ts +++ b/apps/desktop/src/shared/modelRegistry.ts @@ -769,6 +769,23 @@ export function getLocalModelIdTail(modelId: string, provider: LocalProviderFami return String(modelId ?? "").trim().slice(provider.length + 1).trim(); } +/** + * Descriptor for unified permission / harness decisions when the registry has no row yet. + * `getModelById` returns undefined for refs such as `ollama/auto`; this still returns a + * guarded local descriptor so the UI matches main-process harness behavior. + */ +export function getModelDescriptorForPermissionMode(modelId: string): ModelDescriptor | undefined { + const resolved = getModelById(modelId); + if (resolved) return resolved; + const provider = parseLocalProviderFromModelId(modelId); + if (!provider) return undefined; + const tail = getLocalModelIdTail(modelId, provider); + if (!tail.length || tail === "auto") { + return createDynamicLocalModelDescriptor(provider, "auto", { harnessProfile: "guarded" }); + } + return createDynamicLocalModelDescriptor(provider, tail); +} + function parseDynamicLocalModelRef(modelRef: string): { provider: LocalProviderFamily; modelId: string } | null { const normalized = modelRef.trim(); if (!normalized.length) return null;