diff --git a/README.md b/README.md index d0f0033a10..8d2ee72939 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ | **Web Search** | 内置网页搜索工具, 支持 bing 和 brave 搜索 | [文档](https://ccb.agent-aura.top/docs/features/web-browser-tool) | | **Poor Mode** | 穷鬼模式,关闭记忆提取和键入建议,大幅度减少并发请求 | /poor 可以开关 | | **Channels 频道通知** | MCP 服务器推送外部消息到会话(飞书/Slack/Discord/微信等),`--channels plugin:name@marketplace` 启用 | [文档](https://ccb.agent-aura.top/docs/features/channels) | -| **自定义模型供应商** | OpenAI/Anthropic/Gemini/Grok 兼容 (`/login`) | [文档](https://ccb.agent-aura.top/docs/features/all-features-guide) | +| **自定义模型供应商** | OpenAI/Anthropic/Gemini/Grok/Ollama 兼容 (`/login`) | [文档](https://ccb.agent-aura.top/docs/features/all-features-guide) / [Ollama](docs/features/ollama-provider.md) | | Voice Mode | 语音输入,支持豆包语言输入(`/voice doubao`) | [文档](https://ccb.agent-aura.top/docs/features/voice-mode) | | Computer Use | 屏幕截图、键鼠控制 | [文档](https://ccb.agent-aura.top/docs/features/computer-use) | | Chrome Use | 浏览器自动化、表单填写、数据抓取 | [自托管](https://ccb.agent-aura.top/docs/features/chrome-use-mcp) [原生版](https://ccb.agent-aura.top/docs/features/claude-in-chrome-mcp) | diff --git a/docs/features/ollama-provider.md b/docs/features/ollama-provider.md new file mode 100644 index 0000000000..0d401976c1 --- /dev/null +++ b/docs/features/ollama-provider.md @@ -0,0 +1,71 @@ +# Ollama Native Provider + +Claude Code Best supports Ollama through the native Ollama API, not the +OpenAI-compatible endpoint. This lets Cloud and local Ollama use the same +request shape for chat, tool calling, thinking, model discovery, and web +utilities. + +## Configure + +Use `/login` and choose `Ollama`, or configure these environment variables: + +```bash +CLAUDE_CODE_USE_OLLAMA=1 +OLLAMA_API_KEY=ollama_api_key +OLLAMA_BASE_URL=https://ollama.com/api +OLLAMA_DEFAULT_HAIKU_MODEL=qwen3:cloud +OLLAMA_DEFAULT_SONNET_MODEL=qwen3-coder +OLLAMA_DEFAULT_OPUS_MODEL=glm-4.7:cloud +``` + +`OLLAMA_API_KEY` is required for direct Ollama Cloud API access. It is not +required for local Ollama. For local Ollama, set: + +```bash +OLLAMA_BASE_URL=http://localhost:11434/api +``` + +If `OLLAMA_BASE_URL` is omitted, Claude Code Best uses +`https://ollama.com/api`. + +## Model Mapping + +Ollama model routing uses the same three Claude model families shown by +`/model`: + +- `OLLAMA_DEFAULT_HAIKU_MODEL` +- `OLLAMA_DEFAULT_SONNET_MODEL` +- `OLLAMA_DEFAULT_OPUS_MODEL` + +There is no global `OLLAMA_MODEL` override. This keeps Ollama behavior aligned +with other third-party providers, where Haiku/Sonnet/Opus can map to different +backend models. + +When a direct Ollama model name is selected from `/model` or `--model`, it is +sent to Ollama unchanged. When a Claude family model is selected, Claude Code +Best maps it through the matching `OLLAMA_DEFAULT_*_MODEL` variable. If no +family mapping is configured, the fallback is `qwen3-coder`. + +## Supported Features + +- Native `POST /api/chat` streaming +- Ollama tool calling through `tools` +- Ollama thinking through `think` +- Native `POST /api/web_search` +- Native `POST /api/web_fetch` +- Dynamic context length discovery through `POST /api/show` +- Local Ollama and Ollama Cloud through the same provider + +The provider reads model context length from `model_info.*.context_length` or +the `num_ctx` parameter returned by `/api/show`, then uses that value to choose +the request output limit. + +## Known Limits + +Ollama does not expose an official Cloud quota or remaining-balance API in the +documented native API. Claude Code Best therefore does not show Ollama Cloud +remaining quota. + +Anthropic-only server tools are not sent directly to Ollama. Web search and web +fetch are handled client-side through Ollama's native web APIs when the Ollama +provider is active. diff --git a/packages/@ant/model-provider/src/index.ts b/packages/@ant/model-provider/src/index.ts index 6f2b1a56ce..318310c36f 100644 --- a/packages/@ant/model-provider/src/index.ts +++ b/packages/@ant/model-provider/src/index.ts @@ -26,6 +26,17 @@ export * from './types/index.js' export { resolveOpenAIModel } from './providers/openai/modelMapping.js' export { resolveGrokModel } from './providers/grok/modelMapping.js' export { resolveGeminiModel } from './providers/gemini/modelMapping.js' +export { resolveOllamaModel } from './providers/ollama/modelMapping.js' +export { anthropicMessagesToOllama } from './providers/ollama/convertMessages.js' +export { anthropicToolsToOllama } from './providers/ollama/convertTools.js' +export { adaptOllamaStreamToAnthropic } from './providers/ollama/streamAdapter.js' +export type { + OllamaChatChunk, + OllamaChatRequest, + OllamaMessage, + OllamaTool, + OllamaToolCall, +} from './providers/ollama/types.js' // Gemini provider utilities export { anthropicMessagesToGemini } from './providers/gemini/convertMessages.js' diff --git a/packages/@ant/model-provider/src/providers/ollama/__tests__/convertMessages.test.ts b/packages/@ant/model-provider/src/providers/ollama/__tests__/convertMessages.test.ts new file mode 100644 index 0000000000..b340c38c62 --- /dev/null +++ b/packages/@ant/model-provider/src/providers/ollama/__tests__/convertMessages.test.ts @@ -0,0 +1,67 @@ +import { describe, expect, test } from 'bun:test' +import { anthropicMessagesToOllama } from '../convertMessages.js' +import type { SystemPrompt } from '../../../types/systemPrompt.js' + +describe('anthropicMessagesToOllama', () => { + test('converts system, text, tool use, and tool result messages', () => { + const result = anthropicMessagesToOllama( + [ + { + type: 'user', + message: { + role: 'user', + content: [{ type: 'text', text: 'weather?' }], + }, + } as any, + { + type: 'assistant', + message: { + role: 'assistant', + content: [ + { + type: 'tool_use', + id: 'toolu_1', + name: 'get_weather', + input: { city: 'Paris' }, + }, + ], + }, + } as any, + { + type: 'user', + message: { + role: 'user', + content: [ + { + type: 'tool_result', + tool_use_id: 'toolu_1', + content: 'sunny', + }, + ], + }, + } as any, + ], + ['You are concise.'] as unknown as SystemPrompt, + ) + + expect(result).toEqual([ + { role: 'system', content: 'You are concise.' }, + { role: 'user', content: 'weather?' }, + { + role: 'assistant', + content: '', + tool_calls: [ + { + type: 'function', + function: { + index: 0, + name: 'get_weather', + arguments: { city: 'Paris' }, + }, + }, + ], + }, + { role: 'tool', tool_name: 'get_weather', content: 'sunny' }, + ]) + }) +}) diff --git a/packages/@ant/model-provider/src/providers/ollama/__tests__/convertTools.test.ts b/packages/@ant/model-provider/src/providers/ollama/__tests__/convertTools.test.ts new file mode 100644 index 0000000000..b969c21e88 --- /dev/null +++ b/packages/@ant/model-provider/src/providers/ollama/__tests__/convertTools.test.ts @@ -0,0 +1,119 @@ +import { describe, expect, test } from 'bun:test' +import { anthropicToolsToOllama } from '../convertTools.js' + +describe('anthropicToolsToOllama', () => { + test('converts basic tools to Ollama function tools', () => { + const tools = [ + { + type: 'custom', + name: 'bash', + description: 'Run a bash command', + input_schema: { + type: 'object', + properties: { command: { type: 'string' } }, + required: ['command'], + }, + }, + ] + + expect(anthropicToolsToOllama(tools as any)).toEqual([ + { + type: 'function', + function: { + name: 'bash', + description: 'Run a bash command', + parameters: { + type: 'object', + properties: { command: { type: 'string' } }, + required: ['command'], + }, + }, + }, + ]) + }) + + test('keeps WebFetch parameters in Ollama-compatible schema subset', () => { + const tools = [ + { + type: 'custom', + name: 'WebFetch', + description: 'Fetch a URL', + input_schema: { + $schema: 'https://json-schema.org/draft/2020-12/schema', + type: 'object', + properties: { + url: { + type: 'string', + format: 'uri', + description: 'The URL to fetch content from', + }, + prompt: { + type: 'string', + description: 'The prompt to run on the fetched content', + }, + }, + required: ['url', 'prompt'], + additionalProperties: false, + }, + }, + ] + + expect( + anthropicToolsToOllama(tools as any)[0]?.function.parameters, + ).toEqual({ + type: 'object', + properties: { + url: { + type: 'string', + description: 'The URL to fetch content from', + }, + prompt: { + type: 'string', + description: 'The prompt to run on the fetched content', + }, + }, + required: ['url', 'prompt'], + }) + }) + + test('converts const and strips unsupported schema keywords recursively', () => { + const tools = [ + { + type: 'custom', + name: 'complex', + description: 'Complex schema', + input_schema: { + type: 'object', + patternProperties: { + '^x-': { type: 'string' }, + }, + properties: { + mode: { const: 'strict' }, + metadata: { + type: 'object', + additionalProperties: { type: 'string' }, + propertyNames: { pattern: '^[a-z]+$' }, + }, + }, + required: ['mode'], + }, + }, + ] + + expect( + anthropicToolsToOllama(tools as any)[0]?.function.parameters, + ).toEqual({ + type: 'object', + properties: { + mode: { + type: 'string', + enum: ['strict'], + }, + metadata: { + type: 'object', + }, + }, + required: ['mode'], + }) + }) +}) diff --git a/packages/@ant/model-provider/src/providers/ollama/__tests__/modelMapping.test.ts b/packages/@ant/model-provider/src/providers/ollama/__tests__/modelMapping.test.ts new file mode 100644 index 0000000000..f5ced94db0 --- /dev/null +++ b/packages/@ant/model-provider/src/providers/ollama/__tests__/modelMapping.test.ts @@ -0,0 +1,57 @@ +import { afterAll, beforeEach, describe, expect, test } from 'bun:test' +import { resolveOllamaModel } from '../modelMapping.js' + +const envKeys = [ + 'OLLAMA_MODEL', + 'OLLAMA_DEFAULT_HAIKU_MODEL', + 'OLLAMA_DEFAULT_SONNET_MODEL', + 'OLLAMA_DEFAULT_OPUS_MODEL', + 'ANTHROPIC_DEFAULT_SONNET_MODEL', +] as const + +const savedEnv: Record = {} + +for (const key of envKeys) { + savedEnv[key] = process.env[key] +} + +beforeEach(() => { + for (const key of envKeys) { + delete process.env[key] + } +}) + +afterAll(() => { + for (const key of envKeys) { + if (savedEnv[key] === undefined) { + delete process.env[key] + } else { + process.env[key] = savedEnv[key] + } + } +}) + +describe('resolveOllamaModel', () => { + test('keeps direct Ollama model names selected from /model', () => { + expect(resolveOllamaModel('qwen3-coder')).toBe('qwen3-coder') + expect(resolveOllamaModel('glm-4.7:cloud')).toBe('glm-4.7:cloud') + }) + + test('maps Claude family model ids to Ollama defaults', () => { + process.env.OLLAMA_DEFAULT_SONNET_MODEL = 'qwen3-coder' + + expect(resolveOllamaModel('claude-sonnet-4-6')).toBe('qwen3-coder') + }) + + test('does not fall back to Anthropic model env vars for Ollama', () => { + process.env.ANTHROPIC_DEFAULT_SONNET_MODEL = 'claude-sonnet-custom' + + expect(resolveOllamaModel('claude-sonnet-4-6')).toBe('qwen3-coder') + }) + + test('ignores legacy OLLAMA_MODEL global override', () => { + process.env.OLLAMA_MODEL = 'legacy-global-model' + + expect(resolveOllamaModel('claude-sonnet-4-6')).toBe('qwen3-coder') + }) +}) diff --git a/packages/@ant/model-provider/src/providers/ollama/__tests__/streamAdapter.test.ts b/packages/@ant/model-provider/src/providers/ollama/__tests__/streamAdapter.test.ts new file mode 100644 index 0000000000..39c87e9921 --- /dev/null +++ b/packages/@ant/model-provider/src/providers/ollama/__tests__/streamAdapter.test.ts @@ -0,0 +1,84 @@ +import { describe, expect, test } from 'bun:test' +import { adaptOllamaStreamToAnthropic } from '../streamAdapter.js' +import type { OllamaChatChunk } from '../types.js' + +async function collect(chunks: OllamaChatChunk[]) { + const events = [] + async function* stream(): AsyncGenerator { + for (const chunk of chunks) { + yield chunk + } + } + for await (const event of adaptOllamaStreamToAnthropic( + stream(), + 'qwen3-coder', + )) { + events.push(event as any) + } + return events +} + +describe('adaptOllamaStreamToAnthropic', () => { + test('streams thinking, text, tool calls, and usage', async () => { + const events = await collect([ + { message: { thinking: 'think' } }, + { message: { content: 'hello' } }, + { + message: { + tool_calls: [ + { + function: { + name: 'get_weather', + arguments: { city: 'Paris' }, + }, + }, + ], + }, + }, + { + done: true, + done_reason: 'stop', + prompt_eval_count: 12, + eval_count: 5, + }, + ]) + + expect(events[0].type).toBe('message_start') + expect( + events.some( + event => + event.type === 'content_block_delta' && + event.delta.type === 'thinking_delta' && + event.delta.thinking === 'think', + ), + ).toBe(true) + expect( + events.some( + event => + event.type === 'content_block_delta' && + event.delta.type === 'text_delta' && + event.delta.text === 'hello', + ), + ).toBe(true) + expect( + events.some( + event => + event.type === 'content_block_start' && + event.content_block.type === 'tool_use' && + event.content_block.name === 'get_weather', + ), + ).toBe(true) + + const messageDelta = events.find(event => event.type === 'message_delta') + expect(messageDelta).toBeDefined() + expect(messageDelta.delta.stop_reason).toBe('tool_use') + expect(messageDelta.usage.input_tokens).toBe(12) + expect(messageDelta.usage.output_tokens).toBe(5) + }) + + test('throws explicit errors from Ollama stream chunks', async () => { + await expect(collect([{ error: 'model not found' }])).rejects.toThrow( + 'Ollama stream error: model not found', + ) + }) +}) diff --git a/packages/@ant/model-provider/src/providers/ollama/convertMessages.ts b/packages/@ant/model-provider/src/providers/ollama/convertMessages.ts new file mode 100644 index 0000000000..9503d6b5e6 --- /dev/null +++ b/packages/@ant/model-provider/src/providers/ollama/convertMessages.ts @@ -0,0 +1,212 @@ +import type { + BetaToolResultBlockParam, + BetaToolUseBlock, +} from '@anthropic-ai/sdk/resources/beta/messages/messages.mjs' +import type { AssistantMessage, UserMessage } from '../../types/message.js' +import type { SystemPrompt } from '../../types/systemPrompt.js' +import type { OllamaMessage } from './types.js' + +function safeParseJSON(json: string | null | undefined): unknown { + if (!json) return null + try { + return JSON.parse(json) + } catch { + return null + } +} + +export function anthropicMessagesToOllama( + messages: (UserMessage | AssistantMessage)[], + systemPrompt: SystemPrompt, +): OllamaMessage[] { + const result: OllamaMessage[] = [] + const toolNamesById = new Map() + const systemText = systemPromptToText(systemPrompt) + + if (systemText) { + result.push({ role: 'system', content: systemText }) + } + + for (const msg of messages) { + if (msg.type === 'assistant') { + result.push(convertInternalAssistantMessage(msg)) + const content = msg.message.content + if (Array.isArray(content)) { + for (const block of content) { + if (typeof block !== 'string' && block.type === 'tool_use') { + toolNamesById.set(block.id, block.name) + } + } + } + continue + } + + if (msg.type === 'user') { + result.push(...convertInternalUserMessage(msg, toolNamesById)) + } + } + + return result +} + +function systemPromptToText(systemPrompt: SystemPrompt): string { + if (!systemPrompt || systemPrompt.length === 0) return '' + return systemPrompt.filter(Boolean).join('\n\n') +} + +function convertInternalUserMessage( + msg: UserMessage, + toolNamesById: ReadonlyMap, +): OllamaMessage[] { + const content = msg.message.content + if (typeof content === 'string') { + return [{ role: 'user', content }] + } + + if (!Array.isArray(content)) { + return [] + } + + const textParts: string[] = [] + const images: string[] = [] + const toolResults: OllamaMessage[] = [] + + for (const block of content) { + if (typeof block === 'string') { + textParts.push(block) + continue + } + + if (block.type === 'text') { + textParts.push(block.text) + continue + } + + if (block.type === 'tool_result') { + const toolResult = block as BetaToolResultBlockParam + toolResults.push({ + role: 'tool', + tool_name: + toolNamesById.get(toolResult.tool_use_id) ?? toolResult.tool_use_id, + content: normalizeToolResultContent(toolResult), + }) + continue + } + + if (block.type === 'image') { + const imageData = convertImageBlockToOllama( + block as unknown as Record, + ) + if (imageData) images.push(imageData) + } + } + + const messages = [...toolResults] + if (textParts.length > 0 || images.length > 0) { + messages.push({ + role: 'user', + content: textParts.join('\n'), + ...(images.length > 0 && { images }), + }) + } + return messages +} + +function convertInternalAssistantMessage(msg: AssistantMessage): OllamaMessage { + const content = msg.message.content + if (typeof content === 'string') { + return { role: 'assistant', content } + } + + if (!Array.isArray(content)) { + return { role: 'assistant', content: '' } + } + + const textParts: string[] = [] + const thinkingParts: string[] = [] + const toolCalls: OllamaMessage['tool_calls'] = [] + + for (const block of content) { + if (typeof block === 'string') { + textParts.push(block) + continue + } + + if (block.type === 'text') { + textParts.push(block.text) + continue + } + + if (block.type === 'thinking') { + thinkingParts.push(block.thinking) + continue + } + + if (block.type === 'tool_use') { + const toolUse = block as BetaToolUseBlock + toolCalls.push({ + type: 'function', + function: { + index: toolCalls.length, + name: toolUse.name, + arguments: normalizeToolUseInput(toolUse.input), + }, + }) + } + } + + return { + role: 'assistant', + content: textParts.join('\n'), + ...(thinkingParts.length > 0 && { thinking: thinkingParts.join('\n') }), + ...(toolCalls.length > 0 && { tool_calls: toolCalls }), + } +} + +function normalizeToolUseInput(input: unknown): Record { + if (typeof input === 'string') { + const parsed = safeParseJSON(input) + if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { + return parsed as Record + } + return parsed === null ? {} : { value: parsed } + } + + if (input && typeof input === 'object' && !Array.isArray(input)) { + return input as Record + } + + return input === undefined ? {} : { value: input } +} + +function normalizeToolResultContent(block: BetaToolResultBlockParam): string { + const content = block.content + let value: string + + if (typeof content === 'string') { + value = content + } else if (Array.isArray(content)) { + value = content + .map(item => { + if (typeof item === 'string') return item + if ('text' in item) return item.text + return '' + }) + .filter(Boolean) + .join('\n') + } else { + value = '' + } + + return block.is_error ? `Error: ${value}` : value +} + +function convertImageBlockToOllama( + block: Record, +): string | undefined { + const source = block.source as Record | undefined + if (source?.type === 'base64' && typeof source.data === 'string') { + return source.data + } + return undefined +} diff --git a/packages/@ant/model-provider/src/providers/ollama/convertTools.ts b/packages/@ant/model-provider/src/providers/ollama/convertTools.ts new file mode 100644 index 0000000000..918a7c98e7 --- /dev/null +++ b/packages/@ant/model-provider/src/providers/ollama/convertTools.ts @@ -0,0 +1,192 @@ +import type { BetaToolUnion } from '@anthropic-ai/sdk/resources/beta/messages/messages.mjs' +import type { OllamaTool } from './types.js' + +const OLLAMA_JSON_SCHEMA_TYPES = new Set([ + 'string', + 'number', + 'integer', + 'boolean', + 'object', + 'array', + 'null', +]) + +export function anthropicToolsToOllama(tools: BetaToolUnion[]): OllamaTool[] { + return tools + .filter(tool => { + const toolType = (tool as unknown as { type?: string }).type + return ( + tool.type === 'custom' || !('type' in tool) || toolType !== 'server' + ) + }) + .map(tool => { + const anyTool = tool as unknown as Record + return { + type: 'function', + function: { + name: (anyTool.name as string) || '', + description: (anyTool.description as string) || '', + parameters: sanitizeOllamaFunctionParameters( + (anyTool.input_schema as Record | undefined) || { + type: 'object', + properties: {}, + }, + ), + }, + } + }) +} + +function normalizeJsonSchemaType( + value: unknown, +): string | string[] | undefined { + if (typeof value === 'string') { + return OLLAMA_JSON_SCHEMA_TYPES.has(value) ? value : undefined + } + + if (Array.isArray(value)) { + const normalized = value.filter( + (item): item is string => + typeof item === 'string' && OLLAMA_JSON_SCHEMA_TYPES.has(item), + ) + const unique = Array.from(new Set(normalized)) + if (unique.length === 0) return undefined + return unique.length === 1 ? unique[0] : unique + } + + return undefined +} + +function inferJsonSchemaTypeFromValue(value: unknown): string | undefined { + if (value === null) return 'null' + if (Array.isArray(value)) return 'array' + if (typeof value === 'string') return 'string' + if (typeof value === 'boolean') return 'boolean' + if (typeof value === 'number') { + return Number.isInteger(value) ? 'integer' : 'number' + } + if (typeof value === 'object') return 'object' + return undefined +} + +function inferJsonSchemaTypeFromEnum( + values: unknown[], +): string | string[] | undefined { + const inferred = values + .map(inferJsonSchemaTypeFromValue) + .filter((value): value is string => value !== undefined) + const unique = Array.from(new Set(inferred)) + if (unique.length === 0) return undefined + return unique.length === 1 ? unique[0] : unique +} + +function sanitizeProperties( + value: unknown, +): Record> | undefined { + if (!value || typeof value !== 'object' || Array.isArray(value)) { + return undefined + } + + const entries = Object.entries(value as Record) + .map(([key, schema]) => [key, sanitizeOllamaJsonSchema(schema)] as const) + .filter(([, schema]) => Object.keys(schema).length > 0) + + return entries.length > 0 ? Object.fromEntries(entries) : undefined +} + +function sanitizeSchemaArray( + value: unknown, +): Record[] | undefined { + if (!Array.isArray(value)) return undefined + + const sanitized = value + .map(item => sanitizeOllamaJsonSchema(item)) + .filter(item => Object.keys(item).length > 0) + + return sanitized.length > 0 ? sanitized : undefined +} + +function sanitizeOllamaJsonSchema(schema: unknown): Record { + if (!schema || typeof schema !== 'object' || Array.isArray(schema)) { + return {} + } + + const source = schema as Record + const result: Record = {} + let type = normalizeJsonSchemaType(source.type) + + if (source.const !== undefined) { + result.enum = [source.const] + type = type ?? inferJsonSchemaTypeFromValue(source.const) + } else if (Array.isArray(source.enum) && source.enum.length > 0) { + result.enum = source.enum + type = type ?? inferJsonSchemaTypeFromEnum(source.enum) + } + + if (!type) { + if (source.properties && typeof source.properties === 'object') { + type = 'object' + } else if (source.items !== undefined || source.prefixItems !== undefined) { + type = 'array' + } + } + + if (type) { + result.type = type + } + + if (typeof source.description === 'string') { + result.description = source.description + } + + const properties = sanitizeProperties(source.properties) + if (properties) { + result.properties = properties + } + + if (Array.isArray(source.required)) { + const required = source.required.filter( + (item): item is string => typeof item === 'string', + ) + if (required.length > 0) { + result.required = required + } + } + + const items = sanitizeOllamaJsonSchema(source.items) + if (Object.keys(items).length > 0) { + result.items = items + } + + const anyOf = sanitizeSchemaArray(source.anyOf ?? source.oneOf) + if (anyOf) { + result.anyOf = anyOf + } + + const allOf = sanitizeSchemaArray(source.allOf) + if (allOf) { + result.allOf = allOf + } + + return result +} + +function sanitizeOllamaFunctionParameters( + schema: unknown, +): Record { + const sanitized = sanitizeOllamaJsonSchema(schema) + if (Object.keys(sanitized).length > 0) { + if (sanitized.type !== 'object') { + return { + type: 'object', + properties: {}, + } + } + return sanitized + } + + return { + type: 'object', + properties: {}, + } +} diff --git a/packages/@ant/model-provider/src/providers/ollama/modelMapping.ts b/packages/@ant/model-provider/src/providers/ollama/modelMapping.ts new file mode 100644 index 0000000000..dcd43bb7e6 --- /dev/null +++ b/packages/@ant/model-provider/src/providers/ollama/modelMapping.ts @@ -0,0 +1,41 @@ +/** + * Resolve the Ollama model name from the selected model setting. + * + * Priority: + * 1. Direct Ollama model names selected from /model or --model + * 2. OLLAMA_DEFAULT_{FAMILY}_MODEL env var — per-family override + * 3. Fall back to qwen3-coder, a coding-oriented Ollama model name + * + * Ollama users configure model routing through per-family overrides. + * The fallback avoids sending Claude model IDs to Ollama, which would fail for + * users who only selected the provider and have not copied model aliases. + */ + +function getModelFamily(model: string): 'haiku' | 'sonnet' | 'opus' | null { + if (/haiku/i.test(model)) return 'haiku' + if (/opus/i.test(model)) return 'opus' + if (/sonnet/i.test(model)) return 'sonnet' + return null +} + +function isClaudeModelId(model: string): boolean { + return /\bclaude[-.]/i.test(model) || /\banthropic\.claude[-.]/i.test(model) +} + +export function resolveOllamaModel(selectedModel: string): string { + const cleanModel = selectedModel.replace(/\[1m\]$/, '') + if (!isClaudeModelId(cleanModel)) { + return cleanModel + } + + const family = getModelFamily(cleanModel) + + // 2. Per-family env var (OLLAMA_DEFAULT_OPUS_MODEL, etc.) + if (family) { + const ollamaEnvVar = `OLLAMA_DEFAULT_${family.toUpperCase()}_MODEL` + const ollamaOverride = process.env[ollamaEnvVar] + if (ollamaOverride) return ollamaOverride + } + + return 'qwen3-coder' +} diff --git a/packages/@ant/model-provider/src/providers/ollama/streamAdapter.ts b/packages/@ant/model-provider/src/providers/ollama/streamAdapter.ts new file mode 100644 index 0000000000..d0d3145323 --- /dev/null +++ b/packages/@ant/model-provider/src/providers/ollama/streamAdapter.ts @@ -0,0 +1,201 @@ +import type { BetaRawMessageStreamEvent } from '@anthropic-ai/sdk/resources/beta/messages/messages.mjs' +import { randomUUID } from 'crypto' +import type { OllamaChatChunk } from './types.js' + +export async function* adaptOllamaStreamToAnthropic( + stream: AsyncIterable, + model: string, +): AsyncGenerator { + const messageId = `msg_${randomUUID().replace(/-/g, '').slice(0, 24)}` + let started = false + let nextContentIndex = 0 + let openTextLikeBlock: { index: number; type: 'text' | 'thinking' } | null = + null + let sawToolUse = false + let doneReason: string | undefined + let inputTokens = 0 + let outputTokens = 0 + + for await (const chunk of stream) { + if (chunk.error) { + throw new Error(`Ollama stream error: ${chunk.error}`) + } + + inputTokens = chunk.prompt_eval_count ?? inputTokens + outputTokens = chunk.eval_count ?? outputTokens + + if (!started) { + started = true + yield { + type: 'message_start', + message: { + id: messageId, + type: 'message', + role: 'assistant', + content: [], + model, + stop_reason: null, + stop_sequence: null, + usage: { + input_tokens: inputTokens, + output_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + }, + } as unknown as BetaRawMessageStreamEvent + } + + const message = chunk.message + if (message?.thinking) { + if (!openTextLikeBlock || openTextLikeBlock.type !== 'thinking') { + if (openTextLikeBlock) { + yield { + type: 'content_block_stop', + index: openTextLikeBlock.index, + } as BetaRawMessageStreamEvent + } + + openTextLikeBlock = { + index: nextContentIndex++, + type: 'thinking', + } + yield { + type: 'content_block_start', + index: openTextLikeBlock.index, + content_block: { + type: 'thinking', + thinking: '', + signature: '', + }, + } as BetaRawMessageStreamEvent + } + + yield { + type: 'content_block_delta', + index: openTextLikeBlock.index, + delta: { + type: 'thinking_delta', + thinking: message.thinking, + }, + } as BetaRawMessageStreamEvent + } + + if (message?.content) { + if (!openTextLikeBlock || openTextLikeBlock.type !== 'text') { + if (openTextLikeBlock) { + yield { + type: 'content_block_stop', + index: openTextLikeBlock.index, + } as BetaRawMessageStreamEvent + } + + openTextLikeBlock = { + index: nextContentIndex++, + type: 'text', + } + yield { + type: 'content_block_start', + index: openTextLikeBlock.index, + content_block: { + type: 'text', + text: '', + }, + } as BetaRawMessageStreamEvent + } + + yield { + type: 'content_block_delta', + index: openTextLikeBlock.index, + delta: { + type: 'text_delta', + text: message.content, + }, + } as BetaRawMessageStreamEvent + } + + for (const toolCall of message?.tool_calls ?? []) { + if (openTextLikeBlock) { + yield { + type: 'content_block_stop', + index: openTextLikeBlock.index, + } as BetaRawMessageStreamEvent + openTextLikeBlock = null + } + + sawToolUse = true + const toolIndex = nextContentIndex++ + const toolId = `toolu_${randomUUID().replace(/-/g, '').slice(0, 24)}` + yield { + type: 'content_block_start', + index: toolIndex, + content_block: { + type: 'tool_use', + id: toolId, + name: toolCall.function.name || '', + input: {}, + }, + } as BetaRawMessageStreamEvent + + yield { + type: 'content_block_delta', + index: toolIndex, + delta: { + type: 'input_json_delta', + partial_json: JSON.stringify(toolCall.function.arguments ?? {}), + }, + } as BetaRawMessageStreamEvent + + yield { + type: 'content_block_stop', + index: toolIndex, + } as BetaRawMessageStreamEvent + } + + if (chunk.done) { + doneReason = chunk.done_reason + } + } + + if (!started) return + + if (openTextLikeBlock) { + yield { + type: 'content_block_stop', + index: openTextLikeBlock.index, + } as BetaRawMessageStreamEvent + } + + yield { + type: 'message_delta', + delta: { + stop_reason: mapOllamaDoneReason(doneReason, sawToolUse), + stop_sequence: null, + }, + usage: { + input_tokens: inputTokens, + output_tokens: outputTokens, + cache_read_input_tokens: 0, + cache_creation_input_tokens: 0, + }, + } as BetaRawMessageStreamEvent + + yield { + type: 'message_stop', + } as BetaRawMessageStreamEvent +} + +function mapOllamaDoneReason( + reason: string | undefined, + sawToolUse: boolean, +): string { + if (sawToolUse) return 'tool_use' + switch (reason) { + case 'length': + return 'max_tokens' + case 'stop': + case 'unload': + default: + return 'end_turn' + } +} diff --git a/packages/@ant/model-provider/src/providers/ollama/types.ts b/packages/@ant/model-provider/src/providers/ollama/types.ts new file mode 100644 index 0000000000..e9029ceba4 --- /dev/null +++ b/packages/@ant/model-provider/src/providers/ollama/types.ts @@ -0,0 +1,54 @@ +export interface OllamaMessage { + role: 'system' | 'user' | 'assistant' | 'tool' + content?: string + thinking?: string + tool_name?: string + images?: string[] + tool_calls?: OllamaToolCall[] +} + +export interface OllamaTool { + type: 'function' + function: { + name: string + description?: string + parameters: Record + } +} + +export interface OllamaToolCall { + type?: 'function' + function: { + index?: number + name: string + arguments: Record + } +} + +export interface OllamaChatRequest { + model: string + messages: OllamaMessage[] + stream: boolean + tools?: OllamaTool[] + think?: boolean | 'high' | 'medium' | 'low' + options?: { + temperature?: number + num_predict?: number + } +} + +export interface OllamaChatChunk { + error?: string + model?: string + created_at?: string + message?: { + role?: 'assistant' + content?: string + thinking?: string + tool_calls?: OllamaToolCall[] + } + done?: boolean + done_reason?: string + prompt_eval_count?: number + eval_count?: number +} diff --git a/packages/builtin-tools/src/tools/WebFetchTool/__tests__/headers.test.ts b/packages/builtin-tools/src/tools/WebFetchTool/__tests__/headers.test.ts index 20755e247c..8050385a08 100644 --- a/packages/builtin-tools/src/tools/WebFetchTool/__tests__/headers.test.ts +++ b/packages/builtin-tools/src/tools/WebFetchTool/__tests__/headers.test.ts @@ -38,10 +38,28 @@ mock.module('src/services/api/claude.js', () => ({ queryHaiku: async () => ({ message: { content: [] } }), })) +mock.module('src/services/api/ollama/client.js', () => ({ + getOllamaClient: () => ({ + webFetch: async () => + new Response( + JSON.stringify({ + title: 'Ollama Page', + content: 'Cloud content', + links: ['https://ollama.com/models'], + }), + { status: 200, statusText: 'OK' }, + ), + }), +})) + mock.module('src/utils/http.js', () => ({ getWebFetchUserAgent: () => 'TestAgent/1.0', })) +mock.module('src/utils/model/providers.js', () => ({ + getAPIProvider: () => 'ollama', +})) + mock.module('src/utils/log.ts', logMock) mock.module('src/utils/mcpOutputStorage.js', () => ({ @@ -59,6 +77,7 @@ mock.module('src/utils/settings/settings.js', () => ({ })) beforeEach(() => { + delete process.env.OLLAMA_USE_NATIVE_WEB_FETCH getMock = async () => ({ data: new TextEncoder().encode('hello').buffer, headers: { 'content-type': 'text/plain' }, @@ -120,6 +139,7 @@ describe('WebFetch response headers', () => { }) test('normalizes array content-type before cache and parsing', async () => { + process.env.OLLAMA_USE_NATIVE_WEB_FETCH = 'false' getMock = async () => ({ data: new TextEncoder().encode('plain body').buffer, headers: { 'content-type': ['text/plain', 'charset=utf-8'] }, @@ -144,4 +164,23 @@ describe('WebFetch response headers', () => { expect(result.content).toBe('plain body') expect(result.contentType).toBe('text/plain, charset=utf-8') }) + + test('uses Ollama native web_fetch when enabled', async () => { + process.env.OLLAMA_USE_NATIVE_WEB_FETCH = 'true' + + const { getURLMarkdownContent } = await import('../utils') + const result = await getURLMarkdownContent( + 'https://ollama.com', + new AbortController(), + ) + + expect('type' in result).toBe(false) + if ('type' in result) { + throw new Error('unexpected redirect result') + } + expect(result.content).toContain('# Ollama Page') + expect(result.content).toContain('Cloud content') + expect(result.content).toContain('https://ollama.com/models') + expect(result.contentType).toBe('text/markdown') + }) }) diff --git a/packages/builtin-tools/src/tools/WebFetchTool/utils.ts b/packages/builtin-tools/src/tools/WebFetchTool/utils.ts index 59a758a1ec..6ea8833018 100644 --- a/packages/builtin-tools/src/tools/WebFetchTool/utils.ts +++ b/packages/builtin-tools/src/tools/WebFetchTool/utils.ts @@ -5,8 +5,11 @@ import { logEvent, } from 'src/services/analytics/index.js' import { queryHaiku } from 'src/services/api/claude.js' +import { getOllamaClient } from 'src/services/api/ollama/client.js' import { AbortError } from 'src/utils/errors.js' +import { isEnvTruthy } from 'src/utils/envUtils.js' import { getWebFetchUserAgent } from 'src/utils/http.js' +import { getAPIProvider } from 'src/utils/model/providers.js' import { logError } from 'src/utils/log.js' import { isBinaryContentType, @@ -376,6 +379,64 @@ export type FetchedContent = { persistedSize?: number } +interface OllamaWebFetchResponse { + title?: string + content?: string + links?: string[] +} + +export function shouldUseOllamaWebFetch(): boolean { + if (process.env.OLLAMA_USE_NATIVE_WEB_FETCH !== undefined) { + return isEnvTruthy(process.env.OLLAMA_USE_NATIVE_WEB_FETCH) + } + if (getAPIProvider() !== 'ollama') return false + const baseURL = process.env.OLLAMA_BASE_URL || 'https://ollama.com/api' + try { + return new URL(baseURL).hostname === 'ollama.com' + } catch { + return false + } +} + +async function getURLMarkdownContentFromOllama( + url: string, + abortController: AbortController, +): Promise { + const client = getOllamaClient() + const response = await client.webFetch( + { url }, + { signal: abortController.signal }, + ) + if (!response.ok) { + const body = await response.text().catch(() => '') + throw new Error( + `Ollama web_fetch failed: HTTP ${response.status} ${response.statusText}${body ? `: ${body}` : ''}`, + ) + } + + const payload = (await response.json()) as OllamaWebFetchResponse + const title = payload.title?.trim() + const content = payload.content?.trim() ?? '' + const links = Array.isArray(payload.links) ? payload.links : [] + const rendered = [ + title ? `# ${title}` : '', + content, + links.length > 0 + ? `\n\nLinks:\n${links.map(link => `- ${link}`).join('\n')}` + : '', + ] + .filter(Boolean) + .join('\n\n') + + return { + bytes: Buffer.byteLength(rendered), + code: response.status, + codeText: response.statusText, + content: rendered, + contentType: 'text/markdown', + } +} + export async function getURLMarkdownContent( url: string, abortController: AbortController, @@ -384,6 +445,10 @@ export async function getURLMarkdownContent( throw new Error('Invalid URL') } + if (shouldUseOllamaWebFetch()) { + return getURLMarkdownContentFromOllama(url, abortController) + } + // Check cache (LRUCache handles TTL automatically) const cachedEntry = URL_CACHE.get(url) if (cachedEntry) { diff --git a/packages/builtin-tools/src/tools/WebSearchTool/__tests__/adapterFactory.test.ts b/packages/builtin-tools/src/tools/WebSearchTool/__tests__/adapterFactory.test.ts index 4e5353d89a..a0c903c9d9 100644 --- a/packages/builtin-tools/src/tools/WebSearchTool/__tests__/adapterFactory.test.ts +++ b/packages/builtin-tools/src/tools/WebSearchTool/__tests__/adapterFactory.test.ts @@ -1,12 +1,19 @@ import { afterEach, describe, expect, mock, test } from 'bun:test' let isFirstPartyBaseUrl = true +let apiProvider = 'firstParty' // Only mock the external dependency that controls adapter selection mock.module('src/utils/model/providers.js', () => ({ isFirstPartyAnthropicBaseUrl: () => isFirstPartyBaseUrl, - getAPIProvider: () => 'firstParty', - getAPIProviderForStatsig: () => 'firstParty', + getAPIProvider: () => apiProvider, + getAPIProviderForStatsig: () => apiProvider, +})) + +mock.module('src/services/api/ollama/client.js', () => ({ + getOllamaClient: () => ({ + webSearch: async () => new Response(JSON.stringify({ results: [] })), + }), })) const { createAdapter } = await import('../adapters/index') @@ -15,6 +22,7 @@ const originalWebSearchAdapter = process.env.WEB_SEARCH_ADAPTER afterEach(() => { isFirstPartyBaseUrl = true + apiProvider = 'firstParty' if (originalWebSearchAdapter === undefined) { delete process.env.WEB_SEARCH_ADAPTER @@ -58,4 +66,18 @@ describe('createAdapter', () => { expect(createAdapter().constructor.name).toBe('ExaSearchAdapter') }) + + test('selects the Bing adapter for third-party providers', () => { + delete process.env.WEB_SEARCH_ADAPTER + apiProvider = 'openai' + + expect(createAdapter().constructor.name).toBe('BingSearchAdapter') + }) + + test('selects the Ollama adapter for Ollama provider', () => { + delete process.env.WEB_SEARCH_ADAPTER + apiProvider = 'ollama' + + expect(createAdapter().constructor.name).toBe('OllamaSearchAdapter') + }) }) diff --git a/packages/builtin-tools/src/tools/WebSearchTool/__tests__/ollamaAdapter.test.ts b/packages/builtin-tools/src/tools/WebSearchTool/__tests__/ollamaAdapter.test.ts new file mode 100644 index 0000000000..fbe4aed302 --- /dev/null +++ b/packages/builtin-tools/src/tools/WebSearchTool/__tests__/ollamaAdapter.test.ts @@ -0,0 +1,44 @@ +import { describe, expect, mock, test } from 'bun:test' + +mock.module('src/services/api/ollama/client.js', () => ({ + getOllamaClient: () => ({ + webSearch: async () => + new Response( + JSON.stringify({ + results: [ + { + title: 'Allowed', + url: 'https://Docs.Example.com/path', + content: 'Allowed result', + }, + { + title: 'Blocked', + url: 'https://blocked.example.com/path', + content: 'Blocked result', + }, + ], + }), + ), + }), +})) + +const { OllamaSearchAdapter } = await import('../adapters/ollamaAdapter.js') + +describe('OllamaSearchAdapter', () => { + test('normalizes allowed and blocked domain filters', async () => { + const adapter = new OllamaSearchAdapter() + + const results = await adapter.search('query', { + allowedDomains: [' .EXAMPLE.com. '], + blockedDomains: ['BLOCKED.example.com'], + }) + + expect(results).toEqual([ + { + title: 'Allowed', + url: 'https://Docs.Example.com/path', + snippet: 'Allowed result', + }, + ]) + }) +}) diff --git a/packages/builtin-tools/src/tools/WebSearchTool/adapters/index.ts b/packages/builtin-tools/src/tools/WebSearchTool/adapters/index.ts index 32226ea434..4e7facfcac 100644 --- a/packages/builtin-tools/src/tools/WebSearchTool/adapters/index.ts +++ b/packages/builtin-tools/src/tools/WebSearchTool/adapters/index.ts @@ -3,11 +3,15 @@ * whether the API base URL points to Anthropic's official endpoint. */ -import { isFirstPartyAnthropicBaseUrl } from 'src/utils/model/providers.js' +import { + getAPIProvider, + isFirstPartyAnthropicBaseUrl, +} from 'src/utils/model/providers.js' import { ApiSearchAdapter } from './apiAdapter.js' import { BingSearchAdapter } from './bingAdapter.js' import { BraveSearchAdapter } from './braveAdapter.js' import { ExaSearchAdapter } from './exaAdapter.js' +import { OllamaSearchAdapter } from './ollamaAdapter.js' import type { WebSearchAdapter } from './types.js' export type { @@ -23,34 +27,34 @@ export type { * so they must fall back to the Bing scraper adapter. */ function isThirdPartyProvider(): boolean { - return !!( - process.env.CLAUDE_CODE_USE_OPENAI || - process.env.CLAUDE_CODE_USE_GEMINI || - process.env.CLAUDE_CODE_USE_GROK - ) + const provider = getAPIProvider() + return provider === 'openai' || provider === 'gemini' || provider === 'grok' } let cachedAdapter: WebSearchAdapter | null = null -let cachedAdapterKey: 'api' | 'bing' | 'brave' | 'exa' | null = null +let cachedAdapterKey: 'api' | 'bing' | 'brave' | 'exa' | 'ollama' | null = null export function createAdapter(): WebSearchAdapter { const envAdapter = process.env.WEB_SEARCH_ADAPTER // Priority: // 1. Explicit env override (WEB_SEARCH_ADAPTER=api|bing|brave) - // 2. Third-party provider (OpenAI/Gemini/Grok) → bing (no server_tools support) - // 3. First-party Anthropic API → api (server-side web search + connector_text) - // 4. Fallback → bing + // 2. Ollama provider → ollama (native /api/web_search) + // 3. Third-party provider (OpenAI/Gemini/Grok) → bing (no server_tools support) + // 4. First-party Anthropic API → api (server-side web search + connector_text) + // 5. Fallback → exa const adapterKey = envAdapter === 'api' || envAdapter === 'bing' || envAdapter === 'brave' || envAdapter === 'exa' ? envAdapter - : isThirdPartyProvider() - ? 'bing' - : isFirstPartyAnthropicBaseUrl() - ? 'api' - : 'exa' + : getAPIProvider() === 'ollama' + ? 'ollama' + : isThirdPartyProvider() + ? 'bing' + : isFirstPartyAnthropicBaseUrl() + ? 'api' + : 'exa' if (cachedAdapter && cachedAdapterKey === adapterKey) return cachedAdapter @@ -69,6 +73,11 @@ export function createAdapter(): WebSearchAdapter { cachedAdapterKey = 'exa' return cachedAdapter } + if (adapterKey === 'ollama') { + cachedAdapter = new OllamaSearchAdapter() + cachedAdapterKey = 'ollama' + return cachedAdapter + } cachedAdapter = new BingSearchAdapter() cachedAdapterKey = 'bing' diff --git a/packages/builtin-tools/src/tools/WebSearchTool/adapters/ollamaAdapter.ts b/packages/builtin-tools/src/tools/WebSearchTool/adapters/ollamaAdapter.ts new file mode 100644 index 0000000000..5e18c9bb36 --- /dev/null +++ b/packages/builtin-tools/src/tools/WebSearchTool/adapters/ollamaAdapter.ts @@ -0,0 +1,112 @@ +import { getOllamaClient } from 'src/services/api/ollama/client.js' +import { AbortError } from 'src/utils/errors.js' +import type { SearchResult, SearchOptions, WebSearchAdapter } from './types.js' + +interface OllamaWebSearchResult { + title?: string + url?: string + content?: string +} + +interface OllamaWebSearchResponse { + results?: OllamaWebSearchResult[] +} + +export class OllamaSearchAdapter implements WebSearchAdapter { + async search(query: string, options: SearchOptions): Promise { + const { signal, onProgress, allowedDomains, blockedDomains } = options + if (signal?.aborted) { + throw new AbortError() + } + + onProgress?.({ type: 'query_update', query }) + + const client = getOllamaClient() + const maxResults = Math.min(Math.max(options.numResults ?? 5, 1), 10) + const response = await client.webSearch( + { + query, + max_results: maxResults, + }, + { signal }, + ) + + if (signal?.aborted) { + throw new AbortError() + } + + if (!response.ok) { + const body = await response.text().catch(() => '') + throw new Error( + `Ollama web_search failed: HTTP ${response.status} ${response.statusText}${body ? `: ${body}` : ''}`, + ) + } + + const payload = (await response.json()) as OllamaWebSearchResponse + const results: SearchResult[] = [] + for (const result of payload.results ?? []) { + if (typeof result.url !== 'string') continue + if (!matchesDomainFilters(result.url, allowedDomains, blockedDomains)) { + continue + } + + const title = result.title?.trim() || result.url + const snippet = result.content?.trim() + results.push({ + title, + url: result.url, + ...(snippet && { snippet }), + }) + } + + onProgress?.({ + type: 'search_results_received', + resultCount: results.length, + query, + }) + + return results + } +} + +function matchesDomainFilters( + url: string, + allowedDomains?: string[], + blockedDomains?: string[], +): boolean { + try { + const hostname = new URL(url).hostname.toLowerCase() + const allowed = normalizeDomains(allowedDomains) + const blocked = normalizeDomains(blockedDomains) + if ( + allowed.length && + !allowed.some( + domain => hostname === domain || hostname.endsWith('.' + domain), + ) + ) { + return false + } + if ( + blocked.length && + blocked.some( + domain => hostname === domain || hostname.endsWith('.' + domain), + ) + ) { + return false + } + return true + } catch { + return false + } +} + +function normalizeDomains(domains?: string[]): string[] { + return (domains ?? []) + .map(domain => + domain + .trim() + .toLowerCase() + .replace(/^\.+|\.+$/g, ''), + ) + .filter(Boolean) +} diff --git a/src/commands.ts b/src/commands.ts index 33c1c75f0f..12779eb647 100644 --- a/src/commands.ts +++ b/src/commands.ts @@ -227,6 +227,7 @@ import { import rateLimitOptions from './commands/rate-limit-options/index.js' import statusline from './commands/statusline.js' import effort from './commands/effort/index.js' +import approval from './commands/approval/index.js' import stats from './commands/stats/index.js' // insights.ts is 113KB (3200 lines, includes diffLines/html rendering). Lazy // shim defers the heavy module until /insights is actually invoked. @@ -299,6 +300,7 @@ const COMMANDS = memoize((): Command[] => [ addDir, advisor, autonomy, + approval, provider, agents, branch, diff --git a/src/commands/approval/__tests__/approvalModes.test.ts b/src/commands/approval/__tests__/approvalModes.test.ts new file mode 100644 index 0000000000..cea42319b1 --- /dev/null +++ b/src/commands/approval/__tests__/approvalModes.test.ts @@ -0,0 +1,98 @@ +import { describe, expect, test } from 'bun:test' +import { + formatApprovalMode, + getApprovalModeDescriptor, + parseApprovalModeArg, +} from '../approvalModes.js' + +describe('parseApprovalModeArg', () => { + test('treats empty and status arguments as current mode requests', () => { + expect(parseApprovalModeArg('')).toEqual({ type: 'current' }) + expect(parseApprovalModeArg('status')).toEqual({ type: 'current' }) + expect(parseApprovalModeArg('show')).toEqual({ type: 'current' }) + }) + + test('treats help aliases as help requests', () => { + expect(parseApprovalModeArg('help')).toEqual({ type: 'help' }) + expect(parseApprovalModeArg('-h')).toEqual({ type: 'help' }) + expect(parseApprovalModeArg('--help')).toEqual({ type: 'help' }) + }) + + test('parses standard approval modes', () => { + expect(parseApprovalModeArg('default')).toEqual({ + type: 'mode', + mode: 'default', + }) + expect(parseApprovalModeArg('accept-edits')).toEqual({ + type: 'mode', + mode: 'acceptEdits', + }) + expect(parseApprovalModeArg('plan')).toEqual({ type: 'mode', mode: 'plan' }) + expect(parseApprovalModeArg('auto')).toEqual({ type: 'mode', mode: 'auto' }) + expect(parseApprovalModeArg('dont-ask')).toEqual({ + type: 'mode', + mode: 'dontAsk', + }) + }) + + test('maps full access aliases to bypassPermissions', () => { + expect(parseApprovalModeArg('full-access')).toEqual({ + type: 'mode', + mode: 'bypassPermissions', + }) + expect(parseApprovalModeArg('full_access')).toEqual({ + type: 'mode', + mode: 'bypassPermissions', + }) + expect(parseApprovalModeArg('bypass')).toEqual({ + type: 'mode', + mode: 'bypassPermissions', + }) + expect(parseApprovalModeArg('allow-all')).toEqual({ + type: 'mode', + mode: 'bypassPermissions', + }) + }) + + test('normalizes case, spaces, and underscores', () => { + expect(parseApprovalModeArg(' FULL_ACCESS ')).toEqual({ + type: 'mode', + mode: 'bypassPermissions', + }) + expect(parseApprovalModeArg('ACCEPT EDITS')).toEqual({ + type: 'mode', + mode: 'acceptEdits', + }) + }) + + test('reports invalid arguments', () => { + const result = parseApprovalModeArg('wat') + expect(result.type).toBe('invalid') + if (result.type === 'invalid') { + expect(result.message).toContain('Invalid approval mode: wat') + } + }) +}) + +describe('formatApprovalMode', () => { + test('uses user-facing names for regular modes', () => { + expect(formatApprovalMode('default')).toBe('Default') + expect(formatApprovalMode('plan')).toBe('Plan') + }) + + test('uses the user-facing full access name for bypassPermissions', () => { + expect(formatApprovalMode('bypassPermissions')).toBe( + 'Full access (bypassPermissions)', + ) + }) + + test('uses an explicit fallback for internal modes', () => { + expect(formatApprovalMode('bubble')).toBe('Internal/Unknown') + expect(getApprovalModeDescriptor('bubble')).toEqual({ + mode: 'bubble', + label: 'Internal/Unknown', + description: 'This approval mode is internal and not user-selectable', + aliases: ['bubble'], + }) + }) +}) diff --git a/src/commands/approval/approval.tsx b/src/commands/approval/approval.tsx new file mode 100644 index 0000000000..8f1b8b1113 --- /dev/null +++ b/src/commands/approval/approval.tsx @@ -0,0 +1,191 @@ +import * as React from 'react'; +import { Box, Text } from '@anthropic/ink'; +import type { OptionWithDescription } from '../../components/CustomSelect/select.js'; +import { Select } from '../../components/CustomSelect/select.js'; +import { + type AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS, + logEvent, +} from '../../services/analytics/index.js'; +import { useAppState, useSetAppState } from '../../state/AppState.js'; +import type { ToolPermissionContext } from '../../Tool.js'; +import type { LocalJSXCommandOnDone } from '../../types/command.js'; +import type { PermissionMode } from '../../utils/permissions/PermissionMode.js'; +import { + getAutoModeUnavailableNotification, + getAutoModeUnavailableReason, + isBypassPermissionsModeDisabled, + transitionPermissionMode, +} from '../../utils/permissions/permissionSetup.js'; +import { + APPROVAL_MODE_DESCRIPTORS, + formatApprovalMode, + getApprovalModeDescriptor, + parseApprovalModeArg, +} from './approvalModes.js'; + +type ApprovalCommandResult = { + message: string; + modeUpdate?: PermissionMode; +}; + +function getModeUnavailableMessage(mode: PermissionMode, context: ToolPermissionContext): string | undefined { + if (mode === 'bypassPermissions') { + if (isBypassPermissionsModeDisabled()) { + return 'Full access is disabled by settings or organization policy.'; + } + if (!context.isBypassPermissionsModeAvailable) { + return 'Full access is not available in this session. Start ccb with --allow-dangerously-skip-permissions to make it selectable, or --dangerously-skip-permissions to start directly in full access mode.'; + } + } + + if (mode === 'auto') { + const reason = getAutoModeUnavailableReason(); + if (reason) { + return `Auto approval is unavailable: ${getAutoModeUnavailableNotification(reason)}`; + } + } + + return undefined; +} + +function applyApprovalMode(context: ToolPermissionContext, mode: PermissionMode): ApprovalCommandResult { + const unavailableMessage = getModeUnavailableMessage(mode, context); + if (unavailableMessage) { + return { message: unavailableMessage }; + } + + if (context.mode === mode) { + return { message: `Approval mode is already ${formatApprovalMode(mode)}.` }; + } + + logEvent('tengu_approval_command', { + mode: mode as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS, + }); + + const descriptor = getApprovalModeDescriptor(mode); + return { + message: `Approval mode set to ${formatApprovalMode(mode)}: ${descriptor.description}`, + modeUpdate: mode, + }; +} + +function applyModeUpdate(context: ToolPermissionContext, mode: PermissionMode): ToolPermissionContext { + const next = transitionPermissionMode(context.mode, mode, context); + return { ...next, mode }; +} + +export function showCurrentApprovalMode(context: ToolPermissionContext): ApprovalCommandResult { + const descriptor = getApprovalModeDescriptor(context.mode); + return { + message: `Current approval mode: ${formatApprovalMode(context.mode)} (${descriptor.description})`, + }; +} + +export function executeApproval(args: string, context: ToolPermissionContext): ApprovalCommandResult { + const parsed = parseApprovalModeArg(args); + switch (parsed.type) { + case 'current': + return showCurrentApprovalMode(context); + case 'help': + return { + message: formatApprovalHelp(), + }; + case 'mode': + return applyApprovalMode(context, parsed.mode); + case 'invalid': + return { message: parsed.message }; + } +} + +function formatApprovalHelp(): string { + const aliases = APPROVAL_MODE_DESCRIPTORS.map(descriptor => descriptor.aliases[0]).join('|'); + const modes = APPROVAL_MODE_DESCRIPTORS.map( + descriptor => `- ${descriptor.aliases[0]}: ${descriptor.description}`, + ).join('\n'); + return `Usage: /approval [${aliases}]\n\nApproval modes:\n${modes}`; +} + +function ApplyApprovalAndClose({ + result, + onDone, +}: { + result: ApprovalCommandResult; + onDone: (result: string) => void; +}): React.ReactNode { + const setAppState = useSetAppState(); + const { message, modeUpdate } = result; + React.useEffect(() => { + if (modeUpdate) { + setAppState(prev => ({ + ...prev, + toolPermissionContext: applyModeUpdate(prev.toolPermissionContext, modeUpdate), + })); + } + onDone(message); + }, [setAppState, message, modeUpdate, onDone]); + return null; +} + +function ApprovalPicker({ onDone }: { onDone: (result: string) => void }): React.ReactNode { + const toolPermissionContext = useAppState(s => s.toolPermissionContext); + const setAppState = useSetAppState(); + + const options: OptionWithDescription[] = APPROVAL_MODE_DESCRIPTORS.map(descriptor => { + const unavailableMessage = getModeUnavailableMessage(descriptor.mode, toolPermissionContext); + return { + label: descriptor.label, + value: descriptor.mode, + description: + descriptor.mode === toolPermissionContext.mode + ? `${descriptor.description} (current)` + : (unavailableMessage ?? descriptor.description), + disabled: unavailableMessage !== undefined, + }; + }); + + function handleSelect(mode: PermissionMode): void { + const result = applyApprovalMode(toolPermissionContext, mode); + if (result.modeUpdate) { + setAppState(prev => ({ + ...prev, + toolPermissionContext: applyModeUpdate(prev.toolPermissionContext, result.modeUpdate!), + })); + } + onDone(result.message); + } + + return ( + + Select approval mode +