diff --git a/apps/desktop/src/main/services/ai/aiIntegrationService.ts b/apps/desktop/src/main/services/ai/aiIntegrationService.ts index f1efbfb8e..7836c2188 100644 --- a/apps/desktop/src/main/services/ai/aiIntegrationService.ts +++ b/apps/desktop/src/main/services/ai/aiIntegrationService.ts @@ -17,6 +17,7 @@ import { getAvailableModels, getLocalProviderDefaultEndpoint, listModelDescriptorsForProvider, + LOCAL_PROVIDER_LABELS, replaceDynamicLocalModelDescriptors, resolveModelAlias, enrichModelRegistry, @@ -380,12 +381,6 @@ function redactDetectedAuth( return redacted; } -const LOCAL_PROVIDER_LABELS: Record = { - ollama: "Ollama", - lmstudio: "LM Studio", - vllm: "vLLM", -}; - function apiProviderLabel(provider: string): string { const labels: Record = { anthropic: "Anthropic", diff --git a/apps/desktop/src/main/services/chat/agentChatService.ts b/apps/desktop/src/main/services/chat/agentChatService.ts index a37fd2baa..d4e407ffc 100644 --- a/apps/desktop/src/main/services/chat/agentChatService.ts +++ b/apps/desktop/src/main/services/chat/agentChatService.ts @@ -121,6 +121,7 @@ import { getModelById, getAvailableModels as getRegistryModels, listModelDescriptorsForProvider, + LOCAL_PROVIDER_LABELS, MODEL_REGISTRY, pickDefaultCursorDescriptorFromCliList, replaceDynamicLocalModelDescriptors, @@ -4294,12 +4295,6 @@ export function createAgentChatService(args: { // Unified session support — for API-key / local models using streamText + universal tools. // CLI-wrapped models fall through to the existing Claude/Codex runtimes. - const LOCAL_PROVIDER_LABELS: Record = { - ollama: "Ollama", - lmstudio: "LM Studio", - vllm: "vLLM", - }; - const discoveredLocalModelToDescriptor = (model: DiscoveredLocalModel): ModelDescriptor => createDynamicLocalModelDescriptor(model.provider, model.modelId, { ...(model.displayName ? { displayName: model.displayName } : {}), diff --git a/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx b/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx index 1ff5bb6ad..86eb1d5cc 100644 --- a/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx +++ b/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx @@ -30,9 +30,13 @@ import { } from "../../../shared/types"; import { parseAgentChatTranscript } from "../../../shared/chatTranscript"; import { + LOCAL_PROVIDER_LABELS, MODEL_REGISTRY, + getLocalModelIdTail, getLocalProviderDefaultEndpoint, getModelById, + getModelDescriptorForPermissionMode, + parseLocalProviderFromModelId, resolveModelDescriptorForProvider, type LocalProviderFamily, type ModelDescriptor, @@ -67,30 +71,17 @@ const LEGACY_PROVIDER_KEY = "ade.chat.lastProvider"; const LEGACY_MODEL_KEY_PREFIX = "ade.chat.lastModel"; const COMPUTER_USE_SNAPSHOT_COOLDOWN_MS = 750; -const LOCAL_PROVIDER_LABELS: Record = { - ollama: "Ollama", - lmstudio: "LM Studio", - vllm: "vLLM", -}; type AiStatusSnapshot = AiSettingsStatus & { runtimeConnections?: Record; }; -function getLocalProviderFromModelId(modelId: string): LocalProviderFamily | null { - const provider = String(modelId ?? "").trim().split("/", 1)[0]?.toLowerCase(); - if (provider === "ollama" || provider === "lmstudio" || provider === "vllm") { - return provider; - } - return null; -} - function formatLocalModelLabel(modelId: string): string { - const provider = getLocalProviderFromModelId(modelId); + const provider = parseLocalProviderFromModelId(modelId); if (!provider) { return getModelById(modelId)?.displayName ?? modelId; } - const tail = String(modelId ?? "").trim().slice(provider.length + 1).trim(); + const tail = getLocalModelIdTail(modelId, provider); return tail.length ? tail : modelId; } @@ -103,6 +94,60 @@ function recommendedUnifiedPermissionModeForModel( : null; } +function shouldResetUnifiedPermissionForModelSwitch( + previous: ModelDescriptor | null | undefined, + next: ModelDescriptor | null | undefined, +): boolean { + const prevRec = recommendedUnifiedPermissionModeForModel(previous); + const nextRec = recommendedUnifiedPermissionModeForModel(next); + if (prevRec == null && nextRec == null) return false; + return prevRec !== nextRec; +} + +type LocalRuntimeNoticeShape = { + tone: "success" | "warning"; + title: string; + message: string; +}; + +function LocalRuntimeNoticeBlock(props: { + notice: LocalRuntimeNoticeShape; + endpoint?: string | null; + /** `inline` = text only (inside a parent runtime card). */ + variant?: "card" | "inline"; +}) { + const { notice, endpoint, variant = "card" } = props; + const isCard = variant === "card"; + return ( +
+
+ {notice.title} +
+
+ {notice.message} +
+ {endpoint ? ( + + {endpoint} + + ) : null} +
+ ); +} + export function resolveChatSessionProfile(_computerUsePolicy: ComputerUsePolicy): AgentChatSessionProfile { return "workflow"; } @@ -629,6 +674,7 @@ export function AgentChatPane({ const [codexSandbox, setCodexSandbox] = useState(initialNativeControls.codexSandbox); const [codexConfigSource, setCodexConfigSource] = useState(initialNativeControls.codexConfigSource); const [unifiedPermissionMode, setUnifiedPermissionMode] = useState(initialNativeControls.unifiedPermissionMode); + const prevModelDescRef = useRef(undefined); const [cursorModeId, setCursorModeId] = useState(initialNativeControls.cursorModeId); const [cursorConfigValues, setCursorConfigValues] = useState>(initialNativeControls.cursorConfigValues); const [computerUsePolicy, setComputerUsePolicy] = useState(createDefaultComputerUsePolicy()); @@ -728,7 +774,7 @@ export function AgentChatPane({ const localRuntimeState = useMemo(() => { const provider = selectedModelDesc?.authTypes.includes("local") ? (selectedModelDesc.family as LocalProviderFamily) - : getLocalProviderFromModelId(modelId); + : parseLocalProviderFromModelId(modelId); if (!provider) return null; const runtimeConnection = aiStatus?.runtimeConnections?.[provider] ?? null; const detectedEntry = aiStatus?.detectedAuth?.find( @@ -790,9 +836,61 @@ export function AgentChatPane({ return { tone: "success" as const, title: `${localRuntimeState.label} runtime`, - message: `${localRuntimeState.label} is connected at ${localRuntimeState.endpoint} with ${localRuntimeState.modelIds.length} loaded model${localRuntimeState.modelIds.length === 1 ? "" : "s"}${localRuntimeState.health ? ` (${localRuntimeState.health})` : ""}.`, + message: `${localRuntimeState.label} is connected with ${localRuntimeState.modelIds.length} loaded model${localRuntimeState.modelIds.length === 1 ? "" : "s"}${localRuntimeState.health ? ` (${localRuntimeState.health})` : ""}.`, }; }, [localRuntimeState, modelId, selectedModelDesc?.displayName]); + + const cliRuntimeBlocked = Boolean( + selectedSessionId + && activeProviderConnection + && !activeProviderConnection.runtimeAvailable + && (activeProviderConnection.blocker || activeProviderConnection.provider === "cursor"), + ); + const cliRuntimeTitle = activeProviderConnection?.provider === "claude" + ? "Claude runtime" + : activeProviderConnection?.provider === "cursor" + ? "Cursor runtime" + : "Codex runtime"; + const cliRuntimeBody = activeProviderConnection?.blocker + ?? (activeProviderConnection?.provider === "cursor" + ? "Cursor agent is not available. Ensure Cursor is installed and the agent is enabled." + : null); + + const mergedRuntimeBanner = useMemo(() => { + if (!cliRuntimeBlocked && !localRuntimeNotice) return null; + if (cliRuntimeBlocked && localRuntimeNotice) { + return { + kind: "merged" as const, + cliTitle: cliRuntimeTitle, + cliBody: cliRuntimeBody ?? "", + localNotice: localRuntimeNotice, + localEndpoint: localRuntimeState?.endpoint, + }; + } + if (cliRuntimeBlocked) { + return { + kind: "cli-only" as const, + cliTitle: cliRuntimeTitle, + cliBody: cliRuntimeBody ?? "", + }; + } + return { + kind: "local-only" as const, + localNotice: localRuntimeNotice!, + localEndpoint: localRuntimeState?.endpoint, + }; + }, [ + cliRuntimeBlocked, + cliRuntimeBody, + cliRuntimeTitle, + localRuntimeNotice, + localRuntimeState?.endpoint, + ]); + + useEffect(() => { + prevModelDescRef.current = getModelDescriptorForPermissionMode(modelId); + }, [modelId]); + const surfaceMode = presentation?.mode ?? "standard"; const identitySessionSettingsBusy = isPersistentIdentitySurface && sessionMutationKind !== null; @@ -1698,32 +1796,41 @@ export function AgentChatPane({ }; }, [currentNativeControls]); const buildModelSelectionSnapshot = useCallback((nextModelId: string) => { + const previousDesc = prevModelDescRef.current; const nextDesc = getModelById(nextModelId); + const nextPermissionDesc = getModelDescriptorForPermissionMode(nextModelId); const nextProvider = resolveChatRuntimeProvider(nextDesc); const nextModel = nextProvider === "unified" ? nextModelId : runtimeFacingModelId(nextDesc, nextModelId); const tiers = nextDesc?.reasoningTiers ?? []; const preferred = readLastUsedReasoningEffort({ laneId, modelId: nextModelId }); const nextReasoningEffort = selectReasoningEffort({ tiers, preferred }); + const nextRec = recommendedUnifiedPermissionModeForModel(nextPermissionDesc); return { nextDesc, nextModelId, nextModel, nextProvider, nextReasoningEffort, - nextUnifiedPermissionMode: recommendedUnifiedPermissionModeForModel(nextDesc), + nextUnifiedPermissionMode: nextRec, + resetUnifiedPermissionToDefault: shouldResetUnifiedPermissionForModelSwitch(previousDesc, nextPermissionDesc), }; }, [laneId]); const applyModelSelectionSnapshot = useCallback((snapshot: { nextModelId: string; nextReasoningEffort: string | null; nextUnifiedPermissionMode?: AgentChatUnifiedPermissionMode | null; + resetUnifiedPermissionToDefault?: boolean; }) => { setModelId(snapshot.nextModelId); setReasoningEffort(snapshot.nextReasoningEffort); - if (snapshot.nextUnifiedPermissionMode) { - setUnifiedPermissionMode(snapshot.nextUnifiedPermissionMode); + const nextUnified = snapshot.nextUnifiedPermissionMode ?? null; + const targetUnified = snapshot.resetUnifiedPermissionToDefault + ? (nextUnified ?? initialNativeControls.unifiedPermissionMode) + : nextUnified; + if (targetUnified != null) { + setUnifiedPermissionMode(targetUnified); } - }, []); + }, [initialNativeControls.unifiedPermissionMode]); const notifySessionCreated = useCallback((session: AgentChatSession) => { if (!onSessionCreated) return; void Promise.resolve(onSessionCreated(session)).catch((err) => { console.error("notifySessionCreated failed:", err); }); @@ -1736,11 +1843,12 @@ export function AgentChatPane({ if (!laneId) return null; const createPromise = (async () => { const desc = getModelById(modelId); + const permissionDesc = getModelDescriptorForPermissionMode(modelId); const provider = resolveChatRuntimeProvider(desc); const model = provider === "unified" ? modelId : runtimeFacingModelId(desc, modelId); const sessionProfile = resolveChatSessionProfile(computerUsePolicy); const harnessPermissionMode = provider === "unified" - ? recommendedUnifiedPermissionModeForModel(desc) + ? recommendedUnifiedPermissionModeForModel(permissionDesc) : null; const nativeControlPayload = harnessPermissionMode ? { @@ -2433,15 +2541,15 @@ export function AgentChatPane({ } setSessionMutationKind("model"); - const nextNativeControlPayload = snapshot.nextUnifiedPermissionMode + const nextUnifiedForPayload = snapshot.resetUnifiedPermissionToDefault + ? (snapshot.nextUnifiedPermissionMode ?? initialNativeControls.unifiedPermissionMode) + : snapshot.nextUnifiedPermissionMode; + const nextNativeControlPayload = snapshot.nextProvider === "unified" && nextUnifiedForPayload != null ? { - ...summarizeNativeControls(snapshot.nextProvider, { + ...summarizeNativeControls("unified", { ...currentNativeControls, - unifiedPermissionMode: snapshot.nextUnifiedPermissionMode, + unifiedPermissionMode: nextUnifiedForPayload, }), - ...(snapshot.nextProvider === "cursor" - ? { cursorConfigValues: currentNativeControls.cursorConfigValues } - : {}), } : buildNativeControlPayload(snapshot.nextProvider); void window.ade.agentChat.updateSession({ @@ -2550,47 +2658,44 @@ export function AgentChatPane({ {error} ) : null} - {selectedSessionId && !activeProviderConnection?.runtimeAvailable && (activeProviderConnection?.blocker || activeProviderConnection?.provider === "cursor") ? ( + {mergedRuntimeBanner?.kind === "cli-only" ? (
- {activeProviderConnection.provider === "claude" - ? "Claude runtime" - : activeProviderConnection.provider === "cursor" - ? "Cursor runtime" - : "Codex runtime"} + {mergedRuntimeBanner.cliTitle}
- {activeProviderConnection.blocker || "Cursor agent is not available. Ensure Cursor is installed and the agent is enabled."} + {mergedRuntimeBanner.cliBody}
) : null} - - {localRuntimeNotice ? ( -
-
- {localRuntimeNotice.title} + {mergedRuntimeBanner?.kind === "local-only" ? ( + + ) : null} + {mergedRuntimeBanner?.kind === "merged" ? ( +
+
+ Runtime status
-
- {localRuntimeNotice.message} +
+
+
+ {mergedRuntimeBanner.cliTitle} +
+
+ {mergedRuntimeBanner.cliBody} +
+
+
+ +
- {localRuntimeState ? ( - - {localRuntimeState.endpoint} - - ) : null}
) : null} diff --git a/apps/desktop/src/renderer/components/settings/ProvidersSection.tsx b/apps/desktop/src/renderer/components/settings/ProvidersSection.tsx index 41734d982..7b934afec 100644 --- a/apps/desktop/src/renderer/components/settings/ProvidersSection.tsx +++ b/apps/desktop/src/renderer/components/settings/ProvidersSection.tsx @@ -8,7 +8,14 @@ import type { AiSettingsStatus, ProjectConfigSnapshot, } from "../../../shared/types"; -import { getLocalProviderDefaultEndpoint, getModelById, type LocalProviderFamily } from "../../../shared/modelRegistry"; +import { + getLocalModelIdTail, + getLocalProviderDefaultEndpoint, + getModelById, + LOCAL_PROVIDER_LABELS, + parseLocalProviderFromModelId, + type LocalProviderFamily, +} from "../../../shared/modelRegistry"; import { ArrowsClockwise, CheckCircle, @@ -72,10 +79,6 @@ const LOCAL_PROVIDER_SPECS: Array<{ { provider: "vllm", label: "vLLM", description: "OpenAI-compatible local server" }, ]; -const LOCAL_PROVIDER_LABELS: Record = Object.fromEntries( - LOCAL_PROVIDER_SPECS.map((entry) => [entry.provider, entry.label]), -) as Record; - const API_KEY_PROVIDERS: Array<{ provider: string; label: string; @@ -184,14 +187,13 @@ function buildCliMessage(tool: (typeof CLI_TOOLS)[number], connection: AiProvide function formatLocalModelLabel(modelId: string): string { const descriptor = getModelById(modelId); if (descriptor) return descriptor.displayName; - const raw = String(modelId ?? "").trim(); - const prefix = raw.split("/", 1)[0]?.toLowerCase(); - if (prefix === "ollama" || prefix === "lmstudio" || prefix === "vllm") { - const tail = raw.slice(prefix.length + 1).trim(); - const providerLabel = LOCAL_PROVIDER_SPECS.find((entry) => entry.provider === prefix)?.label ?? prefix; - return tail.length ? `${tail} (${providerLabel})` : raw; + const provider = parseLocalProviderFromModelId(modelId); + if (provider) { + const tail = getLocalModelIdTail(modelId, provider); + const brand = LOCAL_PROVIDER_LABELS[provider]; + return tail.length ? `${tail} (${brand})` : String(modelId ?? "").trim(); } - return raw; + return String(modelId ?? "").trim(); } const AUTH_ERROR_SIGNALS = [ diff --git a/apps/desktop/src/renderer/components/shared/UnifiedModelSelector.tsx b/apps/desktop/src/renderer/components/shared/UnifiedModelSelector.tsx index 8b289ae13..062ee112a 100644 --- a/apps/desktop/src/renderer/components/shared/UnifiedModelSelector.tsx +++ b/apps/desktop/src/renderer/components/shared/UnifiedModelSelector.tsx @@ -2,7 +2,10 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { createPortal } from "react-dom"; import { AnimatePresence, motion } from "motion/react"; import { + LOCAL_PROVIDER_LABELS, MODEL_REGISTRY, + getLocalModelIdTail, + parseLocalProviderFromModelId, resolveModelDescriptor, type ModelDescriptor, } from "../../../shared/modelRegistry"; @@ -43,11 +46,6 @@ type UnifiedModelSelectorProps = { }; const SOURCE_KEYS: SourceSectionKey[] = ["subscription", "api", "local"]; -const LOCAL_PROVIDER_LABELS: Record = { - ollama: "Ollama", - lmstudio: "LM Studio", - vllm: "vLLM", -}; const selectCls = cn( "h-8 rounded-lg border border-white/[0.08] bg-white/[0.04] px-2 font-sans text-[11px] text-fg/70", @@ -67,21 +65,6 @@ function providerAccent(family: string, fallback?: string): string { return PROVIDER_BADGE_COLORS[family] ?? fallback ?? "#A78BFA"; } -function getLocalProviderFromModelId(modelId: string): "ollama" | "lmstudio" | "vllm" | null { - const provider = String(modelId ?? "").trim().split("/", 1)[0]?.toLowerCase(); - if (provider === "ollama" || provider === "lmstudio" || provider === "vllm") { - return provider; - } - return null; -} - -function getLocalModelShortLabel(modelId: string): string { - const provider = getLocalProviderFromModelId(modelId); - if (!provider) return modelId; - const tail = String(modelId ?? "").trim().slice(provider.length + 1).trim(); - return tail.length ? tail : modelId; -} - function subsectionTabTitle(sub: ModelSubsection): string { return sub.label.trim() || "Models"; } @@ -133,10 +116,10 @@ function createUnknownModelPlaceholder(modelId: string): ModelDescriptor { isCliWrapped: true, }; } - const localProvider = getLocalProviderFromModelId(modelId); + const localProvider = parseLocalProviderFromModelId(modelId); if (localProvider) { - const providerLabel = LOCAL_PROVIDER_LABELS[localProvider]; - const shortId = getLocalModelShortLabel(modelId); + const shortId = getLocalModelIdTail(modelId, localProvider) || modelId; + const brand = LOCAL_PROVIDER_LABELS[localProvider]; return { id: modelId, shortId, @@ -148,11 +131,11 @@ function createUnknownModelPlaceholder(modelId: string): ModelDescriptor { capabilities: { tools: false, vision: false, reasoning: false, streaming: true }, color: PROVIDER_BADGE_COLORS[localProvider] ?? "#64748B", sdkProvider: "@ai-sdk/openai-compatible", - sdkModelId: modelId, + sdkModelId: shortId, isCliWrapped: false, discoverySource: localProvider === "lmstudio" ? "lmstudio-openai" : localProvider, harnessProfile: "guarded", - aliases: [providerLabel], + aliases: [brand], }; } return { diff --git a/apps/desktop/src/renderer/lib/modelOptions.ts b/apps/desktop/src/renderer/lib/modelOptions.ts index d52e08681..c6e5fc1f3 100644 --- a/apps/desktop/src/renderer/lib/modelOptions.ts +++ b/apps/desktop/src/renderer/lib/modelOptions.ts @@ -1,42 +1,33 @@ import type { AiModelDescriptor, AiRuntimeConnectionStatus, AiSettingsStatus, ModelId } from "../../shared/types"; -import { MODEL_REGISTRY, getModelById, type LocalProviderFamily, type ModelDescriptor } from "../../shared/modelRegistry"; +import { + LOCAL_PROVIDER_LABELS, + MODEL_REGISTRY, + getLocalModelIdTail, + getModelById, + isLocalProviderFamily, + parseLocalProviderFromModelId, + type ModelDescriptor, +} from "../../shared/modelRegistry"; function normalizeAuthProvider(provider: string | undefined): string { return String(provider ?? "").trim().toLowerCase(); } -const LOCAL_PROVIDER_LABELS: Record = { - ollama: "Ollama", - lmstudio: "LM Studio", - vllm: "vLLM", -}; - -function getLocalProviderFromModelId(modelId: string): LocalProviderFamily | null { - const provider = String(modelId ?? "") - .trim() - .split("/", 1)[0] - ?.toLowerCase(); - if (provider === "ollama" || provider === "lmstudio" || provider === "vllm") { - return provider; - } - return null; -} - function getLocalModelLabel(modelId: string): string { - const provider = getLocalProviderFromModelId(modelId); + const provider = parseLocalProviderFromModelId(modelId); if (!provider) return modelId; - const tail = String(modelId ?? "").trim().slice(provider.length + 1).trim(); + const tail = getLocalModelIdTail(modelId, provider); return tail.length ? tail : modelId; } function buildFallbackModelOption(modelId: string): AiModelDescriptor { - const provider = getLocalProviderFromModelId(modelId); + const provider = parseLocalProviderFromModelId(modelId); if (provider) { - const providerLabel = LOCAL_PROVIDER_LABELS[provider]; + const pLabel = LOCAL_PROVIDER_LABELS[provider]; return { id: modelId, label: getLocalModelLabel(modelId), - description: `${providerLabel} local model`, + description: `${pLabel} local model`, }; } return { @@ -104,7 +95,7 @@ function hasDynamicLocalModelIdsForProvider( runtimeConnections: Record | undefined, ): boolean { const normalizedProvider = normalizeAuthProvider(provider); - if (normalizedProvider !== "ollama" && normalizedProvider !== "lmstudio" && normalizedProvider !== "vllm") { + if (!isLocalProviderFamily(normalizedProvider)) { return false; } const prefix = `${normalizedProvider}/`; @@ -168,7 +159,7 @@ export function deriveConfiguredModelIds( for (const [provider, connection] of Object.entries(runtimeConnections ?? {})) { const normalizedProvider = normalizeAuthProvider(provider); - if (normalizedProvider !== "ollama" && normalizedProvider !== "lmstudio" && normalizedProvider !== "vllm") { + if (!isLocalProviderFamily(normalizedProvider)) { continue; } if (!hasDynamicLocalModelIdsForProvider(normalizedProvider, status.availableModelIds, runtimeConnections)) { @@ -210,7 +201,7 @@ export function deriveConfiguredModelOptions( const descriptor = getModelById(modelId); return descriptor ? [descriptorToModelOption(descriptor)] - : getLocalProviderFromModelId(modelId) + : parseLocalProviderFromModelId(modelId) ? [buildFallbackModelOption(modelId)] : []; }); @@ -224,6 +215,6 @@ export function includeSelectedModelOption( if (!modelId.length || options.some((option) => option.id === modelId)) return options; const descriptor = getModelById(modelId); if (descriptor) return [descriptorToModelOption(descriptor), ...options]; - if (getLocalProviderFromModelId(modelId)) return [buildFallbackModelOption(modelId), ...options]; + if (parseLocalProviderFromModelId(modelId)) return [buildFallbackModelOption(modelId), ...options]; return options; } diff --git a/apps/desktop/src/shared/modelRegistry.test.ts b/apps/desktop/src/shared/modelRegistry.test.ts index 4fdbf920a..fa446ba97 100644 --- a/apps/desktop/src/shared/modelRegistry.test.ts +++ b/apps/desktop/src/shared/modelRegistry.test.ts @@ -4,6 +4,7 @@ import { getAvailableModels, getDefaultModelDescriptor, getModelById, + getModelDescriptorForPermissionMode, getRuntimeModelRefForDescriptor, listModelDescriptorsForProvider, MODEL_REGISTRY, @@ -76,6 +77,19 @@ describe("modelRegistry", () => { expect(resolveModelDescriptor("nonexistent/model-id")).toBeUndefined(); }); + it("getModelDescriptorForPermissionMode matches getModelById for known locals", () => { + const id = "ollama/qwen2.5-coder:32b"; + expect(getModelDescriptorForPermissionMode(id)).toEqual(getModelById(id)); + }); + + it("getModelDescriptorForPermissionMode yields guarded local for ollama/auto when getModelById is undefined", () => { + expect(getModelById("ollama/auto")).toBeUndefined(); + const perm = getModelDescriptorForPermissionMode("ollama/auto"); + expect(perm?.family).toBe("ollama"); + expect(perm?.harnessProfile).toBe("guarded"); + expect(perm?.authTypes).toContain("local"); + }); + it("resolves gpt-5.4 shortId to the API-key variant, not the codex variant", () => { const resolved = resolveModelAlias("gpt-5.4"); expect(resolved).toBeTruthy(); diff --git a/apps/desktop/src/shared/modelRegistry.ts b/apps/desktop/src/shared/modelRegistry.ts index b7c407cdd..508d4bf44 100644 --- a/apps/desktop/src/shared/modelRegistry.ts +++ b/apps/desktop/src/shared/modelRegistry.ts @@ -82,7 +82,8 @@ export function isModelProviderGroup(value: string | null | undefined): value is const ALL_CAPS: ModelCapabilities = { tools: true, vision: true, reasoning: true, streaming: true }; const NO_REASONING: ModelCapabilities = { tools: true, vision: true, reasoning: false, streaming: true }; const BASIC_CAPS: ModelCapabilities = { tools: true, vision: false, reasoning: false, streaming: true }; -const LOCAL_PROVIDER_LABELS: Record = { +/** Human-readable names for Ollama / LM Studio / vLLM (shared across main, renderer, and MCP). */ +export const LOCAL_PROVIDER_LABELS: Record = { ollama: "Ollama", lmstudio: "LM Studio", vllm: "vLLM", @@ -753,10 +754,38 @@ export function validateModelRegistry(models: ModelDescriptor[] = MODEL_REGISTRY validateModelRegistry(); rebuildIndexes(); -function isLocalProviderFamily(value: string): value is LocalProviderFamily { +export function isLocalProviderFamily(value: string): value is LocalProviderFamily { return value === "ollama" || value === "lmstudio" || value === "vllm"; } +/** First path segment of `provider/modelId` when it is a known local provider. */ +export function parseLocalProviderFromModelId(modelId: string): LocalProviderFamily | null { + const provider = String(modelId ?? "").trim().split("/", 1)[0]?.toLowerCase() ?? ""; + return isLocalProviderFamily(provider) ? provider : null; +} + +/** Model name segment after `provider/` for local refs; empty string if missing. */ +export function getLocalModelIdTail(modelId: string, provider: LocalProviderFamily): string { + return String(modelId ?? "").trim().slice(provider.length + 1).trim(); +} + +/** + * Descriptor for unified permission / harness decisions when the registry has no row yet. + * `getModelById` returns undefined for refs such as `ollama/auto`; this still returns a + * guarded local descriptor so the UI matches main-process harness behavior. + */ +export function getModelDescriptorForPermissionMode(modelId: string): ModelDescriptor | undefined { + const resolved = getModelById(modelId); + if (resolved) return resolved; + const provider = parseLocalProviderFromModelId(modelId); + if (!provider) return undefined; + const tail = getLocalModelIdTail(modelId, provider); + if (!tail.length || tail === "auto") { + return createDynamicLocalModelDescriptor(provider, "auto", { harnessProfile: "guarded" }); + } + return createDynamicLocalModelDescriptor(provider, tail); +} + function parseDynamicLocalModelRef(modelRef: string): { provider: LocalProviderFamily; modelId: string } | null { const normalized = modelRef.trim(); if (!normalized.length) return null;